The algorithm splits Q, K, V matrices into blocks of size B_r × d and B_c × d that fit in SRAM (typically 100-200 kB per SM). For each Q block it loads it once, then iterates over K and V blocks, computing partial attention results and accumulating them with numerically stable online softmax: maintaining current max m and sum l, at each new (K_j, V_j) pair updating O ← rescale(O_prev, m_old, m_new) + exp(S_new - m_new) · V_j. The n×n attention matrix is never materialized in HBM. The backward pass uses recomputation instead of saved attention matrix (gradient checkpointing).
The standard attention implementation materializes the n×n matrix in HBM and is memory-bound — the dominant cost is not softmax FLOPs but data transfer. This limits maximum context length and throughput.
Splitting Q, K, V matrices into blocks fitting in GPU SRAM (typically B_r × d ~ 64-128 × 64-128).
Numerically stable recurrence maintaining running max and exponential sum — allows computing softmax block-wise without materializing the full matrix.
Backward doesn't save the attention matrix, recomputing it from saved O, L (logsumexp) — FLOPs vs memory trade-off.
FlashAttention-3 requires Hopper (H100/H200) — doesn't work on Ampere (A100). v2 is standard on A100. Wrong version = loss of 2-4× speedup.
FlashAttention assumes standard scaled-dot-product attention with optional causal mask. Non-standard masks (e.g., ALiBi, block-sparse) require special variants or prevent its use.
First publication — tiling + online softmax, 2-4× speedup, O(n) memory for exact attention.
Better work partitioning across GPU warps, parallelism along sequence dimension — 2× faster than v1, ~50-70% peak FLOPs on A100.
PyTorch adds FlashAttention as default backend for F.scaled_dot_product_attention — mass adoption across the ecosystem.
Hopper (H100) support: asynchronous TMA, warp-specialization, FP8 — up to 75% peak FLOPs on H100, 2× faster than v2.
FlashAttention shifts attention from memory-bound toward compute-bound by maximizing data reuse in SRAM. The HBM bandwidth limit remains for loading Q/K/V blocks.
Exact attention — all token-token pairs are computed (modulo causal mask). A purely algorithmic optimization, no change to model mathematics.