Hire a web Developer and Designer to upgrade and boost your online presence with cutting edge Technologies

Monday, 12 December 2022

Minibatch Metropolis-Hastings

 Over the last few years we have experienced an enormous data deluge, which has played a key role in the surge of interest in AI. A partial list of some large datasets:

  • ImageNet, with over 14 million images for classification and object detection.
  • Movielens, with 20 million user ratings of movies for collaborative filtering.
  • Udacity’s car dataset (at least 223GB) for training self-driving cars.
  • Yahoo’s 13.5 TB dataset of user-news interaction for studying human behavior.

Stochastic Gradient Descent (SGD) has been the engine fueling the development of large-scale models for these datasets. SGD is remarkably well-suited to large datasets: it estimates the gradient of the loss function on a full dataset using only a fixed-sized minibatch, and updates a model many times with each pass over the dataset.

But SGD has limitations. When we construct a model, we use a loss function Lθ(x) with dataset x and model parameters θ and attempt to minimize the loss by gradient descent on θ. This shortcut approach makes optimization easy, but is vulnerable to a variety of problems including over-fitting, excessively sensitive coefficient values, and possibly slow convergence. A more robust approach is to treat the inference problem for θ as a full-blown posterior inference, deriving a joint distribution p(x,θ) from the loss function, and computing the posterior p(θ|x). This is the Bayesian modeling approach, and specifically the Bayesian Neural Network approach when applied to deep models. This recent tutorial by Zoubin Ghahramani discusses some of the advantages of this approach.

The model posterior p(θ|x) for most problems is intractable (no closed form). There are two methods in Machine Learning to work around intractable posteriors: Variational Bayesian methods and Markov Chain Monte Carlo (MCMC). In variational methods, the posterior is approximated with a simpler distribution (e.g. a normal distribution) and its distance to the true posterior is minimized. In MCMC methods, the posterior is approximated as a sequence of correlated samples (points or particle densities). Variational Bayes methods have been widely used but often introduce significant error — see this recent comparison with Gibbs Sampling, also Figure 3 from the Variational Autoencoder (VAE) paper. Variational methods are also more computationally expensive than direct parameter SGD (it’s a small constant factor, but a small constant times 1-10 days can be quite important).

MCMC methods have no such bias. You can think of MCMC particles as rather like quantum-mechanical particles: you only observe individual instances, but they follow an arbitrarily-complex joint distribution. By taking multiple samples you can infer useful statistics, apply regularizing terms, etc. But MCMC methods have one over-riding problem with respect to large datasets: other than the important class of conjugate models which admit Gibbs sampling, there has been no efficient way to do the Metropolis-Hastings tests required by general MCMC methods on minibatches of data (we will define/review MH tests in a moment). In response, researchers had to design models to make inference tractable, e.g. Restricted Boltzmann Machines (RBMs) use a layered, undirected design to make Gibbs sampling possible. In a recent breakthrough, VAEs use variational methods to support more general posterior distributions in probabilistic auto-encoders. But with VAEs, like other variational models, one has to live with the fact that the model is a best-fit approximation, with (usually) no quantification of how close the approximation is. Although they typically offer better accuracy, MCMC methods have been sidelined recently in auto-encoder applications, lacking an efficient scalable MH test.

A bridge between SGD and Bayesian modeling has been forged recently by papers on Stochastic Gradient Langevin Dynamics (SGLD) and Stochastic Gradient Hamiltonian Monte Carlo (SGHMC). These methods involve minor variations to typical SGD updates which generate samples from a probability distribution which is approximately the Bayesian model posterior p(θ|x). These approaches turn SGD into an MCMC method, and as such require Metropolis-Hastings (MH) tests for accurate results, the topic of this blog post.

Because of these developments, interest has warmed recently in scalable MCMC and in particular in doing the MH tests required by general MCMC models on large datasets. Normally an MH test requires a scan of the full dataset and is applied each time one wants a data sample. Clearly for large datasets, it’s intractable to do this. Two papers from ICML 2014, Korattikara et al. and Bardenet et al., attempt to reduce the cost of MH tests. They both use concentration bounds, and both achieve constant-factor improvements relative to a full dataset scan. Other recent work improves performance but makes even stronger assumptions about the model which limits applicability, especially for deep networks. None of these approaches come close to matching the performance of SGD, i.e. generating a posterior sample from small constant-size batches of data.

