Chapter 4 · Self-attentionFlash Attention

Flash Attention

Go deeper · Chapter 4, Self-attention — the same result, without the N×N matrix.

Flash Attention computes the same softmax(QKᵀ/√d)·V — the exact attention operation, not an approximation like sparse or low-rank attention — without ever materializing the N×N matrix. It tiles Q, K, and V into small blocks and, for each query tile, streams over the key tiles while keeping a small running state per query row: two single numbers — the max score so far m and the running sum of exponentials — plus a running output vector o, one entry per output dimension.

“Exact” here means the algorithm, not the bits: because it sums the key tiles in a different order, the output matches the naïve path up to floating-point rounding, not bit-for-bit. That is the normal floating-point caveat — reordering a sum changes the last few bits — not an approximation of attention itself.

The trick is the online softmax: when a later tile raises the running max, you rescale what you have already accumulated so the normalization stays exactly correct. In the update below, and are this tile's own sum and weighted values measured against the updated max — so once the running totals are scaled by , the new tile folds straight in with no extra factor.

Those tiles live in the GPU's fast on-chip memory (SRAM), so the giant matrix never travels to slow memory at all. Same answer, a fraction of the memory traffic — the technique is “IO-aware.”

A short history

Flash Attention didn't arrive fully formed — it is the payoff of a chain of ideas about computing softmax without ever holding the whole row in memory. Step through the lineage:

The Flash Attention family tree
dashed = precursor ideas · solid = the numbered FlashAttention releases
FlashAttentionv1
May 2022 · NeurIPS 2022
Dao, Fu, Ermon, Rudra, Ré

Make it IO-aware: tile Q/K/V into on-chip SRAM, fuse the whole operation into one kernel, and recompute in the backward pass — so the N×N matrix is never written to slow HBM. Exact, not an approximation (plus a block-sparse variant).

impact15% faster BERT-large (vs the MLPerf 1.1 record), 3× GPT-2; first exact attention to reach 16K-token context (Path-X).
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness · arXiv 2205.14135

The same idea spread fast beyond the original code: xformers' memory_efficient_attention, OpenAI's Triton flash kernels, NVIDIA's cuDNN fused attention, and PyTorch's scaled_dot_product_attention (which dispatches to whichever is fastest). And PagedAttention (from vLLM) is a cousin, not a version — it is an attention kernel built around a paged KV-cache layout for memory-efficient serving, a different problem from FlashAttention's dense in-SRAM tiling. The fused kernel your in-browser Qwen runs during prefill is a FlashAttention-2-style tile kernel — v3's tricks are specific to NVIDIA's Hopper GPUs and FP8, so they don't apply here.

How each version actually works

The timeline above is the map; here is the territory. Each step fixed the previous one's bottleneck, so they make the most sense in order — one mechanism at a time.

The foundation: streaming softmax

Everything starts with one trick, so let's slow down and earn it. First, softmax itself, with real numbers. Softmax turns a row of scores into weights that sum to 1 by exponentiating each score and dividing by the total. Take the row [2, 1, 3]: the exponentials are , , ; their sum is 30.19; divide through and the weights are 0.245, 0.090, 0.665. The biggest score wins most of the weight — that is the whole job.

One practical wrinkle: explodes. already overflows a 32-bit float (the limit is about ), and raw attention scores can get big. The standard fix is to subtract the row's max from every score before exponentiating — the largest term becomes and nothing overflows. Dividing by the sum cancels the shift, so the weights come out identical. But notice what that fix costs: you need the max of the whole row before you can exponentiate anything. Hence two passes — one to find the max, one to sum the exponentials.

The online softmax (2018) collapses that to a single pass: keep a running max and a running sum, and when a later value beats the max so far, patch the sum you already built by multiplying it by . Why is that the exact right patch? Every term you accumulated has the form , and

— one multiply re-expresses every old term against the new max, as if you had known it all along. Concretely, in the animation below the first chunk [2, 1, 3] finishes with max 3; then chunk two contains a 5, so the running sum and the running output are both multiplied by before the new terms fold in. The 2021 “memory-efficient attention” paper noticed you can run this over chunks of keys and values — so you never need the whole row in memory at once.

