Batch Normalization is a technique used to address the vanishing/exploding gradients problems. The technique consists of adding operations such as zero centering, normalizing the inputs, scaling and shifting the result. These sequence of operations are added just before the activation function of each layer of our deep neural network.
If you want to dive into the mathematical details of each of these operations, please refer to the original paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" by Sergey Ioffe and Christian Szegedy (https://arxiv.org/pdf/1502.03167v3.pdf).
Here I will be showing you how you can implement batch normalization on the MNIST dataset using tensorflow. For starters, tensorflow provides a tf.nn.batch_normalization() function that will center and normalizes the inputs. The operations requires the mean and the standard deviation of the mini-batch data (during training) or of the full dataset (during testing) which must be manually computed and passed as parameter to this function. Since tensorflow also provides tf.layers.batch_normalization() which will compute the mean and standard deviation for us, I will be using this.
First, let's import the required libraries.
import tensorflow as tf
import numpy as np
Now we will import the MNIST dataset and scale each the pixel values of the images between 0 and 1. Next we will divide the validation and training data.
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]
The following function will let us reset the graph and also set the seed for random_seed() function to set random initialization of the weights and random shuffling of the images.
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
Function to shuffle the images.
def shuffle_batch(X, y, batch_size):
rnd_idx = np.random.permutation(len(X))
n_batches = len(X) // batch_size
for batch_idx in np.array_split(rnd_idx, n_batches):
X_batch, y_batch = X[batch_idx], y[batch_idx]
yield X_batch, y_batch
Set the hyperparameters.
reset_graph()
n_inputs = 28 * 28 # image size
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
learning_rate = 0.01
n_epochs = 25
batch_size = 250
Set the placeholders X, y and training. The training placeholder will be set to True during the training, otherwise it will False during testing. This will act like a flag to tell tf.layers.batch_normalization() whether it should use the current mini-batch's mean and standard deviation or the whole training set.
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int32, shape=(None), name="y")
training = tf.placeholder_with_default(False, shape=(), name='training')
Then, we will alternate between 2 hidden layers and 2 batch normalization layers. We will not specify the activation function in the hidden layers since we will be applying it after each batch normalization layer. We will also define logits for the output layer and the apply batch normalization to it. The tf.layers.batch_normalization will take a 'hidden layer', 'training' (True/ False, indicating training or testing) and 'momentum' as its parameters. The momentum is specified for the calculation of averages using exponential decay. Typical value of 'momentum' is 0.9.
with tf.name_scope("dnn"):
hiddenLayer1 = tf.layers.dense(X, n_hidden1, name="hidden1")
batchNorm1 = tf.layers.batch_normalization(hiddenLayer1, training=training, momentum=0.9)
batchNorm1_act = tf.nn.elu(batchNorm1) # ELU Activation
hiddenLayer2 = tf.layers.dense(batchNorm1_act, n_hidden2, name="hidden2")
batchNorm2 = tf.layers.batch_normalization(hiddenLayer2, training=training, momentum=0.9)
batchNorm2_act = tf.nn.elu(batchNorm2)
logits = tf.layers.dense(batchNorm2_act, n_outputs, name="outputs")
logits_batchNorm = tf.layers.batch_normalization(logits, training=training, momentum=0.9)
Next we will define the loss function sparse_cross_entropy_with_logits() because we are dealing with multiclass classification.
with tf.name_scope("loss"):
x_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss = tf.reduce_mean(x_entropy, name="loss")
We will use GradientDescentOptimizer as an optimizer.
with tf.name_scope("optimize"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
training_ops = optimizer.minimize(loss)
Calculate accuracy using tf.reduce_mean
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()
saver = tf.train.Saver()
The batch normalization will create a few extra operations which need to be evaluated at each step during the training process which will help it to update the moving averages(note that these moving averages will be needed to compute the training set's mean and standard deviation).
extra_graphkeys_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
Finally we start the session.
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
sess.run([training_ops, extra_graphkeys_update_ops],
feed_dict={training: True, X: X_batch, y: y_batch})
accuracy_val = accuracy.eval(feed_dict={X: X_valid, y: y_valid})
print(epoch+1, "Validation accuracy:", accuracy_val)
save_path = saver.save(sess, "./mnist-batch-normalized-model.ckpt")
That's not a great accuracy, especially for MNIST. It is certain, that if we train for longer it will get much better accuracy, but with such a shallow network, batch normalization and ELU are unlikely to have a very positive impact. They behave better mostly for much deeper networks when dealing with large datasets. But I hope you have got a gist about the implementation of batch normalization in Tensorflow.
Full code at: https://github.com/jaynilpatel/mnist/blob/master/batch_normalization-mnist.py