06. Gradient Checkpointing 是什么?它的代价和适用场景是什么?
整理 Gradient Checkpointing 的原理、计算代价与适用场景。
简单回答
Gradient Checkpointing(梯度检查点,也叫 Activation Recomputation)是一种用计算换显存的技术:前向传播时不保存所有中间激活值(Activation),只保存若干"检查点";反向传播需要某个层的激活时,从最近的检查点重新前向计算一遍,再做反向。显存节省可以非常可观(接近线性于层数),代价是增加了约 30%~40% 的额外计算量。适合显存严重受限、但计算资源相对宽裕的场景。
详细解答
为什么激活值占显存这么多
训练时显存占用来自三部分:模型参数(相对固定)、优化器状态(通过 ZeRO 等方法可以分片)、以及Activation(前向传播的中间结果)。
Activation 是反向传播必须用到的——计算某层的梯度时,需要该层前向时的输入值。对于一个 L 层的网络,如果每层的输入维度是 (Batch、Sequence、Hidden),全部保存的话,Activation 的显存消耗是 。
以 LLaMA-7B 为例,BF16 精度,batch size=1,seq len=2048,hidden size=4096,32 层:
这看起来不多,但 Attention 层的 Activation 还包括 矩阵(),实际占用要大得多,可能到几 GB。对于长序列(seq len=8192 甚至更长)训练,Activation 显存是主要瓶颈。
Gradient Checkpointing 的工作原理
朴素的梯度检查点:把模型分成 段,每段之间的连接处保存激活(这些是检查点),段内的激活在前向时全部丢弃,反向时需要某段的激活就从该段起始的检查点重新前向算一遍。
通过选择合理的检查点间隔,可以把 Activation 显存从 降到 。
PyTorch 提供的 torch.utils.checkpoint.checkpoint 实现了更细粒度的控制——可以对任意子模块应用 checkpointing:
在大模型训练框架里(DeepSpeed、Megatron、FSDP),通常提供更高级的接口,比如按 Transformer 层来应用 checkpointing:
计算代价
重新计算的代价:每个被 checkpointing 的层,在反向传播时需要额外做一次前向。如果对所有层都应用 checkpointing,相当于整个网络做了两次前向传播,总训练时间增加约 30%~40%(不是 100%,因为反向本身比前向慢,两者加起来前向只占总时间的约 30%)。
这个代价在不同场景下的接受程度不同。如果显存是瓶颈(比如用 gradient checkpointing 能把 batch size 从 1 提到 4),那多 30% 计算换 4 倍 batch size,整体 throughput 还是提升的。如果显存不是瓶颈,加 checkpointing 只是增加计算量,没有意义。
选择性 Checkpointing
不是所有层都需要 checkpointing,可以按计算量和显存占比做取舍。Attention 层的 Softmax 是 Activation 显存的大头,对 Attention 单独做 checkpointing 效果很好。FFN 层的 Activation 相对小,可以不做或做粗粒度的。FlashAttention 本身就自带了 Attention 的重计算逻辑(反向时重算 Softmax,不保存 Attention 矩阵),相当于免费获得了 Attention 层的 checkpointing。
面试时可以这样答
Gradient Checkpointing 是用计算换显存——前向传播时不存所有 Activation,只在特定检查点存;反向需要某层的 Activation 时,从最近的检查点重新前向算一遍。
显存节省很可观,对所有层都做的话,Activation 显存从 降到 ,长序列场景下能省非常多显存。代价是多了一次前向,总训练时间大约增加 30%~40%。
实际选择上,如果显存是瓶颈,checkpointing 很值——省了显存能上更大 batch,throughput 反而可能提升。如果显存够用,就不要加,纯粹浪费计算。还有个细节,FlashAttention 自带 Attention 的重计算,等于免费拿到了 Attention 层的 checkpointing。所以现在很多训练配置是:FlashAttention 处理 Attention、FFN 用 full checkpointing,效果和效率都不错。
常见追问
- 最优的检查点间隔怎么计算?显存节省和计算开销的 trade-off 公式是什么?
- Selective Checkpointing(选择性 checkpointing)的策略是什么?
- Gradient Checkpointing 和 ZeRO 配合使用,显存优化是叠加的吗?