第 4 章 · 自注意力Flash Attention

Flash Attention

深入阅读 · 第 4 章 自注意力——同样的结果,却没有 N×N 矩阵。

Flash Attention 计算的仍然是同一个 softmax(QKᵀ/√d)·V——精确的注意力运算本身,而不是稀疏注意力、低秩注意力那样的近似——但从不显式构造 N×N 矩阵。它把 QKV 切成小块(tile),对每个 query 块流式扫过各个 key 块,同时为每个 query 行维护一份很小的滚动状态:两个标量——迄今的最大分数 m 和指数的滚动和 ——外加一个滚动的输出向量 o,每个输出维度一个分量。

这里的“精确”说的是算法,不是比特:因为它按不同顺序累加各个 key 块,输出与朴素路径只在浮点舍入意义上一致,并非逐比特相同。这是常规的浮点注意事项——改变求和顺序会改变最后几位——而不是对注意力本身的近似。

诀窍在于在线 softmax(online softmax):当后来的块抬高了滚动最大值时,把已经累积的量重新缩放,归一化就始终严格正确。下面的更新式里,这一块自己针对更新后的最大值 计算的和与加权值——因此只要滚动总量乘上 ,新块就能直接并入,不需要任何额外因子。

这些小块住在 GPU 的快速片上内存(SRAM)里,巨大的矩阵从头到尾不去慢速内存。同样的答案,一小部分的内存流量——这种技术叫“IO 感知”(IO-aware)。

一段简史

Flash Attention 不是凭空出现的——它是“计算 softmax 而不必把整行同时放进内存”这条思想链的最终回报。逐步看这条脉络:

Flash Attention 家族谱系
虚线 = 先驱想法 · 实线 = 带编号的 FlashAttention 正式发布
FlashAttentionv1
2022 年 5 月 · NeurIPS 2022
Dao, Fu, Ermon, Rudra, Ré

让它感知 IO:把 Q/K/V 切块装进片上 SRAM,把整个运算融合成一个内核,并在反向传播时重算——N×N 矩阵从不写入慢速 HBM。精确,而非近似(另有块稀疏变体)。

影响BERT-large 快 15%(对比 MLPerf 1.1 纪录),GPT-2 快 3×;第一个达到 16K token 上下文(Path-X)的精确注意力。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness · arXiv 2205.14135

同样的想法很快传到了原始代码之外:xformers 的 memory_efficient_attention、OpenAI 的 Triton flash 内核、NVIDIA 的 cuDNN 融合注意力,以及 PyTorch 的 scaled_dot_product_attention(它会自动分派到最快的实现)。而 PagedAttention(来自 vLLM)是表亲,不是某个版本——它是围绕分页 KV 缓存布局构建的注意力内核,解决的是内存高效的推理服务,与 FlashAttention 的稠密 SRAM 内分块是不同的问题。你浏览器里的 Qwen 在 prefill 阶段运行的融合内核是 FlashAttention-2 风格的分块内核——v3 的技巧专属于 NVIDIA 的 Hopper GPU 和 FP8,在这里用不上。

每个版本到底是怎么干的

上面的时间线是地图;下面是地形。每一步都修掉了上一步的瓶颈,所以按顺序读最容易理解——一次一个机制。

地基:流式 softmax

一切始于一个技巧,所以我们放慢脚步把它吃透。先看 softmax 本身,用真实数字。softmax 通过对每个分数取指数、再除以总和,把一行分数变成和为 1 的权重。拿这一行 [2, 1, 3] 来说:指数分别是 ;它们的和是 30.19;逐个除一下,权重就是 0.245, 0.090, 0.665。最大的分数赢走大部分权重——这就是它的全部工作。

一个实践层面的麻烦: 会爆炸。 已经超出 32 位浮点数的表示范围(上限约为 ),而原始注意力分数可能很大。标准解法是在取指数之前给每个分数减去该行的最大值——最大的那一项变成 ,什么都不会溢出。除以总和时这个平移会被抵消,所以权重完全相同。但注意这个修复的代价:你必须先拿到整行的最大值,才能对任何一个数取指数。于是要扫两遍——一遍找最大值,一遍累加指数。

在线 softmax(2018)把它压缩成一遍:维护一个滚动最大值和一个滚动和;当后来的值超过迄今的最大值时,给已经累出来的和打个补丁——乘以 。为什么这恰好是正确的补丁?你累加过的每一项都形如 ,而

