Variational Bayes and the evidence lower bound

by benmoran

Variational methods for Bayesian inference have been enjoying a renaissance recently in machine learning.

Problem: normalization can be intractable when applying Bayes’ Theorem

Given a likelihood function and a prior distribution that we can evaluate, p(y \vert z)p(z) = p(y,z) is the joint likelihood.

The posterior is just p(z\vert y) = p(y\vert z)p(z)/p(y) where we divide by the evidence p(y):

p(y) = \int p(y\vert z) p(z) dz = \mathbb{E}_{p(z)}[p(y\vert z)]

However, z is frequently intractable. For example, it may not admit a closed form solution, and it is frequently high-dimensional, so even numerical methods like quadrature may not help much.

Here are three techniques we can use to approach it:

  1. If the posterior p(z\vert y) and prior p(z) are of particular forms so that they are conjugate to one another, then the integral will have a simple closed form. Then the updates to the prior from the likelihood are often trivial to calculate.
  2. It is possible to draw samples from the posterior before we know the normalization factor. Then we can approximate the expectation stochastically by sample averages. By drawing enough samples the expectations converge to the true values – the theory behind MCMC techniques. Two hindrances arise: it is difficult to know how many samples is “enough”; and “enough samples” can also take a long time to generate.
  3. We can introduce an approximating distribution q(z) \approx p(z\vert y). If we can somehow measure the quality of the approximation and iteratively improve it as far as possible, we can use this approximation with confidence. This is the variational approach derived below.

Variational lower bound

Start with the KL from q to p, and rearrange to isolate the interesting quantity p(y):

D_{KL}(q(z) \Vert p(z \vert y)) = \int q(z) \log \frac{q(z)}{p(z \vert y)} dz = \int q(z) \log \frac{q(z)p(y)}{p(z, y)} dz

= \mathbb{E}_{q(z)}[\log q(z) - \log p(z, y) + \log p(y)] dz

But p(y) doesn’t depend on z so we can pull it out of the expectation:

D_{KL}(q(z) \Vert p(z \vert y)) = \mathbb{E}_{q(z)}[\log q(z) - \log p(z, y)] + \log p(y)

Rearranging, we get

\log p(y) = D_{KL}(q(z) \Vert p(z \vert y) ) + \mathbb{E}_{q(z)}[\log p(z, y) - \log q(z) ]

Because we saw previously that D_{KL} \geq 0, we have

\log p(y) \geq \mathbb{E}_{q(z)}[\log p(z, y) - \log q(z)] = L[q]

This quantity L[q] is the evidence lower bound (ELBO). It is a functional of the approximating distribution q(z).

This bound is valuable because it can be calculated without the unknown normalizing constant p(y). However it is equal to p(y) at its maximum, when D_{KL}(q(z)\Vert p(z\vert y))=0, which also implies q(z) = p(z \vert y).

We have transformed the problem of taking expectations into one of optimization. Now a new question arises – is this problem any easier than the integral we started with? Not necessarily! However we can now make different choices for the form of q(z), so if we can find a family of distributions that is amenable to our available optimization techniques and which also contains a good approximation to the true posterior, we will be happy with the trade-off.

For example, we can rewrite L in terms of yet another KL divergence, this time between the approximate posterior q(z) and the prior p(z) on the latent variables:

L[q] = \mathbb{E}_{q(z)}[\log p(y\vert z)p(z) - \log q(z)]

= \mathbb{E}_{q(z)}[\log p(y\vert z) - \log \frac{q(z)}{p(z)}]

= \mathbb{E}_{q(z)}[\log p(y\vert z) ] - D_{KL}(q(z) \Vert p(z))

If q(z) and p(z) have the same form – for instance we have chosen them both to be Gaussian – then the second term will have a closed form expression, so if we have a good way to evaluate the first term, as in Kingma & Welling 2013, then we’ll be able to optimize q(z) and solve the problem.

(This bound gets its name from the variational principle, applying the calculus of variations to optimize the function q(z) without assuming a particular form. However, we frequently assume a fixed-form approximation using a parametric form of this density. In this case no calculus of variations is required, and the problem reduces to an ordinary optimization.)