The blog post discusses about Layer Normalization in deep learning.

In the last article on LSTM, we discussed how its structure can mitigate the vanishing gradient problem, though it can still suffer from the exploding gradient problem. In the article, Road to ML Engineer #15 - Unstable Gradients, we discussed that gradient clipping can address the exploding gradient problem. However, another technique was mentioned as the most powerful one against unstable gradients: batch normalization.
However, batch normalization doesn't make sense when sequence lengths differ among data points in a batch, as it normalizes the activations by taking the mean and variance of each activation within the batch. Batch normalization can also perform poorly when the batch size is small. Hence, batch normalization is infeasible for RNNs or LSTMs, which can take inputs and produce outputs of varying lengths.
Layer Normalization
To solve this problem, we can introduce layer normalization, where activations are normalized by taking the mean and variance of all the activations for each data point. This allows the outputs to have varying lengths and makes normalization independent of batch size. The following shows the mathematical expression of layer normalization.
The normalization steps are nearly identical to those of batch normalization, except that normalization is done for each data point and is independent of the batch size. This allows the sequence length for each data point to vary while allowing parallel computation within the batch. Additionally, batch normalization requires tracking running mean and variance during training, which is used for inference (where batches may not be used). Layer normalization, on the other hand, does not require different processes for training, testing, and inference. The following is the code implementation of the layer normalization in TensorFlow.
class LayerNormalization(layers.Layer):
def __init__(self, embed_dim, epsilon=1e-5):
super(LayerNormalization, self).__init__()
self.gamma = tf.Variable(tf.ones((1, embed_dim)), trainable=True, name="gamma")
self.beta = tf.Variable(tf.zeros((1, embed_dim)), trainable=True, name="beta")
self.epsilon = epsilon
def call(self, x):
mean = tf.expand_dims(tf.reduce_mean(x, axis=-1), axis=-1)
var = tf.expand_dims(tf.math.reduce_variance(x, axis=-1), axis=-1)
normalized = (x - mean) / tf.math.sqrt(var + self.epsilon)
out = self.gamma * normalized + self.beta
return out
I recommend trying to implement the above in PyTorch as well. The original paper by Ba, L. J. et al. (2016) empirically confirmed that RNNs especially benefit from layer normalization for longer sequences and mini-batches, while batch normalization outperformed layer normalization for CNNs. Due to the high parallelizability, sped-up training, and compatibility with varying lengths, layer normalization remains a popular technique for tasks with a sequential nature, like NLP, even today. The layer normalization layer is predefined in both TensorFlow and PyTorch, just like batch normalization. If you're interested, I recommend adding a layer normalization layer within the RNN or LSTM model you built in the last two articles.
Elephant in the Room
Including the discussions in the last article, we touched on various solutions for solving the unstable gradient problem with RNNs, such as using LSTM, GRU, gradient clipping, and layer normalization. We also talked about bidirectional and deep RNNs, which can contribute to improving the model’s performance.
While these solutions provided incremental improvements to model performance and learning, RNNs continued to struggle with long-term memory retention and the exploding gradient problem. The model’s performance plateaued and was not at the level needed to generate sensible texts.
Most importantly, the biggest problem discussed in the previous article has yet to be addressed by these solutions: the lack of parallelizability in RNNs. Regardless of the cell type or model architecture chosen, the sequential nature of the computation and backpropagation through time (BPTT) remains, which prevents us from fully leveraging parallelization. This makes the model exponentially slower for training and inference. Although recent research has suggested some solutions to this issue, those solutions still involve a high degree of uncertainty.
Therefore, starting in the next article, we will finally move on to discuss an alternative, parallelizable approach that has been extensively used in recent times.
Resources
- Ba, L. J. et al. 2016. Layer Normalization. Arxiv.
- Feng, L. et al. 2024. Were RNNs All We Needed?. Arxiv.
- Hayashi, M. 2023. レイヤー正規化 (layer normalization) [Transformerでよく用いるバッチ正規化層]. CVMLエキスパートガイド.