Over the last few years we have experienced an enormous data deluge, which has played a key role in the surge of interest in AI. A partial list of some large datasets:
- ImageNet, with over 14 million images for classification and object detection.
- Movielens, with 20 million user ratings of movies for collaborative filtering.
- Udacity’s car dataset (at least 223GB) for training self-driving cars.
- Yahoo’s 13.5 TB dataset of user-news interaction for studying human behavior.
Stochastic Gradient Descent (SGD) has been the engine fueling the development of large-scale models for these datasets. SGD is remarkably well-suited to large datasets: it estimates the gradient of the loss function on a full dataset using only a fixed-sized minibatch, and updates a model many times with each pass over the dataset.
But SGD has limitations. When we construct a model, we use a loss function
The model posterior
MCMC methods have no such bias. You can think of MCMC particles as rather like quantum-mechanical particles: you only observe individual instances, but they follow an arbitrarily-complex joint distribution. By taking multiple samples you can infer useful statistics, apply regularizing terms, etc. But MCMC methods have one over-riding problem with respect to large datasets: other than the important class of conjugate models which admit Gibbs sampling, there has been no efficient way to do the Metropolis-Hastings tests required by general MCMC methods on minibatches of data (we will define/review MH tests in a moment). In response, researchers had to design models to make inference tractable, e.g. Restricted Boltzmann Machines (RBMs) use a layered, undirected design to make Gibbs sampling possible. In a recent breakthrough, VAEs use variational methods to support more general posterior distributions in probabilistic auto-encoders. But with VAEs, like other variational models, one has to live with the fact that the model is a best-fit approximation, with (usually) no quantification of how close the approximation is. Although they typically offer better accuracy, MCMC methods have been sidelined recently in auto-encoder applications, lacking an efficient scalable MH test.
A bridge between SGD and Bayesian modeling has been forged recently by papers on
Stochastic Gradient Langevin Dynamics (SGLD) and Stochastic Gradient
Hamiltonian Monte Carlo (SGHMC). These methods involve minor variations to
typical SGD updates which generate samples from a probability distribution which
is approximately the Bayesian model posterior
Because of these developments, interest has warmed recently in scalable MCMC and in particular in doing the MH tests required by general MCMC models on large datasets. Normally an MH test requires a scan of the full dataset and is applied each time one wants a data sample. Clearly for large datasets, it’s intractable to do this. Two papers from ICML 2014, Korattikara et al. and Bardenet et al., attempt to reduce the cost of MH tests. They both use concentration bounds, and both achieve constant-factor improvements relative to a full dataset scan. Other recent work improves performance but makes even stronger assumptions about the model which limits applicability, especially for deep networks. None of these approaches come close to matching the performance of SGD, i.e. generating a posterior sample from small constant-size batches of data.
In this post we describe a new approach to MH testing which moves the cost of MH
testing from
To explain the approach, we review the role of MH tests in MCMC models.
Markov Chain Monte Carlo Review
Markov Chains
MCMC methods are designed to sample from a target distribution which is difficult to compute. To generate samples, they utilize Markov Chains, which consist of nodes representing states of the system and probability distributions for transitioning from one state to another.
A key concept is the Markovian assumption, which states that the probability
of being in a state at time
Since the probability of being in state
A full review of MCMC methods is beyond the scope of this post, but a good reference is the Handbook of Markov Chain Monte Carlo (2011). Standard machine learning textbooks such as Koller & Friedman (2009) and Murphy (2012) also cover MCMC methods.
Metropolis-Hastings
One of the most general and powerful MCMC methods is
Metropolis-Hastings. This uses a test to filter samples. To define
it properly, let
If we were to just use a Gaussian to generate samples in our chain, there’s no
way we could approximate our target
If true, we accept
- It doesn’t require knowledge of a normalizing constant (independent of
θ andθ′ ), because that cancels out in thep(θ′)/p(θ) ratio. This is great, because normalizing constants are arguably the biggest reason why distributions become intractable. - The higher the value of
p(θ′) , the more likely we are to accept.
To get more intuition on how the test works, we’ve created the following figure from this Jupyter Notebook, showing the progression of samples to approximate a target posterior. This example is derived from Welling & Teh (2011).
A quick example of the MH test in action on a mixture of Gaussians example. The
parameter is
Reducing Metropolis-Hastings Data Usage
What happens when we consider the Bayesian posterior inference case with large
datasets? (Perhaps we’re interested in the same example in the figure above,
except that the posterior is based on more data points.) Then our goal is to
sample to approximate the distribution
Or, after taking logarithms and rearranging (while ignoring the minimum operator, which technically isn’t needed here), we get
The problem now is apparent: it’s expensive to compute all the
The naive way to deal with this is to apply the same test, but with a minibatch
of
Unfortunately, this won’t sample from the correct target distribution; see Section 6.1 in Bardenet et al. (2017) for details.
A better strategy is to start with the same batch of
A weakness of the above approach is that it’s doing repeated testing and one must reduce the allowable test error each time one increments the test batch size. Unfortunately, there is also a significant probability that the approaches above will grow the test batch all the way to the full dataset, and they offer at most constant factor speedups over testing the full dataset.
Minibatch Metropolis-Hastings: Our Contribution
Change the Acceptance Function
To set up our test, we first define the log transition probability ratio
This log ratio factors into a sum of per-sample terms, so when we approximate its value by computing on a minibatch we get an unbiased estimator of its full-data value plus some noise (which is asymptotically normal by the Central Limit Theorem).
The first step for applying our MH test is to use a different acceptance
function. Expressed in terms of
Functions
Instead of using the classical test, we’ll use the sigmoid function. It might
not be apparent why this is allowed, but there’s some elegant theory that
explains why using this alternative function as the acceptance test for MH
still results in the correct semantics of MCMC. That is, under the same mild
assumptions, the distribution of samples
The density of the standard logistic random variable, denoted
Our acceptance test is now the sigmoid function. Note that the sigmoid function
is the cumulative distribution function of a (standard) Logistic random
variable; the figure above plots the density. One can show that the MH test
under the sigmoid acceptance function reduces to determining whether
New MH Test
This is nice, but we don’t want to compute
A plot of the logistic CDF in red (as we had earlier) along with a normal CDF
curve, colored in lime, which corresponds to a standard deviation of 1.7.
Rather than resorting to tail bounds as in prior work, we directly bridge these
two distributions using an additive correction variable
A diagram of our minibatch MH test. On the right we have the full data test that
we want, but we can't use it since
We want to make the LHS and RHS distributions equal, so we add in a correction
In our paper, we also prove theoretical results bounding the error of our test, and present experimental results showing that our method results in accurate posterior estimation for a Gaussian Mixture Model, and that it is also highly sample-efficient in Logistic Regression for classification of MNIST digits.
Histograms showing the batch sizes used for Metropolis-Hastings for the three
algorithms benchmarked in our paper. The posterior is similar to the earlier
example from the Jupyter Notebook, except generated with one million data
points. Left is our result, the other two are from Korattikara et al. (2014), and Bardenet et al.
(2014), respectively. Our algorithm uses an average of just 172 data points
each iteration. Note the log-log scale of the histograms.
No comments:
Post a Comment