Model fine-tuning#

Model fine-tuning is the process of taking a model that has already been pre-trained on some large, diverse task and then fine-tuning it to the task of interest. This can lead to massive performance benefits for a given sample size. The closer the target task to the pre-training task, the better the transfer. However, one usually sees benefits even when the tasks are quite different (e.g. ImageNet -> Medical Ultrasound).

To demonstrate, we will start with the previous notebook and swap in a pre-trained model.

# use autoreload because, by default, python will not re-import modules
%load_ext autoreload
%autoreload 2
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

Settings#

data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/model.pt"

# Model and Training
epochs=5 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=10 # parallel data loading to speed things up
lr=0.1 #learning rate (default: 0.1)
gamma=0.7 #Learning rate step gamma (default: 0.7)
no_cuda=False #disables CUDA training (default: False)
seed=42 #random seed (default: 42)
log_interval=10 #how many batches to wait before logging training status (default: 10)
save_model=False #save the trained model (default: False)

# additional derived settings
use_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

print("Device:", device)
Device: cuda

Dataset#

from utils import data

# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = train_transforms

train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)

# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)
Training dataset: Dataset EMNIST
    Number of datapoints: 112800
    Root location: /scratch/dane2/data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )
Testing dataset: Dataset EMNIST
    Number of datapoints: 18800
    Root location: /scratch/dane2/data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )

Model definition#

The torchvision library provides many pre-defined model architectures and trained model weights. Many other models can be downloaded using Pytorch Image Models and the Huggingface libraries.

from torchvision.models import resnet18, ResNet18_Weights

# pretrained weights with advertised accuracy of 80.858% on the validation set
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# note: the input convolution expects 3 channels. Why is this a problem?
model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
# let's make sure we can run a batch of data through the model
with torch.no_grad():
    x, y = next(iter(train_loader))
    
    try:
        y_hat = model(x)
        print(y_hat.shape, y_hat)
    except RuntimeError as e: 
        print("RuntimeError:", e)
    
# we can't because we didn't deal with the fact that the model expects 3 color channels
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[128, 1, 28, 28] to have 3 channels, but got 1 channels instead

To solve this issue, let’s just swap out the initial convolution layer with one expecting a single channel. This convolution will be trained from scratch.

model.conv1
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.conv1 = torch.nn.Conv2d(
    in_channels=1, # we changed this
    out_channels=model.conv1.out_channels, 
    kernel_size=model.conv1.kernel_size,
    stride=model.conv1.stride,
    padding=model.conv1.padding,
    bias=model.conv1.bias
)

model
ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
Number of parameters: 11683240
# let's try again
with torch.no_grad():
    x, y = next(iter(train_loader))
    y_hat = model(x)
    
y_hat.shape, y_hat
# The output size is wrong!
(torch.Size([128, 1000]),
 tensor([[ 2.8179,  1.0689, -1.9409,  ...,  5.1819, -0.3923, -2.8032],
         [-1.2234, -1.5119, -2.3858,  ..., -1.5502,  3.3604, -0.6424],
         [ 0.6216,  2.9419,  2.5593,  ...,  1.7307,  1.4611,  0.2213],
         ...,
         [ 0.9893, -1.8194, -2.8329,  ...,  1.8187,  1.2180,  3.1364],
         [ 7.4379,  0.9004,  0.0970,  ..., -3.9618,  4.8143, -1.9797],
         [-1.3743,  0.2013, -0.7154,  ..., -1.4267, -0.3651, -0.3076]]))
# let's install a fresh classification head
model.fc = torch.nn.Linear(512, 47, bias=True)
# let's try again
with torch.no_grad():
    x, y = next(iter(train_loader))
    y_hat = model(x)
    
y_hat.shape, y_hat
(torch.Size([128, 47]),
 tensor([[-0.2310,  2.5642, -1.1058,  ..., -1.1087, -1.2465,  1.3283],
         [-0.5069,  0.3748,  0.3644,  ...,  0.5988,  0.6146,  0.0753],
         [ 0.0221,  0.8552,  0.5377,  ..., -0.3202, -0.4839,  0.7633],
         ...,
         [-0.8688,  1.5012,  0.1901,  ..., -0.5180,  0.5192,  0.0882],
         [-0.2565,  1.0510, -0.1019,  ..., -0.7193,  1.3646,  0.1110],
         [ 0.5696,  0.2535, -0.4605,  ...,  1.0174, -0.0918,  2.3605]]))

Looks good! Our model is ready for training.

Scriptify model creation#

Now that we’ve got this working, it would be a good idea to put this logic into our models.py script. When doing so, we might want to pass in the weights as an argument. This will allow us to load in different pretrained weights or none at all for random initialization.

from utils import models
model_pretrained = models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
model_random = models.make_resnet18_model(weights=None)

Training and testing#

We can re-use our training code. Note: training will take much longer because ResNet18 is a much larger model.

from utils import training

Random weight model#

_ = model_random.to(device)
training.train_and_test(model_random, train_loader, test_loader, epochs, lr, gamma, device)
[TRAIN] epoch 1: Average loss: 0.0047                                                                 
[TEST] epoch 1: Average loss: 0.4532, Accuracy: 15790/18800 (83.99%)
[TRAIN] epoch 2: Average loss: 0.0025                                                                 
[TEST] epoch 2: Average loss: 0.3840, Accuracy: 16197/18800 (86.15%)
[TRAIN] epoch 3: Average loss: 0.0020                                                                 
[TEST] epoch 3: Average loss: 0.3500, Accuracy: 16482/18800 (87.67%)
[TRAIN] epoch 4: Average loss: 0.0017                                                                 
[TEST] epoch 4: Average loss: 0.3358, Accuracy: 16568/18800 (88.13%)
[TRAIN] epoch 5: Average loss: 0.0014                                                                 
[TEST] epoch 5: Average loss: 0.3333, Accuracy: 16624/18800 (88.43%)

Pretrained Model#

_ = model_pretrained.to(device)
training.train_and_test(model_pretrained, train_loader, test_loader, epochs, lr, gamma, device)
[TRAIN] epoch 1: Average loss: 0.0055                                                                 
[TEST] epoch 1: Average loss: 0.5303, Accuracy: 15412/18800 (81.98%)
[TRAIN] epoch 2: Average loss: 0.0028                                                                 
[TEST] epoch 2: Average loss: 0.3655, Accuracy: 16403/18800 (87.25%)
[TRAIN] epoch 3: Average loss: 0.0022                                                                 
[TEST] epoch 3: Average loss: 0.3439, Accuracy: 16555/18800 (88.06%)
[TRAIN] epoch 4: Average loss: 0.0020                                                                 
[TEST] epoch 4: Average loss: 0.3309, Accuracy: 16597/18800 (88.28%)
[TRAIN] epoch 5: Average loss: 0.0017                                                                 
[TEST] epoch 5: Average loss: 0.3347, Accuracy: 16634/18800 (88.48%)

Conclusions: we see that both versions are better than our custom architecture though there’s not a large benefit from fine-tuning. This is probably because the ImageNet dataset is very different different from the EMNIST dataset.