Training Techniques#

The Pytorch Ligtning Trainer class implements many advanced features to improve training speed, convergence, reproducibility, etc. In this notebook, we apply a number of these features to our EMNIST dataset.

import os
import torch
from torchvision import transforms
from utils import models
from torchvision.models import resnet18, ResNet18_Weights

Settings#

data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/model.pt"

# Model and Training
epochs=2 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=10 # parallel data loading to speed things up
lr=0.1 #learning rate (default: 0.1)
gamma=0.7 #Learning rate step gamma (default: 0.7)
seed=42 #random seed (default: 42)

EMNIST Dataset#

from utils import data

# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = train_transforms

train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)

# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)
Training dataset: Dataset EMNIST
    Number of datapoints: 112800
    Root location: /scratch/dane2/data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )
Testing dataset: Dataset EMNIST
    Number of datapoints: 18800
    Root location: /scratch/dane2/data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )
# test batch
x, y = next(iter(train_loader))

Baseline#

# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)

# init the lazy layers
with torch.no_grad():
    pt_model(x)
    
# create the ligtning model
# Note: since the last notebook, we moved the LitModel logic into utils.models
model = models.LitModel(pt_model, lr, gamma)
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch import Trainer

# a logger to save results
csv_logger = pl_loggers.CSVLogger(save_dir="logs/")

