Same cross-entropy loss, but on curated instruction data
Instruction tuning: Implementation details
Formatting data for instruction-tuned models:
# For summarization with instruction-tuned modelsdef prepare_for_summarization(article, tokenizer): prompt =f"Summarize the following article.\n\n{article}\n\nTL;DR:"return tokenizer(prompt, return_tensors="pt")# System prompts set the model's "persona"system_prompt ="You are a helpful assistant and an expert at summarizing articles."
Fine-tuning with label masking (for training summarization models):
# We only want to train on generating the SUMMARY, not the articlelabels = input_ids.clone()labels[:prompt_length] =-100# -100 is ignored by CrossEntropyLoss# The model learns: given article, predict summary tokens
Large models are expensive to deploy
70B+
parameters in frontier models
$$$
GPU costs for serving
High latency
slow responses at scale
We want similar quality at lower cost.
Preview: Knowledge Distillation (Part 8)
We can train a smaller “student” model to mimic the large model’s behavior — getting most of the quality at a fraction of the cost.
Before turning to efficiency, we’ll address a deeper limitation of SFT: It only learns from one correct answer per instruction. Can we do better?
Part 6: Preference Alignment with RLHF
Why RLHF? SFT isn’t enough
SFT limitations:
One "correct" answer per instruction
Can't encode preferences between valid responses
Expensive to write high-quality responses
May learn to imitate surface patterns
RLHF advantages:
Captures human preferences between options
Easier to judge than to write
Learns subtle quality distinctions
Optimizes for what humans actually want
Key insight: It’s easier to say “A is better than B” than to write the perfect A.
RLHF: The three-step process
1
Collect Preference Data
For each prompt, generate multiple responses. Humans rank: A > B or B > A
↓
2
Train Reward Model
Learn to predict human preferences: r(x, y) → scalar score
↓
3
Optimize LLM with RL
Use PPO to maximize expected reward while staying close to SFT model
RLHF data flow
%%{init: {"theme":"base","themeVariables":{"fontFamily":"Inter, system-ui, sans-serif","fontSize":"16px","primaryColor":"#eef2f7","primaryTextColor":"#0f172a","lineColor":"#475569","tertiaryColor":"#e2e8f0"},"flowchart":{"curve":"basis","nodeSpacing":30,"rankSpacing":50}}}%%
flowchart LR
P["Prompts"] --> M["SFT Model πθ"]
M --> RA["Response A"]
M --> RB["Response B"]
RA --> H["Human Rater"]
RB --> H
H -->|"A ≻ B"| RM["Reward Model rθ"]
RM -->|"r(x,y)"| PPO["PPO"]
PPO -->|"Update πθ"| M
The reward model is a learned proxy for human judgment — it scores responses so we don’t need a human in the loop at training time.
Reward model intuition: Learning from comparisons
Prompt: "Explain gravity to a 5-year-old"
Response A ✓ preferred "Gravity is like an invisible magnet in the Earth that pulls everything down — that's why when you jump, you come back!"
Response B "Gravity is a fundamental force described by general relativity as the curvature of spacetime caused by mass-energy..."
PPO (Proximal Policy Optimization) is commonly used
KL penalty prevents “reward hacking”
From RLHF to DPO: What changes?
RLHF
Train a separate reward model
Run PPO optimization loop
Complex, unstable training
Three separate stages
→
DPO
No reward model needed
No RL loop required
Directly optimize preference pairs
Standard supervised learning
Key insight: Same preference data, dramatically simpler training.
Alternatives to RLHF: Direct Preference Optimization
Key insight: The optimal RLHF policy has a closed-form relationship with the reward model. DPO exploits this to optimize directly on preference pairs — no reward model or RL needed.
Reading the formula: increase the probability of the preferred response \(y_w\) (relative to the reference model) and decrease the probability of the dispreferred response \(y_l\).
RLHF
Preferences → Reward Model → RL
Complex, unstable, expensive
DPO
Preferences → Direct optimization
Simpler, stable, supervised learning
Alignment goals: Helpful, Harmless, Honest
Helpful
Answers questions effectively
Completes requested tasks
Provides useful information
Harmless
Refuses dangerous requests
Avoids harmful content
Doesn't assist malicious use
Honest
Acknowledges uncertainty
Doesn't make things up
Corrects mistakes
Tension
These goals can conflict: A helpful answer to a harmful request is itself harmful.
Concept check
Quick questions (think for 30 seconds):
In the RL objective, what happens if we set β = 0?
An RLHF-trained chatbot starts adding “As an AI language model…” to every response. What went wrong?
Given a trained model and a prompt, how do we produce an output — one token at a time?
But note: what you put in matters as much as how tokens come out. We’ll cover the output side today (decoding), and the input side next time (prompting, in-context learning, reasoning).
Three control knobs in LLM systems
Training-time
(change weights)
SFT
RLHF
DPO
Inference-time
(change decoding)
Greedy
Temperature
Top-k / Top-p
Deployment-time
(change size)
Knowledge distillation
These mechanisms operate on different axes but interact in final behavior.
Training vs. inference: Different levers
Training (shapes probabilities)
Alignment modifies the token distribution
Changes relative likelihoods
Inference (selects from distribution)
Decoding determines which tokens get chosen
These interact in practice:
Greedy decoding can amplify repetitive artifacts from training
High temperature increases hallucination risk regardless of alignment
Beam search often reduces diversity even in well-aligned models
Part 7: Text Generation and Decoding
The autoregressive generation loop
def generate(model, prompt_ids, max_new_tokens=50): input_ids = prompt_idsfor _ inrange(max_new_tokens):# Forward pass: get logits for all positions logits = model(input_ids) # (B, seq_len, vocab_size)# Extract logits for the LAST position only next_token_logits = logits[:, -1, :] # (B, vocab_size)# DECODE: Select next token (this is where methods differ!) next_token = decode(next_token_logits)# Append and continue input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)return input_ids
The decode() function is where greedy, sampling, top-k, and top-p differ!
def temperature(logits, t=1.0):"""Apply temperature scaling before sampling""" scaled_logits = logits / t probs = F.softmax(scaled_logits, dim=-1)return torch.multinomial(probs, num_samples=1).squeeze(-1)
Temperature
Effect
Distribution
T → 0
More deterministic
Approaches greedy
T = 1
Original distribution
Standard sampling
T > 1
More random
Flatter distribution
Temperature effects visualized
Original logits: [4.0, 2.0, 1.0, 0.5]
T = 0.5 (sharper)
probs ≈ [0.88, 0.09, 0.02, 0.01]
T = 1.0 (original)
probs ≈ [0.64, 0.24, 0.09, 0.05]
T = 2.0 (flatter)
probs ≈ [0.42, 0.26, 0.18, 0.14]
Top-k sampling: Only consider the k best tokens
def topk(logits, k=50):"""Sample from only the top k tokens"""# Get top k logits and their indices top_logits, top_indices = torch.topk(logits, k, dim=-1)# Sample from the top k probs = F.softmax(top_logits, dim=-1) sampled_idx = torch.multinomial(probs, num_samples=1)# Map back to vocabulary indicesreturn torch.gather(top_indices, -1, sampled_idx).squeeze(-1)
Idea: Truncate the long tail of unlikely tokens, then sample from the rest.
k = 50: Sample from top 50 tokens only k = 1: Equivalent to greedy decoding
Visualizing truncation: top-k vs top-p
Same distribution, different truncation:
Top-k (k=3): Fixed — always 3 tokens
the .40
a .25
one .15
an .10
my .05
... .05
Covers 80% of probability mass
Top-p (p=0.9): Adaptive — cumsum ≥ 0.9
the .40
a .25
one .15
an .10
my .05
... .05
Covers 90% of probability mass (4 tokens)
Top-k misses “an” (10% probability) because it always takes exactly k. Top-p includes it because 90% mass hasn’t been reached yet.
Top-p (nucleus) sampling: Adaptive truncation
Idea: Keep the smallest set of tokens whose cumulative probability ≥ p
Algorithm (Holtzman et al., 2020):
Sort tokens by probability (highest first)
Compute cumulative probabilities
Find the cutoff where cumulative probability exceeds p
Zero out all tokens beyond the cutoff
Sample from the remaining tokens
Top-p implementation
def topp(logits, p=0.9):"""Nucleus sampling: dynamic truncation based on cumulative probability"""# Step 1: Sort by probability (descending) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1)# Step 2: Compute cumulative probabilities cumsum_probs = torch.cumsum(sorted_probs, dim=-1)# Step 3-4: Find and apply cutoff sorted_indices_to_remove = cumsum_probs > p# Tricky: shift right so we INCLUDE the token that crosses p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] =False# Always keep at least one token sorted_logits[sorted_indices_to_remove] =float('-inf')# Step 5: Sample from remaining tokens probs = F.softmax(sorted_logits, dim=-1) sampled_idx = torch.multinomial(probs, num_samples=1)return torch.gather(sorted_indices, -1, sampled_idx).squeeze(-1)
The right-shift trick (lines 8–9)
Without the shift, we’d exclude the token that crosses the threshold. The shift ensures the nucleus always covers at least p probability mass.