In this post we describe a new approach to MH testing which moves the cost of MH testing from O(N) to O(1) relative to dataset size. It avoids the need for global statistics and does not use tail bounds (which lead to long-tailed distributions for the amount of data required for a test). Instead we use a novel correction distribution to directly “morph” the distribution of a noisy minibatch estimator into a smooth MH test distribution. Our method is a true “black-box” method which provides estimates on the accuracy of each MH test using only data from a small expected size minibatch. It can even be applied to unbounded data streams. It can be “piggy-backed” on existing SGD implementations to provide full posterior samples (via SGLD or SGHMC) for almost the same cost as SGD samples. Thus full Bayesian neural network modeling is now possible for about the same cost as SGD optimization. Our approach is also a potential substitute for variational methods and VAEs, providing unbiased posterior samples at lower cost.

To explain the approach, we review the role of MH tests in MCMC models.

Markov Chain Monte Carlo Review

Markov Chains

MCMC methods are designed to sample from a target distribution which is difficult to compute. To generate samples, they utilize Markov Chains, which consist of nodes representing states of the system and probability distributions for transitioning from one state to another.

A key concept is the Markovian assumption, which states that the probability of being in a state at time t+1 can be inferred entirely based on the current state at time t. Mathematically, letting θt represent the current state of the Markov chain at time t, we have p(θt+1|θt,,θ0)=p(θt+1|θt). By using these probability distributions, we can generate a chain of samples (θi)Ti=1 for some large T.

Since the probability of being in state θt+1 directly depends on θt, the samples are correlated. Rather surprisingly, it can be shown that, under mild assumptions, in the limit of many samples the distribution of the chain’s samples approximates the target distribution.

A full review of MCMC methods is beyond the scope of this post, but a good reference is the Handbook of Markov Chain Monte Carlo (2011). Standard machine learning textbooks such as Koller & Friedman (2009) and Murphy (2012) also cover MCMC methods.

Metropolis-Hastings

One of the most general and powerful MCMC methods is Metropolis-Hastings. This uses a test to filter samples. To define it properly, let p(θ) be the target distribution we want to approximate. In general, it’s intractable to sample directly from it. Metropolis-Hastings uses a simpler proposal distribution q(θ|θ) to generate samples. Here, θ represents our current sample in the chain, and θ represents the proposed sample. For simple cases, it’s common to use a Gaussian proposal centered at θ.

If we were to just use a Gaussian to generate samples in our chain, there’s no way we could approximate our target p, since the samples would form a random walk. The MH test cleverly resolves this by filtering samples with the following test. Draw a uniform random variable u[0,1] and determine whether the following is true:

u<?min{p(θ)q(θ|θ)p(θ)q(θ|θ),1}

If true, we accept θ. Otherwise, we reject and reuse the old sample θ. Notice that

  • It doesn’t require knowledge of a normalizing constant (independent of θ and θ), because that cancels out in the p(θ)/p(θ) ratio. This is great, because normalizing constants are arguably the biggest reason why distributions become intractable.
  • The higher the value of p(θ), the more likely we are to accept.

To get more intuition on how the test works, we’ve created the following figure from this Jupyter Notebook, showing the progression of samples to approximate a target posterior. This example is derived from Welling & Teh (2011).

jupyter_notebook
A quick example of the MH test in action on a mixture of Gaussians example. The parameter is θR2 with the x and y axes representing θ1 and θ2, respectively. The target posterior has contours shown in the fourth plot; the probability mass is concentrated in the diagonal between points (0,1) and (1,1). (This posterior depends on sampled Gaussians.) The plots show the progression of the MH test after 50, 500, and 5000 samples in our MCMC chain. After 5000 samples, it's clear that our samples are concentrated in the regions with higher posterior probability.

Reducing Metropolis-Hastings Data Usage

What happens when we consider the Bayesian posterior inference case with large datasets? (Perhaps we’re interested in the same example in the figure above, except that the posterior is based on more data points.) Then our goal is to sample to approximate the distribution p(θ|x1,,xN) for large N. By Bayes’ rule, this is p0(θ)p(x1,,xN|θ)p(x1,,xN) where p0 is the prior. We additionally assume that the xi are conditionally independent given θ. The MH test therefore becomes:

u<?min{p0(θ)Ni=1p(xi|θ)q(θ|θ)p0(θ)Ni=1p(xi|θ)q(θ|θ),1}

Or, after taking logarithms and rearranging (while ignoring the minimum operator, which technically isn’t needed here), we get

log(uq(θ|θ)p0(θ)q(θ|θ)p0(θ))<?i=1Nlogp(xi|θ)p(xi|θ)

The problem now is apparent: it’s expensive to compute all the p(xi|θ) terms, and this has to be done every time we sample since it depends on θ.

The naive way to deal with this is to apply the same test, but with a minibatch of b elements:

log(uq(θ|θ)p0(θ)q(θ|θ)p0(θ))<?Nbi=1blogp(xi|θ)p(xi|θ)

