# Dirichlet process mixtures for density estimation¶

Author: Austin Rochford

## Dirichlet processes¶

The Dirichlet process is a flexible probability distribution over the space of distributions. Most generally, a probability distribution, \(P\), on a set \(\Omega\) is a [measure](https://en.wikipedia.org/wiki/Measure_(mathematics%29) that assigns measure one to the entire space (\(P(\Omega) = 1\)). A Dirichlet process \(P \sim \textrm{DP}(\alpha, P_0)\) is a measure that has the property that, for every finite disjoint partition \(S_1, \ldots, S_n\) of \(\Omega\),

Here \(P_0\) is the base probability measure on the space \(\Omega\). The precision parameter \(\alpha > 0\) controls how close samples from the Dirichlet process are to the base measure, \(P_0\). As \(\alpha \to \infty\), samples from the Dirichlet process approach the base measure \(P_0\).

Dirichlet processes have several properties that make then quite suitable to MCMC simulation.

The posterior given i.i.d. observations \(\omega_1, \ldots, \omega_n\) from a Dirichlet process \(P \sim \textrm{DP}(\alpha, P_0)\) is also a Dirichlet process with

\[P\ |\ \omega_1, \ldots, \omega_n \sim \textrm{DP}\left(\alpha + n, \frac{\alpha}{\alpha + n} P_0 + \frac{1}{\alpha + n} \sum_{i = 1}^n \delta_{\omega_i}\right),\]

where \(\delta\) is the Dirac delta measure

The posterior predictive distribution of a new observation is a compromise between the base measure and the observations,

\[\omega\ |\ \omega_1, \ldots, \omega_n \sim \frac{\alpha}{\alpha + n} P_0 + \frac{1}{\alpha + n} \sum_{i = 1}^n \delta_{\omega_i}.\]

We see that the prior precision \(\alpha\) can naturally be interpreted as a prior sample size. The form of this posterior predictive distribution also lends itself to Gibbs sampling.

Samples, \(P \sim \textrm{DP}(\alpha, P_0)\), from a Dirichlet process are discrete with probability one. That is, there are elements \(\omega_1, \omega_2, \ldots\) in \(\Omega\) and weights \(w_1, w_2, \ldots\) with \(\sum_{i = 1}^{\infty} w_i = 1\) such that

\[P = \sum_{i = 1}^\infty w_i \delta_{\omega_i}.\]The stick-breaking process gives an explicit construction of the weights \(w_i\) and samples \(\omega_i\) above that is straightforward to sample from. If \(\beta_1, \beta_2, \ldots \sim \textrm{Beta}(1, \alpha)\), then \(w_i = \beta_i \prod_{j = 1}^{j - 1} (1 - \beta_j)\). The relationship between this representation and stick breaking may be illustrated as follows:

- Start with a stick of length one.
- Break the stick into two portions, the first of proportion \(w_1 = \beta_1\) and the second of proportion \(1 - w_1\).
- Further break the second portion into two portions, the first of proportion \(\beta_2\) and the second of proportion \(1 - \beta_2\). The length of the first portion of this stick is \(\beta_2 (1 - \beta_1)\); the length of the second portion is \((1 - \beta_1) (1 - \beta_2)\).
- Continue breaking the second portion from the previous break in this manner forever. If \(\omega_1, \omega_2, \ldots \sim P_0\), then

\[P = \sum_{i = 1}^\infty w_i \delta_{\omega_i} \sim \textrm{DP}(\alpha, P_0).\]

We can use the stick-breaking process above to easily sample from a Dirichlet process in Python. For this example, \(\alpha = 2\) and the base distribution is \(N(0, 1)\).

```
In [1]:
```

```
%matplotlib inline
```

```
In [2]:
```

```
from __future__ import division
```

```
In [3]:
```

```
from matplotlib import pyplot as plt
import numpy as np
import pymc3 as pm
import scipy as sp
import seaborn as sns
from statsmodels.datasets import get_rdataset
from theano import tensor as tt
```

```
In [4]:
```

```
blue, *_ = sns.color_palette()
```

```
In [5]:
```

```
SEED = 5132290 # from random.org
np.random.seed(SEED)
```

```
In [6]:
```

```
N = 20
K = 30
alpha = 2.
P0 = sp.stats.norm
```

We draw and plot samples from the stick-breaking process.

```
In [7]:
```

```
beta = sp.stats.beta.rvs(1, alpha, size=(N, K))
w = np.empty_like(beta)
w[:, 0] = beta[:, 0]
w[:, 1:] = beta[:, 1:] * (1 - beta[:, :-1]).cumprod(axis=1)
omega = P0.rvs(size=(N, K))
x_plot = np.linspace(-3, 3, 200)
sample_cdfs = (w[..., np.newaxis] * np.less.outer(omega, x_plot)).sum(axis=1)
```

```
In [8]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(x_plot, sample_cdfs[0], c='gray', alpha=0.75,
label='DP sample CDFs');
ax.plot(x_plot, sample_cdfs[1:].T, c='gray', alpha=0.75);
ax.plot(x_plot, P0.cdf(x_plot), c='k', label='Base CDF');
ax.set_title(r'$\alpha = {}$'.format(alpha));
ax.legend(loc=2);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

As stated above, as \(\alpha \to \infty\), samples from the Dirichlet process converge to the base distribution.

```
In [9]:
```

```
fig, (l_ax, r_ax) = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(16, 6))
K = 50
alpha = 10.
beta = sp.stats.beta.rvs(1, alpha, size=(N, K))
w = np.empty_like(beta)
w[:, 0] = beta[:, 0]
w[:, 1:] = beta[:, 1:] * (1 - beta[:, :-1]).cumprod(axis=1)
omega = P0.rvs(size=(N, K))
sample_cdfs = (w[..., np.newaxis] * np.less.outer(omega, x_plot)).sum(axis=1)
l_ax.plot(x_plot, sample_cdfs[0], c='gray', alpha=0.75,
label='DP sample CDFs');
l_ax.plot(x_plot, sample_cdfs[1:].T, c='gray', alpha=0.75);
l_ax.plot(x_plot, P0.cdf(x_plot), c='k', label='Base CDF');
l_ax.set_title(r'$\alpha = {}$'.format(alpha));
l_ax.legend(loc=2);
K = 200
alpha = 50.
beta = sp.stats.beta.rvs(1, alpha, size=(N, K))
w = np.empty_like(beta)
w[:, 0] = beta[:, 0]
w[:, 1:] = beta[:, 1:] * (1 - beta[:, :-1]).cumprod(axis=1)
omega = P0.rvs(size=(N, K))
sample_cdfs = (w[..., np.newaxis] * np.less.outer(omega, x_plot)).sum(axis=1)
r_ax.plot(x_plot, sample_cdfs[0], c='gray', alpha=0.75,
label='DP sample CDFs');
r_ax.plot(x_plot, sample_cdfs[1:].T, c='gray', alpha=0.75);
r_ax.plot(x_plot, P0.cdf(x_plot), c='k', label='Base CDF');
r_ax.set_title(r'$\alpha = {}$'.format(alpha));
r_ax.legend(loc=2);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

