Pytorch Lightning#
Pytorch Lightning wraps your Pytorch code and implements many common workflows. For instance, the training and testing loops always look very similar. Lightning makes it so that you don’t have to re-write this boilerplate code for every project. The best way to understand it, is to just implement a Pytorch Lightning model. Let’s take our previous code for EMNIST and refactor it as a Pytorch Lightning model.
# use autoreload because, by default, python will not re-import modules
%load_ext autoreload
%autoreload 2
import os
import torch
from torchvision import transforms
from utils import models
from torchvision.models import resnet18, ResNet18_Weights
Settings#
We don’t specify anything about the device here. Pytorch Lightning will automatically detect and use our gpu.
data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/model.pt"
# Model and Training
epochs=5 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=10 # parallel data loading to speed things up
lr=0.1 #learning rate (default: 0.1)
gamma=0.7 #Learning rate step gamma (default: 0.7)
seed=42 #random seed (default: 42)
Dataset#
from utils import data
# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = train_transforms
train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)
# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)
The Lightning Model#
We implement Lightning Models like normal Pytorch models: we define the architecture and a forward method for passing data through the model. In addition, we implement methods defining training, validation, and optimization.
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import lightning.pytorch as pl
import torchmetrics
# define the LightningModule
class LitModel(pl.LightningModule):
def __init__(self, pytorch_model, lr, gamma):
super().__init__()
self.save_hyperparameters("lr", "gamma")
self.model = pytorch_model
self.lr = lr
self.gamma = gamma
# metrics
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=47)
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=47)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# lightning automatically puts the model in train mode
# gradient updates etc. are handled automatically
# but can be customized if desired
data, target = batch
output = self.model(data)
loss = F.cross_entropy(output, target)
self.log("train_loss", loss)
self.train_acc(output, target)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
# lightning automatically puts the model in eval mode
# and turns off gradient tracking
data, target = batch
output = self.model(data)
loss = F.cross_entropy(output, target)
self.log("val_loss", loss)
self.test_acc(output, target)
self.log("test_acc", self.test_acc, on_step=True, on_epoch=True)
def configure_optimizers(self):
optimizer = optim.Adadelta(model.parameters(), lr=self.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=self.gamma)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
# init the model
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
model = LitModel(pt_model, lr, gamma)
# we use a pytorch lightning model just like a normal model
with torch.no_grad():
x, y = next(iter(train_loader))
y_hat = model(x)
y_hat.shape, y_hat
Looks good! Our model is ready for training.
Training and testing#
We no longer need our training/testing functions. Lightning construct the appropriate training loop based on the definitions in our lightning module.
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch import Trainer
# Without this line, there's a warning that points us to using it.
# This allows a slight tradeoff of precision for speed with our GPU's tensor cores.
# See https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision('medium')
# a logger to save results
csv_logger = pl_loggers.CSVLogger(save_dir="logs/")
# the trainer class has about a million arguments. For now, the defaults will suffice.
trainer = Trainer(max_epochs=epochs, logger=csv_logger)
trainer.fit(model, train_loader, test_loader)
See logs/lightning_logs
for results.
Next session, we will look at some of the advanced features that we can access now that we have our model set up in Lightning:
Multi-gpu training
Automatic mixed precision
Advanced logging and dashboards
Performance profiling
Hyperparameter tuning
Coding challenge#
The current model uses StepLR
(inside the configure_optimizers
method) which reduces the learning rate by a factor of gamma
every step_size
epochs. This creates a “staircase” pattern where the learning rate drops abruptly at regular intervals.
Try experimenting with a different scheduling strategy. Popular alternatives include ExponentialLR
, CosineAnnealingLR
, or ReduceLROnPlateau
. You can select one of these or other built-in schedulers from PyTorch or implement your own custom scheduler. See the documentation for available options.
Documentation: PyTorch LR Schedulers | Lightning LR Scheduler Config
# Copy/paste the LitModel definition above, and modify it to use a different learning rate scheduler.
# init the model
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
model = LitModel(pt_model, lr, gamma)
torch.set_float32_matmul_precision('medium')
# a logger to save results
csv_logger = pl_loggers.CSVLogger(save_dir="logs/")
# the trainer class has about a million arguments. For now, the defaults will suffice.
trainer = Trainer(max_epochs=epochs, logger=csv_logger)
trainer.fit(model, train_loader, test_loader)
from utils.response import create_answer_box
create_answer_box(
"Which alternative did you use? What results did you observe?", "04-01")