# the trainer class has about a million arguments. For now, the defaults will suffice.
trainer = Trainer(max_epochs=epochs, logger=csv_logger)
trainer.fit(model, train_loader, test_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Classifier         | 23.1 K
1 | train_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.093     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.

See logs/lightning_logs for results.

Model profiling#

Profiling tools show you how much time each part of your training code is taking. This can help you identify areas where your program should be optimized in order to speed things up.

Pytorch has a built in profiler. We can easily turn this on in Lightning by setting the profiler argument in the trainer. Note that the Pytorch profiler forces synchronous cuda execution. That makes things take longer.

# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)

# init the lazy layers
with torch.no_grad():
    pt_model(x)
    
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
from pytorch_lightning.profilers import PyTorchProfiler
profiler = PyTorchProfiler()

# just need to set the profiler argument
trainer = Trainer(max_epochs=epochs, logger=csv_logger, profiler='simple')
trainer.fit(model, train_loader, test_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Classifier         | 23.1 K
1 | train_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.093     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                          	|  -              	|  64189          	|  14.797         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                                                                                             	|  6.672          	|  2              	|  13.344         	|  90.182         	|
|  run_training_batch                                                                                                                                             	|  0.0041982      	|  1764           	|  7.4056         	|  50.049         	|
|  [LightningModule]LitModel.optimizer_step                                                                                                                       	|  0.0041312      	|  1764           	|  7.2875         	|  49.251         	|
|  [Strategy]SingleDeviceStrategy.training_step                                                                                                                   	|  0.0023097      	|  1764           	|  4.0742         	|  27.535         	|
|  [Callback]TQDMProgressBar.on_train_batch_end                                                                                                                   	|  0.0010392      	|  1764           	|  1.8332         	|  12.389         	|
|  [Strategy]SingleDeviceStrategy.backward                                                                                                                        	|  0.00097618     	|  1764           	|  1.722          	|  11.638         	|
|  [_EvaluationLoop].val_next                                                                                                                                     	|  0.016218       	|  41             	|  0.66495        	|  4.4939         	|
|  [_TrainingEpochLoop].train_dataloader_next                                                                                                                     	|  0.00020906     	|  1764           	|  0.36878        	|  2.4923         	|
|  [Strategy]SingleDeviceStrategy.batch_to_device                                                                                                                 	|  0.00016747     	|  1804           	|  0.30211        	|  2.0417         	|
|  [LightningModule]LitModel.transfer_batch_to_device                                                                                                             	|  0.0001305      	|  1804           	|  0.23541        	|  1.591          	|
|  [Strategy]SingleDeviceStrategy.validation_step                                                                                                                 	|  0.0034089      	|  40             	|  0.13635        	|  0.92152        	|
|  [LightningModule]LitModel.optimizer_zero_grad                                                                                                                  	|  7.4632e-05     	|  1764           	|  0.13165        	|  0.88972        	|
|  [Callback]TQDMProgressBar.on_validation_start                                                                                                                  	|  0.040434       	|  3              	|  0.1213         	|  0.81979        	|
|  [LightningModule]LitModel.configure_gradient_clipping                                                                                                          	|  1.8325e-05     	|  1764           	|  0.032326       	|  0.21846        	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_train_batch_end       	|  1.7976e-05     	|  1764           	|  0.031709       	|  0.2143         	|
|  [Callback]TQDMProgressBar.on_validation_batch_end                                                                                                              	|  0.00062314     	|  40             	|  0.024926       	|  0.16845        	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_train_epoch_end       	|  0.01069        	|  2              	|  0.02138        	|  0.14449        	|
|  [Callback]TQDMProgressBar.on_train_start                                                                                                                       	|  0.012711       	|  1              	|  0.012711       	|  0.085902       	|
|  [Callback]TQDMProgressBar.on_sanity_check_start                                                                                                                	|  0.011467       	|  1              	|  0.011467       	|  0.077495       	|
|  [Callback]ModelSummary.on_train_batch_end                                                                                                                      	|  3.5871e-06     	|  1764           	|  0.0063276      	|  0.042764       	|
|  [Callback]TQDMProgressBar.on_before_zero_grad                                                                                                                  	|  3.0775e-06     	|  1764           	|  0.0054287      	|  0.036688       	|
|  [Callback]TQDMProgressBar.on_validation_batch_start                                                                                                            	|  0.000128       	|  40             	|  0.0051199      	|  0.034601       	|
|  [Callback]TQDMProgressBar.on_train_batch_start                                                                                                                 	|  2.6424e-06     	|  1764           	|  0.0046612      	|  0.031502       	|
|  [Callback]TQDMProgressBar.on_after_backward                                                                                                                    	|  2.566e-06      	|  1764           	|  0.0045265      	|  0.030591       	|
|  [Callback]TQDMProgressBar.on_before_backward                                                                                                                   	|  2.4907e-06     	|  1764           	|  0.0043937      	|  0.029694       	|
|  [Callback]TQDMProgressBar.on_before_optimizer_step                                                                                                             	|  1.8815e-06     	|  1764           	|  0.0033191      	|  0.022431       	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_after_backward        	|  1.6706e-06     	|  1764           	|  0.0029469      	|  0.019916       	|
|  [LightningModule]LitModel.on_before_batch_transfer                                                                                                             	|  1.5896e-06     	|  1804           	|  0.0028677      	|  0.01938        	|
|  [Callback]ModelSummary.on_before_zero_grad                                                                                                                     	|  1.5617e-06     	|  1764           	|  0.0027549      	|  0.018618       	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_before_zero_grad      	|  1.5176e-06     	|  1764           	|  0.002677       	|  0.018092       	|
|  [Callback]ModelSummary.on_train_batch_start                                                                                                                    	|  1.4731e-06     	|  1764           	|  0.0025985      	|  0.017561       	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_train_batch_start     	|  1.4309e-06     	|  1764           	|  0.0025241      	|  0.017058       	|
|  [Callback]TQDMProgressBar.on_validation_end                                                                                                                    	|  0.00081106     	|  3              	|  0.0024332      	|  0.016444       	|
|  [Callback]ModelSummary.on_before_backward                                                                                                                      	|  1.3715e-06     	|  1764           	|  0.0024193      	|  0.01635        	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_before_backward       	|  1.3586e-06     	|  1764           	|  0.0023965      	|  0.016196       	|
|  [Callback]ModelSummary.on_after_backward                                                                                                                       	|  1.2958e-06     	|  1764           	|  0.0022858      	|  0.015448       	|
|  [LightningModule]LitModel.on_after_batch_transfer                                                                                                              	|  1.2316e-06     	|  1804           	|  0.0022217      	|  0.015015       	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_before_optimizer_step 	|  1.2388e-06     	|  1764           	|  0.0021852      	|  0.014768       	|
|  [Callback]ModelSummary.on_before_optimizer_step                                                                                                                	|  1.2288e-06     	|  1764           	|  0.0021675      	|  0.014649       	|
|  [LightningModule]LitModel.on_train_batch_end                                                                                                                   	|  1.1941e-06     	|  1764           	|  0.0021064      	|  0.014236       	|
|  [Callback]TQDMProgressBar.on_train_epoch_start                                                                                                                 	|  0.00097899     	|  2              	|  0.001958       	|  0.013233       	|
|  [LightningModule]LitModel.on_train_batch_start                                                                                                                 	|  1.0119e-06     	|  1764           	|  0.0017849      	|  0.012063       	|
|  [LightningModule]LitModel.on_before_zero_grad                                                                                                                  	|  9.9306e-07     	|  1764           	|  0.0017518      	|  0.011839       	|
|  [LightningModule]LitModel.on_after_backward                                                                                                                    	|  9.5494e-07     	|  1764           	|  0.0016845      	|  0.011384       	|
|  [LightningModule]LitModel.on_before_backward                                                                                                                   	|  9.1773e-07     	|  1764           	|  0.0016189      	|  0.010941       	|
|  [LightningModule]LitModel.on_before_optimizer_step                                                                                                             	|  8.0452e-07     	|  1764           	|  0.0014192      	|  0.0095912      	|
|  [Strategy]SingleDeviceStrategy.on_train_batch_start                                                                                                            	|  8.0104e-07     	|  1764           	|  0.001413       	|  0.0095497      	|
|  [Callback]ModelSummary.on_fit_start                                                                                                                            	|  0.001351       	|  1              	|  0.001351       	|  0.0091307      	|
|  [LightningModule]LitModel.on_validation_model_eval                                                                                                             	|  0.00039541     	|  3              	|  0.0011862      	|  0.0080169      	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.setup                    	|  0.00079733     	|  1              	|  0.00079733     	|  0.0053885      	|
|  [Callback]TQDMProgressBar.on_train_epoch_end                                                                                                                   	|  0.00034932     	|  2              	|  0.00069864     	|  0.0047216      	|
|  [Callback]TQDMProgressBar.on_train_end                                                                                                                         	|  0.00055111     	|  1              	|  0.00055111     	|  0.0037245      	|
|  [LightningModule]LitModel.on_validation_model_train                                                                                                            	|  0.00011774     	|  3              	|  0.00035323     	|  0.0023872      	|
|  [LightningModule]LitModel.configure_optimizers                                                                                                                 	|  0.00014666     	|  1              	|  0.00014666     	|  0.00099116     	|
|  [Callback]ModelSummary.on_validation_batch_end                                                                                                                 	|  3.5676e-06     	|  40             	|  0.0001427      	|  0.00096444     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_validation_end        	|  4.3696e-05     	|  3              	|  0.00013109     	|  0.00088594     	|
|  [LightningModule]LitModel.lr_scheduler_step                                                                                                                    	|  5.1686e-05     	|  2              	|  0.00010337     	|  0.00069861     	|
|  [Callback]ModelSummary.on_validation_batch_start                                                                                                               	|  1.9563e-06     	|  40             	|  7.8253e-05     	|  0.00052886     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_validation_batch_end  	|  1.7672e-06     	|  40             	|  7.0687e-05     	|  0.00047772     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_validation_batch_start	|  1.6458e-06     	|  40             	|  6.5831e-05     	|  0.00044491     	|
|  [LightningModule]LitModel.on_validation_batch_end                                                                                                              	|  1.1852e-06     	|  40             	|  4.7408e-05     	|  0.0003204      	|
|  [LightningModule]LitModel.on_validation_batch_start                                                                                                            	|  9.9554e-07     	|  40             	|  3.9821e-05     	|  0.00026912     	|
|  [Strategy]SingleDeviceStrategy.on_validation_start                                                                                                             	|  7.5934e-06     	|  3              	|  2.278e-05      	|  0.00015395     	|
|  [Callback]ModelSummary.on_validation_start                                                                                                                     	|  7.4717e-06     	|  3              	|  2.2415e-05     	|  0.00015149     	|
|  [Callback]ModelSummary.on_validation_end                                                                                                                       	|  3.2149e-06     	|  3              	|  9.6448e-06     	|  6.5182e-05     	|
|  [Callback]TQDMProgressBar.on_validation_epoch_end                                                                                                              	|  2.7524e-06     	|  3              	|  8.2571e-06     	|  5.5804e-05     	|
|  [LightningModule]LitModel.on_train_epoch_end                                                                                                                   	|  3.9898e-06     	|  2              	|  7.9796e-06     	|  5.3928e-05     	|
|  [Callback]ModelSummary.on_train_epoch_end                                                                                                                      	|  3.2689e-06     	|  2              	|  6.5379e-06     	|  4.4185e-05     	|
|  [Callback]ModelSummary.on_train_epoch_start                                                                                                                    	|  3.2475e-06     	|  2              	|  6.495e-06      	|  4.3895e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_validation_start      	|  2.0663e-06     	|  3              	|  6.1989e-06     	|  4.1894e-05     	|
|  [Callback]ModelSummary.on_train_start                                                                                                                          	|  5.8357e-06     	|  1              	|  5.8357e-06     	|  3.9439e-05     	|
|  [Callback]TQDMProgressBar.on_save_checkpoint                                                                                                                   	|  2.9039e-06     	|  2              	|  5.8077e-06     	|  3.925e-05      	|
|  [Callback]TQDMProgressBar.on_validation_epoch_start                                                                                                            	|  1.7403e-06     	|  3              	|  5.221e-06      	|  3.5285e-05     	|
|  [Callback]TQDMProgressBar.setup                                                                                                                                	|  4.6473e-06     	|  1              	|  4.6473e-06     	|  3.1408e-05     	|
|  [Callback]ModelSummary.on_validation_epoch_end                                                                                                                 	|  1.5199e-06     	|  3              	|  4.5598e-06     	|  3.0816e-05     	|
|  [LightningModule]LitModel.on_validation_end                                                                                                                    	|  1.4435e-06     	|  3              	|  4.3306e-06     	|  2.9268e-05     	|
|  [Callback]TQDMProgressBar.on_fit_end                                                                                                                           	|  4.1407e-06     	|  1              	|  4.1407e-06     	|  2.7984e-05     	|
|  [Callback]TQDMProgressBar.on_sanity_check_end                                                                                                                  	|  4.122e-06      	|  1              	|  4.122e-06      	|  2.7858e-05     	|
|  [LightningModule]LitModel.on_train_epoch_start                                                                                                                 	|  2.0051e-06     	|  2              	|  4.0103e-06     	|  2.7102e-05     	|
|  [LightningModule]LitModel.on_validation_start                                                                                                                  	|  1.3088e-06     	|  3              	|  3.9265e-06     	|  2.6536e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_validation_epoch_end  	|  1.302e-06      	|  3              	|  3.906e-06      	|  2.6398e-05     	|
|  [Callback]ModelSummary.on_validation_epoch_start                                                                                                               	|  1.2126e-06     	|  3              	|  3.6377e-06     	|  2.4585e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_validation_epoch_start	|  1.1958e-06     	|  3              	|  3.5875e-06     	|  2.4245e-05     	|
|  [LightningModule]LitModel.on_validation_epoch_end                                                                                                              	|  1.1759e-06     	|  3              	|  3.5278e-06     	|  2.3842e-05     	|
|  [Callback]ModelSummary.on_save_checkpoint                                                                                                                      	|  1.7025e-06     	|  2              	|  3.4049e-06     	|  2.3011e-05     	|
|  [Callback]ModelSummary.on_sanity_check_start                                                                                                                   	|  3.295e-06      	|  1              	|  3.295e-06      	|  2.2269e-05     	|
|  [Callback]ModelSummary.on_train_end                                                                                                                            	|  3.295e-06      	|  1              	|  3.295e-06      	|  2.2269e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_train_start           	|  3.2205e-06     	|  1              	|  3.2205e-06     	|  2.1765e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_fit_start             	|  3.1982e-06     	|  1              	|  3.1982e-06     	|  2.1614e-05     	|
|  [Strategy]SingleDeviceStrategy.on_train_start                                                                                                                  	|  3.0417e-06     	|  1              	|  3.0417e-06     	|  2.0557e-05     	|
|  [Strategy]SingleDeviceStrategy.on_validation_end                                                                                                               	|  9.7789e-07     	|  3              	|  2.9337e-06     	|  1.9826e-05     	|
|  [LightningModule]LitModel.on_validation_epoch_start                                                                                                            	|  9.6609e-07     	|  3              	|  2.8983e-06     	|  1.9587e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_train_epoch_start     	|  1.4193e-06     	|  2              	|  2.8387e-06     	|  1.9184e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_save_checkpoint       	|  1.3905e-06     	|  2              	|  2.7809e-06     	|  1.8794e-05     	|
|  [LightningModule]LitModel.on_train_start                                                                                                                       	|  2.5611e-06     	|  1              	|  2.5611e-06     	|  1.7309e-05     	|
|  [LightningModule]LitModel.configure_callbacks                                                                                                                  	|  2.5574e-06     	|  1              	|  2.5574e-06     	|  1.7284e-05     	|
|  [Callback]TQDMProgressBar.on_fit_start                                                                                                                         	|  2.5127e-06     	|  1              	|  2.5127e-06     	|  1.6982e-05     	|
|  [Callback]TQDMProgressBar.teardown                                                                                                                             	|  2.4326e-06     	|  1              	|  2.4326e-06     	|  1.644e-05      	|
|  [LightningModule]LitModel.on_save_checkpoint                                                                                                                   	|  1.0366e-06     	|  2              	|  2.0731e-06     	|  1.4011e-05     	|
|  [LightningModule]LitModel.setup                                                                                                                                	|  2.0284e-06     	|  1              	|  2.0284e-06     	|  1.3709e-05     	|
|  [Callback]ModelSummary.on_sanity_check_end                                                                                                                     	|  2.0135e-06     	|  1              	|  2.0135e-06     	|  1.3608e-05     	|
|  [Callback]ModelSummary.setup                                                                                                                                   	|  1.7472e-06     	|  1              	|  1.7472e-06     	|  1.1808e-05     	|
|  [LightningModule]LitModel.prepare_data                                                                                                                         	|  1.6913e-06     	|  1              	|  1.6913e-06     	|  1.143e-05      	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.teardown                 	|  1.628e-06      	|  1              	|  1.628e-06      	|  1.1002e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_sanity_check_start    	|  1.6037e-06     	|  1              	|  1.6037e-06     	|  1.0838e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_fit_end               	|  1.5199e-06     	|  1              	|  1.5199e-06     	|  1.0272e-05     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_train_end             	|  1.492e-06      	|  1              	|  1.492e-06      	|  1.0083e-05     	|
|  [LightningModule]LitModel.on_fit_end                                                                                                                           	|  1.4734e-06     	|  1              	|  1.4734e-06     	|  9.9573e-06     	|
|  [Callback]ModelSummary.on_fit_end                                                                                                                              	|  1.4585e-06     	|  1              	|  1.4585e-06     	|  9.8566e-06     	|
|  [LightningModule]LitModel.configure_sharded_model                                                                                                              	|  1.4193e-06     	|  1              	|  1.4193e-06     	|  9.5922e-06     	|
|  [Strategy]SingleDeviceStrategy.on_train_end                                                                                                                    	|  1.4193e-06     	|  1              	|  1.4193e-06     	|  9.5922e-06     	|
|  [LightningModule]LitModel.on_train_end                                                                                                                         	|  1.4156e-06     	|  1              	|  1.4156e-06     	|  9.5671e-06     	|
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}.on_sanity_check_end      	|  1.3392e-06     	|  1              	|  1.3392e-06     	|  9.0509e-06     	|
|  [LightningModule]LitModel.teardown                                                                                                                             	|  1.3392e-06     	|  1              	|  1.3392e-06     	|  9.0509e-06     	|
|  [Callback]ModelSummary.teardown                                                                                                                                	|  1.3318e-06     	|  1              	|  1.3318e-06     	|  9.0006e-06     	|
|  [LightningModule]LitModel.on_fit_start                                                                                                                         	|  1.302e-06      	|  1              	|  1.302e-06      	|  8.7992e-06     	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Lightning supports several other performance profilers.

Logging with Weights and Biases#

Logging is simply the act of recording data throughout model training and evaluation that can be used to make decisions about model development. Earlier, we used the CSV logger to record experimental results. Weights and Biases (WandB) is an online platform for logging training experiments. It provides a range of data collection and visualization tools to help you understand how your training is going. WandB is free for academic use cases. It’s also very easy to integrate with Pytorch Lightning.

# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)

# init the lazy layers
with torch.no_grad():
    pt_model(x)
    
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name='initial_run')

