Road to ML Engineer #19 - Generative Adversarial Networks

Last Edited: 9/14/2024

The blog post discusses about generative adversarial networks in deep learning.

ML

Generative Adversarial Networks

In the previous article, we discussed how we can modify the autoencoder to create a VAE that can generate images. In this article, I would like to talk about an interesting alternative approach: Generative Adversarial Networks (GANs). As the name suggests, GANs are generative models trained with an adversary, as shown below.

GAN

The generator is trained to produce realistic images to deceive the discriminator model, while the discriminator is trained to distinguish whether an image is generated by the generator or is a real image. After competing with each other during training, the generator should eventually produce images so realistic that they are indistinguishable from real images to the human eye.

Code Implementation

One advantage of GANs is their intuitive architecture, which doesn't require heavy mathematical understanding. Therefore, we can jump straight into the code implementation. It doesn't require much data preprocessing either—since the only preprocessing step needed is to normalize the images, while the generated images and label tensors (with just 0s and 1s) are created on the fly.

Step 3. Model

The following is the implementation of an example GAN using PyTorch and TensorFlow.

Step 4. Model Evaluation

After training the model, we can take the generator from the GAN and pass in noise of the appropriate size to generate new images. Let’s see how the images look after training a PyTorch implementation of GAN for 50 epochs.

# Generate
latent = torch.randn(10, 16)
generated = gan.generator(latent)
 
# Preprocess
generated = generated.detach().numpy()
generated = generated.reshape(generated.shape[0], 28, 28)
 
plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(generated[i], cmap='gray')
    plt.axis('off')
plt.tight_layout()
plt.show()
ML

We can already see that the model is learning to draw lines that resemble handwritten digits. I recommend you try training the model yourself more to see the real potential of GANs. According to the original GAN paper, GANs can produce clearer images than VAEs without the need for sampling, although synchronizing the training of the adversarial models can be more challenging.

Conclusion

In the last two articles, we discussed two generative models, VAE and GAN, and identified challenges with both. A common problem is how slow the training process is when dealing with small datasets and small images, which don’t reflect real-world scenarios. If we were to use real RGB images with resolutions of 1024x1024 pixels, training these models would become nearly infeasible. Therefore, we need to explore different layers that can process large images more efficiently, which we will cover in the next article.

Resources