## Dirichlet process mixtures¶

For the task of density estimation, the (almost sure) discreteness of samples from the Dirichlet process is a significant drawback. This problem can be solved with another level of indirection by using Dirichlet process mixtures for density estimation. A Dirichlet process mixture uses component densities from a parametric family \(\mathcal{F} = \{f_{\theta}\ |\ \theta \in \Theta\}\) and represents the mixture weights as a Dirichlet process. If \(P_0\) is a probability measure on the parameter space \(\Theta\), a Dirichlet process mixture is the hierarchical model

To illustrate this model, we simulate draws from a Dirichlet process mixture with \(\alpha = 2\), \(\theta \sim N(0, 1)\), \(x\ |\ \theta \sim N(\theta, (0.3)^2)\).

```
In [10]:
```

```
N = 5
K = 30
alpha = 2
P0 = sp.stats.norm
f = lambda x, theta: sp.stats.norm.pdf(x, theta, 0.3)
```

```
In [11]:
```

```
beta = sp.stats.beta.rvs(1, alpha, size=(N, K))
w = np.empty_like(beta)
w[:, 0] = beta[:, 0]
w[:, 1:] = beta[:, 1:] * (1 - beta[:, :-1]).cumprod(axis=1)
theta = P0.rvs(size=(N, K))
dpm_pdf_components = f(x_plot[np.newaxis, np.newaxis, :], theta[..., np.newaxis])
dpm_pdfs = (w[..., np.newaxis] * dpm_pdf_components).sum(axis=1)
```

