02. ZeRO-1/2/3 分别优化了什么?和普通 DDP 有什么区别?

整理 ZeRO-1/2/3 的优化对象、显存收益与 DDP 的区别。

简单回答

ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 提出的显存优化方案,核心思路是消除数据并行训练中多卡之间的冗余状态存储。普通 DDP 每张卡都完整存储模型参数、梯度、优化器状态三份数据;ZeRO-1 只对优化器状态做分片;ZeRO-2 进一步对梯度做分片;ZeRO-3 把参数本身也做了分片,理论上显存消耗随 GPU 数量线性下降。代价是需要更多的通信量来聚合被分片的数据。

详细解答

训练状态的显存构成

要理解 ZeRO,先要清楚训练时显存里存了什么。以一个参数量为 的模型、用 Adam 优化器、混合精度(FP16 参数 + FP32 Master Weight)训练为例:

  • 模型参数:FP16 存储, Bytes
  • 梯度:FP16, Bytes
  • 优化器状态:Adam 有 FP32 的 Master Weight、一阶动量、二阶动量,共 Bytes

总计 Bytes。一个 7B 参数的模型,光这些状态就需要约 112 GB 显存,远超单张 A100 的 80 GB。

普通 DDP 里,每张 GPU 都完整保存上面所有状态,N 张卡就是 份冗余。

ZeRO-1:分片优化器状态

ZeRO-1 只对优化器状态(Optimizer States)做分片。 张 GPU 各只保存 份优化器状态,参数和梯度仍然每张卡都有完整副本。

通信模式变化:梯度汇总还是 AllReduce;但更新参数时,每张卡只更新自己负责的那 份参数,然后用 AllGather 把更新后的参数广播给所有卡。

显存节省:优化器状态从 降到 ,总显存从 降到

ZeRO-2:进一步分片梯度

ZeRO-2 在 ZeRO-1 的基础上,把梯度也做了分片。每张 GPU 在反向传播过程中,只保留自己负责的那 份参数对应的梯度,其他梯度在用完后立即释放。

通信模式:反向传播不再做 AllReduce,改为 ReduceScatter(每张卡得到所有卡梯度平均后自己负责的那 份)。

显存节省:梯度从 降到 ,总显存从 降到

ZeRO-3:参数本身也分片

ZeRO-3 是最激进的方案,把模型参数本身也做了分片——每张 GPU 只持有 份模型参数。需要用到某层参数时,通过 AllGather 从所有卡收集完整参数进行计算,用完立即释放,只保留自己负责的那份。

理论上,随着 增大,每张卡的显存需求趋近于 大小,参数、梯度、优化器状态的显存几乎可以忽略。

代价:相比 ZeRO-1/2,ZeRO-3 的通信量更大——每次前向和反向都需要 AllGather 参数。通信和计算的重叠(Overlap)是 ZeRO-3 能否高效运行的关键。

ZeRO++ 和 ZeRO-Infinity

ZeRO++ 是在 ZeRO-3 基础上的进一步优化,主要通过量化通信(把 AllGather 的参数做量化,减少通信量)和分层 AllGather(hierarchical AllGather,先在节点内 AllGather,减少跨节点通信)来降低通信开销。

ZeRO-Infinity 把优化器状态甚至参数 offload 到 CPU 内存和 NVMe SSD,进一步降低 GPU 显存压力,代价是速度更慢(CPU/NVMe 带宽比显存带宽小很多)。适合显存严重受限但对训练速度要求不高的场景。

和 DDP 的实际区别

普通 DDP:每张卡完整存储所有训练状态,梯度汇总用 AllReduce,实现简单,通信量 (AllReduce 等效于 ReduceScatter + AllGather)。

ZeRO-1/2:通信量和 DDP 相当,但显存消耗更低。是显存受限时最轻量的升级方案。

ZeRO-3:通信量增加(多了前向的 AllGather),但显存可以做到极低。在 GPU 数量多的情况下最合算。

面试时可以这样答

ZeRO 是 DeepSpeed 的核心显存优化方案,本质上是消除数据并行多卡之间的冗余存储。

普通 DDP 里每张卡都完整存参数、梯度、优化器状态三份,一个 7B 模型用 Adam 混合精度训练光这三样就要 112 GB,单卡装不下。

ZeRO 按照优化粒度分三个 Stage:ZeRO-1 只对优化器状态分片,每张卡只存 份 Adam 状态,最容易实现,通信模式变化最小;ZeRO-2 进一步把梯度也分片,AllReduce 梯度换成了 ReduceScatter;ZeRO-3 最激进,参数本身也分片,用的时候 AllGather 进来算完就扔,理论上显存随 GPU 数量线性下降。

代价是通信量增加,特别是 ZeRO-3 前向也要做 AllGather,通信和计算的重叠做得好不好直接决定实际效率。实际工程选型上:模型能放进单卡就用 DDP;放不下先试 ZeRO-2,大多数场景够用;需要训超大模型或者 GPU 数量很多,上 ZeRO-3,然后仔细调 overlap 参数。

常见追问

  1. ZeRO-3 通信量比 DDP 多多少?能不能量化?
  2. ZeRO-Infinity 把参数 offload 到 CPU/NVMe,速度影响有多大?
  3. ZeRO 和张量并行可以同时用吗?怎么配合?