Flash Attention
Go deeper · Chapter 4, Self-attention — the same result, without the N×N matrix.
Flash Attention computes the same softmax(QKᵀ/√d)·V — the exact attention operation, not an approximation like sparse or low-rank attention — without ever materializing the N×N matrix. It tiles Q, K, and V into small blocks and, for each query tile, streams over the key tiles while keeping a small running state per query row: two single numbers — the max score so far m and the running sum of exponentials ℓ — plus a running output vector o, one entry per output dimension.
“Exact” here means the algorithm, not the bits: because it sums the key tiles in a different order, the output matches the naïve path up to floating-point rounding, not bit-for-bit. That is the normal floating-point caveat — reordering a sum changes the last few bits — not an approximation of attention itself.
The trick is the online softmax: when a later tile raises the running max, you rescale what you have already accumulated so the normalization stays exactly correct. In the update below, and are this tile's own sum and weighted values measured against the updated max — so once the running totals are scaled by , the new tile folds straight in with no extra factor.
Those tiles live in the GPU's fast on-chip memory (SRAM), so the giant matrix never travels to slow memory at all. Same answer, a fraction of the memory traffic — the technique is “IO-aware.”
A short history
Flash Attention didn't arrive fully formed — it is the payoff of a chain of ideas about computing softmax without ever holding the whole row in memory. Step through the lineage:
Make it IO-aware: tile Q/K/V into on-chip SRAM, fuse the whole operation into one kernel, and recompute in the backward pass — so the N×N matrix is never written to slow HBM. Exact, not an approximation (plus a block-sparse variant).
The same idea spread fast beyond the original code: xformers' memory_efficient_attention, OpenAI's Triton flash kernels, NVIDIA's cuDNN fused attention, and PyTorch's scaled_dot_product_attention (which dispatches to whichever is fastest). And PagedAttention (from vLLM) is a cousin, not a version — it is an attention kernel built around a paged KV-cache layout for memory-efficient serving, a different problem from FlashAttention's dense in-SRAM tiling. The fused kernel your in-browser Qwen runs during prefill is a FlashAttention-2-style tile kernel — v3's tricks are specific to NVIDIA's Hopper GPUs and FP8, so they don't apply here.
How each version actually works
The timeline above is the map; here is the territory. Each step fixed the previous one's bottleneck, so they make the most sense in order — one mechanism at a time.
The foundation: streaming softmax
Everything starts with one trick, so let's slow down and earn it. First, softmax itself, with real numbers. Softmax turns a row of scores into weights that sum to 1 by exponentiating each score and dividing by the total. Take the row [2, 1, 3]: the exponentials are , , ; their sum is 30.19; divide through and the weights are 0.245, 0.090, 0.665. The biggest score wins most of the weight — that is the whole job.
One practical wrinkle: explodes. already overflows a 32-bit float (the limit is about ), and raw attention scores can get big. The standard fix is to subtract the row's max from every score before exponentiating — the largest term becomes and nothing overflows. Dividing by the sum cancels the shift, so the weights come out identical. But notice what that fix costs: you need the max of the whole row before you can exponentiate anything. Hence two passes — one to find the max, one to sum the exponentials.
The online softmax (2018) collapses that to a single pass: keep a running max and a running sum, and when a later value beats the max so far, patch the sum you already built by multiplying it by . Why is that the exact right patch? Every term you accumulated has the form , and
— one multiply re-expresses every old term against the new max, as if you had known it all along. Concretely, in the animation below the first chunk [2, 1, 3] finishes with max 3; then chunk two contains a 5, so the running sum and the running output are both multiplied by before the new terms fold in. The 2021 “memory-efficient attention” paper noticed you can run this over chunks of keys and values — so you never need the whole row in memory at once.
Start empty: running max m = −∞, running sum ℓ = 0. We will walk the row left to right, one chunk of 3 at a time — a single pass.
This single-pass rescale is the seed FlashAttention later runs entirely inside the GPU's tiny on-chip memory (SRAM) — so the N×N score matrix never has to be written out to slow main memory.
FlashAttention — tiling the whole thing into SRAM
A GPU has two kinds of memory: HBM, the big main memory that is far away and slow, and SRAM, a tiny scratchpad right next to the compute units that is roughly ten times faster. The naïve recipe builds the full N×N score matrix in HBM and drags it back and forth — the “memory wall” from the previous sub-chapter. FlashAttention (2022) keeps the scores out of HBM entirely. It cuts Q, K, and V into small tiles; for each tile of queries it streams over the key/value tiles, computes each little score tile inside SRAM, folds it straight into the running softmax, and throws it away. Only the finished output rows are written back. The giant score matrix never exists in slow memory — the same exact answer, a fraction of the traffic.
How big is a tile? Just small enough that everything in flight fits on the desk. With a typical 64-wide head, a 128-row query tile is 128 × 64 numbers ≈ 32 KB in fp32. A key tile and a value tile of the same shape are another 32 KB each, the little 128 × 128 score tile is 64 KB — and the running output o is itself a 128 × 64 tile, another 32 KB (only m and ℓ are single numbers per row). Add it up: ≈ 190 KB, right at the edge of the ~192 KB of SRAM sitting next to the compute units — which is exactly why the tiles are this size and not bigger. Production kernels squeeze further with 16-bit tiles and narrower key blocks, but the budget arithmetic — “what fits in SRAM” — is what picks the tile sizes.
One honest footnote: the invariant v1 introduced is that the N×N score matrix never touches HBM. The exact loop order shown here — queries on the outside, keys/values streamed on the inside, each output row written once — is the cleaner schedule that v2 settled on. The original v1 kernel actually looped the other way (keys/values outside, queries inside), revisiting the output as it went. The diagram below draws the modern order because that is what kernels run today.
Outer loop, tile 0. Load query tile Q₀ from HBM into fast on-chip SRAM. We will reuse it across every K/V tile.
Schedule shown: queries on the outer loop, keys/values streamed on the inner — the modern (FA2-style) order. v1's contribution is the invariant either order keeps: the N×N score matrix never reaches HBM.
How much traffic does that actually save? At 4,096 tokens, the naïve recipe pushes the score matrix across the slow HBM bus four times — written as raw scores, read back for softmax, written as probabilities, read again to weight V. For this model's 8 heads in bf16 that is ≈ 1.07 GB of bus traffic per full-attention layer, for scores that are used once and thrown away. The tiled kernel moves exactly zero of those bytes. Drag the sequence length and watch the gap:
Bars are drawn on a square-root scale so the small values stay visible; the numbers are exact.
Honest footnote: Q, K, V and the output still cross the bus in both schemes — tens of MB at these sizes (≈42 MB at 4K tokens for this model), similar for both, growing only linearly with N — and flash re-reads the K/V tiles a few times. The point is that the unbounded N² score term disappears entirely.
There is a second half to v1 that is easy to miss: training. The backward pass (computing gradients) normally needs the attention weights again — and storing the N×N matrix for later would defeat the entire point. FlashAttention's answer is recomputation: keep only the output and two small per-row statistics (that same max m and sum ℓ), then re-derive each score tile from Q and K inside SRAM when the backward pass needs it. That is deliberately doing more arithmetic to do less memory traffic — and it wins, because on a modern GPU the math is cheap and the trips to HBM are not. This counterintuitive trade is the heart of the whole “IO-aware” idea.
A 30-second tour of the GPU
The next two versions are about filling the machine, so you need a picture of the machine. A GPU is not one processor — it is more like an office park. An NVIDIA A100 has 108 SMs (streaming multiprocessors): independent little processors, each with its own SRAM scratchpad (~192 KB) and its own Tensor Cores, the dedicated matrix-multiply units. Work arrives as thread blocks — each block is assigned to one SM, and inside a block the threads run in bundles of 32 called warps, which execute in lockstep. Two consequences matter here: an SM with no block assigned does nothing, and the kernel decides how many blocks exist. Zoom through the levels:
Each square is an SM — a streaming multiprocessor, one of 108 independent little processors on an A100. Work is handed to SMs as thread blocks; an SM with no block assigned does nothing. FlashAttention-1 launches batch × heads = 1 × 8 = 8 blocks, so only 8 of the 108 SMs have anything to do.
FlashAttention-2 — keep every core busy
FlashAttention-2 (2023) changes none of the math — it changes how the work is spread across the chip. Run the numbers with the office-park picture in mind. The original kernel launched roughly one thread block per (batch × attention head): one prompt times 8 heads is 8 blocks — on a 108-SM A100, about 7% of the chip working and 100 SMs sitting dark. v2's headline fix is to also split the query sequence into blocks: a 4,096-token prompt in 128-row tiles gives 32 query tiles, so 8 heads × 32 tiles = 256 blocks — more than enough to light up every SM.
It also re-partitions the work inside a block: instead of every warp computing a slice of each row and then combining results through shared memory, each warp owns a band of query rows outright — no cross-warp bookkeeping. And it shaves the non-matmul arithmetic, because regular math runs far slower than the Tensor Cores: the running output is kept unnormalized and divided by ℓ exactly once at the very end instead of being normalized at every step (the running max-rescale by still happens each tile — it is only the ℓ-division that is deferred), and for causal masks (where future tokens cannot be attended anyway) whole score tiles that would be fully masked are simply never computed — measured at ~1.7–1.8× by itself. Together: about twice as fast as v1, with the attention kernel hitting up to 73% of the A100's theoretical FP16 peak. (That occupancy idea is exactly what drives the decode story below.)
FlashAttention-2 also splits the query rows into 4 tiles, so it launches 8×4 = 32 thread blocks and fills the whole GPU. Inside each block it uses split-Q: every warp owns its own band of rows end-to-end, so the warps never combine partials through shared memory.
FlashAttention-3 — overlap and low precision
FlashAttention-3 (2024) is tuned for NVIDIA's Hopper GPUs (the H100), and it starts from a wild imbalance in the hardware. The H100's Tensor Cores deliver about 989 TFLOPS of fp16 matrix multiply — but the little units that compute exponentials (which every softmax needs) manage only about 3.9 TFLOPS. That is a 256× speed gap between the two kinds of math attention alternates between. Run matmul → softmax → matmul strictly in order, and the most expensive silicon on the chip spends much of its time waiting for the cheap part to finish.
FlashAttention-3's answer is to stop taking turns. Hopper lets warps specialize: producer warps do nothing but tell the TMA (a dedicated copy engine) to fetch the next tiles from HBM, while consumer warpgroups do nothing but math. Then two consumer warpgroups play pingpong: while one runs its softmax on the exp units, the other's matmuls keep the Tensor Cores fed — and they swap. Each tile is really two matmuls — the scores QKᵀ, then the value mix P·V — with the softmax wedged between, and it is those that get overlapped:
FlashAttention-3 overlaps the two kinds of work: while the Tensor Cores grind tile j+1's matmul, the exp units already run tile j's softmax. The Tensor-Core lane stays nearly solid — only 16% idle — so the schedule finishes sooner. (Each “matmul” bar is simplified: a real tile runs two GEMMs, QKᵀ then P·V, with the softmax between.)
The speedup comes from two levers: keeping the Tensor Cores busy by overlapping matmul with softmax (warp-specialized “pingpong” scheduling), and cheaper FP8 math. Block sizes here are illustrative, not measured.
The second lever is FP8 — and it deserves a slower look, because “8-bit float” sounds innocent until you see what it costs. An fp8 number (the E4M3 format) has 256 possible bit patterns and a maximum value of 448; matmuls in fp8 run about twice as fast as fp16. The catch: with so few representable values, everything hinges on the scale factor that maps your data onto them. Real attention inputs contain occasional outliers — one huge value in a sea of small ones. Pick one scale for the whole tensor and that single outlier stretches the grid so far that all the small values collapse onto a handful of levels and their information is gone. FlashAttention-3 rescues fp8 twice: block quantization gives each small tile its own scale, so one outlier only coarsens its own block — and incoherent processing multiplies Q and K by a fixed random rotation that smears the outlier's energy across every dimension before rounding. Before any rounding that rotation is mathematically free — rotate both sides and the dot products are unchanged. After fp8 rounding nothing is exactly preserved; the rotation's job is to make the rounding cheap, and the error it saves is what the widget below measures:
The outlier wins. One scale must fit 52, so scale = 448/52 ≈ 8.6 — every row gets the same coarse grid (the ticks). The 15 small values snap onto a handful of levels, and the two near-zero values round almost entirely away (up to 100% error). Avg round-trip error of the 15 small values: 11.5%.
Before any rounding, the rotation is exactly free: rotate every q and k vector with the same orthogonal R and it cancels inside the matmul — (QRᵀ)(KRᵀ)ᵀ = Q(RᵀR)Kᵀ = QKᵀ, since RᵀR = I. After FP8 rounding nothing is exact anymore — the rotation's job is to make that rounding cheaper, which is exactly the error drop shown above. All percentages are real round-trips through an E4M3 quantizer (1 sign + 4 exponent + 3 mantissa bits, max 448) on these 16 fixed values; in mode 3 the error is measured on the rotated values, because those are what actually get stored and multiplied.
And the story keeps going: FlashAttention-4 (2026) carries the same invariant onto NVIDIA's next generation, the Blackwell B200 — where matmul throughput roughly doubled again while the rest of the chip didn't — reaching up to 1,613 TFLOPS (71% utilization). Our deep-dive stops at v3 because v4's tricks are Blackwell-specific, but notice the pattern across every version: the algorithm's invariant never changes — the N×N matrix never touches slow memory — and each release just re-tunes that invariant to whatever the newest silicon's bottleneck happens to be.
Does your in-browser Qwen use it?
Partly — and the honest answer is the interesting one. The WebGPU backend running this model has a fused, tiled flash-attention kernel with the online softmax above, and it runs during prefill of the 6 full-attention layers. But the token-by-token decode you watch in the demo deliberately does not use it: at 0.8B with 8 query heads the fused kernel would launch only 8 GPU workgroups and leave most of the chip idle, so an occupancy gate routes decode back to the plain three-step path (matmul → softmax → matmul), which fans out far more work and measured ~90% faster here. And only 6 of 24 layers run softmax attention at all — the other 18 are GatedDeltaNet. So: flash during prefill, plain during decode, on purpose.
Toggle between the two and watch the dispatch route the call to the kernel that actually runs:
Decode generates one token, so Tq = 1. The fused kernel would launch only B·H = 8 workgroups and leave most of the chip idle — so the gate is true and routes decode to the decomposed path, which fans out far more work and measured ~90% faster here.
Default route shown. A debug switch (?sdpa_fallback=1) forces every call — prefill and decode — through the decomposed path, ahead of this gate.