# just need to set the profiler argument
trainer = Trainer(max_epochs=epochs, logger=wandb_logger)
trainer.fit(model, train_loader, test_loader)

# indicate that the run has finished
wandb_logger.experiment.finish()
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: dhudsmith (wficai-fast). Use `wandb login --relogin` to force relogin
Tracking run with wandb version 0.15.4
Run data is saved locally in ./wandb/run-20230621_221708-uugnumm8
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Classifier         | 23.1 K
1 | train_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.093     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
Waiting for W&B process to finish... (success).

Run history:


epoch▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███████████████████
train_acc▁▄▄▅▆▅▇▆▇▆▆▆▇▆▇▆▆▇▇█▇▇▇▇▇██▇▇▇▇▇▇██
train_loss█▆▅▄▃▃▃▃▂▂▂▃▂▂▂▂▂▂▂▁▁▂▂▂▂▁▁▂▂▂▂▂▁▂▁
trainer/global_step▁▁▂▂▃▃▃▄▄▁▁▁▁▁▁▁▁▁▁▄▅▅▅▆▆▇▇▇█▁▁▁▁▁▁▁▁▁▁█
val_acc_epoch▁█
val_acc_step▅▅▅▅▃▂▄▃▅▄▂▄▂▃▁▄▄▂▄█▇▇▇▆▅▆▆▇█▆▇▆▆▄▇▇▅█
val_loss_epoch█▁
val_loss_step▅▅▄▅▆▇▅▆▄▄▆▄▆▅█▆▄▅▄▂▂▂▂▃▄▃▃▂▂▃▁▃▂▆▃▂▃▁

