0%

归一化层的演进

训练深度神经网络有一个反复出现的噩梦:网络越深,训练越难。梯度消失、梯度爆炸、收敛慢、对初始化极度敏感——这些问题在 2015 年之前几乎是绕不过去的墙。归一化层的故事,就是一代代研究者在这堵墙上凿出一条路的故事。


问题:网络为什么越深越难训

深层网络训练困难的核心原因之一,是 Ioffe 和 Szegedy 在 2015 年命名的 Internal Covariate Shift(内部协变量偏移)

论文:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift(Ioffe & Szegedy,ICML 2015)

所谓 Internal Covariate Shift,是指在训练过程中,每一层的输入分布会随着前面层的参数更新而不断变化。举个具体的例子:假设第 3 层网络学会了“当输入分布在 \([-1, 1]\) 之间时,激活某个特征”。

但训练中第 1、2 层的参数在不断更新,第 3 层的输入分布也在跟着漂移,可能从 \([-1, 1]\) 变成了 \([0, 5]\)。第 3 层学到的东西就失效了,必须重新适应新的分布。这个“追着变化的分布跑”的过程,让训练非常低效。

更坏的是,这个问题在深层网络里会被放大:层数越多,下游层的输入分布变化越剧烈,每一层都在做无用功。实践中的表现是:需要非常小的学习率(不然容易发散)、需要精心设计的初始化(不然梯度消失),而且收敛极慢。


Batch Normalization:强制每层输入保持稳定分布

Batch Norm 的思路非常直接:既然输入分布会乱跑,那就强制把它固定住。 在每一层的输入进入激活函数之前,先把它归一化成均值为 0、方差为 1 的分布。

对一个 mini-batch \(\mathcal{B} = \{x_1, \ldots, x_m\}\),BN 的计算是:

\[\mu_\mathcal{B} = \frac{1}{m}\sum_{i=1}^m x_i, \quad \sigma^2_\mathcal{B} = \frac{1}{m}\sum_{i=1}^m (x_i - \mu_\mathcal{B})^2\]

\[\hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}}, \quad y_i = \gamma \hat{x}_i + \beta\]

最后两个可学习参数 \(\gamma\)(缩放)和 \(\beta\)(偏移)是为了让网络在需要时能够恢复原始分布——如果强制归一化反而有害,网络可以通过学习 \(\gamma\)\(\beta\) 来“撤销”归一化。

效果非常显著。 BN 允许使用更大的学习率,大幅加快收敛;降低了对初始化的敏感性;甚至有一定的正则化效果(因为每个样本的归一化统计量会受到 batch 中其他样本的影响,引入了噪声)。2015 年发表后,BN 几乎立刻成为深度网络训练的标配,尤其在当时以 CNN 为主的视觉任务里被广泛采用。

训练和推理的不一致

BN 有一个实现上的细节需要注意。训练时,均值和方差是从当前 mini-batch 计算的;但推理时,一次只处理一个样本,没有 batch,无法计算 batch 统计量。

解决方案是在训练过程中维护一个滑动平均(running mean 和 running variance),用指数加权的方式记录历史上所有 batch 的均值和方差:

\[\mu_{\text{running}} \leftarrow \alpha \mu_{\text{running}} + (1-\alpha)\mu_\mathcal{B}\]

推理时直接用这个全局统计量来归一化,不再依赖当前 batch。

用测试集统计量“作弊”

这里有一个值得注意的细节:running statistics 是在训练集上积累的,如果测试集的分布和训练集存在偏移,推理时用的统计量就不是最优的。

一个顺理成章的想法是:在正式推理之前,先用测试集(或目标域数据)跑一遍前向传播,重新估计每层 BN 的均值和方差,再用这组统计量做真正的推理。这个技巧有时被称为 Test-Time Batch Normalization。操作上只需把模型切换回 train 模式(让 BN 重新计算 batch 统计量),在测试集上过一遍数据,收集好新的统计量,然后再切换回 eval 模式推理。

