Road to ML Engineer #47 - DVAEs

Last Edited: 3/13/2025

The blog post introduces some discrete variational auto-encoders in deep learning.

ML

In the previous article, we introduced DreamerV2, which uses a VAE that produces discrete latent state representation for improving performance. However, how can we implement such a VAE without stopping the gradient? In this article, we will cover some clever tricks that researchers have come up with for achieving discrete latent representations and how discretization might improve performance.

Note: If you are unfamiliar with the concept of variational autoencoder, I highly recommend checking out the article, Road to ML Engineer #18 - Variational Autoencoders.

Straight-Through Gradients

The simplest way to obtain a discrete latent representation is to sample from the categorical distribution. However, the problem with sampling is that it is not differentiable and stops backward propagation to the encoder. To get around this problem, DreamerV2 uses the simplest solution called straight-through gradient estimator, which ignores the sampling or any other non-differentiable operations and directly passes the gradient backwards from the first layer of the decoder as an estimation of the true gradient.

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        dist=torch.distributions.categorical.Categorical(probs=x)
        return F.one_hot(dist.sample())
 
    @staticmethod
    def backward(ctx, g):
        return F.hardtanh(g)
 
class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
 
    def forward(self, x):
            x = STEFunction.apply(x)
            return x

Here, the VAE of DreamerV2 applies KL penalties to make the categorical posterior distribution similar to the prior produced from the hidden state. Since the prior itself is learnable, DreamerV2 uses KL balancing and applies a higher learning rate for learning the prior faster than learning the posterior. The following is a simplified PyTorch implementation of the straight-through gradient estimator for DreamerV2: we can simply pass the gradient output, but we use hardtanh here to stabilize the gradient by rescaling it between -1 and 1.

Gumbel Softmax Trick

An alternative workaround to the problem of sampling from categorical distribution being non-differentiable is to apply the trick similar to that of reparameterization trick used for Gaussian distributions. Here, we can simulate sampling from the categorical distribution by sampling noise from a Gumbel distribution, which models the maximum (or minimum) of samples from various different distributions, adding the noise to the log probabilities, and taking the argmax of the sum of log probabilities and noise (we use log probabilities because the log function is monotonically increasing). A video here by Bechtel (2022) describes Gumbel distribution concisely with its shape, where the probability density concentrates around 0.

class GumbelSoftmaxLayer(nn.Module):
    def __init__(self, temperature=10):
        super(CustomLayer, self).__init__()
        self.temperature = temperature # usually annealed
    
    def forward(self, x):
        # Sample Gumbel noise g ~ G(mu=0, beta=1)
        uniform = torch.rand(x.size())
        gumbel_noise = -torch.log(-torch.log(uniform))
 
        # Add to noise log prob
        log_prob = torch.log(x)
        log_prob += gumbel_noise
 
        # Use softmax when training and argmax or inference
        # assuming x of size (batch_size, num_latent, num_categories)
        if self.training:
            log_prob /= self.temperature
            return F.softmax(log_prob, dim=2)
        else
            return F.one_hot(torch.argmax(log_prob, dim=2)) 

By adding the sample noise from the Gumbel distribution, we can take into account the probability of a category being the maximum and thus chosen as a sample, effectively simulating sampling from the categorical distribution. The technique, Gumbel max trick, is a great reparameterization, but it still has an argmax operation that is non-differentiable. Hence, we can instead use a softmax function with a temperature parameter τ\tau instead of argmax (we have always assumed τ=1\tau = 1). The closer τ\tau is to 0, the sharper the distribution becomes and closer it gets to the argmax.

By annealing τ\tau gradually from high values to 0, we can pass gradients initially for robust training and move towards working with a latent representation closer to being discrete. During inference, we can use argmax for producing the discrete latent representation. This trick of substituting argmax with softmax with varying temperature during training is called Gumbel softmax trick. The above is a PyTorch implementation of the Gumbel softmax trick. As you can see from above, the noise can be sampled from the Gumbel distribution with μ=0\mu = 0 and β=1\beta = 1 by applying negative log twice to the sample from a uniform distribution between 0 and 1.

VQ-VAE

Another alternative approach is to introduce a learnable codebook or embedding, where there are KK number of latent vectors eie_i with DD elements. In this approach, we make the encoder output vectors zez_e with the last dimension being DD, instead of the categorical distributions, and choose the latent vector with the highest similarity based on Euclidean distance to pass to the decoder as zqz_q. Through this process called vector quantization, we can obtain one-hot encoded vectors, corresponding to the indices of the chosen eie_i, as the discrete latent representation.

VQVAE

Here, we use a deterministic posterior (based on Euclidean distance) and assume the prior to be the discrete uniform distribution during training (for unbiased training). Since we can assume that encoder output and the codebook are similar, which is especially true after training using squared distance as the loss for both ee and zez_e (like when using KL balancing), we can confidently pass gradients straight from the decoder to the encoder.

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost
 
    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous()

Though we assume a uniform prior for choosing all vectors, the true prior for selecting each vector that reflects the training data would not be uniform and would have sequential dependence. (If there are 10 digits in the training data, selecting one latent vector would influence the digit to generate with the decoder and influence the distribution of selecting the second latent vector and the subsequent vectors.)

Thus, the original paper suggests training an autoregressive model (like RNNs and Transformers) that outputs the true categorical prior distribution conditioned on previously sampled vectors (next vector prediction on latent representation). The model allows us to perform ancestral sampling, where we sample from the sequence from the conditional distribution, for generative tasks. The above is the example PyTorch implementation of a vector quantization layer for VQ-VAE. In this implementation, we compute the loss for quantization in the module where we have access to relevant components and use detach to stop the gradients.

Why Discrete Over Continuous?

We've introduced three techniques for using discrete latent representations, but haven't discussed why we would like to use discrete representation when we can use a simple continuous representation. The obvious benefit to using discrete representation is data compression. While a continuous representation requires us to store one floating-point or 32 bits per latent dimension, a discrete representation only requires log2(K)log_2(K) bits (KK is the number of categories) per latent dimension, which is often smaller than 32 bits (log2(1024)=10log_2(1024) = 10).

Although it achieves better data compression, intuitively using a discrete representation seems to limit the latent space to KK and results in lower performance than continuous representation in exchange. However, DreamerV2 counterintuitively achieves better results with categorical latents than Gaussian latents on 42 tasks. The original paper suggests that this might be because Gaussian latents assume an unimodal distribution unlike categorical latents that can fit to multi-modal distributions, may worsen unstable gradients by ϵ\epsilon unlike the straight-through estimator, and is less natural to model inherently discrete states.

Although the second point does not apply to the Gumbel softmax trick, I'd argue these three are sensible explanations for why discrete latents sometimes outperform continuous latents in some tasks even with limited expressivity. As we often deal with finite MDPs with discrete spaces and natural language and speech that are inherently discrete (which have already been quantized and tokenized), it might be more natural to use discrete latent representations for them than to forcefully map them into continuous space. In addition, VQ-VAE can be a natural model for tokenization (and is actually utilized as a model for tokenization in practice), especially for continuous data like audio and images, since it can learn the corresponding embeddings appropriate for downstream tasks.

Conclusion

In this article, we introduced the straight-through gradient estimator, Gumbel softmax trick, and vector quantization techniques that allow us to produce categorical latents. We also discussed potential reasons as to why categorical latents are sometimes highly performant and preferred in some cases. For more details on these techniques, I recommend checking out the original papers and supplementary resources cited below.

Resources