```
In [12]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(x_plot, dpm_pdfs.T, c='gray');
ax.set_yticklabels([]);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

We now focus on a single mixture and decompose it into its individual (weighted) mixture components.

```
In [13]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
ix = 1
ax.plot(x_plot, dpm_pdfs[ix], c='k', label='Density');
ax.plot(x_plot, (w[..., np.newaxis] * dpm_pdf_components)[ix, 0],
'--', c='k', label='Mixture components (weighted)');
ax.plot(x_plot, (w[..., np.newaxis] * dpm_pdf_components)[ix].T,
'--', c='k');
ax.set_yticklabels([]);
ax.legend(loc=1);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

Sampling from these stochastic processes is fun, but these ideas become
truly useful when we fit them to data. The discreteness of samples and
the stick-breaking representation of the Dirichlet process lend
themselves nicely to Markov chain Monte Carlo simulation of posterior
distributions. We will perform this sampling using
``pymc3`

<https://pymc-devs.github.io/pymc3/>`__.

Our first example uses a Dirichlet process mixture to estimate the density of waiting times between eruptions of the Old Faithful geyser in Yellowstone National Park.

```
In [14]:
```

```
old_faithful_df = get_rdataset('faithful', cache=True).data[['waiting']]
```

For convenience in specifying the prior, we standardize the waiting time between eruptions.

```
In [15]:
```

```
old_faithful_df['std_waiting'] = (old_faithful_df.waiting - old_faithful_df.waiting.mean()) / old_faithful_df.waiting.std()
```

```
In [16]:
```

```
old_faithful_df.head()
```

```
Out[16]:
```

waiting | std_waiting | |
---|---|---|

0 | 79 | 0.596025 |

1 | 54 | -1.242890 |

2 | 74 | 0.228242 |

3 | 62 | -0.654437 |

4 | 85 | 1.037364 |

```
In [17]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
n_bins = 20
ax.hist(old_faithful_df.std_waiting, bins=n_bins, color=blue, lw=0, alpha=0.5);
ax.set_xlabel('Standardized waiting time between eruptions');
ax.set_ylabel('Number of eruptions');
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

Observant readers will have noted that we have not been continuing the stick-breaking process indefinitely as indicated by its definition, but rather have been truncating this process after a finite number of breaks. Obviously, when computing with Dirichlet processes, it is necessary to only store a finite number of its point masses and weights in memory. This restriction is not terribly onerous, since with a finite number of observations, it seems quite likely that the number of mixture components that contribute non-neglible mass to the mixture will grow slower than the number of samples. This intuition can be formalized to show that the (expected) number of components that contribute non-negligible mass to the mixture approaches \(\alpha \log N\), where \(N\) is the sample size.

There are various clever Gibbs sampling techniques for Dirichlet processes that allow the number of components stored to grow as needed. Stochastic memoization is another powerful technique for simulating Dirichlet processes while only storing finitely many components in memory. In this introductory example, we take the much less sophistocated approach of simply truncating the Dirichlet process components that are stored after a fixed number, \(K\), of components. Ohlssen, et al. provide justification for truncation, showing that \(K > 5 \alpha + 2\) is most likely sufficient to capture almost all of the mixture weight (\(\sum_{i = 1}^{K} w_i > 0.99\)). In practice, we can verify the suitability of our truncated approximation to the Dirichlet process by checking the number of components that contribute non-negligible mass to the mixture. If, in our simulations, all components contribute non-negligible mass to the mixture, we have truncated the Dirichlet process too early.

Our (truncated) Dirichlet process mixture model for the standardized waiting times is

Note that instead of fixing a value of \(\alpha\), as in our previous simulations, we specify a prior on \(\alpha\), so that we may learn its posterior distribution from the observations.

We now construct this model using `pymc3`

.

```
In [18]:
```

```
N = old_faithful_df.shape[0]
K = 30
```

```
In [19]:
```

```
def stick_breaking(beta):
portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining
```

```
In [20]:
```

```
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1., alpha, shape=K)
w = pm.Deterministic('w', stick_breaking(beta))
tau = pm.Gamma('tau', 1., 1., shape=K)
lambda_ = pm.Uniform('lambda', 0, 5, shape=K)
mu = pm.Normal('mu', 0, tau=lambda_ * tau, shape=K)
obs = pm.NormalMixture('obs', w, mu, tau=lambda_ * tau,
observed=old_faithful_df.std_waiting.values)
```

We sample from the model 2,000 times using NUTS initialized with ADVI.

```
In [21]:
```

```
with model:
trace = pm.sample(2000, n_init=50000, random_seed=SEED)
```

```
Auto-assigning NUTS sampler...
Initializing NUTS using advi...
Average ELBO = -9,730.2: 100%|██████████| 50000/50000 [00:42<00:00, 1172.42it/s]
Finished [100%]: Average ELBO = -9,729.8
100%|██████████| 2000/2000 [01:25<00:00, 23.40it/s]
```

The posterior distribution of \(\alpha\) is highly concentrated between 0.25 and 1.

```
In [22]:
```

```
pm.traceplot(trace, varnames=['alpha']);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