这在领域自适应(domain adaptation)场景下尤其有效:模型在源域训练,在目标域推理,两者分布不同,用目标域统计量替换训练时积累的 running statistics,往往能显著提升指标。

事实上,一些论文在报告跨域迁移性能时确实会用这个手段,但如果不明确说明,读者会默认统计量来自训练集——这是一个容易被忽略的“比较不公平”来源。

BN 的根本局限:它本质上是一个 batch 级别的操作

BN 的统计量是在 batch 维度上计算的——均值和方差都是“这个 batch 里所有样本在这个通道上的平均”。这个设计带来一个根本性的约束:batch 必须足够大,统计量才可靠。

当 batch size 从 32 降到 2 时,ResNet-50 在 ImageNet 上的错误率从 23.6% 上升到 34.7%——差了整整 11 个点。这在目标检测、视频理解这类内存密集型任务里是致命的:这些场景下每张 GPU 只能放下 1-2 张图,BN 直接失效。

更根本的问题在于 RNN 和 Transformer:序列模型的每个 timestep 的“batch”语义不同,无法直接套用 BN 的框架。


Layer Normalization:换一个维度来归一化

论文:Layer Normalization(Ba et al.,2016)

BN 的问题出在归一化的维度上。BN 是在 batch 维度上做统计——对每个特征,收集这个 batch 里所有样本的值来算均值和方差。

Layer Norm 换了一个维度:在特征维度上做统计——对每个样本,收集这个样本在所有特征上的值来算均值和方差。

BN 与 LN 归一化维度对比

具体来说,对一个隐层有 \(H\) 个神经元的网络,LN 的计算是:

\[\mu = \frac{1}{H}\sum_{i=1}^H a_i, \quad \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^H (a_i - \mu)^2}\]

\[\bar{a}_i = \frac{a_i - \mu}{\sigma}, \quad h_i = f(\gamma \bar{a}_i + \beta)\]

统计量 \(\mu\)\(\sigma\) 只依赖当前这个样本的当前这一层,和 batch 里的其他样本完全无关。这意味着:

  • batch size 可以是 1:推理和训练完全一致,不需要维护 running statistics
  • 适用于 RNN:每个 timestep 独立归一化,不受序列长度影响
  • 适用于 Transformer:每个 token 的表示独立归一化

Layer Norm 在 2017 年 Transformer 提出时就被采用(原始论文用的是 Post-LN,即归一化放在残差连接之后),从此成为序列模型的标配。

BN 和 LN 各有擅长

LN 在视觉任务上效果通常不如 BN,根本原因在于 CNN 的 channel 有一个特殊性质。所谓 channel,是卷积层输出的“特征图的层数”——一个卷积层有多少个滤波器,输出就有多少个 channel,每个 channel 就是对应滤波器在整张图上滑动后产生的二维激活图。

每个 channel 都是同一个卷积滤波器在整张图上滑动的输出,所有空间位置检测的是同一种特征(比如横向边缘、或某种纹理)。这意味着同一个 channel 内的值,无论是同一张图的不同位置,还是这个 batch 里不同图的对应位置,都在描述同一件事。把它们放在一起统计均值和方差,得到的是“这个滤波器在当前 batch 上的激活分布”——一个有明确语义的量。BN 恰好就是这样设计的,统计维度与 CNN 的结构天然吻合。

LN 则是把同一个样本的所有 channel 混在一起统计。在 CNN 里,这意味着把检测横边缘的 channel、检测竖边缘的 channel、检测颜色的 channel 的激活值全部混合求均值——这个均值对应的是什么?很难赋予它清晰的物理含义。归一化的统计量越有意义,模型从中受益越多;统计量越随意,归一化的效果就越接近噪声。


Group Normalization:在 BN 和 LN 之间找平衡

论文:Group Normalization(Wu & He,ECCV 2018)

既然 BN 在小 batch 下失效,LN 在视觉任务上效果有限,能不能找一个中间方案?

Group Norm 的思路是:把 channel 分成若干组,在每组 channel 内做归一化。

