FlashAttention

How IO-awareness and tiling make exact attention fast and memory-efficient.

This article explains the paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Dao et al. (2022), which introduced an algorithm that makes transformer attention 2-4x faster without any approximation.

The Problem: Attention is Memory-Bound

Standard self-attention in transformers has a well-known issue: it scales quadratically with sequence length. For a sequence of length $N$, we need to compute and store an $N \times N$ attention matrix.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

But here’s the insight that changes everything: the bottleneck isn’t computation—it’s memory access.

Modern GPUs are incredibly fast at arithmetic. The NVIDIA A100 can perform 312 TFLOPS of compute. But moving data between memory types? That’s the real bottleneck.

Key insight: Standard attention implementations are IO-bound, not compute-bound. We spend more time moving data than doing math.

GPU Memory Hierarchy

To understand FlashAttention, we need to understand how GPU memory works.

HBM
40-80 GB
Main memory
(Q, K, V, O)
1.5 TB/s bottleneck
SRAM
~20 MB
On-chip cache
19 TB/s 10× faster
Compute
312 TFLOPS
Tensor cores
GPU memory hierarchy on A100. Data must flow through SRAM to reach compute. The HBM→SRAM transfer is the bottleneck—not the computation itself.

GPUs have two main memory types:

HBM (High Bandwidth Memory)

SRAM (On-chip Cache)

The key numbers: SRAM is ~10x faster but ~1000x smaller than HBM.

Standard Attention: The Memory Problem

Let’s trace what standard attention does:

1
Load Q, K from HBM HBM → SRAM
2
Compute S = QKT Compute
3
Write S to HBM (N×N matrix!) SRAM → HBM
4
Read S from HBM HBM → SRAM
5
Compute P = softmax(S) Compute
6
Write P to HBM (N×N matrix!) SRAM → HBM
7
Read P, V from HBM HBM → SRAM
8
Compute O = PV Compute

The problem: we read/write the $N \times N$ attention matrix multiple times. For long sequences, this dominates runtime.

$$\text{HBM accesses} = \Theta(Nd + N^2)$$

FlashAttention: The Solution

FlashAttention’s key idea: never materialize the full attention matrix. Instead, compute attention in tiles that fit in SRAM.

FlashAttention tiling diagram
FlashAttention tiling and softmax rescaling. By operating on blocks and using online softmax to rescale partial results, we avoid writing the large N×N attention matrix to HBM. Image credit: Stanford CRFM
Interactive: Standard attention materializes the full N×N matrix in HBM (red). FlashAttention processes blocks in SRAM (green), discarding each after use.

The Tiling Strategy

Instead of computing the full attention matrix at once:

  1. Divide Q, K, V into blocks that fit in SRAM
  2. Compute attention for each block pair
  3. Accumulate results with proper normalization

But there’s a catch: softmax isn’t block-decomposable. You need the full row to compute the denominator.

The Online Softmax Trick

This is the algorithmic insight that makes FlashAttention possible.

Standard softmax requires two passes over the data:

  1. Find the maximum (for numerical stability)
  2. Compute exp and sum

Online softmax does it in one pass by maintaining running statistics and rescaling on the fly.

Online Softmax Update Rule
$$m^{(new)} = \max(m^{(old)}, \tilde{m})$$ $$\ell^{(new)} = e^{m^{(old)} - m^{(new)}} \ell^{(old)} + e^{\tilde{m} - m^{(new)}} \tilde{\ell}$$ $$O^{(new)} = \frac{\ell^{(old)} e^{m^{(old)} - m^{(new)}} O^{(old)} + e^{\tilde{m} - m^{(new)}} \tilde{P}V}{\ell^{(new)}}$$
$m$: running maximum for numerical stability
$\ell$: running sum of exponentials (softmax denominator)
$O$: running output, rescaled as we see more blocks

The key insight: when we see a new block with a larger maximum, we can rescale our previous partial results. This makes softmax associative—we can compute it block by block.

The Algorithm

FlashAttention algorithm loop structure
FlashAttention loop structure. The outer loop (orange) iterates over K,V blocks, loading them to SRAM. The inner loop (blue) iterates over Q blocks, computing attention in SRAM and writing output to HBM.
Interactive: Step through the algorithm to see data flow between HBM and SRAM. Press Play or use Step to advance.

