Attention is all you need#

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch 
from torch import nn

Make sure you have the dataset.py and utils.py in your working directory.

from dataset import PubMedDataset
from utils import train, test, generate

A little bit of history#

In the beginning was the Markov model:

autoregressive markov chain

Then came recurrent neural networks (GRU, LSTM, …):

rnn and lstm

Then came LSTMs with something called attention:

rnn and lstm

If you are feeling a little overwhelmed by this picture, you are not alone. In fact, exactly that feeling produced the title for the 2017 paper “Attention is all you need”:

rnn and lstm

And this is what their model architecture looks like:

transformer encoder decoder

However, they were focused on text translation which benefits from having a separate encoder/decoder. For language modeling, we only need the first half. The picture simplifies to:

transformer decoder only for language modeling

Now this is starting to be a little less intimidating.

Let’s take stock of what we need to figure out:

  1. Input embedding

  2. Position encoding

  3. The circle plus thing

  4. Masked Multi-Head atttention (MMHA)

  5. Connections going around the MMHA

  6. Add and Norm

  7. Fully Connected

  8. Linear and Softmax

The heart of the transformer is the “Masked Multi-Head attention” step. All of the other oparations act at the single-token level.

Masked Multi-Head Attention#

What is attention?#

Attention selects information from a set of entries based on a query. To perform this operation we need to define:

  • \(Q\): the query, represented by a numeric vector. The query specifies what kind of information should be given more attention.

  • \(K\): the keys, also vectors. Each entry in the set has a key. We compare the query with each key to see how much attention the entry should be given.

  • \(V\): the value, also usually a vector. This represents the information associated with each entry that we are retrieving.

  • \(f(Q, K_i) = \alpha_i\): the “comparison” or “compatibility” function. This function compares \(Q\) with \(K_i\), the key for entry \(i\). The function returns the attention logit \(\alpha_i\).

The attention scores are computed from the attention logits with the softmax operation: $\( a_i = \frac{\exp{(\alpha_i)}}{\sum_{j=1}^L\exp{(\alpha_j)}} \)$ In pytorch, we will simply do a = alpha.softmax(dim=-1).

Let’s work out a simple example.

query = torch.tensor([1., 0.])

values = torch.tensor([0., 
                       1., 
                       0.])

keys = torch.tensor([
    [0.95, 0.05],  # goes with value 0.
    [0.1, 0.9],  # goes with value 1. 
    [0.8, 0.2]    # goes with value 0.
])
# for our comparison function, let's just use the dot product
alpha = keys @ query
print("alpha values:", alpha)

# now compute the normalized attention scores
attn = alpha.softmax(dim=-1)
print("attention values:", attn)

# now use the attention scores to aggregate the values
result = values @ attn
print("Result:", result.item())
alpha values: tensor([0.9500, 0.1000, 0.8000])
attention values: tensor([0.4370, 0.1868, 0.3762])
Result: 0.18679720163345337

Because the query vector was more like the vectors with value 0., our result ended up closer to 0.

Check to see the result when using the query [0., 1.]

Masked Attention for autoregressive language models#

autoregressive lm

Consider the figure above. In order to make a good prediction for token 4, we need to adaptively combine the information from tokens 1, 2, and 3. Let’s use attention to do this. Here’s how we define Q, K, and, V:

  • \(Q_3 = W_Q h_3\), where \(h_3\) is the embedding for token 3 and \(W_Q\) is an embed_dim x embed_dim projection matrix.

  • \(K_{i\leq3} = W_K h_i\) where \(W_K\) is an embed_dim x embed_dim matrix.

  • \(V_{i\leq3} = W_V h_i\) where \(W_V\) is an embed_dim x embed_dim matrix.

  • \(\alpha_{i,3} = \frac{Q_3\cdot K_i}{\sqrt{|Q_3|}}\) where \(|Q_3|\) is the number of elements in \(Q_3\).

We then use softmax to normalize the attention logits yeilding the attention scores \(a_{i,3},\, i\leq3\). The output of the attention block is then $\( h^{(\rm out)}_3 = \sum_{i=1}^3 a_{i,3}V_{i} \)$

We’re now ready to implement this in code:

class MaskedAttention(nn.Module):
    def __init__(self, embed_dim, max_tokens, dropout_rate):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_tokens = max_tokens
        self.dropout_rate = dropout_rate
        
        self.scale_factor = embed_dim**-0.5
        
        # q,k,v
        self.query = nn.Linear(embed_dim, embed_dim, bias=False)  # W_Q @ h_i
        self.key = nn.Linear(embed_dim, embed_dim, bias=False)  # W_K @ h_i
        self.value = nn.Linear(embed_dim, embed_dim, bias=False)  # W_V @ h_i
        self.attn_dropout = nn.Dropout(dropout_rate)
        
        # final projection
        self.proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim, bias=False),
            nn.Dropout(dropout_rate)
        )
        
        # autoregressive mask
        self.register_buffer(
            "ar_mask",
            torch.tril(torch.ones(max_tokens, max_tokens)).unsqueeze(0)
        )  
        # self.ar_mask.shape == (1, L, L)
        # for each batch, we need to select the sub-matrix
        # of size (1, L_batch, L_batch) where L_batch<=L
        # is the sequence length for the batch.

    def forward(self, x):
        # x.shape = (N, L_batch, embed_dim)
        L_batch = x.size(1)
        
        # qkv
        q = self.query(x) # (N, L_batch, embed_dim)
        k = self.key(x) # (N, L_batch, embed_dim)
        v = self.value(x) # (N, L_batch, embed_dim)
        
        # scaled dot-product attention
        # we use einstein summation approach to avoid 
        # complicated reshape then permute operations
        alpha = torch.einsum("Nie,Nje->Nij", q, k) * self.scale_factor
        alpha = self.attn_dropout(alpha)
        # alpha.shape = (N, L_batch, L_batch)
        # the 1st L_batch dim indexes the query token, 
        # the 2nd indexes the key/val token
        
        # autoregressive masking
        mask = self.ar_mask[:, :L_batch, :L_batch] # (1, L_batch, L_batch)
        alpha = alpha.masked_fill(mask==0, float("-inf"))  # why does this work? 
        
        # normalized attention scores
        attn = alpha.softmax(-1)  # N, L_batch, L_batch
        
        # aggregate
        v_agg = torch.einsum("Nij,Nje->Nie", attn, v)
        h_out = self.proj(v_agg)
        
        return h_out # (N, L_batch, embed_dim)

h = torch.randn(3, 462, 32)  # (N, L_batch, embed_dim)
ma = MaskedAttention(embed_dim=32, max_tokens=512, dropout_rate=0.1)
h_out = ma(h)  # expect (3, 462, 32)
h_out.shape
torch.Size([3, 462, 32])

Multi-head Masked Attention#

Now we deal with the “Multi-head” part. The logic here is that using a single attention score to aggregate an entire token embedding may not have enough resolution. Perhaps their are two somewhat independent parts of the embedding that need to be attended to under different circumstances. Multi-head attention addresses this issue. Conceptually, we break up the embedding vector into num_heads smaller embedding vectors and then perform the same attention mechanism as above independently for each sub-vector. We then concatenate the resulting sub-vectors before projecting.

Once we’ve understood the single-head case, the multi-head case is not very difficult to implement. Copy-paste the MaskedAttention class and modify it to incorporate multiple heads.

class MultiHeadMaskedAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, max_tokens, dropout_rate):
        super().__init__()        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.max_tokens = max_tokens
        self.dropout_rate = dropout_rate
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisble by num_heads"
        self.head_dim = embed_dim // num_heads
        self.scale_factor = self.head_dim**-0.5  # now we scale based on head size
        
        # q,k,v
        self.query = nn.Identity()
        # self.query = nn.Linear(embed_dim, embed_dim, bias=False)  # W_Q @ h_i
        # self.key = nn.Linear(embed_dim, embed_dim, bias=False)  # W_K @ h_i
        self.value = nn.Linear(embed_dim, embed_dim, bias=False)  # W_V @ h_i
        self.attn_dropout = nn.Dropout(dropout_rate)
        
        # final projection
        self.proj = nn.Sequential(
            nn.Linear(embed_dim, embed_dim, bias=False),
            nn.Dropout(dropout_rate)
        )
        
        # autoregressive mask
        # we need one extra dimension for the head
        self.register_buffer(
            "ar_mask",
            torch.tril(torch.ones(max_tokens, max_tokens)).unsqueeze(0).unsqueeze(0)
        )  
        # self.ar_mask.shape == (1, 1, L, L)
        # for each batch, we need to select the sub-matrix
        # of size (1, 1, L_batch, L_batch) where L_batch<=L
        # is the sequence length for the batch.

    def forward(self, x):
        # x.shape = (N, L_batch, embed_dim)
        L_batch = x.size(1)
        
        # qkv
        q = self.query(x) # (N, L_batch, num_heads * head_dim)
        # k = self.key(x) # (N, L_batch, num_heads * head_dim)
        k = self.query(x) # (N, L_batch, num_heads * head_dim)  TESTING SHARED Q/K PROJECTION
        v = self.value(x) # (N, L_batch, num_heads * head_dim)
        
        # reshape to isolate head embedding
        q,k,v = [vec.view(-1, L_batch, self.num_heads, self.head_dim) for vec in (q,k,v)]
        # vec.shape == (N, L_batch, num_heads, head_dim)
        
        # scaled dot-product attention
        # we use einstein summation approach to avoid 
        # complicated reshape then permute operations
        alpha = torch.einsum("Nihe,Njhe->Nhij", q, k) * self.scale_factor
        alpha = self.attn_dropout(alpha)
        # alpha.shape = (N, num_heads, L_batch, L_batch)
        # the 1st L_batch dim indexes the query token, 
        # the 2nd indexes the key/val token
        
        # autoregressive masking
        mask = self.ar_mask[:, :, :L_batch, :L_batch] # (1, 1, L_batch, L_batch)
        alpha = alpha.masked_fill(mask==0, float("-inf")) 
        
        # normalized attention scores
        attn = alpha.softmax(-1)  # N, num_heads, L_batch, L_batch
        
        # aggregate
        v_agg = torch.einsum("Nhij,Njhe->Nihe", attn, v)  # (N,L_batch,num_heads,head_dim)
        
        # reshape to concat the heads (view won't work)
        v_agg = v_agg.reshape(-1, L_batch, self.embed_dim) # (N, L_batch, embed_dim)
        h_out = self.proj(v_agg)
        
        return h_out # (N, L_batch, embed_dim)

h = torch.randn(3, 462, 32)  # (N, L_batch, embed_dim)
ma = MultiHeadMaskedAttention(embed_dim=32, num_heads=4, max_tokens=512, dropout_rate=0.1)
h_out = ma(h)  # expect (3, 462, 32)
h_out.shape
torch.Size([3, 462, 32])

The Transformer#

Now that we’ve tackled multi-head masked attention, the rest is easy. All other operations act at the individual token level.

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, max_tokens, dropout_rate):
        super().__init__()

        self.lay_norm1 = nn.LayerNorm(embed_dim)
        self.lay_norm2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadMaskedAttention(
            embed_dim=embed_dim, 
            num_heads = num_heads,
            max_tokens=max_tokens, 
            dropout_rate=dropout_rate
        )

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, 4*embed_dim),  # the factor of 4 comes from the original GPT paper.
            nn.GELU(),  # like relu but a smooth
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        z = self.lay_norm1(x + self.attn(x))
        z = self.lay_norm2(z + self.feed_forward(z))

        return z

class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, max_tokens, num_blocks, dropout_rate):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.max_tokens = max_tokens
        self.num_blocks = num_blocks
        self.dropout_rate = dropout_rate

        # embeddings
        self.tok_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_tokens, embed_dim)

        # sequence of transformer blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, max_tokens, dropout_rate) 
            for i in range(num_blocks)])

        # output linear layer
        self.fout = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        # x.shape = (N, L)
        # mask.shape = (N, L)
        L = x.shape[-1]
        pos = torch.arange(0, L, device=x.device, dtype=torch.long)

        # embeddings
        tok_embedding = self.tok_embed(x)  # (N, L, embed_dim)
        pos_embedding = self.pos_embed(pos)  # (L, embed_dim)
        embedding = tok_embedding + pos_embedding  # (N, L, embed_dim)

        # transformer blocks
        h = self.blocks(embedding)

        # output
        logits = self.fout(h)

        return logits
    
    def numpar(self):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)    