\[\mu_{n,g} = \frac{1}{(C/G) \cdot H \cdot W}\sum_{c \in \mathcal{G}_g}\sum_{h,w} x_{n,c,h,w}\]

\[\hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,g}}{\sigma_{n,g}}, \quad y = \gamma \hat{x} + \beta\]

其中 \(G\) 是组数(默认 32),每组包含 \(C/G\) 个 channel。GN 在每个样本内、每组 channel 内独立计算统计量,完全不依赖 batch。

GN 的关键价值是在小 batch 场景下对 BN 的替代:

方法 batch size = 32 batch size = 2
BN 23.6% 错误率 34.7% 错误率
GN 24.1% 错误率 24.1% 错误率

batch size 从 32 降到 2,BN 错误率暴涨 11 点,GN 纹丝不动。在目标检测、视频分类这类只能用小 batch 的任务上,GN 成为 BN 的主要替代方案。

GN、BN、LN、以及另一个变体 Instance Norm(IN)的归一化维度可以用一张图统一描述:

BN、LN、IN、GN 归一化维度示意

回到 Transformer:归一化放哪里?

到 2019 年,Transformer 已经席卷 NLP,LN 是标配。但 LN 具体放在哪里,是有争议的。

原始 Transformer(2017 年 Vaswani 等人)用的是 Post-LN:先做子层计算(Attention 或 FFN),再加残差,再归一化:

\[x_{i+1} = \text{LN}(x_i + \text{SubLayer}(x_i))\]

这个结构有一个隐患,根源在于 LN 在残差连接之后的位置。

梯度从最终输出往回传播时,每经过一层,都要穿过一次 LN。LN 会对输入做重新缩放,其 Jacobian(局部梯度)的范数并不固定——在训练早期,参数还没有收敛,LN 的缩放行为可以放大或压缩梯度。

更关键的是,Post-LN 的残差路径和 LN 是串联的:梯度无法绕开 LN 直接沿残差流回传,每一层都必须穿过 LN 的变换。层数越多,这些变换叠加的效果越不可控。

理论分析表明,在初始化阶段,Post-LN 最底层参数收到的梯度范数大约是 \(O(\sqrt{d} \cdot L)\),其中 \(d\) 是隐层维度,\(L\) 是层数。这意味着随着网络加深,底层参数的梯度会线性放大,一旦学习率稍大,底层参数更新幅度过大,整个训练过程就会发散。

这正是为什么原始 Transformer 训练时必须使用 learning rate warm-up:从一个极小的学习率开始,让模型在初期“小心翼翼”地更新,等参数逐渐进入稳定区域(LN 的缩放行为趋于合理)之后,再逐步把学习率提升到正常水平。warm-up 不是一个经验技巧,它是 Post-LN 结构不稳定性的直接补偿手段。一旦跳过 warm-up 直接用大学习率,Post-LN 的深层模型几乎必然在训练早期发散。


Pre-LN:把归一化移到残差连接之前

论文:On Layer Normalization in the Transformer Architecture(Xiong et al.,ICML 2020)

Pre-LN 把归一化的位置挪到子层计算之前:

\[x_{i+1} = x_i + \text{SubLayer}(\text{LN}(x_i))\]

位置变了一行,梯度的行为就完全不同了。

Pre-LN 中,残差流 \(x_i\) 是“干净的”——没有被 LN 处理过,梯度可以直接沿残差路径畅通无阻地回传。理论分析给出:Pre-LN 的梯度范数是 \(O(1)\),与层数 \(L\) 无关。

这带来了两个实际好处:

  • 不再需要 warm-up:梯度始终稳定,可以一开始就用大学习率
  • 更容易训练深层模型:100 层以上的 Transformer 用 Post-LN 几乎无法训练,Pre-LN 可以

代价是 Pre-LN 的最终精度有时略低于 Post-LN(因为最后一层的输出没有经过 LN,尺度不那么统一),但工程上的稳定性优势让 Pre-LN 成为现代大模型的默认选择。GPT 系列、LLaMA、Mistral 都用的是 Pre-LN 结构。


