Explain why attention is run as multiple heads and how grouped-query attention shrinks the KV cache.
A single attention head can model one pattern at a time, and giving every query head its own K/V blows up cache memory at long context.
Glossary · 8 terms
- head
- One parallel copy of attention with its own Q/K/V slice. Each head can specialise on a different relationship.
- num_heads (H)
- How many query heads run in parallel per layer. Qwen3.5-0.8B uses 8.
- d_model (hidden)
- The width of the residual stream flowing between layers. Qwen3.5-0.8B uses 1024 — W_O projects the concatenated heads (8 × 256 = 2048) back to this width.
- num_kv_heads (G)
- How many K/V heads exist. Under GQA each K/V head is shared across H/G query heads. Qwen3.5-0.8B uses 2.
- group_size
- H / G — how many query heads share one K/V head. Four for Qwen3.5-0.8B.
- GQA
- Grouped-query attention: partitions H query heads into G K/V groups so the cache shrinks by H/G with little quality loss.
- MQA
- Multi-query attention: GQA's extreme where num_kv_heads=1. Smallest cache, biggest cut in how many different things the heads can look for.
- output projection (W_O)
- The learned matrix (2048 × 1024 here) that maps the concatenated heads back to the residual width. Each 256-wide slice of it is one head’s personal "write" back into the residual stream — concat-then-W_O is the same arithmetic as summing eight per-head contributions.
Multi-head attention & GQA
Chapter 4 introduced attention as a single operation: softmax(QKᵀ/√d)·V (here that √d is √head_dim, the per-head width — √256 = 16 for Qwen3.5 — not the full hidden size). In words: score how well each token's query matches every token's key, divide by √d so the scores don't blow up as the vectors get longer, softmax turns those scores into weights that sum to 1, and you blend the values by those weights. In real LLMs, that operation is run many times in parallel, once per head, and the results are concatenated and projected back together. The reason is capacity: a single head can model one type of relationship at a time — say, "look at the previous token." Multiple heads in parallel let the model attend to different patterns simultaneously: one head can chase the previous token, another the matching bracket, another the subject—verb agreement at distance.
Think of each head as one "lens" for comparing tokens: it learns to notice one kind of relationship. One head might track which adjective belongs to which noun; another, which earlier word a pronoun refers to. Running several in parallel lets the model watch many relationships at once.
Before zooming into heads and the KV cache, here is the whole attention block on one diagram — every stage this chapter and the last one taught separately, wired together end to end. Step through it, then flip the toggle to see how Qwen3.5's linear layers replace the whole quadratic pipeline with a single recurrent state.
The normalized residual vector for the current token enters the attention block — 1024 dims wide.
The QKᵀ softmax is quadratic in sequence length, and the KV cache grows with every token (8 query / 2 KV heads, head_dim 256). Only 6 of Qwen3.5-0.8B's 24 layers use this path.
Illustrative schematic — the full-attention path mirrors Qwen3.5's actual gated GQA (QK-norm, output gate); the linear path is a high-level sketch. No tensors flow here; not live output from the model.
Standard multi-head attention (MHA)
For hidden size d and H heads (H = the number of query heads running in parallel), each head gets its own dimension d_head (d_head = the per-head vector width) — classically d_head = d / H, so the heads tile the hidden width exactly (many models, including Qwen3.5, decouple the two — more below). The model projects the token vector to a Q, K and V of shape [H, seq, d_head] with three weight matrices, then reshapes (read [H, seq, d_head] as a 3-D array: H heads × seq positions × d_head numbers each). Each head runs its own attention over its own Q/K/V slice; the H outputs are concatenated and projected back into a d-wide vector that flows on through the layer.
Here is that split drawn out — one token's vector going in, eight head-bands coming out (and the narrower K/V path that this chapter's GQA section is about):
One token arrives as a single 1024-wide vector. No heads anywhere yet.
Color = head identity, matching the concat diagram further down — the band that splits off here is the same head whose output merges back there. Widths are to scale: 2048 query features vs 512 K/V features.
At inference, K and V for each token are cached across generation steps — that's what makes the second token onward fast. Why K and V, but never Q? At decode time the model only needs the new token's query — but that one query must dot against every past token's key and blend every past token's value. So past K and V get saved and re-read at every step, while a past token's Q was used once, at its own step, and is never needed again. The cache size is 2 × H × d_head × seq_len floats per layer. For a 32k-context model with deep layer counts and wide hidden sizes, that's gigabytes — and unlike weights, it grows with each new token. The KV cache, not the weights, is what dominates memory at long context.
Each of the 8 heads produces a slice of shape [seq_len, d_head] = [5, 256]. We stack them side-by-side along the feature axis to recover a [5, 2048] matrix, then project by the learned output matrix WO back to the 1024-dim residual stream. Heads don't talk to each other inside attention — they only mix afterward, here.
Grouped query attention (GQA)
GQA is the now-standard trick to shrink the cache. Instead of giving every query head its own K/V head, you partition the query heads into num_kv_heads groups and let every head in a group share a single K/V pair:
The Q projection still has H heads (each with their own weights), but the K and V projections only have G heads (G = the number of key/value groups, one shared K/V per group). When attention is computed for query head h, it dot-products against the K of group floor(h / group_size) — and reads from the same group's V. KV cache shrinks by a factor of H/G with no change to the Q dimension and minimal accuracy loss in practice.
Qwen3.5-0.8B specifics
Qwen3.5-0.8B uses num_heads = 8, num_kv_heads = 2, head_dim = 256. Each K/V head is shared across 4 query heads, so its KV cache is 4× smaller than the equivalent full-MHA model. Larger Qwen3.5 variants keep the same group_size = 4 ratio.
Notice that Qwen3.5 decouples head_dim from hidden / H: the classic default would be 1024 / 8 = 128, but it picks head_dim = 256 independently — per-head width is a free hyperparameter. So the query projection maps hidden = 1024 → H · head_dim = 8 · 256 = 2048, not a square d × d matrix. (Qwen adds a couple of wrinkles to this projection — see Advanced below.)
Qwen3.5 also alternates two kinds of layer, and the layer selector below restricts to the classic full-attention layers with softmax scores you can inspect (the other kind is covered in Advanced below). Only every fourth layer — 6 of the 24 — runs this full-softmax attention and keeps a growing KV cache; the other 18 are linear (GatedDeltaNet) layers that carry a fixed-size recurrent state instead — the KV-cache chapter (chapter 12) covers them. So the cache numbers below count only those 6 layers, not all 24.
Advanced: Qwen-3.5 specifics (output gate, QK-norm, linear layers) · optional, for the curious
In Qwen3.5 the full q_proj weight is actually twice as wide — 4096 outputs — because it also emits a per-head output gate next to the queries, as the pipeline diagram above shows; the 2048 above is the query half. (Qwen also applies a per-head RMSNorm to the queries and keys — q_norm / k_norm — before the dot products; the architecture chapter covers both.)
Layers in Qwen3.5 alternate: most are linear-attention (a recurrent variant outside this chapter's scope), and every fourth layer is the classic full-attention layer with softmax scores you can inspect. The layer selector below restricts to the latter.
Why W_O exists
One matrix in the pipeline diagram deserves a second look: the output projection W_O at the very end. Step back and trace what a single head does to the residual stream, ignoring attention weights for a moment. Its value path is two linear maps glued together: v_proj takes the 1024-wide residual vector down to a 256-dim value, and that head's 256-wide slice of W_O takes the result back up to 1024. That is a low-rank map: it can only move the residual vector within a 256-dimensional subspace, but it costs far less than a free-form one. A full-rank per-head map would be a 1024 × 1024 matrix — about 1.05M parameters per head — while the factored down/up pair is 2 × (1024 × 256) ≈ 0.52M, half the cost even before GQA shares the down half (more on that below). Eight constrained heads for the price of four unconstrained ones is the trade the architecture makes.
The "concatenate, then multiply by W_O" recipe also looks more mysterious than it is. Slice the 2048 × 1024 matrix into 8 horizontal bands of 256 × 1024, one per head. Multiplying the concatenated 2048-vector by W_O is exactly the same arithmetic as letting each head push its own 256-dim output through its own band and then summing the 8 results — concatenation followed by one big projection ≡ a sum of per-head contributions. So a better mental model than "glue the heads together" is: each head independently proposes a small low-rank edit, and the edits are added onto the residual stream. No head overwrites another; they accumulate.
GQA adds one wrinkle to this picture for our model. The down half is shared per group: v_proj only has 2 KV heads (1024 → 512), so four query heads in a group read the same 256-dim value vectors. But the up half stays fully per-head: o_proj is 2048 → 1024, giving every one of the 8 query heads its own 256 → 1024 band. Group-mates blend the same values with different attention weights, then write the result into the residual stream through different projections — shared reading, private writing.
The trade-off, and where MQA fits
GQA gives up a little of how many different things the heads can look for — query heads in the same group cannot attend to different keys, because they share K. They can still learn different attention patterns over those shared keys via their distinct Q projections (you'll see this in the side-by-side below). The memory win is large, the quality loss is small, and the decode speedup from a smaller cache is real. Mid-size open models like Llama 3 and Qwen3.5 all use GQA.
Push GQA to the extreme — num_kv_heads = 1 — and you get Multi-Query Attention (MQA): every query head shares one K/V. Earlier inference-focused models (PaLM, Falcon) used MQA. The consensus today is that GQA with a moderate group size (typically 4–8) recovers most of MHA's quality at most of MQA's cache savings.
Sweep the whole spectrum below. As you move from MHA (8 KV heads) down to MQA (1), the per-layer KV cache shrinks linearly. Every variant keeps all eight query heads, so the query side is untouched — what shrinks is the number of distinct sets of K/V features the heads can match against, since the query heads in a group must now share one K/V.
★ = Qwen3.5-0.8B's choice. 1 = MQA (one shared K/V), 8 = MHA (one K/V per query head).
The KV cache shrinks linearly with the number of KV heads. Every variant keeps all 8 query heads, so the query side is untouched — but sharing K/V means fewer distinct sets of K/V features for the heads to match against, a real (if usually modest) cut in capacity, not zero capacity. Empirically GQA still recovers most of MHA's quality at a fraction of the cache (a general result, not something this toy sweep measures), which is why Qwen3.5-0.8B uses 2 KV heads — a 4× cache saving over MHA.
Illustrative — the cache floats are exact (2 · #KV · head_dim); not live output from the model.
What different heads end up doing
You can't read a head's job off its weight matrices — you have to look at the patterns it produces on real inputs. Here are three common archetypes drawn as illustrative heatmaps.
Different heads learn to detect different things. You can't tell what a head detects from its weights alone — you have to see the patterns it produces on real inputs. Here are three common archetypes a typical mid-size LLM contains, drawn as illustrative heatmaps.
Detector: position. Each token mostly attends to itself; minor leak to neighbours.
Detector: recency. Strong attention to the immediately-prior token.
Detector: syntactic head. Late tokens look back at the determiner ("t0").
These are hand-built diagrams, not measured from any specific model. Real heads are messier and often mix multiple of these archetypes.
How big is the KV cache, really?
Numbers make the savings concrete. The bars below compare the whole-model KV cache for a hypothetical MHA Qwen3.5-0.8B against the actual GQA layout at an 8k context — the level a chat session can hit in a few turns.
The widget on the right runs the same "The cat sat on the" prompt as chapter 4 and ghosts the model's predicted next token at the end with a rainbow shimmer. The lane diagram above shows which query heads share which K/V head, and the two-up heatmaps below compare two query heads in the same group — same K, different Q, so different attention patterns over the same keys.
- Heads in the same K/V group share keys and values — they can still attend differently because their Q projections are separate.
- GQA cuts the KV cache by H/G with almost no quality loss; for Qwen3.5-0.8B that's a 4x saving over full MHA.
- At long context, the KV cache (not the weights) is what fills up memory and dominates decode latency.
After the auto-run, the lane diagram starts on Q0 and the side-by-side compares it against Q1 (its group-mate). Click around the other Q0–Q3 chips in the same group and compare the heatmaps. Are the patterns identical, similar, or very different?