Datasets and data loading#

The torch.utils.data subpackage is an important part of PyTorch for developing neural networks. The Dataset class represents a dataset and provides an interface to access the data samples. The DataLoader class helps fetch data from the dataset and prepare it for passing to your neural network.

Case study: ImageNet data#

The ImageNet-1000 image classification task has been a huge driver of progress in deep learning. Let’s get to know this dataset.

import os

# Locate the images
image_dir = '/project/rcde/datasets/imagenet/ILSVRC/Data/CLS-LOC/'
os.listdir(image_dir)
['test', 'train', 'val']

Imagenet has 1000 different classes. Each class has its own sub-folder (test dataset is organized differently):

for d in os.listdir(image_dir):
    print(d, len(os.listdir(image_dir+d)))
test 100000
train 1000
val 1000

These classes have uninformative directory names:

os.listdir(image_dir+'train')[:5]
['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475']

There’s a file that maps from these strange names to human-readable names:

! head -n 5 '/project/rcde/datasets/imagenet/LOC_synset_mapping.txt'
n01440764 tench, Tinca tinca
n01443537 goldfish, Carassius auratus
n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
n01491361 tiger shark, Galeocerdo cuvieri
n01494475 hammerhead, hammerhead shark
with open('/project/rcde/datasets/imagenet/LOC_synset_mapping.txt') as f: 
    lines = f.readlines()
    
    # we will need these two dictionaries
    class2label = {l[:9].strip(): l[10:-1].strip() for l in lines}
    class2ix = {l[:9].strip(): ix for ix, l in enumerate(lines)}

Most classes have 1300 training images

for cls in os.listdir(image_dir+'train')[::50]:
    print(class2label[cls], len(os.listdir(f"{image_dir}train/{cls}/")))
tench, Tinca tinca 1300
American alligator, Alligator mississipiensis 1300
black swan, Cygnus atratus 1300
sea lion 1300
Tibetan terrier, chrysanthemum dog 1300
Siberian husky 1300
tiger beetle 1300
ibex, Capra ibex 1300
academic gown, academic robe, judge's robe 1300
bobsled, bobsleigh, bob 1300
cliff dwelling 1300
espresso maker 1136
hook, claw 1300
microphone, mike 1300
paper towel 1300
quilt, comforter, comfort, puff 1300
slot, one-armed bandit 1300
teddy, teddy bear 1300
water tower 1300
orange 1300

Most classes have only 50 validation samples

for cls in os.listdir(image_dir+'val')[::50]:
    print(class2label[cls], len(os.listdir(f"{image_dir}val/{cls}/")))
tench, Tinca tinca 50
American alligator, Alligator mississipiensis 50
black swan, Cygnus atratus 50
sea lion 50
Tibetan terrier, chrysanthemum dog 50
Siberian husky 50
tiger beetle 50
ibex, Capra ibex 50
academic gown, academic robe, judge's robe 50
bobsled, bobsleigh, bob 50
cliff dwelling 50
espresso maker 50
hook, claw 50
microphone, mike 50
paper towel 50
quilt, comforter, comfort, puff 50
slot, one-armed bandit 50
teddy, teddy bear 50
water tower 50
orange 50

Let’s look at a a few images

from glob import glob
import matplotlib.pyplot as plt
import matplotlib.image as img

num_images = 25
sample_images = []
image_classes = []
for cls in os.listdir(image_dir+'train')[:num_images]:
    sample_images.append(glob(f"{image_dir}train/{cls}/*.*")[0])
    image_classes.append(cls)
    
fig, ax = plt.subplots(5, 5)
fig.set_size_inches(8,8)
for ix, a in enumerate(ax.flatten()):
    a.imshow(img.imread(sample_images[ix]))
    a.set_title(class2label[image_classes[ix]].split(',')[0])
    a.axis('off')
../_images/23b5a0a99048d4e537cf95d0934ed473b09482e2eadde4068781ac2ec2e6a024.png

Map-style dataset#

Use this when you have a well-defined set of samples that you will use to train your model. This is the most common case and the natural choice for ImageNet because we have a well-defined set of images that we want to feed to our model. Let’s see how to create a map-style dataset class for ImageNet.