RMSNorm:LN 还能再简化吗

论文:Root Mean Square Layer Normalization(Zhang & Sennrich,NeurIPS 2019)

Layer Norm 的公式里有两步:先减均值(中心化),再除标准差(缩放)。RMSNorm 问了一个问题:减均值这一步是必要的吗?

论文的论证是:归一化起作用的核心是重缩放不变性(re-scaling invariance)——输入 \(x\) 被放大 \(\alpha\) 倍,输出保持不变。这个性质只需要除以 RMS(均方根)就能实现,不需要减均值。均值中心化提供的是平移不变性,但实验表明这对模型性能的贡献很有限。

RMSNorm 直接去掉均值,只保留缩放:

\[\text{RMS}(a) = \sqrt{\frac{1}{n}\sum_{i=1}^n a_i^2}, \quad \bar{a}_i = \frac{a_i}{\text{RMS}(a)} \cdot g_i\]

其中 \(g_i\) 是可学习的缩放参数(等价于 LN 里的 \(\gamma\)),没有偏移参数 \(\beta\)

省去了均值计算,RMSNorm 在不同模型上有 7%–64% 的速度提升,而精度和 LN 基本持平。LLaMA 和 Gemma 系列都用了 RMSNorm 替代 LN,是目前大模型里最常见的归一化方式。


DeepNorm:走向 1000 层

论文:DeepNet: Scaling Transformers to 1,000 Layers(Wang et al.,2022)

Pre-LN 解决了训练不稳定的问题,但走到几百层时,另一个问题浮现了:模型更新爆炸。训练初期,参数变化过于剧烈,优化轨迹不稳定,极深模型(500 层以上)依然容易发散。

注意 DeepNorm 并没有切换到 Pre-LN——它仍然是 Post-LN 的结构,LN 还是在最外层。DeepNorm 的思路不是改变归一化的位置,而是在 Post-LN 的框架内,通过修改残差连接和初始化来压制不稳定性。

具体地,DeepNorm 的出发点不只是梯度稳定,而是同时控制梯度和参数更新幅度。它在残差连接上引入了一个缩放系数 \(\alpha\)

\[x_{i+1} = \text{LN}(\alpha \cdot x_i + G_i(x_i, \theta_i))\]

\(\alpha > 1\) 的作用是:把残差流放大,相当于让子层的输出 \(G_i\) 在整体更新中的占比相对缩小。

同时,DeepNorm 在初始化时对子层权重乘以一个与层数相关的缩放因子,让初始时的参数更新更小。两者配合,可以证明训练过程中的模型更新幅度有界。

\(\alpha\) 的取值和架构深度有关。对 decoder-only 或 encoder-only 的架构,\(\alpha = \sqrt{3L}\)\(L\) 是层数);encoder-decoder 架构有更复杂的规则。

实验结果很直观:

层数 参数量 多语言翻译 BLEU
6 层(基线) 28.1
100 层 28.8
1000 层 3.8B 28.9

200 层的 DeepNorm(863M 参数)甚至超过了当时 48 层 12B 参数的 SOTA 模型 5 个 BLEU 点——深度的提升比单纯堆参数更有效。

为什么不在浅层也用 DeepNorm?

一个自然的问题是:DeepNorm 既然这么有效,能不能直接替代 Pre-LN 成为所有 Transformer 的默认选择,哪怕只有 6 层?

理论上可以,但实践中没人这么做,原因是代价和收益完全不对称。

DeepNorm 解决的核心问题——Post-LN 在极深模型上的训练不稳定——在浅层模型上根本不尖锐。6 层或 12 层的 Transformer,Pre-LN 已经足够稳定,甚至不需要 warm-up,训练过程毫无障碍。把 DeepNorm 用在这里,等于专门去解决一个不存在的问题。