——一次乘法就把每个旧项重新表达到新最大值之下,仿佛你早就知道它一样。具体地,在下面的动画里,第一块 [2, 1, 3] 结束时最大值是 3;接着第二块里出现了 5,于是滚动和与滚动输出都先乘上 ,新的项再并入。2021 年的 “memory-efficient attention” 论文注意到,这个流程可以按 key 和 value 的来跑——于是你永远不需要把整行同时放进内存。

流式 softmax——一遍扫完整行
朴素 softmax:2 遍扫过整行流式:1 遍
213051421362块 1块 2块 3块 4一行注意力 · 分数滚动最大值m = −∞escore − m 的滚动和ℓ = 0.00输出累加器  o = Σ escore − m·vo = 0.00当前答案  o / ℓ =

从空状态开始:滚动最大值 m = −∞,滚动和 ℓ = 0。我们将从左到右扫过整行,每次一个 3 元素的块——只扫一遍

0 / 4 · 尚未开始
滚动 m = −∞ · 滚动 ℓ = 0.00

这个单遍重缩放,正是后来 FlashAttention 完全在 GPU 微小片上内存(SRAM)里运行的种子——因此 N×N 分数矩阵永远不必写出到慢速主内存。

FlashAttention——把整件事分块塞进 SRAM

GPU 有两种内存:HBM,遥远而慢的大主存;SRAM,紧贴计算单元的微型暂存区,大约快十倍。朴素做法在 HBM 里构造完整的 N×N 分数矩阵并来回拖拽——就是上一个子章节的“显存带宽墙”。FlashAttention(2022)让分数完全不进 HBM。它把 QKV 切成小块;对每个 query 块流式扫过 key/value 块,在 SRAM 内部算出每个小分数块,直接并入滚动 softmax,然后丢掉。只有算完的输出行才写回。巨大的分数矩阵在慢速内存里从未存在过——答案一模一样,流量只剩一个零头。

块要多大?刚好让所有在途数据都摆得上桌面。以常见的 64 宽的头为例,一个 128 行的 query 块是 128 × 64 个数,在 fp32 下 ≈ 32 KB。同样形状的一个 key 块和一个 value 块再各占 32 KB,小小的 128 × 128 分数块是 64 KB——滚动输出 o 自己也是一个 128 × 64 的块,又 32 KB(只有 m 是每行一个标量)。加起来:≈ 190 KB,正好顶到计算单元旁边那 ~200 KB SRAM 的边缘——这恰恰是块取这个尺寸、不能更大的原因。生产内核还会用 16 位的块和更窄的 key 块进一步压榨,但决定块尺寸的,正是“SRAM 能装下什么”这笔预算账。

一个诚实的脚注:v1 确立的不变量是 N×N 分数矩阵永不触碰 HBM。这里展示的循环顺序——query 在外层、key/value 在内层流式扫过、每个输出行只写一次——是 v2 才定下来的更干净的调度。最初的 v1 内核实际上反着循环(key/value 在外、query 在内),一路上反复回写输出。下图画的是现代顺序,因为那才是今天的内核真正运行的方式。

Flash 分块——分数块在 SRAM 中生灭
HBM · 主内存巨大——但遥远而慢Q 块Q₀Q₁K 块K₀K₁K₂V 块V₀V₁V₂O 输出O₀O₁← 加载块SRAM · 片上暂存区极小——但快约 10×,整个运算在此融合Q₀KⱼVⱼSᵢⱼ流式 softmax滚动 m, ℓ, O₀完整的 N×N 分数矩阵从未在此处或 HBM 中构造

外层循环,块 0把 query 块 Q 从 HBM 加载进快速的片上 SRAM。它会在每个 K/V 块上被复用。

循环位置 i = 0 (2 个 query 块 × 3 个 K/V 块)
一次分发,全融合:加载 → 算分 → 更新 → 丢弃,最后写回 Oᵢ。
HBM 分数矩阵写入: 0
朴素做法让 N×N 在慢速总线上往返约 4 趟——flash 的分数过桥次数是 0。

图中调度:query 在外层循环,key/value 在内层流式扫过——现代(FA2 风格)顺序。v1 的贡献是两种顺序共守的不变量:N×N 分数矩阵永不触及 HBM。

这到底省下多少流量?在 4,096 token 时,朴素做法让分数矩阵在慢速 HBM 总线上往返四趟——写出原始分数、读回做 softmax、写出概率、再读回给 V 加权。对这个模型的 8 个头、bf16 而言,每个全量注意力层就是 ≈ 1.07 GB 的总线流量——而这些分数用一次就被扔掉。分块内核移动的这类字节恰好为零。拖动序列长度,看差距:

跨总线的分数流量——朴素 vs flash,以字节计
朴素注意力完整的 N×N 分数矩阵在总线上往返——整整四趟HBMSN×N 分数SRAMsoftmax · P·V① 写出分数② 读回做 softmax③ 写出概率④ 读回做 P·V第 1/4 趟 · 把原始的 N×N 分数写出到 HBM每趟搬 268.4 MB → 4 趟 = 1.07 GBFlashAttention分数块从不离开 SRAMHBM这里没有分数SRAMS分数块在此生灭跨总线的分数字节0 B每个分数块算完即并入滚动 softmax,然后丢弃分数流量 = 0 B——任何 N 都如此
序列长度 N = 4,096 个 token
拖动以拉长上下文
跨过 HBM 总线的分数字节
朴素1.07 GB
flash0 B
朴素:4 趟 × 8 个头 × N² 个 bf16 分数(2 字节)· flash:分数块死在 SRAM 里——无需搬运
最大的存活分数张量
朴素268.4 MB
flash65.5 KB
朴素:全部 8 个头的 N×N 分数同时存活——随 N² 增长 · flash:一个 128×128 fp32 块——任何 N 都不变

条形按平方根刻度绘制,让小值保持可见;数字本身是精确的。

诚实的脚注:两种方案里 Q、K、V 和输出仍要过总线——在这些尺寸下是几十 MB(此模型在 4K token 时 ≈42 MB),两边相近,且只随 N 线性增长——flash 还会把 K/V 块重读几次。重点在于:无界的 N² 分数项彻底消失了。

v1 还有容易被忽略的后一半:训练。反向传播(计算梯度)通常需要再次用到注意力权重——而把 N×N 矩阵存起来留给后面用,会让整件事前功尽弃。FlashAttention 的答案是重计算(recomputation):只保留输出和每行两个小统计量(还是那个最大值 m 和和 ),等反向传播需要时,再在 SRAM 里由 Q 和 K 重新推导出每个分数块。这是有意用更多算术去换更少的内存流量——而且是赚的,因为在现代 GPU 上数学便宜、去 HBM 的往返不便宜。这笔反直觉的交易正是整个“IO 感知”思想的核心。

30 秒逛完 GPU

接下来的两个版本讲的都是把机器填满,所以你需要一张机器的图。GPU 不是一颗处理器——它更像一个办公园区。NVIDIA A100 有 108 个 SM(streaming multiprocessor,流式多处理器):彼此独立的小处理器,各自带着自己的 SRAM 暂存区(~200 KB)和自己的 Tensor Core——专用的矩阵乘法单元。工作以线程块(thread block)的形式到达——每个块被指派给一个 SM;块内部的线程以 32 个一捆的 warp 为单位锁步执行。两条推论与本文相关:没有分到块的 SM 什么都不做;而存在多少个块,由内核说了算。逐级缩放看看:

GPU 内部——芯片 → SM → warp
A100 芯片108 个 SM(流式多处理器)线程块 = 1 batch × 8 头 = 88 / 108 个 SM 忙碌 ≈ 7%

每个方块是一个 SM——流式多处理器,A100 上 108 个彼此独立的小处理器之一。工作以线程块的形式交给 SM;没有分到块的 SM 什么都不做。FlashAttention-1 只启动 batch × heads = 1 × 8 = 8 个块,所以 108 个 SM 里只有 8 个有活可干。

8 个块喂不饱 108 个 SM——FlashAttention-2 的全部招数(下一节)修的正是这件事。
A100 · 108 个 SM · warp = 32 线程 · 共享内存 ≈192 KB/SM

FlashAttention-2——让每个核心都忙起来

FlashAttention-2(2023)一点数学都没改——它改的是工作如何铺满整块芯片。带着办公园区的画面算一笔账。原始内核大致按(batch × 注意力头)各启动一个线程块:一条提示词乘 8 个头是 8 个块——在 108 个 SM 的 A100 上,约 7% 的芯片在干活,100 个 SM 黑着灯。v2 的头号修复是把 query 序列切成块:4,096 token 的提示词按 128 行一块是 32 个 query 块,于是 8 头 × 32 块 = 256 个块——绰绰有余,点亮每一个 SM。

它还重新划分了块内部的工作:不再让每个 warp 各算每行的一个切片、再经共享内存合并结果,而是让每个 warp 整体拥有一段 query 行——没有任何跨 warp 的对账。它还削减了非矩阵乘法的算术,因为普通运算比 Tensor Core 慢得多:滚动输出保持未归一化、只在最后除一次 ,而不是每一步都归一化(按 做的滚动最大值重缩放每一块仍然发生——只是把对 的除法推迟了);对因果掩码而言(未来 token 反正不能被注意),注定被完全掩掉的分数块干脆不算——单这一项实测约 1.7–1.8×。合起来:比 v1 快约一倍,注意力内核最高达到 A100 理论 FP16 峰值的 73%。(这种占用率思路,正是下文解码故事的驱动力。)

