04. Megatron-LM 的张量并行是怎么切的?它的通信模式是什么?
整理 Megatron-LM 张量并行的切分方式与通信模式。
简单回答
Megatron-LM 的张量并行(Tensor Parallelism)把 Transformer 每一层的大矩阵乘法切分到多张 GPU 上并行计算。核心思路是针对 MLP 和 Self-Attention 的矩阵结构,设计了特定的列切分和行切分方案,使得每层只需要两次集合通信(AllReduce 或 ReduceScatter + AllGather),且这两次通信可以和下一层的计算重叠,通信效率高。
详细解答
为什么需要张量并行
流水线并行解决的是"模型层太多,按层分配"的问题;张量并行解决的是"单层的权重矩阵太大,单卡算不下"的问题。对于 hidden size 4096、FFN 维度 16384 的模型,一个 FFN 层的权重就有约 (FP16),多层累加起来单卡装不下,而且矩阵乘法本身就很费时,切开并行计算也能加速。
MLP 层的张量并行切分
Transformer 的 MLP(FFN)层由两个线性变换组成:
其中 是输入,,。
Megatron 的切法是: 按列切分, 按行切分。
对于 张 GPU,每张 GPU 持有 的 列(即 )和 对应的 行(即 )。
前向计算流程:
- 输入 在所有 GPU 上完整存在(无需通信)
- 每张 GPU 独立计算 ( 列切分,无需通信)
- 每张 GPU 独立计算
- 最后 AllReduce 把 相加得到完整输出
这样整个 MLP 层只需要一次 AllReduce,而且这次通信是在层的末尾,可以和下一层的计算做一定程度的重叠。
反向传播同样只需要一次 AllReduce(对输入 的梯度需要汇总)。所以整个 MLP 层前向 + 反向共 2 次 AllReduce。
Self-Attention 层的张量并行切分
Self-Attention 的结构更复杂,有多个头(Multi-Head Attention)。Megatron 的做法是按照注意力头来切分——每张 GPU 负责 个注意力头( 是总头数)。
具体来说,Query、Key、Value 的投影矩阵 各按列切分,每张卡持有 份;输出投影矩阵 按行切分。
每张 GPU 独立计算自己负责的那些头的 Attention,然后对输出做 AllReduce 合并。同样是前向 + 反向共 2 次 AllReduce。
通信量分析
设序列长度为 ,batch size 为 ,hidden size 为 ,对于 卡张量并行:
每次 AllReduce 的通信量约为 (前向一次,反向一次),一个 Transformer 层(MLP + Attention)共 4 次 AllReduce。
其中 是层数。相比于数据并行的 AllReduce(只在训练步末尾通信一次),张量并行的通信频率高很多,这就是为什么张量并行必须在高带宽互联(NVLink)的节点内使用,跨节点走 InfiniBand(带宽低 10~20 倍)会严重拖慢速度。
Megatron 的序列并行(Sequence Parallelism)
更新的 Megatron 版本还引入了序列并行(Sequence Parallelism)——在张量并行的 GPU 之间,对序列维度也做切分,主要是为了减少 Activation 的显存占用。LayerNorm 和 Dropout 这类不需要通信的操作就在切分后的序列上各自独立计算。通过 ReduceScatter 和 AllGather 配合张量并行,可以进一步降低每张卡的 Activation 显存,使得同样的 GPU 配置能训练更大的 batch size。
面试时可以这样答
Megatron 的张量并行核心思路是:MLP 的两个权重矩阵,第一个按列切,第二个按行切,每张 GPU 各自计算一部分,最后 AllReduce 合并。设计的巧妙之处在于,按列切第一个矩阵的话,每张卡的中间结果可以直接喂给按行切的第二个矩阵,中间不需要通信,只在最后做一次 AllReduce。
Attention 层类似,按注意力头来切,每张卡负责一部分头,最后 AllReduce。所以整个 Transformer 层一次前向 + 反向只需要 4 次 AllReduce。
关键约束是这个通信量很大,必须用节点内的 NVLink(带宽 600 GB/s 以上)才够,跨节点走 InfiniBand 就不行了。所以张量并行度通常和单节点 GPU 数量一致,比如 8 卡节点就用 TP=8。
更新版的 Megatron 还加了序列并行,在 LayerNorm 这类操作上把序列维度也切开,进一步降低 Activation 显存,让同样的配置能跑更大的 batch。
常见追问
- MLP 的两个矩阵为什么一个按列切、一个按行切?反过来行不行?
- 张量并行里,LayerNorm 怎么处理?为什么它是序列并行的关键点?
- GQA(Grouped Query Attention)对张量并行的切分方式有影响吗?