# let's test it
dataset = PubMedDataset("/project/rcde/datasets/pubmed/mesh_50k/splits/")
dl_train = dataset.get_dataloader('train', batch_size=3)
batch = next(iter(dl_train))
print(batch.keys())
Found cached dataset text (/home/dane2/.cache/huggingface/datasets/text/default-cadbbf8acc2e2b5a/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
model = Transformer(
    vocab_size = dataset.tokenizer.vocab_size,
    embed_dim = 64, 
    num_heads = 8,
    max_tokens = 512, 
    num_blocks = 3, 
    dropout_rate = 0.1
)
print(model)
print("Trainable parameters: ", model.numpar())
model(batch['input_ids']).shape
Transformer(
  (tok_embed): Embedding(28996, 64)
  (pos_embed): Embedding(512, 64)
  (blocks): Sequential(
    (0): TransformerBlock(
      (lay_norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (lay_norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadMaskedAttention(
        (query): Identity()
        (value): Linear(in_features=64, out_features=64, bias=False)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=False)
          (1): Dropout(p=0.1, inplace=False)
        )
      )
      (feed_forward): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (1): TransformerBlock(
      (lay_norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (lay_norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadMaskedAttention(
        (query): Identity()
        (value): Linear(in_features=64, out_features=64, bias=False)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=False)
          (1): Dropout(p=0.1, inplace=False)
        )
      )
      (feed_forward): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (2): TransformerBlock(
      (lay_norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (lay_norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadMaskedAttention(
        (query): Identity()
        (value): Linear(in_features=64, out_features=64, bias=False)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=False)
          (1): Dropout(p=0.1, inplace=False)
        )
      )
      (feed_forward): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fout): Linear(in_features=64, out_features=28996, bias=True)
)
Trainable parameters:  3897860
torch.Size([3, 463, 28996])

Time to train the model#

# training settings
num_epochs=20
batch_size=64
learning_rate=0.002  # We could get better performance by using a learning rate scheduler

# model settings
embed_dim = 128 # gpt-1 uses 768. We have a much smaller dataset.
num_heads = 4  # gpt uses 12 size 64 heads.
max_tokens = 512 # gpt-1 uses 512
dropout_rate = 0.2 # gpt-1 uses 0.1
num_blocks = 6 # gpt-1 uses 12
# reinitialize dataset for good measure
dataset = PubMedDataset("/project/rcde/datasets/pubmed/mesh_50k/splits/", max_tokens=max_tokens)
vocab_size = dataset.tokenizer.vocab_size

# train/test dataloaders
dl_train = dataset.get_dataloader('train', batch_size=batch_size, num_workers=20)
dl_test = dataset.get_dataloader('test', batch_size=batch_size, num_workers=20)

# reinitialize the model on gpu
model = Transformer(
    vocab_size = dataset.tokenizer.vocab_size,
    embed_dim = embed_dim, 
    num_heads = num_heads,
    max_tokens = max_tokens, 
    num_blocks = num_blocks, 
    dropout_rate = dropout_rate
).to('cuda')

print("Trainable parameters:", model.numpar())

# create the pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
Found cached dataset text (/home/dane2/.cache/huggingface/datasets/text/default-cadbbf8acc2e2b5a/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
Trainable parameters: 8507460

This is going to take a while, and we only have 45k training samples (63MB) and a tiny model (yes 8.7 million is tiny by today’s LLM standards). GPT-3 has about 175 billion parameters and 45 TB of text data. That’s 22 thousand times more model and 700 thousand times more data… Be glad we don’t have to train that! Nevertheless, the basic architecture is very similar to what we wrote down above.

While it trains, try looking at your gpu utilization (for example nvidia-smi -l 3) and cpu utilization (top or htop). Can you identify the bottleneck in the training pipeline? How would we remedy this?

for epoch in range(num_epochs):
    print(f"START EPOCH {epoch+1}")
    print("TRAINING")
    train(model, dl_train, optimizer, reporting_interval=80)
    print("TESTING")
    test(model, dl_test, reporting_interval=20)
    print("-"*30)
START EPOCH 1
TRAINING
Batch 0 training loss: 10.453564643859863
Batch 80 training loss: 6.512960433959961
Batch 160 training loss: 5.864621639251709
Batch 240 training loss: 5.503118991851807
Batch 320 training loss: 5.313235282897949
Batch 400 training loss: 5.219045162200928
Batch 480 training loss: 5.240912437438965
Batch 560 training loss: 4.972538471221924
Batch 640 training loss: 5.059909820556641
TESTING
Batch 0 testing loss: 4.8257737159729
Batch 20 testing loss: 4.94636869430542
Batch 40 testing loss: 4.815924167633057
Batch 60 testing loss: 4.825933456420898
------------------------------
START EPOCH 2
TRAINING
Batch 0 training loss: 4.924393653869629
Batch 80 training loss: 5.031074523925781
Batch 160 training loss: 4.919568061828613
Batch 240 training loss: 4.847174644470215
Batch 320 training loss: 4.805762767791748
Batch 400 training loss: 4.793119430541992
Batch 480 training loss: 4.8642048835754395
Batch 560 training loss: 4.656754493713379
Batch 640 training loss: 4.764900207519531
TESTING
Batch 0 testing loss: 4.573569297790527
Batch 20 testing loss: 4.690598964691162
Batch 40 testing loss: 4.555648326873779
Batch 60 testing loss: 4.563152313232422
------------------------------
START EPOCH 3
TRAINING
Batch 0 training loss: 4.671664237976074
Batch 80 training loss: 4.795679569244385
Batch 160 training loss: 4.706965923309326
Batch 240 training loss: 4.662759304046631
Batch 320 training loss: 4.623923301696777
Batch 400 training loss: 4.622889995574951
Batch 480 training loss: 4.70373010635376
Batch 560 training loss: 4.514253616333008
Batch 640 training loss: 4.642773151397705
TESTING
Batch 0 testing loss: 4.458096981048584
Batch 20 testing loss: 4.56614351272583
Batch 40 testing loss: 4.429101467132568
Batch 60 testing loss: 4.430231094360352
------------------------------
START EPOCH 4
TRAINING
Batch 0 training loss: 4.535168647766113
Batch 80 training loss: 4.670892238616943
Batch 160 training loss: 4.584466934204102
Batch 240 training loss: 4.547022819519043
Batch 320 training loss: 4.5159711837768555
Batch 400 training loss: 4.534675598144531
Batch 480 training loss: 4.6015543937683105
Batch 560 training loss: 4.41730260848999
Batch 640 training loss: 4.544186115264893
TESTING
Batch 0 testing loss: 4.3786301612854
Batch 20 testing loss: 4.477602005004883
Batch 40 testing loss: 4.341939449310303
Batch 60 testing loss: 4.3493733406066895
------------------------------
START EPOCH 5
TRAINING
Batch 0 training loss: 4.447899341583252
Batch 80 training loss: 4.589491367340088
Batch 160 training loss: 4.508801460266113
Batch 240 training loss: 4.469975471496582
Batch 320 training loss: 4.443156719207764
Batch 400 training loss: 4.4578070640563965
Batch 480 training loss: 4.54600715637207
Batch 560 training loss: 4.363983154296875
Batch 640 training loss: 4.49476957321167
TESTING
Batch 0 testing loss: 4.328824996948242
Batch 20 testing loss: 4.426694869995117
Batch 40 testing loss: 4.289669036865234
Batch 60 testing loss: 4.291662216186523
------------------------------
START EPOCH 6
TRAINING
Batch 0 training loss: 4.382369518280029
Batch 80 training loss: 4.5319061279296875
Batch 160 training loss: 4.463258266448975
Batch 240 training loss: 4.418227672576904
Batch 320 training loss: 4.404354572296143
Batch 400 training loss: 4.403997898101807
Batch 480 training loss: 4.490046977996826
Batch 560 training loss: 4.307733535766602
Batch 640 training loss: 4.446858882904053
TESTING
Batch 0 testing loss: 4.279397010803223
Batch 20 testing loss: 4.383358478546143
Batch 40 testing loss: 4.245883464813232
Batch 60 testing loss: 4.250319004058838
------------------------------
START EPOCH 7
TRAINING
Batch 0 training loss: 4.338723182678223
Batch 80 training loss: 4.486749172210693
Batch 160 training loss: 4.416853427886963
Batch 240 training loss: 4.372591495513916
Batch 320 training loss: 4.354620933532715
Batch 400 training loss: 4.370021343231201
Batch 480 training loss: 4.452726364135742
Batch 560 training loss: 4.271473407745361
Batch 640 training loss: 4.400921821594238
TESTING
Batch 0 testing loss: 4.246220588684082
Batch 20 testing loss: 4.355269432067871
Batch 40 testing loss: 4.212291240692139
Batch 60 testing loss: 4.221573352813721
------------------------------
START EPOCH 8
TRAINING
Batch 0 training loss: 4.299985408782959
Batch 80 training loss: 4.4461565017700195
Batch 160 training loss: 4.374502182006836
Batch 240 training loss: 4.336207389831543
Batch 320 training loss: 4.321552753448486
Batch 400 training loss: 4.32875919342041
Batch 480 training loss: 4.408866882324219
Batch 560 training loss: 4.234516620635986
Batch 640 training loss: 4.371394634246826
TESTING
Batch 0 testing loss: 4.210877418518066
Batch 20 testing loss: 4.328013896942139
Batch 40 testing loss: 4.178816795349121
Batch 60 testing loss: 4.186997413635254
------------------------------
START EPOCH 9
TRAINING
Batch 0 training loss: 4.254145622253418
Batch 80 training loss: 4.417206287384033
Batch 160 training loss: 4.347931861877441
Batch 240 training loss: 4.318736553192139
Batch 320 training loss: 4.2999587059021
Batch 400 training loss: 4.295639514923096
Batch 480 training loss: 4.386085510253906
Batch 560 training loss: 4.218393325805664
Batch 640 training loss: 4.349240303039551
TESTING
Batch 0 testing loss: 4.187094688415527
Batch 20 testing loss: 4.297364234924316
Batch 40 testing loss: 4.153755187988281
Batch 60 testing loss: 4.161794185638428
------------------------------
START EPOCH 10
TRAINING
Batch 0 training loss: 4.230549335479736
Batch 80 training loss: 4.394379138946533
Batch 160 training loss: 4.307508945465088
Batch 240 training loss: 4.290331840515137
Batch 320 training loss: 4.256217956542969
Batch 400 training loss: 4.2689104080200195
Batch 480 training loss: 4.3596367835998535
Batch 560 training loss: 4.182875633239746
Batch 640 training loss: 4.32074499130249
TESTING
Batch 0 testing loss: 4.161235809326172
Batch 20 testing loss: 4.285717964172363
Batch 40 testing loss: 4.134584426879883
Batch 60 testing loss: 4.143280506134033
------------------------------
START EPOCH 11
TRAINING
Batch 0 training loss: 4.205535411834717
Batch 80 training loss: 4.3611979484558105
Batch 160 training loss: 4.291257858276367
Batch 240 training loss: 4.2684831619262695
Batch 320 training loss: 4.239354133605957
Batch 400 training loss: 4.234607696533203
Batch 480 training loss: 4.333159446716309
Batch 560 training loss: 4.152698040008545
Batch 640 training loss: 4.295297145843506
TESTING
Batch 0 testing loss: 4.143322944641113
Batch 20 testing loss: 4.2660346031188965
Batch 40 testing loss: 4.116174221038818
Batch 60 testing loss: 4.122979640960693
------------------------------
START EPOCH 12
TRAINING
Batch 0 training loss: 4.1926045417785645
Batch 80 training loss: 4.344197750091553
Batch 160 training loss: 4.273345947265625
Batch 240 training loss: 4.241662502288818
Batch 320 training loss: 4.220840930938721
Batch 400 training loss: 4.210723876953125
Batch 480 training loss: 4.308537483215332
Batch 560 training loss: 4.151541709899902
Batch 640 training loss: 4.26921272277832
TESTING
Batch 0 testing loss: 4.119717121124268
Batch 20 testing loss: 4.248164176940918
Batch 40 testing loss: 4.096585750579834
Batch 60 testing loss: 4.1085920333862305
------------------------------
START EPOCH 13
TRAINING
Batch 0 training loss: 4.16041898727417
Batch 80 training loss: 4.312416076660156
Batch 160 training loss: 4.239747047424316
Batch 240 training loss: 4.2181396484375
Batch 320 training loss: 4.1919941902160645
Batch 400 training loss: 4.188230514526367
Batch 480 training loss: 4.291740894317627
Batch 560 training loss: 4.119902610778809
Batch 640 training loss: 4.25164270401001
TESTING
Batch 0 testing loss: 4.109146595001221
Batch 20 testing loss: 4.236158847808838
Batch 40 testing loss: 4.083269119262695
Batch 60 testing loss: 4.096263408660889
------------------------------
START EPOCH 14
TRAINING
Batch 0 training loss: 4.1401686668396
Batch 80 training loss: 4.296884059906006
Batch 160 training loss: 4.219123363494873
Batch 240 training loss: 4.194148540496826
Batch 320 training loss: 4.17279052734375
Batch 400 training loss: 4.172521114349365
Batch 480 training loss: 4.268155097961426
Batch 560 training loss: 4.106161117553711
Batch 640 training loss: 4.234467029571533
TESTING
Batch 0 testing loss: 4.087616920471191
Batch 20 testing loss: 4.2239861488342285
Batch 40 testing loss: 4.060999870300293
Batch 60 testing loss: 4.0785231590271
------------------------------
START EPOCH 15
TRAINING
Batch 0 training loss: 4.124801158905029
Batch 80 training loss: 4.279020309448242
Batch 160 training loss: 4.201953887939453
Batch 240 training loss: 4.178887367248535
Batch 320 training loss: 4.160668849945068
Batch 400 training loss: 4.158737659454346
Batch 480 training loss: 4.256253242492676
Batch 560 training loss: 4.083315372467041
Batch 640 training loss: 4.213152885437012
TESTING
Batch 0 testing loss: 4.081240653991699
Batch 20 testing loss: 4.213397026062012
Batch 40 testing loss: 4.05709981918335
Batch 60 testing loss: 4.071685314178467
------------------------------
START EPOCH 16
TRAINING
Batch 0 training loss: 4.105445384979248
Batch 80 training loss: 4.257458209991455
Batch 160 training loss: 4.181327819824219
Batch 240 training loss: 4.160776615142822
Batch 320 training loss: 4.133213043212891
Batch 400 training loss: 4.14075231552124
Batch 480 training loss: 4.230212211608887
Batch 560 training loss: 4.068319797515869
Batch 640 training loss: 4.199864387512207
TESTING
Batch 0 testing loss: 4.05750036239624
Batch 20 testing loss: 4.198790550231934
Batch 40 testing loss: 4.032836437225342
Batch 60 testing loss: 4.047375679016113
------------------------------
START EPOCH 17
TRAINING
Batch 0 training loss: 4.079669952392578
Batch 80 training loss: 4.2428998947143555
Batch 160 training loss: 4.164575576782227
Batch 240 training loss: 4.144774913787842
Batch 320 training loss: 4.125186443328857
Batch 400 training loss: 4.118528842926025
Batch 480 training loss: 4.224449634552002
Batch 560 training loss: 4.043789863586426
Batch 640 training loss: 4.182216644287109
TESTING
Batch 0 testing loss: 4.0463104248046875
Batch 20 testing loss: 4.182623386383057
Batch 40 testing loss: 4.020364284515381
Batch 60 testing loss: 4.032769203186035
------------------------------
START EPOCH 18
TRAINING
Batch 0 training loss: 4.067461013793945
Batch 80 training loss: 4.221118450164795
Batch 160 training loss: 4.142800807952881
Batch 240 training loss: 4.129454135894775
Batch 320 training loss: 4.112790107727051
Batch 400 training loss: 4.101941108703613
Batch 480 training loss: 4.196477890014648
Batch 560 training loss: 4.035633087158203
Batch 640 training loss: 4.1695780754089355
TESTING
Batch 0 testing loss: 4.031458854675293
Batch 20 testing loss: 4.1668572425842285
Batch 40 testing loss: 4.003049373626709
Batch 60 testing loss: 4.024814605712891
------------------------------
START EPOCH 19
TRAINING
Batch 0 training loss: 4.051701068878174
Batch 80 training loss: 4.202447414398193
Batch 160 training loss: 4.129954814910889
Batch 240 training loss: 4.109521389007568
Batch 320 training loss: 4.086175441741943
Batch 400 training loss: 4.092165946960449
Batch 480 training loss: 4.185718536376953
Batch 560 training loss: 4.016693592071533
Batch 640 training loss: 4.15826416015625
TESTING
Batch 0 testing loss: 4.019284725189209
Batch 20 testing loss: 4.161983489990234
Batch 40 testing loss: 3.9886131286621094
Batch 60 testing loss: 4.014815807342529
------------------------------
START EPOCH 20
TRAINING
Batch 0 training loss: 4.036025524139404
Batch 80 training loss: 4.18900728225708
Batch 160 training loss: 4.11253023147583
Batch 240 training loss: 4.099170207977295
Batch 320 training loss: 4.078619480133057
Batch 400 training loss: 4.075125694274902
Batch 480 training loss: 4.162968635559082
Batch 560 training loss: 3.9928629398345947
Batch 640 training loss: 4.134066104888916
TESTING
Batch 0 testing loss: 4.008586883544922
Batch 20 testing loss: 4.14926815032959
Batch 40 testing loss: 3.982499837875366
Batch 60 testing loss: 4.0042524337768555
------------------------------

Generate#

prompts = [
    "We compared the prevalence of", # organ-specific autoantibodies in a group of Helicobacter..."
    "We compared the prevalence of",
    "We compared the prevalence of",
]

prompt_ids = dataset.tokenizer(prompts, return_tensors='pt')['input_ids']

# trim off unwanted [SEP] tokens which act like our special end-of-sequence token.
prompt_ids = prompt_ids[:,:-1]

# generate ids
gen_ids = generate(model.to('cpu'), prompt_ids, 50)
dataset.decode_batch(gen_ids)
['[CLS] we compared the prevalence of - metabolites in swine ponds of high - substituted for piperine glutamase. to ouabain - sensitive catecholamine ( eps ) strips with enzyme inhibitory chemical analysis for analyses namely short - exercise. neuro',
 '[CLS] we compared the prevalence of cholinimedumann a significant clinical symptoms and in patients with albic symptoms, using histological tests, and patient clinicopathological examinations, offer prognosis and immunohistochemistry were done. the distinction between the records',
 '[CLS] we compared the prevalence of ventricular aortic ( bf ) in all3 increased arterioplasty compared with mortality in trachea, using aortic valve defects in 12 cardiac arrest. eighty - four patients treated with chronic aspiricular pressures']
prompts = [
    "The pytorch llm workshop was",
    "The pytorch llm workshop was",
    "The pytorch llm workshop was",
]

prompt_ids = dataset.tokenizer(prompts, return_tensors='pt')['input_ids']

# trim off unwanted [SEP] tokens which act like our special end-of-sequence token.
prompt_ids = prompt_ids[:,:-1]

# generate ids
gen_ids = generate(model, prompt_ids, 50)
dataset.decode_batch(gen_ids)
['[CLS] the pytorch llm workshop was assessed regarding the accuracy and validity effects of comorpharmacytric acute excretion in children with insulin activity in suckling for children and adults with impaired children ( palsycintimulating age 6 years ) in their daily red periodicity',
 '[CLS] the pytorch llm workshop was prepared on morphology of pigments at 14cncephalism carrier hooklets. dentinal coat v1 hams of the medialisport muscle caused by corneal vessels ( gsor2, and phospho phosphatase',
 '[CLS] the pytorch llm workshop was prepared using the cell - cell membrane protein membranes containing other materials displaying moderate - strains of microtimulating trophages. bacterization products of irratiate - phospholipids were used for the bacryl']