Streaming softmax — the whole row in one pass
naïve softmax: 2 passes over the rowstreaming: 1 pass
213051421362chunk 1chunk 2chunk 3chunk 4one attention row · scoresrunning maxm = −∞running sum of escore − mℓ = 0.00output accumulator  o = Σ escore − m·vo = 0.00answer so far  o / ℓ =

Start empty: running max m = −∞, running sum ℓ = 0. We will walk the row left to right, one chunk of 3 at a time — a single pass.

chunk 0 / 4 · not started
running m = −∞ · running ℓ = 0.00

This single-pass rescale is the seed FlashAttention later runs entirely inside the GPU's tiny on-chip memory (SRAM) — so the N×N score matrix never has to be written out to slow main memory.

FlashAttention — tiling the whole thing into SRAM

A GPU has two kinds of memory: HBM, the big main memory that is far away and slow, and SRAM, a tiny scratchpad right next to the compute units that is roughly ten times faster. The naïve recipe builds the full N×N score matrix in HBM and drags it back and forth — the “memory wall” from the previous sub-chapter. FlashAttention (2022) keeps the scores out of HBM entirely. It cuts Q, K, and V into small tiles; for each tile of queries it streams over the key/value tiles, computes each little score tile inside SRAM, folds it straight into the running softmax, and throws it away. Only the finished output rows are written back. The giant score matrix never exists in slow memory — the same exact answer, a fraction of the traffic.

How big is a tile? Just small enough that everything in flight fits on the desk. With a typical 64-wide head, a 128-row query tile is 128 × 64 numbers ≈ 32 KB in fp32. A key tile and a value tile of the same shape are another 32 KB each, the little 128 × 128 score tile is 64 KB — and the running output o is itself a 128 × 64 tile, another 32 KB (only m and are single numbers per row). Add it up: ≈ 190 KB, right at the edge of the ~192 KB of SRAM sitting next to the compute units — which is exactly why the tiles are this size and not bigger. Production kernels squeeze further with 16-bit tiles and narrower key blocks, but the budget arithmetic — “what fits in SRAM” — is what picks the tile sizes.

One honest footnote: the invariant v1 introduced is that the N×N score matrix never touches HBM. The exact loop order shown here — queries on the outside, keys/values streamed on the inside, each output row written once — is the cleaner schedule that v2 settled on. The original v1 kernel actually looped the other way (keys/values outside, queries inside), revisiting the output as it went. The diagram below draws the modern order because that is what kernels run today.

Flash tiling — the score tile lives and dies in SRAM
HBM · main memoryhuge — but far away and slowQ tilesQ₀Q₁K tilesK₀K₁K₂V tilesV₀V₁V₂O outputO₀O₁← load tilesSRAM · on-chip scratchpadtiny — but ~10× faster · the whole op fuses hereQ₀KⱼVⱼSᵢⱼstreaming softmaxrunning m, ℓ, O₀full N×N score matrix never built here or in HBM

Outer loop, tile 0. Load query tile Q from HBM into fast on-chip SRAM. We will reuse it across every K/V tile.

loop position i = 0 (2 query tiles × 3 K/V tiles)
One dispatch, fused: load → score → update → discard, then write Oᵢ.
HBM score-matrix writes: 0
naïve recipe crosses the slow bus ~4× to move N×N — flash crosses it 0× for scores.

Schedule shown: queries on the outer loop, keys/values streamed on the inner — the modern (FA2-style) order. v1's contribution is the invariant either order keeps: the N×N score matrix never reaches HBM.

How much traffic does that actually save? At 4,096 tokens, the naïve recipe pushes the score matrix across the slow HBM bus four times — written as raw scores, read back for softmax, written as probabilities, read again to weight V. For this model's 8 heads in bf16 that is ≈ 1.07 GB of bus traffic per full-attention layer, for scores that are used once and thrown away. The tiled kernel moves exactly zero of those bytes. Drag the sequence length and watch the gap:

Score traffic over the bus — naïve vs flash, in bytes
Naïve attentionthe full N×N score matrix rides the bus — four timesHBMslowSN×N scoresSRAMfastsoftmax · P·V① write scores② read for softmax③ write probs④ read for P·Vtrip 1 of 4 · write the raw N×N scores out to HBMeach trip moves 268.4 MB → 4 trips = 1.07 GBFlashAttentionthe score tile never leaves SRAMHBMslowno scores hereSRAMfastStile lives & diesscore bytes over the bus0 Beach tile is computed, folded into the running softmax, discardedscore traffic = 0 B — at every N
Sequence length N = 4,096 tokens
drag to stretch the context
Score bytes crossing the HBM bus
naïve1.07 GB
flash0 B
naïve: 4 trips × 8 heads × N² scores in bf16 (2 bytes) · flash: the tile dies in SRAM — nothing to move
Largest live score tensor
naïve268.4 MB
flash65.5 KB
naïve: all 8 heads' N×N scores at once — grows with N² · flash: one 128×128 fp32 tile — flat at every N

Bars are drawn on a square-root scale so the small values stay visible; the numbers are exact.

Honest footnote: Q, K, V and the output still cross the bus in both schemes — tens of MB at these sizes (≈42 MB at 4K tokens for this model), similar for both, growing only linearly with N — and flash re-reads the K/V tiles a few times. The point is that the unbounded N² score term disappears entirely.

There is a second half to v1 that is easy to miss: training. The backward pass (computing gradients) normally needs the attention weights again — and storing the N×N matrix for later would defeat the entire point. FlashAttention's answer is recomputation: keep only the output and two small per-row statistics (that same max m and sum ), then re-derive each score tile from Q and K inside SRAM when the backward pass needs it. That is deliberately doing more arithmetic to do less memory traffic — and it wins, because on a modern GPU the math is cheap and the trips to HBM are not. This counterintuitive trade is the heart of the whole “IO-aware” idea.

A 30-second tour of the GPU

The next two versions are about filling the machine, so you need a picture of the machine. A GPU is not one processor — it is more like an office park. An NVIDIA A100 has 108 SMs (streaming multiprocessors): independent little processors, each with its own SRAM scratchpad (~192 KB) and its own Tensor Cores, the dedicated matrix-multiply units. Work arrives as thread blocks — each block is assigned to one SM, and inside a block the threads run in bundles of 32 called warps, which execute in lockstep. Two consequences matter here: an SM with no block assigned does nothing, and the kernel decides how many blocks exist. Zoom through the levels:

Inside the GPU — chip → SM → warp
A100 chip108 SMs (streaming multiprocessors)blocks = 1 batch × 8 heads = 88 / 108 SMs busy ≈ 7%

Each square is an SM — a streaming multiprocessor, one of 108 independent little processors on an A100. Work is handed to SMs as thread blocks; an SM with no block assigned does nothing. FlashAttention-1 launches batch × heads = 1 × 8 = 8 blocks, so only 8 of the 108 SMs have anything to do.

8 blocks can't feed 108 SMs — fixing exactly that is FlashAttention-2's whole move (next).
A100 · 108 SMs · warp = 32 threads · shared memory ≈192 KB/SM

FlashAttention-2 — keep every core busy

FlashAttention-2 (2023) changes none of the math — it changes how the work is spread across the chip. Run the numbers with the office-park picture in mind. The original kernel launched roughly one thread block per (batch × attention head): one prompt times 8 heads is 8 blocks — on a 108-SM A100, about 7% of the chip working and 100 SMs sitting dark. v2's headline fix is to also split the query sequence into blocks: a 4,096-token prompt in 128-row tiles gives 32 query tiles, so 8 heads × 32 tiles = 256 blocks — more than enough to light up every SM.

It also re-partitions the work inside a block: instead of every warp computing a slice of each row and then combining results through shared memory, each warp owns a band of query rows outright — no cross-warp bookkeeping. And it shaves the non-matmul arithmetic, because regular math runs far slower than the Tensor Cores: the running output is kept unnormalized and divided by exactly once at the very end instead of being normalized at every step (the running max-rescale by still happens each tile — it is only the -division that is deferred), and for causal masks (where future tokens cannot be attended anyway) whole score tiles that would be fully masked are simply never computed — measured at ~1.7–1.8× by itself. Together: about twice as fast as v1, with the attention kernel hitting up to 73% of the A100's theoretical FP16 peak. (That occupancy idea is exactly what drives the decode story below.)

