Concepts: attention | KV cache | inference efficiency | grouped-query attention Builds on: Attention Is All You Need | FlashAttention Leads to: vLLM
The problem nobody mentions at training time
When a language model generates text, it doesn’t start fresh at each token. It looks back at everything it’s seen — every previous token — and asks: which parts matter for the next word? That lookback is attention. And it requires loading something called the KV cache: matrices of keys and values, one set for every past token, stored in GPU memory.
Here’s what nobody tells you during training: at inference, compute isn’t the bottleneck. Memory bandwidth is. For every single token generated, the GPU must haul the entire KV cache out of high-bandwidth memory (HBM) and into compute units. The bigger the context, the more you haul. The more attention heads you have, the more copies you haul.
Standard transformers have H=32 attention heads. Every head gets its own copy of K and V. That means 32 K matrices and 32 V matrices loaded from memory — for every token, every request, every user. Multi-Query Attention (MQA, Shazeer 2019) went nuclear: reduce to 1 shared K/V head for all 32 query heads. Blazing fast, but quality degraded and training became unstable. GQA finds the middle ground that actually shipped into production.
The core idea
Let’s start with a concrete picture of where the memory goes.
Think of a large hotel where concierge staff (key-value heads) maintain detailed guest records. Every time a guest (query head) needs to respond to a request, they consult those records. In standard MHA, each of 32 guests has their own dedicated concierge with a full copy of all records — premium service, but 32 full concierge desks running simultaneously. MQA says: one shared front desk, everyone consults the same generic records. Fast to maintain, but one desk can’t serve 32 specialists. GQA says: group guests into teams of 4, with one concierge per team. Eight desks instead of 32. Each desk’s records are more specialized than the single shared one. You get most of the bandwidth win with almost none of the quality loss.
The mechanism, step by step:
- Take H=32 query heads. Divide into G=8 groups of 4 heads each.
- Assign ONE key projection and ONE value projection to each group — shared by all 4 query heads in that group.
- Attention for query head uses its group’s K/V, not a private copy:
where maps each query head to its group index.
Translation: each of the 32 query heads still has its own projection (what it’s searching for), but 4 heads share the same and (what they’re looking at and what they retrieve). The diversity in attention patterns comes mostly from the query side anyway — sharing K/V costs little.
“Grouped-query attention divides query heads into G groups, each of which shares a single key head and value head.”
MHA (32 K/V) MQA (1 K/V) GQA-8 (8 K/V)
Queries: Q1 Q2 ... Q32 Q1 Q2 ... Q32 Q1 Q2 Q3 Q4 | Q5...Q8 | ...
| | | \ | / \ | | / \ | | /
↓ ↓ ↓ \ ↓ / ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
KV heads: K1 K2 ... K32 K K1 K2
V1 V2 ... V32 V V1 V2 ...
KV reads: 32× 1× 8×
Bandwidth: 32× 1× 4× less than MHA
Quality: ████ (baseline) ██░ (degraded) ███░ (near-baseline)
Now the numbers. Assume a typical LLM config: H=32 heads, head dimension , 32 transformer layers, float16 precision (2 bytes per value), 2,048 tokens in context.
KV cache per layer per token:
MHA: 32 heads × 2 (K+V) × 128 values × 2 bytes = 16,384 bytes/token/layer
GQA-8: 8 heads × 2 × 128 × 2 bytes = 4,096 bytes/token/layer (4x smaller)
MQA: 1 head × 2 × 128 × 2 bytes = 512 bytes/token/layer (32x smaller)
Total KV cache for 2,048-token context across all 32 layers:
MHA: 2048 × 32 × 16,384 = 1,024 MB (~1 GB)
GQA-8: 2048 × 32 × 4,096 = 256 MB (4x smaller)
MQA: 2048 × 32 × 512 = 32 MB (32x smaller)
An A100 has roughly 900 GB/s HBM bandwidth. Loading the full KV cache for each new token costs:
MHA: 1024 MB / 900 GB/s ≈ 1.14 ms — just in memory reads
GQA-8: 256 MB / 900 GB/s ≈ 0.28 ms (4x faster)
MQA: 32 MB / 900 GB/s ≈ 0.036 ms (31x faster)
This happens for every token generated. At 200 tokens per response, MHA spends ~228 ms just moving K/V data. GQA-8 brings that to ~57 ms. The win isn’t from being smarter — it’s from being leaner.
“Going from MHA to MQA reduces H key and value heads to a single key and value head, reducing the size of the key-value cache and therefore amount of data that needs to be loaded by a factor of H.”
The uptraining trick. Already have an MHA model and don’t want to train from scratch? The paper proposes a two-step recipe:
- Convert: For each GQA group, take the original K/V projection matrices from all heads assigned to that group and average them (mean-pool).
- Uptrain: Continue pretraining on only 5% of the original training compute to let the model adapt.
“Mean pooling appears to work best, followed by selecting a single head and then random initialization. Intuitively, results are ordered by the degree to which information is preserved from the pre-trained model.”
Why mean pooling? It preserves the most signal from all original heads simultaneously. Selecting one head discards the others. Random init throws everything out and forces the model to relearn from scratch. Mean pooling is a compressed average of what all heads knew — the model then fine-tunes from that informed starting point.
A key finding from the ablations:
“GQA already achieves reasonable performance after conversion while MQA requires uptraining to be useful.”
This tells you where the real compression pressure is. MQA is so extreme — 32 heads crammed into 1 — that the model needs substantial retraining to adapt. GQA-8 is close enough to MHA that the converted checkpoint works immediately. Uptraining just polishes it.
What’s clever. The non-obvious insight is about how MQA gets worse as models scale. Larger models have more heads — 64, 96, or more. MQA always compresses to 1, so the compression ratio grows with model size. A 70B model with 64 heads would compress 64→1 under MQA. Quality loss compounds. GQA maintains a fixed group count (G=8), so the compression ratio is always H/G — stable across model sizes.
“larger models generally scale the number of heads, such that multi-query attention represents a more aggressive cut in both memory bandwidth and capacity. GQA lets us keep the same proportional decrease in bandwidth and capacity as model size increases.”
That’s the scaling argument that sealed GQA as the standard. It’s not just that GQA is better than MQA at any given size — it’s that GQA’s quality advantage grows as you scale up.
Does it work, and what breaks?
| Model | Inference time | Avg quality (7 benchmarks) |
|---|---|---|
| T5-XXL MHA (baseline) | 1.51s | 47.2 |
| T5-XXL MQA (5% uptrained) | 0.24s — 6.3x faster | 46.6 (-0.6) |
| T5-XXL GQA-8 (5% uptrained) | 0.28s — 5.4x faster | 47.1 (-0.1) |
GQA-8 is 5.4x faster than full MHA while losing only 0.1 average quality points across 7 NLP benchmarks (CNN/DM, arXiv, PubMed, MediaSum, MultiNews, WMT, TriviaQA). MQA is marginally faster but costs 0.6 quality points — six times the quality degradation for modest extra speed.
What doesn’t work:
GQA does nothing for encoder self-attention. Encoders compute all representations in parallel, so there’s no per-token sequential loading bottleneck. The speedup is entirely in autoregressive decoding.
MQA showed consistent training instability in experiments — loss spikes during pretraining, divergence during fine-tuning on long-input tasks:
“Uptrained grouped-query attention models, however, appear to be stable” — unlike MQA, which required multiple fine-tuning runs and averaged results.
GQA doesn’t reduce quadratic sequence-length scaling. For very long contexts, you still need FlashAttention or sparse attention — GQA reduces the data volume, not the algorithmic complexity.
The paper only tested uptraining, not GQA trained from scratch. Whether the 5% uptraining fully matches scratch-trained GQA remains an open question.
If you’re building ML systems
GQA is now the default for any serious open-weight LLM. Llama 2 70B, Llama 3, Mistral 7B, Mixtral, Gemma — all use it. If you’re training from scratch: use GQA-8 (or whatever group count divides cleanly into your head count). There’s no meaningful quality downside and the inference throughput difference compounds at production scale.
If you have an existing MHA model and want faster inference: the uptraining recipe works. Mean-pool the K/V projections, continue pretraining for 5% of original compute. You don’t need to restart training.
GQA matters most when memory bandwidth is your bottleneck — which it is during streaming token generation. If you’re doing large-batch prefill of long prompts (which is compute-bound, not bandwidth-bound), GQA helps less. For the typical generation workload — sequential token-by-token output — the 4-6x reduction in KV reads translates directly to throughput.
This connects to what PagedAttention does next: once GQA shrinks the KV cache, PagedAttention manages that cache efficiently across variable-length sequences with virtual memory paging. And FlashAttention reduces the memory bandwidth of the attention computation itself. GQA + FlashAttention + PagedAttention is the standard stack that makes modern LLM serving economically viable.
GQA proves you don’t have to choose between the quality of 32 attention heads and the speed of 1 — for the price of 5% more pretraining compute, you get both.
Paper: GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — Ainslie et al. — 2023
Connections
- attention-is-all-you-need — original multi-head attention that GQA generalizes
- flash-attention-fast-and-memory-efficient-exact-attention — complementary optimization: GQA reduces KV cache size, FlashAttention reduces bandwidth per attention call
- pagedattention-vllm — next step: managing the smaller KV cache GQA enables
- attention — the mechanism being optimized
- kv-cache — the inference bottleneck GQA directly targets
- inference-efficiency — broader category; GQA is one of the core techniques
- gqa — concept entry with the math
Citation
Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. https://arxiv.org/abs/2305.13245