To verify that truncation is not biasing our results, we plot the posterior expected mixture weight of each component.

```
In [23]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
plot_w = np.arange(K) + 1
ax.bar(plot_w - 0.5, trace['w'].mean(axis=0), width=1., lw=0);
ax.set_xlim(0.5, K);
ax.set_xlabel('Component');
ax.set_ylabel('Posterior expected mixture weight');
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

We see that only three mixture components have appreciable posterior expected weights, so we conclude that truncating the Dirichlet process to forty components has not appreciably affected our estimates.

We now compute and plot our posterior density estimate.

```
In [24]:
```

```
post_pdf_contribs = sp.stats.norm.pdf(np.atleast_3d(x_plot),
trace['mu'][:, np.newaxis, :],
1. / np.sqrt(trace['lambda'] * trace['tau'])[:, np.newaxis, :])
post_pdfs = (trace['w'][:, np.newaxis, :] * post_pdf_contribs).sum(axis=-1)
post_pdf_low, post_pdf_high = np.percentile(post_pdfs, [2.5, 97.5], axis=0)
```

```
In [25]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
n_bins = 20
ax.hist(old_faithful_df.std_waiting.values, bins=n_bins, normed=True,
color=blue, lw=0, alpha=0.5);
ax.fill_between(x_plot, post_pdf_low, post_pdf_high,
color='gray', alpha=0.45);
ax.plot(x_plot, post_pdfs[0],
c='gray', label='Posterior sample densities');
ax.plot(x_plot, post_pdfs[::100].T, c='gray');
ax.plot(x_plot, post_pdfs.mean(axis=0),
c='k', label='Posterior expected density');
ax.set_xlabel('Standardized waiting time between eruptions');
ax.set_yticklabels([]);
ax.set_ylabel('Density');
ax.legend(loc=2);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

As above, we can decompose this density estimate into its (weighted) mixture components.

