Training Techniques#
The Pytorch Lightning Trainer
class implements many advanced features to improve training speed, convergence, reproducibility, etc. In this notebook, we apply a number of these features to our EMNIST dataset.
import os
import torch
from torchvision import transforms
from utils import models
from torchvision.models import resnet18, ResNet18_Weights
Settings#
data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/model.pt"
# Model and Training
epochs=2 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=2 # parallel data loading to speed things up
lr=0.1 #learning rate (default: 0.1)
gamma=0.7 #Learning rate step gamma (default: 0.7)
seed=42 #random seed (default: 42)
EMNIST Dataset#
from utils import data
# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = train_transforms
train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)
# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)
# test batch
x, y = next(iter(train_loader))
Baseline#
# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
# init the lazy layers
with torch.no_grad():
pt_model(x)
# create the lightning model
# Note: since the last notebook, we moved the LitModel logic into utils.models
model = models.LitModel(pt_model, lr, gamma)
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch import Trainer
# a logger to save results
csv_logger = pl_loggers.CSVLogger(save_dir="logs/")
# the trainer class has about a million arguments. For now, the defaults will suffice. See docs here: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html
trainer = Trainer(max_epochs=epochs, logger=csv_logger)
trainer.fit(model, train_loader, test_loader)
See logs/lightning_logs
for results.
Model profiling#
Profiling tools show you how much time each part of your training code is taking. This can help you identify areas where your program should be optimized in order to speed things up.
Pytorch has a built in profiler. We can easily turn this on in Lightning by setting the profiler
argument in the trainer. Note that the Pytorch profiler forces synchronous cuda execution. That makes things take longer.
# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
# init the lazy layers
with torch.no_grad():
pt_model(x)
# create the lightning model
model = models.LitModel(pt_model, lr, gamma)
# just need to set the profiler argument
trainer = Trainer(max_epochs=epochs, logger=csv_logger, profiler='simple')
trainer.fit(model, train_loader, test_loader)
Lightning supports several other performance profilers.
Logging with Weights and Biases#
Logging is simply the act of recording data throughout model training and evaluation that can be used to make decisions about model development. Earlier, we used the CSV logger to record experimental results. Weights and Biases (WandB) is an online platform for logging training experiments. It provides a range of data collection and visualization tools to help you understand how your training is going. WandB is free for academic use cases. It’s also very easy to integrate with Pytorch Lightning.
# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
# init the lazy layers
with torch.no_grad():
pt_model(x)
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name='initial_run')
# just need to set the profiler argument
trainer = Trainer(max_epochs=epochs, logger=wandb_logger)
trainer.fit(model, train_loader, test_loader)
# indicate that the run has finished
wandb_logger.experiment.finish()
Automatic mixed precision (AMP)#
By default, PyTorch uses 32-bit floating point numbers. These means that each element of a tensor, takes up 32 bits / 4 Bytes of memory.
torch.randn(3).dtype
32-bit floats can store about 8 digits of precision. This level of precision may not be necessary for many of the computations performed by the neural network. Pytorch supports several other floating point formats, that we can make use of. For instance, we can allocate 16-bit tensors:
torch.randn(3, dtype=torch.float16).dtype
16-bit floats take up half the memory of 32 bit, so this may allow us to train larger models on the same GPU hardware. In addition, modern GPU architectures can perform some calculations more efficiently with 16-bit numbers.
Unfortunately, it turns out that its usually not a good idea to convert all aspects of our computation into 16-bit floats. The research community has come up with good approaches to mixing 32-bit and 16-bit computation to get the benefits of using lower-precision without hurting model convergence. Manually setting all of this up is a headache, so Pytorch supports “Automatic Mixed Precision” to perform the conversion automatically under the hood.
To use this functionality, you use the autocast
context manager:
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
# this improves convergence by preventing underflow
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
AMP in Pytorch Ligtning#
Fortunately for us, it is extremely easy to implement AMP now that we have our model set up in Pytorch Lightning
# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
# init the lazy layers
with torch.no_grad():
pt_model(x)
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name='AMP run')
trainer = Trainer(max_epochs=epochs, logger=wandb_logger,
precision=16) #<-- this
trainer.fit(model, train_loader, test_loader)
wandb_logger.experiment.finish()
Model checkpointing#
Simply put, checkpointing is the process of saving a model periodically based on a metric that you monitor. If the metric has improved, save the model. In addition to saving the model weights, we need to save the hyperparameters. We do this by calling self.save_hyperparameters()
in the initializer for the lightning model:
class LitModel(pl.LightningModule):
def __init__(self, pytorch_model, lr, gamma):
super().__init__()
self.save_hyperparameters()
...
Unlike the previous options, we need to use a Callback method to set up checkpointing.
# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
# init the lazy layers
with torch.no_grad():
pt_model(x)
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
# need to import the checkpoint callback object
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint("checkpoints/", save_top_k=3, monitor='val_loss')
# we add log_model='all' to save models on wandb
wandb_logger = WandbLogger(name='Run with Checkpointing', log_model='all')
# need to pass the checkpoint callback
trainer = Trainer(max_epochs=epochs, logger=wandb_logger, precision=16, callbacks = [checkpoint_callback])
trainer.fit(model, train_loader, test_loader)
wandb_logger.experiment.finish()
Creating the model from a local checkpoint#
To load a model from a checkpoint, we use the load_from_checkpoint
, passing in the path.
model = models.LitModel.load_from_checkpoint('checkpoints/epoch=1-step=1764.ckpt')
# when creating a trainer just for validation, we don't need to fuss over the arguments.
trainer = Trainer()
trainer.validate(model, test_loader)
Early Stopping#
Early stopping, as the name implies, is a technique for stopping training early if a monitored metric is not improving. This can save lots of time and compute resources. We set this up in Lightning using a callback.
# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)
# init the lazy layers
with torch.no_grad():
pt_model(x)
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
# need to import the early stop callback object
from lightning.pytorch.callbacks import EarlyStopping
earlystop_callback = EarlyStopping(monitor="val_loss", min_delta=0.005, patience=3, verbose=False)
checkpoint_callback = ModelCheckpoint("checkpoints/", save_top_k=3, monitor='val_loss')
wandb_logger = WandbLogger(name='Early stopping run')
# need to pass the checkpoint callback
trainer = Trainer(max_epochs=20, logger=wandb_logger, precision=16,
callbacks = [earlystop_callback, checkpoint_callback])
trainer.fit(model, train_loader, test_loader)
wandb_logger.experiment.finish()
Multi-gpu#
Single-node, multi-gpu#
Pytorch Lightning makes this very easy. In fact, Lightning will automatically use all available gpus by default. However, it is tricky to get this working in Jupyter notebooks. To demonstrate, we have create a script that you can download here.
Multi-node, multi-gpu#
This is a bit trickier because it involves setting up communication across nodes. We have an example of how to set this up in our Palmetto Examples repository here.