Data preparation#
from workshop_utils import display_pdf
display_pdf("Slides_part5.pdf")
Let’s work through an example of preparing data to use to fine-tune our model. Along the way, we’ll develop functions that we can put in our workshop_utils.py
script so we can easily import them into any script or notebook.
from datasets import load_dataset
import torch
# Load the dataset
ds = load_dataset("HuggingFaceH4/MATH-500")
# Split the dataset into training and validation sets
train_val_dataset = ds["test"].train_test_split(test_size=0.1)
train_dataset = train_val_dataset["train"]
eval_dataset = train_val_dataset["test"]
# This example only has a test split, so we use that, for demonstration purposes.
import pandas as pd
pd.DataFrame(train_dataset[:5])
import textwrap
import random
wrapper = textwrap.TextWrapper(width=80)
sample = train_dataset[random.randint(0, len(train_dataset) - 1)]
problem = wrapper.fill(sample['problem'])
solution = wrapper.fill(sample['solution'])
print('Problem:\n', problem, '\n\n')
print('Solution:\n', solution)
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the model and tokenizer
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# The model may not have a pad token set by default, so set it (using the EOS token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Set a system prompt to be prepended to the user prompt.
# This is a simple example, but you could use a more complex system prompt.
system_prompt = "Solve the following math problem."
# Define a function to format the prompt and apply loss masking.
# This function builds a full text with a "User:" prompt and an "Assistant:" response.
# It then computes which tokens belong to the prompt (to be masked in the loss)
# The function assumes that each example has a `problem` and a `solution`, which is true for the MATH-500 dataset.
def tokenize_and_mask(example, tokenizer, max_length=1024, system_prompt=system_prompt):
# Build a prompt with the system, user, and assistant messages.
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": example['problem']},
{"role": "assistant", "content": example['solution']}
]
# Tokenize the prompt (without special tokens) to know its length.
prompt_ids = tokenizer.apply_chat_template(
messages[:-1],
return_tensors='pt',
return_dict=True,
add_special_tokens=False
)["input_ids"][0] # Remove batch dimension
# Tokenize the full conversation (with special tokens and truncation)
tokenized = tokenizer.apply_chat_template(
messages,
truncation=True,
max_length=max_length,
add_special_tokens=True,
return_tensors="pt",
return_dict=True,
)
input_ids = tokenized["input_ids"][0] # Remove batch dimension
attention_mask = tokenized["attention_mask"][0] # Remove batch dimension
# Create labels as a copy of input_ids.
labels = input_ids.clone()
prompt_length = len(prompt_ids)
labels[:prompt_length] = torch.tensor([-100] * prompt_length)
return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
def tokenize_for_generation(example, tokenizer):
# Build a prompt with the system and user messages only, for generation (not for training)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": example['problem']}
]
# Tokenize the full conversation
tokenized = tokenizer.apply_chat_template(
messages,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
return_dict=True,
add_generation_prefix=True, # Doesn't seem to work with Qwen2.5-0.5B-Instruct
)
generation_prefix = '<|im_start|>assistant\n'
generation_prefix_tokenized = tokenizer(generation_prefix, return_tensors='pt')["input_ids"]
input_ids = tokenized["input_ids"][0] # Remove batch dimension
attention_mask = tokenized["attention_mask"][0]
# Unsqueeze generation_prefix_tokenized to match dimensions
generation_prefix_tokenized = generation_prefix_tokenized.squeeze(0)
# Add the generation prefix to the input_ids
input_ids = torch.cat([input_ids, generation_prefix_tokenized], dim=0)
attention_mask = torch.cat([attention_mask, torch.ones_like(generation_prefix_tokenized)], dim=0)
return {"input_ids": input_ids, "attention_mask": attention_mask}
# Map the formatting function over the dataset.
# This applies the formatting function to each example in the dataset.
# The result is that we have a dataset where each math problem is formatted as a prompt for the model,
# and the solution is formatted as a response that the model should generate.
# Each example is also tokenized
# (If your dataset is large you might use batched=True; here we keep it simple.)
train_dataset_tokenized = train_dataset.map(tokenize_and_mask, batched=False, fn_kwargs={"tokenizer": tokenizer})
eval_dataset_tokenized = eval_dataset.map(tokenize_and_mask, batched=False, fn_kwargs={"tokenizer": tokenizer})
# Get a sample dataset so we can examine model generations before and after training
sample_dataset = eval_dataset.select(range(3))
sample_dataset_tokenized = sample_dataset.map(tokenize_for_generation, batched=False, fn_kwargs={"tokenizer": tokenizer})
train_dataset_tokenized.set_format(type="torch", columns=["input_ids", "labels", "attention_mask"])
eval_dataset_tokenized.set_format(type="torch", columns=["input_ids", "labels", "attention_mask"])
sample_dataset_tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])
def generate_and_print(sample_dataset, sample_dataset_tokenized, model, tokenizer):
outputs = []
for sample in sample_dataset_tokenized:
input_ids_batch = sample['input_ids'].unsqueeze(0).to(model.device)
attention_mask_batch = sample['attention_mask'].unsqueeze(0).to(model.device)
generated_ids = model.generate(
input_ids=input_ids_batch,
attention_mask=attention_mask_batch,
max_new_tokens=512, # change as needed
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
outputs.append(generated_text)
for i, sample in enumerate(sample_dataset_tokenized):
# Retrieve the original problem and solution from the un-tokenized dataset
original = sample_dataset[i]
print("Problem:")
print(wrapper.fill(original["problem"]))
print("\nTrue Solution:")
print(wrapper.fill(original["solution"]))
print("\nModel's Solution:")
model_output = outputs[i].split("assistant\n")[-1].strip()
print(wrapper.fill(model_output))
print("\n" + "-" * 80 + "\n")
# Generate and print model outputs before training
generate_and_print(sample_dataset, sample_dataset_tokenized, model, tokenizer)
# Create a simple data collator
def data_collator(features):
input_ids = torch.nn.utils.rnn.pad_sequence(
[f["input_ids"].clone().detach() for f in features],
batch_first=True,
padding_value=tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
[f["labels"].clone().detach() for f in features],
batch_first=True,
padding_value=-100,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[f["attention_mask"].clone().detach() for f in features],
batch_first=True,
padding_value=0,
)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}