23. Pre-Norm 和 Post-Norm 在训练动态上有什么本质区别?为什么大模型几乎都用 Pre-Norm?
整理 Pre-Norm 与 Post-Norm 在梯度流、训练稳定性上的差异。
简单回答
Pre-Norm 和 Post-Norm 的区别只是 LayerNorm 放在残差连接之内还是之外。Post-Norm 是 ,原始 Transformer 用的就是它。Pre-Norm 是 ,现在大模型几乎都用。本质区别在梯度流——Pre-Norm 的残差路径上没有 Norm,梯度可以直接从顶层"高速公路"传到底层;Post-Norm 的残差路径上每层都过一次 Norm,深层时梯度容易爆炸或消失。代价是 Pre-Norm 训练更稳但最终性能上限略低,Post-Norm 训得起来效果可能更好但训练极难。在大模型规模下稳定性压倒一切,Pre-Norm 几乎是默认选择。
详细解释
两者的公式差异
把一层 Transformer 的子层(Attention 或 FFN)记作 ,残差连接记作 ,LayerNorm 记作 。
Post-Norm 的形式:
Pre-Norm 的形式:
差别看起来只是 LN 的位置——一个在残差和外面,一个在 sublayer 输入端。但这个位置变化导致两者训练动态完全不同。
残差路径上有没有 Norm 是关键
把上面两个公式按层堆 L 层展开,看从输出 到输入 的反向传播路径长什么样。
Pre-Norm 下,残差展开是:
注意 这一项是直接加进 的——存在一条从 直接到 的恒等通路。这条路径上没有 Norm、没有非线性、没有任何会让梯度衰减的东西。反向传播时梯度可以沿着这条"高速公路"无损地从顶层传回底层。
Post-Norm 下,每一层的输出都要过一次 LN。沿着残差展开,从顶层到底层的路径上每一层都嵌着一个 LN。LN 的本质是按 RMS(或方差)做缩放——如果输入的尺度被 LN 压回了固定范围,那残差从底层带过来的"原始幅度"信息会被反复重置,深层时这条路径就不再是恒等映射了。
直观结果:Pre-Norm 在深层模型里训练稳定,learning rate 可以开大,几乎不需要 warmup 也能训。Post-Norm 在深层(24 层以上)时几乎训不起来,必须配合非常仔细的 learning rate warmup、预热阶段和小的初始化才能稳住。
Post-Norm 真的更"差"吗
这是个有趣的问题。理论分析和一些实验表明,Post-Norm 在能训得起来的前提下,最终性能可能略高于 Pre-Norm。原因是 Post-Norm 强制每层输出都被归一化,对中间表征的"形状"约束更强,模型最终学到的特征更紧致;而 Pre-Norm 因为残差路径上没有规范化,深层时浅层信息会通过残差不断累积,到顶层时残差信号占比可能过大,反而稀释了深层 sublayer 的贡献。
DeepNorm(微软 2022 年提出)是个折中尝试:保持 Post-Norm 的形式,但通过给残差和 sublayer 输出加上精心设计的缩放系数 和 ,让深层模型也能稳定训练。DeepNorm 论文成功训了 1000 层的 Transformer,但工程上调参复杂度比 Pre-Norm 高得多。
Sandwich-Norm、ResiDual 等也是这个方向的变体,思路都是"想要 Post-Norm 的效果,但要解决它的训练稳定性"。这些方案在论文里都能展示一些性能提升,但工业界采用度都不高——大模型训练成本太高,没人愿意拿训练失败的风险换那一两个百分点的提升。
为什么大模型几乎都选 Pre-Norm
最直接的原因就是稳定性。70B、175B 规模的模型训练一次要烧几百万到几千万美元,训到一半 loss 发散整个 run 就废了。在这种成本下,"略低的性能上限"远比"稳定收敛的训练"次要。
第二个原因是 Pre-Norm 对超参数容忍度高。Post-Norm 对 learning rate、warmup 长度、初始化方差都很敏感,调一组好的超参要做大量小规模实验。Pre-Norm 对超参不敏感,新模型架构出来后调参成本低很多。
第三个原因是 Pre-Norm 配合现代优化技术更顺畅。FlashAttention、张量并行、流水线并行这些工程优化都在 Pre-Norm 架构上做了大量适配。生态上 Pre-Norm 已经是事实标准。
一个细节:最后一层的额外 Norm
Pre-Norm 架构里有一个容易忽视的细节——堆叠 L 层 Pre-Norm 之后,输出 。这个 没有过 LN(因为 LN 在每层 sublayer 的输入端,输出端没有)。如果直接拿这个 接 LM head,数值范围可能很不稳定(残差累积没有 normalize)。
实际实现里 Pre-Norm 模型在最后一层之后会再额外加一个 final LayerNorm(GPT、LLaMA 都这么做),把整个 stack 的输出做一次归一化再送给 LM head。这是 Pre-Norm 的标准做法但很少有文档单独提到。
工程上的副作用
Pre-Norm 有一个值得注意的副作用——浅层信息会通过残差路径"漏"到顶层。极端情况下,模型可能学到"前面几层的输出几乎不变,到深层才做实质处理"——因为深层只要负责对残差做小幅修正就够了。一些研究观察到 Pre-Norm 大模型在中后段层的 sublayer 输出范数远小于残差范数,意味着这些层的实际贡献有限。
这也是为什么近期一些研究(如 DeepNorm 的后续工作、Layer Pruning 工作)会探讨"大模型的某些深层是否冗余"——Pre-Norm 的架构特性使得"深层贡献小"是结构性现象,不是训练不充分。
面试时可以这样答
Pre-Norm 和 Post-Norm 的差异本质在残差路径上有没有 Norm。Post-Norm 是先残差相加再 Norm,所以从顶层反向传到底层的路径上每层都嵌着一个 Norm。Pre-Norm 是先 Norm 再过 sublayer 再残差相加,残差路径上是干净的恒等通路,梯度可以无损从顶层传回底层。
这导致两者训练动态完全不同。Pre-Norm 在深层模型里训练稳定,learning rate 容忍度高,warmup 也可以短甚至不要。Post-Norm 在 24 层以上的模型里几乎训不起来,必须非常小心的 warmup 和初始化才能稳。
有意思的是 Post-Norm 在能训起来的前提下最终性能可能略好——它对每层输出做了更强的归一化约束,特征更紧致;Pre-Norm 因为残差路径无 Norm,深层时浅层信息会累积稀释深层贡献。DeepNorm 这类工作就是在保留 Post-Norm 形式的同时解决稳定性问题,但工程上调参太复杂。
大模型几乎都选 Pre-Norm 的根本原因是训练成本——一次大模型训练几百万美元起,稳定性压倒一切,那点性能上限差别不值得冒训练发散的风险。Pre-Norm 对超参不敏感,调参成本低,配套生态也成熟。
顺带提一个细节,Pre-Norm 模型最后一层后面还要补一个 final LayerNorm 再接 LM head,因为残差累积导致顶层数值范围不稳定。这是 LLaMA、GPT 都这么做的标准实现。
常见追问
- DeepNorm 具体怎么稳住 Post-Norm 的训练? 和 怎么设?
- Pre-Norm 下深层 sublayer 贡献被稀释,怎么验证?怎么缓解?
- 既然 Pre-Norm 顶层还要补一个 final LN,为什么不直接全用 Post-Norm?