Latent Variable Models

Directed latent variable models provide a powerful way to represent complex distributions by combining simple ones. However, they often have intractable log-likelihoods, yielding complicated learning algorithms. In this post, I hope to build intuition for these concepts.

It’s taken me a while to fully grasp directed latent variable models, and I hope to use this post as a way to collect my thoughts and provide a concise examination of latent variable representation, learning, and inference. Significant portions of this material are inspired by Stanford’s CS 228 and CS 236 courses, taught by Dr. Stefano Ermon.

Introduction

Many modeling problems are framed in a supervised setting, in which one is provided a dataset $X$ along with outcomes $Y$ with the task of predicting outcomes for new, unseen samples drawn from the same distribution as $X$. Discriminative models learn the conditional distribution $p(Y \mid X)$ directly, and therefore directly predict outcomes given new samples $X$. On the other hand, generative models specify or learn both $p(Y)$ and $p(X \mid Y)$, and compute $p(Y \mid X)$ via Bayes’ rule. Both models are powerful in their own respects: while discriminative models tend to be more expressive as they are only required to learn the conditional, generative models allow for sampling new data from $p(X \mid Y)$ and performing inference with some variables $X_i$ are unobserved by marginalizing over the unseen variables.

In an unsupervised setting, in which one is provided a dataset $X$ without associated outcomes, discriminative modeling assumptions are no longer meaningful. However, generative models remain powerful: instead of specifying distributions $p(Y)$ and $p(X \mid Y)$, we now specify distributions $p(Z)$ and $p(X \mid Z)$ for latent variables $Z$. Intuitively, these variables $Z$ represent unobserved factors of variation that contribute to diversity in $X$; for example, hair color and eye color in a dataset of faces or pencil stroke width and shape in a dataset of handwritten images. Learning such latent representations is both immensely powerful and incredibly challenging: while effective latent variables can improve modeling and understanding of the dataset, the unobserved nature of these variables implies that maximum likelihood cannot be directly applied as in supervised models.

The task of latent variable models (LVMs) is to explicity model latent variables as defined by the Bayes net $Z \to X$, with $Z$ unobserved and $X$ observed. They therefore learn the joint distribution $p(X, Z; \theta)$ for parameters $\theta$. The remainder of this post will provide intuition for, and derivations of, learning algorithms for shallow and deep LVMs.

Shallow Latent Variable Models

We begin our discussion of latent variable models with shallow LVMs, models that consist of a simple relationship between $Z$ and $X$. In particular, these models specify distributions $p(Z)$ and $p(X \mid Z)$ such that the computation of $p(Z \mid X)$ is tractable. One common example is the Gaussian mixture model, which specifies $z \sim \text{Categorical}(1 \dots k)$ and $p(x \mid z = k) = \mathcal{N}(x \mid \mu_k, \sigma_k)$.

To learn such a model via maximum likelihood, we hope to solve the optimization problem

\[\text{argmax}_\theta \prod_{x \in \mathcal{D}} p(x; \theta) = \text{argmax}_\theta \prod_{x \in \mathcal{D}} \sum_{z} p(x, z; \theta)\]

Our log-likelihood function is therefore

\[\ell(\theta; \mathcal{D}) = \sum_{x \in \mathcal{D}} \log \sum_z p(x, z; \theta)\]

which does not admit a decomposition conducive to optimization due to the summation within the logarithm.

A First Attempt: Sampling

One potential solution to the issues posed by the marginalization over $z$ in the likelihood function is to perform a Monte Carlo estimate of the inner sum by sampling $z$ at random and approximating the inner sum with a sample average:

\[\sum_z p(x, z; \theta) = |\mathcal{Z}| \sum_z \frac{1}{|\mathcal{Z}|} p(x, z; \theta) = |\mathcal{Z}| \mathbf{E}_{z \sim \text{Uniform}(\mathcal{Z})} p(x, z; \theta)\]

While this works in theory, in practice such estimates tend to perform poorly as the search space increases exponentially and the majority of randomly selected $z$s yield a small joint probability. A second, more intricate, attempt via importance sampling with proposal distribution $q(z \mid x)$ yields

\[\begin{align} \sum_z p(x, z; \theta) = \sum_z q(z \mid x) \frac{p(x, z; \theta)}{q(z \mid x)} = \mathbf{E}_{z \sim q(z \mid x)} \left[ \frac{p(x, z; \theta)}{q(z \mid x)} \right] \end{align}\]

where we again approximate the expectation with a sample average, this time from the proposal distribution $q(z \mid x)$. Doing so alleviates the issue of few ‘‘hits’’ with uniform random sampling in the naive Monte Carlo estimate, given an appropriate choice of $q(z \mid x)$.

