The Problem

Autoregressive generation is serial by design. To generate token 50, you must have already generated tokens 1–49. Each token requires one full forward pass through a 70B model. There’s no shortcut — or so it seemed.

The forward pass for a Transformer is naturally parallelizable over sequence positions. During training, we exploit this fully: we feed in the full sequence and predict all tokens simultaneously. During inference, we can’t — we don’t know what token 49 is until we’ve generated it.

Speculative decoding breaks this assumption by making a bet.

Analogy

Imagine you’re a slow, meticulous editor (the large model) reviewing text from a fast, sloppy intern (the draft model).

The intern writes 5 words. You scan them in one pass — accepting correct words, stopping at the first mistake, crossing it out, writing the correct word, and handing back to the intern.

On average you’re doing far less work per accepted word than if you’d written every word yourself. And the output quality is determined entirely by your standards — the intern’s words only appear if you approve them.

The key insight: reading is faster than writing. The large model can verify K tokens in parallel in one forward pass (that’s just reading), while generating K tokens would require K serial forward passes (writing).

Mechanism in Plain English

  1. A small, fast draft model generates K tokens autoregressively (e.g., K=5). This is cheap.

  2. The target model (large, slow) runs a single forward pass on all K+1 positions simultaneously — the original context plus the K draft tokens.

  3. Compare the target model’s distribution at each position to the draft model’s:

    • If the draft token is “good enough” (formally: accepted with probability ), keep it.
    • At the first rejection, sample a corrected token from an adjusted distribution and stop.
  4. The result is the same distribution as if the target model had generated every token itself — this is provably exact, not an approximation.

  5. On average, more than one token is produced per target model call. The speedup is roughly where is the per-token acceptance rate.

ASCII Diagram

  Step 1: Draft model generates 5 tokens (fast, sequential)
  
  Context: "The cat sat on the"
  Draft:    "mat" "and" "the" "dog" "slept"
  
  Step 2: Target model verifies all 5 in one parallel pass
  
  Target evaluates:
    P_target("mat"|context)   = 0.72   P_draft = 0.68  → ratio 1.06 → ACCEPT ✓
    P_target("and"|...mat)    = 0.15   P_draft = 0.41  → ratio 0.37 → REJECT ✗ (with prob 0.63)
  
  At rejection: sample corrected token from (P_target - P_draft)+ distribution
  Say we get: "." 
  
  Step 3: Output this decode step: "mat ."
  
  We generated 2 tokens with 1 target model call.
  Without speculative decoding: 2 target model calls.
  
  ─────────────────────────────────────────────────────
  
  Best case (α=1.0, all accepted, K=5):
    "mat and the dog slept" → 5 tokens, 1 call → 5x speedup
  
  Typical case (α=0.8, K=5):
    Expected tokens per call = (1 - 0.8^5)/(1 - 0.8) = (1 - 0.33)/0.2 ≈ 3.4
    Speedup: ~2-3x

Math with Translation

Expected tokens per target forward pass:

  • — per-token acceptance rate (0 to 1, higher = faster)
  • — number of draft tokens generated per step
  • At : tokens per target call
  • At : tokens per target call

Speedup , approximately when the draft model is cheap.

Concrete Walkthrough

Setup:

  • Target: 70B model, 100ms per forward pass
  • Draft: 7B model, 10ms per token (5x faster)
  • K = 4, α = 0.85

Without speculative decoding:

  • Cost per token: 100ms × 1 = 100ms

With speculative decoding:

  • Draft cost: 4 × 10ms = 40ms
  • Target cost: 100ms (one pass over 5 positions)
  • Total: 140ms
  • Expected tokens: tokens
  • Cost per token: 140ms / 3.2 ≈ 44ms
  • Speedup: 2.3x

The rejection sampling scheme guarantees output distribution is identical to greedy target sampling. The intern’s mistakes don’t propagate.

What’s Clever

The insight that unlocked this: the target model’s forward pass cost is nearly constant whether you verify 1 token or K tokens, because the batch computation over sequence positions is fully parallel.

Verifying K tokens costs almost the same as verifying 1 token (slightly more KV cache reads, but the matmuls are the same size). So every extra token the draft gets right is almost free.

The second non-obvious thing: the rejection sampling scheme is exact. Early drafting proposals (before the 2022 Leviathan et al. paper) worried that using a draft model would bias outputs. The correction distribution — sampling from max(0, p_target - p_draft) at the first rejected position — mathematically guarantees the output matches the target distribution. No approximation.

Self-speculative decoding extends this: instead of a separate draft model, use early transformer layers as the draft (exit early, verify with the full model). No separate model needed at all.

Code

Draft-verify loop skeleton illustrating the accept/reject logic:

from random import random
 
def speculative_decode(draft_model, target_model, prompt, K=4, max_tokens=100):
    tokens = list(prompt)
    while len(tokens) < max_tokens:
        # Step 1: draft model autoregressively proposes K candidate tokens (fast)
        drafts = []
        x = tokens[:]
        for _ in range(K):
            next_tok = draft_model.sample(x)   # cheap: small model, one step
            drafts.append(next_tok)
            x.append(next_tok)
 
        # Step 2: target model scores all K+1 positions in ONE parallel forward pass
        target_logits = target_model.forward(tokens + drafts)  # O(1) passes for K tokens
 
        # Step 3: accept/reject each draft token using rejection sampling
        accepted = []
        for i, tok in enumerate(drafts):
            p_target = target_logits[len(tokens) + i][tok]          # target prob for draft tok
            p_draft  = draft_model.prob(tokens + drafts[:i], tok)   # draft prob for same tok
            if random() < p_target / p_draft:    # accept if target agrees (or more)
                accepted.append(tok)
            else:
                # First rejection: sample a corrected token from the adjusted distribution
                accepted.append(target_model.sample(target_logits[len(tokens) + i]))
                break  # stop here; output is now target-distributed
 
        tokens.extend(accepted)
    return tokens
# Output distribution is identical to greedy target-model sampling — provably exact.

Key Sources

Open Questions

  • Optimal draft model size and architecture relative to target
  • Speculative decoding with quantized or approximate target models