import pymc as pm
import nutpie
import pandas as pd
= {"observation": range(3)}
coords
with pm.Model(coords=coords) as model:
# Prior distributions for the intercept and slope
= pm.Normal("intercept", mu=0, sigma=1)
intercept = pm.Normal("slope", mu=0, sigma=1)
slope
# Likelihood (sampling distribution) of observations
= [1, 2, 3]
x
= intercept + slope * x
mu = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3], dims="observation") y
Usage with PyMC models
This document shows how to use nutpie
with PyMC models. We will use the pymc
package to define a simple model and sample from it using nutpie
.
Installation
The recommended way to install pymc
is through the conda
ecosystem. A good package manager for conda packages is pixi
. See for the pixi documentation for instructions on how to install it.
We create a new project for this example:
pixi new pymc-example
This will create a new directory pymc-example
with a pixi.toml
file, that you can edit to add meta information.
We then add the pymc
and nutpie
packages to the project:
cd pymc-example
pixi add pymc nutpie arviz
You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. Both are excellent tools for working with Python and data science projects.
Using VSCode
- Open VSCode.
- Open the
pymc-example
directory created earlier. - Create a new file named
model.ipynb
. - Select the pixi kernel to run the code.
Using JupyterLab
- Add jupyter labs to the project by running
pixi add jupyterlab
. - Open JupyterLab by running
pixi run jupyter lab
in your terminal. - Create a new Python notebook.
Defining and Sampling a Simple Model
We will define a simple Bayesian model using pymc
and sample from it using nutpie
.
Model Definition
In your model.ipypy
file or Jupyter notebook, add the following code:
Sampling
We can now compile the model using the numba backend:
= nutpie.compile_pymc_model(model)
compiled = nutpie.sample(compiled) trace
Sampler Progress
Total Chains: 6
Active Chains: 0
Finished Chains: 6
Sampling for now
Estimated Time to Completion: now
Progress | Draws | Divergences | Step Size | Gradients/Draw |
---|---|---|---|---|
1400 | 0 | 0.59 | 3 | |
1400 | 0 | 0.65 | 9 | |
1400 | 0 | 0.55 | 1 | |
1400 | 0 | 0.58 | 15 | |
1400 | 0 | 0.67 | 7 | |
1400 | 0 | 0.58 | 3 |
Alternatively, we can also sample through the pymc
API:
with model:
= pm.sample(model, nuts_sampler="nutpie") trace
While sampling, nutpie shows a progress bar for each chain. It also includes information about how each chain is doing:
- It shows the current number of draws
- The step size of the integrator (very small stepsizes are typically a bad sign)
- The number of divergences (if there are divergences, that means that nutpie is probably not sampling the posterior correctly)
- The number of gradient evaluation nutpie uses for each draw. Large numbers (100 to 1000) are a sign that the parameterization of the model is not ideal, and the sampler is very inefficient.
After sampling, this returns an arviz
InferenceData object that you can use to analyze the trace.
For example, we should check the effective sample size:
import arviz as az
az.ess(trace)
<xarray.Dataset> Size: 16B Dimensions: () Data variables: intercept float64 8B 1.517e+03 slope float64 8B 1.517e+03
and take a look at a trace plot:
; az.plot_trace(trace)
Choosing the backend
Right now, we have been using the numba backend. This is the default backend for nutpie
, when sampling from pymc models. It tends to have relatively long compilation times, but samples small models very efficiently. For larger models the jax
backend sometimes outperforms numba
.
First, we need to install the jax
package:
pixi add jax
We can select the backend by passing the backend
argument to the compile_pymc_model
:
= nutpie.compiled_pymc_model(model, backend="jax")
compiled_jax = nutpie.sample(compiled_jax) trace
Or through the pymc API:
with model:
= pm.sample(
trace
model,="nutpie",
nuts_sampler={"backend": "jax"},
nuts_sampler_kwargs )
If you have an nvidia GPU, you can also use the jax
backend with the gpu
. We will have to install the jaxlib
package with the cuda
option
pixi add jaxlib --build 'cuda12'
Restart the kernel and check that the GPU is available:
import jax
# Should list the cuda device
jax.devices()
Sampling again, should now use the GPU, which you can observe by checking the GPU usage with nvidia-smi
or nvtop
.
Changing the dataset without recompilation
If you want to use the same model with different datasets, you can modify datasets after compilation. Since jax does not like changes in shapes, this is only recommended with the numba backend.
First, we define the model, but put our dataset in a pm.Data
structure:
with pm.Model() as model:
= pm.Data("x", [1, 2, 3])
x = pm.Normal("intercept", mu=0, sigma=1)
intercept = pm.Normal("slope", mu=0, sigma=1)
slope = intercept + slope * x
mu = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3]) y
We can now compile the model:
= nutpie.compile_pymc_model(model)
compiled = nutpie.sample(compiled) trace
Sampler Progress
Total Chains: 6
Active Chains: 0
Finished Chains: 6
Sampling for now
Estimated Time to Completion: now
Progress | Draws | Divergences | Step Size | Gradients/Draw |
---|---|---|---|---|
1400 | 0 | 0.64 | 1 | |
1400 | 0 | 0.61 | 11 | |
1400 | 0 | 0.68 | 3 | |
1400 | 0 | 0.55 | 9 | |
1400 | 0 | 0.57 | 9 | |
1400 | 0 | 0.65 | 3 |
After compilation, we can change the dataset:
= compiled.with_data(x=[4, 5, 6])
compiled2 = nutpie.sample(compiled2) trace2
Sampler Progress
Total Chains: 6
Active Chains: 0
Finished Chains: 6
Sampling for now
Estimated Time to Completion: now
Progress | Draws | Divergences | Step Size | Gradients/Draw |
---|---|---|---|---|
1400 | 0 | 0.42 | 7 | |
1400 | 0 | 0.35 | 27 | |
1400 | 0 | 0.42 | 13 | |
1400 | 0 | 0.45 | 3 | |
1400 | 0 | 0.40 | 3 | |
1400 | 0 | 0.39 | 19 |