But what should our proposal distribution be? Ideally, we’d like to be sampling $z \sim p(z \mid x; \theta)$ to choose likely values of the latent variables, and so a reasonable choice would be $q(z \mid x) = p (z \mid x; \theta)$. Since we’re working with shallow LVMs, we can safely assume this is a simple distribution we can sample from (we can compute it with Bayes’ theorem). 1

However, while this approach allows us to decompose the joint distribution $p(x, z; \theta)$, writing the log-likelihood function yields

\[\ell(\theta) = \sum_{x \in \mathcal{D}} \log \mathbf{E}_{z \sim q(z \mid x)} \left[ \frac{p(x, z; \theta)}{q(z \mid x)} \right]\]

which presents us with another problem: it is difficult to optimize the logarithm of a sum.

Acheiving Tractability: A Lower Bound on $\ell(\theta)$

In order to transform the logarithm of a sum into a sum of logarithms in the log-likelihood function, we apply Jensen’s inequality. Doing so provides us a lower bound on $\ell(\theta)$:

\[\begin{align} \ell(\theta) &= \sum_{x \in \mathcal{D}} \log \mathbf{E}_{z \sim q(z \mid x)} \left[ \frac{p(x, z; \theta)}{q(z \mid x)} \right] \\ &\geq \sum_{x \in \mathcal{D}} \mathbf{E}_{z \sim q(z \mid x)} \log \left[ \frac{p(x, z; \theta)}{q(z \mid x)} \right] \label{jensen} \tag{1} \\ &= \sum_{x \in \mathcal{D}} \sum_z q(z \mid x) \log \frac{p(x, z; \theta)}{q(z \mid x)} \label{a} \tag{2} \\ &= \sum_{x \in \mathcal{D}} \sum_z q(z \mid x) \log p(x, z; \theta) - \sum_{x \in \mathcal{D}} \sum_z q(z \mid x) \log q(z \mid x) \label{elbo} \tag{3} \\ \end{align}\]

where Equation \ref{jensen} is by Jensen. We have therefore arrived at a lower bound for the likelihood $\ell(\theta)$ that’s optimizable! In fact, the lower bound in Equation \ref{elbo} is so important that we’ll give it a special name: the evidence lower bound (ELBO).

Note that our derivation of the ELBO is independent of a choice of $q(z \mid x)$. However, as it turns out, our intuitive choice of $q(z \mid x) = p(z \mid x; \theta)$ has a beautiful property: it makes the bound tight! For proof, substituting this distribution in Equation \ref{a} yields

\[\begin{align} \sum_{x \in \mathcal{D}} \sum_z q(z \mid x) \log \frac{p(x, z; \theta)}{q(z \mid x)} &= \sum_{x \in \mathcal{D}} \sum_z p(z \mid x) \log \frac{p(x, z; \theta)}{p(z \mid x)} \\ &= \sum_{x \in \mathcal{D}} \sum_z p(z \mid x) \log \frac{p(z \mid x) p(x)}{p(z \mid x)} \\ &= \sum_{x \in \mathcal{D}} \sum_z p(z \mid x) \log p (x) \\ &= \ell(\theta) \end{align}\]

as desired. As a result, choosing $q(z \mid x) = p(z \mid x; \theta)$ guarantees that optimizing the ELBO always increases the likelihood.

Expectation—Maximization

So far, we’ve built intuition for maximizing the LVM log-likelihood function by drawing insights from importance sampling and subsequently obtaining a tractable lower bound on the log-likelihood function (the ELBO). With proposal distribution $q(z \mid x) = p(z \mid x; \theta)$ tractable for shallow LVMs, we are guaranteed that the ELBO is tight.

The expectation—maximization algorithm builds upon these ideas, iteratively optimizing the ELBO over $q$ in the expectation step and the model parameters $\theta$ in the maximization step. Since the ELBO is tight in the expectation step, optimization over $\theta$ in the maximization step is guaranteed to increase the log-likelihood, ensuring that each step of the algorithm makes progress. In particular, the algorithm proceeds as follows:

\[\theta_{t+1} = \text{argmax}_\theta \sum_{x \in \mathcal{D}} \mathbf{E}_{z \sim p(z \mid x; \theta_t)} \log p(x, z; \theta)\]

which is broken down into the ‘‘E’’ and ‘‘M’’ steps as follows.

E(xpectation) step. For each $x \in \mathcal{D}$, compute the proposal distribution $q(z \mid x) = p(z \mid x, \theta_t)$; this is the posterior probability for all values $z$ can take. A common interpretation is that we ``hallucinate’’ the missing values of the latent variables $z$ by computing the distribution over $z$ using our current parameters, $\theta_t$. Note that this computation requires iterating over all values of $z$ in the discrete case and integrating in the continuous case, and is therefore only tractable in the shallow LVM case.

M(aximization) step. Compute $\theta_{t+1}$ given the posterior computed in the E step. This requires computing and updating along the gradient; however, as the logarithm is within the sum, doing so is tractable.