Run summary:


epoch1
train_acc0.85938
train_loss0.46852
trainer/global_step1763
val_acc_epoch0.79718
val_acc_step0.8125
val_loss_epoch0.67434
val_loss_step0.60827

View run initial_run at: https://wandb.ai/wficai-fast/lightning_logs/runs/uugnumm8
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20230621_221708-uugnumm8/logs

Automatic mixed precision (AMP)#

By default, PyTorch uses 32-bit floating point numbers. These means that each element of a tensor, takes up 32 bits / 4 Bytes of memory.

torch.randn(3).dtype
torch.float32

32-bit floats can store about 8 digits of precision. This level of precision may not be necessary for many of the computations performed by the neural network. Pytorch supports several other floating point formats, that we can make use of. For instance, we can allocate 16-bit tensors:

torch.randn(3, dtype=torch.float16).dtype
torch.float16

16-bit floats take up half the memory of 32 bit, so this may allow us to train larger models on the same GPU hardware. In addition, modern GPU architectures can perform some calculations more efficiently with 16-bit numbers.

Unfortunately, it turns out that its usually not a good idea to convert all aspects of our computation into 16-bit floats. The research community has come up with good approaches to mixing 32-bit and 16-bit computation to get the benefits of using lower-precision without hurting model convergence. Manually setting all of this up is a headache, so Pytorch supports “Automatic Mixed Precision” to perform the conversion automatically under the hood.

