Home » Stan for multimodal mixtures—from exponential CPS to linear DP

Stan for multimodal mixtures—from exponential CPS to linear DP

This post is from Bob

I’ve been thinking about evaluation recently because I’ve been working with colleagues on new samplers, which means evaluating how well they work (more on that soon). This in turn means coming up with target densities on which to evaluate them.

A combinatorial multimodal test case

I wanted something clearly multimodal and hence not log concave. I remember somebody’s paper (help with citation?) used a mixture of four two-dimensional isotropic normals, separated enough to make transition possible, but still a bit difficult. Not to give the game away, but here’s a posterior plot of a sample drawn from Stan—the imbalance in component weights is intentional, as I’ll describe below.

I’m still working on posteriordb with the Stan gang (see the authors of the linked paper) and Inference Gym with Reuben Cohn-Gordon (another linguist by training and programming language geek turned to MCMC), and thought it’d be nice to have something a little more general than just the 2D example. So I got out my notebook, and realized the generalization to D dimensions involves 2^D mixture components that are normal with unit covariance located at the points in {-r, r}^D.

p(y | r) = SUM_{mu in {-r, r}^D} 1/2^D normal(y | mu, I).

I then generalized to allow setting the probability that Y[d] > 0 to be p in (0, 1) to get a non-uniform mixture. This leads to a slightly more complex density because of the non-uniformity.

p(y | r) = SUM_{mu in {-r, r}^D} binomial(sum(mu == r) | D, p) * normal(y | mu, I).

Coding in Stan with continuation-passing style

So how do we code this in Stan? Obviously it needs to be recursive or at least iterative to deal with the D being unknown at compilation time. Whenever I see recursion, I immediately think of continuation passing style (CPS). So I came up with this Stan program to code a generalization in D dimensions.

functions {
  real mm(vector y, real r, real p, int d, real lp) {
    if (d == 0) {
      return lp;
    }
    real lp1 = mm(y, r, p, d - 1, lp + normal_lpdf(y[d] | r, 1));
    real lp2 = mm(y, r, p, d - 1, lp + normal_lpdf(y[d] | -r, 1));
    return log_mix(p, lp1, lp2);
  }

  real mm_lpdf(vector y, real r, real p, int D) {
    return mm(y, r, p, D, 0);
  }
}
data {
  int D;   // number of dimensions
  real r;  // modes in {-r, r}^D
  real p;  // p = Pr[Y[d] > 0]
}
parameters {
  vector[D] y;
}
model {
  y ~ mm(r, p, D);
}

The log_mix function is defined as follows, but implemented in a more stable way.

log_mix(p, lp1, lp2)
    = log_sum_exp(log(p) + lp1, log(1 - p) + lp2)
    = log(exp(log(p) + lp1) + exp(log(1 - p) + lp2))
    = log(p * exp(lp1) + (1 - p) * exp(lp2)).

If you unfold the recursion manually, the leaves wind up being the log densities and the weights wind up percolating as described in the definition. If you’re having trouble seeing this, manually expanding the D = 1 and then D = 2 cases will help. It’s compact, but it’s still exponential in cost to evaluate a log density and gradient (i.e., O(2^D)). Although it’s slow in higher dimensions, it works.

Python scripts

The plot above is from the following Python code that sets

D = 2, r = 2.5, and p = 2.0/3.0.

For those of you considering a move to Python, having a clone of data frames (pandas) and ggplot2 (plotnine) is a godsend. And yes, of course the LLMs know how to code pandas and plotnine.

import cmdstanpy as csp
import pandas as pd
import plotnine as pn

model = csp.CmdStanModel(stan_file="mm.stan")
D = 2
r = 2.5
p = 2.0 / 3.0
data = {'D': D, 'r': r, 'p': p}
fit = model.sample(data = data, iter_sampling=5_000)
print(fit.summary(sig_figs=2))

y = fit.stan_variable('y')
df = pd.DataFrame({'y1': y[:, 0], 'y2': y[:, 1]})
plot = (
    pn.ggplot(df, pn.aes(x='y1', y='y2'))
    + pn.geom_vline(xintercept=[-r, r], color="red", linetype="dashed")
    + pn.geom_hline(yintercept=[-r, r], color="red", linetype="dashed")
    + pn.geom_point(alpha=0.1)
    + pn.scale_x_continuous(breaks=[-r, 0, r])
    + pn.scale_y_continuous(breaks=[-r, 0, r])
    + pn.coord_fixed()
    + pn.theme_minimal()
)
plot.save('mm.jpg', dpi=300)

The knockoff of data frames in pandas and ggplot2 in plotnine are a godsend if you’re transitioning to Python from R (which I would highly recommend).

Dynamic programming to the rescue

Because it involved CPS, I mailed it off to Brian Ward around midnight last night. I’m a decent programmer, but Brian’s next level. By the time I arrived today at 10 am, he had rewritten the target density as follows.

  real mm_lpdf(vector y, real r, real p, int d) {
    if (d == 0) {
      return 0;
    }
    real lower_mixture = mm_lpdf(y | r, p, d - 1);
    real lp1 = lower_mixture + normal_lpdf(y[d] | r, 1);
    real lp2 = lower_mixture + normal_lpdf(y[d] | -r, 1);
    return log_mix(p, lp1, lp2);
  }

[Edit: Switched everything to lpdf from a mix of lpdf and lupdf.]

He saw that the recursions were doing the same thing in each branch and could be shared. Because there’s only one recursive call, Brian’s code is linear (i.e., O(D)). It achieves this speedup using dynamic programming (DP). DP calculates partial solutions that can be combined into larger solutions rather than recomputing them. DP’s the technique that you need to solve the harder L33T-code quizzes you’ll get during technical interviews these days. Other examples where DP can be helpful for statistical models include the fast Fourier transform (FFT), the forward algorithm for hidden Markov models (HMMs), and the Poisson-binomial distribution. The first two are coded efficiently in Stan and the latter I showed how to code in a Stan forum post on Poisson-binomial.

Try it yourself in the Stan Playground

If you want to play with this yourself, Brian built a version using the Stan Playground that you can run in the browser.

Here’s what it looks like after setting D = 3, running sampling, and then viewing a histogram with all three dimensions selected.

It’s a live demo, so you can edit the data to set r, D, and p. And it’s really fast due to the DP. Just like in ShinyStan and especially like its generic in-the-browser version MCMCMonitor (from many of the same developers as Stan Playground), you can view 3D projections of the higher-dimensional draws and rotate them to see it making 8 balls in 3D, 7 of which are visible in the screen grab. Or you can go to higher dimensions and view projections down to two or three dimensions. You might want to increase the number of draws per chain to get cleaner delineation of the posterior densities in the visualizations.

Related Posts

Leave a Reply

Your email address will not be published. Required fields are marked *