The blog post introduces some discrete variational auto-encoders in deep learning.

In the previous article, we introduced DreamerV2, which uses a VAE that produces discrete latent state representation for improving performance. However, how can we implement such a VAE without stopping the gradient? In this article, we will cover some clever tricks that researchers have come up with for achieving discrete latent representations and how discretization might improve performance.
Note: If you are unfamiliar with the concept of variational autoencoder, I highly recommend checking out the article, Road to ML Engineer #18 - Variational Autoencoders.
Straight-Through Gradients
The simplest way to obtain a discrete latent representation is to sample from the categorical distribution. However, the problem with sampling is that it is not differentiable and stops backward propagation to the encoder. To get around this problem, DreamerV2 uses the simplest solution called straight-through gradient estimator, which ignores the sampling or any other non-differentiable operations and directly passes the gradient backwards from the first layer of the decoder as an estimation of the true gradient.
class STEFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
dist=torch.distributions.categorical.Categorical(probs=x)
return F.one_hot(dist.sample())
@staticmethod
def backward(ctx, g):
return F.hardtanh(g)
class StraightThroughEstimator(nn.Module):
def __init__(self):
super(StraightThroughEstimator, self).__init__()
def forward(self, x):
x = STEFunction.apply(x)
return x
Here, the VAE of DreamerV2 applies KL penalties to make the categorical posterior distribution similar to the prior
produced from the hidden state. Since the prior itself is learnable, DreamerV2 uses KL balancing and applies
a higher learning rate for learning the prior faster than learning the posterior. The following is a simplified
PyTorch implementation of the straight-through gradient estimator for DreamerV2: we can simply pass the gradient output,
but we use hardtanh
here to stabilize the gradient by rescaling it between -1 and 1.
Gumbel Softmax Trick
An alternative workaround to the problem of sampling from categorical distribution being non-differentiable is to apply the trick similar to that of reparameterization trick used for Gaussian distributions. Here, we can simulate sampling from the categorical distribution by sampling noise from a Gumbel distribution, which models the maximum (or minimum) of samples from various different distributions, adding the noise to the log probabilities, and taking the argmax of the sum of log probabilities and noise (we use log probabilities because the log function is monotonically increasing). A video here by Bechtel (2022) describes Gumbel distribution concisely with its shape, where the probability density concentrates around 0.
class GumbelSoftmaxLayer(nn.Module):
def __init__(self, temperature=10):
super(CustomLayer, self).__init__()
self.temperature = temperature # usually annealed
def forward(self, x):
# Sample Gumbel noise g ~ G(mu=0, beta=1)
uniform = torch.rand(x.size())
gumbel_noise = -torch.log(-torch.log(uniform))
# Add to noise log prob
log_prob = torch.log(x)
log_prob += gumbel_noise
# Use softmax when training and argmax or inference
# assuming x of size (batch_size, num_latent, num_categories)
if self.training:
log_prob /= self.temperature
return F.softmax(log_prob, dim=2)
else
return F.one_hot(torch.argmax(log_prob, dim=2))
By adding the sample noise from the Gumbel distribution, we can take into account the probability of a category being the maximum and thus chosen as a sample, effectively simulating sampling from the categorical distribution. The technique, Gumbel max trick, is a great reparameterization, but it still has an argmax operation that is non-differentiable. Hence, we can instead use a softmax function with a temperature parameter instead of argmax (we have always assumed ). The closer is to 0, the sharper the distribution becomes and closer it gets to the argmax.
By annealing gradually from high values to 0, we can pass gradients initially for robust training and move towards working with a latent representation closer to being discrete. During inference, we can use argmax for producing the discrete latent representation. This trick of substituting argmax with softmax with varying temperature during training is called Gumbel softmax trick. The above is a PyTorch implementation of the Gumbel softmax trick. As you can see from above, the noise can be sampled from the Gumbel distribution with and by applying negative log twice to the sample from a uniform distribution between 0 and 1.
VQ-VAE
Another alternative approach is to introduce a learnable codebook or embedding, where there are number of latent vectors with elements. In this approach, we make the encoder output vectors with the last dimension being , instead of the categorical distributions, and choose the latent vector with the highest similarity based on Euclidean distance to pass to the decoder as . Through this process called vector quantization, we can obtain one-hot encoded vectors, corresponding to the indices of the chosen , as the discrete latent representation.

