The blog post discusses about ViT in deep learning.

So far, we have only been discussing attention and transformers in the context of natural language processing and analogous tasks that involve sequential data. However, they can also be applied to computer vision tasks.
Image as Words
One of the challenges of dealing with images is their sheer size. The typical image size we handle in computer vision is 224x224 pixels, with three color channels, totaling approximately 150,000 pixel values. Processing individual pixels as units of input is infeasible for transformers. This is why convolutional neural networks (CNNs), with shared kernels, have dominated computer vision despite the success of transformers in natural language processing.
class Patches(layers.Layer):
def __init__(self, patch_size, **kwargs):
super(Patches, self).__init__(**kwargs)
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims]) # batch_size, patch_num, patch_dims
return patches
However, Dosovitskiy, A. et al. (2021) proposed that transformers can also be applied to images without relying on CNNs by dividing the image into patches (typically 16x16 pixels) and treating these patches as words. This approach reduces the input size to a manageable amount for transformers (from 150,000 pixels to 588 patches). A simple linear layer can then reduce the dimensionality of each patch from 768 (16×16×3) and create embeddings of appropriate size, retaining as much information as possible.
Learnable Positional Embeddings
A significant problem with dividing images into patches is the loss of positional information. Unlike NLP tasks, where positional encoding using sine and cosine waves works well for sequential 1D data with repetitive patterns like sentences, these encodings cannot directly represent spatial positions in 2D images.

To address this, Dosovitskiy et al. (2021) introduced learnable positional embeddings with the same size as the patch embeddings. These embeddings are trained to capture the spatial relationships between patches in 2D images. After training, the positional embeddings successfully represent the appropriate positions of patches in 2D space, demonstrating their effectiveness. With both patch embeddings and positional embeddings prepared, they can be passed to a transformer for computer vision tasks.
Vision Transformers
A notable transformer architecture for image classification is the Vision Transformer (ViT). It consists of a transformer encoder that processes patch embeddings, positional embeddings, and an additional learnable class embedding. This class embedding helps the model classify the image by learning to represent the class information, which corresponds to the output of the class embedding. Below is an illustration of the ViT architecture.

The TensorFlow implementation of learnable positional embeddings, including the extra class token and the ViT model, is presented below. If you are unfamiliar with transformer encoders, refer to the previous article, Road to ML Engineer #29 - BERT vs GPT.
class PatchClassEmbedding(layers.Layer):
def __init__(self, d_model, n_patches, kernel_initializer="he_normal", **kwargs):
super(PatchClassEmbedding, self).__init__(**kwargs)
self.n_tot_patches = n_patches + 1
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.class_embed = self.add_weight(
shape=(1, 1, d_model),
initializer=self.kernel_initializer,
name="class_token",
) # extra learnable class
self.position_embedding = layers.Embedding(
input_dim=(self.n_tot_patches), output_dim=d_model
)
def call(self, inputs):
positions = tf.range(start=0, limit=self.n_tot_patches, delta=1)
x = tf.repeat(self.class_embed, tf.shape(inputs)[0], axis=0)
x = tf.concat((x, inputs), axis=1)
encoded = x + self.position_embedding(positions)
return encoded
class ViT(tf.keras.Model):
def __init__(self, num_classes, patch_size, num_patches, embed_dim,
hidden_dim, num_layers=4, num_heads=8):
super(ViT, self).__init__()
self.patches = Patches(patch_size)
self.linear = layers.Dense(embed_dim)
self.embedding = PatchClassEmbedding(embed_dim, num_patches)
self.encoder = tf.keras.Sequential([
TransformerEncoderBlock(embed_dim, hidden_dim)
for i in range(num_layers)
])
self.classifier = tf.keras.Sequential([
layers.Lambda(lambda x: x[:,0,:]), ## take the class embedding for classification
tf.keras.layers.Dense(num_classes, activation='softmax')
])
def call(self, x):
x = self.patches(x)
x = self.linear(x)
x = self.embedding(x)
x = self.encoder(x)
x = self.classifier(x)
return x
The ViT architecture is very similar to BERT, differing primarily in the preprocessing phase. It has been shown to achieve state-of-the-art results on small- and medium-sized image recognition tasks when pretrained on large datasets and fine-tuned. This success may stem from the reduced inductive bias of transformers, allowing them to flexibly learn complex patterns when sufficient data is available.
Conclusion
As with the previous article, I have intentionally omitted pipeline implementations and PyTorch examples to encourage you to try them yourself. For an alternative approach, you can replace the preprocessing phase with convolutional layers to build a Hybrid Vision Transformer. This variant has proven more effective for smaller models, though its performance is comparable to vanilla ViT for larger models. Implementing a Hybrid Vision Transformer is another worthwhile exercise.
ViT highlights the flexibility of transformer architectures, which are now used for a variety of data types and tasks beyond NLP. Many details and findings have been left out of this article, so I highly recommend reading the original paper cited below.
Resources
- Dosovitskiy, A. et al. 2021. An Image is worth 16×16 words: Transformers for image recognition at scale. arXiv.
- Pinecone. 2023. Vision Transformers (ViT) Explained. Pinecone.