14. FlashAttention 的核心思路是什么?它为什么能同时节省显存和加速?
整理 FlashAttention 的核心思路及其显存与速度优势。
简单回答
FlashAttention 的核心思路是 tiling + 重计算:把 Q、K、V 分成小块,在 GPU SRAM(快速缓存)中完成注意力计算,避免将完整的 N×N 注意力矩阵写入 HBM(显存)。它节省显存是因为不用存中间的大注意力矩阵,加速是因为大幅减少了 HBM 读写次数——标准 attention 是 memory-bound 的,瓶颈不在计算而在访存。
详细解释
为什么标准 Attention 慢?
标准 attention 的计算流程是:算 QK^T → softmax → 乘 V。关键问题是 QK^T 产生的注意力矩阵是 N×N 的,需要先完整写入显存(HBM),再从显存读出来做 softmax,然后再读出来乘 V。
GPU 的计算单元(SM)本身很快,但 HBM 的带宽是瓶颈。标准 attention 是典型的 memory-bound 操作:大量时间花在"搬数据"上,而不是"算数据"上。N×N 的注意力矩阵既占空间(显存 O(N²)),又需要反复读写。
FlashAttention 的做法
核心是 IO-aware 的算法设计。具体来说:
把 Q 分成若干小块,把 K 和 V 也分成小块,然后外循环遍历 K/V 块,内循环遍历 Q 块。每次只把一小块 Q、K、V 加载到 SRAM 中,在 SRAM 里完成 QK^T、softmax(局部)、乘 V 的全部计算,然后把结果累积写回 HBM。
这里有一个技术难点:softmax 需要全局的归一化因子(分母),但 tiling 后每次只看到局部的 K。FlashAttention 用 online softmax 算法解决了这个问题——在遍历 K 块的过程中增量地更新 softmax 的分母,最终结果是精确的(不是近似)。
为什么能同时节省显存和加速?
节省显存的原因很直接:不需要存储完整的 N×N 注意力矩阵,显存复杂度从 O(N²) 降到 O(N)。
加速的原因是减少了 HBM 访问量。标准 attention 需要把 N×N 矩阵多次在 HBM 和 SRAM 之间搬运。FlashAttention 的 HBM 访问量从 O(N² d) 降到了 O(N² d² / M)(M 是 SRAM 大小)。在 SRAM 足够大的情况下,这个减少是非常显著的。
FlashAttention 2 和 3
FlashAttention 2 进一步优化了并行度和 warp 级别的调度,减少了非计算指令的开销,速度比 v1 又提升了约 2 倍。FlashAttention 3 则针对 Hopper 架构(H100)的新硬件特性做了优化。
关键认知
FlashAttention 没有改变 attention 的计算结果,它是精确的、数学等价的。它改变的是计算的执行顺序,把一个 memory-bound 的操作变得更 compute-bound,从而更好地利用 GPU 的计算能力。
面试时可以这样答
FlashAttention 的核心出发点是:标准 attention 是一个典型的 memory-bound 操作,瓶颈不在计算而在显存访问。QK^T 生成的 N×N 注意力矩阵需要写入 HBM 再读出来,这个反复搬运的过程非常慢。
FlashAttention 的做法是 tiling——把 Q、K、V 分成小块,每次只加载一小块到 GPU 的 SRAM(片上高速缓存)中,在 SRAM 里完成全部计算再写回。这样就不需要存完整的 N×N 矩阵了。技术上的难点是 softmax 需要全局归一化因子,FlashAttention 用 online softmax 的方式增量地更新分母,最终结果和标准 attention 完全一致,是精确等价的,不是近似。
它能同时省显存和加速:省显存是因为不存 N×N 矩阵了,复杂度从 O(N²) 降到 O(N);加速是因为大幅减少了 HBM 读写次数,把操作从 memory-bound 推向 compute-bound。
现在 FlashAttention 基本是所有大模型训练和推理的标配了,PyTorch 2.0 也原生集成了。FA2 在 FA1 基础上进一步优化了 warp 级调度,FA3 针对 H100 的 Hopper 架构做了适配。
常见追问
- online softmax 具体怎么做到增量更新的?为什么结果是精确的?
- FlashAttention 的反向传播也需要特殊处理,它怎么在不存注意力矩阵的情况下算梯度的?
- FlashAttention 对所有序列长度都有加速吗?什么情况下加速效果最明显?