from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.io import ImageReadMode
from torchvision.transforms import transforms
from pathlib import Path

# subclass Dataset
class Imagenet(Dataset):
    
    def __init__(self, root_dir: str, split: str, class2ix: dict, tfms = None):
        """
        The __init__ method is called when creating an instance of the class.
        This is where we put setup code.
        
        Args:
            root_dir: full path to the imagenet CLS-LOC folder containing train and val subfolders
            split: which subfolder to use, either "train" or "val"
            tfms: image transforms 
        
        """
        self.root_dir = root_dir
        self.split = split
        self.class2ix = class2ix
        self.tfms = tfms
        
        # make sure split is supported
        assert split in {'train', 'val'}, f"Split must be one of 'train' or 'val', not {split}."
        
        # get a list of all the images
        self.image_paths = list(Path(f"{self.root_dir}/{self.split}").rglob("*.JPEG"))
        
        # create a list mapping from class to index
        self.class2ix = class2ix
                                
    def __len__(self):
        """
        Map-style datasets must define the __len__ method. These return the number of 
        samples in the dataset. 
        """
        return len(self.image_paths)
    
    def __getitem__(self, index):
        """
        Map-style datasets must define __getitem__ which takes an index and returns
        a sample. This puts the "map" in "map-style dataset" because it represents
        a mapping from some keys (indices) to the actual data. Map must return 
        a pytorch tensor or numpy array (or a collection thereof).
        """
        # the path to the selected image
        path = self.image_paths[index].as_posix()
        
        # get the class index
        # the class is the next-to-last location in the file path
        y = self.class2ix[path.split('/')[-2]]
        
        # read the instance
        x = read_image(path, mode = ImageReadMode.RGB)
        
        # scale to 0 to 1 range
        x = x / 255
        
        if self.tfms:
            x = self.tfms(x)
        
        # return the 
        return x, y
        
tfms = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std= [0.229, 0.224, 0.225]),
    transforms.Resize((224,224), antialias=True)
])
        
imagenet = Imagenet('/project/rcde/datasets/imagenet/ILSVRC/Data/CLS-LOC/', split='val', class2ix=class2ix, tfms=tfms)

print("Number of samples:", len(imagenet))
x, y = imagenet[5]
print("(x.shape, y)=", x.shape, y)
Number of samples: 50000
(x.shape, y)= torch.Size([3, 224, 224]) 0

Aside: Mini-Batch Gradient Descent#

In the Regression and Classification notebook, we trained the model by computing the loss for the entire dataset multiple times. Our training loop looked something like:

for i in range(num_epochs):
    # forward pass
    y_hat = model(x)

    # measure the loss
    # this is the mean squared error
    loss = loss_func(y_hat, y)

    # gradient computation
    loss.backward()

    # parameter updates
    optimizer.step()

Where x and y represented the entire input data and target data, respectively.


Question: What problems would we run into if we applied this to ImageNet?


In most applications of deep learning, we will instead loop over mini-batches (small subsets) of our training data. Our modified training loop will look something like:

for i in range(num_epochs):
    # Now we have an inner loop over batches of data
    for x_batch, y_batch in batches:
        # forward pass
        y_hat_batch = model(x_batch)

        # measure the loss
        # this is the mean squared error
        loss_batch = loss_func(y_hat_batch, y_batch)

        # gradient computation
        loss_batch.backward()

        # parameter updates
        optimizer.step()

Where batches is an iterable that returns tuples of the form (x_batch, y_batch) representing samples of the full dataset.

It turns out that using very large batches leads to worse performance.

large batch size

Question: Why do you think large batch size leads to worse performance?


Mini-batch gradient descent with the DataLoader class#

The Dataset class is our interface to the individual samples within our dataset. The DataLoader utility class provides an interface to batches of data. It also supports multiprocessing out of the box.

from torch.utils.data import DataLoader

# the DataLoader takes the dataset class as input
# batch_size: how many samples per mini batch
# num_workers: how many parallel processes for data loading
dl = DataLoader(imagenet, batch_size=256, num_workers=8)