与此同时,DeepNorm 的调参负担是实实在在的:\(\alpha\) 的取值依赖层数 \(L\),初始化的缩放系数也要随着架构深度重新计算。每次改变网络深度,这些超参数都得重新推导。Pre-LN 完全没有这个负担。

DeepNorm 的另一个隐含前提是“Post-LN 的最终精度比 Pre-LN 更高”——这个优势在层数很深时才明显,浅层模型两者差距本来就很小,DeepNorm 费心保留 Post-LN 结构的动机也就不成立了。

所以 DeepNorm 是一把专门为极深 Transformer 打造的工具,它的存在价值在于打开“500 层以上”这个之前不可能的区间,而不是取代 Pre-LN 成为通用默认。


归一化进入 Attention 内部

到 2022 年,Pre-LN + RMSNorm 已经成为主流大模型的默认配置,归一化的位置讨论似乎已经有了答案。但这个答案只解决了“残差流上该怎么做”的问题。

从 BN 到 RMSNorm,所有这些归一化方案的作用对象都是同一个东西:每一层的输入或输出,也就是残差流。Attention 机制内部——Q、K、V 的计算过程——从来没有人专门对它做归一化处理,大家默认它不需要。

随着模型规模推进到百亿参数,这个默认开始被打破。研究者发现,Attention 内部同样存在数值失控的问题,而残差流上的 LN 管不到这里。于是“在 Attention 内部加归一化”成了新的方向。

QK-Norm:在 Q 和 K 上加归一化

论文:Scaling Vision Transformers to 22 Billion Parameters(Dehghani et al.,Google,2023)

ViT-22B 是 Google 在 2023 年发布的 220 亿参数视觉 Transformer。在把 ViT 推向这个规模的过程中,工程师发现了一个令人困扰的现象:随着训练步数的增加,注意力的 logit——即 \(QK^T\) 的输出值——会指数级增长,最终导致训练崩溃。

这个问题在小模型上几乎不可见。当参数量还在亿级时,logit 的增长被其他训练动态所掩盖,不至于失控。但一旦扩展到百亿参数,矩阵维度更大,参数更新的累积效应更强,logit 的增长就变成了压倒训练的主要问题。

解法出人意料地简单:在 Q 和 K 各自的投影输出上,各加一个 LayerNorm(称为 QK-LayerNorm)。这两个 LN 在注意力计算 \(\text{softmax}(QK^T / \sqrt{d})\) 之前独立作用,把 Q 和 K 的向量幅度限制在一个稳定的范围内,从根本上断掉了 logit 指数增长的来源。

这个做法的意义在于,它把“归一化放哪里”的问题从 block 级(Pre/Post)细化到了 attention 内部的向量级。之前我们讨论的所有归一化方案都作用于残差流,QK-Norm 是第一个系统性地在 attention 机制内部进行干预的方案。

后来,Llama 3、Gemma 2 等大模型都将这个做法纳入了默认架构。

值得一提的是,这里用 LayerNorm 还是 RMSNorm 并没有实质区别——QK-Norm 的目的只是限制 Q 和 K 的向量幅度,两者都能做到。原始论文用 LN 只是沿用了 ViT 系列的惯例,Llama 3 在实现 QK-Norm 时就换成了 RMSNorm,效果相当。


回望这条路

从 Batch Norm 到 DeepNorm,归一化层走过了近十年。每一步的动机都很清晰:

BatchNorm:输入分布不稳定导致训练困难——强制归一化到标准分布,训练终于可以用大学习率。代价:依赖 batch,小 batch 下失效,不适用于序列模型。

LayerNorm:BN 无法用于 RNN 和变长序列——换一个维度归一化(特征维度而非 batch 维度),彻底摆脱 batch 依赖。代价:在视觉任务上效果有时不如 BN。

GroupNorm:BN 在小 batch(检测、视频)下失效,LN 在视觉上效果有限——折中:在 channel 的分组内归一化,兼顾两者。

RMSNorm:LN 的均值中心化是不必要的计算开销——去掉均值,只保留 RMS 缩放,速度提升 7–64%,精度不损失。