An illustrative example of expectation—maximization for Gaussian mixture models is located here.

Deep Latent Variable Models

We continue our discussion of latent variable models with deep LVMs, models that consist of a more complicated relationship between $Z$ and $X$. In particular, we remove the assumption that the $p(Z)$ and $p(X \mid Z)$ are chosen so that $p(Z \mid X)$ is tractable. While doing so allows for heightened expressivity, it also invalidates the tractability of $p(z \mid x; \theta)$, a requirement for the tightness of the ELBO in expectation—maximization.

One common example of a deep LVM is the variational autoencoder (VAE), which extends the Gaussian mixture model to a mixture of an infinite number of Gaussian distributions. VAEs are specified as $z \sim \mathcal{N}(0, I)$, $p(x \mid z) = \mathcal{N}(\mu_\theta(z), \Sigma_\theta(z))$, and $q(z \mid x; \lambda) = \mathcal{N}(\mu_\lambda(x), \sigma_\lambda(x))$. The necessity of specifying a proposal distribution $q$ will become evident as we build intuition for learning deep LVMs.

The log-likelihood function and learning problem for deep LVMs are the same as those of shallow LVMs.

Revisiting the ELBO

Since the posterior distribution $p(z \mid x; \theta)$ is no longer guaranteed to be tractable, we can no longer tractably compute the expectations with respect to the posterior in the E-step of expectation—maximization. We’ll therefore need a new learning algorithm for deep LVMs; to derive one, let’s begin by revisiting the evidence lower bound (Equation \ref{elbo}).

Recall that the ELBO is a lower bound to the log-likelihood for all choices of proposal distribution $q(z)$. To quantify how poor the bound is for an arbitrary choice of $q(z)$, we can express the KL-divergence between $q(z)$ and $p(z \mid x; \theta)$ as

\[D_{KL} (q(z) \| p(z \mid x; \theta)) = -\sum_z q(z) \log p(z, x; \theta) + \log p(x; \theta) - \sum_z q(z) \log q(z) \geq 0\]

which we rearrange to obtain

\[\ell (\theta) = \log p(x; \theta) = \underbrace{\sum_z q(z) \log p(z, x; \theta) - \sum_z q(z) \log q(z)}_{\text{ELBO}} + D_{KL} (q(z) \| p(z \mid x; \theta))\]

As expected, setting $q(z) = p(z \mid x; \theta)$ makes the ELBO tight since the KL-divergence between identical distributions is zero. More importantly, since $p(z \mid x; \theta)$ is intractable for deep LVMs, this formulation of the ELBO motivates a variational learning algorithm: can we learn a tractable distribution $q(z; \phi)$ to closely approximate $p(z \mid x; \theta)$? Doing so would tighten the ELBO, improving our ability to increase $\ell(\theta)$.

This process is termed variational learning 2 as it involves the optimization of $q(z; \phi)$ in function space. Jointly optimizing over our original paramters $\theta$ and our variational parameters $\phi$ thus provides a reasonable way to maximize the ELBO over a dataset.

Variational Learning

Building upon the intuition derived in the previous section, we can write the ELBO with variational parameters as

\[\begin{align} \mathcal{L}(x; \theta, \phi) &= \sum_z q(z; \phi) \log p(z, x; \theta) - \sum_z q(z; \phi) \log q (z; \phi) \\ &= \mathbf{E}_{z \sim q(z; \phi)} [\log p(z, x; \theta) - \log q (z; \phi)] \end{align}\]

Our new form of maximum likelihood learning over our dataset is to maximize a lower bound to $\ell(\theta)$:

\[\ell (\theta) = \sum_{x^{(i)} \in \mathcal{D}} \log p(x^{(i)}; \theta) \geq \sum_{x^{(i)} \in \mathcal{D}} \mathcal{L}(x^{(i)}; \theta, \phi^{(i)})\]

where we note that each data point $x^{(i)}$ has an associated set of variational parameters $\phi^{(i)}$ as the true posterior $p(z \mid x^{(i)}; \theta)$ is different for each data point $x^{(i)}$. Doing so can be challenging for large datasets (where such large numbers of parameters makes optimization expensive), so we instead choose to learn how to map each $x^{(i)}$ to a good set of parameters $\phi^{(i)}$ via a function $f_\lambda$.3 Specifically, we work with $q(z; f_\lambda(x))$ for each $x$; in the literature (and for the remainder of this post), we write $q(z; f_\lambda(x))$ as $q(z \mid x; \lambda)$. Our ELBO thus has the form

\[\begin{align} \mathcal{L}(x; \theta, \lambda) &= \mathbf{E}_{z \sim q(z \mid x; \lambda)} [\log p(z, x; \theta) - \log q(z \mid x; \lambda)] \end{align}\]