print(dl)
<torch.utils.data.dataloader.DataLoader object at 0x153b730a3430>

Pytorch fetches the batches of data on the fly, so we have to request them one at a time:

x,y=next(iter(dl)) # get the first batch
print(x.shape, y.shape) 
torch.Size([256, 3, 224, 224]) torch.Size([256])

Notice the new dimension. The dataloader has bundled up the samples into a single tensor.

We’re now ready to write our new training loop:

ImageNet Training/Testing Loop#

# make the datasets
imagenet_train = Imagenet('/project/rcde/datasets/imagenet/ILSVRC/Data/CLS-LOC/', split='train', class2ix=class2ix, tfms=tfms)
imagenet_val = Imagenet('/project/rcde/datasets/imagenet/ILSVRC/Data/CLS-LOC/', split='val', class2ix=class2ix, tfms=tfms)

# create dataloaders for training and validation
dl_train = DataLoader(imagenet_train, batch_size=256, num_workers=9)
dl_val = DataLoader(imagenet_val, batch_size=256, num_workers=9)

Question: Why is it good to have separate training and validation sets?


import torch

device = torch.device('cuda')

num_epochs=3
# Take a look at `htop` and `nvidia-smi` when running this...
for i in range(num_epochs):
    print(f"[Epoch {i}] Training...")
    for ix, (x,y) in enumerate(dl_train):
        print(f"\r[Epoch {i}] Batch {ix}. x.shape={x.shape}", end='')
        
        # move to device
        x = x.to(device)
        y = y.to(device)
        
        # this is just a test, so break early
        if ix==9:
            break
            
    print(f"\n[Epoch {i}] Testing...")
    for ix, (x, y) in enumerate(dl_val):
        print(f"\r[Epoch {i}] Batch {ix}. x.shape={x.shape}", end='')
        
        # move to device
        x = x.to(device)
        y = y.to(device)
        
        # this is where we put the model evaluation logic
        
        # this is just a test, so break early
        if ix==3:
            break
            
    print()
[Epoch 0] Training...
[Epoch 0] Batch 9. x.shape=torch.Size([256, 3, 224, 224])
[Epoch 0] Testing...
[Epoch 0] Batch 3. x.shape=torch.Size([256, 3, 224, 224])
[Epoch 1] Training...
[Epoch 1] Batch 9. x.shape=torch.Size([256, 3, 224, 224])
[Epoch 1] Testing...
[Epoch 1] Batch 3. x.shape=torch.Size([256, 3, 224, 224])
[Epoch 2] Training...
[Epoch 2] Batch 9. x.shape=torch.Size([256, 3, 224, 224])
[Epoch 2] Testing...
[Epoch 2] Batch 3. x.shape=torch.Size([256, 3, 224, 224])

Let’s actually evaluate a trained model#

Training takes too long.

from torchvision.models import resnet50, ResNet50_Weights

# pretrained weights with advertised accuracy of 80.858% on the validation set
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
dl_val = DataLoader(imagenet_val, batch_size=256, num_workers=9)

preds_ls = []
targs_ls = []

# put the model in eval mode
model.eval()
for ix, (x, y) in enumerate(dl_val):
    
    # move to device
    x = x.to(device)
    y = y.to(device)

    # this is where we put the model evaluation logic
    with torch.no_grad():
        y_pred = model(x)
    
    # compute batch-level performance metrics
    pred_cls = y_pred.argmax(-1)
    top1_acc = (pred_cls == y).type(torch.float32).mean().item()
    
    # save preds for final acc calc
    preds_ls.append(pred_cls.cpu().squeeze())
    targs_ls.append(y.cpu().squeeze())
    
    print(f"Batch {ix}. Accuracy={100*top1_acc:0.1f}%")
