Batching, multi-gpu, and multi-node for large data and large models#
We’ve seen how to inference LLMs with a high degree of control over the model inputs and outputs. The goal of this last notebook is to discussion measures to scale up the inference process to large data and large models.
There are three primary tools we will use:
Batching
Multi-GPU inference
Multi-node inference
We’ll discuss each of these in turn.
Batching#
Batching is the process of processing multiple inputs at once. This is a common technique in deep learning, as it allows the model to process multiple inputs in parallel. The transformers
library has built-in support for batching, and we can use it to speed up inference with minimal code changes.
First, we’ll load a large number of pieces of text that we want to process using an LLM. Then, we’ll process them in batches and compare the time it takes to process them in batches versus one at a time.
# Get a list of texts from the 20 newsgroups dataset
# Each text is a post from a newsgroup
from sklearn.datasets import fetch_20newsgroups
docs = fetch_20newsgroups(subset='test')['data'][:64]
print(f'Number of documents: {len(docs)}')
for i, doc in enumerate(docs[:3]):
print(f'\n\nDOCUMENT {i+1}:\n{doc}\n')
Suppose we want some piece of information about each of these newsgroup posts, and what we want cannot be easily extracted in an automated way using traditional NLP techniques. An LLM might be a good choice for such a task.
For example, we might want a one-sentence summary of each post. We can craft a prompt that asks the model to generate such a summary.
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="auto")
device = model.device
system_prompt = "The user will supply a post from an online newsgroup. Summarize the post in a single, very short sentence."
# Define a function that will generate summaries for a batch of posts
def generate_summaries(texts, batch_size=8):
results = []
total_batches = (len(texts) + batch_size - 1) // batch_size
with tqdm(total=total_batches, desc="Processing batches", leave=True, bar_format="{l_bar}{bar} | {n_fmt}/{total_fmt}") as pbar:
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_messages = [[{"role": "system", "content": system_prompt},
{"role": "user", "content": text}] for text in batch]
# Tokenize the messages using chat template
model_inputs = tokenizer.apply_chat_template(
batch_messages,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
return_dict=True,
).to(device)
# Run model to get logits and generated output
with torch.no_grad():
outputs = model.generate(
**model_inputs,
max_new_tokens=100,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode output
prompt_length = model_inputs["input_ids"].shape[1]
generated_sequences = outputs.sequences[:, prompt_length:]
decoded_outputs = tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)
results.extend(decoded_outputs)
pbar.update(1)
return results
start = time.time()
# Generate summaries for the documents
summaries = generate_summaries(docs, batch_size=64)
end = time.time()
print(f"Total time taken: {end - start:.2f} seconds")
summaries
# Clear the model from memory
import torch
del model
torch.cuda.empty_cache()
Multi-GPU inference#
Thankfully, transformers
makes multi-gpu inference easy.
Note that there are multiple kinds of ways you might want to use multiple GPUs. Note that there are different kinds of paralellism one might want to use. For example, if you just want to speed up your LLM inference, and your model can fit on a single GPU, you can use data parallelism.
If your model is too large to fit on a single GPU, you can use model parallelism, in which the different GPUs each hold a different part of the model. Luckily, transformers
makes it easy to use model parallelism, via setting device_map
.
import time
start = time.time()
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
model_name = "meta-llama/Llama-3.3-70B-Instruct" # "meta-llama/Llama-3.1-8B-Instruct"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
device = model.device
system_prompt = "The user will supply a post from an online newsgroup. Summarize the post in a single, very short sentence."
# Generate summaries for the documents
summaries = generate_summaries(docs, batch_size=8)
end = time.time()
print(f"Total time taken: {end - start:.2f} seconds")
summaries
# Clear the model from memory
import torch
del model
torch.cuda.empty_cache()
Multi-node inference#
What if you have multiple nodes available, and want to use them all to speed up your inference? There are a variety of sorts of parallelism that are possible with multi-node inference.
For example, you can use data parallelism, in which you split the data across the nodes, and each node processes a different part of the data. You can also use model parallelism, in which you split the model across the nodes, and each node processes a different part of the model. The former is for speeding up inference, and the latter is for when you have a model that’s too large to fit on a single node.
We will implement data parallelism. The code is in the scripts inference.slurm
, helper_inference.sh
, and inference.py
. These three files work together to enable distributed inference:
inference.slurm
: The SLURM job submission script that requests and configures computing resources (in this example, 2 nodes, each with 1 V100 GPU, 12GB memory, etc.)helper_inference.sh
: A shell script that sets up the environment and launches the distributed training usingtorchrun
. It handles environment modules, activates the python environment, and configures the distributed setup parameters.inference.py
: The main Python script that performs the actual inference. It:Initializes distributed processing across nodes
Loads the model and tokenizer
Splits the input prompts across available nodes
Processes prompts in parallel
Gathers results back to the main node
Saves the combined output
The workflow is to use sbatch
to submit inference.slurm
, which calls helper_inference.sh
on each node, which then launches inference.py
in a coordinated way across all nodes.