To use this functionality, you use the autocast context manager:

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
# this improves convergence by preventing underflow
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

AMP in Pytorch Ligtning#

Fortunately for us, it is extremely easy to implement AMP now that we have our model set up in Pytorch Lightning

# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)

# init the lazy layers
with torch.no_grad():
    pt_model(x)
    
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
wandb_logger = WandbLogger(name='AMP run')

trainer = Trainer(max_epochs=epochs, logger=wandb_logger, 
                  precision=16) #<-- this
trainer.fit(model, train_loader, test_loader)

wandb_logger.experiment.finish()
Tracking run with wandb version 0.15.4
Run data is saved locally in ./wandb/run-20230621_221732-4r3rplac
Syncing run AMP run to Weights & Biases (docs)
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Classifier         | 23.1 K
1 | train_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.093     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
Waiting for W&B process to finish... (success).

Run history:


epoch▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███████████████████
train_acc▁▃▄▅▅▆▆▆▆▇▆▆▆▆▆▇▆▆▇▇██▇▇▆▇▆▇▇▇▇▇▇▇▇
train_loss█▆▅▄▄▃▃▃▃▂▃▃▂▃▂▂▂▂▂▂▁▁▂▂▂▁▂▁▁▂▁▁▂▂▁
trainer/global_step▁▁▂▂▃▃▃▄▄▁▁▁▁▁▁▁▁▁▁▄▅▅▅▆▆▇▇▇█▁▁▁▁▁▁▁▁▁▁█
val_acc_epoch▁█
val_acc_step▄▃▃▄▃▂▂▂▃▃▃▄▂▁▂▅▃▄▂▇▇▆█▆▅▆▇▇▇▆▇▅▅▄█▇▇▆
val_loss_epoch█▁
val_loss_step▅▆▅▅▇▇▆▆▅▅▆▄▇▆█▆▅▅▅▂▂▂▂▃▄▂▃▂▁▃▁▄▃▅▂▁▂▂

