Batch normalization, as described in the March 2015 paper (the BN2015 paper) by Sergey Ioffe and Christian Szegedy, is a simple and effective way to improve the performance of a neural network. In the BN2015 paper, Ioffe and Szegedy show that batch normalization enables the use of higher learning rates, acts as a regularizer and can speed up training by 14 times. In this post, I show how to implement batch normalization in Tensorflow.
Edit 2018 (that should have been made back in 2016): If you’re just looking for a working implementation, Tensorflow has an easy to use batch_normalization layer in the tf.layers module. Just be sure to wrap your training step in a with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
and it will work.
Edit 07/12/16: I’ve updated this post to cover the calculation of population mean and variance at test time in more detail.
Edit 02/08/16: In case you are looking for recurrent batch normalization (i.e., from Cooijmans et al. (2016)), I have uploaded a working Tensorflow implementation here. The only tricky part of the implementation, as compared to the feedforward batch normalization presented this post, is storing separate population variables for different timesteps.
The problem
Batch normalization is intended to solve the following problem: Changes in model parameters during learning change the distributions of the outputs of each hidden layer. This means that later layers need to adapt to these (often noisy) changes during training.
Batch normalization in brief
To solve this problem, the BN2015 paper propposes the batch normalization of the input to the activation function of each nuron (e.g., each sigmoid or ReLU function) during training, so that the input to the activation function across each training batch has a mean of 0 and a variance of 1. For example, applying batch normalization to the activation (\sigma(Wx + b)) would result in (\sigma(BN(Wx + b))) where (BN) is the batch normalizing transform.
The batch normalizing transform
To normalize a value across a batch (i.e., to batch normalize the value), we subtract the batch mean, (\mu_B), and divide the result by the batch standard deviation, (\sqrt{\sigma^2_B + \epsilon}). Note that a small constant (\epsilon) is added to the variance in order to avoid dividing by zero.
Thus, the initial batch normalizing transform of a given value, (x_i), is: [BN_{initial}(x_i) = \frac{x_i - \mu_B}{\sqrt{\sigma^2_B + \epsilon}}]
Because the batch normalizing transform given above restricts the inputs to the activation function to a prescribed normal distribution, this can limit the representational power of the layer. Therefore, we allow the network to undo the batch normalizing transform by multiplying by a new scale parameter (\gamma) and adding a new shift parameter (\beta). (\gamma) and (\beta) are learnable parameters.
Adding in (\gamma) and (\beta) producing the following final batch normalizing transform: [BN(x_i) = \gamma(\frac{x_i - \mu_B}{\sqrt{\sigma^2_B + \epsilon}}) + \beta]
Implementing batch normalization in Tensorflow
We will add batch normalization to a basic fully-connected neural network that has two hidden layers of 100 neurons each and show a similar result to Figure 1 (b) and (c) of the BN2015 paper.
Note that this network is not yet generally suitable for use at test time. See the section Making predictions with the model below for the reason why, as well as a fixed version.
Imports, config
1 |
|
Generate predetermined random weights so the networks are similarly initialized
w1_initial = np.random.normal(size=(784,100)).astype(np.float32) w2_initial = np.random.normal(size=(100,100)).astype(np.float32) w3_initial = np.random.normal(size=(100,10)).astype(np.float32)
Small epsilon value for the BN transform
epsilon = 1e-3
1 |
|
Placeholders
x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, 10])
1 |
|
Here is the same layer 1 with batch normalization:
1 |
|
Layer 2 without BN
w2 = tf.Variable(w2_initial) b2 = tf.Variable(tf.zeros([100])) z2 = tf.matmul(l1,w2)+b2 l2 = tf.nn.sigmoid(z2)
1 |
|
Layer 2 with BN, using Tensorflows built-in BN function
w2_BN = tf.Variable(w2_initial) z2_BN = tf.matmul(l1_BN,w2_BN) batch_mean2, batch_var2 = tf.nn.moments(z2_BN,[0]) scale2 = tf.Variable(tf.ones([100])) beta2 = tf.Variable(tf.zeros([100])) BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon) l2_BN = tf.nn.sigmoid(BN2)
1 |
|
Loss, optimizer and predictions
cross_entropy = -tf.reduce_sum(y_tf.log(y)) cross_entropy_BN = -tf.reduce_sum(y_tf.log(y_BN))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) train_step_BN = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy_BN)
correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) correct_prediction_BN = tf.equal(tf.arg_max(y_BN,1),tf.arg_max(y_,1)) accuracy_BN = tf.reduce_mean(tf.cast(correct_prediction_BN,tf.float32))
1 |
|
zs, BNs, acc, acc_BN = [], [], [], []
sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) for i in tqdm.tqdm(range(40000)): batch = mnist.train.next_batch(60) train_step.run(feed_dict={x: batch[0], y_: batch[1]}) train_step_BN.run(feed_dict={x: batch[0], y_: batch[1]}) if i % 50 is 0: res = sess.run([accuracy,accuracy_BN,z2,BN2],feed_dict={x: mnist.test.images, y_: mnist.test.labels}) acc.append(res[0]) acc_BN.append(res[1]) zs.append(np.mean(res[2],axis=0)) # record the mean value of z2 over the entire test set BNs.append(np.mean(res[3],axis=0)) # record the mean value of BN2 over the entire test set
zs, BNs, acc, acc_BN = np.array(zs), np.array(BNs), np.array(acc), np.array(acc_BN)
1 |
|
fig, ax = plt.subplots()
ax.plot(range(0,len(acc)50,50),acc, label=’Without BN’) ax.plot(range(0,len(acc)50,50),acc_BN, label=’With BN’) ax.set_xlabel(‘Training steps’) ax.set_ylabel(‘Accuracy’) ax.set_ylim([0.8,1]) ax.set_title(‘Batch Normalization Accuracy’) ax.legend(loc=4) plt.show()
1 |
|
fig, axes = plt.subplots(5, 2, figsize=(6,12)) fig.tight_layout()
for i, ax in enumerate(axes): ax[0].set_title(“Without BN”) ax[1].set_title(“With BN”) ax[0].plot(zs[:,i]) ax[1].plot(BNs[:,i])
1 |
|
predictions = [] correct = 0 for i in range(100): pred, corr = sess.run([tf.arg_max(y_BN,1), accuracy_BN], feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]}) correct += corr predictions.append(pred[0]) print(“PREDICTIONS:”, predictions) print(“ACCURACY:”, correct/100)
1 |
|
Our model always predicts 8, and there appear to be only two 8s in the first 100 MNIST test samples, for an accuracy of 2%.
Fixing the model for test time
To fix this, we need to replace the batch mean and batch variance in each batch normalization step with estimates of the population mean and population variance, respectively. See Section 3.1 of the BN2015 paper. Testing the model above only worked because the entire test set was predicted at once, so the “batch mean” and “batch variance” of the test set provided good estimates for the population mean and population variance.
To make a batch normalized model generally suitable for testing, we want to obtain estimates for the population mean and population variance at each batch normalization step before test time (i.e., during training), and use these values when making predictions. Note that for the same reason that we need batch normalization (i.e. the mean and variance of the activation inputs changes during training), it would be best to estimate the population mean and variance after the weights they depend on are trained, although doing these simultaneously is not the worst offense, since the weights are expected to converge near the end of training.
And now, to actually implement this in Tensorflow, we will write a batch_norm_wrapper
function, which we will use to wrap the inputs to our activation functions. The function will store the population mean and variance as tf.Variables, and decide whether to use the batch statistics or the population statistics for normalization. To do this, it makes use of an is_training
flag. Because we need to learn the population mean and variance during training, we do this when is_training == True
. Here is an outline of the code:
1 |
|
Note that the variables have been declared with a trainable = False
argument, since we will be updating these ourselves rather than having the optimizer do it.
One approach to estimating the population mean and variance during training is to use an exponential moving average, though strictly speaking, a simple average over the sample would be (marginally) better. The exponential moving average is simple and lets us avoid extra work, so we use that:
1 |
|
Finally, we will need a way to call these training ops. For full control, you can add them to a graph collection (see the link to Tensorflow’s code below), but for simplicity, we will call them every time we calculate the batch_mean and batch_var. To do this, we add them as dependencies to the return value of batch_norm_wrapper when is_training is true. Here is the final batch_norm_wrapper function:
1 |
|
An implementation that works at test time
And now to demonstrate that this works, we rebuild/retrain the model with our batch_norm_wrapper function. Note that we need to build the graph once for training, and then again at test time, so we write a build_graph function (in practice, this would usually be encapsulated in a model object):
1 |
|
#Build training graph, train and save the trained model
sess.close() tf.reset_default_graph() (x, y_), train_step, accuracy, _, saver = build_graph(is_training=True)
acc = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in tqdm.tqdm(range(10000)): batch = mnist.train.next_batch(60) train_step.run(feed_dict={x: batch[0], y_: batch[1]}) if i % 50 is 0: res = sess.run([accuracy],feed_dict={x: mnist.test.images, y_: mnist.test.labels}) acc.append(res[0]) saved_model = saver.save(sess, ‘./temp-bn-save’)
print(“Final accuracy:”, acc[-1])
Final accuracy: 0.9721
1 |
|
tf.reset_default_graph() (x, y_), _, accuracy, y, saver = build_graph(is_training=False)
predictions = [] correct = 0 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, ‘./temp-bn-save’) for i in range(100): pred, corr = sess.run([tf.arg_max(y,1), accuracy], feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]}) correct += corr predictions.append(pred[0]) print(“PREDICTIONS:”, predictions) print(“ACCURACY:”, correct/100)
1 |
|