Model fine-tuning#
Model fine-tuning is the process of taking a model that has already been pre-trained on some large, diverse task and then fine-tuning it to the task of interest. This can lead to massive performance benefits for a given sample size. The closer the target task to the pre-training task, the better the transfer. However, one usually sees benefits even when the tasks are quite different (e.g. ImageNet -> Medical Ultrasound).
To demonstrate, we will start with the previous notebook and swap in a pre-trained model.
# use autoreload because, by default, python will not re-import modules
%load_ext autoreload
%autoreload 2
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
Settings#
data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/model.pt"
# Model and Training
epochs=5 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=10 # parallel data loading to speed things up
lr=0.1 #learning rate (default: 0.1)
gamma=0.7 #Learning rate step gamma (default: 0.7)
no_cuda=False #disables CUDA training (default: False)
seed=42 #random seed (default: 42)
log_interval=10 #how many batches to wait before logging training status (default: 10)
save_model=False #save the trained model (default: False)
# additional derived settings
use_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
print("Device:", device)
Dataset#
from utils import data
# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = train_transforms
train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)
# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)
Model definition#
The torchvision library provides many pre-defined model architectures and trained model weights. Many other models can be downloaded using Pytorch Image Models and the Huggingface libraries.
from torchvision.models import resnet18, ResNet18_Weights
# pretrained weights with advertised accuracy of 80.858% on the validation set
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# note: the input convolution expects 3 channels. Why is this a problem?
model
# let's make sure we can run a batch of data through the model
with torch.no_grad():
x, y = next(iter(train_loader))
try:
y_hat = model(x)
print(y_hat.shape, y_hat)
except RuntimeError as e:
print("RuntimeError:", e)
# we can't because we didn't deal with the fact that the model expects 3 color channels
To solve this issue, let’s just swap out the initial convolution layer with one expecting a single channel. This convolution will be trained from scratch.
model.conv1
model.conv1 = torch.nn.Conv2d(
in_channels=1, # we changed this
out_channels=model.conv1.out_channels,
kernel_size=model.conv1.kernel_size,
stride=model.conv1.stride,
padding=model.conv1.padding,
bias=model.conv1.bias
)
model
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
# let's try again
with torch.no_grad():
x, y = next(iter(train_loader))
y_hat = model(x)
y_hat.shape, y_hat
# The output size is wrong!
# let's install a fresh classification head
model.fc = torch.nn.Linear(512, 47, bias=True)
# let's try again
with torch.no_grad():
x, y = next(iter(train_loader))
y_hat = model(x)
y_hat.shape, y_hat
Looks good! Our model is ready for training.
Scriptify model creation#
Now that we’ve got this working, it would be a good idea to put this logic into our models.py script. When doing so, we might want to pass in the weights as an argument. This will allow us to load in different pretrained weights or none at all for random initialization.
from utils import models
model_pretrained = models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
model_random = models.make_resnet18_model(weights=None)
Training and testing#
We can re-use our training code. Note: training will take much longer because ResNet18 is a much larger model.
from utils import training
Random weight model#
_ = model_random.to(device)
training.train_and_test(model_random, train_loader, test_loader, epochs, lr, gamma, device)
Pretrained Model#
_ = model_pretrained.to(device)
training.train_and_test(model_pretrained, train_loader, test_loader, epochs, lr, gamma, device)
Conclusions: we see that both versions are better than our custom architecture though there’s not a large benefit from fine-tuning. This is probably because the ImageNet dataset is very different different from the EMNIST dataset.