Batch 0. Accuracy=89.8%
Batch 1. Accuracy=87.1%
Batch 2. Accuracy=98.4%
Batch 3. Accuracy=94.9%
Batch 4. Accuracy=93.8%
Batch 5. Accuracy=82.0%
Batch 6. Accuracy=71.9%
Batch 7. Accuracy=79.3%
Batch 8. Accuracy=83.2%
Batch 9. Accuracy=80.1%
Batch 10. Accuracy=78.1%
Batch 11. Accuracy=74.6%
Batch 12. Accuracy=73.8%
Batch 13. Accuracy=78.9%
Batch 14. Accuracy=76.2%
Batch 15. Accuracy=84.4%
Batch 16. Accuracy=91.0%
Batch 17. Accuracy=96.9%
Batch 18. Accuracy=96.5%
Batch 19. Accuracy=86.7%
Batch 20. Accuracy=88.7%
Batch 21. Accuracy=85.9%
Batch 22. Accuracy=91.0%
Batch 23. Accuracy=83.6%
Batch 24. Accuracy=85.5%
Batch 25. Accuracy=94.1%
Batch 26. Accuracy=91.8%
Batch 27. Accuracy=93.8%
Batch 28. Accuracy=96.5%
Batch 29. Accuracy=84.0%
Batch 30. Accuracy=85.5%
Batch 31. Accuracy=80.9%
Batch 32. Accuracy=67.6%
Batch 33. Accuracy=83.6%
Batch 34. Accuracy=88.3%
Batch 35. Accuracy=83.6%
Batch 36. Accuracy=76.6%
Batch 37. Accuracy=76.6%
Batch 38. Accuracy=85.9%
Batch 39. Accuracy=79.7%
Batch 40. Accuracy=85.5%
Batch 41. Accuracy=85.2%
Batch 42. Accuracy=92.2%
Batch 43. Accuracy=82.0%
Batch 44. Accuracy=82.4%
Batch 45. Accuracy=79.7%
Batch 46. Accuracy=77.0%
Batch 47. Accuracy=82.4%
Batch 48. Accuracy=71.5%
Batch 49. Accuracy=93.8%
Batch 50. Accuracy=91.0%
Batch 51. Accuracy=84.4%
Batch 52. Accuracy=84.0%
Batch 53. Accuracy=82.8%
Batch 54. Accuracy=76.6%
Batch 55. Accuracy=69.9%
Batch 56. Accuracy=87.1%
Batch 57. Accuracy=95.7%
Batch 58. Accuracy=86.3%
Batch 59. Accuracy=82.0%
Batch 60. Accuracy=79.3%
Batch 61. Accuracy=84.8%
Batch 62. Accuracy=91.0%
Batch 63. Accuracy=94.1%
Batch 64. Accuracy=86.7%
Batch 65. Accuracy=91.4%
Batch 66. Accuracy=85.2%
Batch 67. Accuracy=84.8%
Batch 68. Accuracy=85.2%
Batch 69. Accuracy=76.6%
Batch 70. Accuracy=83.6%
Batch 71. Accuracy=90.6%
Batch 72. Accuracy=81.2%
Batch 73. Accuracy=82.8%
Batch 74. Accuracy=71.1%
Batch 75. Accuracy=84.4%
Batch 76. Accuracy=89.1%
Batch 77. Accuracy=88.3%
Batch 78. Accuracy=77.7%
Batch 79. Accuracy=79.7%
Batch 80. Accuracy=67.6%
Batch 81. Accuracy=75.8%
Batch 82. Accuracy=79.7%
Batch 83. Accuracy=85.2%
Batch 84. Accuracy=83.2%
Batch 85. Accuracy=69.9%
Batch 86. Accuracy=78.1%
Batch 87. Accuracy=78.9%
Batch 88. Accuracy=77.7%
Batch 89. Accuracy=78.9%
Batch 90. Accuracy=63.3%
Batch 91. Accuracy=78.9%
Batch 92. Accuracy=87.5%
Batch 93. Accuracy=73.8%
Batch 94. Accuracy=64.5%
Batch 95. Accuracy=73.8%
Batch 96. Accuracy=75.8%
Batch 97. Accuracy=70.3%
Batch 98. Accuracy=68.0%
Batch 99. Accuracy=80.1%
Batch 100. Accuracy=75.0%
Batch 101. Accuracy=82.8%
Batch 102. Accuracy=68.8%
Batch 103. Accuracy=75.0%
Batch 104. Accuracy=82.8%
Batch 105. Accuracy=75.0%
Batch 106. Accuracy=82.4%
Batch 107. Accuracy=77.7%
Batch 108. Accuracy=84.4%
Batch 109. Accuracy=82.0%
Batch 110. Accuracy=83.2%
Batch 111. Accuracy=88.7%
Batch 112. Accuracy=87.9%
Batch 113. Accuracy=79.7%
Batch 114. Accuracy=69.5%
Batch 115. Accuracy=77.0%
Batch 116. Accuracy=81.2%
Batch 117. Accuracy=71.9%
Batch 118. Accuracy=88.7%
Batch 119. Accuracy=92.2%
Batch 120. Accuracy=74.6%
Batch 121. Accuracy=54.3%
Batch 122. Accuracy=82.8%
Batch 123. Accuracy=69.9%
Batch 124. Accuracy=59.0%
Batch 125. Accuracy=80.9%
Batch 126. Accuracy=82.8%
Batch 127. Accuracy=71.1%
Batch 128. Accuracy=66.0%
Batch 129. Accuracy=68.4%
Batch 130. Accuracy=86.3%
Batch 131. Accuracy=77.7%
Batch 132. Accuracy=79.3%
Batch 133. Accuracy=77.7%
Batch 134. Accuracy=76.6%
Batch 135. Accuracy=74.6%
Batch 136. Accuracy=78.9%
Batch 137. Accuracy=80.1%
Batch 138. Accuracy=75.0%
Batch 139. Accuracy=85.2%
Batch 140. Accuracy=82.8%
Batch 141. Accuracy=83.2%
Batch 142. Accuracy=72.3%
Batch 143. Accuracy=77.0%
Batch 144. Accuracy=78.9%
Batch 145. Accuracy=68.8%
Batch 146. Accuracy=75.0%
Batch 147. Accuracy=82.0%
Batch 148. Accuracy=80.1%
Batch 149. Accuracy=75.4%
Batch 150. Accuracy=80.9%
Batch 151. Accuracy=74.2%
Batch 152. Accuracy=74.2%
Batch 153. Accuracy=79.7%
Batch 154. Accuracy=80.5%
Batch 155. Accuracy=77.0%
Batch 156. Accuracy=84.4%
Batch 157. Accuracy=76.6%
Batch 158. Accuracy=62.5%
Batch 159. Accuracy=72.3%
Batch 160. Accuracy=82.8%
Batch 161. Accuracy=72.3%
Batch 162. Accuracy=80.1%
Batch 163. Accuracy=51.2%
Batch 164. Accuracy=70.3%
Batch 165. Accuracy=66.0%
Batch 166. Accuracy=85.9%
Batch 167. Accuracy=72.7%
Batch 168. Accuracy=77.3%
Batch 169. Accuracy=78.9%
Batch 170. Accuracy=87.1%
Batch 171. Accuracy=73.4%
Batch 172. Accuracy=71.1%
Batch 173. Accuracy=81.2%
Batch 174. Accuracy=78.5%
Batch 175. Accuracy=72.3%
Batch 176. Accuracy=73.8%
Batch 177. Accuracy=59.0%
Batch 178. Accuracy=83.6%
Batch 179. Accuracy=83.6%
Batch 180. Accuracy=74.6%
Batch 181. Accuracy=82.4%
Batch 182. Accuracy=85.2%
Batch 183. Accuracy=90.6%
Batch 184. Accuracy=82.8%
Batch 185. Accuracy=73.4%
Batch 186. Accuracy=92.6%
Batch 187. Accuracy=75.8%
Batch 188. Accuracy=78.9%
Batch 189. Accuracy=63.3%
Batch 190. Accuracy=72.3%
Batch 191. Accuracy=76.6%
Batch 192. Accuracy=86.3%
Batch 193. Accuracy=96.5%
Batch 194. Accuracy=89.1%
Batch 195. Accuracy=51.2%
import numpy as np

preds = torch.concat(preds_ls)
targs = torch.concat(targs_ls)
mean_top1_acc = (preds==targs).type(torch.float32).mean()

print(f"Average Accuracy={100*mean_top1_acc:0.4f}%")
Average Accuracy=80.1420%