02. Self-Attention 的计算流程是什么?为什么要除以 √d_k?
整理 Self-Attention 的计算流程、缩放原因与核心公式。
简单回答
Self-Attention 的计算流程是:输入向量分别线性投影得到 Q、K、V,然后用 Q 和 K 做点积算相似度,除以 √d_k 做缩放,经 softmax 归一化后作为权重对 V 加权求和。除以 √d_k 是为了防止点积值在高维下方差过大,导致 softmax 输出接近 one-hot 分布,梯度消失。
详细解释
整体计算流程
给定输入序列 (n 个 token,每个 d 维),Self-Attention 的步骤如下:
第一步:线性投影得到 Q、K、V。 用三个可学习的权重矩阵把输入分别映射到 Query、Key、Value 空间:
其中 ,。Q 和 K 的维度 必须一致(因为要做点积),V 的维度 可以不同(实践中通常也等于 )。
第二步:计算注意力分数。 Q 和 K 做矩阵乘法,得到每对 token 之间的相似度分数:
表示第 i 个 token 的 Query 和第 j 个 token 的 Key 之间的点积相似度。这一步的计算复杂度是 ,也是长序列的主要瓶颈。
第三步:缩放(Scale)。 将分数除以 :
这一步的动机后面展开讲。
第四步:Masking(可选)。 如果是 Decoder 架构,在 softmax 之前需要加 Causal Mask,把上三角部分设为负无穷。如果有 Padding Mask,也在这一步叠加。
第五步:Softmax 归一化。 对每一行做 softmax,得到注意力权重分布:
每一行是一个概率分布,和为 1。 表示第 i 个 token 对第 j 个 token 的注意力权重。
第六步:加权求和。 用注意力权重对 V 做加权:
每个 token 的输出是所有 token 的 Value 向量的加权组合,权重就是注意力分数。
完整公式写在一起就是:
为什么要除以 √d_k?——深入理解
这是面试高频追问,需要从数学层面讲清楚。
假设 Q 和 K 的每个元素都是均值为 0、方差为 1 的独立随机变量。那么它们的点积 ,根据独立随机变量求和的方差公式:
也就是说,点积的方差随 线性增长。当 时,点积值的标准差就是 ,有些值可能到几十甚至更大。
这些大数值进 softmax 后会导致什么?softmax 函数在输入绝对值很大时,输出趋近于 one-hot 分布——概率集中在最大值上,其他位置几乎为零。这带来两个问题:第一,梯度几乎为零(softmax 在饱和区梯度极小),训练困难;第二,注意力分布过于尖锐,模型失去了"综合多个 token 信息"的能力。
除以 后,点积的方差被拉回到 1 附近,softmax 的输入分布更温和,注意力权重更平滑,梯度传播正常。
一个直觉理解: 可以把这个缩放理解为一种"温度调节"。不缩放等于用了一个很低的温度(softmax 输入值大 → 分布尖锐),缩放后温度回到合理范围(分布平滑)。这和 softmax 温度参数 的概念是一致的:,这里 就扮演了温度的角色。
常见误区
很多人以为除以 是为了"归一化"或"让数值更小"。更准确的说法是:让点积的方差不随维度增长,保持 softmax 的有效工作区间。另外,这个缩放是在 softmax 之前做的,如果在 softmax 之后做就没意义了。
面试时可以这样答
Self-Attention 的计算流程分几步:首先对输入做三次线性投影得到 Q、K、V;然后 Q 和 K 做矩阵乘法得到 n×n 的注意力分数矩阵;接着除以根号 d_k 做缩放;如果是 Decoder 架构需要加 causal mask;然后过 softmax 归一化成权重分布;最后用这个权重对 V 做加权求和,得到输出。
关于为什么要除以根号 d_k,核心原因是点积的方差会随维度线性增长。假设 Q 和 K 的元素是标准分布的,维度 d_k 的点积方差就是 d_k。如果 d_k 是 128,有些点积值就会很大。这些大值进 softmax 后会导致分布极度尖锐,接近 one-hot,一方面梯度消失训练不动,另一方面注意力过度集中在单个 token 上。除以根号 d_k 就是把方差拉回到 1 附近,让 softmax 在一个比较好的工作区间里。
直觉上可以理解为一种温度调节。不缩放等于用了很低的温度,分布太尖锐;缩放后温度合理,分布平滑。
常见追问
- 如果不除以 √d_k 而是用一个可学习的温度参数,效果会怎样?
- Multi-Head Attention 和 Single-Head Attention 的区别是什么?多头的 d_k 怎么定?
- Attention 的计算复杂度是多少?为什么长序列下它是瓶颈?