Building the network#

The nn.Module subpackage in PyTorch contains many neural network building blocks called “modules”. We can compose these in arbitrary ways to build network architectures tailored to a given problem.

import torch
import torch.nn as nn

# do everything on gpu unless we explicitly say otherwise
torch.set_default_device('cuda')

The basics#

We saw examples like this in earlier notebooks:

model = nn.Sequential(
    nn.Linear(10,10),
    nn.Tanh(),
    nn.Linear(10,10),
    nn.Tanh(),
    nn.Linear(10,3),
    nn.Sigmoid()
)

# printing the model shows the layers
model

nn.Sequential, nn.Linear, nn.Tanh, and nn.Sigmoid are all examples of modules. There are many more. You can see a full list here: https://pytorch.org/docs/stable/nn.html

Callable. All modules are callable, meaning they can be evaluated like a function:

layer = nn.Linear(4,5)
x = torch.randn(7, 4)
layer(x)
layer = nn.Tanh()
layer(x)

Changing device. Modules can be moved between devices. Unlike tensors, this operation is in place.

layer = nn.Linear(4,5)
print("Before:", layer.weight.device)
layer.to('cpu')
print("After:", layer.weight.device)

All nested modules also move:

model = nn.Sequential(
    nn.Linear(10,10),
    nn.Tanh(),
    nn.Linear(10,3)
)

print("Before:", model[0].weight.device)
model.to('cpu')
print("After:", model[0].weight.device)

# back on gpu for later
model.to('cuda')

Saving/loading. Model weights can be saved to and loaded from disc. There are a few ways to do this. The recommended way is to just save the weights using the “state dict” object:

for k, v in model.state_dict().items():
    print(k, v.shape)
torch.save(model.state_dict(), 'model_weights.pt')
# Pytorch uses a version of pickle to save the weights
!head -n 3 model_weights.pt
# some time later...
model2 = nn.Sequential(
    nn.Linear(10,10),
    nn.Tanh(),
    nn.Linear(10,3)
)

model2.load_state_dict(torch.load('model_weights.pt', weights_only=True))

Using the state dict required that we instantiate the model class first. We can also save the model structure together.

torch.save(model, 'model.pt')
model2 = torch.load('model.pt')

Using model.state_dict() to save weights offers greater flexibility and compatibility, as it separates the model’s parameters from its architecture, making it easy to update the model class or share weights. This approach results in smaller files and better portability across environments or versions of PyTorch, whereas saving the entire model (torch.save(model, ...)) is simpler but less adaptable to changes.

eval/train modes. Some layers need to behave differently at training time and evaluation time. These can all be toggled with the train() and eval() methods:

layer = nn.Dropout(0.5)

# the default mode is "training"
x = torch.randn(3, 5)
print(x)
layer(x)
# switch to eval:
layer.eval()
layer(x)
# switch back to train
layer.train()
layer(x)

Writing custom modules#

You can make your own modules. To do so, subclass nn.Module and define the __init__ and forward method. These modules can be used just like any other module.

class NeuralNetwork(nn.Module):
    def __init__(self):
        """
        The __init__ method defines all of the modules/parameters that will 
        appear in the model.
        """
        super().__init__()
        self.flatten = nn.Flatten()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256,1)
        )

    def forward(self, x):
        """
        Define how to get from the input to the output. 
        You can use arbitrary python code here so long as the 
        tensor operations are differentiable. 
        """
        x = self.flatten(x)
        h = self.encoder(x)
        y = self.classifier(h)
        return y
    
model = NeuralNetwork()
model
# simulate a batch of grayscale images:
x = torch.randn(5, 1, 28, 28)

model(x)

You can customize your network however you see fit. For example, say we had a problem where the network took two images as input and made some decision about them. We could do something like this:

class PairNetwork(nn.Module):
    def __init__(self):
        """
        The __init__ method defines all of the modules/parameters that will 
        appear in the model.
        """
        super().__init__()
        self.flatten = nn.Flatten()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(2*256,1)  # double the representation size
        )

    def forward(self, x1, x2):
        """
        Define how to get from the input to the output. 
        You can use arbitrary python code here so long as the 
        tensor operations are differentiable. 
        """
        x1 = self.flatten(x1)
        h1 = self.encoder(x1)
        
        x2 = self.flatten(x2)
        h2 = self.encoder(x2)
        
        # fuse the representations
        h = torch.concat([h1, h2], axis=-1)
        
        y = self.classifier(h)
        return y
    
pair_model = PairNetwork()
pair_model
# simulate a batch of grayscale images:
x1 = torch.randn(5, 1, 28, 28)
x2 = torch.randn(5, 1, 28, 28)

pair_model(x1, x2)

Tracking parameters Pytorch automatically tracks all of the parameters that appear in your custom model. This allows Pytorch to optimize the network during training. It allows can allow you to get diagnostic information such as the number of parameters in your model:

num_pars = sum([p.numel() for p in model.parameters()])
print("Number of parameters:", num_pars)