Road to ML Engineer #22 - Transfer Learning

Last Edited: 9/25/2024

The blog post discusses about transfer learning in deep learning.

ML

In the previous article, we encountered challenges in training large models on large datasets, even with convolutional layers. To address this issue, many researchers and engineers have developed solutions and dedicated time and resources to creating highly capable models. In this article, we will explore how we can reduce effort by leveraging these solutions through transfer learning.

Transfer Learning

Transfer learning has emerged as a powerful machine learning technique, utilizing pre-trained models trained on extensive and diverse datasets to enhance performance in specific tasks, while reducing training time and data requirements. In transfer learning, the final layers of these pre-trained models used for classification can be removed and replaced with our own classifier layer(s), allowing us to train the model for a specific task while retaining the valuable features extracted by earlier convolutional layers. The following section describes one dataset and several notable pre-trained models used for transfer learning in computer vision tasks.

ImageNet

ImageNet is a large database containing over 10 million hand-annotated images across more than 20,000 categories. It has been widely used in the field of computer vision as a benchmark dataset for image recognition. Most of the pre-trained models available today are trained on the ImageNet dataset. Due to the size and diversity of the dataset, we expect that models trained on ImageNet have learned to extract essential features for classifying any image, which can be leveraged for various image recognition tasks.

VGG16

One of the pre-trained models available is VGG16, which has a straightforward architecture as illustrated below.

VGG16

When examining the architectural design of VGG16, we see a trend of progressively decreasing spatial resolution and increasing the number of channels as the depth increases. This architectural pattern is based on the concept of receptive fields (very important!), where smaller spatial resolutions enable kernels to capture larger portions of the image, facilitating the extraction of high-level features.

ResNet50

Another notable pre-trained model is ResNet50, whose architecture is shown below.

ResNet50

ResNet50 introduces residual connections (very important!), which mitigate the issue of losing initial inputs by adding the output from certain layers to the input. This mechanism helps the model focus on adjusting inputs rather than memorizing features, thereby improving performance. It also allows many layers to be stacked without encountering difficulties with learning. This is similar to how we train on residuals when using Gradient Boosting.

Other pre-trained models, like MobileNetV2, are also available. MobileNetV2 is particularly notable for its relatively low number of parameters, making it well-suited for mobile applications where computational resources are limited. I recommend exploring more pre-trained models to find the one that best fits your needs.

Fine-Tuning

Typically, the pre-trained models are set to non-trainable while training the classifier at the end, to preserve the feature extraction learned by the pre-trained models. However, you can make the last few layers trainable to tailor the feature extraction to the specific task. This technique is called fine-tuning, and we can expect better performance by applying it. However, fine-tuning increases the number of trainable weights and can make training more difficult (requiring more time and data to tailor feature extraction), while it might result in only minor performance gains for very simple tasks.

Data Augmentation

One of the biggest bottlenecks in training performant and robust machine learning models is the lack of high-quality data. While transfer learning helps address this, it is unlikely that the model will be sufficiently performant for practical use without fine-tuning, which requires additional data. This problem can be partially alleviated through data augmentation, which involves rotating, flipping, recoloring, adding noise, and more, to generate new data from existing data. Data augmentation can also help the model become more robust and applicable to real-world scenarios, where inputs are prone to noise.

Code Implementation

Now that we've understood the concepts of transfer learning, fine-tuning, and data augmentation, let's implement them in TensorFlow and PyTorch. As an example, we will train a Cat vs Dog classifier using ResNet50 pre-trained on ImageNet.

Step 1 & 2. Data Exploration and Preprocessing

For this example, we will use the dataset provided in the article Transfer learning and fine-tuning by TensorFlow Core, which contains 2000 images of cats and dogs for the training dataset and 1000 images of cats and dogs for the testing dataset. The file structure is organized as follows:

cats_and_dogs_filtered/
  ├── train/
  │  ├── cats/
  │  └── dogs/
  ├── validation/
  │  ├── cats/
  │  └── dogs/
  └── vectorize.py

From the above directory, you can create datasets in both TensorFlow and PyTorch.

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
 
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
 
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
 
# TensorFlow
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)
 
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
 
AUTOTUNE = tf.data.AUTOTUNE # Data Prefeching
 
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
 
# PyTorch
train_dataset = datasets.ImageFolder(root=train_dir)
validation_dataset = datasets.ImageFolder(root=validation_dir)
validation_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

Unlike the previous examples, we use directories containing JPG images for training. Since there are only 2000 images, we can use data augmentation to prevent overfitting. (The data augmentation is almost identical to the one specified for ResNet50 by PyTorch.)

# TensorFlow
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
])
 
# PyTorch
train_transform = transforms.Compose([
    transforms.Resize(size=IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation((-72, 72)),
    ## Uncomment the below transformations when plotting the images
    transforms.PILToTensor(),
    transforms.v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    transforms.v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # typically from ImageNet
])
test_transform = transforms.Compose([
    transforms.Resize(size=IMG_SIZE),
    transforms.PILToTensor(),
    transforms.v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    transforms.v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # typically from ImageNet
])
 
train_dataset = datasets.ImageFolder(root=train_dir, transform=train_trainsform)
validation_dataset = datasets.ImageFolder(root=validation_dir, transform=test_trainsform)
validation_dataset, test_dataset = torch.utils.data.random_split(validation_dataset, [0.8, 0.2])
 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=True)

In TensorFlow, layers such as RandomFlip and RandomRotation are set up solely for data augmentation and do not contain any trainable weights. These layers are not used for validation or test datasets. PyTorch has similar functionality, but data augmentation is performed when creating the dataset, unlike TensorFlow. For PyTorch, resizing, normalization, and other preprocessing steps also need to be done here for all datasets. (You can use other libraries to perform augmentation at any point.) You can visualize the results of both data augmentations using the code below.

# TensorFlow
for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 4))
  for i in range(9):
    ax = plt.subplot(2, 5, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(image[i], 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')
 
# PyTorch
plt.figure(figsize=(10, 4))
for i in range(9):
  ax = plt.subplot(2, 5, i + 1)
  plt.imshow(train_dataset[i][0])
  plt.axis('off')

The output below shows the result of data augmentation in the PyTorch implementation, confirming that the data augmentation is applied as expected. Thus, we can move on to Step 3.

Data Augmentation

Step 3. Model

Below is an example implementation of transfer learning and fine-tuning in TensorFlow and PyTorch.

If you're following along, I highly recommend setting finetune=False, as this task is very simple, and images of cats and dogs are highly likely to be present in ImageNet. Fine-tuning the feature extraction adds little to no value in this case and only complicates training. The above example includes fine-tuning purely for technical demonstration—it’s useful when the task is complex and the images are unlikely to have been seen by the pretrained models.

Since there’s not much to cover in Step 4, Model Evaluation, it is omitted from this article. I recommend you try evaluating the model for practice. (Spoiler Alert: The above model learns relatively quickly, classifying the test images with almost 100% accuracy after only a few epochs.)

Conclusion

There are many other pretrained models available, so I recommend checking out
Models and pre-trained weights by PyTorch or Module: tf.keras.applications by TensorFlow for alternative ways of using pretrained models and for a list of available pretrained models.

The concepts of transfer learning, fine-tuning, and data augmentation also apply to other machine learning tasks and data types, and they remain quite relevant in today’s machine learning landscape. Large Language Models (LLMs) are a prime example of where fine-tuning is frequently used to adapt a generic language model to specific needs. Speaking of LLMs, from the next article, we will start diving into natural language processing.

Resources

= TensorFlow. n.d. Transfer learning and fine-tuning. TensorFlow Core.