Unfortunately, this won’t sample from the correct target distribution; see Section 6.1 in Bardenet et al. (2017) for details.

A better strategy is to start with the same batch of b points, but then gauge the confidence of the batch test relative to using the full data. If, after seeing b points, we already know that our proposed sample θ is significantly worse than our current sample θ, then we should reject right away. If θ is significantly better, we should accept. If it’s ambiguous, then we increase the size of our test batch, perhaps to 2b elements, and then measure the test’s confidence. Lather, rinse, repeat. As mentioned earlier, Korattikara et al. (2014) and Bardenet et al. (2014) developed algorithms following this framework.

A weakness of the above approach is that it’s doing repeated testing and one must reduce the allowable test error each time one increments the test batch size. Unfortunately, there is also a significant probability that the approaches above will grow the test batch all the way to the full dataset, and they offer at most constant factor speedups over testing the full dataset.

Minibatch Metropolis-Hastings: Our Contribution

Change the Acceptance Function

To set up our test, we first define the log transition probability ratio Δ:

Δ(θ,θ)=logp0(θ)Ni=1p(xi|θ)q(θ|θ)p0(θ)Ni=1p(xi|θ)q(θ|θ)

This log ratio factors into a sum of per-sample terms, so when we approximate its value by computing on a minibatch we get an unbiased estimator of its full-data value plus some noise (which is asymptotically normal by the Central Limit Theorem).

The first step for applying our MH test is to use a different acceptance function. Expressed in terms of Δ, the classical MH accepts a transition with probability given by the blue curve.

different_tests
Functions f and g can serve as acceptance tests for Metropolis-Hastings. Given current sample θ and proposed sample θ, the vertical axis represents the probability of accepting θ.

Instead of using the classical test, we’ll use the sigmoid function. It might not be apparent why this is allowed, but there’s some elegant theory that explains why using this alternative function as the acceptance test for MH still results in the correct semantics of MCMC. That is, under the same mild assumptions, the distribution of samples (θi)Ti=1 approaches the target distribution.

equivalent_test
The density of the standard logistic random variable, denoted Xlog along with the equivalent MH test expression (Xlog+Δ>0) with the sigmoid acceptance function.

Our acceptance test is now the sigmoid function. Note that the sigmoid function is the cumulative distribution function of a (standard) Logistic random variable; the figure above plots the density. One can show that the MH test under the sigmoid acceptance function reduces to determining whether Xlog+Δ>0 for a sampled Xlog value.

New MH Test

This is nice, but we don’t want to compute Δ because it depends on all p(xi|θ) terms. When we estimate Δ using a minibatch, we introduce an additive error which is approximately normal, Xnormal. The key observation in our work is that the distribution of the minibatch estimate of Δ (approximately Gaussian) is already very close to the desired test distribution Xlog, as shown below.

gaussian_logistic_cdf
A plot of the logistic CDF in red (as we had earlier) along with a normal CDF curve, colored in lime, which corresponds to a standard deviation of 1.7.

Rather than resorting to tail bounds as in prior work, we directly bridge these two distributions using an additive correction variable Xcorrection:

test_visual
A diagram of our minibatch MH test. On the right we have the full data test that we want, but we can't use it since Δ is intractable. Instead, we have Δ+Xnormal (from the left side) and must add a correction Xcorrection.

We want to make the LHS and RHS distributions equal, so we add in a correction Xcorrection which is a symmetric random variable centered at zero. Adding independent random variables gives a random variable whose distribution is the convolution of the summands’ distributions. So finding the correction distribution involves “deconvolution” of a logistic and normal distribution. It’s not always possible to do this, and several conditions must be met (e.g. the tails of the normal distribution must be weaker than the logistic) but luckily for us they are. In our paper to appear at UAI 2017 we show that the correction distribution can be approximated to essentially single-precision floating-point precision by tabulation.

In our paper, we also prove theoretical results bounding the error of our test, and present experimental results showing that our method results in accurate posterior estimation for a Gaussian Mixture Model, and that it is also highly sample-efficient in Logistic Regression for classification of MNIST digits.

paper_results
Histograms showing the batch sizes used for Metropolis-Hastings for the three algorithms benchmarked in our paper. The posterior is similar to the earlier example from the Jupyter Notebook, except generated with one million data points. Left is our result, the other two are from Korattikara et al. (2014), and Bardenet et al. (2014), respectively. Our algorithm uses an average of just 172 data points each iteration. Note the log-log scale of the histograms. 

No comments:

Post a Comment

Connect broadband

How to Develop a Character-Based Neural Language Model in Keras

  A   language model   predicts the next word in the sequence based on the specific words that have come before it in the sequence. It is al...