Pre-LN:Post-LN 的梯度随层数放大,训练不稳定——把 LN 挪到残差之前,梯度与层数无关,warm-up 不再必要。

DeepNorm:Pre-LN 在极深模型(500+ 层)下参数更新仍不稳定——在残差上加缩放系数 \(\alpha\),同时控制梯度和参数更新,成功训练 1000 层 Transformer。

QK-Norm:大模型扩展时注意力 logit 指数增长导致崩溃——在 Q 和 K 上加归一化,把归一化的作用范围从残差流延伸进 attention 内部。

这条路上有一个反复出现的主题:归一化的本质是控制信息流的尺度。无论是控制输入分布(BN/LN)、控制归一化维度(GN)、简化统计量(RMSNorm),还是控制残差流的幅度(DeepNorm),背后都是同一件事——让每一层在训练时看到的信号处于一个可预期的范围内,不过大也不过小。

QK-Norm 进一步把这件事的作用范围从残差流延伸进了 attention 内部。归一化层的演进还没有结束。随着模型越来越大、架构越来越复杂,“如何让深层网络的信息流保持稳定”这个问题,依然是研究的核心之一。


值得关注但尚待验证的方向

这一节收录两篇近年的工作,它们的思路有价值,但在工程实践中尚未得到广泛验证,暂时放在主线之外,供感兴趣的读者延伸阅读。

σReparam:从理论上保证注意力熵不崩溃

论文:Stabilizing Transformer Training by Preventing Attention Entropy Collapse(Zhai et al.,Apple Research,ICML 2023)

Apple Research 的研究者识别出了注意力内部的另一种不稳定模式:熵崩溃(entropy collapse)。所谓熵崩溃,是指注意力权重过度集中在少数几个 token 上,注意力分布的熵趋近于零——注意力头“只盯着一两个位置”,模型从序列中整合信息的能力大幅退化。

σReparam 的思路是对网络中所有线性层做谱归一化(spectral normalization):把每个权重矩阵除以其最大奇异值,使其谱范数恒为 1,再乘以一个可学习标量 \(\sigma\)。作者严格证明了在此约束下,注意力熵有可证明的下界——模型从原理上无法退化到完全的熵崩溃。实验中,σReparam 让 ViT 训练可以不需要 warm-up、不需要 weight decay,甚至去掉 LayerNorm 仍能收敛。

理论完备性是 σReparam 的亮点。但它对架构的改动较深(所有线性层都要重新参数化),后续主流模型并没有跟进采用。

nGPT:用超球面约束替代统计归一化

论文:nGPT: Normalized Transformer with Representation Learning on the Hypersphere(Loshchilov et al.,NVIDIA,2024)

nGPT 走得更远——它从根本上重新思考归一化的作用。核心思路是:把所有向量强制约束为单位范数,让 token 的表示在高维超球面上“流动”。词向量、权重矩阵的每一行、隐层状态,全部归一化到范数为 1 的超球面上,LayerNorm 在这个框架里彻底消失。

论文声称,在达到相同精度的前提下,nGPT 只需要原来 1/4 到 1/20 的训练步数。这个数字本身值得思考:传统 LN/RMSNorm 做的是“数值尺度归一化”,nGPT 做的是“方向归一化”,后者可能比前者更触及 Transformer 表示的本质。

不过,nGPT 目前的可信度需要打一个问号。论文发表于 2024 年 10 月,截至目前没有主流模型(Llama、Mistral、Gemma 等)跟进采用,也缺乏大规模独立复现的公开报告。

“4–20× 训练步数减少”的结论如果成立,社区早应跟进。论文来自 NVIDIA,而“需要从头训练”的理由本身也站不住脚——Pre-LN 替换 Post-LN、RMSNorm 替换 LN,同样需要重新训练,但学界都跟上去了。真正让大家观望的,更可能是可重复性存疑。把它放在这里,是因为“方向约束”的视角本身有思想价值;但它离改变行业默认配置还有相当距离。


参考文献