Usage with PyMC models

This document shows how to use nutpie with PyMC models. We will use the pymc package to define a simple model and sample from it using nutpie.

Installation

The recommended way to install pymc is through the conda ecosystem. A good package manager for conda packages is pixi. See for the pixi documentation for instructions on how to install it.

We create a new project for this example:

pixi new pymc-example

This will create a new directory pymc-example with a pixi.toml file, that you can edit to add meta information.

We then add the pymc and nutpie packages to the project:

cd pymc-example
pixi add pymc nutpie arviz

You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. Both are excellent tools for working with Python and data science projects.

Using VSCode

  1. Open VSCode.
  2. Open the pymc-example directory created earlier.
  3. Create a new file named model.ipynb.
  4. Select the pixi kernel to run the code.

Using JupyterLab

  1. Add jupyter labs to the project by running pixi add jupyterlab.
  2. Open JupyterLab by running pixi run jupyter lab in your terminal.
  3. Create a new Python notebook.

Defining and Sampling a Simple Model

We will define a simple Bayesian model using pymc and sample from it using nutpie.

Model Definition

In your model.ipypy file or Jupyter notebook, add the following code:

import pymc as pm
import nutpie
import pandas as pd

coords = {"observation": range(3)}

with pm.Model(coords=coords) as model:
    # Prior distributions for the intercept and slope
    intercept = pm.Normal("intercept", mu=0, sigma=1)
    slope = pm.Normal("slope", mu=0, sigma=1)

    # Likelihood (sampling distribution) of observations
    x = [1, 2, 3]

    mu = intercept + slope * x
    y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3], dims="observation")

Sampling

We can now compile the model using the numba backend:

compiled = nutpie.compile_pymc_model(model)
trace = nutpie.sample(compiled)

Sampler Progress

Total Chains: 6

Active Chains: 0

Finished Chains: 6

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
1400 0 0.59 3
1400 0 0.65 9
1400 0 0.55 1
1400 0 0.58 15
1400 0 0.67 7
1400 0 0.58 3

Alternatively, we can also sample through the pymc API:

with model:
    trace = pm.sample(model, nuts_sampler="nutpie")

While sampling, nutpie shows a progress bar for each chain. It also includes information about how each chain is doing:

  • It shows the current number of draws
  • The step size of the integrator (very small stepsizes are typically a bad sign)
  • The number of divergences (if there are divergences, that means that nutpie is probably not sampling the posterior correctly)
  • The number of gradient evaluation nutpie uses for each draw. Large numbers (100 to 1000) are a sign that the parameterization of the model is not ideal, and the sampler is very inefficient.

After sampling, this returns an arviz InferenceData object that you can use to analyze the trace.

For example, we should check the effective sample size:

import arviz as az
az.ess(trace)
<xarray.Dataset> Size: 16B
Dimensions:    ()
Data variables:
    intercept  float64 8B 1.517e+03
    slope      float64 8B 1.517e+03

and take a look at a trace plot:

az.plot_trace(trace);

Choosing the backend

Right now, we have been using the numba backend. This is the default backend for nutpie, when sampling from pymc models. It tends to have relatively long compilation times, but samples small models very efficiently. For larger models the jax backend sometimes outperforms numba.

First, we need to install the jax package:

pixi add jax

We can select the backend by passing the backend argument to the compile_pymc_model:

compiled_jax = nutpie.compiled_pymc_model(model, backend="jax")
trace = nutpie.sample(compiled_jax)

Or through the pymc API:

with model:
    trace = pm.sample(
        model,
        nuts_sampler="nutpie",
        nuts_sampler_kwargs={"backend": "jax"},
    )

If you have an nvidia GPU, you can also use the jax backend with the gpu. We will have to install the jaxlib package with the cuda option

pixi add jaxlib --build 'cuda12'

Restart the kernel and check that the GPU is available:

import jax

# Should list the cuda device
jax.devices()

Sampling again, should now use the GPU, which you can observe by checking the GPU usage with nvidia-smi or nvtop.

Changing the dataset without recompilation

If you want to use the same model with different datasets, you can modify datasets after compilation. Since jax does not like changes in shapes, this is only recommended with the numba backend.

First, we define the model, but put our dataset in a pm.Data structure:

with pm.Model() as model:
    x = pm.Data("x", [1, 2, 3])
    intercept = pm.Normal("intercept", mu=0, sigma=1)
    slope = pm.Normal("slope", mu=0, sigma=1)
    mu = intercept + slope * x
    y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3])

We can now compile the model:

compiled = nutpie.compile_pymc_model(model)
trace = nutpie.sample(compiled)

Sampler Progress

Total Chains: 6

Active Chains: 0

Finished Chains: 6

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
1400 0 0.64 1
1400 0 0.61 11
1400 0 0.68 3
1400 0 0.55 9
1400 0 0.57 9
1400 0 0.65 3

After compilation, we can change the dataset:

compiled2 = compiled.with_data(x=[4, 5, 6])
trace2 = nutpie.sample(compiled2)

Sampler Progress

Total Chains: 6

Active Chains: 0

Finished Chains: 6

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
1400 0 0.42 7
1400 0 0.35 27
1400 0 0.42 13
1400 0 0.45 3
1400 0 0.40 3
1400 0 0.39 19