Bayesian deep learning methods often look like a theoretical curiosity, rather than a practically useful tool, and I’m personally a bit skeptical about the practical usefulness of some of the work. However, there are some situations a decent method of handling and representing residual uncertainties about model parameters might prove crucial. These applications include active learning, reinforcement learning and online/continual learning.
So as I recently read a paper by Tencent, I was surprised to learn that the online Bayesian deep learning algorithm is apparently deployed in production to power click-through-rate prediction in their ad system. I thought, therefore, that this is worth a mention. Here is the paper:
Though the paper has a few typos and sometimes inconsistent notation (like going back and forth between using $\omega$ and $w$) which can make it tricky to read, I like the paper’s layout: starting from desiderata - what the system has to be able to do - arriving at an elegant solution.
Assumed Density Filtering
The method relies on the approximate Bayesian online-learning technique often referred to as assumed density filtering.
ADF has been independently discovered by the statistics, machine learning and control communities (for citations see (Minka, 2001). It is perhaps most elegantly described by Opper, (1998), and by Minka, (2001) who also extended the idea to develop expectation propagation, a highly influential method for approximate Bayesian inference not just for online settings.
ADF can be explained as recursive algorithm repeating the following steps:
-
The starting point for the algorithm at time $t$ is a “temporal prior” distribution $q_{t-1}(w)$ over parameters $w$. This $q_{t-1}(w)$ incorporates all evidence from datapoints observed in previous timesteps, and it is assumed to approximate the posterior $p(w\vert x_{1:t-1})$ conditioned on all data oberved so far. $q_{t-1}(w)$ is assumed to take some simple form, like a Gaussian distribution factorized over the elements of $w$.
-
Then, some new observations $x_t$ are observed. The Bayesian way to learn from these new observations would be to update the posterior: $p_t(w x_t) \propto p(x_t\vert w)q_{t-1}(w)$ -
However, this posterior might be complicated, and may not be available in a simple form. Therefore, at this point we approximate $p_t(w x_t)$ by a new simple distribution $q_{t}$, which is in the same family as $q_{t-1}$. This step can involve a KL divergence-based approximation, e.g. (Opper, 1998, Minka, 2001), or Laplace approximation e.g. (Kirkpatrick et al, 2017, Huszár, 2017), or a more involved inner loop such as EP in MatchBox (Stern et al, 2009) or probabilistic backpropagation (Hernandez-Lobato and Adams, 2015) as in this paper.
Probabilistic Backpropagation
The Tencent paper is based on probabilistic backpropagation (Hernandez-Lobato and Adams, 2015), which uses the ADF idea multiple times to perform the two non-trivial tasks required in a supervised Bayesian deep network: inference and learning.
forward propagation: In Bayesian deep learning, we maintain a distribution $q(w)$ over neural network weights, and each value $w$ defines a conditional probability $p(y\vert x, w)$. To predict the label $y$ from the input $x$, we have to average over $q(w)$, that is calculate $p(y\vert x) = \int q(w)p(y\vert x,w)dw$, which is difficult due to the nonlinearities. Probabilistic backprop uses an ADF-like algorithm to approximate this predictive distribution. Starting from the bottom of the network, it approximates the distribution of the first hidden layer’s activations given the input $x$ with a Gaussian. The first hidden layer will have a distribution of activations because we have a distribution over the weights) It then propagates that distribution to the second layer, and approximates the result with a Gaussian. This process is repeated until the distribution of $y$ given $x$ is calculated in the last step, which can be done easily if one uses a probit activation at the last layer.
-
backward propagation: Forward propagation allows us to make a prediction given an input. The second task we have to be able to perform is to incorporate evidence from a new datapoint $(x_t, y_t)$ by updating the distribution over weights $q_t(w)$ to $p_t(w x_t, y_t) \propto p(y_t\vert x_t, w) q_{t-1}(w)$. We approximate this $p_t(w x_t, y_t)$ in an inner loop, by first running probabilistic forward propagation, then a similar ADF-like sweep backwards in the network.
I find the similarity of this algorithm to gradient calculations in neural network training beautiful: forward propagation to perform inference, forward and backward propagation for learning and parameter updates.
Crucially though, there is no gradient descent, or indeed no gradients or iterative optimization taking place here. After a single forward-backward cycle, information about the new datapoint (or minibatch of data) is - approximately - incorporated into the updated distribution $q_{t}(w)$ of network weights. Once this is done, the mini-batch can be discarded and never revisited. This makes this method a great candidate for online learning in large-scale streaming data situations.
Parallel operation
Another advantage of this method is that it can be parallelized in a data-parallel fashion: Multiple workers can update the posterior simultaneously on different minibatches of data. Expectation-propagation provides a meaningful way of combining the parameter updates computed by parallel workers in a meaningful fashion. Indeed, this is what the PBODL algorithm of Tencent does.
Advantages
Bayesian online learning comes with great advantages. In recommender systems and ads marketplaces, the Bayesian online learning approach handles the user and item cold-start problem gracefully. If there is a new ad in the system that has to be recommended, initially our network might say “I don’t really know”, and then gradually hone in on a confident prediction as more data is observed.
One can use the uncertainty estimates in a Bayesian network to perform online learning: actively proposing which items to label in order to increase predictive performance in the future. In the context of ads, this may be a useful feature if implemented with care.
Finally, there is the useful principle that Bayes never forgets: if we perform exact Bayesian updates on global parameters of an exchangeable model, the posterior will store information indefinitely long about all datapoints, including very early ones. This advantage is clearly demonstrated in DeepMind’s work on catastrophic forgetting (Kirkpatrick et al, 2017). This capacity to keep a long memory of course diminishes the more approximations we have to make, which leads me to drawbacks.
Drawbacks
The ADF method is only approximate, and over time, the approximations in each step may accumulate resulting in the network to essentially forget what it learned from earlier datapoints. It is also worth pointing out that in stationary situations, the posterior over parameters is expected to shrink, especially in the last hidden layers of the network where there are often fewer parameters. This posterior compression is often countered by introducing explicit forgetting: without evidence to the contrary, the variance of each parameter is marginally increased in each step. Forgetting and the lack of infinite memory may turn out to be an advantage, not a bug, in non-stationary situations where more recent datapoints are more representative of future test cases.
A practical drawback of a method like this in production environments is that probabilistic forward- and backpropagation require non-trivial custom implementations. Probabilistic backprop does not use reverse mode automatic differentiation, i.e. vanilla backprop as a subroutine. As a consequence, one cannot rely on extensively tested and widely used autodiff packages in tensorflow or pytorch, and performing probabilistic backprop in new layer-types might require significant effort and custom implementations. It is possible that high-quality message passing packages such as the recently open-sourced infer.net will see a renaissance after all.