The Problem
Autoregressive generation is serial by design. To generate token 50, you must have already generated tokens 1–49. Each token requires one full forward pass through a 70B model. There’s no shortcut — or so it seemed.
The forward pass for a Transformer is naturally parallelizable over sequence positions. During training, we exploit this fully: we feed in the full sequence and predict all tokens simultaneously. During inference, we can’t — we don’t know what token 49 is until we’ve generated it.
Speculative decoding breaks this assumption by making a bet.
Analogy
Imagine you’re a slow, meticulous editor (the large model) reviewing text from a fast, sloppy intern (the draft model).
The intern writes 5 words. You scan them in one pass — accepting correct words, stopping at the first mistake, crossing it out, writing the correct word, and handing back to the intern.
On average you’re doing far less work per accepted word than if you’d written every word yourself. And the output quality is determined entirely by your standards — the intern’s words only appear if you approve them.
The key insight: reading is faster than writing. The large model can verify K tokens in parallel in one forward pass (that’s just reading), while generating K tokens would require K serial forward passes (writing).
Mechanism in Plain English
-
A small, fast draft model generates K tokens autoregressively (e.g., K=5). This is cheap.
-
The target model (large, slow) runs a single forward pass on all K+1 positions simultaneously — the original context plus the K draft tokens.
-
Compare the target model’s distribution at each position to the draft model’s:
- If the draft token is “good enough” (formally: accepted with probability ), keep it.
- At the first rejection, sample a corrected token from an adjusted distribution and stop.
-
The result is the same distribution as if the target model had generated every token itself — this is provably exact, not an approximation.
-
On average, more than one token is produced per target model call. The speedup is roughly where is the per-token acceptance rate.
ASCII Diagram
Step 1: Draft model generates 5 tokens (fast, sequential)
Context: "The cat sat on the"
Draft: "mat" "and" "the" "dog" "slept"
Step 2: Target model verifies all 5 in one parallel pass
Target evaluates:
P_target("mat"|context) = 0.72 P_draft = 0.68 → ratio 1.06 → ACCEPT ✓
P_target("and"|...mat) = 0.15 P_draft = 0.41 → ratio 0.37 → REJECT ✗ (with prob 0.63)
At rejection: sample corrected token from (P_target - P_draft)+ distribution
Say we get: "."
Step 3: Output this decode step: "mat ."
We generated 2 tokens with 1 target model call.
Without speculative decoding: 2 target model calls.
─────────────────────────────────────────────────────
Best case (α=1.0, all accepted, K=5):
"mat and the dog slept" → 5 tokens, 1 call → 5x speedup
Typical case (α=0.8, K=5):
Expected tokens per call = (1 - 0.8^5)/(1 - 0.8) = (1 - 0.33)/0.2 ≈ 3.4
Speedup: ~2-3x
Math with Translation
Expected tokens per target forward pass:
- — per-token acceptance rate (0 to 1, higher = faster)
- — number of draft tokens generated per step
- At : tokens per target call
- At : tokens per target call
Speedup , approximately when the draft model is cheap.
Concrete Walkthrough
Setup:
- Target: 70B model, 100ms per forward pass
- Draft: 7B model, 10ms per token (5x faster)
- K = 4, α = 0.85
Without speculative decoding:
- Cost per token: 100ms × 1 = 100ms
With speculative decoding:
- Draft cost: 4 × 10ms = 40ms
- Target cost: 100ms (one pass over 5 positions)
- Total: 140ms
- Expected tokens: tokens
- Cost per token: 140ms / 3.2 ≈ 44ms
- Speedup: 2.3x
The rejection sampling scheme guarantees output distribution is identical to greedy target sampling. The intern’s mistakes don’t propagate.
What’s Clever
The insight that unlocked this: the target model’s forward pass cost is nearly constant whether you verify 1 token or K tokens, because the batch computation over sequence positions is fully parallel.
Verifying K tokens costs almost the same as verifying 1 token (slightly more KV cache reads, but the matmuls are the same size). So every extra token the draft gets right is almost free.
The second non-obvious thing: the rejection sampling scheme is exact. Early drafting proposals (before the 2022 Leviathan et al. paper) worried that using a draft model would bias outputs. The correction distribution — sampling from max(0, p_target - p_draft) at the first rejected position — mathematically guarantees the output matches the target distribution. No approximation.
Self-speculative decoding extends this: instead of a separate draft model, use early transformer layers as the draft (exit early, verify with the full model). No separate model needed at all.
Code
Draft-verify loop skeleton illustrating the accept/reject logic:
from random import random
def speculative_decode(draft_model, target_model, prompt, K=4, max_tokens=100):
tokens = list(prompt)
while len(tokens) < max_tokens:
# Step 1: draft model autoregressively proposes K candidate tokens (fast)
drafts = []
x = tokens[:]
for _ in range(K):
next_tok = draft_model.sample(x) # cheap: small model, one step
drafts.append(next_tok)
x.append(next_tok)
# Step 2: target model scores all K+1 positions in ONE parallel forward pass
target_logits = target_model.forward(tokens + drafts) # O(1) passes for K tokens
# Step 3: accept/reject each draft token using rejection sampling
accepted = []
for i, tok in enumerate(drafts):
p_target = target_logits[len(tokens) + i][tok] # target prob for draft tok
p_draft = draft_model.prob(tokens + drafts[:i], tok) # draft prob for same tok
if random() < p_target / p_draft: # accept if target agrees (or more)
accepted.append(tok)
else:
# First rejection: sample a corrected token from the adjusted distribution
accepted.append(target_model.sample(target_logits[len(tokens) + i]))
break # stop here; output is now target-distributed
tokens.extend(accepted)
return tokens
# Output distribution is identical to greedy target-model sampling — provably exact.Key Sources
Related Concepts
Open Questions
- Optimal draft model size and architecture relative to target
- Speculative decoding with quantized or approximate target models