描述 SwiGLU 门控 MLP 如何变换单个 token 的隐藏状态,以及它与残差流之间的关系。
注意力在 token 之间混合信息,却做不了逐 token 的特征计算;这部分工作就发生在 MLP 里。
术语表 · 6 个术语
- MLP block
- Transformer 层内逐 token 的前馈子模块。在每个位置上运行同一个小网络。
- SwiGLU
- 门控 MLP 变体:silu(gate_proj(x)) ⊙ up_proj(x),再投影回去。⊙ 表示逐元素相乘。
- SiLU
- 激活函数 z * sigmoid(z)。平滑,z 为较大正数时接近线性,并抑制负值——起到软门控的作用。
- intermediate_dim
- MLP 内部加宽的草稿空间。通常约为 hidden_dim 的 3-4 倍。Qwen3.5-0.8B 为 3584(hidden 为 1024)。
- residual stream
- 贯穿模型全部深度、未归一化的隐藏状态;每个模块从中读取,再把结果加回去。
- residual connection
- 即 x_out = x_in + sub_block(norm(x_in)) 中的 x_in + 部分。让梯度可以跳过子模块,保持健康。
MLP 模块:逐 token 的前馈计算
先说名字。MLP 是多层感知机(multi-layer perceptron)的缩写——最朴素的神经网络:一次矩阵乘法,接一个压缩函数(即“激活函数”),再接一次矩阵乘法。“前馈”(feed-forward)是同一个东西的另一个名字——数字一路直行,没有回环。
注意力是 Transformer 中在 token 之间混合信息的部分:通过 softmax(QKᵀ)·V,每个位置都能读取其他所有位置。MLP 模块恰恰相反——它对每个 token 各自独立处理,在每个位置的隐藏向量上运行同一个小神经网络,并把结果写回该位置。跨 token 的混合发生在注意力里;逐 token 的计算发生在 MLP 里。一个 Transformer 层就在两者之间交替。
先 pre-MLP 归一化,再门控 MLP,最后残差相加
与它的注意力同胞一样,Qwen3.5 的 MLP 子模块也包在 pre-norm + 残差的模式里:
残差连接就是其中的 x_in + 部分。它的存在有两个理由。其一是梯度:反向传播可以沿恒等路径直接穿过,即使 24 层的堆叠也能保持可训练。其二是语义:把残差流想象成一条贯穿模型全部深度的高速公路(下一章——完整的 Transformer 块——会画出残差流)。每个模块从公路上读取,计算一个小的贡献,再加回去。模型的预测是每个模块贡献的累积——而不是最后一个模块单独的输出。
看公式之前还有一件事:为什么需要压缩函数?因为把矩阵乘法直接叠起来、中间什么都不放,等于白叠——A·(B·x) 就是 (A·B)·x,两个线性映射坍缩成一个。夹在中间的压缩函数正是阻止坍缩、让模块能算出真正新东西的关键。
“门控 MLP”到底指什么
Qwen3.5(与 Llama、Mistral 以及大多数现代开源 LLM 一样)使用 SwiGLU 风格的门控 MLP。每层有三个线性投影:
gate_proj:hidden_dim → intermediate_dim。其输出会经过 SiLU 激活函数(也叫 Swish):silu(z) = z · sigmoid(z)。SiLU 平滑,z为较大正数时近乎线性,但会抑制负值——一道软门。up_proj:hidden_dim → intermediate_dim。这是“取值”路径——真正被传递的特征。down_proj:intermediate_dim → hidden_dim。把逐元素乘积投影回残差流的宽度。
仔细读这个表达式。⊙ 是逐元素相乘,不是矩阵乘法(⊙ 表示逐元素相乘——对应位置的元素两两相乘):silu(gate_proj(x)) 中 intermediate_dim 个特征里的每一个,都与 up_proj(x) 中对应的特征相乘。门控投影决定每个特征通过多少;up 投影提供取值。两者按特征纠缠在一起后,再由 down_proj 压回 hidden_dim。
亲手走一个特征。假设门控路径产生 z = 2,up 路径产生 0.5。那么 silu(2) = 2 / (1 + e⁻²) ≈ 1.76,于是通过的特征是 1.76 × 0.5 ≈ 0.88。如果门控是很大的负数,同样的取值就会被压向零。
中间维度就是模型的“草稿空间”——这里约为隐藏维度的 3.5 倍(Qwen3.5-0.8B 在 1024 维隐藏状态上使用 3584;3–4 倍在各模型间都很常见)。这三个矩阵是模型中最大的参数块之一:3 个矩阵 × 每个 1024 × 3584 个数 × 24 层 = 264,241,152 ≈ 264M——约占 ~0.8B 总参数的三分之一,规模与嵌入表大致相当。
上面的矩阵乘法视角告诉你维度。神经元视角告诉你拓扑:两条并行的加宽投影(门控和取值)、一次逐元素相乘、再一条收窄的投影回来。图中宽度只是象征——Qwen3.5-0.8B 用的是 1024 → 3584 → 1024。
SiLU 平滑,并在 z = 0 附近轻微下探到负值(最低 ≈ −0.278,位于 z ≈ −1.278)——一道软门——因此小的负输入仍有梯度流动,不像 ReLU 那样硬归零。Qwen3.5-0.8B 在它的 SwiGLU MLP 里用的就是 SiLU。
示意——SiLU 与 ReLU 按精确公式绘制;Qwen3.5-0.8B 的 MLP 使用 SiLU。不是模型的实时输出。
为什么要门控——以及为什么它并没有变大
你可能会问为什么要三个矩阵。经典的前馈模块——原版 Transformer 里的那个——只用两个:一个升维投影,一个固定的非线性(ReLU 或 GELU),再一个降维投影。SwiGLU 保留了升维和降维矩阵,但增加了第三个——gate——并把固定的非线性换成一个可学习的、乘性的非线性。门控不再对每个特征施加同一个阈值,而是让网络自己决定——按 token、按特征——每个升维后的取值能存活多少。这种依赖输入的门控严格地比固定激活函数更具表达力,实践中也确实能训到更低的损失。
自然的反驳是:第三个矩阵不是要多花 50% 的参数吗?在经典的 4× 中间维度下确实如此——三个 4× 宽的矩阵合计 12 h²,而朴素模块是 8 h²。常规的解决办法(Llama、Mistral)是把中间维度缩到原来的约 ⅔(≈ 8⁄3 · hidden),于是三个更窄的矩阵正好落回 8 h²——也就是下图所示的收支平衡点。切换开关,看长条在总长度不变的情况下重新分配比例。(Qwen3.5 没有缩到那么窄——见图下方的说明。)
长条展示的是理想化的收支平衡点:三个矩阵收窄到 8⁄3 · hidden,总重与朴素 FFN 两个 4× 宽的矩阵相同,都是 8 h²——gate 就是这样被“免费”加进来的。这正是 Llama 和 Mistral 采用的取舍。真实 checkpoint 往往会宽一点:Qwen3.5-0.8B 用了 3584 的中间维度(是其 1024 hidden 的 3.5×),三个矩阵合计 ≈ 10.5 h²——比朴素模块多约 30%。门控很便宜,但并非字面意义上的免费。
管线讲完了——下面看看这些中间特征训练完成后究竟在做什么:
3584 个中间特征中的每一个都像一个小小的检测器:它的门只对匹配自己模式的 token 打开,对其余一切保持关闭。它一旦激活,它在 down_proj 中的那一片就把该特征的签名加进这个 token 的残差流。颜色越深 = 激活越强。
示意——真实特征更杂乱,是被发现的,不是被设计的。没有人把某个神经元指派给“城市名”;训练找到的是任何能降低损失的检测器,很多神经元会同时在几件不相干的事情上激活,干净的“一个神经元一个概念”是例外,不是常态。
MLP 贡献的是一个小修正
令人意外的部分来了:MLP 的原始输出比它写入的残差流小得多。由于每个模块都是把输出加进残差流——从不覆写——这些贡献逐层累积,所以即使每次写入都很小,残差流的逐 token L2 仍会增长(L2 范数就是向量的长度——各元素平方和的平方根)。每个 MLP 的输出只是叠在上面的一个小增量——是修正,不是替换。看下面的图表:MLP 输出的逐 token L2 只是该层完整输出 L2 的一小部分。残差流中的大部分幅值来自输入,而不是这一层的 MLP。
这就是被量化了的“高速公路”直觉。全层长条图显示,各层 MLP 贡献的大小随深度变化:有的层贡献多些,有的少些,但没有任何一层占据主导。预测从许多次小写入的总和中涌现——这恰恰是让深层 Transformer 可学习的设计选择。
关于 Qwen3.5 混合堆叠的一个统一要点:每一层——无论全量注意力还是线性注意力——都带有同一个 MLP 模块。两种层只在 token 混合的那一半有差别;其后的 SwiGLU MLP 处处相同。
- MLP 对每个 token 各自独立处理——MLP 模块内部不存在跨 token 的信息流动。
- gate_proj/up_proj/down_proj 是模型中最大的参数块之一——约占 Qwen3.5-0.8B 的三分之一(≈264M),规模与嵌入表大致相当。
- 每个 MLP 只向残差流写入一个小的修正量;预测来自许多次小写入的总和,而不是最后一个模块的输出。
自动运行结束后,查看 "MLP contribution across all layers" 长条图。点击长条最大的那一层,再点击非空长条中最小的某一层。两者的 "MLP / layer-output ratio" 百分比相比如何?