Flash Attention
深入阅读 · 第 4 章 自注意力——同样的结果,却没有 N×N 矩阵。
Flash Attention 计算的仍然是同一个 softmax(QKᵀ/√d)·V——精确的注意力运算本身,而不是稀疏注意力、低秩注意力那样的近似——但从不显式构造 N×N 矩阵。它把 Q、K、V 切成小块(tile),对每个 query 块流式扫过各个 key 块,同时为每个 query 行维护一份很小的滚动状态:两个标量——迄今的最大分数 m 和指数的滚动和 ℓ——外加一个滚动的输出向量 o,每个输出维度一个分量。
这里的“精确”说的是算法,不是比特:因为它按不同顺序累加各个 key 块,输出与朴素路径只在浮点舍入意义上一致,并非逐比特相同。这是常规的浮点注意事项——改变求和顺序会改变最后几位——而不是对注意力本身的近似。
诀窍在于在线 softmax(online softmax):当后来的块抬高了滚动最大值时,把已经累积的量重新缩放,归一化就始终严格正确。下面的更新式里, 和 是这一块自己针对更新后的最大值 计算的和与加权值——因此只要滚动总量乘上 ,新块就能直接并入,不需要任何额外因子。
这些小块住在 GPU 的快速片上内存(SRAM)里,巨大的矩阵从头到尾不去慢速内存。同样的答案,一小部分的内存流量——这种技术叫“IO 感知”(IO-aware)。
一段简史
Flash Attention 不是凭空出现的——它是“计算 softmax 而不必把整行同时放进内存”这条思想链的最终回报。逐步看这条脉络:
让它感知 IO:把 Q/K/V 切块装进片上 SRAM,把整个运算融合成一个内核,并在反向传播时重算——N×N 矩阵从不写入慢速 HBM。精确,而非近似(另有块稀疏变体)。
同样的想法很快传到了原始代码之外:xformers 的 memory_efficient_attention、OpenAI 的 Triton flash 内核、NVIDIA 的 cuDNN 融合注意力,以及 PyTorch 的 scaled_dot_product_attention(它会自动分派到最快的实现)。而 PagedAttention(来自 vLLM)是表亲,不是某个版本——它是围绕分页 KV 缓存布局构建的注意力内核,解决的是内存高效的推理服务,与 FlashAttention 的稠密 SRAM 内分块是不同的问题。你浏览器里的 Qwen 在 prefill 阶段运行的融合内核是 FlashAttention-2 风格的分块内核——v3 的技巧专属于 NVIDIA 的 Hopper GPU 和 FP8,在这里用不上。
每个版本到底是怎么干的
上面的时间线是地图;下面是地形。每一步都修掉了上一步的瓶颈,所以按顺序读最容易理解——一次一个机制。
地基:流式 softmax
一切始于一个技巧,所以我们放慢脚步把它吃透。先看 softmax 本身,用真实数字。softmax 通过对每个分数取指数、再除以总和,把一行分数变成和为 1 的权重。拿这一行 [2, 1, 3] 来说:指数分别是 、、;它们的和是 30.19;逐个除一下,权重就是 0.245, 0.090, 0.665。最大的分数赢走大部分权重——这就是它的全部工作。
一个实践层面的麻烦: 会爆炸。 已经超出 32 位浮点数的表示范围(上限约为 ),而原始注意力分数可能很大。标准解法是在取指数之前给每个分数减去该行的最大值——最大的那一项变成 ,什么都不会溢出。除以总和时这个平移会被抵消,所以权重完全相同。但注意这个修复的代价:你必须先拿到整行的最大值,才能对任何一个数取指数。于是要扫两遍——一遍找最大值,一遍累加指数。
在线 softmax(2018)把它压缩成一遍:维护一个滚动最大值和一个滚动和;当后来的值超过迄今的最大值时,给已经累出来的和打个补丁——乘以 。为什么这恰好是正确的补丁?你累加过的每一项都形如 ,而
——一次乘法就把每个旧项重新表达到新最大值之下,仿佛你早就知道它一样。具体地,在下面的动画里,第一块 [2, 1, 3] 结束时最大值是 3;接着第二块里出现了 5,于是滚动和与滚动输出都先乘上 ,新的项再并入。2021 年的 “memory-efficient attention” 论文注意到,这个流程可以按 key 和 value 的块来跑——于是你永远不需要把整行同时放进内存。
从空状态开始:滚动最大值 m = −∞,滚动和 ℓ = 0。我们将从左到右扫过整行,每次一个 3 元素的块——只扫一遍。
这个单遍重缩放,正是后来 FlashAttention 完全在 GPU 微小片上内存(SRAM)里运行的种子——因此 N×N 分数矩阵永远不必写出到慢速主内存。
FlashAttention——把整件事分块塞进 SRAM
GPU 有两种内存:HBM,遥远而慢的大主存;SRAM,紧贴计算单元的微型暂存区,大约快十倍。朴素做法在 HBM 里构造完整的 N×N 分数矩阵并来回拖拽——就是上一个子章节的“显存带宽墙”。FlashAttention(2022)让分数完全不进 HBM。它把 Q、K、V 切成小块;对每个 query 块流式扫过 key/value 块,在 SRAM 内部算出每个小分数块,直接并入滚动 softmax,然后丢掉。只有算完的输出行才写回。巨大的分数矩阵在慢速内存里从未存在过——答案一模一样,流量只剩一个零头。
块要多大?刚好让所有在途数据都摆得上桌面。以常见的 64 宽的头为例,一个 128 行的 query 块是 128 × 64 个数,在 fp32 下 ≈ 32 KB。同样形状的一个 key 块和一个 value 块再各占 32 KB,小小的 128 × 128 分数块是 64 KB——滚动输出 o 自己也是一个 128 × 64 的块,又 32 KB(只有 m 和 ℓ 是每行一个标量)。加起来:≈ 190 KB,正好顶到计算单元旁边那 ~200 KB SRAM 的边缘——这恰恰是块取这个尺寸、不能更大的原因。生产内核还会用 16 位的块和更窄的 key 块进一步压榨,但决定块尺寸的,正是“SRAM 能装下什么”这笔预算账。
一个诚实的脚注:v1 确立的不变量是 N×N 分数矩阵永不触碰 HBM。这里展示的循环顺序——query 在外层、key/value 在内层流式扫过、每个输出行只写一次——是 v2 才定下来的更干净的调度。最初的 v1 内核实际上反着循环(key/value 在外、query 在内),一路上反复回写输出。下图画的是现代顺序,因为那才是今天的内核真正运行的方式。
外层循环,块 0。把 query 块 Q₀ 从 HBM 加载进快速的片上 SRAM。它会在每个 K/V 块上被复用。
图中调度:query 在外层循环,key/value 在内层流式扫过——现代(FA2 风格)顺序。v1 的贡献是两种顺序共守的不变量:N×N 分数矩阵永不触及 HBM。
这到底省下多少流量?在 4,096 token 时,朴素做法让分数矩阵在慢速 HBM 总线上往返四趟——写出原始分数、读回做 softmax、写出概率、再读回给 V 加权。对这个模型的 8 个头、bf16 而言,每个全量注意力层就是 ≈ 1.07 GB 的总线流量——而这些分数用一次就被扔掉。分块内核移动的这类字节恰好为零。拖动序列长度,看差距:
条形按平方根刻度绘制,让小值保持可见;数字本身是精确的。
诚实的脚注:两种方案里 Q、K、V 和输出仍要过总线——在这些尺寸下是几十 MB(此模型在 4K token 时 ≈42 MB),两边相近,且只随 N 线性增长——flash 还会把 K/V 块重读几次。重点在于:无界的 N² 分数项彻底消失了。
v1 还有容易被忽略的后一半:训练。反向传播(计算梯度)通常需要再次用到注意力权重——而把 N×N 矩阵存起来留给后面用,会让整件事前功尽弃。FlashAttention 的答案是重计算(recomputation):只保留输出和每行两个小统计量(还是那个最大值 m 和和 ℓ),等反向传播需要时,再在 SRAM 里由 Q 和 K 重新推导出每个分数块。这是有意用更多算术去换更少的内存流量——而且是赚的,因为在现代 GPU 上数学便宜、去 HBM 的往返不便宜。这笔反直觉的交易正是整个“IO 感知”思想的核心。
30 秒逛完 GPU
接下来的两个版本讲的都是把机器填满,所以你需要一张机器的图。GPU 不是一颗处理器——它更像一个办公园区。NVIDIA A100 有 108 个 SM(streaming multiprocessor,流式多处理器):彼此独立的小处理器,各自带着自己的 SRAM 暂存区(~200 KB)和自己的 Tensor Core——专用的矩阵乘法单元。工作以线程块(thread block)的形式到达——每个块被指派给一个 SM;块内部的线程以 32 个一捆的 warp 为单位锁步执行。两条推论与本文相关:没有分到块的 SM 什么都不做;而存在多少个块,由内核说了算。逐级缩放看看:
每个方块是一个 SM——流式多处理器,A100 上 108 个彼此独立的小处理器之一。工作以线程块的形式交给 SM;没有分到块的 SM 什么都不做。FlashAttention-1 只启动 batch × heads = 1 × 8 = 8 个块,所以 108 个 SM 里只有 8 个有活可干。
FlashAttention-2——让每个核心都忙起来
FlashAttention-2(2023)一点数学都没改——它改的是工作如何铺满整块芯片。带着办公园区的画面算一笔账。原始内核大致按(batch × 注意力头)各启动一个线程块:一条提示词乘 8 个头是 8 个块——在 108 个 SM 的 A100 上,约 7% 的芯片在干活,100 个 SM 黑着灯。v2 的头号修复是把 query 序列也切成块:4,096 token 的提示词按 128 行一块是 32 个 query 块,于是 8 头 × 32 块 = 256 个块——绰绰有余,点亮每一个 SM。
它还重新划分了块内部的工作:不再让每个 warp 各算每行的一个切片、再经共享内存合并结果,而是让每个 warp 整体拥有一段 query 行——没有任何跨 warp 的对账。它还削减了非矩阵乘法的算术,因为普通运算比 Tensor Core 慢得多:滚动输出保持未归一化、只在最后除一次 ℓ,而不是每一步都归一化(按 做的滚动最大值重缩放每一块仍然发生——只是把对 ℓ 的除法推迟了);对因果掩码而言(未来 token 反正不能被注意),注定被完全掩掉的分数块干脆不算——单这一项实测约 1.7–1.8×。合起来:比 v1 快约一倍,注意力内核最高达到 A100 理论 FP16 峰值的 73%。(这种占用率思路,正是下文解码故事的驱动力。)
FlashAttention-2 还把 query 行也切成 4 个小块,于是它启动 8×4 = 32 个线程块,铺满整块 GPU。在每个块内部它采用 split-Q:每个 warp 从头到尾拥有自己的一段行,warp 之间从不经由共享内存合并部分结果。
FlashAttention-3——重叠与低精度
FlashAttention-3(2024)为 NVIDIA 的 Hopper GPU(H100)调校,它的起点是硬件里一处悬殊得离谱的失衡。H100 的 Tensor Core 提供约 989 TFLOPS 的 fp16 矩阵乘法——而计算指数(每个 softmax 都要用)的那些小单元只有约 3.9 TFLOPS。注意力来回切换的这两种数学之间,有 256× 的速度差。如果严格按 matmul → softmax → matmul 的顺序轮流跑,芯片上最昂贵的硅就要花大量时间,等便宜的那部分干完。
FlashAttention-3 的答案是不再轮流。Hopper 允许 warp 分工(specialize):producer warp 只负责指挥 TMA(一个专用拷贝引擎)从 HBM 取下一批块,consumer warpgroup 只负责算。然后两个 consumer warpgroup 打乒乓(pingpong):一个在指数单元上跑自己的 softmax 时,另一个的矩阵乘法持续喂饱 Tensor Core——然后互换。每个块其实是两次矩阵乘法——先是分数 QKᵀ,然后是 value 混合 P·V——softmax 夹在中间,被重叠起来的正是它们:
FlashAttention-3 把两种工作重叠起来:当 Tensor Core 还在啃第 j+1 块的矩阵乘法时,指数单元已经在跑第 j 块的 softmax。 Tensor Core 通道几乎排满——只有 16% 闲置——整个调度更早完工。(每条“matmul”柱是简化的:真实的一块要跑两次 GEMM,先 QKᵀ 再 P·V,softmax 夹在中间。)
加速来自两根杠杆:用 matmul 与 softmax 的重叠让 Tensor Core 保持忙碌(warp 分工的“乒乓(pingpong)”调度),以及更便宜的 FP8 数学。这里的块尺寸只是示意,并非实测。
第二根杠杆是 FP8——它值得放慢看,因为“8 位浮点”听上去人畜无害,直到你看清代价。一个 fp8 数(E4M3 格式)只有 256 种位模式,最大值 448;fp8 矩阵乘法约比 fp16 快一倍。问题在于:可表示的值这么少,一切都押在把你的数据映射上去的缩放因子(scale factor)上。真实的注意力输入偶尔会有离群值(outlier)——一片小数中的一个巨大值。给整个张量挑一个缩放,那个离群值就把网格拉得太开,所有小值坍缩到寥寥几级上,它们的信息就没了。FlashAttention-3 两次拯救 fp8:块量化(block quantization)给每个小块自己的缩放,一个离群值只会糟蹋自己那个块——不相干处理(incoherent processing)则用一个固定的随机旋转去乘 Q 和 K,在舍入之前把离群值的能量抹匀到每个维度上。在任何舍入之前,这个旋转在数学上是免费的——两边同转,点积不变。fp8 舍入之后没有什么是严格保持的;旋转的职责是让舍入变得便宜,它省下的误差正是下面这个小部件度量的东西:
离群值赢了。一个缩放必须容下 52,于是 scale = 448/52 ≈ 8.6——每一行都拿到同一张粗网格(即那些刻度)。15 个小值只能吸附到寥寥几级上,两个近零值几乎被整个舍掉(误差最高 100%)。15 个小值的平均往返误差:11.5%。
在任何舍入发生之前,旋转是完全免费的:用同一个正交矩阵 R 旋转每个 q 和 k 向量,它会在矩阵乘法内部抵消——(QRᵀ)(KRᵀ)ᵀ = Q(RᵀR)Kᵀ = QKᵀ,因为 RᵀR = I。经过 FP8 舍入之后,一切都不再精确——旋转的职责是让那次舍入变得更便宜,也就是上面展示的误差下降。所有百分比都是这 16 个固定值经过 E4M3 量化器(1 个符号位 + 4 个指数位 + 3 个尾数位,最大 448)的真实往返;模式 3 的误差是在旋转后的值上测量的,因为真正被存储和相乘的就是它们。
故事还在继续:FlashAttention-4(2026)把同一个不变量带到 NVIDIA 的下一代——Blackwell B200——那里矩阵乘法吞吐又差不多翻倍,而芯片的其余部分没有——最高达到 1,613 TFLOPS(71% 利用率)。我们的深入阅读停在 v3,因为 v4 的技巧是 Blackwell 专属的,但请注意贯穿每个版本的模式:算法的不变量从未改变——N×N 矩阵永不触碰慢速内存——每个版本只是把这个不变量重新调校到最新硅片的瓶颈上。
你浏览器里的 Qwen 用它吗?
用了一部分——而诚实的答案才是有意思的那个。运行这个模型的 WebGPU 后端确实有一个带上面在线 softmax 的融合分块 flash-attention 内核,它在 6 个全量注意力层的 prefill 阶段运行。但你在演示里看到的逐 token 解码故意不用它:在 0.8B、8 个 query 头的情况下,融合内核只会启动 8 个 GPU workgroup,让芯片大部分闲置,所以一个占用率闸门把解码路由回普通的三步路径(matmul → softmax → matmul),它能撒出多得多的并行工作,在这里实测快约 90%。而且 24 层中只有 6 层运行 softmax 注意力——其余 18 层是 GatedDeltaNet。所以:prefill 用 flash,解码用普通路径,皆是有意为之。
在两者之间切换,看分派器把调用路由到真正运行的那个内核:
解码一次只生成一个 token,所以 Tq = 1。融合内核只会启动 B·H = 8 个 workgroup,让芯片大部分闲置——于是闸门为 true,把解码路由到分解路径;它能撒出多得多的并行工作,在这里实测快约 90%。
图中为默认路由。调试开关(?sdpa_fallback=1)会在此闸门之前,把每一次调用——prefill 与解码——强制走分解路径。