We optimize the ELBO with gradient descent, updating both the model parameters $\theta$ and the variational parameters $\lambda$. Our learning algorithm is therefore

Computing Variational Gradients

Now that we have a learning algorithm, the final piece of the puzzle is to compute the gradients of the ELBO with respect to $\theta$ and $\lambda$.

Let’s first examine the gradient with respect to $\theta$. We simply have

\[\begin{align} \nabla_\theta \mathcal{L}(x; \theta, \lambda) &= \nabla_\theta \mathbf{E}_{z \sim q(z \mid x; \lambda)} [\log p(z, x; \theta) - \log q(z \mid x; \lambda)] \\ &= \mathbf{E}_{z \sim q(z \mid x; \lambda)} [\nabla_\theta \log p(z, x; \theta)] \end{align}\]

which we can approximate with Monte Carlo sampling from $q(z \mid x; \lambda)$.

Let’s next consider the gradient with respect to $\lambda$. We have

\[\begin{align} \nabla_\lambda \mathcal{L}(x; \theta, \lambda) &= \nabla_\lambda \mathbf{E}_{z \sim q(z \mid x; \lambda)} [\log p(z, x; \theta) - \log q(z \mid x; \lambda)] \end{align}\]

but we can’t simply pass the gradient through the expectation as before since the expectation is itself parameterized by $\lambda$. We can solve this problem in two ways: a general technique from reinforcement learning called REINFORCE, and a more stable (but specific) technique called the reparameterization trick. An excellent article explaining and comparing the two is here: REINFORCE yields

\[\nabla_\lambda \mathcal{L}(x; \theta, \lambda) = \mathbf{E}_{z \sim q(z \mid x; \lambda)} [(\log p(z, x; \theta) - \log q(z \mid x; \lambda)) \nabla_\lambda \log q(z \mid x; \lambda)]\]

while the reparametrization trick varies depending on the choice of $q(z \mid x; \lambda)$ (and only works for continuous $q$ with specific properties); further information can be found here.

Interpreting Variational Autoencoders

Earlier in this section, discussing the VAE model required specification of the variational proposal distribution $q(z \mid x; \lambda)$; as evident from our derivation of the variational learning algorithm, specifying the class of distributions from which $q$ is to be learned is necessary. A common interpretation is that $q(z \mid x; \lambda)$ acts as an ‘‘encoder’’ to latent representation $z$, and $p(x \mid z; \theta)$ acts as a ‘‘decoder’’ to the true data distribution.

It so happens that specifying $p (x \mid z; \theta)$ and $q (z \mid x; \lambda)$ as normal distributions as in (Kingma & Welling, 2013) allows for an analytical simplification of the ELBO as

\[\mathcal{L}(x; \theta, \lambda) = \underbrace{D_{KL} (q(z \mid x; \lambda) \| p(z))}_{\text{Analytically compute this}} + \underbrace{\mathbf{E}_{z \sim q(z \mid x; \lambda)} \log p(x \mid z; \theta)}_{\text{Monte Carlo estimate this}}\]

This representation also has a nice interpretation: the first term encourages latent representations to be likely under the prior $p(z)$, and the second term encourages $x$ to be likely given its latent representation.

Summary & Further Reading

Latent variable models are incredibly useful frameworks that combine simple distributions to create more complicated ones. Defined by the Bayes net structure $Z \to X$, they permit ancestral sampling for efficient generation ($z \sim p(z)$ and $x \sim p(x \mid z; \theta))$, but often have intractable log-likelihoods, making learning difficult.

Both shallow and deep LVMs therefore optimize a lower bound to the log likelihood, called the ELBO. While shallow LVMs make the ELBO tight by explicitly computing $q(z \mid x; \theta) = p(z \mid x; \theta)$, this computation is intractable for deep LVMs, which use variational learning to learn a distribution $q(z \mid x; \lambda)$ that best approximates $p(z \mid x; \theta)$. Jointly learning the model $\theta$ and the amortized inference component $\lambda$ helps deep LVMs acheive tractability for ELBO optimization.

Many other types of latent variable models which perform learning without worrying about the ELBO weren’t covered in this post. For the interested reader, normalizing flow models (using invertible transformations) and GANs (likelihood-free models) are exciting avenues for further reading.

Notes

  1. If this worries you, that’s good! We throw away this assumption with deep latent variable models, but doing so makes learning far more complicated… 

  2. '’Variational’’ is a term borrowed from variational calculus; in our context, it refers to the process of optimizing over functions. 

  3. Learning a single parametric function $f_\lambda : x \to \phi$ that maps each $x$ to a set of variational parameters is called amortized inference, as the process of inferring $z$ given $x$ is amortized over all training examples for sake of tractability. 

  1. Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes. ArXiv Preprint ArXiv:1312.6114.