第 10 章 · LM head 与权重共享

解释最后的隐藏状态如何变成一个 logits 向量,以及为什么对 Qwen3.5 来说,LM head 实际上就是被复用的嵌入矩阵。

第 9 章结束于一个隐藏状态;第 11 章开始于 logits。中间必须有一步把前者变成后者。

6 分钟
前向传播分词嵌入查表× 24 层最终 RMSNormLM head采样
术语表 · 6 个术语
LM head
模型最后的线性投影。把最后的隐藏状态乘以一个 [d, V] 矩阵,为词表中的每个 token 得到一个 logit。
logit
词表中每个条目对应的一个无界实数分数——越高表示模型越偏好该 token。它们还不是概率;采样一章会用 softmax 把它们变成概率。
weight tying
把 embed_tokens.weight 复用为 LM head,让输入查表和输出投影共享同一个张量(在小模型上能省下相当大比例的参数)。
vocabulary size (V)
分词器能产出的不同 token 的数量。Qwen3.5 使用 V = 248,320。
hidden dim (d)
残差流的宽度。Qwen3.5-0.8B 使用 d = 1024——每个 token 在每一层都携带 1024 个浮点数。
inner product
逐元素乘积之和。logit_j = <last_hidden, W[j, :]>:模型给 token j 的分数,就是隐藏状态与该 token 在 LM head 中那一行的对齐程度。

LM head:隐藏状态 → logits

第 9 章把我们留在残差流的顶端:经过最后的 RMSNorm 之后,每个 token 一个宽度为 d = 1024 的隐藏向量。第 11 章(采样)将从一个 logits 向量出发——词表中每个 token 一个实数分数。连接两者的步骤就是一次矩阵乘法,它有自己的名字:LM head

一行数学

h_last 是最后一个 token 的隐藏状态——当模型逐个 token 生成时(这一步叫解码,decode——后面的章节会讲),形状为 [1, 1024]W_lm 是 LM head 的权重矩阵——对 Qwen3.5-0.8B 来说形状为 [V, d] = [248,320, 1024]。转置后变成 [1024, 248,320];相乘得到形状为 [1, 248,320] 的输出——词表中每个 token 一个 logit。

有个问题值得停下来想一想:层堆叠为提示词里的每个 token 都产出了隐藏向量,为什么只有最后一个进入 LM head?因为每个位置预测的是它自己的下一个 token——而模型在生成时,唯一需要的预测就是最后那个 token 之后的那个。其他位置的预测并没有浪费:训练会一次性给它们全部打分,训练一章会展示这一点。

动画逐格演示这次矩阵乘法。扫描光束一次高亮一个输出列,旁边同时点亮产生它的矩阵列。每个输出条目都是一次独立的内积——这也是这个运算能在 GPU 上如此干净地并行化的原因。矩阵以转置形式绘制,所以每个 token 的指纹显示为一——光束下方的标签标出当前正在打分的是哪个 token 的列。

最后的矩阵乘法——隐藏状态 → logits

堆叠最顶端的一次矩阵-向量乘法:logits = last_hidden @ embed_tokens.weight.T。对 Qwen3.5-0.8B 来说,就是一个 [1, 1024] 的向量乘以一个 [1024, 248320] 的矩阵 → 得到 [1, 248320] 的输出,词表中每个 token 一个分数。扫描光束标出当前正在产出的输出列。

隐藏状态d = 1024@embed_tokens.weight.T[d=1024, V=248,320]第 0 列 = " the"第 j 列 = token j 的“指纹”(W_lm 的第 j 行)=logitsV = 248,320 个条目

每个输出条目都是一次内积:logit_j = sum_i (last_hidden[i] · W[i, j])。这个 [d, V] 矩阵的每一列——等价于未转置的 [V, d] 权重的某一行——就是一个词表 token 的“指纹”;当这枚指纹与隐藏状态指向大致相同的方向时,该 token 的 logit 就高。

示意图——12×36 的网格和柱高只是真实 [1024 × 248,320] 矩阵乘法的替身,并非模型的真实 logits。

每个输出单元格意味着什么

W_lm 的每一想成某个词表 token 学到的“指纹”。token j 的 logit 就是 h_last 与该行的内积:

所以,logit 高 ⇔ 隐藏状态与该 token 的行指向大致相同的方向。决定下一个 token 分布的是隐藏状态的几何结构,而 LM head 的各行就是模型用来读出这份几何结构的字典。

内积可以展开为 h · w = |h|·|w|·cos θ,其中 cos θ 是两个向量夹角的余弦——同向时为 +1,正交时为 0,反向时为 −1。当每个 token 的行长度相当时,logit 实际上就是隐藏状态与该行之间夹角的读数。切换下面的三种情形——对齐、正交、相反——观察各维度的乘积(它们的和就是 logit)与余弦一起变化。

logit 就是一次内积

词表中的每个 token 在输出矩阵(也就是共享的嵌入矩阵)里都拥有一行 w_t;它的 logit 就是这一行与最后隐藏状态 h 的点积,而 softmax(课程其他地方展示过)会把整个 logits 向量变成概率。这里三个候选向量的长度相同|w| = 1.5),所以唯一能改变 logit 的只有它与 h夹角

每个维度的贡献 h_d · w_d正贡献负贡献
Σ_d h_d·w_d = logit = +2.230——上面八根柱子加起来就是这一个分数。
cos θ = (h·w)/(|h||w|) = +1.000θ ≈ 0°(|h| ≈ 1.487, |w| = 1.5)

