Concepts: gqa | sliding-window-attention | inference-efficiency | kv-cache Builds on: attention-is-all-you-need | llama-2-open-foundation-fine-tuned-chat-models Leads to: Mixtral 8x7B (not yet ingested)
Deploying a 13B language model costs roughly twice as much per inference request as a 7B model — not just once at training time, but at every single request, forever. So when Mistral AI published a 7B model in October 2023 that outperformed LLaMA 2 13B across every benchmark, the question was immediate: where was the performance hiding?
Not in a clever training trick. Not in proprietary data. The answer was architectural: two targeted modifications to the standard Transformer made inference significantly cheaper at no accuracy cost. A 7B model behaving like a 13B is not magic — it’s what happens when you stop wasting memory on K/V caches that don’t need to exist.
The core idea
The analogy: When reading a dense technical document, you don’t re-read every sentence you’ve ever seen before processing the next word. You hold a working window of recent context vividly — the last few paragraphs — while earlier content is compressed into your general understanding of the document. A separate, related trick: four colleagues can share one reference shelf rather than each maintaining their own copy of the same books.
Mistral 7B encodes both observations directly into the architecture:
- Sliding Window Attention (SWA): each layer looks back at most recent tokens instead of the full sequence. Like keeping the last few paragraphs vivid in working memory.
- Rolling Buffer Cache: the KV cache is fixed-size, overwriting old entries as new tokens arrive. Memory footprint stays bounded regardless of sequence length.
- Grouped Query Attention (GQA): 32 query heads share 8 K/V head pairs — 4 queries per shared source. One reference shelf, four colleagues.
The result: reasoning, math, and code performance surpasses LLaMA 2 13B at 45% of the parameter count.
The mechanism, step by step
Sliding Window Attention
Standard attention lets every token attend to every previous token. The KV cache grows with sequence length , and loading it costs per inference step.
SWA restricts attention to a local window of recent tokens:
Attention cost drops to per step — constant regardless of how long the sequence gets. But doesn’t this lose all information beyond tokens back?
No. Information percolates forward through stacked layers. After attention layers, token can incorporate information from up to positions back:
SLIDING WINDOW ATTENTION — how information propagates across layers
Sequence of 9 tokens, window W=3:
Layer 1 — what each position directly attends to:
pos 0: {0}
pos 1: {0,1}
pos 2: {0,1,2}
pos 3: {1,2,3} ← pos 0 falls out of the window
pos 4: {2,3,4}
pos 5: {3,4,5}
Layer 2 — indirect reach (each position attends to Layer 1 representations):
pos 5 at L2 attends {3,4,5}
pos 3 (L1) had seen {1,2,3}
pos 4 (L1) had seen {2,3,4}
pos 5 (L1) had seen {3,4,5}
→ pos 5 at L2 indirectly touches tokens {1..5} = 2×W back
Pattern: after k layers, effective context = k × W positions
Mistral 7B: k=32 layers, W=4096 → 32 × 4096 = 131,072 tokens
The paper states: “using a window size of , we have a theoretical attention span of approximately 131K tokens.”
Rolling Buffer Cache
With a bounded attention window, the KV cache only needs slots per layer. Keys and values for position are stored at cache slot :
When , old entries are overwritten. The cache never grows:
Rolling buffer (W=4), tokens arriving one by one:
After token 0: cache = [K₀, _, _, _]
After token 1: cache = [K₀, K₁, _, _]
After token 2: cache = [K₀, K₁, K₂, _]
After token 3: cache = [K₀, K₁, K₂, K₃]
After token 4: cache = [K₄, K₁, K₂, K₃] ← slot 0 overwritten (4 mod 4 = 0)
After token 5: cache = [K₄, K₅, K₂, K₃] ← slot 1 overwritten (5 mod 4 = 1)
After token 6: cache = [K₄, K₅, K₆, K₃]
Always holds the most recent W tokens. Memory: constant.
For W=4096 on a 32K-token sequence:
Without rolling buffer: 32,768 cache entries
With rolling buffer: 4,096 cache entries
Reduction: 32,768 / 4,096 = 8×
The paper confirms: “On a sequence length of 32k tokens, this reduces the cache memory usage by 8x, without impacting the model quality.”
Grouped Query Attention
Standard multi-head attention (MHA) maintains independent K and V matrices for all query heads. During inference, every new token extends 32 separate K/V sequences. GQA groups query heads to share K/V heads. Mistral uses query heads and K/V heads — 4 query heads per K/V pair:
Query heads 0–3 share K/V head 0. Query heads 4–7 share K/V head 1. And so on.
Numeric walkthrough — KV cache size:
Mistral 7B: n_heads=32, n_kv_heads=8, head_dim=128, n_layers=32
Standard MHA — KV cache per token, per layer (fp16):
K: 32 heads × 128 dims × 2 bytes = 8,192 bytes
V: 32 heads × 128 dims × 2 bytes = 8,192 bytes
Total: 16,384 bytes/token/layer
GQA — KV cache per token, per layer (fp16):
K: 8 heads × 128 dims × 2 bytes = 2,048 bytes
V: 8 heads × 128 dims × 2 bytes = 2,048 bytes
Total: 4,096 bytes/token/layer
Reduction: 16,384 / 4,096 = 4×
For L=8,192 tokens, 32 layers:
Standard MHA: 8,192 × 32 × 16,384 bytes ≈ 4.3 GB
GQA: 8,192 × 32 × 4,096 bytes ≈ 1.1 GB
Freed: ≈ 3.2 GB — fits 4× more concurrent requests
What’s clever — find the instinct:
The key realization is that K/V quality and K/V count are nearly decoupled. Each query head is asking a different question, but the source material they’re querying doesn’t need to be unique per head. The diversity in attention behavior comes from the query projections. The keys and values — what each token “offers up” for comparison — can be shared across four queries at minimal accuracy cost.
This is also the instinct behind SWA: empirically, most attention weights are already concentrated on nearby tokens. If you look at where transformers actually attend in practice, the long-tail attention to tokens 10,000 positions back is small. SWA formalizes the observation by hard-constraining it — and saves the cost of computing attention weights you’d mostly throw away.
“Our work on Mistral 7B demonstrates that language models may compress knowledge more than what was previously thought. This opens up interesting perspectives: the field has so far put the emphasis on scaling laws in 2 dimensions… the problem is rather 3 dimensional (model capabilities, training cost, inference cost).”
Results
| Model | MMLU | HellaSwag | ARC-C | GSM8K | MATH | HumanEval |
|---|---|---|---|---|---|---|
| LLaMA 2 7B | 44.4% | 77.1% | 43.2% | 16.0% | 3.9% | 11.6% |
| LLaMA 2 13B | 55.6% | 80.7% | 48.8% | 34.3% | 6.0% | 18.9% |
| Mistral 7B | 60.1% | 81.3% | 55.5% | 52.2% | 13.1% | 30.5% |
Mistral 7B surpasses LLaMA 2 13B on every benchmark despite having 45% fewer parameters. The largest gaps are math (+17.9 points on GSM8K) and code (+11.6 on HumanEval). On instruction-following: Mistral 7B Instruct scores MT-Bench 6.84 vs LLaMA 2 13B Chat’s 6.65, and Chatbot Arena ELO 1031 vs 1012.
The paper also reports that Mistral 7B performs like a LLaMA 2 model with more than 3× its parameters on reasoning and STEM benchmarks — the efficiency multiplier the architecture delivers.
What doesn’t work:
Knowledge retrieval benchmarks are the exception. On NaturalQuestions, Mistral 7B scores 28.8% vs LLaMA 2 13B’s 29.0% — essentially tied. The paper is direct: “this is likely due to its limited parameter count that restricts the amount of knowledge it can store.” GQA and SWA improve inference efficiency and reasoning; they can’t conjure parametric knowledge the model doesn’t have room to store.
SWA’s 131K theoretical span is also theoretical. Information from early tokens attenuates as it propagates through 32 selective attention layers. Precise recall of specific tokens from many layers’ worth of propagation back may degrade compared to full attention — making long-context needle-in-a-haystack tasks a weak point.
Practitioner notes
If you’re building ML systems that serve long prompts — RAG pipelines, document analysis, multi-turn conversations — GQA’s 4× KV cache reduction means 4× more concurrent requests on the same GPU. The rolling buffer makes KV cache growth predictable: regardless of whether the conversation is 1K or 32K tokens, peak memory stays bounded. Both matter enormously for production serving.
For RAG specifically, SWA’s local focus is well-matched. You’re reasoning about a retrieved chunk (local context), not a token from 50K positions back. A 4096-token window covers most practical retrieval granularities.
The broader lesson from the “3D scaling” framing: benchmark scores are one axis, training cost is another, and inference cost per token is the third — often the most important for high-traffic applications. A model that’s 15% better on MMLU but 4× slower to serve may be worse in production. Mistral 7B makes this tradeoff explicit and tilts it decisively toward efficiency.
Connections
- gqa — grouped query attention: multiple query heads sharing K/V heads, 4× KV cache reduction
- sliding-window-attention — each attention layer attends only to a W-token local window
- inference-efficiency — the primary design objective; GQA and SWA both target serving cost
- kv-cache — rolling buffer is a bounded-memory variant of the standard KV cache
- transformer — Mistral 7B is a transformer with targeted attention modifications
- attention-is-all-you-need — original transformer; Mistral modifies its attention and KV mechanism
- llama-2-open-foundation-fine-tuned-chat-models — primary baseline; Mistral 7B outperforms the 13B variant
Citation
Jiang, A. Q., Sablayrolles, A., Mensch, A., Bamford, C., Chaplot, D. S., de las Casas, D., Bressand, F., Lengyel, G., Lample, G., Saulnier, L., Lavaud, L. R., Lachaux, M.-A., Stock, P., Le Scao, T., Lavril, T., Wang, T., Lacroix, T., & El Sayed, W. (2023). Mistral 7B. arXiv preprint. https://arxiv.org/abs/2310.06825