Adaptation with Normalizing Flows

Experimental and subject to change

Normalizing flow adaptation through Fisher HMC is a new sampling algorithm that automatically reparameterizes a model. It adds some computational cost outside model log-density evaluations, but allows sampling from much more difficult posterior distributions. For models with expensive log-density evaluations, the normalizing flow adaptation can also be much faster, if it can reduce the number of log-density evaluations needed to reach a given effective sample size.

The normalizing flow adaptation works by learning a transformation of the parameter space that makes the posterior distribution more amenable to sampling. This is done by fitting a sequence of invertible transformations (the “flow”) that maps the original parameter space to a space where the posterior is closer to a standard normal distribution. The flow is trained during warmup.

For more information about the algorithm, see the (still work in progress) paper If only my posterior were normal: Introducing Fisher HMC.

Currently, a lot of time is spent on compiling various parts of the normalizing flow, and for small models this can take a large amount of the total time. Hopefully, we will be able to reduce this overhead in the future.

Requirements

Install the optional dependencies for normalizing flow adaptation:

pip install 'nutpie[nnflow]'

If you use with PyMC, this will only work if the model is compiled using the jax backend, and if the gradient_backend is also set to jax.

Training of the normalizing flow can often be accelerated by using a GPU (even if the model itself is written in Stan, without any GPU support). To enable GPU you need to make sure your jax installation comes with GPU support, for instance by installing it with pip install 'jax[cuda12]', or selecting the jaxlib version with GPU support, if you are using conda-forge. You can check if your installation has GPU support by checking the output of:

import jax
jax.devices()

Usage

To use normalizing flow adaptation in nutpie, you need to enable the transform_adapt option during sampling. Here is an example of how we can use it to sample from a difficult posterior:

import pymc as pm
import nutpie
import numpy as np
import arviz

# Define a 100-dimensional funnel model
with pm.Model() as model:
    log_sigma = pm.Normal("log_sigma")
    pm.Normal("x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100)

# Compile the model with the jax backend
compiled = nutpie.compile_pymc_model(
    model, backend="jax", gradient_backend="jax"
)

If we sample this model without normalizing flow adaptation, we will encounter convergence issues, often divergences and always low effective sample sizes:

# Sample without normalizing flow adaptation
trace_no_nf = nutpie.sample(compiled, seed=1)
assert (arviz.ess(trace_no_nf) < 100).any().to_array().any()

Sampler Progress

Total Chains: 6

Active Chains: 0

Finished Chains: 6

Sampling for 16 seconds

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
1400 0 0.45 7
1400 0 0.31 15
1400 0 0.31 7
1400 0 0.28 7
1400 0 0.39 15
1400 0 0.34 7
# We can add further arguments for the normalizing flow:
compiled = compiled.with_transform_adapt(
    num_layers=5,        # Number of layers in the normalizing flow
    nn_width=32,         # Neural networks with 32 hidden units
    num_diag_windows=6,  # Number of windows with a diagonal mass matrix intead of a flow
    verbose=False,       # Whether to print details about the adaptation process
    show_progress=False, # Whether to show a progress bar for each optimization step
)

# Sample with normalizing flow adaptation
trace_nf = nutpie.sample(
    compiled,
    transform_adapt=True,  # Enable the normalizing flow adaptation
    seed=1,
    chains=2,
    cores=1,  # Running chains in parallel can be slow
    window_switch_freq=150,  # Optimize the normalizing flow every 150 iterations
)
assert trace_nf.sample_stats.diverging.sum() == 0
assert (arviz.ess(trace_nf) > 1000).all().to_array().all()

Sampler Progress

Total Chains: 2

Active Chains: 0

Finished Chains: 2

Sampling for 18 minutes

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2500 0 0.52 7
2500 0 0.53 7

The sampler used fewer gradient evaluations with the normalizing flow adaptation, but still converged, and produce a good effective sample size:

n_steps = int(trace_nf.sample_stats.n_steps.sum() + trace_nf.warmup_sample_stats.n_steps.sum())
ess = float(arviz.ess(trace_nf).min().to_array().min())
print(f"Number of gradient evaluations: {n_steps}")
print(f"Minimum effective sample size: {ess}")
Number of gradient evaluations: 42527
Minimum effective sample size: 1835.9674640023168

Without normalizing flow, it used more gradient evaluations, and still wasn’t able to get a good effective sample size:

n_steps = int(trace_no_nf.sample_stats.n_steps.sum() + trace_no_nf.warmup_sample_stats.n_steps.sum())
ess = float(arviz.ess(trace_no_nf).min().to_array().min())
print(f"Number of gradient evaluations: {n_steps}")
print(f"Minimum effective sample size: {ess}")
Number of gradient evaluations: 124219
Minimum effective sample size: 31.459420094540565

The flow adaptation occurs during warmup, so the number of warmup draws should be large enough to allow the flow to converge. For more complex posteriors, you may need to increase the number of layers (using the num_layers argument), or you might want to increase the number of warmup draws.

To monitor the progress of the flow adaptation, you can set verbose=True, or show_progress=True, but the second should only be used if you sample just one chain.

All losses are on a log-scale. Negative values smaller -2 are a good sign that the adaptation was successful. If the loss stays positive, the flow is either not expressive enough, or the training period is too short. The sampler might still converge, but will probably need more gradient evaluations per effective draw. Large losses bigger than 6 tend to indicate that the posterior is too difficult to sample with the current flow, and the sampler will probably not converge.