FlashAttention-2 — same math, fuller GPU
GPU occupancy32 cores (SMs)blocks = 8×4 = 32 → 0/32 litcores saturatedInside one thread blocksplit-Q · warps own whole rowsscore tilesplit by Q rows →warp 1Q rows 1–2warp 2Q rows 3–4warp 3Q rows 5–6warp 4Q rows 7–8no cross-warp sync

FlashAttention-2 also splits the query rows into 4 tiles, so it launches 8×4 = 32 thread blocks and fills the whole GPU. Inside each block it uses split-Q: every warp owns its own band of rows end-to-end, so the warps never combine partials through shared memory.

more thread blocks (sequence parallelism) + cleaner warp split (less shared-memory traffic) roughly 2× FA1, same result.
arXiv 2307.08691 · same online-softmax math, better GPU mechanics

FlashAttention-3 — overlap and low precision

FlashAttention-3 (2024) is tuned for NVIDIA's Hopper GPUs (the H100), and it starts from a wild imbalance in the hardware. The H100's Tensor Cores deliver about 989 TFLOPS of fp16 matrix multiply — but the little units that compute exponentials (which every softmax needs) manage only about 3.9 TFLOPS. That is a 256× speed gap between the two kinds of math attention alternates between. Run matmul → softmax → matmul strictly in order, and the most expensive silicon on the chip spends much of its time waiting for the cheap part to finish.

FlashAttention-3's answer is to stop taking turns. Hopper lets warps specialize: producer warps do nothing but tell the TMA (a dedicated copy engine) to fetch the next tiles from HBM, while consumer warpgroups do nothing but math. Then two consumer warpgroups play pingpong: while one runs its softmax on the exp units, the other's matmuls keep the Tensor Cores fed — and they swap. Each tile is really two matmuls — the scores QKᵀ, then the value mix P·V — with the softmax wedged between, and it is those that get overlapped:

FlashAttention-3 — overlap the math so the Tensor Cores barely idle
Tensor Coresmatmul · very fastExp / Softmaxscale · slower unitsmm1soft1mm2soft2mm3soft3mm4soft4total timemakespan 19 time-units1.47× vs FA2time →

FlashAttention-3 overlaps the two kinds of work: while the Tensor Cores grind tile j+1's matmul, the exp units already run tile j's softmax. The Tensor-Core lane stays nearly solid — only 16% idle — so the schedule finishes sooner. (Each “matmul” bar is simplified: a real tile runs two GEMMs, QKᵀ then P·V, with the softmax between.)

matmul precision
schedule length: 19 time-units
FP16 (16-bit) matmuls — switch to FP8 to shrink the blue blocks ~2×.

The speedup comes from two levers: keeping the Tensor Cores busy by overlapping matmul with softmax (warp-specialized “pingpong” scheduling), and cheaper FP8 math. Block sizes here are illustrative, not measured.

The second lever is FP8 — and it deserves a slower look, because “8-bit float” sounds innocent until you see what it costs. An fp8 number (the E4M3 format) has 256 possible bit patterns and a maximum value of 448; matmuls in fp8 run about twice as fast as fp16. The catch: with so few representable values, everything hinges on the scale factor that maps your data onto them. Real attention inputs contain occasional outliers — one huge value in a sea of small ones. Pick one scale for the whole tensor and that single outlier stretches the grid so far that all the small values collapse onto a handful of levels and their information is gone. FlashAttention-3 rescues fp8 twice: block quantization gives each small tile its own scale, so one outlier only coarsens its own block — and incoherent processing multiplies Q and K by a fixed random rotation that smears the outlier's energy across every dimension before rounding. Before any rounding that rotation is mathematically free — rotate both sides and the dot products are unchanged. After fp8 rounding nothing is exactly preserved; the rotation's job is to make the rounding cheap, and the error it saves is what the widget below measures:

FP8 rounding error — and the two tricks that tame it
original valueafter fp8 round-tripfp8-representable levelsblock A4 values×8.60.62 → 0.64 (3.0% off)-0.41 → -0.41 (0.9% off)0.0001 → 0.00 (100.0% off)0.93 → 0.93 (0.2% off)block B4 values×8.6-0.77 → -0.75 (2.0% off)0.05 → 0.05 (1.6% off)0.34 → 0.35 (2.4% off)-0.59 → -0.58 (1.6% off)block C4 values×8.60.81 → 0.81 (0.3% off)-0.00015 → -0.00022670200892857144 (51.1% off)outlier 52 →0.47 → 0.46 (1.2% off)block D4 values×8.6-0.95 → -0.93 (2.3% off)0.26 → 0.26 (0.4% off)-0.68 → -0.70 (2.4% off)0.73 → 0.75 (3.4% off)-10+1zoomed to ±1 — the outlier sits far off to the right

The outlier wins. One scale must fit 52, so scale = 448/52 8.6 — every row gets the same coarse grid (the ticks). The 15 small values snap onto a handful of levels, and the two near-zero values round almost entirely away (up to 100% error). Avg round-trip error of the 15 small values: 11.5%.

Before any rounding, the rotation is exactly free: rotate every q and k vector with the same orthogonal R and it cancels inside the matmul — (QRᵀ)(KRᵀ)ᵀ = Q(RᵀR)Kᵀ = QKᵀ, since RᵀR = I. After FP8 rounding nothing is exact anymore — the rotation's job is to make that rounding cheaper, which is exactly the error drop shown above. All percentages are real round-trips through an E4M3 quantizer (1 sign + 4 exponent + 3 mantissa bits, max 448) on these 16 fixed values; in mode 3 the error is measured on the rotated values, because those are what actually get stored and multiplied.

And the story keeps going: FlashAttention-4 (2026) carries the same invariant onto NVIDIA's next generation, the Blackwell B200 — where matmul throughput roughly doubled again while the rest of the chip didn't — reaching up to 1,613 TFLOPS (71% utilization). Our deep-dive stops at v3 because v4's tricks are Blackwell-specific, but notice the pattern across every version: the algorithm's invariant never changes — the N×N matrix never touches slow memory — and each release just re-tunes that invariant to whatever the newest silicon's bottleneck happens to be.

Does your in-browser Qwen use it?

Partly — and the honest answer is the interesting one. The WebGPU backend running this model has a fused, tiled flash-attention kernel with the online softmax above, and it runs during prefill of the 6 full-attention layers. But the token-by-token decode you watch in the demo deliberately does not use it: at 0.8B with 8 query heads the fused kernel would launch only 8 GPU workgroups and leave most of the chip idle, so an occupancy gate routes decode back to the plain three-step path (matmul → softmax → matmul), which fans out far more work and measured ~90% faster here. And only 6 of 24 layers run softmax attention at all — the other 18 are GatedDeltaNet. So: flash during prefill, plain during decode, on purpose.

Toggle between the two and watch the dispatch route the call to the kernel that actually runs:

The native dispatch — which kernel actually runs
attention callTq = 1 query rowsB·H = 8 heads × batchoccupancy gateTq==1 & B·H<32 ?Fused FlashAttention-2one dispatch · tiles Q,K,Vonline softmax m, ℓ, o in SRAMruns in PREFILLDecomposed pathmatmul QKᵀsoftmaxmatmul ·Vfans out 100+ workgroupsruns in DECODE

Decode generates one token, so Tq = 1. The fused kernel would launch only B·H = 8 workgroups and leave most of the chip idle — so the gate is true and routes decode to the decomposed path, which fans out far more work and measured ~90% faster here.

gate: Tq = 1 and B·H = 8 < 32 → true decomposed
Only 6 of 24 layers reach this gate — the other 18 are GatedDeltaNet (linear).

Default route shown. A debug switch (?sdpa_fallback=1) forces every call — prefill and decode — through the decomposed path, ahead of this gate.