Run summary:


epoch1
train_acc0.79688
train_loss0.62448
trainer/global_step1763
val_acc_epoch0.77883
val_acc_step0.7775
val_loss_epoch0.75217
val_loss_step0.71923

View run AMP run at: https://wandb.ai/wficai-fast/lightning_logs/runs/4r3rplac
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20230621_221732-4r3rplac/logs

Model checkpointing#

Simply put, checkpointing is the process of saving a model periodically based on a metric that you monitor. If the metric has improved, save the model. In addition to saving the model weights, we need to save the hyperparameters. We do this by calling self.save_hyperparameters() in the initializer for the lightning model:

class LitModel(pl.LightningModule):
    def __init__(self, pytorch_model, lr, gamma):
        super().__init__()
        self.save_hyperparameters()
        ...

Unlike the previous options, we need to use a Callback method to set up checkpointing.

# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)

# init the lazy layers
with torch.no_grad():
    pt_model(x)
     
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
# need to import the checkpoint callback object
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint("checkpoints/", save_top_k=3, monitor='val_loss')

# we add log_model='all' to save models on wandb
wandb_logger = WandbLogger(name='Run with Checkpointing', log_model='all')

# need to pass the checkpoint callback
trainer = Trainer(max_epochs=epochs, logger=wandb_logger, precision=16, callbacks = [checkpoint_callback])
trainer.fit(model, train_loader, test_loader)

