.. _sec_batch_norm:
Batch Normalization
===================
Training deep neural networks is difficult. Getting them to converge in
a reasonable amount of time can be tricky. In this section, we describe
*batch normalization*, a popular and effective technique that
consistently accelerates the convergence of deep networks
:cite:`Ioffe.Szegedy.2015`. Together with residual blocks—covered
later in :numref:`sec_resnet`—batch normalization has made it possible
for practitioners to routinely train networks with over 100 layers. A
secondary (serendipitous) benefit of batch normalization lies in its
inherent regularization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import autograd, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from functools import partial
import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Training Deep Networks
----------------------
When working with data, we often preprocess before training. Choices
regarding data preprocessing often make an enormous difference in the
final results. Recall our application of MLPs to predicting house prices
(:numref:`sec_kaggle_house`). Our first step when working with real
data was to standardize our input features to have zero mean
:math:`\boldsymbol{\mu} = 0` and unit variance
:math:`\boldsymbol{\Sigma} = \boldsymbol{1}` across multiple
observations :cite:`friedman1987exploratory`, frequently rescaling the
latter so that the diagonal is unity, i.e., :math:`\Sigma_{ii} = 1`. Yet
another strategy is to rescale vectors to unit length, possibly zero
mean *per observation*. This can work well, e.g., for spatial sensor
data. These preprocessing techniques and many others, are beneficial for
keeping the estimation problem well controlled. For a review of feature
selection and extraction see the article of :cite:t:`guyon2008feature`,
for example. Standardizing vectors also has the nice side-effect of
constraining the function complexity of functions that act upon it. For
instance, the celebrated radius-margin bound :cite:`Vapnik95` in
support vector machines and the Perceptron Convergence Theorem
:cite:`Novikoff62` rely on inputs of bounded norm.
Intuitively, this standardization plays nicely with our optimizers since
it puts the parameters *a priori* on a similar scale. As such, it is
only natural to ask whether a corresponding normalization step *inside*
a deep network might not be beneficial. While this is not quite the
reasoning that led to the invention of batch normalization
:cite:`Ioffe.Szegedy.2015`, it is a useful way of understanding it and
its cousin, layer normalization :cite:`Ba.Kiros.Hinton.2016`, within a
unified framework.
Second, for a typical MLP or CNN, as we train, the variables in
intermediate layers (e.g., affine transformation outputs in MLP) may
take values with widely varying magnitudes: whether along the layers
from input to output, across units in the same layer, and over time due
to our updates to the model parameters. The inventors of batch
normalization postulated informally that this drift in the distribution
of such variables could hamper the convergence of the network.
Intuitively, we might conjecture that if one layer has variable
activations that are 100 times that of another layer, this might
necessitate compensatory adjustments in the learning rates. Adaptive
solvers such as AdaGrad :cite:`Duchi.Hazan.Singer.2011`, Adam
:cite:`Kingma.Ba.2014`, Yogi :cite:`Zaheer.Reddi.Sachan.ea.2018`, or
Distributed Shampoo :cite:`anil2020scalable` aim to address this from
the viewpoint of optimization, e.g., by adding aspects of second-order
methods. The alternative is to prevent the problem from occurring,
simply by adaptive normalization.
Third, deeper networks are complex and tend to be more liable to
overfitting. This means that regularization becomes more critical. A
common technique for regularization is noise injection. This has been
known for a long time, e.g., with regard to noise injection for the
inputs :cite:`Bishop.1995`. It also forms the basis of dropout in
:numref:`sec_dropout`. As it turns out, quite serendipitously, batch
normalization conveys all three benefits: preprocessing, numerical
stability, and regularization.
Batch normalization is applied to individual layers, or optionally, to
all of them: In each training iteration, we first normalize the inputs
(of batch normalization) by subtracting their mean and dividing by their
standard deviation, where both are estimated based on the statistics of
the current minibatch. Next, we apply a scale coefficient and an offset
to recover the lost degrees of freedom. It is precisely due to this
*normalization* based on *batch* statistics that *batch normalization*
derives its name.
Note that if we tried to apply batch normalization with minibatches of
size 1, we would not be able to learn anything. That is because after
subtracting the means, each hidden unit would take value 0. As you might
guess, since we are devoting a whole section to batch normalization,
with large enough minibatches the approach proves effective and stable.
One takeaway here is that when applying batch normalization, the choice
of batch size is even more significant than without batch normalization,
or at least, suitable calibration is needed as we might adjust batch
size.
Denote by :math:`\mathcal{B}` a minibatch and let
:math:`\mathbf{x} \in \mathcal{B}` be an input to batch normalization
(:math:`\textrm{BN}`). In this case the batch normalization is defined
as follows:
.. math:: \textrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}.
:label: eq_batchnorm
In :eq:`eq_batchnorm`, :math:`\hat{\boldsymbol{\mu}}_\mathcal{B}`
is the sample mean and :math:`\hat{\boldsymbol{\sigma}}_\mathcal{B}` is
the sample standard deviation of the minibatch :math:`\mathcal{B}`.
After applying standardization, the resulting minibatch has zero mean
and unit variance. The choice of unit variance (rather than some other
magic number) is arbitrary. We recover this degree of freedom by
including an elementwise *scale parameter* :math:`\boldsymbol{\gamma}`
and *shift parameter* :math:`\boldsymbol{\beta}` that have the same
shape as :math:`\mathbf{x}`. Both are parameters that need to be learned
as part of model training.
The variable magnitudes for intermediate layers cannot diverge during
training since batch normalization actively centers and rescales them
back to a given mean and size (via
:math:`\hat{\boldsymbol{\mu}}_\mathcal{B}` and
:math:`{\hat{\boldsymbol{\sigma}}_\mathcal{B}}`). Practical experience
confirms that, as alluded to when discussing feature rescaling, batch
normalization seems to allow for more aggressive learning rates. We
calculate :math:`\hat{\boldsymbol{\mu}}_\mathcal{B}` and
:math:`{\hat{\boldsymbol{\sigma}}_\mathcal{B}}` in
:eq:`eq_batchnorm` as follows:
.. math::
\hat{\boldsymbol{\mu}}_\mathcal{B} = \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x}
\textrm{ and }
\hat{\boldsymbol{\sigma}}_\mathcal{B}^2 = \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.
Note that we add a small constant :math:`\epsilon > 0` to the variance
estimate to ensure that we never attempt division by zero, even in cases
where the empirical variance estimate might be very small or vanish. The
estimates :math:`\hat{\boldsymbol{\mu}}_\mathcal{B}` and
:math:`{\hat{\boldsymbol{\sigma}}_\mathcal{B}}` counteract the scaling
issue by using noisy estimates of mean and variance. You might think
that this noisiness should be a problem. On the contrary, it is actually
beneficial.
This turns out to be a recurring theme in deep learning. For reasons
that are not yet well-characterized theoretically, various sources of
noise in optimization often lead to faster training and less
overfitting: this variation appears to act as a form of regularization.
:cite:t:`Teye.Azizpour.Smith.2018` and :cite:t:`Luo.Wang.Shao.ea.2018`
related the properties of batch normalization to Bayesian priors and
penalties, respectively. In particular, this sheds some light on the
puzzle of why batch normalization works best for moderate minibatch
sizes in the 50–100 range. This particular size of minibatch seems to
inject just the “right amount” of noise per layer, both in terms of
scale via :math:`\hat{\boldsymbol{\sigma}}`, and in terms of offset via
:math:`\hat{\boldsymbol{\mu}}`: a larger minibatch regularizes less due
to the more stable estimates, whereas tiny minibatches destroy useful
signal due to high variance. Exploring this direction further,
considering alternative types of preprocessing and filtering may yet
lead to other effective types of regularization.
Fixing a trained model, you might think that we would prefer using the
entire dataset to estimate the mean and variance. Once training is
complete, why would we want the same image to be classified differently,
depending on the batch in which it happens to reside? During training,
such exact calculation is infeasible because the intermediate variables
for all data examples change every time we update our model. However,
once the model is trained, we can calculate the means and variances of
each layer’s variables based on the entire dataset. Indeed this is
standard practice for models employing batch normalization; thus batch
normalization layers function differently in *training mode*
(normalizing by minibatch statistics) than in *prediction mode*
(normalizing by dataset statistics). In this form they closely resemble
the behavior of dropout regularization of :numref:`sec_dropout`, where
noise is only injected during training.
Batch Normalization Layers
--------------------------
Batch normalization implementations for fully connected layers and
convolutional layers are slightly different. One key difference between
batch normalization and other layers is that because the former operates
on a full minibatch at a time, we cannot just ignore the batch dimension
as we did before when introducing other layers.
Fully Connected Layers
~~~~~~~~~~~~~~~~~~~~~~
When applying batch normalization to fully connected layers,
:cite:t:`Ioffe.Szegedy.2015`, in their original paper inserted batch
normalization after the affine transformation and *before* the nonlinear
activation function. Later applications experimented with inserting
batch normalization right *after* activation functions. Denoting the
input to the fully connected layer by :math:`\mathbf{x}`, the affine
transformation by :math:`\mathbf{W}\mathbf{x} + \mathbf{b}` (with the
weight parameter :math:`\mathbf{W}` and the bias parameter
:math:`\mathbf{b}`), and the activation function by :math:`\phi`, we can
express the computation of a batch-normalization-enabled, fully
connected layer output :math:`\mathbf{h}` as follows:
.. math:: \mathbf{h} = \phi(\textrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ).
Recall that mean and variance are computed on the *same* minibatch on
which the transformation is applied.
Convolutional Layers
~~~~~~~~~~~~~~~~~~~~
Similarly, with convolutional layers, we can apply batch normalization
after the convolution but before the nonlinear activation function. The
key difference from batch normalization in fully connected layers is
that we apply the operation on a per-channel basis *across all
locations*. This is compatible with our assumption of translation
invariance that led to convolutions: we assumed that the specific
location of a pattern within an image was not critical for the purpose
of understanding.
Assume that our minibatches contain :math:`m` examples and that for each
channel, the output of the convolution has height :math:`p` and width
:math:`q`. For convolutional layers, we carry out each batch
normalization over the :math:`m \cdot p \cdot q` elements per output
channel simultaneously. Thus, we collect the values over all spatial
locations when computing the mean and variance and consequently apply
the same mean and variance within a given channel to normalize the value
at each spatial location. Each channel has its own scale and shift
parameters, both of which are scalars.
.. _subsec_layer-normalization-in-bn:
Layer Normalization
~~~~~~~~~~~~~~~~~~~
Note that in the context of convolutions the batch normalization is well
defined even for minibatches of size 1: after all, we have all the
locations across an image to average. Consequently, mean and variance
are well defined, even if it is just within a single observation. This
consideration led :cite:t:`Ba.Kiros.Hinton.2016` to introduce the
notion of *layer normalization*. It works just like a batch norm, only
that it is applied to one observation at a time. Consequently both the
offset and the scaling factor are scalars. For an :math:`n`-dimensional
vector :math:`\mathbf{x}`, layer norms are given by
.. math:: \mathbf{x} \rightarrow \textrm{LN}(\mathbf{x}) = \frac{\mathbf{x} - \hat{\mu}}{\hat\sigma},
where scaling and offset are applied coefficient-wise and given by
.. math::
\hat{\mu} \stackrel{\textrm{def}}{=} \frac{1}{n} \sum_{i=1}^n x_i \textrm{ and }
\hat{\sigma}^2 \stackrel{\textrm{def}}{=} \frac{1}{n} \sum_{i=1}^n (x_i - \hat{\mu})^2 + \epsilon.
As before we add a small offset :math:`\epsilon > 0` to prevent division
by zero. One of the major benefits of using layer normalization is that
it prevents divergence. After all, ignoring :math:`\epsilon`, the output
of the layer normalization is scale independent. That is, we have
:math:`\textrm{LN}(\mathbf{x}) \approx \textrm{LN}(\alpha \mathbf{x})`
for any choice of :math:`\alpha \neq 0`. This becomes an equality for
:math:`|\alpha| \to \infty` (the approximate equality is due to the
offset :math:`\epsilon` for the variance).
Another advantage of the layer normalization is that it does not depend
on the minibatch size. It is also independent of whether we are in
training or test regime. In other words, it is simply a deterministic
transformation that standardizes the activations to a given scale. This
can be very beneficial in preventing divergence in optimization. We skip
further details and recommend that interested readers consult the
original paper.
Batch Normalization During Prediction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
As we mentioned earlier, batch normalization typically behaves
differently in training mode than in prediction mode. First, the noise
in the sample mean and the sample variance arising from estimating each
on minibatches is no longer desirable once we have trained the model.
Second, we might not have the luxury of computing per-batch
normalization statistics. For example, we might need to apply our model
to make one prediction at a time.
Typically, after training, we use the entire dataset to compute stable
estimates of the variable statistics and then fix them at prediction
time. Hence, batch normalization behaves differently during training
than at test time. Recall that dropout also exhibits this
characteristic.
Implementation from Scratch
---------------------------
To see how batch normalization works in practice, we implement one from
scratch below.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# Use is_grad_enabled to determine whether we are in training mode
if not torch.is_grad_enabled():
# In prediction mode, use mean and variance obtained by moving average
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of X, so that the broadcasting
# operation can be carried out later
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / torch.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
moving_var = (1.0 - momentum) * moving_var + momentum * var
Y = gamma * X_hat + beta # Scale and shift
return Y, moving_mean.data, moving_var.data
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# Use autograd to determine whether we are in training mode
if not autograd.is_training():
# In prediction mode, use mean and variance obtained by moving average
X_hat = (X - moving_mean) / np.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(axis=0)
var = ((X - mean) ** 2).mean(axis=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of X, so that the broadcasting
# operation can be carried out later
mean = X.mean(axis=(0, 2, 3), keepdims=True)
var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / np.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
moving_var = (1.0 - momentum) * moving_var + momentum * var
Y = gamma * X_hat + beta # Scale and shift
return Y, moving_mean, moving_var
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def batch_norm(X, deterministic, gamma, beta, moving_mean, moving_var, eps,
momentum):
# Use `deterministic` to determine whether the current mode is training
# mode or prediction mode
if deterministic:
# In prediction mode, use mean and variance obtained by moving average
# `linen.Module.variables` have a `value` attribute containing the array
X_hat = (X - moving_mean.value) / jnp.sqrt(moving_var.value + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# When using a fully connected layer, calculate the mean and
# variance on the feature dimension
mean = X.mean(axis=0)
var = ((X - mean) ** 2).mean(axis=0)
else:
# When using a two-dimensional convolutional layer, calculate the
# mean and variance on the channel dimension (axis=1). Here we
# need to maintain the shape of `X`, so that the broadcasting
# operation can be carried out later
mean = X.mean(axis=(0, 2, 3), keepdims=True)
var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
# In training mode, the current mean and variance are used
X_hat = (X - mean) / jnp.sqrt(var + eps)
# Update the mean and variance using moving average
moving_mean.value = momentum * moving_mean.value + (1.0 - momentum) * mean
moving_var.value = momentum * moving_var.value + (1.0 - momentum) * var
Y = gamma * X_hat + beta # Scale and shift
return Y
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps):
# Compute reciprocal of square root of the moving variance elementwise
inv = tf.cast(tf.math.rsqrt(moving_var + eps), X.dtype)
# Scale and shift
inv *= gamma
Y = X * inv + (beta - moving_mean * inv)
return Y
.. raw:: html
.. raw:: html
We can now create a proper ``BatchNorm`` layer. Our layer will maintain
proper parameters for scale ``gamma`` and shift ``beta``, both of which
will be updated in the course of training. Additionally, our layer will
maintain moving averages of the means and variances for subsequent use
during model prediction.
Putting aside the algorithmic details, note the design pattern
underlying our implementation of the layer. Typically, we define the
mathematics in a separate function, say ``batch_norm``. We then
integrate this functionality into a custom layer, whose code mostly
addresses bookkeeping matters, such as moving data to the right device
context, allocating and initializing any required variables, keeping
track of moving averages (here for mean and variance), and so on. This
pattern enables a clean separation of mathematics from boilerplate code.
Also note that for the sake of convenience we did not worry about
automatically inferring the input shape here; thus we need to specify
the number of features throughout. By now all modern deep learning
frameworks offer automatic detection of size and shape in the high-level
batch normalization APIs (in practice we will use this instead).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BatchNorm(nn.Module):
# num_features: the number of outputs for a fully connected layer or the
# number of output channels for a convolutional layer. num_dims: 2 for a
# fully connected layer and 4 for a convolutional layer
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# The variables that are not model parameters are initialized to 0 and
# 1
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, X):
# If X is not on the main memory, copy moving_mean and moving_var to
# the device where X is located
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# Save the updated moving_mean and moving_var
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.1)
return Y
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BatchNorm(nn.Block):
# `num_features`: the number of outputs for a fully connected layer
# or the number of output channels for a convolutional layer. `num_dims`:
# 2 for a fully connected layer and 4 for a convolutional layer
def __init__(self, num_features, num_dims, **kwargs):
super().__init__(**kwargs)
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = self.params.get('gamma', shape=shape, init=init.One())
self.beta = self.params.get('beta', shape=shape, init=init.Zero())
# The variables that are not model parameters are initialized to 0 and
# 1
self.moving_mean = np.zeros(shape)
self.moving_var = np.ones(shape)
def forward(self, X):
# If `X` is not on the main memory, copy `moving_mean` and
# `moving_var` to the device where `X` is located
if self.moving_mean.ctx != X.ctx:
self.moving_mean = self.moving_mean.copyto(X.ctx)
self.moving_var = self.moving_var.copyto(X.ctx)
# Save the updated `moving_mean` and `moving_var`
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma.data(), self.beta.data(), self.moving_mean,
self.moving_var, eps=1e-12, momentum=0.1)
return Y
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BatchNorm(nn.Module):
# `num_features`: the number of outputs for a fully connected layer
# or the number of output channels for a convolutional layer.
# `num_dims`: 2 for a fully connected layer and 4 for a convolutional layer
# Use `deterministic` to determine whether the current mode is training
# mode or prediction mode
num_features: int
num_dims: int
deterministic: bool = False
@nn.compact
def __call__(self, X):
if self.num_dims == 2:
shape = (1, self.num_features)
else:
shape = (1, 1, 1, self.num_features)
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
gamma = self.param('gamma', jax.nn.initializers.ones, shape)
beta = self.param('beta', jax.nn.initializers.zeros, shape)
# The variables that are not model parameters are initialized to 0 and
# 1. Save them to the 'batch_stats' collection
moving_mean = self.variable('batch_stats', 'moving_mean', jnp.zeros, shape)
moving_var = self.variable('batch_stats', 'moving_var', jnp.ones, shape)
Y = batch_norm(X, self.deterministic, gamma, beta,
moving_mean, moving_var, eps=1e-5, momentum=0.9)
return Y
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BatchNorm(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
def build(self, input_shape):
weight_shape = [input_shape[-1], ]
# The scale parameter and the shift parameter (model parameters) are
# initialized to 1 and 0, respectively
self.gamma = self.add_weight(name='gamma', shape=weight_shape,
initializer=tf.initializers.ones, trainable=True)
self.beta = self.add_weight(name='beta', shape=weight_shape,
initializer=tf.initializers.zeros, trainable=True)
# The variables that are not model parameters are initialized to 0
self.moving_mean = self.add_weight(name='moving_mean',
shape=weight_shape, initializer=tf.initializers.zeros,
trainable=False)
self.moving_variance = self.add_weight(name='moving_variance',
shape=weight_shape, initializer=tf.initializers.ones,
trainable=False)
super(BatchNorm, self).build(input_shape)
def assign_moving_average(self, variable, value):
momentum = 0.1
delta = (1.0 - momentum) * variable + momentum * value
return variable.assign(delta)
@tf.function
def call(self, inputs, training):
if training:
axes = list(range(len(inputs.shape) - 1))
batch_mean = tf.reduce_mean(inputs, axes, keepdims=True)
batch_variance = tf.reduce_mean(tf.math.squared_difference(
inputs, tf.stop_gradient(batch_mean)), axes, keepdims=True)
batch_mean = tf.squeeze(batch_mean, axes)
batch_variance = tf.squeeze(batch_variance, axes)
mean_update = self.assign_moving_average(
self.moving_mean, batch_mean)
variance_update = self.assign_moving_average(
self.moving_variance, batch_variance)
self.add_update(mean_update)
self.add_update(variance_update)
mean, variance = batch_mean, batch_variance
else:
mean, variance = self.moving_mean, self.moving_variance
output = batch_norm(inputs, moving_mean=mean, moving_var=variance,
beta=self.beta, gamma=self.gamma, eps=1e-5)
return output
.. raw:: html
.. raw:: html
We used ``momentum`` to govern the aggregation over past mean and
variance estimates. This is somewhat of a misnomer as it has nothing
whatsoever to do with the *momentum* term of optimization. Nonetheless,
it is the commonly adopted name for this term and in deference to API
naming convention we use the same variable name in our code.
LeNet with Batch Normalization
------------------------------
To see how to apply ``BatchNorm`` in context, below we apply it to a
traditional LeNet model (:numref:`sec_lenet`). Recall that batch
normalization is applied after the convolutional layers or fully
connected layers but before the corresponding activation functions.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential(
nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.LazyLinear(120),
BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),
BatchNorm(84, num_dims=2), nn.Sigmoid(),
nn.LazyLinear(num_classes))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(
nn.Conv2D(6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5), BatchNorm(16, num_dims=4),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2), nn.Dense(120),
BatchNorm(120, num_dims=2), nn.Activation('sigmoid'),
nn.Dense(84), BatchNorm(84, num_dims=2),
nn.Activation('sigmoid'), nn.Dense(num_classes))
self.initialize()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNetScratch(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(6, kernel_size=(5, 5)),
BatchNorm(6, num_dims=4, deterministic=not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(16, kernel_size=(5, 5)),
BatchNorm(16, num_dims=4, deterministic=not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(120),
BatchNorm(120, num_dims=2, deterministic=not self.training),
nn.sigmoid,
nn.Dense(84),
BatchNorm(84, num_dims=2, deterministic=not self.training),
nn.sigmoid,
nn.Dense(self.num_classes)])
Since ``BatchNorm`` layers need to calculate the batch statistics (mean
and variance), Flax keeps track of the ``batch_stats`` dictionary,
updating them with every minibatch. Collections like ``batch_stats`` can
be stored in the ``TrainState`` object (in the ``d2l.Trainer`` class
defined in :numref:`oo-design-training`) as an attribute and during
the model’s forward pass, these should be passed to the ``mutable``
argument, so that Flax returns the mutated variables.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
Y_hat, updates = state.apply_fn({'params': params,
'batch_stats': state.batch_stats},
*X, mutable=['batch_stats'],
rngs={'dropout': state.dropout_rng})
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
Y = Y.reshape((-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
return (fn(Y_hat, Y).mean(), updates) if averaged else (fn(Y_hat, Y), updates)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNetScratch(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6, kernel_size=5,
input_shape=(28, 28, 1)),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=16, kernel_size=5),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
BatchNorm(), tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(84), BatchNorm(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(num_classes)])
.. raw:: html
.. raw:: html
As before, we will train our network on the Fashion-MNIST dataset. This
code is virtually identical to that when we first trained LeNet.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_65_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_68_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_71_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128)
with d2l.try_gpu():
model = BNLeNetScratch(lr=0.5)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_74_0.svg
.. raw:: html
.. raw:: html
Let’s have a look at the scale parameter ``gamma`` and the shift
parameter ``beta`` learned from the first batch normalization layer.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.net[1].gamma.reshape((-1,)), model.net[1].beta.reshape((-1,))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([1.4334, 1.9905, 1.8584, 2.0740, 2.0522, 1.8877], device='cuda:0',
grad_fn=),
tensor([ 0.7354, -1.3538, -0.2567, -0.9991, -0.3028, 1.3125], device='cuda:0',
grad_fn=))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.net[1].gamma.data().reshape(-1,), model.net[1].beta.data().reshape(-1,)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([2.130113 , 1.560813 , 1.461431 , 1.9807949, 2.2318861, 1.551563 ], ctx=gpu(0)),
array([ 1.2277379 , 1.514598 , -1.014917 , 0.19028394, 0.6355166 ,
0.5642359 ], ctx=gpu(0)))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer.state.params['net']['layers_1']['gamma'].reshape((-1,)), \
trainer.state.params['net']['layers_1']['beta'].reshape((-1,))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([1.6042372, 2.336828 , 1.6758277, 2.621987 , 1.1613046, 2.3155787], dtype=float32),
Array([-0.21946815, 0.64822614, -0.3630768 , 0.88309413, -0.04114898,
-0.610043 ], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.reshape(model.net.layers[1].gamma, (-1,)), tf.reshape(
model.net.layers[1].beta, (-1,))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
Concise Implementation
----------------------
Compared with the ``BatchNorm`` class, which we just defined ourselves,
we can use the ``BatchNorm`` class defined in high-level APIs from the
deep learning framework directly. The code looks virtually identical to
our implementation above, except that we no longer need to provide
additional arguments for it to get the dimensions right.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential(
nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(),
nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),
nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),
nn.Sigmoid(), nn.LazyLinear(num_classes))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(
nn.Conv2D(6, kernel_size=5), nn.BatchNorm(),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Conv2D(16, kernel_size=5), nn.BatchNorm(),
nn.Activation('sigmoid'),
nn.AvgPool2D(pool_size=2, strides=2),
nn.Dense(120), nn.BatchNorm(), nn.Activation('sigmoid'),
nn.Dense(84), nn.BatchNorm(), nn.Activation('sigmoid'),
nn.Dense(num_classes))
self.initialize()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNet(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(6, kernel_size=(5, 5)),
nn.BatchNorm(not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(16, kernel_size=(5, 5)),
nn.BatchNorm(not self.training),
nn.sigmoid,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(120),
nn.BatchNorm(not self.training),
nn.sigmoid,
nn.Dense(84),
nn.BatchNorm(not self.training),
nn.sigmoid,
nn.Dense(self.num_classes)])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6, kernel_size=5,
input_shape=(28, 28, 1)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=16, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.AvgPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(), tf.keras.layers.Dense(120),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(84),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid'),
tf.keras.layers.Dense(num_classes)])
.. raw:: html
.. raw:: html
Below, we use the same hyperparameters to train our model. Note that as
usual, the high-level API variant runs much faster because its code has
been compiled to C++ or CUDA while our custom implementation must be
interpreted by Python.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_110_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_113_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_116_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128)
with d2l.try_gpu():
model = BNLeNet(lr=0.5)
trainer.fit(model, data)
.. figure:: output_batch-norm_cf033c_119_0.svg
.. raw:: html
.. raw:: html
Discussion
----------
Intuitively, batch normalization is thought to make the optimization
landscape smoother. However, we must be careful to distinguish between
speculative intuitions and true explanations for the phenomena that we
observe when training deep models. Recall that we do not even know why
simpler deep neural networks (MLPs and conventional CNNs) generalize
well in the first place. Even with dropout and weight decay, they remain
so flexible that their ability to generalize to unseen data likely needs
significantly more refined learning-theoretic generalization guarantees.
The original paper proposing batch normalization
:cite:`Ioffe.Szegedy.2015`, in addition to introducing a powerful and
useful tool, offered an explanation for why it works: by reducing
*internal covariate shift*. Presumably by *internal covariate shift*
they meant something like the intuition expressed above—the notion that
the distribution of variable values changes over the course of training.
However, there were two problems with this explanation: i) This drift is
very different from *covariate shift*, rendering the name a misnomer. If
anything, it is closer to concept drift. ii) The explanation offers an
under-specified intuition but leaves the question of *why precisely this
technique works* an open question wanting for a rigorous explanation.
Throughout this book, we aim to convey the intuitions that practitioners
use to guide their development of deep neural networks. However, we
believe that it is important to separate these guiding intuitions from
established scientific fact. Eventually, when you master this material
and start writing your own research papers you will want to be clear to
delineate between technical claims and hunches.
Following the success of batch normalization, its explanation in terms
of *internal covariate shift* has repeatedly surfaced in debates in the
technical literature and broader discourse about how to present machine
learning research. In a memorable speech given while accepting a Test of
Time Award at the 2017 NeurIPS conference, Ali Rahimi used *internal
covariate shift* as a focal point in an argument likening the modern
practice of deep learning to alchemy. Subsequently, the example was
revisited in detail in a position paper outlining troubling trends in
machine learning :cite:`Lipton.Steinhardt.2018`. Other authors have
proposed alternative explanations for the success of batch
normalization, some :cite:`Santurkar.Tsipras.Ilyas.ea.2018` claiming
that batch normalization’s success comes despite exhibiting behavior
that is in some ways opposite to those claimed in the original paper.
We note that the *internal covariate shift* is no more worthy of
criticism than any of thousands of similarly vague claims made every
year in the technical machine learning literature. Likely, its resonance
as a focal point of these debates owes to its broad recognizability for
the target audience. Batch normalization has proven an indispensable
method, applied in nearly all deployed image classifiers, earning the
paper that introduced the technique tens of thousands of citations. We
conjecture, though, that the guiding principles of regularization
through noise injection, acceleration through rescaling and lastly
preprocessing may well lead to further inventions of layers and
techniques in the future.
On a more practical note, there are a number of aspects worth
remembering about batch normalization:
- During model training, batch normalization continuously adjusts the
intermediate output of the network by utilizing the mean and standard
deviation of the minibatch, so that the values of the intermediate
output in each layer throughout the neural network are more stable.
- Batch normalization is slightly different for fully connected layers
than for convolutional layers. In fact, for convolutional layers,
layer normalization can sometimes be used as an alternative.
- Like a dropout layer, batch normalization layers have different
behaviors in training mode than in prediction mode.
- Batch normalization is useful for regularization and improving
convergence in optimization. By contrast, the original motivation of
reducing internal covariate shift seems not to be a valid
explanation.
- For more robust models that are less sensitive to input
perturbations, consider removing batch normalization
:cite:`wang2022removing`.
Exercises
---------
1. Should we remove the bias parameter from the fully connected layer or
the convolutional layer before the batch normalization? Why?
2. Compare the learning rates for LeNet with and without batch
normalization.
1. Plot the increase in validation accuracy.
2. How large can you make the learning rate before the optimization
fails in both cases?
3. Do we need batch normalization in every layer? Experiment with it.
4. Implement a “lite” version of batch normalization that only removes
the mean, or alternatively one that only removes the variance. How
does it behave?
5. Fix the parameters ``beta`` and ``gamma``. Observe and analyze the
results.
6. Can you replace dropout by batch normalization? How does the behavior
change?
7. Research ideas: think of other normalization transforms that you can
apply:
1. Can you apply the probability integral transform?
2. Can you use a full-rank covariance estimate? Why should you
probably not do that?
3. Can you use other compact matrix variants (block-diagonal,
low-displacement rank, Monarch, etc.)?
4. Does a sparsification compression act as a regularizer?
5. Are there other projections (e.g., convex cone, symmetry
group-specific transforms) that you can use?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html