Summary

Ainslie et al. (2023) address a fundamental inference bottleneck in autoregressive Transformer decoding: the key-value (KV) cache grows proportionally with both sequence length and the number of attention heads, becoming a major memory bandwidth bottleneck for large-batch or long-context inference. The paper builds on Multi-Query Attention (MQA, Shazeer 2019), which uses a single shared KV head across all query heads to dramatically reduce KV cache size, but proposes Grouped-Query Attention (GQA) as a more general interpolation. GQA partitions query heads into G groups, each sharing one KV head. G=1 recovers MQA; G=H (H = number of query heads) recovers standard Multi-Head Attention (MHA).

The authors also propose a practical “uptraining” recipe: an existing MHA checkpoint can be converted to GQA or MQA by mean-pooling the original KV head projections within each group, then continuing pretraining for just 5% of the original compute. This avoids training from scratch. Experiments on T5-Large (encoder-decoder) and T5-XXL show that uptrained GQA matches MHA quality while achieving MQA-level inference speed. GQA was adopted in LLaMA 2 70B, Mistral 7B, Mixtral 8x7B, Gemma, and most subsequent open-weight models, making it the de facto standard attention variant for efficient LLM inference.

Key Claims

  • MQA (1 KV head) reduces KV cache size by H× (H = number of heads) and achieves near-MHA quality when uptrained on 5% of original compute.
  • GQA with G=8 groups achieves quality within 0.1–0.2 perplexity points of MHA on T5-Large/XXL while matching MQA inference speed.
  • Uptraining recipe: convert MHA checkpoint by mean-pooling KV projections per group, continue training on 5% of original steps — sufficient to recover quality.
  • GQA reduces memory bandwidth requirements proportionally to the number of KV heads (H/G× reduction in KV cache memory vs MHA).
  • Encoder models (which don’t use KV caching) show no inference benefit from GQA; the gain is specific to autoregressive decoding.

Methods

In standard MHA, each of H heads has its own Q, K, V projections. In GQA, H query heads are divided into G groups of H/G heads each; each group shares a single K and V projection. During decoding, the KV cache stores only G key-value pairs per layer per token instead of H, reducing cache memory by H/G×. The uptraining procedure: (1) take a pretrained MHA checkpoint; (2) for each new KV group, initialize its projection as the mean of the original KV projections assigned to that group; (3) continue pretraining with the GQA objective. Experiments use T5 trained on C4; quality is measured by perplexity and downstream task performance (SuperGLUE, CNN/DM summarization).

Failure modes

  • GQA reduces KV cache memory but does not reduce the compute of the attention operation itself (still O(N²) in sequence length for full attention).
  • The quality–efficiency tradeoff depends on the number of groups G; choosing G requires hyperparameter search or empirical validation per model.
  • Uptraining on 5% compute is sufficient for modest quality recovery but may not fully close the gap for tasks highly sensitive to attention diversity.
  • GQA does not address quadratic sequence length scaling — FlashAttention or sparse attention is still needed for long context.

Connections

Citation

arXiv:2305.13245

Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. https://arxiv.org/abs/2305.13245