Concepts: batch normalization | optimization | vanishing gradients | SGD Builds on: plain SGD and deep network instability (the problem BN was invented to solve) Leads to: ResNet (152-layer networks only became stable because of BN)
Every layer is chasing a moving target
Let’s think about what happens inside a deep network as it trains. You update the weights in layer 3. That changes what values layer 4 receives. Layer 4 then updates to handle the new values. But those updates change what layer 5 receives. Layer 5 updates. And so on, all the way down.
From any individual layer’s perspective, its inputs are never stable. Early layers keep changing the distribution of values that deep layers receive — not slowly, but every single batch. Each layer is trying to learn its job while the input format keeps shifting underneath it.
“Training Deep Neural Networks is complicated by the fact that the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change.”
The authors call this internal covariate shift. Covariate shift is a known problem in statistics: when the distribution of inputs changes between training and deployment, models fail. Internal covariate shift is the same problem happening inside the network, between layers, during every training step.
The consequence: you need lower learning rates (because large updates with shifting inputs destabilize quickly), careful weight initialization (because a bad starting point cascades), and saturating nonlinearities (like sigmoid) become nightmares (because large input values push activations into flat regions where gradients vanish).
“We refer to this phenomenon as internal covariate shift, and address the problem by normalizing layer inputs.”
The fix sounds almost too obvious. If the problem is that layer inputs keep shifting, just stop them from shifting. Normalize the inputs to each layer using the current mini-batch.
The factory quality inspector
Let’s start with the right analogy before touching any equations.
Imagine you are a quality control inspector at a factory. Your job is to decide if a part passes: too small, too large, or acceptable. You are excellent at this job when parts are measured in millimeters. But every hour, a different supplier ships parts in different units: sometimes millimeters, sometimes inches, sometimes centimeters. You cannot develop reliable intuitions because your frame of reference keeps shifting. Some days “5” means tiny, some days it means huge.
Now imagine a standardization station at the entrance to your workstation. Every part, before it reaches you, is converted: subtract the average of today’s batch, divide by the spread. The result is always a normalized score — how far this part is from the batch average in standard deviations. You can now fully focus on learning your actual job, because the input format is stable.
That standardization station is batch normalization.
But there is a subtlety. If you only normalize and stop there, you have stripped the network’s ability to represent non-normalized distributions. A sigmoid neuron that should operate near saturation cannot, because its inputs are always near zero. The network loses expressivity.
Batch normalization adds a second step: after normalizing, apply a learned scale (called gamma) and learned shift (called beta). The network can now undo the normalization if it wants to. If the best representation for a layer is non-normalized, it will learn gamma and beta such that the output recovers the original distribution. Normalization is the default. The network has the override.
The mechanism, step by step
During training, for each mini-batch and each feature:
- Compute the mean of that feature’s values across the batch
- Compute the variance across the batch
- Normalize: subtract the mean, divide by the standard deviation (plus a tiny epsilon to prevent division by zero)
- Rescale and shift: multiply by learned parameter gamma, add learned parameter beta
During inference, you do not have a mini-batch. So during training, batch normalization maintains a running average of the batch means and variances. At test time, those running statistics are used instead of fresh batch statistics. The layer behaves deterministically and consistently.
TRAINING — forward pass through one BN layer:
Raw activations from layer above:
Batch of 4 examples, one feature: [2.0, 4.0, 6.0, 8.0]
|
v
Step 1: compute batch mean
μ_B = 5.0
Step 2: compute batch variance
σ²_B = 5.0
|
v
Step 3: normalize (zero mean, unit variance)
x̂ = (x - μ_B) / sqrt(σ²_B + ε)
→ [-1.34, -0.45, +0.45, +1.34]
|
v
Step 4: scale + shift (γ, β learned via backprop)
y = γ · x̂ + β
→ the distribution the network chose
INFERENCE — no batch, use running stats:
Single example x
↓
Use μ_running, σ²_running accumulated during training
↓
Same formula: y = γ · (x - μ_running) / sqrt(σ²_running + ε) + β
↓
Deterministic output
The four equations
Mean of the mini-batch ( examples). This is the reference point: what is the average activation this batch?
Variance of the mini-batch. This is the spread: how much do activations differ from the mean?
Normalize. After this step, the activations have mean 0 and variance 1 (approximately). The (typically 1e-5) prevents division by zero when variance collapses.
Rescale and shift. (scale) and (shift) are learned parameters, one per feature. If the network learns and , it exactly undoes the normalization. The parameters are the escape hatch.
Walkthrough with actual numbers
Mini-batch of 4 pre-activation values from one neuron: {2.0, 4.0, 6.0, 8.0}.
Step 1 — batch mean:
Step 2 — batch variance:
Step 3 — normalize (using , so ):
x̂₁ = (2.0 - 5.0) / 2.236 = -1.342
x̂₂ = (4.0 - 5.0) / 2.236 = -0.447
x̂₃ = (6.0 - 5.0) / 2.236 = +0.447
x̂₄ = (8.0 - 5.0) / 2.236 = +1.342
Verify: mean of {-1.342, -0.447, +0.447, +1.342} = 0. Variance ≈ 1.0. Exactly what we want.
Step 4 — scale and shift with , :
y₁ = 1.5 × (-1.342) + 0.3 = -2.013 + 0.3 = -1.713
y₂ = 1.5 × (-0.447) + 0.3 = -0.671 + 0.3 = -0.371
y₃ = 1.5 × (+0.447) + 0.3 = +0.671 + 0.3 = +0.971
y₄ = 1.5 × (+1.342) + 0.3 = +2.013 + 0.3 = +2.313
The output is now centered around with spread controlled by . The next layer receives these values. The mean is 0.3. The variance is . Not zero and one — whatever the network learned was useful.
For comparison: without BN, the next layer would have received {2.0, 4.0, 6.0, 8.0} or something with a shifted mean and scale depending on what the previous layer’s weights happened to be at that training step.
What is clever about it
The key instinct that leads to batch normalization: instead of trying to stabilize training by tuning learning rates and initialization, what if you directly controlled the distribution of each layer’s inputs?
Once you have that instinct, the design follows. Normalize by the batch. But normalization alone kills expressivity — so add learnable scale and shift. Make it differentiable so it trains end-to-end. Keep running statistics for inference so the layer is deterministic at test time.
The non-obvious benefit is what it does to gradients. Saturating activations (sigmoid, tanh) produce near-zero gradients in their flat regions. Inputs far from zero push activations into those regions. By keeping inputs near zero, BN keeps activations in the linear region where gradients are useful. It is a partial, practical solution to the vanishing gradient problem.
“Batch Normalization allows us to use much higher learning rates and be less careful about initialization.”
Translation: the optimization landscape becomes smoother. High learning rates that would have caused instability now converge. The sensitivity to initialization — which historically required careful tuning — largely disappears.
There is also a regularization effect that the authors did not fully anticipate:
“Batch Normalization also acts as a regularizer, in some cases eliminating the need for Dropout.”
Translation: because each mini-batch has different statistics, each example is normalized slightly differently depending on which batch it appears in. This stochastic noise acts like a regularizer. Batch normalization introduces a small but consistent source of noise that generalizes similarly to dropout.
Does it actually work
| Model | ImageNet Top-5 Error | Training Steps |
|---|---|---|
| Inception (baseline) | 6.7% | 31 million |
| BN-Inception | 4.9% | 2.2 million (14x fewer) |
| BN-Inception ensemble | 4.8% (test) | — |
Human rater performance on ImageNet classification: ~5.1%. The BN ensemble exceeded it.
The 14x speedup comes directly from using higher learning rates. The authors describe a configuration they call BN-x30: starting from the baseline learning rate, they multiplied by 30 and used a more aggressive schedule. Without BN, training at 30x the learning rate diverges immediately. With BN, it converges to a better solution in a fraction of the steps.
“Applied to a state-of-the-art image classification model, Batch Normalization achieves the same accuracy with 14 times fewer training steps, and beats the original model by a significant margin.”
What does not work: Batch normalization breaks down with small batch sizes. With batch size 1, you compute mean and variance from a single example — the statistics are meaningless. With batch size 2 or 4, statistics are noisy and training becomes unstable. This is why Layer Normalization (Ba et al., 2016) was invented: it normalizes across the feature dimension rather than the batch dimension, so it works regardless of batch size and is the standard for transformers and language models. BN also introduces a gap between training and inference behavior (batch stats vs. running stats) that can cause subtle bugs if the running statistics are stale or if the model is evaluated on a distribution shift. For recurrent networks, BN transfers poorly because sequences have variable length and statistics differ at each timestep.
If you are building ML systems
Batch normalization is table stakes for convolutional networks. If you are training a CNN on image data with batch sizes of 16 or larger, use it by default after every convolution and before the nonlinearity. Do not train without it unless you have a specific reason.
The practical decision is mostly about which normalization to use. For images with large batches: BN. For transformers and sequence models: Layer Normalization. For image generation at small batch sizes: Group Normalization. If you are debugging a training run that diverges or trains slowly, check normalization first — missing BN layers or incorrect placement (after vs. before the nonlinearity) is a common source of instability.
The conceptual connection to earlier work: Adam attacks the optimization problem from the gradient side — adaptive learning rates per parameter. Batch normalization attacks it from the activation side — stable input distributions per layer. They address different symptoms of the same underlying challenge: that neural network optimization is sensitive and fragile without active stabilization mechanisms. In practice, most models use both.
Batch normalization also made ResNet possible. The 152-layer ResNet applies BN after every single convolution — without it, gradients through that many layers would vanish or explode regardless of skip connections. The two ideas together (skip connections + BN) are the foundation of every modern image model.
Normalize the inputs to each layer, let the network relearn the distribution it actually wants, and watch training stability appear from almost nothing.
Connections
- Batch Normalization
- Optimization
- Vanishing Gradients
- Stochastic Gradient Descent
- Deep Residual Learning (ResNet)
- Adam: A Method for Stochastic Optimization
Citation
Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Proceedings of the 32nd International Conference on Machine Learning (ICML 2015).