09. 大规模训练的断点续训和容错怎么做?

整理大规模训练断点续训、Checkpoint 完整状态与容错机制。

简单回答

大规模训练动辄跑几周到几个月,硬件故障、集群抢占是常态而非异常。断点续训的核心是定期保存完整的训练状态(模型参数、优化器状态、随机数状态、数据索引),恢复时从最近的 checkpoint 继续。容错设计则需要在单点故障不影响整体训练的架构上,加上快速检测和自动重启的机制。大规模训练的稳定性工程往往和模型本身一样重要。

详细解答

Checkpoint 的完整状态

一个完整的 checkpoint 不只是模型权重,还包括所有能让训练从断点精确续训的状态:

模型参数:所有层的权重,如果用 ZeRO-3 或 FSDP,还要正确地从分片状态聚合或直接保存分片。

优化器状态:Adam 的一阶动量()和二阶动量()、当前步数、FP32 主权重(混合精度时)。优化器状态的大小通常是模型参数的 2~3 倍,不保存的话续训时优化器需要重新热身,前期学习率效果会不准。

学习率调度状态:当前 step、learning rate scheduler 的状态,确保恢复后学习率曲线连续。

随机数状态torch.random.get_rng_state()numpy.random.get_state()、CUDA 的随机数状态。如果不保存,续训后的随机采样(Dropout、Data Augmentation)和原来就不一致,可能影响可重复性。

数据状态:当前训练到了数据集的哪个位置(哪个 epoch 的第几个 batch,或者 token 级别的计数)。DataLoader 的 sampler 状态也需要保存,否则续训时会重复见到训过的数据,或者打乱重来,都不好。

RNG 状态(多 GPU):每张 GPU 的随机数状态要单独保存,因为数据并行里每张卡处理不同的数据,随机数状态也不同。

Checkpoint 的保存策略

保存频率:太频繁(比如每步)会占用大量存储和 I/O 时间,影响训练速度;太少(比如每天一次)出问题时会丢失大量进度。通常按时间(每 3060 分钟)或按步数(每 5001000 步)定期保存,另外在关键节点(评测前、配比切换前)额外保存一次。

异步保存:把 checkpoint 写入磁盘的操作放在后台线程异步执行,不阻塞训练主循环。PyTorch 2.0 引入了 torch.save 的异步版本,大模型保存可能需要几分钟,异步化是必须的。

保存格式和工具:HuggingFace 的 safetensors 格式比 PyTorch 的 pkl 格式更安全(不会执行任意代码)且支持 lazy loading,适合大模型 checkpoint。DeepSpeed 有自己的 checkpoint 格式(每张卡保存分片,恢复时需要合并)。Megatron 的 checkpoint 也有类似的分片格式。

保留多个历史 checkpoint:不只保留最新的,至少保留最近 3~5 个,防止最新的 checkpoint 本身有问题(比如刚好在训练不稳定时保存的)。自动清理旧的 checkpoint 节省存储。

容错架构

大规模训练(几百张 GPU 以上)中,按统计规律,每天都可能有硬件故障。容错设计需要:

快速故障检测:心跳机制,监控每张 GPU 的状态和训练进度,发现某张 GPU 停止响应立刻报警。在分布式训练里,一张卡挂了往往会导致整个训练卡死(AllReduce 等通信操作会无限等待),检测要够快。

自动重启:检测到故障后,自动从最近的 checkpoint 恢复训练,替换掉故障节点(如果集群有空闲节点)或者在剩余节点上重新配置并行策略。像 Kubernetes 上的训练作业可以配置自动重启策略。

弹性训练(Elastic Training):更高级的容错,支持在训练过程中动态增减 GPU 数量,节点故障时自动剔除,有新节点加入时自动纳入。PyTorch 的 Elastic Distributed Training(torchrun 的弹性模式)支持这个,但配置和稳定性仍有挑战。

节点健康检查:在开始训练之前做集群健康检查(NCCL 通信测试、GPU 显存测试),提前发现有问题的节点,不让它参与训练。这能减少训练中途因硬件问题中断的概率。

实践中的稳定性工程

大规模训练的稳定性工程是一个经常被低估的工程量大户。以训练一个 70B 模型为例,可能需要 512 张 A100,连续跑 46 周。这期间:平均每天可能有 12 张 GPU 或 1~2 台机器出故障;存储 I/O 可能因为 checkpoint 频繁读写而成为瓶颈;NCCL 通信库本身可能有 bug 导致死锁;节点间网络偶发超时导致训练卡死。

解决方案是"watchdog + 自动重启 + 定期 checkpoint"的三件套,再加上详细的训练日志(每步的 loss、grad_norm、速度指标)用于事后分析。一些团队还会定期做"全量 checkpoint 验证"——加载最近的 checkpoint 验证是否能正常恢复,而不是等到真的出事才发现 checkpoint 有问题。

面试时可以这样答

大规模训练的断点续训和容错是工程上非常重要的一块,因为跑几周的训练里,硬件故障是必然会发生的。

Checkpoint 保存必须是完整状态:模型参数、优化器状态(Adam 的 m 和 v 要保存,否则续训后优化器要重新热身)、学习率调度状态、随机数状态、数据进度。保存频率通常是每 30~60 分钟或几百步异步保存一次,保留最近的几个版本。

容错架构上有几个要点:快速故障检测(心跳机制,一张卡挂了 AllReduce 会卡死,必须快速发现);自动重启(从最近 checkpoint 恢复,替换故障节点);训练前做集群健康检查(NCCL 通信测试),提前发现有问题的节点。

实践中容错这块的工程量经常超出预期,日志一定要够详细,每步的 loss 和 grad_norm 要留存,这些是事后分析的依据。另外 checkpoint 本身要定期验证,不然等真的要恢复的时候才发现 checkpoint 损坏就麻烦了。

常见追问

  1. ZeRO-3 分片保存的 checkpoint 恢复时,如果 GPU 数量变了怎么办?
  2. 异步 checkpoint 保存期间,如果训练继续了怎么保证一致性?
  3. 在训练过程中动态切换并行策略(比如增加 GPU 数量),checkpoint 怎么兼容?