I'm implementing a model relying on 3D convolutions (for a task that is similar to action recognition) and I want to use batch normalization (see [Ioffe & Szegedy 2015]). I could not find any tutorial focusing on 3D convs, hence I'm making a short one here which I'd like to review with you.
The code below refers to TensorFlow r0.12 and it explicitly instances variables - I mean I'm not using tf.contrib.learn except for the tf.contrib.layers.batch_norm() function. I'm doing this both to better understand how things work under the hood and to have more implementation freedom (e.g., variable summaries).
I will get to the 3D convolution case smoothly by first writing the example for a fully-connected layer, then for a 2D convolution and finally for the 3D case. While going through the code, it would be great if you could check if everything is done correctly - the code runs, but I'm not 100% sure about the way I apply batch normalization. I end this post with a more detailed question.
import tensorflow as tf
# This flag is used to allow/prevent batch normalization params updates
# depending on whether the model is being trained or used for prediction.
training = tf.placeholder_with_default(True, shape=())
Fully-connected (FC) case
INPUT_SIZE = 512
u = tf.placeholder(tf.float32, shape=(None, INPUT_SIZE))
# FC params: weights only, no bias as per [Ioffe & Szegedy 2015].
FC_OUTPUT_LAYER_SIZE = 1024
w = tf.Variable(tf.truncated_normal(
[INPUT_SIZE, FC_OUTPUT_LAYER_SIZE], dtype=tf.float32, stddev=1e-1))
# Layer output with no activation function (yet).
fc = tf.matmul(u, w)
# Batch normalization.
fc_bn = tf.contrib.layers.batch_norm(
# Activation function.
fc_bn_relu = tf.nn.relu(fc_bn)
print(fc_bn_relu) # Tensor("Relu:0", shape=(?, 1024), dtype=float32)
I guess the code above is also correct for the 3D conv case. In fact, when I define my model if I print all the trainable variables, I also see the expected numbers of beta and gamma variables. For instance:
Tensor("conv3a/conv3d_weights/read:0", shape=(3, 3, 3, 128, 256), dtype=float32)
Tensor("BatchNorm_2/beta/read:0", shape=(256,), dtype=float32)
Tensor("BatchNorm_2/gamma/read:0", shape=(256,), dtype=float32)