示意用的 8 维玩具向量(真实隐藏状态是 1024 维)——并非模型的实时输出。

这就是为什么你有时会在 top-K 顶部看到一些出乎你意料、却看起来还算合理的 token:它们的行恰好与隐藏状态对齐,即使模型最终未必会采样它们。点击 Run 之后,右侧的 top-K 面板会具体展示这一点。

权重共享:同一个矩阵,用两次

接下来是让人意外的部分。上面公式里的矩阵 W_lm 并不是一个单独学习的张量。对 Qwen3.5(以及大多数 7B 以下的现代解码器 LLM)来说,它就是第 3 章的嵌入矩阵——在堆叠底部把 token id 映射成向量的那个 [248,320, 1024] 浮点数网格,在顶部(转置后)被复用。

权重共享——同一个矩阵,用两次
第 1/4 步

Qwen3.5-0.8B(以及大多数现代解码器 LLM)设置了 tie_word_embeddings = true。这意味着输入端的嵌入矩阵和输出端的 LM head 在内存中就是同一个张量——同一个 [248,320 × 1024] 的浮点数网格,一次用于 id → vector,一次(转置后)用于 vector → vocab scores

24 层embed_tokens.weighttoken id → 向量(查表)lm_head.weight (= embed_tokens.weight)向量 → 词表分数共享——同一个张量"the"
第 1 步:嵌入查表——读取 embed_tokens.weight 的一行
参数节省:这个矩阵有 248,320 × 1024 ≈ 254.3M 个浮点数。权重共享省掉了 LM head 处的第二份拷贝——在一个 0.8B 参数的模型上少了约 254M 个参数。仅仅是复用这本字典,就省掉了接近模型三分之一的参数。

从概念上讲,权重共享是在说:把 token id 映射成输入表示的那本字典,同样把输出表示映射回词表分数。读和写共用同一套字母表。并非所有模型都做共享——大型 GPT 风格的模型有时为了一点质量提升而把两者分开——但对参数量低于十亿的模型来说,共享是标准做法。

第 3 章提到这省下了第二份拷贝——这一份拷贝,画出来就是这样。

共享(Qwen3.5-0.8B 实际采用)
嵌入表254,279,680
第二份拷贝(LM head)只买一份,用两次
模型总计852,985,920
不共享(假设)
嵌入表254,279,680
第二份拷贝(LM head)+254,279,680
嵌入小计508,559,360
模型总计1,107,265,600

这一张 254,279,680 个浮点数的表,占了整个模型的 29.8%。

规模决定利害。GPT-3 的这张表是 50,257 × 12,288 = 617,558,016(≈617.6M)——约是本模型这张表的 2.4 倍——而且 GPT-3 同样共享了嵌入。若让这么大的一张表不共享,就要再添一份拷贝、约 1,235,116,032(~1.24B) 个参数全花在查表上。词表越宽,共享省得越多,所以从 0.8B 到 175B 它都是标准做法。

这就是权重共享(weight tying),由配置中的 tie_word_embeddings = true 控制。它背后有个干净的直觉:模型读入 token 用的那套“字母表”,也应该是它写出 token 用的字母表。如果 embed_tokens.weight 的第 j 行是 "the" 作为隐藏向量的样子,那么每当残差流长得像这一行时,LM head 就应该给 "the" 打出高分——而 row_j · h 恰好就是这个检验。

实际收益:一个参数量不足 10 亿的模型靠不复制这个矩阵省下约 254M 参数(接近其总量的三分之一)。代价:在非常大的规模上有少量质量损失,这也是为什么几十亿参数以上的 GPT 风格模型有时会解开(untie)这个头。对 Qwen3.5-0.8B 来说,省下的参数占了上风。

右侧你将看到什么

该面板执行一次 inspector 调用:对提示词分词、运行一次前向传播、在最后一个位置捕获 top-K logits、把它们渲染成柱状图。每一行是 logit 最高的 12 个词表 token 之一;以主题色高亮的那一行是模型实际采样到的 token(此阶段为贪心——第 11 章会引入 temperature(温度)和 top-p)。

值得注意:面板对这些 logits 做了 softmax,所以柱子的高度读作概率(每个面板的柱子之和为 1)。但 LM head 输出的原始分数是 logits——可正可负的任意实数,其大小在不同提示词之间不可比。第 11 章将讲到 temperature 和 top-p 如何把这个 softmax 重塑成模型的最终选择。

工程要点
  • LM head 只是一次矩阵乘法:last_hidden @ embed_tokens.T → logits。词表中每个条目得到一个分数。
  • 权重矩阵的每一行是某个词表 token 学到的“指纹”;logit 最高的那个,就是与隐藏状态最对齐的那个。
  • Qwen3.5 把 embed_tokens.weight 与 lm_head.weight 绑定——同一个 ~254M 浮点数的张量在输入查表和输出投影两处使用。
动手练习

对 'The cat sat on the' 点击 Run,查看 top-K logits 面板。最高的柱子就是模型的选择。现在在 top-K 靠后的位置中找一个在你看来"不像合理续写"的 token(例如像 ' his' 这样的代词混进了名词的位置)。它出现在 top-K 里,说明 LM head 的几何结构是什么样的?

随堂测验
1. Qwen3.5-0.8B 的 LM head 矩阵的形状是什么?
2. Qwen3.5 配置中的 'tie_word_embeddings = true' 是什么意思?
3. 关于 LM head 各行的说法,哪一项是对的?

动手试试

互动演示加载中……