This article explains the paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Dao et al. (2022), which introduced an algorithm that makes transformer attention 2-4x faster without any approximation.
The Problem: Attention is Memory-Bound
Standard self-attention in transformers has a well-known issue: it scales quadratically with sequence length. For a sequence of length $N$, we need to compute and store an $N \times N$ attention matrix.
But here’s the insight that changes everything: the bottleneck isn’t computation—it’s memory access.
Modern GPUs are incredibly fast at arithmetic. The NVIDIA A100 can perform 312 TFLOPS of compute. But moving data between memory types? That’s the real bottleneck.
GPU Memory Hierarchy
To understand FlashAttention, we need to understand how GPU memory works.
(Q, K, V, O)
GPUs have two main memory types:
HBM (High Bandwidth Memory)
- Size: 40-80 GB on A100
- Bandwidth: 1.5-2.0 TB/s
- Role: Main GPU memory where model weights and activations live
SRAM (On-chip Cache)
- Size: ~20 MB total (192 KB per streaming multiprocessor × 108 SMs)
- Bandwidth: ~19 TB/s
- Role: Fast scratch space for active computation
The key numbers: SRAM is ~10x faster but ~1000x smaller than HBM.
Standard Attention: The Memory Problem
Let’s trace what standard attention does:
The problem: we read/write the $N \times N$ attention matrix multiple times. For long sequences, this dominates runtime.
FlashAttention: The Solution
FlashAttention’s key idea: never materialize the full attention matrix. Instead, compute attention in tiles that fit in SRAM.
The Tiling Strategy
Instead of computing the full attention matrix at once:
- Divide Q, K, V into blocks that fit in SRAM
- Compute attention for each block pair
- Accumulate results with proper normalization
But there’s a catch: softmax isn’t block-decomposable. You need the full row to compute the denominator.
The Online Softmax Trick
This is the algorithmic insight that makes FlashAttention possible.
Standard softmax requires two passes over the data:
- Find the maximum (for numerical stability)
- Compute exp and sum
Online softmax does it in one pass by maintaining running statistics and rescaling on the fly.
The key insight: when we see a new block with a larger maximum, we can rescale our previous partial results. This makes softmax associative—we can compute it block by block.
The Algorithm
Here’s the FlashAttention forward pass:
Input: Matrices $Q, K, V \in \mathbb{R}^{N \times d}$ in HBM, block sizes $B_r, B_c$
Output: $O \in \mathbb{R}^{N \times d}$
-
Divide $Q$ into $T_r = \lceil N/B_r \rceil$ blocks, $K, V$ into $T_c = \lceil N/B_c \rceil$ blocks
-
Initialize $O = 0$, $\ell = 0$, $m = -\infty$ in HBM
-
For $j = 1, \ldots, T_c$ (outer loop over K, V):
- Load $K_j, V_j$ from HBM to SRAM
-
For $i = 1, \ldots, T_r$ (inner loop over Q):
- Load $Q_i, O_i, \ell_i, m_i$ from HBM to SRAM
- On chip, compute $S_{ij} = Q_i K_j^T \in \mathbb{R}^{B_r \times B_c}$
- On chip, compute:
- $\tilde{m}_{ij} = \text{rowmax}(S_{ij}) \in \mathbb{R}^{B_r}$
- $\tilde{P}_{ij} = \exp(S_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c}$
- $\tilde{\ell}_{ij} = \text{rowsum}(\tilde{P}_{ij}) \in \mathbb{R}^{B_r}$
- Update $m_i^{\text{new}}, \ell_i^{\text{new}}, O_i$ using online softmax
- Write $O_i, \ell_i, m_i$ to HBM
-
Return $O$
The critical property: the $N \times N$ attention matrix $S$ is never fully materialized in HBM. Each block $S_{ij}$ exists only briefly in SRAM.
Where $M$ is SRAM size. For typical values ($d = 64$, $M = 100$KB), this is 5-20x fewer HBM accesses.
Why It Works: Arithmetic Intensity
Arithmetic intensity = FLOPs / bytes moved
Standard attention has low arithmetic intensity: we move lots of data for relatively little compute. FlashAttention increases arithmetic intensity by reusing data in SRAM.
| Method | HBM Reads/Writes | Arithmetic Intensity |
|---|---|---|
| Standard Attention | $\Theta(Nd + N^2)$ | Low |
| FlashAttention | $O(N^2d^2/M)$ | High |
Results
FlashAttention achieves significant speedups across different models and sequence lengths:
| Model | Sequence Length | Speedup |
|---|---|---|
| BERT-large | 512 | 15% faster end-to-end |
| GPT-2 | 1K | 3× faster |
| Long-range arena | 1K-4K | 2.4× faster |
More importantly, FlashAttention enables much longer sequences. The paper shows the first Transformer to achieve better-than-random performance on Path-X (16K tokens) and Path-256 (64K tokens).
Extensions: FlashAttention-2 and 3
The original FlashAttention has been followed by improved versions:
FlashAttention-2 (2023) improves parallelism:
- Better work partitioning across GPU thread blocks
- Reduced non-matmul FLOPs
- 2× faster than FlashAttention-1
FlashAttention-3 (2024) leverages new hardware features:
- Asynchronous execution (overlap compute and memory access)
- FP8 low-precision support
- Further speedups on H100 GPUs
Key Takeaways
References
-
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
-
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.
-
Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.
-
Milakov, M., & Gimelshein, N. (2018). Online Normalizer Calculation for Softmax. arXiv:1805.02867.
Michael Wan Interactive Insights