FlashAttention-2——同样的数学,更满的 GPU
GPU 占用率32 个核心(SM)线程块 = 8×4 = 32 → 已点亮 0/32核心全部占满一个线程块内部split-Q · warp 拥有整行分数块按 Q 行切分 →warp 1Q 行 1–2warp 2Q 行 3–4warp 3Q 行 5–6warp 4Q 行 7–8无跨 warp 同步

FlashAttention-2 还把 query 行也切成 4 个小块,于是它启动 8×4 = 32 个线程块,铺满整块 GPU。在每个块内部它采用 split-Q:每个 warp 从头到尾拥有自己的一段行,warp 之间从不经由共享内存合并部分结果。

更多线程块(序列并行)+ 更干净的 warp 划分(更少共享内存流量) 大约 2× FA1,结果不变。
arXiv 2307.08691 · 同样的在线 softmax 数学,更好的 GPU 机制

FlashAttention-3——重叠与低精度

FlashAttention-3(2024)为 NVIDIA 的 Hopper GPU(H100)调校,它的起点是硬件里一处悬殊得离谱的失衡。H100 的 Tensor Core 提供约 989 TFLOPS 的 fp16 矩阵乘法——而计算指数(每个 softmax 都要用)的那些小单元只有约 3.9 TFLOPS。注意力来回切换的这两种数学之间,有 256× 的速度差。如果严格按 matmul → softmax → matmul 的顺序轮流跑,芯片上最昂贵的硅就要花大量时间,等便宜的那部分干完。

FlashAttention-3 的答案是不再轮流。Hopper 允许 warp 分工(specialize)producer warp 只负责指挥 TMA(一个专用拷贝引擎)从 HBM 取下一批块,consumer warpgroup 只负责算。然后两个 consumer warpgroup 打乒乓(pingpong):一个在指数单元上跑自己的 softmax 时,另一个的矩阵乘法持续喂饱 Tensor Core——然后互换。每个块其实是两次矩阵乘法——先是分数 QKᵀ,然后是 value 混合 P·V——softmax 夹在中间,被重叠起来的正是它们:

FlashAttention-3——重叠两种计算,让 Tensor Core 几乎不闲置
Tensor Coresmatmul · 极快Exp / Softmax缩放 · 更慢的单元mm1soft1mm2soft2mm3soft3mm4soft4总耗时总工期 19 个时间单位1.47× vs FA2时间 →

FlashAttention-3 把两种工作重叠起来:当 Tensor Core 还在啃第 j+1 块的矩阵乘法时,指数单元已经在跑第 j 块的 softmax。 Tensor Core 通道几乎排满——只有 16% 闲置——整个调度更早完工。(每条“matmul”柱是简化的:真实的一块要跑两次 GEMM,先 QKᵀ P·V,softmax 夹在中间。)

matmul 精度
调度长度: 19 时间单位
FP16(16 位)矩阵乘法——切到 FP8 可把蓝色块缩小约 2×。

加速来自两根杠杆:用 matmul 与 softmax 的重叠让 Tensor Core 保持忙碌(warp 分工的“乒乓(pingpong)”调度),以及更便宜的 FP8 数学。这里的块尺寸只是示意,并非实测。

第二根杠杆是 FP8——它值得放慢看,因为“8 位浮点”听上去人畜无害,直到你看清代价。一个 fp8 数(E4M3 格式)只有 256 种位模式,最大值 448;fp8 矩阵乘法约比 fp16 快一倍。问题在于:可表示的值这么少,一切都押在把你的数据映射上去的缩放因子(scale factor)上。真实的注意力输入偶尔会有离群值(outlier)——一片小数中的一个巨大值。给整个张量挑一个缩放,那个离群值就把网格拉得太开,所有小值坍缩到寥寥几级上,它们的信息就没了。FlashAttention-3 两次拯救 fp8:块量化(block quantization)给每个小块自己的缩放,一个离群值只会糟蹋自己那个块——不相干处理(incoherent processing)则用一个固定的随机旋转去乘 Q 和 K,在舍入之前把离群值的能量抹匀到每个维度上。在任何舍入之前,这个旋转在数学上是免费的——两边同转,点积不变。fp8 舍入之后没有什么是严格保持的;旋转的职责是让舍入变得便宜,它省下的误差正是下面这个小部件度量的东西:

FP8 舍入误差——以及驯服它的两个技巧
原始值经 fp8 往返后fp8 可表示的级别块 A4 个值×8.60.62 → 0.64(误差 3.0%)-0.41 → -0.41(误差 0.9%)0.0001 → 0.00(误差 100.0%)0.93 → 0.93(误差 0.2%)块 B4 个值×8.6-0.77 → -0.75(误差 2.0%)0.05 → 0.05(误差 1.6%)0.34 → 0.35(误差 2.4%)-0.59 → -0.58(误差 1.6%)块 C4 个值×8.60.81 → 0.81(误差 0.3%)-0.00015 → -0.00022670200892857144(误差 51.1%)离群值 52 →0.47 → 0.46(误差 1.2%)块 D4 个值×8.6-0.95 → -0.93(误差 2.3%)0.26 → 0.26(误差 0.4%)-0.68 → -0.70(误差 2.4%)0.73 → 0.75(误差 3.4%)-10+1放大到 ±1——离群值远在右侧之外

离群值赢了。一个缩放必须容下 52,于是 scale = 448/52 8.6——每一行都拿到同一张粗网格(即那些刻度)。15 个小值只能吸附到寥寥几级上,两个近零值几乎被整个舍掉(误差最高 100%)。15 个小值的平均往返误差:11.5%

在任何舍入发生之前,旋转是完全免费的:用同一个正交矩阵 R 旋转每个 q 和 k 向量,它会在矩阵乘法内部抵消——(QRᵀ)(KRᵀ)ᵀ = Q(RᵀR)Kᵀ = QKᵀ,因为 RᵀR = I。经过 FP8 舍入之后,一切都不再精确——旋转的职责是让那次舍入变得更便宜,也就是上面展示的误差下降。所有百分比都是这 16 个固定值经过 E4M3 量化器(1 个符号位 + 4 个指数位 + 3 个尾数位,最大 448)的真实往返;模式 3 的误差是在旋转后的值上测量的,因为真正被存储和相乘的就是它们。

故事还在继续:FlashAttention-4(2026)把同一个不变量带到 NVIDIA 的下一代——Blackwell B200——那里矩阵乘法吞吐又差不多翻倍,而芯片的其余部分没有——最高达到 1,613 TFLOPS(71% 利用率)。我们的深入阅读停在 v3,因为 v4 的技巧是 Blackwell 专属的,但请注意贯穿每个版本的模式:算法的不变量从未改变——N×N 矩阵永不触碰慢速内存——每个版本只是把这个不变量重新调校到最新硅片的瓶颈上。

你浏览器里的 Qwen 用它吗?

用了一部分——而诚实的答案才是有意思的那个。运行这个模型的 WebGPU 后端确实有一个带上面在线 softmax 的融合分块 flash-attention 内核,它在 6 个全量注意力层的 prefill 阶段运行。但你在演示里看到的逐 token 解码故意用它:在 0.8B、8 个 query 头的情况下,融合内核只会启动 8 个 GPU workgroup,让芯片大部分闲置,所以一个占用率闸门把解码路由回普通的三步路径(matmul → softmax → matmul),它能撒出多得多的并行工作,在这里实测快约 90%。而且 24 层中只有 6 层运行 softmax 注意力——其余 18 层是 GatedDeltaNet。所以:prefill 用 flash,解码用普通路径,皆是有意为之。

在两者之间切换,看分派器把调用路由到真正运行的那个内核:

原生分派——真正运行的是哪个内核
注意力调用Tq = 1 行 queryB·H = 8 个头 × batch占用率闸门Tq==1 & B·H<32 ?Fused FlashAttention-2一次分发 · 给 Q,K,V 分块在 SRAM 里在线 softmax m, ℓ, o在 PREFILL 运行分解路径matmul QKᵀsoftmaxmatmul ·V撒出 100+ 个 workgroup在 DECODE 运行

解码一次只生成一个 token,所以 Tq = 1。融合内核只会启动 B·H = 8 个 workgroup,让芯片大部分闲置——于是闸门为 true,把解码路由到分解路径;它能撒出多得多的并行工作,在这里实测快约 90%

闸门: Tq = 1 and B·H = 8 < 32 → true decomposed
24 层中只有 6 层会到达这个闸门——其余 18 层是 GatedDeltaNet(线性)。

图中为默认路由。调试开关(?sdpa_fallback=1)会在此闸门之前,把每一次调用——prefill 与解码——强制走分解路径。