解释最后的隐藏状态如何变成一个 logits 向量,以及为什么对 Qwen3.5 来说,LM head 实际上就是被复用的嵌入矩阵。
第 9 章结束于一个隐藏状态;第 11 章开始于 logits。中间必须有一步把前者变成后者。
术语表 · 6 个术语
- LM head
- 模型最后的线性投影。把最后的隐藏状态乘以一个 [d, V] 矩阵,为词表中的每个 token 得到一个 logit。
- logit
- 词表中每个条目对应的一个无界实数分数——越高表示模型越偏好该 token。它们还不是概率;采样一章会用 softmax 把它们变成概率。
- weight tying
- 把 embed_tokens.weight 复用为 LM head,让输入查表和输出投影共享同一个张量(在小模型上能省下相当大比例的参数)。
- vocabulary size (V)
- 分词器能产出的不同 token 的数量。Qwen3.5 使用 V = 248,320。
- hidden dim (d)
- 残差流的宽度。Qwen3.5-0.8B 使用 d = 1024——每个 token 在每一层都携带 1024 个浮点数。
- inner product
- 逐元素乘积之和。logit_j = <last_hidden, W[j, :]>:模型给 token j 的分数,就是隐藏状态与该 token 在 LM head 中那一行的对齐程度。
LM head:隐藏状态 → logits
第 9 章把我们留在残差流的顶端:经过最后的 RMSNorm 之后,每个 token 一个宽度为 d = 1024 的隐藏向量。第 11 章(采样)将从一个 logits 向量出发——词表中每个 token 一个实数分数。连接两者的步骤就是一次矩阵乘法,它有自己的名字:LM head。
一行数学
h_last 是最后一个 token 的隐藏状态——当模型逐个 token 生成时(这一步叫解码,decode——后面的章节会讲),形状为 [1, 1024]。W_lm 是 LM head 的权重矩阵——对 Qwen3.5-0.8B 来说形状为 [V, d] = [248,320, 1024]。转置后变成 [1024, 248,320];相乘得到形状为 [1, 248,320] 的输出——词表中每个 token 一个 logit。
有个问题值得停下来想一想:层堆叠为提示词里的每个 token 都产出了隐藏向量,为什么只有最后一个进入 LM head?因为每个位置预测的是它自己的下一个 token——而模型在生成时,唯一需要的预测就是最后那个 token 之后的那个。其他位置的预测并没有浪费:训练会一次性给它们全部打分,训练一章会展示这一点。
动画逐格演示这次矩阵乘法。扫描光束一次高亮一个输出列,旁边同时点亮产生它的矩阵列。每个输出条目都是一次独立的内积——这也是这个运算能在 GPU 上如此干净地并行化的原因。矩阵以转置形式绘制,所以每个 token 的指纹显示为一列——光束下方的标签标出当前正在打分的是哪个 token 的列。
堆叠最顶端的一次矩阵-向量乘法:logits = last_hidden @ embed_tokens.weight.T。对 Qwen3.5-0.8B 来说,就是一个 [1, 1024] 的向量乘以一个 [1024, 248320] 的矩阵 → 得到 [1, 248320] 的输出,词表中每个 token 一个分数。扫描光束标出当前正在产出的输出列。
每个输出条目都是一次内积:logit_j = sum_i (last_hidden[i] · W[i, j])。这个 [d, V] 矩阵的每一列——等价于未转置的 [V, d] 权重的某一行——就是一个词表 token 的“指纹”;当这枚指纹与隐藏状态指向大致相同的方向时,该 token 的 logit 就高。
示意图——12×36 的网格和柱高只是真实 [1024 × 248,320] 矩阵乘法的替身,并非模型的真实 logits。
每个输出单元格意味着什么
把 W_lm 的每一行想成某个词表 token 学到的“指纹”。token j 的 logit 就是 h_last 与该行的内积:
所以,logit 高 ⇔ 隐藏状态与该 token 的行指向大致相同的方向。决定下一个 token 分布的是隐藏状态的几何结构,而 LM head 的各行就是模型用来读出这份几何结构的字典。
内积可以展开为 h · w = |h|·|w|·cos θ,其中 cos θ 是两个向量夹角的余弦——同向时为 +1,正交时为 0,反向时为 −1。当每个 token 的行长度相当时,logit 实际上就是隐藏状态与该行之间夹角的读数。切换下面的三种情形——对齐、正交、相反——观察各维度的乘积(它们的和就是 logit)与余弦一起变化。
词表中的每个 token 在输出矩阵(也就是共享的嵌入矩阵)里都拥有一行 w_t;它的 logit 就是这一行与最后隐藏状态 h 的点积,而 softmax(课程其他地方展示过)会把整个 logits 向量变成概率。这里三个候选向量的长度相同(|w| = 1.5),所以唯一能改变 logit 的只有它与 h 的夹角。
示意用的 8 维玩具向量(真实隐藏状态是 1024 维)——并非模型的实时输出。
这就是为什么你有时会在 top-K 顶部看到一些出乎你意料、却看起来还算合理的 token:它们的行恰好与隐藏状态对齐,即使模型最终未必会采样它们。点击 Run 之后,右侧的 top-K 面板会具体展示这一点。
权重共享:同一个矩阵,用两次
接下来是让人意外的部分。上面公式里的矩阵 W_lm 并不是一个单独学习的张量。对 Qwen3.5(以及大多数 7B 以下的现代解码器 LLM)来说,它就是第 3 章的嵌入矩阵——在堆叠底部把 token id 映射成向量的那个 [248,320, 1024] 浮点数网格,在顶部(转置后)被复用。
Qwen3.5-0.8B(以及大多数现代解码器 LLM)设置了 tie_word_embeddings = true。这意味着输入端的嵌入矩阵和输出端的 LM head 在内存中就是同一个张量——同一个 [248,320 × 1024] 的浮点数网格,一次用于 id → vector,一次(转置后)用于 vector → vocab scores。
从概念上讲,权重共享是在说:把 token id 映射成输入表示的那本字典,同样把输出表示映射回词表分数。读和写共用同一套字母表。并非所有模型都做共享——大型 GPT 风格的模型有时为了一点质量提升而把两者分开——但对参数量低于十亿的模型来说,共享是标准做法。
第 3 章提到这省下了第二份拷贝——这一份拷贝,画出来就是这样。
| 嵌入表 | 254,279,680 |
| 第二份拷贝(LM head) | 只买一份,用两次 |
| 模型总计 | 852,985,920 |
| 嵌入表 | 254,279,680 |
| 第二份拷贝(LM head) | +254,279,680 |
| 嵌入小计 | 508,559,360 |
| 模型总计 | 1,107,265,600 |
这一张 254,279,680 个浮点数的表,占了整个模型的 29.8%。
规模决定利害。GPT-3 的这张表是 50,257 × 12,288 = 617,558,016(≈617.6M)——约是本模型这张表的 2.4 倍——而且 GPT-3 同样共享了嵌入。若让这么大的一张表不共享,就要再添一份拷贝、约 1,235,116,032(~1.24B) 个参数全花在查表上。词表越宽,共享省得越多,所以从 0.8B 到 175B 它都是标准做法。
这就是权重共享(weight tying),由配置中的 tie_word_embeddings = true 控制。它背后有个干净的直觉:模型读入 token 用的那套“字母表”,也应该是它写出 token 用的字母表。如果 embed_tokens.weight 的第 j 行是 "the" 作为隐藏向量的样子,那么每当残差流长得像这一行时,LM head 就应该给 "the" 打出高分——而 row_j · h 恰好就是这个检验。
实际收益:一个参数量不足 10 亿的模型靠不复制这个矩阵省下约 254M 参数(接近其总量的三分之一)。代价:在非常大的规模上有少量质量损失,这也是为什么几十亿参数以上的 GPT 风格模型有时会解开(untie)这个头。对 Qwen3.5-0.8B 来说,省下的参数占了上风。
右侧你将看到什么
该面板执行一次 inspector 调用:对提示词分词、运行一次前向传播、在最后一个位置捕获 top-K logits、把它们渲染成柱状图。每一行是 logit 最高的 12 个词表 token 之一;以主题色高亮的那一行是模型实际采样到的 token(此阶段为贪心——第 11 章会引入 temperature(温度)和 top-p)。
值得注意:面板对这些 logits 做了 softmax,所以柱子的高度读作概率(每个面板的柱子之和为 1)。但 LM head 输出的原始分数是 logits——可正可负的任意实数,其大小在不同提示词之间不可比。第 11 章将讲到 temperature 和 top-p 如何把这个 softmax 重塑成模型的最终选择。
- LM head 只是一次矩阵乘法:last_hidden @ embed_tokens.T → logits。词表中每个条目得到一个分数。
- 权重矩阵的每一行是某个词表 token 学到的“指纹”;logit 最高的那个,就是与隐藏状态最对齐的那个。
- Qwen3.5 把 embed_tokens.weight 与 lm_head.weight 绑定——同一个 ~254M 浮点数的张量在输入查表和输出投影两处使用。
对 'The cat sat on the' 点击 Run,查看 top-K logits 面板。最高的柱子就是模型的选择。现在在 top-K 靠后的位置中找一个在你看来"不像合理续写"的 token(例如像 ' his' 这样的代词混进了名词的位置)。它出现在 top-K 里,说明 LM head 的几何结构是什么样的?