wandb_logger.experiment.finish()
Tracking run with wandb version 0.15.4
Run data is saved locally in ./wandb/run-20230621_221757-tgzrctm2
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Classifier         | 23.1 K
1 | train_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.093     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=2` reached.
Waiting for W&B process to finish... (success).

Run history:


epoch▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███████████████████
train_acc▁▂▄▄▅▅▆▅▆▆▆▇▆▆▇▇▇▆▆▇█▇▇▇▇█▇▇█▇▇▇██▇
train_loss█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▂▁▁▂▁▁▁▂
trainer/global_step▁▁▂▂▃▃▃▄▄▁▁▁▁▁▁▁▁▁▁▄▅▅▅▆▆▇▇▇█▁▁▁▁▁▁▁▁▁▁█
val_acc_epoch▁█
val_acc_step▅▅▄▃▁▁▃▃▄▅▂▄▄▁▂▅▃▄▃██▇▆▆▅▆▆▇▇▇█▇▇▆█▇▆▇
val_loss_epoch█▁
val_loss_step▅▅▄▄▆▇▆▆▄▄▆▄▆▆█▅▄▅▄▂▂▁▁▃▄▃▂▁▁▃▁▃▂▅▂▁▂▁

Run summary:


epoch1
train_acc0.77344
train_loss0.85265
trainer/global_step1763
val_acc_epoch0.78878
val_acc_step0.79125
val_loss_epoch0.70496
val_loss_step0.64834

View run Run with Checkpointing at: https://wandb.ai/wficai-fast/lightning_logs/runs/tgzrctm2
Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20230621_221757-tgzrctm2/logs

Creating the model from a local checkpoint#

To load a model from a checkpoint, we use the load_from_checkpoint, passing in the path.

model = models.LitModel.load_from_checkpoint('checkpoints/epoch=1-step=1764.ckpt')
# when creating a trainer just for validation, we don't need to fuss over the arguments.
trainer = Trainer()
trainer.validate(model, test_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       val_acc_epoch           0.7675532102584839     │
│      val_loss_epoch           0.7949550747871399     │
└───────────────────────────┴───────────────────────────┘
[{'val_loss_epoch': 0.7949550747871399, 'val_acc_epoch': 0.7675532102584839}]

Early Stopping#

Early stopping, as the name implies, is a technique for stopping training early if a monitored metric is not improving. This can save lots of time and compute resources. We set this up in Lightning using a callback.

# init the classifier
pt_model = models.Classifier() #models.make_resnet18_model(weights=ResNet18_Weights.IMAGENET1K_V1)

# init the lazy layers
with torch.no_grad():
    pt_model(x)
    
# create the ligtning model
model = models.LitModel(pt_model, lr, gamma)
# need to import the early stop callback object
from lightning.pytorch.callbacks import EarlyStopping
earlystop_callback = EarlyStopping(monitor="val_loss", min_delta=0.005, patience=3, verbose=False)

checkpoint_callback = ModelCheckpoint("checkpoints/", save_top_k=3, monitor='val_loss')
wandb_logger = WandbLogger(name='Early stopping run')


# need to pass the checkpoint callback
trainer = Trainer(max_epochs=20, logger=wandb_logger, precision=16, 
                  callbacks = [earlystop_callback, checkpoint_callback])
trainer.fit(model, train_loader, test_loader)

wandb_logger.experiment.finish()
Tracking run with wandb version 0.15.4
Run data is saved locally in ./wandb/run-20230621_221825-s41wgsxm
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Classifier         | 23.1 K
1 | train_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
23.1 K    Trainable params
0         Non-trainable params
23.1 K    Total params
0.093     Total estimated model params size (MB)
Waiting for W&B process to finish... (success).

Run history:


epoch▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇████
train_acc▁▆▆▅▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇█▇▆▇█▇▆▇▇▇▇▇▇▇██▇██▇▇
train_loss█▃▃▃▂▂▂▁▂▂▁▁▂▂▁▁▁▁▁▁▁▂▁▁▁▂▂▁▂▂▁▂▂▁▁▁▁▁▁▂
trainer/global_step▁▁▁▁▂▁▁▂▃▁▁▃▄▁▁▄▄▁▄▅▁▁▅▁▁▆▆▁▁▆▇▁▁▇█▁██▁▁
val_acc_epoch▁▄▆▇▇▇▇████
val_acc_step▁▁▃▂▄▄▃▃▅▇▇▅▆▅▆▆▇▆▇▇▆█▆▇▇█▇▇▄▆▇▇▇▆▇▃▆▇▇▇
val_loss_epoch█▄▃▂▂▁▁▁▁▁▁
val_loss_step▇█▅▇▄▃▅▄▆▂▃▃▄▄▂▃▂▃▃▃▄▂▂▃▂▃▄▄▆▃▃▁▂▄▁▆▃▅▁▂

Run summary:


epoch10
train_acc0.82031
train_loss0.72232
trainer/global_step9701
val_acc_epoch0.82207
val_acc_step0.825
val_loss_epoch0.58447
val_loss_step0.54862

View run Early stopping run at: https://wandb.ai/wficai-fast/lightning_logs/runs/s41wgsxm
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20230621_221825-s41wgsxm/logs

Multi-gpu#

Single-node, multi-gpu#

Pytorch Lightning makes this very easy. In fact, Lightning will automatically use all available gpus by default. However, it is tricky to get this working in Jupyter notebooks. To demonstrate, we have create a script that you can download here.

Multi-node, multi-gpu#

This is a bit trickier because it involves setting up communication across nodes. We have an example of how to set this up in our Palmetto Examples repository here.