Here, we use a deterministic posterior (based on Euclidean distance) and assume the prior to be the discrete uniform distribution during training (for unbiased training). Since we can assume that encoder output and the codebook are similar, which is especially true after training using squared distance as the loss for both and (like when using KL balancing), we can confidently pass gradients straight from the decoder to the encoder.
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
self._commitment_cost = commitment_cost
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss
quantized = inputs + (quantized - inputs).detach()
# convert quantized from BHWC -> BCHW
return loss, quantized.permute(0, 3, 1, 2).contiguous()
Though we assume a uniform prior for choosing all vectors, the true prior for selecting each vector that reflects the training data would not be uniform and would have sequential dependence. (If there are 10 digits in the training data, selecting one latent vector would influence the digit to generate with the decoder and influence the distribution of selecting the second latent vector and the subsequent vectors.)
Thus, the original paper suggests training an autoregressive model (like RNNs and Transformers) that outputs the true categorical prior
distribution conditioned on previously sampled vectors (next vector prediction on latent representation). The model allows us to
perform ancestral sampling, where we sample from the sequence from the conditional distribution, for generative tasks.
The above is the example PyTorch implementation of a vector quantization layer for VQ-VAE. In this implementation, we compute the
loss for quantization in the module where we have access to relevant components and use detach
to stop the gradients.
Why Discrete Over Continuous?
We've introduced three techniques for using discrete latent representations, but haven't discussed why we would like to use discrete representation when we can use a simple continuous representation. The obvious benefit to using discrete representation is data compression. While a continuous representation requires us to store one floating-point or 32 bits per latent dimension, a discrete representation only requires bits ( is the number of categories) per latent dimension, which is often smaller than 32 bits ().
Although it achieves better data compression, intuitively using a discrete representation seems to limit the latent space to and results in lower performance than continuous representation in exchange. However, DreamerV2 counterintuitively achieves better results with categorical latents than Gaussian latents on 42 tasks. The original paper suggests that this might be because Gaussian latents assume an unimodal distribution unlike categorical latents that can fit to multi-modal distributions, may worsen unstable gradients by unlike the straight-through estimator, and is less natural to model inherently discrete states.
Although the second point does not apply to the Gumbel softmax trick, I'd argue these three are sensible explanations for why discrete latents sometimes outperform continuous latents in some tasks even with limited expressivity. As we often deal with finite MDPs with discrete spaces and natural language and speech that are inherently discrete (which have already been quantized and tokenized), it might be more natural to use discrete latent representations for them than to forcefully map them into continuous space. In addition, VQ-VAE can be a natural model for tokenization (and is actually utilized as a model for tokenization in practice), especially for continuous data like audio and images, since it can learn the corresponding embeddings appropriate for downstream tasks.
Conclusion
In this article, we introduced the straight-through gradient estimator, Gumbel softmax trick, and vector quantization techniques that allow us to produce categorical latents. We also discussed potential reasons as to why categorical latents are sometimes highly performant and preferred in some cases. For more details on these techniques, I recommend checking out the original papers and supplementary resources cited below.
Resources
- Benjaminson, E. 2020. The Gumbel-Softmax Distribution. Emma Benjaminson.
- DeepBean. 2024. Vector-Quantized Variational Autoencoders (VQ-VAEs) | Deep Learning. YouTube.
- Hafner, D. et al. 2021. Mastering Atari With Discrete World Models. ICLR 2021.
- Huijben, I. et al. 2022. A Review of the Gumbel-max Trick and its Extentions for Discrete Stochasticity in Machine Learning. ArXiv.
- Jang, E. et al. 2017. Categorical Reparametrization With Gumbel Softmax. ICLR 2017.
- Van den Oord, A. et al. 2018. Neural Discrete Representation Learning. ArXiv.