05. 混合精度训练的原理是什么?为什么要保留 FP32 的 Master Weight?
整理混合精度训练原理、FP16/BF16 差异与 FP32 Master Weight 的作用。
简单回答
混合精度训练是指参数计算用低精度(FP16 或 BF16),但优化器更新时维护 FP32 的主权重(Master Weight)的训练策略。这样能大幅减少显存占用和计算量(低精度矩阵乘法快 2~4 倍),同时通过 FP32 主权重保证参数更新的精度不丢失。保留 FP32 主权重的核心原因是:当学习率乘以梯度远小于参数值时,FP16 的精度不足以表示这个微小变化,更新会直接被舍入为零。
详细解答
FP16 和 BF16 的精度特点
理解混合精度,先要了解不同数值格式的特性。
FP32(单精度浮点):1 位符号 + 8 位指数 + 23 位尾数,可表示约 到 的数,精度约 7 位有效数字。
FP16(半精度浮点):1 位符号 + 5 位指数 + 10 位尾数,可表示约 到 65504 的数,精度约 3~4 位有效数字。最大值只有 65504,溢出就变成 inf。
BF16(Brain Float 16,Google 提出):1 位符号 + 8 位指数 + 7 位尾数,指数位和 FP32 相同,所以动态范围和 FP32 一样大(不容易溢出),但精度更低(只有约 2~3 位有效数字)。
FP16 在训练大模型时最大的问题是溢出——梯度数值范围可以很大,超过 65504 就直接变成 inf,导致训练崩溃。这就是为什么 FP16 训练需要配合损失缩放(Loss Scaling)。BF16 因为动态范围和 FP32 一样,溢出问题少得多,但精度更低,在某些需要高精度累加的操作上会有误差。A100 和 H100 对 BF16 有专门的硬件加速,目前大模型训练里 BF16 比 FP16 更主流。
为什么要保留 FP32 的 Master Weight
这是混合精度训练中最容易被忽视、但最关键的设计之一。
假设模型训练中某个参数的当前值是 (FP16 可以精确表示),学习率是 ,某步的梯度是 ,那么参数更新量是:
FP16 的最小精度(机器精度)约为 ,对于量级为 1.0 的参数,FP16 能表示的最小相对变化大约是 。 远小于这个精度,在 FP16 中直接被舍入为 0——参数根本没有被更新!
如果训练数百万步,有大量参数更新被舍入为零,模型就无法继续收敛,或者收敛到一个次优解。
FP32 的机器精度约为 ,对于量级为 1.0 的参数, 完全可以被精确表示,更新不会丢失。
所以混合精度训练的做法是:保存两份参数——一份 FP16 用于前向和反向计算(快,省显存),一份 FP32 用于优化器状态存储和参数更新(精度高,更新不会丢失)。每步训练结束后,把 FP32 主权重同步到 FP16 用于下一步计算。
损失缩放(Loss Scaling)
在使用 FP16 时(BF16 通常不需要),还需要配合损失缩放来防止梯度下溢。FP16 的最小正数约为 ,训练中很多梯度值(特别是深层网络的梯度)会比这个小,直接变成 0(梯度下溢),导致参数没有更新。
解决方法:在计算 loss 之后乘以一个大的缩放因子 (比如 2048 或 32768),使梯度也被等比例放大,不容易下溢;反向传播完成后,把梯度除以 恢复原始大小,再做参数更新。
动态损失缩放(Dynamic Loss Scaling)会根据是否出现 overflow 来自动调整 :如果本步出现了 inf/NaN,说明缩放太大了,减小 并跳过这步更新;如果连续 步没有 overflow,说明缩放可能可以更大,增大 。PyTorch 的 torch.cuda.amp.GradScaler 实现了这个逻辑。
BF16 不需要损失缩放,这也是它比 FP16 更方便的原因。
面试时可以这样答
混合精度训练是用低精度(BF16 或 FP16)做计算,但维护 FP32 主权重做参数更新的方案。低精度计算快 2~4 倍,显存占用也更低,但更新精度不够。
FP32 主权重存在的必要性是这样的:假设参数值是 1.0,学习率乘以梯度得到的更新量是 ,FP16 的精度只到 级别,这个更新直接被舍入为 0,参数根本没动。FP32 精度到 ,这个更新能被精确表示。所以必须在 FP32 上做更新,算完再把结果 cast 回 FP16 用于下一步计算。
FP16 还需要配合损失缩放——梯度值很小时会下溢变成 0,解决方法是 loss 先乘一个大系数,让梯度放大不下溢,更新前再除回去。BF16 的动态范围和 FP32 一样,不容易溢出,不需要损失缩放,所以现在大模型训练基本都用 BF16。
常见追问
- BF16 精度比 FP16 低,为什么训练效果通常和 FP16 相当甚至更好?
- 动态损失缩放的
scale_factor初始值怎么设?出现 NaN 时它是怎么反应的? - 推理时能不能只用 FP16,不保留 FP32 主权重?