Pytorch Lightning#
Pytorch Lightning wraps your Pytorch code and implements many common workflows. For instance, the training and testing loops always look very similar. Lightning makes it so that you don’t have to re-write this boilerplate code for every project. The best way to understand it, is to just implement a Pytorch Lightning model. Let’s take our previous code for EMNIST and refactor it as a Pytorch Lightning model.
# use autoreload because, by default, python will not re-import modules
%load_ext autoreload
%autoreload 2
import os
import torch
from torchvision import transforms
from utils import models
from torchvision.models import resnet18, ResNet18_Weights
We don’t specify anything about the device here. Pytorch Lightning will automatically detect and use our gpu.
data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/"
# Model and Training
epochs=5 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=10 # parallel data loading to speed things up
lr=0.1 #learning rate (default: 0.1)
gamma=0.7 #Learning rate step gamma (default: 0.7)
seed=42 #random seed (default: 42)
from utils import data
# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
test_transforms = train_transforms
train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)
# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)
The Lightning Model#
We implement Lightning Models like normal Pytorch models: we define the architecture and a forward method for passing data through the model. In addition, we implement methods defining training, validation, and optimization.
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import lightning.pytorch as pl
import torchmetrics
# define the LightningModule
class LitModel(pl.LightningModule):
def __init__(self, pytorch_model, lr, gamma):
self.model = pytorch_model = lr
self.gamma = gamma
# metrics
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=47)
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=47)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# lightning automatically puts the model in train mode
# gradient updates etc. are handled automatically
# but can be customized if desired
data, target = batch
output = self.model(data)
loss = F.cross_entropy(output, target)
self.log("train_loss", loss)
self.train_acc(output, target)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
# lightning automatically puts the model in eval mode
# and turns off gradient tracking
data, target = batch
output = self.model(data)
loss = F.cross_entropy(output, target)
self.log("val_loss", loss)
self.test_acc(output, target)
self.log("test_acc", self.test_acc, on_step=True, on_epoch=True)
def configure_optimizers(self):
optimizer = optim.Adadelta(model.parameters(),
scheduler = StepLR(optimizer, step_size=1, gamma=self.gamma)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
# init the model
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
model = LitModel(pt_model, lr, gamma)
# we use a pytorch lightning model just like a normal model
with torch.no_grad():
x, y = next(iter(train_loader))
y_hat = model(x)
y_hat.shape, y_hat
Looks good! Our model is ready for training.
Training and testing#
We no longer need our training/testing functions. Lightning construct the appropriate training loop based on the definitions in our lightning module.
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch import Trainer
# a logger to save results
csv_logger = pl_loggers.CSVLogger(save_dir="logs/")
# the trainer class has about a million arguments. For now, the defaults will suffice.
trainer = Trainer(max_epochs=epochs, logger=csv_logger), train_loader, test_loader)
See logs/lightning_logs
for results.
Next session, we will look at some of the advanced features that we can access now that we have our model set up in Lightning:
Multi-gpu training
Automatic mixed precision
Advanced logging and dashboards
Performance profiling
Hyperparameter tuning