Recap: Attention Essentials
Recall: Scaled dot-product attention
\[
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
\]
- Q, K, V are linear projections of the input: \(Q = XW^Q\), \(K = XW^K\), \(V = XW^V\)
- Scaling by \(\sqrt{d_k}\) prevents softmax saturation
- Output for each position = weighted sum of all value vectors
Recall: Why scale by \(\sqrt{d_k}\)? Preventing softmax saturation
For random vectors \(\mathbf{q}, \mathbf{k} \in \mathbb{R}^{d_k}\) with entries ~ \(\mathcal{N}(0,1)\): \(\text{Var}(\mathbf{q} \cdot \mathbf{k}) = d_k\)
Without Scaling (dk = 64)
Dot products have std dev ≈ 8
scores = [12, 8, 5, 1]
softmax ≈ [0.982, 0.018, 0.0, 0.0]
Near one-hot → vanishing gradients for non-max tokens
With Scaling: scores / √64 = scores / 8
Dot products rescaled to std dev ≈ 1
scaled = [1.5, 1.0, 0.625, 0.125]
softmax ≈ [0.439, 0.267, 0.183, 0.111]
Smooth distribution → healthy gradients for all tokens
Why \(\sqrt{d_k}\)? The variance argument
Assume \(q_i, k_i \sim \text{i.i.d.}\) with mean 0, variance 1. The dot product is:
\[
q \cdot k = \sum_{i=1}^{d_k} q_i \, k_i
\]
Each term \(q_i k_i\) has: \(\;\mathbb{E}[q_i k_i] = 0\), \(\;\text{Var}(q_i k_i) = 1\)
By independence, the variance of the sum is:
\[
\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k
\]
So dot products grow as \(O(d_k)\). Dividing by \(\sqrt{d_k}\) restores unit variance:
\[
\text{Var}\!\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1
\]
Recall: Multi-head attention and causal masking
Multi-head: Run \(h\) parallel attention operations with \(d_k = d_{model}/h\)
\[
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O
\]
Causal masking: Prevent attending to future tokens
\[
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top + M}{\sqrt{d_k}}\right)V
\]
\[
\text{where } M_{ij} = \begin{cases} 0 & j \leq i \\ -\infty & j > i \end{cases}
\]
Multi-head
Each head learns different relationships
Causal mask
Lower-triangular: token i sees only ≤ i
Transformer block
Attn + FFN + residuals + LayerNorm
Part 3: Pretraining at Scale
The pretraining objective: predict the next token
Causal Language Modeling (CLM) loss:
\[
\mathcal{L}(\theta) = -\sum_{t=1}^{T} \log P_\theta(x_t \mid x_1, \ldots, x_{t-1})
\]
Training Example: "The quick brown fox jumps"
| Context |
Target |
Loss contribution |
| <bos> |
The |
−log P(The | <bos>) |
| The |
quick |
−log P(quick | The) |
| The quick |
brown |
−log P(brown | The quick) |
| The quick brown |
fox |
−log P(fox | The quick brown) |
| The quick brown fox |
jumps |
−log P(jumps | ...) |
Self-supervision: the data labels itself
Traditional Supervised Learning
Image
→
Human labels "cat"
Requires expensive manual annotation
Self-Supervised (LLMs)
"The cat"
→
Next word is "sat"
Labels come from the text itself—free and unlimited
Key insight: The structure of language provides free supervision at massive scale
Next-token prediction implicitly learns many skills
To predict the next token well, the model must learn:
Grammar & Syntax
"She runs" not "She run"
World Knowledge
"Paris is the capital of France"
Reasoning Patterns
"If A then B. A is true. Therefore B"
Style & Tone
Formal vs. casual, technical vs. simple
The simple objective captures complex structure
![]()
GPT-2 Paper
Scaling laws describe predictable improvement with resources
Empirical finding (Kaplan et al., 2020; Hoffmann et al., 2022):
\[
L(N, D) \approx \left(\frac{N_c}{N}\right)^{\alpha_N} + \left(\frac{D_c}{D}\right)^{\alpha_D} + L_\infty
\]
N = Parameters
More parameters → lower loss
D = Data (tokens)
More data → lower loss
C = Compute
More FLOPs → lower loss
Loss decreases as a power law in each resource
Scaling law curves (Kaplan et al., 2020)
Chinchilla scaling: balance parameters and data
Key insight (Hoffmann et al., 2022):
For a fixed compute budget, there’s an optimal ratio:
\[
N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5}
\]
Rule of thumb: Train on ~20 tokens per parameter
Undertrained
70B params, 300B tokens
GPT-3 (4 tokens/param)
Compute-Optimal
70B params, 1.4T tokens
Chinchilla (20 tokens/param)
Inference-Optimal
7B params, 2T tokens
LLaMA (285 tokens/param)