09. 什么是灾难性遗忘?怎么缓解?

整理灾难性遗忘的成因、典型表现与缓解方法。

简单回答

灾难性遗忘(Catastrophic Forgetting)是指模型在学习新任务后,对之前学过的任务能力严重退化。在大模型微调场景中,典型表现是 SFT 后模型在微调任务上表现很好,但通用能力(如多语言、常识推理、代码)明显下降。缓解手段包括参数高效微调(LoRA 等)、混入通用数据、控制学习率和训练步数、以及正则化方法。

详细解释

为什么会遗忘?

神经网络的参数是共享的——所有任务的知识都编码在同一套权重中。当你用新数据去更新这些权重时,原来存储旧知识的参数分布被改变了,旧知识就被"覆盖"了。微调数据越集中在某个领域,模型参数朝这个领域偏移的幅度越大,通用能力退化也越严重。

在大模型微调中的典型表现

比较常见的情况是:一个 base model 本来中英文都不错,微调时用的全是中文数据,微调完发现英文能力断崖式下降。或者用纯客服数据微调后,模型的数学和代码能力明显变差。还有一种隐蔽的情况是模型的"格式偏好"被覆盖——本来能输出 JSON、Markdown 等多种格式,微调后只会用一种固定的回答格式。

缓解方法

第一,用参数高效微调(LoRA 等)代替全量微调。LoRA 只更新极少量参数,原始权重完全冻结,对已有知识的扰动天然就小。这是目前最常用也最有效的缓解手段。

第二,在微调数据中混入通用数据。比如微调数据有 5K 条领域数据,再混入 1K~2K 条通用对话数据、代码数据、多语言数据等。这样训练过程中模型不会完全"忘记"通用能力。混入比例需要实验调整,通常领域数据占 60%~80%,通用数据占 20%~40%。

第三,控制学习率和训练步数。学习率过大会导致权重更新幅度过大,遗忘加速。通常 SFT 的学习率比预训练低一到两个数量级(比如 1e-5 到 2e-5)。训练步数也不是越多越好,过拟合微调数据同时意味着遗忘通用能力。观察验证集 loss 拐头或者通用 benchmark 开始下降就该停了。

第四,正则化方法。最直接的是 L2 正则,约束参数不要偏离初始值太远。也有一些更精细的方法如 EWC(Elastic Weight Consolidation),给每个参数根据其对旧任务的重要性设定不同的正则强度——重要参数少动,不重要的参数多动。但 EWC 在大模型上计算 Fisher information matrix 的开销很大,实际用的不多。

第五,在继续预训练阶段混入通用语料。如果你在做领域继续预训练,一定不能只用领域数据。一般建议通用语料和领域语料至少 1:1 混合,有的实践甚至推荐 3:1(通用:领域)。

一个工程上的实用原则

微调前后都要跑一遍通用 benchmark(如 MMLU、HumanEval、C-Eval 等),定量监控通用能力的变化。不要只看微调任务的效果好不好,通用能力的退化往往是隐性的,不测不知道。

面试时可以这样答

灾难性遗忘就是模型学了新东西以后,旧能力严重退化。原因很直觉——所有知识编码在同一套参数里,微调更新参数时旧知识被覆盖了。

在大模型微调中这个问题很常见。比如全中文数据微调后英文能力暴跌,纯客服数据微调后代码和数学能力下降。

缓解的手段我一般会用这几个。最有效的是用 LoRA 而不是全量微调,参数冻结天然减少遗忘。其次是在微调数据里混入通用数据,保持模型对通用任务的"记忆"。然后是控制学习率和训练步数,学习率一般 1e-5 到 2e-5,训练不要过拟合。

工程上有一个很重要的习惯:微调前后都要跑通用 benchmark。只看微调任务的指标是不够的,通用能力退化是隐性的,必须定量监控。

常见追问

  1. LoRA 真的能完全避免遗忘吗?在什么情况下 LoRA 也会出现遗忘?
  2. 混入通用数据的比例怎么定?有没有系统化的方法?
  3. 继续预训练和 SFT 阶段的遗忘表现有什么不同?