解释注意力为什么要拆成多个头并行运行,以及分组查询注意力(GQA)如何缩小 KV 缓存。
单个注意力头一次只能建模一种模式,而给每个 query 头都配一份独立的 K/V,会让长上下文下的缓存内存爆炸。
术语表 · 8 个术语
- head
- 注意力的一个并行副本,拥有自己的 Q/K/V 切片。每个头可以专注于一种不同的关系。
- num_heads (H)
- 每层并行运行多少个 query 头。Qwen3.5-0.8B 用 8 个。
- d_model (hidden)
- 在层与层之间流动的残差流的宽度。Qwen3.5-0.8B 用 1024——W_O 把拼接后的各头输出(8 × 256 = 2048)投影回这个宽度。
- num_kv_heads (G)
- 一共有多少个 K/V 头。在 GQA 下,每个 K/V 头被 H/G 个 query 头共享。Qwen3.5-0.8B 用 2 个。
- group_size
- H / G——多少个 query 头共享同一个 K/V 头。对 Qwen3.5-0.8B 来说是 4。
- GQA
- 分组查询注意力:把 H 个 query 头划分成 G 个 K/V 组,使缓存缩小 H/G 倍,而质量损失很小。
- MQA
- 多查询注意力:GQA 的极端形式,即 num_kv_heads=1。缓存最小,但各头能寻找的不同模式数量也被削减得最狠。
- output projection (W_O)
- 把拼接后的各头映射回残差宽度的可学习矩阵(这里是 2048 × 1024)。它的每个 256 宽切片就是一个头写回残差流的专属“笔”——先拼接再乘 W_O,与把 8 份逐头贡献相加是同一套算术。
多头注意力与 GQA
第 4 章把注意力当作一个单一操作来介绍:softmax(QKᵀ/√d)·V(这里的 √d 是 √head_dim,即每个头的宽度——对 Qwen3.5 来说 √256 = 16——而不是完整的隐藏宽度)。用文字描述就是:先给每个 token 的 query 与所有 token 的 key 的匹配程度打分,除以 √d 以免向量变长时分数爆炸,softmax 把这些分数变成总和为 1 的权重,再按这些权重混合各个 value。在真实的 LLM 里,这个操作会被并行运行很多次,每个头(head)各跑一份,最后把各头的结果拼接起来、投影回去。原因在于容量:单个头一次只能建模一种关系——比如“看上一个 token”。多个头并行让模型可以同时关注不同的模式:一个头追上一个 token,另一个头找匹配的括号,还有一个头处理远距离的主谓一致。
可以把每个头看作比较 token 的一面“透镜”:它学会留意一种特定的关系。一个头可能追踪哪个形容词属于哪个名词;另一个头追踪代词指代前文的哪个词。多面透镜并行运行,模型就能同时盯住许多种关系。
在深入头和 KV 缓存之前,先看一张把整个注意力模块画在一起的图——本章和上一章分开讲的每个阶段,端到端连成一条线。先逐步走一遍,再切换开关,看看 Qwen3.5 的线性层如何用一个循环状态取代整条二次复杂度的流水线。
当前 token 归一化后的残差向量进入注意力模块——宽 1024 维。
QKᵀ softmax 与序列长度成平方关系,KV 缓存随每个 token 增长(8 个 query 头 / 2 个 KV 头,head_dim 256)。Qwen3.5-0.8B 的 24 层中只有 6 层走这条通路。
示意图——全量注意力通路对应 Qwen3.5 真实的门控 GQA(QK-norm、输出门);线性通路只是高层速写。这里没有张量流动;不是模型的实时输出。
标准多头注意力(MHA)
对于隐藏宽度 d 和 H 个头(H = 并行运行的 query 头数),每个头有自己的维度 d_head(d_head = 每个头的向量宽度)——经典设定是 d_head = d / H,让各个头恰好铺满整个隐藏宽度(很多模型,包括 Qwen3.5,把这两者解耦了——下文细说)。模型用三个权重矩阵把 token 向量投影成形状为 [H, seq, d_head] 的 Q、K、V,再做 reshape(把 [H, seq, d_head] 读成一个三维数组:H 个头 × seq 个位置 × 每个位置 d_head 个数)。每个头在自己的 Q/K/V 切片上独立运行注意力;H 个输出被拼接起来,投影回一个 d 宽的向量,继续流过这一层。
下面把这个切分画了出来——一个 token 的向量进入,八条头带出来(还有那条更窄的 K/V 通路,正是本章 GQA 部分要讲的):
一个 token 以单个 1024 宽的向量到达。此刻还没有任何头。
颜色 = 头的身份,与下方的拼接图一一对应——在这里拆出去的那条带,正是在那里合并回来的同一个头。宽度按真实比例绘制:2048 个 query 特征 vs 512 个 K/V 特征。
推理时,每个 token 的 K 和 V 会跨生成步缓存下来——这正是从第二个 token 起速度变快的原因。为什么缓存 K 和 V,却从不缓存 Q?解码时模型只需要新 token 的 query——但这个 query 必须和每个历史 token 的 key 做点积,并混合每个历史 token 的 value。所以历史的 K 和 V 要保存下来、每一步都重新读取,而历史 token 的 Q 只在它自己那一步用过一次,之后再也用不到。缓存大小是每层 2 × H × d_head × seq_len 个浮点数。对一个 32k 上下文、层数深、隐藏宽度大的模型来说,这就是 GB 量级——而且与权重不同,它会随每个新 token 持续增长。在长上下文下占满内存的是 KV 缓存,不是权重。
8 个头各产出一片形状为 [seq_len, d_head] = [5, 256] 的切片。我们沿特征轴把它们并排堆起来,得到一个 [5, 2048] 矩阵,再用学到的输出矩阵 WO 投影回 1024 维的残差流。注意力内部各头之间互不交流——它们只在之后、也就是这里混合。
分组查询注意力(GQA)
GQA 是如今缩小缓存的标准技巧。它不再给每个 query 头单独配一个 K/V 头,而是把 query 头划分成 num_kv_heads 组,让同组的每个头共享同一对 K/V:
Q 投影仍然有 H 个头(各自有自己的权重),但 K 和 V 投影只有 G 个头(G = key/value 组的数量,每组共享一份 K/V)。计算 query 头 h 的注意力时,它和组 floor(h / group_size) 的 K 做点积——并读取同一组的 V。KV 缓存缩小 H/G 倍,Q 的维度完全不变,实践中的精度损失极小。
Qwen3.5-0.8B 的具体配置
Qwen3.5-0.8B 使用 num_heads = 8、num_kv_heads = 2、head_dim = 256。每个 K/V 头被 4 个 query 头共享,所以它的 KV 缓存比等价的完整 MHA 模型小 4×。更大的 Qwen3.5 变体保持同样的 group_size = 4 比例。
注意 Qwen3.5 把 head_dim 和 hidden / H 解耦了:经典的默认值应该是 1024 / 8 = 128,但它独立选择了 head_dim = 256——每个头的宽度是一个自由的超参数。因此 query 投影做的映射是 hidden = 1024 → H · head_dim = 8 · 256 = 2048,而不是一个方形的 d × d 矩阵。(Qwen 还给这个投影加了几处小变化——见下方进阶。)
Qwen3.5 还交替使用两种层,下方的层选择器只列出可以检查 softmax 分数的经典全量注意力层(另一种层在下方进阶中介绍)。只有每四层中的一层——24 层中的 6 层——运行这种完整的 softmax 注意力并保有持续增长的 KV 缓存;其余 18 层是线性(GatedDeltaNet)层,携带的是固定大小的循环状态。所以下面的缓存数字只统计这 6 层,而不是全部 24 层。
进阶:Qwen-3.5 特有细节(输出门、QK-norm、线性层) · 可选,给好奇的读者
在 Qwen3.5 中,完整的 q_proj 权重实际上是两倍宽——4096 个输出——因为它在 query 旁边还输出一个逐头的输出门,正如上方流水线图所示;上文写的 2048 是 query 那一半。(Qwen 还在点积之前对 query 和 key 做逐头的 RMSNorm——q_norm / k_norm;架构一章会把两者都讲到。)
Qwen3.5 的层是交替排布的:大多数层是线性注意力(一种循环变体,超出本章范围),每第四层才是可以检查 softmax 分数的经典全量注意力层。下方的层选择器只列出后者。
W_O 是怎么来的
流水线图里有一个矩阵值得回头多看一眼:最末端的输出投影 W_O。先退一步,暂时忽略注意力权重,只追踪单个头对残差流做了什么。它的 value 通路是两段线性映射拼起来的:v_proj 把 1024 宽的残差向量降到 256 维的 value,这个头在 W_O 里那条 256 宽的切片再把结果升回 1024 维。这是一个低秩映射:它只能在一个 256 维的子空间里移动残差向量,但代价比自由映射小得多。如果给每个头配一个满秩映射,那是一个 1024 × 1024 矩阵——每个头约 1.05M 个参数——而分解后的降/升一对只要 2 × (1024 × 256) ≈ 0.52M,还没算上 GQA 对“降”那一半的共享(下文细说),成本已经减半。用四个不受约束的头的价钱买八个受约束的头,这就是架构做的交易。
“先拼接、再乘 W_O”这套配方看上去也比实际更神秘。把这个 2048 × 1024 的矩阵横向切成 8 条 256 × 1024 的带子,每头一条。拼接后的 2048 维向量乘 W_O,与让每个头把自己的 256 维输出穿过自己那条带子、再把 8 个结果相加,是完全相同的算术——拼接加一次大投影 ≡ 逐头贡献求和。所以比“把各头粘在一起”更好的心智模型是:每个头独立提出一份小小的低秩修改,这些修改被加到残差流上。没有哪个头会覆盖别的头;它们是累加的。
对我们这个模型,GQA 给这幅图景添了一道褶皱。“降”的那一半按组共享:v_proj 只有 2 个 KV 头(1024 → 512),所以同组的四个 query 头读到的是同一份 256 维 value 向量。但“升”的那一半仍然完全逐头:o_proj 是 2048 → 1024,给全部 8 个 query 头各留了一条自己的 256 → 1024 带子。同组的头用不同的注意力权重混合同一份 value,再经由不同的投影把结果写进残差流——共享地读,各自地写。
取舍,以及 MQA 的位置
GQA 放弃了一点“各头能寻找多少种不同东西”的能力——同组的 query 头无法关注不同的 key,因为它们共享 K。但它们仍然可以通过各自独立的 Q 投影,在这些共享的 key 上学出不同的注意力模式(下面的并排对比就能看到)。内存上的收益很大,质量损失很小,更小的缓存带来的解码加速是实打实的。Llama 3、Qwen3.5 这些中等规模的开源模型全都在用 GQA。
把 GQA 推到极端——num_kv_heads = 1——就得到了多查询注意力(Multi-Query Attention,MQA):所有 query 头共享一份 K/V。早期面向推理优化的模型(PaLM、Falcon)用的是 MQA。如今的共识是:组大小适中(通常 4–8)的 GQA 能在保住 MQA 大部分缓存节省的同时,拿回 MHA 的大部分质量。
在下面把整个谱系扫一遍。从 MHA(8 个 KV 头)一路滑到 MQA(1 个),每层的 KV 缓存线性缩小。每个变体都保留全部八个 query 头,所以 query 侧不受影响——缩小的是各头可以匹配的不同 K/V 特征组的数量,因为同组的 query 头现在必须共享一份 K/V。
★ = Qwen3.5-0.8B 的选择。1 = MQA(共享一份 K/V),8 = MHA(每个 query 头一份 K/V)。
KV 缓存随 KV 头数线性缩小。每个变体都保留全部 8 个 query 头,所以 query 侧不受影响——但共享 K/V 意味着各头可以匹配的不同 K/V 特征组变少了,这是一次真实(尽管通常不大)的容量削减,而不是零削减。经验上,GQA 仍能用一小部分缓存拿回 MHA 的大部分质量(这是一个普遍结论,不是这个玩具扫描能测出来的),这正是 Qwen3.5-0.8B 使用 2 个 KV 头的原因——相比 MHA 节省 4× 的缓存。
示意——缓存浮点数是精确的(2 · #KV · head_dim);不是模型的实时输出。
不同的头最终学会做什么
你没法从权重矩阵上直接读出一个头的职责——必须看它在真实输入上产生的模式。下面是三种常见原型,用示意性的热力图画出来。
不同的头学会检测不同的东西。光看权重没法判断一个头在检测什么——必须看它在真实输入上产生的模式。下面是一个典型中等规模 LLM 里常见的三种原型,用示意性热力图画出来。
检测器:位置。每个 token 主要关注自己;少量泄漏到相邻 token。
检测器:近因。对紧邻的前一个 token 投以强注意力。
检测器:句法。靠后的 token 回头看位置 0 的限定词("t0")。
这些是手工构造的示意图,并非从某个具体模型测得。真实的头更杂乱,常常混合多种原型。
KV 缓存到底有多大?
数字能让节省变得具体。下面的柱状图比较了一个假想的 MHA 版 Qwen3.5-0.8B 与真实 GQA 布局在 8k 上下文(一次聊天几轮就能达到的水平)下的全模型 KV 缓存。
右侧的小部件运行与第 4 章相同的 "The cat sat on the" 提示词,并在末尾用彩虹流光“幽灵”出模型预测的下一个 token。上方的通道图展示哪些 query 头共享哪个 K/V 头;下方的双联热力图对比同组的两个 query 头——相同的 K、不同的 Q,因此在同样的 key 上呈现出不同的注意力模式。
- 同一 K/V 组里的头共享 key 和 value——但它们的 Q 投影各自独立,所以仍然可以关注不同的内容。
- GQA 把 KV 缓存缩小 H/G 倍,质量几乎不受影响;对 Qwen3.5-0.8B 来说,相比完整的 MHA 是 4x 的节省。
- 在长上下文下,占满内存、主导解码延迟的是 KV 缓存,而不是权重。
自动运行结束后,通道图默认选中 Q0,右侧的并排对比把它和同组的 Q1 放在一起。点一点同组里其余的 Q0–Q3 芯片,对比它们的热力图。这些模式是完全相同、彼此相似,还是差异很大?