import pymc as pm
import nutpie
import numpy as np
import arviz
# Define a 100-dimensional funnel model
with pm.Model() as model:
= pm.Normal("log_sigma")
log_sigma "x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100)
pm.Normal(
# Compile the model with the jax backend
= nutpie.compile_pymc_model(
compiled ="jax", gradient_backend="jax"
model, backend )
Adaptation with Normalizing Flows
Experimental and subject to change
Normalizing flow adaptation through Fisher HMC is a new sampling algorithm that automatically reparameterizes a model. It adds some computational cost outside model log-density evaluations, but allows sampling from much more difficult posterior distributions. For models with expensive log-density evaluations, the normalizing flow adaptation can also be much faster, if it can reduce the number of log-density evaluations needed to reach a given effective sample size.
The normalizing flow adaptation works by learning a transformation of the parameter space that makes the posterior distribution more amenable to sampling. This is done by fitting a sequence of invertible transformations (the “flow”) that maps the original parameter space to a space where the posterior is closer to a standard normal distribution. The flow is trained during warmup.
For more information about the algorithm, see the (still work in progress) paper If only my posterior were normal: Introducing Fisher HMC.
Currently, a lot of time is spent on compiling various parts of the normalizing flow, and for small models this can take a large amount of the total time. Hopefully, we will be able to reduce this overhead in the future.
Requirements
Install the optional dependencies for normalizing flow adaptation:
pip install 'nutpie[nnflow]'
If you use with PyMC, this will only work if the model is compiled using the jax backend, and if the gradient_backend
is also set to jax
.
Training of the normalizing flow can often be accelerated by using a GPU (even if the model itself is written in Stan, without any GPU support). To enable GPU you need to make sure your jax
installation comes with GPU support, for instance by installing it with pip install 'jax[cuda12]'
, or selecting the jaxlib
version with GPU support, if you are using conda-forge. You can check if your installation has GPU support by checking the output of:
import jax
jax.devices()
Usage
To use normalizing flow adaptation in nutpie
, you need to enable the transform_adapt
option during sampling. Here is an example of how we can use it to sample from a difficult posterior:
If we sample this model without normalizing flow adaptation, we will encounter convergence issues, often divergences and always low effective sample sizes:
# Sample without normalizing flow adaptation
= nutpie.sample(compiled, seed=1)
trace_no_nf assert (arviz.ess(trace_no_nf) < 100).any().to_array().any()
Sampler Progress
Total Chains: 6
Active Chains: 0
Finished Chains: 6
Sampling for 16 seconds
Estimated Time to Completion: now
Progress | Draws | Divergences | Step Size | Gradients/Draw |
---|---|---|---|---|
1400 | 0 | 0.45 | 7 | |
1400 | 0 | 0.31 | 15 | |
1400 | 0 | 0.31 | 7 | |
1400 | 0 | 0.28 | 7 | |
1400 | 0 | 0.39 | 15 | |
1400 | 0 | 0.34 | 7 |
# We can add further arguments for the normalizing flow:
= compiled.with_transform_adapt(
compiled =5, # Number of layers in the normalizing flow
num_layers=32, # Neural networks with 32 hidden units
nn_width=6, # Number of windows with a diagonal mass matrix intead of a flow
num_diag_windows=False, # Whether to print details about the adaptation process
verbose=False, # Whether to show a progress bar for each optimization step
show_progress
)
# Sample with normalizing flow adaptation
= nutpie.sample(
trace_nf
compiled,=True, # Enable the normalizing flow adaptation
transform_adapt=1,
seed=2,
chains=1, # Running chains in parallel can be slow
cores=150, # Optimize the normalizing flow every 150 iterations
window_switch_freq
)assert trace_nf.sample_stats.diverging.sum() == 0
assert (arviz.ess(trace_nf) > 1000).all().to_array().all()
Sampler Progress
Total Chains: 2
Active Chains: 0
Finished Chains: 2
Sampling for 18 minutes
Estimated Time to Completion: now
Progress | Draws | Divergences | Step Size | Gradients/Draw |
---|---|---|---|---|
2500 | 0 | 0.52 | 7 | |
2500 | 0 | 0.53 | 7 |
The sampler used fewer gradient evaluations with the normalizing flow adaptation, but still converged, and produce a good effective sample size:
= int(trace_nf.sample_stats.n_steps.sum() + trace_nf.warmup_sample_stats.n_steps.sum())
n_steps = float(arviz.ess(trace_nf).min().to_array().min())
ess print(f"Number of gradient evaluations: {n_steps}")
print(f"Minimum effective sample size: {ess}")
Number of gradient evaluations: 42527
Minimum effective sample size: 1835.9674640023168
Without normalizing flow, it used more gradient evaluations, and still wasn’t able to get a good effective sample size:
= int(trace_no_nf.sample_stats.n_steps.sum() + trace_no_nf.warmup_sample_stats.n_steps.sum())
n_steps = float(arviz.ess(trace_no_nf).min().to_array().min())
ess print(f"Number of gradient evaluations: {n_steps}")
print(f"Minimum effective sample size: {ess}")
Number of gradient evaluations: 124219
Minimum effective sample size: 31.459420094540565
The flow adaptation occurs during warmup, so the number of warmup draws should be large enough to allow the flow to converge. For more complex posteriors, you may need to increase the number of layers (using the num_layers
argument), or you might want to increase the number of warmup draws.
To monitor the progress of the flow adaptation, you can set verbose=True
, or show_progress=True
, but the second should only be used if you sample just one chain.
All losses are on a log-scale. Negative values smaller -2 are a good sign that the adaptation was successful. If the loss stays positive, the flow is either not expressive enough, or the training period is too short. The sampler might still converge, but will probably need more gradient evaluations per effective draw. Large losses bigger than 6 tend to indicate that the posterior is too difficult to sample with the current flow, and the sampler will probably not converge.