```
In [26]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
n_bins = 20
ax.hist(old_faithful_df.std_waiting.values, bins=n_bins, normed=True,
color=blue, lw=0, alpha=0.5);
ax.plot(x_plot, post_pdfs.mean(axis=0),
c='k', label='Posterior expected density');
ax.plot(x_plot, (trace['w'][:, np.newaxis, :] * post_pdf_contribs).mean(axis=0)[:, 0],
'--', c='k', label='Posterior expected mixture\ncomponents\n(weighted)');
ax.plot(x_plot, (trace['w'][:, np.newaxis, :] * post_pdf_contribs).mean(axis=0),
'--', c='k');
ax.set_xlabel('Standardized waiting time between eruptions');
ax.set_yticklabels([]);
ax.set_ylabel('Density');
ax.legend(loc=2);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

The Dirichlet process mixture model is incredibly flexible in terms of the family of parametric component distributions \(\{f_{\theta}\ |\ f_{\theta} \in \Theta\}\). We illustrate this flexibility below by using Poisson component distributions to estimate the density of sunspots per year.

```
In [27]:
```

```
sunspot_df = get_rdataset('sunspot.year', cache=True).data
```

```
In [28]:
```

```
sunspot_df.head()
```

```
Out[28]:
```

time | sunspot.year | |
---|---|---|

0 | 1700 | 5.0 |

1 | 1701 | 11.0 |

2 | 1702 | 16.0 |

3 | 1703 | 23.0 |

4 | 1704 | 36.0 |

For this example, the model is

```
In [29]:
```

```
K = 50
N = sunspot_df.shape[0]
```

```
In [30]:
```

```
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=K)
w = pm.Deterministic('w', stick_breaking(beta))
mu = pm.Uniform('mu', 0., 300., shape=K)
obs = pm.Mixture('obs', w, pm.Poisson.dist(mu), observed=sunspot_df['sunspot.year'])
```

```
In [31]:
```

```
with model:
step = pm.Metropolis()
trace_ = pm.sample(100000, step=step, random_seed=SEED)
trace = trace_[50000::50]
```

```
100%|██████████| 100000/100000 [03:41<00:00, 451.22it/s]
```

For the sunspot model, the posterior distribution of \(\alpha\) is concentrated between 0.6 and 1.2, indicating that we should expect more components to contribute non-negligible amounts to the mixture than for the Old Faithful waiting time model.

```
In [32]:
```

```
pm.traceplot(trace, varnames=['alpha']);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

Indeed, we see that between ten and fifteen mixture components have appreciable posterior expected weight.

```
In [33]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
plot_w = np.arange(K) + 1
ax.bar(plot_w - 0.5, trace['w'].mean(axis=0), width=1., lw=0);
ax.set_xlim(0.5, K);
ax.set_xlabel('Component');
ax.set_ylabel('Posterior expected mixture weight');
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

We now calculate and plot the fitted density estimate.

```
In [34]:
```

```
x_plot = np.arange(250)
```

```
In [35]:
```

```
post_pmf_contribs = sp.stats.poisson.pmf(np.atleast_3d(x_plot),
trace['mu'][:, np.newaxis, :])
post_pmfs = (trace['w'][:, np.newaxis, :] * post_pmf_contribs).sum(axis=-1)
post_pmf_low, post_pmf_high = np.percentile(post_pmfs, [2.5, 97.5], axis=0)
```

```
In [36]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(sunspot_df['sunspot.year'].values, bins=40, normed=True, lw=0, alpha=0.75);
ax.fill_between(x_plot, post_pmf_low, post_pmf_high,
color='gray', alpha=0.45)
ax.plot(x_plot, post_pmfs[0],
c='gray', label='Posterior sample densities');
ax.plot(x_plot, post_pmfs[::200].T, c='gray');
ax.plot(x_plot, post_pmfs.mean(axis=0),
c='k', label='Posterior expected density');
ax.set_xlabel('Yearly sunspot count');
ax.set_yticklabels([]);
ax.legend(loc=1);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

Again, we can decompose the posterior expected density into weighted mixture densities.

```
In [37]:
```

```
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(sunspot_df['sunspot.year'].values, bins=40, normed=True, lw=0, alpha=0.75);
ax.plot(x_plot, post_pmfs.mean(axis=0),
c='k', label='Posterior expected density');
ax.plot(x_plot, (trace['w'][:, np.newaxis, :] * post_pmf_contribs).mean(axis=0)[:, 0],
'--', c='k', label='Posterior expected\nmixture components\n(weighted)');
ax.plot(x_plot, (trace['w'][:, np.newaxis, :] * post_pmf_contribs).mean(axis=0),
'--', c='k');
ax.set_xlabel('Yearly sunspot count');
ax.set_yticklabels([]);
ax.legend(loc=1);
```

```
/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
(prop.get_family(), self.defaultFamily[fontext]))
```

An earlier version of this example first appeared here.