The Problem

When a language model generates token 100 of a response, it needs to compute attention against all 99 previous tokens. Without any optimization, that means running all 99 tokens through the full Q/K/V computation again — for every single layer. Generating a 1000-token response costs 1+2+3+…+1000 = 500,500 forward-pass computations of attention rows. That’s quadratic in output length.

The fix is obvious once you see it. Most of that work is being repeated.

Analogy

You’re writing a long report and need to check earlier sections regularly. Without any optimization, you’d re-read the entire report from page 1 every time you want to reference anything. That’s what naive autoregressive generation does.

The KV cache is a notepad. The first time you read section 3, you jot down the key facts. Next time you need it, you check the notepad. The original report doesn’t change — your notes stay valid.

The crucial point: only the Keys and Values get cached. Queries are always fresh (you’re always generating something new). K and V represent “what past tokens contain” — that doesn’t change as new tokens arrive.

Mechanism in Plain English

Prefill phase (processing the user’s prompt):

  1. Run all prompt tokens through the model in one parallel forward pass.
  2. For every layer, store the K and V tensors for every prompt token.
  3. Compute the first output token.

Decode phase (generating the response, token by token):

  1. For each new token, compute Q for that token only.
  2. Retrieve all cached K, V tensors from previous tokens.
  3. Run attention: Q (1 token) × K (all previous tokens) = attention weights.
  4. Weighted sum of cached V tensors gives the new token’s attended representation.
  5. Compute next token. Append its K, V to the cache. Repeat.

The decode phase computes only the new token’s contribution — no re-computation of past K, V.

ASCII Diagram

  Without KV cache (naive):
  
  Token 1 generated:  [t1] → full pass → output
  Token 2 generated:  [t1, t2] → full pass again → output     ← t1 recomputed
  Token 3 generated:  [t1, t2, t3] → full pass again → output ← t1, t2 recomputed
  
  Cost: O(N²) per sequence

  ─────────────────────────────────────────────────────────────

  With KV cache:
  
  Prefill:            [t1, t2, t3, t4] → parallel → KV stored
                                          ┌─────────────────────┐
                                          │ Cache: K1,V1        │
                                          │        K2,V2        │
                                          │        K3,V3        │
                                          │        K4,V4        │
                                          └─────────────────────┘
  Decode t5:          Q5 × [K1..K4] → attn → V blend → output
                          append K5, V5 to cache
  Decode t6:          Q6 × [K1..K5] → attn → V blend → output
                          append K6, V6 to cache
  
  Cost: O(1) per token (reads grow linearly but no recomputation)

Math: Memory Cost

For a single request, KV cache memory is:

bytes = 2 × num_layers × num_kv_heads × head_dim × seq_len × dtype_bytes

Breaking this down:

  • 2 — one K tensor, one V tensor
  • num_layers — one cache per transformer layer (e.g., 32 for LLaMA-7B)
  • num_kv_heads × head_dim — the KV dimension per layer (e.g., 32 heads × 128 = 4096)
  • seq_len — grows as generation proceeds
  • dtype_bytes — 2 for fp16, 1 for int8

Concrete example: LLaMA-7B (32 layers, 32 heads, head_dim=128), 4K context, fp16: 2 × 32 × 32 × 128 × 4096 × 2 ≈ 2 GB per request

At batch size 32 that’s 64 GB — the entire GPU memory budget of an A100.

This is why KV cache management is the binding constraint at inference time, not model weights.

What’s Clever

The observation is that K and V represent “past token content” which is purely a function of input tokens and model weights — both of which are fixed during decoding. Nothing about generating token 150 changes what token 7 meant. So recomputing K and V for token 7 when generating token 150 is pure waste.

The non-obvious implication: memory, not compute, becomes the inference bottleneck. Batch size is limited by how many KV caches you can fit in GPU memory simultaneously, not by FLOPs. This shifted the entire inference systems engineering problem from compute utilization to memory management.

GQA/MQA as a consequence: if KV cache is the bottleneck, why use full multi-head attention? Grouped Query Attention (GQA) uses fewer KV heads than Q heads — LLaMA-3 70B uses 8 KV heads instead of 64, reducing cache by 8x at modest quality cost.

PagedAttention: vLLM treats the KV cache like virtual memory — allocating physical “pages” on demand rather than pre-reserving contiguous memory per request. This eliminates fragmentation and increases throughput 2-4x vs. naive pre-allocation.

Code

# Without cache: recompute K,V for all previous tokens each step → O(n²) total work
# With cache: store K,V and append each new token's K,V → O(n) per step
 
import torch
 
class KVCache:
    def __init__(self):
        self.k_cache = []  # list of [batch, n_heads, 1, d_head] tensors, one per token
        self.v_cache = []
 
    def update(self, new_k, new_v):
        """Append the new token's K and V, return full K and V for attention."""
        self.k_cache.append(new_k)               # store new token's key
        self.v_cache.append(new_v)               # store new token's value
        # Concat along the sequence dimension (dim=2)
        full_k = torch.cat(self.k_cache, dim=2)  # shape: [batch, heads, seq_len, d_head]
        full_v = torch.cat(self.v_cache, dim=2)
        return full_k, full_v                    # use these in attention for current token
 
# Decode loop (sketch):
# cache = KVCache()
# for each new token:
#     new_k, new_v = model.compute_kv(new_token)   # only for the new token
#     full_k, full_v = cache.update(new_k, new_v)  # retrieve full history
#     output = attention(Q_new, full_k, full_v)     # attend over all past tokens

Key Sources

Open Questions

  • KV cache compression (quantization, eviction policies) for very long contexts
  • KV cache offloading to CPU or disk