Here’s the FlashAttention forward pass:

FlashAttention Forward Pass

Input: Matrices $Q, K, V \in \mathbb{R}^{N \times d}$ in HBM, block sizes $B_r, B_c$

Output: $O \in \mathbb{R}^{N \times d}$

  1. Divide $Q$ into $T_r = \lceil N/B_r \rceil$ blocks, $K, V$ into $T_c = \lceil N/B_c \rceil$ blocks

  2. Initialize $O = 0$, $\ell = 0$, $m = -\infty$ in HBM

  3. For $j = 1, \ldots, T_c$ (outer loop over K, V):

    • Load $K_j, V_j$ from HBM to SRAM
  4. For $i = 1, \ldots, T_r$ (inner loop over Q):

    • Load $Q_i, O_i, \ell_i, m_i$ from HBM to SRAM
    • On chip, compute $S_{ij} = Q_i K_j^T \in \mathbb{R}^{B_r \times B_c}$
    • On chip, compute:
      • $\tilde{m}_{ij} = \text{rowmax}(S_{ij}) \in \mathbb{R}^{B_r}$
      • $\tilde{P}_{ij} = \exp(S_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c}$
      • $\tilde{\ell}_{ij} = \text{rowsum}(\tilde{P}_{ij}) \in \mathbb{R}^{B_r}$
    • Update $m_i^{\text{new}}, \ell_i^{\text{new}}, O_i$ using online softmax
    • Write $O_i, \ell_i, m_i$ to HBM
  5. Return $O$

The critical property: the $N \times N$ attention matrix $S$ is never fully materialized in HBM. Each block $S_{ij}$ exists only briefly in SRAM.

$$\text{HBM accesses} = O\left(\frac{N^2 d^2}{M}\right)$$

Where $M$ is SRAM size. For typical values ($d = 64$, $M = 100$KB), this is 5-20x fewer HBM accesses.

Why It Works: Arithmetic Intensity

Arithmetic intensity = FLOPs / bytes moved

Standard attention has low arithmetic intensity: we move lots of data for relatively little compute. FlashAttention increases arithmetic intensity by reusing data in SRAM.

Method HBM Reads/Writes Arithmetic Intensity
Standard Attention $\Theta(Nd + N^2)$ Low
FlashAttention $O(N^2d^2/M)$ High

Results

FlashAttention achieves significant speedups across different models and sequence lengths:

FlashAttention-2 A100 benchmark
FlashAttention-2 benchmark on A100 GPU. Forward and backward pass speedup compared to baseline attention across different sequence lengths and head dimensions. Image credit: Stanford CRFM
Model Sequence Length Speedup
BERT-large 512 15% faster end-to-end
GPT-2 1K 3× faster
Long-range arena 1K-4K 2.4× faster

More importantly, FlashAttention enables much longer sequences. The paper shows the first Transformer to achieve better-than-random performance on Path-X (16K tokens) and Path-256 (64K tokens).

Memory savings: For a 2K sequence with 16 heads and head dimension 64, standard attention needs ~1GB for the attention matrix. FlashAttention needs only the block size (~MB), a reduction of ~1000×.

Extensions: FlashAttention-2 and 3

The original FlashAttention has been followed by improved versions:

FlashAttention-2 (2023) improves parallelism:

FlashAttention-2 work partitioning
Work partitioning improvement in FlashAttention-2. The original "sliced-K" approach (left) requires synchronization between warps. FlashAttention-2 (right) partitions work to reduce synchronization overhead. Image credit: Stanford CRFM

FlashAttention-3 (2024) leverages new hardware features:

Key Takeaways

1
IO-awareness matters. Modern GPUs are compute-rich but memory-bandwidth-limited. Algorithm design must account for the memory hierarchy.
2
Tiling enables efficiency. By processing data in blocks that fit in fast cache, we can dramatically reduce slow memory access.
3
Online algorithms unlock parallelism. The online softmax trick makes softmax block-decomposable, enabling the tiled computation.
4
Exact beats approximate. FlashAttention is faster than approximate attention methods while computing the exact same result.

References

  1. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

  2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.

  3. Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.

  4. Milakov, M., & Gimelshein, N. (2018). Online Normalizer Calculation for Softmax. arXiv:1805.02867.