08. 训练损失不收敛或出现 NaN,应该怎么排查?

整理训练 Loss 不收敛或出现 NaN 的常见原因与排查路径。

简单回答

训练 Loss 不收敛和出现 NaN 是大模型训练中最常见的故障,原因复杂,排查要从学习率、梯度、数据、数值稳定性四个维度入手。出现 NaN 通常是数值溢出(梯度爆炸、Softmax 溢出、Loss Scaling 失控等),Loss 不下降通常是学习率设置、数据问题或模型初始化问题。系统化地保留中间状态日志是排查的基础。

详细解答

NaN 的常见来源

梯度爆炸是 NaN 的第一大来源。当梯度数值在反向传播中被连乘多次,可能指数级增大,超出 FP16(最大 65504)或 FP32(最大约 )的范围,变成 inf 或 NaN。梯度裁剪(Gradient Clipping)是标准的缓解手段:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

如果训练日志里发现 grad_norm 经常超过 10 甚至 100,就要警觉,快爆炸了。

Loss Scaling 失控:FP16 训练时,如果 scale factor 设得太大,乘上去的梯度会溢出变成 inf,触发 optimizer step skip,但如果连续很多步都在 skip,训练实际上没有进展。应该监控 loss_scale 的变化和 skip 的频率,如果 scale 一直在下降,说明训练不稳定。

Softmax 数值溢出:Attention 里的 Softmax 计算 ,如果 的值很大, 会溢出。标准实现会减去最大值做数值稳定化(stable softmax),但如果实现有 bug,或者在某些边界 case 下(比如全是 padding 的位置),可能出问题。FlashAttention 实现了数值稳定的 Attention,这类问题少很多。

数据中的 NaN/Inf:如果训练数据里有异常值(比如文本被错误解码、数值特征未归一化),可能直接注入 NaN 到计算图里。应该在数据 loader 里加检查,对 batch 的统计量做断言。

模型初始化问题:权重初始化值过大或过小,导致前向传播一开始就出现极端值。标准 Transformer 的初始化(对权重乘以 的缩放,其中 是层数)是经过验证的,自定义架构要注意遵循类似原则。

Loss 不下降的常见原因

学习率设置错误是最常见的。学习率太大,Loss 震荡不下降甚至爆炸;学习率太小,Loss 下降极慢。还要检查学习率调度是否正确——Warmup 没有正常工作、余弦调度的周期设置错误等。一个常见的坑是把学习率设成了相对值(比如相对于 batch size 的缩放),但 batch size 和预期不符。

数据 Label Mask 错误:SFT 训练时,如果 Loss mask 没有正确屏蔽 system 和 user 部分,模型在学"复现输入"而不是"生成输出",Loss 可能偏低但效果差;反过来如果 mask 覆盖了 output 部分,Loss 会是 0,梯度为 0,什么都学不到。

梯度消失:某些激活函数(早期的 sigmoid)在深层网络中容易梯度消失。Transformer 的 Pre-LayerNorm(在 Attention 前做 LN)相比 Post-LN 在深层模型中更稳定,就是出于梯度流动的考虑。

数据本身问题:训练数据质量差(噪音太多、指令回答不匹配),模型学不到有用的模式,Loss 下降很慢甚至在某个高值停滞。这种情况下要抽查数据,看 batch 里的样本质量。

系统化排查步骤

首先确认是否真的出现 NaN:检查训练日志里的 lossgrad_normloss_scale(FP16 时),确认 NaN 第一次出现的步数。

然后二分定位:恢复到 NaN 前的最近 checkpoint,继续训练,用相同数据看是否复现。如果是固定 batch 触发的,把这个 batch 单独拿出来调试,检查数据是否有问题。

缩小范围:注册 forward hook 和 backward hook,监控每一层的 Activation 和梯度的统计量(max、min、均值、是否有 NaN/Inf),找到第一个出现异常的层。

def hook_fn(name):
    def fn(module, input, output):
        if isinstance(output, torch.Tensor):
            if torch.isnan(output).any():
                print(f"NaN in {name} output")
    return fn

for name, module in model.named_modules():
    module.register_forward_hook(hook_fn(name))

面试时可以这样答

Loss 不收敛和 NaN 要分开看。NaN 通常是数值溢出,排查方向有几个:梯度爆炸(看 grad_norm 是否经常超大,加 gradient clipping);FP16 Loss Scaling 失控(scale 一直在降、step 一直在 skip);Softmax 溢出(QK^T 值过大,FlashAttention 自带稳定化基本能解决);数据里有 NaN(loader 加检查)。

Loss 不下降通常是:学习率设置问题(太大震荡,太小进展慢,调度配置有误);SFT 的 loss mask 配错(mask 掉了 output 就梯度为 0,没 mask 掉 input 就在学复现输入);数据质量差,模型学不到有用模式。

排查的基础是日志要全——每步记录 loss、grad_norm、loss_scale,这些是定位问题的依据。出现 NaN 的话,用 forward/backward hook 监控每层的输出统计量,二分法找到第一个出问题的层,再往那里深挖。

常见追问

  1. Gradient Clipping 的 max_norm 怎么定?太小会影响收敛吗?
  2. 用 BF16 是不是就不会出现 NaN 了?BF16 下还有哪些数值稳定性问题?
  3. 训练了几千步之后 Loss 突然上升,但没有 NaN,是什么原因?