我们在多达 512 个 GPU 上进行了超过 4000 次扩展实验,并测量了吞吐量(标记大小)和 GPU 利用率(标记颜色)。需要注意的是,在本可视化中,两者都根据模型大小进行归一化。
我们在多达 512 个 GPU 上进行了超过 4000 次扩展实验,并测量了吞吐量(标记大小)和 GPU 利用率(标记颜色)。需要注意的是,在本可视化中,两者都根据模型大小进行归一化。
数千块 GPU 协同工作,完美协调——这正是当今最强大 AI 模型所需的训练规模。直到最近,这种规模还只属于顶尖研究实验室的专有领域。开源让这块领域发生了翻天覆地的变化,但并没有彻底改变现状。是的,你可以下载最新的 Llama 或 DeepSeek 模型。是的,你能阅读它们的 技术和 实验报告。但最具挑战性的部分——用于协调 GPU 训练这些超大模型的训练代码、必要的知识和技术——仍然极其复杂,零散地分布于多篇论文或私有代码库中。
本开源书将尝试改变这一点。我们将从基础开始,循序渐进地带你了解如何将大型语言模型的训练从单卡 GPU,扩展到数十、数百甚至数千个 GPU,并用实践代码示例和可复现的基准来演示理论。
随着用于训练这些模型的集群规模不断增大,出现了如数据并行、张量并行、流水线并行、上下文并行,以及 ZeRO 或核融合等技术,以确保 GPU 始终保持高利用率,从而显著缩短训练时间并最大化昂贵硬件的使用效率。更进一步地,随着对大型模型的微调在真实应用中越来越受重视,人们发现微调在特定领域数据上能带来更好的效果,这通常也需要相同的分布式训练技术。在本书中,我们将从最简单的方法到最精巧的方法,依次介绍所有这些技术,同时围绕单一的故事线说明每种方法的来龙去脉。
本书假设你对当前 LLM 架构有一些基本了解,并大致熟悉深度学习模型的训练原理,但在分布式训练方面可以是新手。如果有需要,可在 DeepLearning.ai 或 PyTorch 教程等地方学习模型训练的基础知识。可以将本书视为我们第一篇关于预训练数据处理博客——即“FineWeb 博客”——的后续篇章。读完这两篇博客后,你几乎掌握了当今高性能 LLM 构建所需的核心知识,只差一些关于数据混合和模型架构选择等“最后的调味料”就可以完整配方了(敬请期待第三部分……)。
本书基于以下 三个核心支柱:
简明的原理与概念简介:在开始编码和实验前,我们希望先高层次地理解每种方法如何工作,及其优劣势何在。你会了解到语言模型在训练时哪些部分会占用最多显存,以及何时占用显存。你会了解到如何通过对模型进行并行拆分来缓解显存的限制,并如何通过扩展 GPU 集群来提高吞吐量。于是你就能理解下面这个小工具是如何计算 Transformer 模型显存占用的:
(如果你对这个小工具里发生的事情还摸不着头脑,不用担心——这正是我们要讲的内容。)
虽然上面的小工具给出的是理论分析,但我们还做了一个 在线工具,可用来预测训练过程中的实际显存用量:
清晰的代码实现:理论是一方面,但在真正实现过程中会遇到各种边界情况以及重要的细节。因此我们会在可能的地方关联到实现参考。我们会根据情况引用两份代码:
真实的训练效率基准:最后,如何在实际环境中真正扩展 LLM 训练还要取决于你的硬件环境,比如 GPU/TPU 芯片种类、网络带宽等。因此我们不可能给出一个适用于所有硬件的统一方案,但我们会介绍评测方法,而我们也已经在自己的集群上进行了此类评测!我们在最多 512 块 GPU 上运行了超过 4100 个分布式实验(加上测试一共做了 16000 多次),以探索众多可能的分布式训练布局和模型规模。
如你所见,还有很多内容需要讨论。在深入了解分布式训练的细节之前,让我们先从更高层面回顾一下在本书中会涉及到的挑战。
本书中所涉及到的所有技术都是为了应对如下三个关键挑战,而这三个挑战会在全书中反复出现:
在很多地方,我们会看到可以在(计算、通信、显存)三者之间做权衡(例如重计算或者 Tensor 并行)。找到合适的平衡是成功扩展训练的关键。
因为本书内容非常详实,我们做了一个 快速参考表,以便在阅读本书时辅助理解并帮你提炼要点。可以在阅读过程中随时翻看它!
如果你想添加“播客”式的阅读体验,可以在阅读本节时收听 NotebookLM 主播关于本书第一部分的讨论。
在扩展到多 GPU 之前,让我们先快速回顾在单卡 GPU 上训练模型的基本流程。单卡训练通常包括三步:
通常可以表示为下图:
在此图中,顶部每个彩色框代表模型的某一层(最后一行同理);红色框是对应层在反向传播时的梯度。
批大小(
小批量(batch size 小)有利于训练初期快速更新权重,且梯度带有随机性。但在训练后期,小 batch size 可能使梯度噪音过大,模型无法收敛到最优的点。而相反,如果 batch size 太大,虽然可以让梯度估计更精确,但会使每次更新投入的计算代价增大,从而降低了整体效率。关于这个话题可参见 OpenAI 的大 batch 训练论文
batch size 同时也影响训练相同数据集所需的训练时间:较小的 batch size 会导致需要更多次的优化步骤(optimizer steps),而优化步骤通常是最耗算力的。因此小 batch size 往往会拉长整体训练时间。话虽如此,在实际训练中,batch size 的最终性能对其附近的一些取值并没有那么敏感,通常存在一个相对平缓的区间。
在 LLM 预训练社区中,batch size 通常会用 token 而不是样本量来表示(
在最简单的单机训练场景下,
近来对于 LLM 预训练而言,较常见的 token batch 范围在几百万到几千万之间不等。随着预训练规模的不断增长:Llama 1 约在 4M tokens batch,对 1.4T token 进行训练;DeepSeek 则用到了 60M tokens batch,训练了 14T token。
而当我们将模型训练扩展到如此大的批次时,第一个挑战便出现了:内存不足问题。当我们的GPU没有足够的内存来容纳目标批次大小的完整数据时,我们该怎么办?
让我们先快速了解一下最初导致内存不足问题的原因。这将帮助我们对训练模型所需的内存有一些有用的直观认识。
在训练神经网络模型时,一般需要存储:
📝 注意
你可能认为对于一个模型,我们可以精确地计算出需要多少显存,但实际上显存里还会有以下内容:
import torch; torch.ones((1, 1)).to("cuda")
并配合 nvidia-smi
观测 GPU 显存来验证。这些需要存储的项目以张量形式存在,不同张量有不同的 形状(shape) 和 精度(precision)。形状由例如 batch size、序列长度、模型隐层维度、注意力头数、词表大小,以及是否进行模型切分等超参数决定;精度则对应 FP32、BF16 或 FP8 等格式,会影响每个元素所占的字节数(4、2 或 1 字节)。我们稍后会在 混合精度训练部分更详细讨论精度,这里只需知道不同精度会影响存储需求即可。
那如何快速确定模型的实际显存占用?一个简单的方法是直接在真实环境中测量。
借助 PyTorch 的 profiler 工具,我们可以查看训练过程中不同阶段的显存分配。可以发现显存使用并非静态,而是在训练过程(尤其是单个 step 内)不断变化:
显然,第一步看起来和后续步骤有些不同;但先让我们看看一次完整训练 step 的内存使用模式:前向传播时,会随着激活值的产生,显存占用快速上涨;随后在反向传播时,梯度逐渐累加,而用于计算梯度的激活值也在此过程中逐步释放。最后,我们进行优化步骤,此时需要所有梯度,然后更新优化器状态,接着进入下一次前向传播。
为什么第一步看上去与众不同:激活值先快速增加,然后会有一段时间保持在高位。这是因为第一步里 Torch 的缓存分配器(allocator)要初始化分配显存块,以便随后的训练步骤中不必再频繁搜索可用显存(见 Zach 的博客)。在第一步结束后,我们还要为优化器状态分配显存,这通常会让后续步骤的起始显存占用有所提高。
现在我们已经对显存变化有了初步认识,下面我们就来看看在扩展训练规模的同时,如何在保证计算效率的前提下,让模型的各种存储需求(激活值、参数、梯度、优化器状态)都在显存范围内。
让我们先看一下前面提到的前三项:模型权重、梯度和优化器状态。我们可以相对容易地对它们的显存需求做出估算。
对一个简化的 Transformer LLM,其参数量可按 如下公式计算:
其中
对于这些参数以及对应的梯度,我们在显存中需要的空间就是“参数个数乘以每个参数的字节数”。在传统的 32 位浮点(FP32)训练中,权重和梯度都是 4 字节,优化器在使用 Adam 时,需要存储动量和方差各 4 字节,还会加上一些管理用的结构。总结起来:
现在让我们看看如果我们使用较低精度,会发生怎样的变化。出于稳定性考虑(参见下文的混合精度训练部分),我们通常不会采用完全的低精度训练,而是使用一种高低精度混合的方法,称为“混合精度”
以下是总结:
📝 注意
有些库会将梯度以 FP32 存储,则需要额外的 bf16
对较小值是有损的,为了稳定性,一些库就会把梯度也保存在 FP32。可参考 DeepSpeed issue 了解更多信息。
📝 注意
在文献和代码中,有时会把这份 FP32 权重副本称作 “master weights”。
有意思的是,混合精度本身并不会减少总体显存需求,因为虽然参数本身变少了,但又多了一份 FP32 副本,甚至如果把梯度也保存在 FP32,整体开销还会上涨。但它能带来很大好处:用 BF16 进行前向/反向计算可用到 GPU 的低精度优化运算单元,速度更快,同时前向传播中激活所需的显存也能减少,这在大 batch 或长序列时尤为重要。
来看几个常见模型规模下,计算或 BF16 混合精度时,这些存储需求的大致量级:
模型参数量 | FP32 或 BF16(不含 FP32 梯度累加) | BF16 + FP32 梯度累加 |
---|---|---|
1B | 16 GB | 20 GB |
7B | 112 GB | 140 GB |
70B | 1120 GB | 1400 GB |
405B | 6480 GB | 8100 GB |
如上表所示,当模型达到 7B 规模时(!),仅权重和优化器状态就已经远超许多 GPU 的显存(如 H100 的 80GB)。
目前我们只考虑了模型可以单卡放得下的情况,接着让我们看看另一个显存主要来源:激活值。
激活值的占用比权重、梯度和优化器状态更为复杂,因为它会受输入有关。若你不太确定为什么反向传播需要存储激活值,可参考这篇参考文档。仔细分析 Transformer 反向传播的计算过程后,可估算训练时在混合精度下的总激活显存,结论可参见原版 NVIDIA 的重计算论文
这里
激活值需要在反向传播时用来计算梯度。与权重、梯度和优化器状态相比,激活值在计算图中随时在变化,因而需要在整个 forward + backward 周期进行分配和释放。
一个非常重要的观察是,激活值的占用会随序列长度 bs=1
)的激活值显存走势:
图中可以看到:对于短序列(或者 equivalently 小 batch size)时,激活值几乎可以忽略不计。但一旦序列长度达到 2-4k,它就会变成一个相当庞大的开销,而此时参数/梯度/优化器状态的开销也不再是主要矛盾。
对于大输入 tokens(即大 batch size/长序列),激活值会成为主要的显存负担。
有没有办法抑制这种“激活值膨胀”?你问得好!
下面来介绍我们的第一个技巧——激活重计算(Activation Recomputation)。它可以帮助我们将激活值的占用限制在合理范围。它是当下大型模型训练中不可或缺的重要技术。
激活重计算(也叫 gradient checkpointing 或 rematerialization)的核心思想是:在前向传播时丢弃部分激活值,从而省显存;需要它们做反向传播时,再运行一次子前向过程把它们计算回来,换取多一些计算量来节省显存。若不开启重计算,我们会在每个可学习操作(比如 feed-forward,layernorm 等)之间都保存激活值。启用重计算后,我们只保存少量关键位置的激活值,而丢弃其它激活值;当反向传播需要时,再用已保存的激活值进行部分前向运算重算出所需内容。可通过下图直观理解:
在实际实施中常见有几种策略来决定哪些激活值需要保存:
让我们用实践测量来看看各种重计算策略能带来多大程度的显存优化,以及 selective 重计算如何在内存节省和计算开销之间取得平衡:
图中还能看到,越小的模型(h 较小)在长序列(seq 较大)下,激活值占比就越明显,重计算带来的收益也更大。
📝 注意
在测量训练对 GPU/TPU/加速器的利用率时,一般要把重计算计算量纳入总 FLOPs,再和理论峰值 FLOPs 进行对比,以得到实际硬件 FLOPS 利用率 (Hardware FLOPS Utilization, HFU)。因为重计算会增加实际运算量。
但最终我们往往更关心的是,从头到尾完成同样数据量训练所需要的总时间,因此评估不同加速器时,若某加速器拥有足够的显存以完全跳过重计算,它所做的实际运算就会变少,这会导致它的 HFU(硬件利用率)看似降低,但训练速度可能反而更快。为此,有人提出仅统计模型本身前+后向步骤所需的 FLOPs(不含重算)来计算 Model FLOPS Utilization (MFU)。这在比较硬件时有时更有意义。
在当今的训练框架中,FlashAttention(后面会介绍)已成为注意力优化的标配,它在反向传播中就会自动重算注意力得分和中间矩阵,而不存储它们,这本质上就是 selective 重计算的一种。所以只要用上 FlashAttention,你其实已经在用 selective 重计算了。
小结:激活重计算增加了些许 FLOPs(约多 2~30%),却显著减少了显存访问带来的延迟和内存需求。从而带来显存大幅节省。
因此它对当前内存不大但运算能力强的 GPU 来说特别有用,即使多做了一些计算,整体上也会因为减少了更多的内存访问而变得更快。
不过即便如此,激活值的开销依然随 batch size 线性增长,那如果想要用很大的 global batch size,该怎么办?这就要看看我们的另一个法宝——梯度累加(gradient accumulation)!
梯度累加是个非常直观的方法,用于避免因过大 batch size 带来的显存爆炸。它做法是:把原本的一次大 batch,拆成多个更小的 micro-batch,每次只做一次前后向并计算出梯度,然后把这些梯度累加起来,最后再做一次优化器更新(optimizer.step)。这样,就能在保持相同 global batch size 的前提下减少每次前向时的激活显存占用。
我们把每次前向使用的 batch size 称作 micro batch size
(mbs),把整个全局 batch size(在每次 optimizer step 之间)称为 global batch size
(gbs)。如果我们在一次优化步骤中处理了 8 个 micro-batch,那么 global batch size 就是 mbs × 8
。
因此,之前我们文中统称的 batch size
其实就是 global batch size
。用符号表示就是:
不过,梯度累加也有个明显缺点:每个优化步骤需要重复多次前后向,因此增加了计算量,从而放慢训练速度。没有免费的午餐啊!
梯度累积使我们可以通过仅计算部分微批次来降低激活内存的使用,因为激活内存在批次大小增大时会线性增长。
然而,一个缺点是梯度累积要求在每次优化步骤中进行多次连续的前向/反向传播,这会增加计算开销并减慢训练速度。天下没有免费的午餐!
但是,如果你仔细跟随,你可能已经注意到,每个微批次的前向/反向传播实际上可以并行运行。前向和反向传播彼此独立,唯一的区别在于输入样本是独立的。看来是时候将我们的训练扩展到多GPU上了!
在此之前,让我们快速了解如何通过分布式训练工具箱中最有用的工具之一——性能分析器,来可视化计算和通信情况。这个工具对于理解和验证GPU之间以及计算之间的通信方式以及瓶颈所在将非常有用。
PyTorch 自带的 profiler 能够跟踪并可视化 CPU 和 GPU 在训练过程中的行为。它可以很方便地查看:
下面是一个简单示例:
这会生成一份 trace,可在 TensorBoard 或 Chrome tracing viewer 中查看,展示:
示例 trace,显示 CPU 线程异步向 GPU 提交 kernel,多条 CUDA stream 并行进行计算和通信
通过 trace 我们能发现很多瓶颈,比如:
这些信息对优化分布式训练性能至关重要。比如,你可以清晰看到梯度同步是否成功和反向计算重叠到一起等。
好了,现在我们可以正式进入多 GPU 训练世界,看看第一个扩展技巧 —— 数据并行!
如果想让阅读体验更像播客,可以播放此音频,收听 NotebookLM 主播对以下章节的讨论。
数据并行(DP)的核心思想是:在多块 GPU 上复制相同的模型副本,只是每张卡处理不同的微批数据。如此一来,称为“数据并行”。想必很多人已在简单训练示例中用过 DP,但我们这里会深入探讨更多实现细节。
在每个GPU使用不同的微批次意味着每个GPU上会有不同的梯度,为了保持各GPU上模型实例的同步,模型实例之间的梯度将在反向传播过程中、优化步骤之前,通过一种称为“all-reduce”的操作进行平均。
这涉及到我们的第一个“分布式通信”原语:all-reduce,它负责处理GPU实例和节点之间的同步与通信。
一个简单的DP实现会等待反向传播完成以获得所有梯度,然后触发一次all-reduce操作,对所有DP级别的梯度进行同步。但这种计算与通信依次进行的方式绝对不可取!因为我们不希望在通信过程中GPU处于空闲状态,如上图所示。
相反,我们应该尽可能地重叠通信与计算,使它们尽可能同时进行。
让我们看看三种优化方法,它们使我们比最初的简单实现做得更好!
我们刚刚描述的简单DDP方法的主要缺点在于,在反向传播(计算)之后,我们必须等待梯度同步(通信)完成,才能更新参数。我们能否将这种通信与计算重叠进行?答案是肯定的!
如上图所示,一个层的梯度(红色方框)可以在之前层的梯度(左侧红色方框)尚未计算完毕时就被收集和求和。例如,一旦最后一层的反向传播完成(右侧最后一个方框),这些梯度就可以在对前面各层继续进行反向计算时被收集和求和,计算过程向左推进。
这可以通过在pytorch中为每个参数附加一个all-reduce钩子函数来实现。只要该参数的梯度准备就绪,就会触发一次all-reduce操作,而其他参数的梯度仍在计算中。这种方法使大部分all-reduce操作与梯度计算重叠,从而提高了效率。下面是一个简单的函数,用于附加钩子:
计算与通信的重叠减少了整个模型中等待梯度同步所花费的时间。梯度同步(至少部分地)可以与反向传播并行进行,从而显著加速数据并行。下面是一个带有同步重叠的简单DP完整实现:
这是我们首次展示“计算与通信的重叠”,在本文中我们会多次讨论这一关键技术,它是实现最大扩展效率的必要手段。但我们仍可以进一步提高效率!
GPU 的效率在处理大张量时通常更高;若我们对每个参数单独执行 all-reduce,会有很多小通信操作,不够高效。更好的做法是把多个梯度合并到一个大的 bucket,一次 all-reduce,即可显著减少通信开销,提升效率。
示意图:
类似打包快递,把许多小包裹装进大箱子里,一次性发走。具体实现上,你会预先分配一个大的 buffer,把多个梯度放进来,然后做一次性 all-reduce。例如:
最后,前面我们提到的梯度累加与数据并行结合时,需要留意在累加阶段是否还要进行梯度同步。朴素做法可能每次反向传播都发起 all-reduce,但其实只有在完成了所有 micro-batch 的反向传播后再进行一次 all-reduce 就足够了,能减少通信开销。
在 PyTorch 中通常可用 model.no_sync()
decorator 来禁止某些 backward 阶段的梯度同步。
📝 注意
进行通信操作时,张量必须在内存中是连续的才能避免额外的拷贝。因此常常会预先分配与激活值或参数大小相同的连续内存块,专门用于通信。不过这样会带来额外的峰值内存占用。
接下来我们再回到 global batch size 的公式。
带有数据并行和梯度累加之后,全局 batch size 变为:
这里
给定目标全局批次大小,我们因此可以在梯度累积步数与数据并行进程之间进行权衡,以加速训练。
在实践中,人们倾向于尽可能增加数据并行节点(DP)的数量,而不是梯度累积步数,因为数据并行本质上是并行的,而梯度累积则具有顺序性质。当仅扩展数据并行还不足以达到目标全局批次大小且GPU资源耗尽时,才会在数据并行基础上增加梯度累积。
能够在不同样本上分布训练,为我们提供了并行化的第一个维度,从而构成了这种1D并行性(我们将逐步介绍另外4个维度)。
让我们快速总结一下如何通过一个草稿配方来设置我们的第一个1D并行训练,即一个最佳数据并行设置:
GBST
),可以通过查阅文献或进行实验测量模型收敛性来确定。如果梯度累加得到的值小于 1(相当于 GPU 多了?🤑),可以选择减少用到的 GPU 数量,也可以扩大 global batch size,或试着减小 mbs 看是否速度更快(增大数据并行度会带来通信开销,但减少 mbs 意味着可能浪费一部分硬件算力,需要实验来平衡)
举个例子:假设我们想要训练一个新模型,global batch size 设为 4M tokens,序列长度设为 4k,所以 batch size(以样本数计)约为 1024。我们发现单卡仅能放下 mbs=2(再大就 OOM ),若我们有 128 块 GPU,就可以在每个 GPU 上做 4 次梯度累加来得到全局的 1024 样本,这样就能到达目标 4M tokens。接着如果我们拥有 512 块 GPU,就可以把梯度累加减少到 1 并获得更快的训练速度。
📝 注意
当并行度到达 512 块 GPU 乃至以上时,网络瓶颈(ring latency 等)会使得 DP 的通信无法被完全隐藏,GPU 利用率会下降,吞吐量可能开始恶化。此时就要考虑其他并行维度。
虽然数据并行能够很好地将all-reduce梯度同步与反向计算重叠以节省时间,但当规模增大时,这一优势开始显现不足。原因何在?因为随着我们添加越来越多的GPU(数百甚至数千个),它们之间协调的开销显著增加,网络需求也变得过于庞大,从而抵消了并行带来的好处。结果是,每增加一块GPU,我们的设置效率就会逐渐降低。
让我们通过一些基准测试来看看这一现象在实际中的表现:
可以看到,当 GPU 数量超过某个范围后,吞吐量开始明显下降,而每卡所需内存并不会随着 GPU 数的增加而减少。
数据并行是扩展到更多 GPU 的第一种(也是相对简单的)并行策略。它和梯度累加的原理很像,但通过并行加速微批处理,来提高训练吞吐量!
然而,我们之前也提到,DP 默认要求单卡至少能放得下完整的一次前向(mbs≥1),对于更大的模型(如 70B+)即使激活重计算打开后,也还是放不下。比如:
与此同时,我们也发现当 DP 扩展到数百上千块 GPU 时,通信已经成为巨大瓶颈,还有没有其他策略?其实可以让某些张量放到 CPU ,或者把这些权重/梯度/优化器状态在 GPU 之间做分块切分。接下来我们将探讨两类思路:分片并行(tensor/context/pipeline 并行)和 共享(DeepSpeed Zero/FSDP)。它们彼此独立,也可以组合。
因为 ZeRO 方法和 DP 密切相关,所以先从它开始。
本节介绍 DeepSpeed ZeRO,专门用来减少 LLM 训练中的冗余显存占用。
数据并行可以提高吞吐量,但也带来了额外的内存浪费:所有 DP 副本都要存储同样的优化器状态、梯度和参数,形成重复。ZeRO 的思路正是通过沿数据并行维度对上述对象(优化器状态、梯度、参数)进行切分,来消除冗余。这会要求一定的额外通信操作,以便在需要时重建完整的参数,但相比内存收益常常很值得。
这种方法分为ZeRO的三个可能的优化阶段:
你可能会注意到,我们没有把激活值包含在可分片的对象中。因为模型的每个DP副本接收到的微批次不同,每个DP节点上的激活值也各不相同,所以它们不会被复制,也就无法进行分片!
让我们从内存需求角度看看不同阶段的 ZeRO 到底能省多少。
你可能还记得在上一节中关于标准训练过程中优化器状态、梯度和参数的内存使用情况。我们用
如果我们只关心不带 FP32 梯度累加的混合精度训练,其总占用就是
如果我们不以fp32累积梯度,总内存消耗为
ZeRO的理念是将这些对象分片存储到各个数据并行(DP)节点上,每个节点只存储部分数据,在需要时再重构这些数据,从而使内存使用量降低为数据并行度
这里
让我们通过探讨每个ZeRO阶段的工作原理来解释这张图及其数值。我们将从ZeRO-1开始。
【校对到这里,2025年2月21日】
在最基础的 DP 里,每个副本都会在反向传播结束后得到完整梯度,然后都进行一次优化器更新,这似乎有很多重复计算。ZeRO-1 把优化器状态在 DP 进程间做分割。也就是说,每个副本只存全部优化器状态的 1/N_d,做优化器更新时也只更新这部分 fp32 权重。
可是在 forward 时,每个副本都需要完整的(BF16)权重进行推理,所以在做完 optimizer update 后,需要一次 all-gather(新的通信原语)来把更新过的权重合并。
这样就解释了下图中的公式
下面用更直观的图示来理解。设想一个两层的模型分到两个 GPU 上,但在这个场景里模型参数是完整复制的,只是优化器状态分片:
通信上,和常规 DP 相比,ZeRO-1 主要变化是:把梯度的 all-reduce 换成 reduce-scatter,并多了一个用于同步权重的 all-gather:
注意,在常规 DP 里我们可以把 all-reduce 与后向传播进行重叠。ZeRO-1 里依然可以这样做,并且新增加的全量权重 all-gather 也可能尝试和更新下一层的操作重叠,当然这需要额外的调度。
ZeRO-1 只对优化器状态做分片,可是梯度在每个副本依然是完整的。若我们并不真正需要所有梯度,那是否可以进一步节省?这就引出了 ZeRO-2。
事实上,在 ZeRO-1 中,每个副本只更新 1/N_d 参数,所以它实际上并不需要其他参数对应的梯度。那为什么还要得到完整梯度?完全没必要,于是就可以把梯度也进行分片。这样在反向传播时不再进行 all-reduce,而是 reduce-scatter。同样可显著节省显存。
就此我们可以写出它的内存公式
与 ZeRO-1 的区别只是把梯度也进行了分片;在通信操作上还是 reduce-scatter + all-gather。
那似乎还能再往下走,甚至可以把模型参数本身也分片起来?这就是 ZeRO-3!
ZeRO-3(又称 Fully-Sharded Data Parallelism,FSDP)在前两者的基础上,进一步把模型权重也切分到各个 DP 副本上。
📝 注意
PyTorch 原生的 FSDP 与 ZeRO-3 基本是同一个概念,这里两者我们就混用了。
这样一来,当我们需要在前向或后向中访问参数时,就需要先把它们从各个副本“拿”过来。也就是在需要该层参数的时刻进行 all-gather,前向和后向完毕后释放这部分显存。示意图:
也就是说,每计算一层就要进行一次 all-gather,然后再释放参数。这带来了不少额外通信,但好处是:并不需要存所有层的参数。对于一个有 40 层的模型,仅在计算每层前向和后向时才收集它的权重,然后立刻释放,如此只在单卡上保留部分权重。
从通信角度看:相较于 ZeRO-2,ZeRO-3 每一层在 forward、backward 里多做了 2 次 all-gather(分别是正向和反向),以及一次 reduce-scatter 用于梯度,故总共
听起来通信开销很大,但其实如果我们能将这些 all-gather 与层的计算相互重叠(prefetch),这部分开销常能被隐藏。比如在前向时,算第 n 层时去预取 n+1 层的参数,以此类推。不过当 DP 规模非常大时,通信还是可能成为瓶颈。
内存方面,现在进化到了
总结一下:使用 DP+ZeRO,能让训练大模型时,每块 GPU 仅需存储模型的一个碎片(ZeRO-3),同时梯度和优化器状态也分别按需存储,大大降低单卡显存占用。
然而,ZeRO-3 仍然依赖单卡能放得下某个层,以及它无法削减大序列情况下的激活值开销——因为激活值在 DP 内并不存在重复,需要用到别的办法。正如我们之前说的,如果是长序列导致激活值占用过大,那就可以考虑张量并行(Tensor Parallel)来分割激活值。下面来深入了解一下。
让我们进入第二种并行轴——张量并行(TP)!和 ZeRO-3 不同,TP 不用反复 all-gather 参数,而是从运算角度直接对矩阵乘法进行分割,让权重、激活都自然分布在多卡上。
若想让阅读体验更像播客,可播放此音频,收听 NotebookLM 主播对以下章节的讨论。
我们已经看到,ZeRO-3 可以分片权重、梯度和优化器状态,但对于激活值则束手无策,同时 ZeRO-3 的通信模式也需要比较高的带宽才能隐藏。张量并行(TP)则可以将权重、梯度,以及激活值也拆分给多块 GPU,并且不需要在计算前先收集所有权重。但它会在计算层面引入更多通信原语,需要在算子内部进行通信。
让我们先从简单的矩阵乘法示例开始,理解张量并行的基本原理。
我们知道矩阵乘法
第一条表示可以在列维度上拆分乘法,第二条表示可以在行维度上拆分乘法。应用到神经网络中,乘法通常写作
在张量并行中,会把矩阵按照某个维度分割成 N 份,分别分配到 N 张 GPU 上做并行计算。举例来说,可以对 W 的列方向做切分,也叫 column-linear。每张 GPU 会获得 W 的部分列,同时 X 在每张卡上都要有副本,需要通过 broadcast 同步过去。计算结束后,再用 all-gather 把结果拼起来:
下面是一个具体代码示例:
另一个方式是沿行方向分割 W(即 row-linear),这时 X 需要切分给各 GPU(scatter),GPU 计算得到部分输出后,需要 all-reduce 来汇总:
对应的实现:
以上只是单次矩阵运算。在 Transformer 中,会有多层注意力和多层 MLP,需要把它们结合在一起实现真正的张量并行。
让我们把玩具示例扩展到一个完整 Transformer 层,其中主要包括 前馈网络(MLP)和多头注意力(MHA)。两者都可以用前面提到的 row-linear 或 column-linear 进行拆分。
先看 MLP,把它拆成 “列并行+行并行” 两步,如下图。先对输入做广播,再做列切分进行乘法,最后行切分合并输出。
这样比先行后列要更有效率,因为在第一步可省去一次 all-reduce。
对于多头注意力,其 Q、K、V 矩阵可列并行,每张卡只存一部分头;输出投影则行并行。如果是多查询注意力(MQA)或分组注意力(GQA)也可以有类似思路。注意头数不能少于并行度,否则会增加额外通信。
张量并行可以把中间激活值(包括 MLP 内部)也拆分到不同卡上,从而减少单卡激活。唯一问题在层归一化(layer norm)等操作需要完整向量,所以需要一次 all-gather 通信。
所以张量并行虽然能很好地拆分模型与激活,但中间仍存在一些全局性操作(如 LN)需要全量通信,难以完全隐藏通信。
让我们看看实践中的折衷:通信成本与内存开销的平衡。比如在单节点内(NVLink 带宽较高)做 TP 往往更高效,跨节点时带宽变慢,效率会明显下降。我们用基准测评了不同 TP 并行度的吞吐量和显存:
图左显示,随着 TP 并行度增加,通信开销变得突出,吞吐量下降。尤其过了 8,跨节点后受带宽限制更严重。图右则表明使用更多 TP 可以减少批大小上限中的显存占用。
我们再进一步看看 70B 模型时的显存表现:
增大 TP 确实能减少权重、梯度、优化器状态和部分激活的显存使用,让 70B 模型在单节点 8 卡上也可放下。
但你或许会注意到对于 LayerNorm 和 Dropout 等操作,TP 并没有分割激活,需要一次性 gather;于是顺理成章地出现了一个补充:序列并行(Sequence Parallelism,SP),用来解决这些操作造成的“剩余”激活开销。
序列并行(SP)的思路是对于列并行无法处理的那部分操作(比如 LayerNorm 和 Dropout),换个思路,绕到序列维度上去切分,这样可以进一步分担激活显存。
📝 注意
“Sequence Parallelism” 这个词容易混淆:在本节中,它是指和张量并行配合使用的 SP,用来处理那些需要完整隐层维度的操作(如 LN、Dropout)。而后面会提到“Context Parallelism”也 sometimes 也被称作 Sequence Parallelism,但我们会在本书里把它叫做“Context Parallelism”,以示区分。
LayerNorm 需要在隐层维度上计算均值和方差
实现上,需要在不同阶段进行不同的通信操作。例如,在完成 row-linear 后可能要做 reduce-scatter 或 all-gather 去切换到下一个并行方式。看起来有点复杂,但可以概括如下:
在前向阶段:
在后向阶段:
类似地,“g” 和 “g*” 对应 all-gather 和 reduce-scatter,用来在进入张量并行部分或返回序列并行部分时进行形状转换。
相比纯粹的 TP,TP+SP 可以显著减少激活内存,让我们可以进一步增大 batch size 或序列长度。例如,仍以 70B 模型为例:
我们发现利用 TP/SP=16,可以在单卡显存里放进 16k 序列长度,这比单纯的 TP 要好不少(虽然 TP=16 在跨节点时通信会慢,但若只需要单节点内,也还可以接受)。
那么,TP+SP 是否会比单纯 TP 带来更多通信?从操作数量看,TP 里我们每层需要 2 次 all-reduce,而 TP+SP 里会变成 2 次 all-gather + 2 次 reduce-scatter。但从带宽角度看,一次 all-reduce 等价于一次 all-gather + 一次 reduce-scatter,故理论上的通信量相近。不过操作次数增多,也会带来一些基准延迟开销。
下图展示了我们对一个 3B、4096 序列长度的模型在不同 TP 并行度下的吞吐量和最大可承受 batch size 情况:
再次可见:随并行度升高,单卡吞吐量显著下降,但可支持的最大 batch size 提升。尤其在跨节点时,通信带宽成瓶颈。
总结:TP 能并行拆分激活和权重,以缩减显存;SP 则针对 LN/Dropout 等仍需完整激活的操作,用序列维度补充切分进一步节省显存。
📝 注意
当在 TP 区域里对 LN 等操作进行 SP 切分时,需要留意 LN 的权重在反向中也会产生日不一样的梯度,需要多一次 all-reduce 保持同步。不过因为 LN 参数远小于全连接层,所以通信代价也相对较小。
不过,如果我们有极长的序列(如 128k+),即使用了 TP+SP,也还是会有一部分注意力计算需要全序列,消耗依然巨大。于是我们就再往前迈一步,引入 Context Parallelism。
通过张量并行和序列并行,我们可以显著降低每个GPU的内存需求,因为模型权重和激活值均分布在各个GPU上。然而,当训练的序列越来越长(例如当每个序列扩展到128k个token甚至更多时),我们仍可能超出单节点可用内存,因为在TP区域内我们仍需处理完整的序列长度。
此外,即使我们采用完全重新计算激活值的方法(这会带来约30%的沉重计算开销),我们仍需在内存中保留部分层边界的激活值,而这些激活值随序列长度呈线性增长。让我们来看看上下文并行如何帮助我们:
上下文并行的核心思想是将序列并行的方法(也就是沿序列长度进行拆分)的思路应用到已经采用张量并行的模块上。我们将对这些模块沿两个维度进行拆分,从而也减少序列长度带来的影响。经过前面所讨论的内容,你会发现这种方法非常直观,但……这里有一个技巧,所以请保持警惕!
对于上下文并行,就像序列并行一样,我们将沿序列维度拆分输入,但这次我们对整个模型进行拆分,而不仅仅是对之前Tensor+Sequence并行中涉及的部分模型。
拆分序列不会影响大多数模块,如MLP和LayerNorm,因为它们对每个token的处理是独立的。它也不像TP那样需要昂贵的通信,因为只拆分了输入而非权重矩阵。就像数据并行一样,在计算梯度后,会启动一次all-reduce操作以在上下文并行组内同步梯度。
不过,有一个重要例外需要特别注意,那就是注意力模块(呵呵……双关语来啦 :D)。在注意力模块中,每个token需要访问来自所有其他序列token的键/值对,或者在因果注意力的情况下,至少需要关注每个前面的token。
由于上下文并行是沿序列维度将输入分布到各个GPU上,注意力模块将需要各个GPU之间进行充分通信,以交换必要的键/值数据。
如果我们采用简单的方法,这听起来会非常昂贵。但有没有办法能更高效、更快速地完成这一操作呢?幸运的是,有一种核心技术可以高效地处理键/值对的通信,叫做环形注意力。
📝 注
上下文并行与Flash Attention在概念上有一些相似之处(更多细节稍后会提到)——两者都依赖在线Softmax计算以减少内存使用。虽然Flash Attention侧重于在单个GPU上优化注意力计算本身,但上下文并行通过将序列分布到多个GPU上来实现内存降低。
在这种注意力机制的实现中,每个GPU首先发起一个异步通信操作,将其键/值对发送给其他GPU。在等待其他GPU数据的同时,它会计算当前已在内存中的那部分数据的注意力得分。理想情况下,在这次计算结束前,下一个来自其他GPU的键/值对就已经接收完毕,使得该GPU在完成第一轮计算后能够立即开始下一轮计算。
让我们来说明这一点。假设我们有4个GPU和4个token的输入。最初,输入序列沿序列维度均匀拆分,因此每个GPU仅拥有一个token及其对应的Q/K/V值。假设Q1、K1和V1分别表示第一个token的查询、键和值,并且它们位于第1个GPU上。注意力计算需要4个时间步来完成。在每个时间步中,每个GPU依次执行以下三个操作:
我们将这三个步骤执行四次以完成注意力计算。
整个过程在4个GPU上的表现如下面的动画所示:
从这个动画中,你应该能明显看出作者为何选择将这种方法称为环形注意力。
不过有一个大问题,那就是环形注意力的简单实现会导致因果注意力矩阵形状造成的GPU间工作不平衡。让我们通过考虑带有因果注意力掩码的注意力得分矩阵来观察Softmax计算:
Softmax是按行计算的,这意味着每当一个GPU收到一整行的所有token时,就可以进行计算。我们看到GPU1可以立即计算,因为它一开始就拥有token 1-4,而GPU1实际上不需要从其他GPU接收任何信息。然而,GPU2需要等待第二轮,才能收到token 1-4,从而获得token 1-8的所有值。同时,GPU1的工作量明显比其他GPU要少得多。
让我们看看是否能更好地平衡计算负载:
我们需要一种更好的方式来分配输入序列。这可以通过不将token纯粹顺序地分配给各个GPU,而是稍微混合一下顺序,从而使每个GPU上都有较早和较晚的token。这种方法被称为之字形注意力
与此同时,我们也会看到,为了完成所有行的计算,每个GPU都需要来自其他所有GPU的信息。
我们有两种常见方式来重叠计算和通信:一种是通过执行一次通用的all-gather操作,同时在每个GPU上重新组合所有KV(类似于Zero-3的方式);另一种是根据需要从每个GPU逐个收集KV对:
这两种实现方式的关键区别在于它们的通信模式和内存使用:
1. AllGather实现:
2. 全对全(环形)实现:
全对全方法通常在内存效率上更优,但其通信模式稍显复杂;而AllGather方法则更简单,但在注意力计算过程中需要更多的临时内存。
到目前为止,我们已经看到如何通过TP在单个节点上拆分模型以驯服大模型,以及如何利用CP应对长序列带来的激活值爆炸问题。
然而,我们也知道TP在跨节点扩展时并不理想,那么如果模型权重难以容纳在单个节点上,我们该怎么办?这时,另一种并行度——流水线并行,将派上用场!
如想让阅读体验更像播客,可播放此音频,收听 NotebookLM 主播对以下章节的讨论。
在张量并行部分,我们提到若模型非常大(如 70B+),一台节点(8 块 GPU)放不下就得跨节点。可是 TP 在跨节点时通信巨大,常导致效率大跌。流水线并行(PP) 就是另一种可将模型切分到多节点的方法。
思路很直接:将模型沿层数方向分成多段,每段放到不同 GPU。这样单卡只需要存一部分层的权重,减少对同一节点显存的压力。我们来测了一下一个 8B 模型在不同 PP 度数下的显存:
可以看到,流水线并行有效削减了模型参数在单卡的占用,但激活值并未减少——因为每个微批还是完整地跑完每段,需要把整个 batch 的激活值一直传给下一卡,前面计算的激活值还要等到反向传播时使用,导致激活在这块卡上也得保留。
更麻烦的是流水线天生带有串行依赖:第一张卡先执行前向,然后交给第二张卡,再交给第三张……直到全部层前向都结束才能返回梯度。显然这是极其低效的:只有一张卡在工作,其他都在等。
接下来我们一起来看看如何设计流水线的调度来减少这种低效,让 GPU 都忙起来!
最朴素做法:把模型分成多段,1~4 层在 GPU0,5~8 层在 GPU1……前向传播时就顺次把数据传下去;反向传播再把梯度传回来。可视化如下:
一个 16 层模型分到 4 块 GPU 的示意图,数字代表层号。
灰色部分表示 GPU 空闲等待,称为“气泡(bubble)”。为了量化这个气泡带来的低效,可以简单假设一次前向耗时
在实践中,我们会引入多次 micro-batch,让 GPU1 在算 micro-batch1 的反向时,GPU0 可以算 micro-batch2 的前向,以减少气泡。
其具体实现叫做 all-forward-all-backward (AFAB) 调度:先把所有 micro-batch 的前向都算完,再统一反向:
仍会有一个初始和结尾的气泡:开始时只能有 GPU0 工作,结束时只有 GPUp-1 工作。这个气泡占用时间约
可见通过增大 micro-batch(梯度累加)能降低气泡浪费。但这样也会引入额外内存需求——需要同时存储所有 micro-batch 的激活值,可能造成内存爆炸。所以下面的一种改进调度——1F1B——可以缓解这一点。
在上面 AFAB 中,前向与后向是分开的。1F1B 做法是尽早开始后向,让模型不必保留太多 micro-batch 的激活值。形象表示如下:
此时,每块 GPU 不再同步地进行前向或后向,而是可以按照一定顺序单独地开始后向。这样激活可更早释放。
但也能看到,bubble 大小并没有显著下降,只是说我们不再需要存储所有的激活;如果再增加足够多 micro-batch(梯度累加),bubble 会在一定程度上减少。
因此虽然在调度中更灵活,但实现复杂度更高。Picotron 里有一个 1F1B 的实现示例:
让我们看看这种调度在实际集群上扩展的表现:
左图里如果梯度累加步骤
然而相比张量并行在跨节点时的剧烈性能下降,流水线并行的通信量更小(只传激活值到下一节点),所以更擅长在 512+ GPU 大规模场景保持还算不错的效率。
但我们还想进一步减少这个 bubble 的浪费。有一些新近工作能将 bubble 减到几乎为零。比如 DeepSeek-V3/R1 采用的 DualPipe,就极大减少了流水线气泡。让我们简单了解一下这些高阶调度。
除了 1F1B,近期又出现了一些更先进的流水线调度方法,可以把 bubble 压缩到几乎为 0。例如 ZeroBubble
ZeroBubble 的核心观察:在反向传播的全连接中,梯度对输入(B)和梯度对权重(W)的计算是分离的;对输入梯度要马上回传,但对权重梯度只需在最后更新前进行即可。也就是说可以把后向再细分为 B 和 W 两部分,然后在流水线里灵活地穿插 W,以填补 bubble 时间。
DeepSeek 的 DualPipe 还把流水线分为上下两个流,并结合这个 B/W 分解的方法,采用 ILP 算法生成几乎零气泡的调度。
实现这类高级调度往往需要对模型和硬件做精确测量,并结合启发式或整数线性规划进行自动调度。细节可参见
至此,我们就把流水线并行的前世今生梳理了一遍。
现在,终于轮到最后一个并行策略:专家并行(Expert Parallelism,EP)。
这是我们要介绍的最后一种并行技术。如果你对 Mixture-of-Experts (MoE) 不熟,可先看我们以前写的一篇博客进行快速了解。
近来各种稀疏专家模型越来越受到关注,如 GPT-4、Mixtral
示例图来自 Switch Transformers 论文
因为专家之间互不干扰,所以可以把这些专家层自然地分配到不同的 GPU 上,这就是 专家并行(EP)。相比张量并行要拆分矩阵,这里不需要在操作内部进行行列切分,所以通信主要发生在 token 与专家之间的路由上——当 token 被分配给某个专家时,就要发送到那张 GPU 处理。
实现时,一般会在 DP 基础上再加上 EP,因为专家层只是部分层(取代了原本的 MLP),其余 attention 等模块还是可以 DP 来并行。
参考:A Survey on Mixture of Experts
DeepSeek-V3 中会进一步约束路由,让每个 token 最多只发往 4 个专家,以减少通信量。当模型拥有上百个专家时,如 256 专家,EP 并行就能显著提升可扩展性。
未来我们可能会在 picotron/nanotron 中补充一个 EP 的完整示例,敬请期待。
到这里,我们已经看完了 5 个可以帮助我们扩展 LLM 训练的并行策略:
以及与数据并行结合的三种 ZeRO 阶段:
也许你会问,这么多并行策略之间是怎么组合、取舍的?该怎么一起用?是不是都能叠加?下面我们给出更深入的对比与分析。
流水线并行 vs. ZeRO-3:两者都可以沿模型层数方向分拆,但在实现和通信模式上有区别:
ZeRO-3 | 流水线并行 (PP) | |
---|---|---|
每个并行单元存储 | 仅该层参数的 1/DP,需通信重组 | 该部分层的完整参数 |
通信对象 | 权重碎片 | 激活值 |
对模型的依赖程度 | 无强模型耦合(深度学习框架层面即可) | 也无强耦合,但要对层结构做拆分 |
实现难度 | 对模型分片和通信调度较复杂 | 对 pipeline 气泡和调度算法复杂 |
扩展时考虑 | 需要较大 micro batch 或较长序列,以覆盖通信 | 需要较大 gradient accumulation 以减少气泡 |
因此 ZeRO-3 和 PP 在本质上都属于“按层或参数的维度”进行拆分,但一个主要传权重,一个主要传激活。它们也可以组合,但常见做法是 ZeRO-1/2 与 PP 结合,而 ZeRO-3 与 PP 结合会导致通信更复杂,需要更多调度来隐藏开销。
张量并行(TP) 是与 ZeRO 或 PP 互补的,它更依赖对算子的实现进行拆分,如分块矩阵乘法、分块注意力等,尤其需要高带宽环境。实际生产中,为了让 TP 部分只在单节点内通信,我们会把 8 张卡作为一个 TP 分组,再与 DP 或 PP 结合,用多节点扩展。
上下文并行(CP) 和 专家并行(EP) 也是与 TP 互补的思路,分别针对长序列场景和 MoE 专家分片。
下表简要对比:
张量 + 序列并行(TP+SP) | 上下文并行(CP) | 专家并行(EP) |
---|---|---|
在隐藏/序列维度拆分权重和激活 | 在序列维度拆分激活 | 在专家维度拆分 MoE 参数 |
大量 row/column linear 通信 | 注意力中交换 KV | 路由 token 给专家,需要 all-to-all |
需深度改写网络实现 | 相对对模型无强耦合,但需自定义注意力通信 | 只在 MoE 层特殊处理 |
倾向在节点内(高速 NVLink)使用 | 适合长序列场景 | 需要 MoE 架构 |
最后,为了让你对如何结合五种并行有更直观印象,我们做了一张大图,展示了一个包含 MoE 的 Transformers 层在进行五维并行时各自要通信的地方:
我们也画了一张表格,对这些方法在激活、参数、通信上所起到的作用做了总览:
最后,用一个简短表格总括:
方法 | 主要解决的显存瓶颈 | 并行/分片的维度 | 缺点或额外开销 |
---|---|---|---|
DP | 激活值(可减少单卡 batch) | batch | 受制于最大 batch size |
PP | 模型权重 | 模型层数 | 气泡;调度复杂 |
TP/SP | 模型权重和部分激活 | 隐藏维度 / 序列维度 | 通信多,依赖大带宽 |
CP | 激活值 | 序列维度 | 注意力通信复杂,需特殊实现 |
EP | MoE 专家参数 | 专家数 | 需要 MoE;路由通信消耗 |
ZeRO-1 | 优化器状态 | DP 范围内分片 | 需参数通信 |
ZeRO-2 | 优化器状态 + 梯度 | DP 范围内分片 | 需参数通信 |
ZeRO-3 | 优化器状态 + 梯度 + 模型参数 | DP 范围内分片 | 需参数通信 |
没有一种方法是完全通用的银弹,实际生产中往往会组合使用它们。具体组合策略视模型大小、目标 batch size、硬件带宽和 GPU/节点数而定。接下来,我们给出一些经验或规则供你参考。
我们已经了解了各种并行方法的原理与实现细节。下面给出一个简单的思路,用于在实际场景中做初步决策,然后再通过实验微调:
先确定如何把一个模型副本放到 GPU 上;有两种情况:
如果 GPU 资源相对充裕 🤑:
如果 GPU 资源有限 😭:
接着,我们还要配合全局 batch size。
根据第一步得到的单卡 batch size,如果还不够大,可以逐步增加 DP 或梯度累加;如果过大,可以减少 DP 或引入其他并行方式。
对于超长序列,还可用 CP 把激活值拆分给更多 GPU,以支撑更大的序列长度。
在确定能放进显存、满足 batch size 的前提下,我们就要尽量提高训练速度:
实践里,为找最佳配置通常需要进行大量分布式实验。我们在 nanotron 中准备了相应脚本,可自动遍历各种并行组合,并跑基准测试。
我们自己就做了几千次测试来绘制之前的图表,用了 1~64 台节点(每节点 8×H100)。下面是我们对不同模型规模、节点数量下最佳配置所对应的效率热力图:
热力图展示对不同模型尺寸与节点数(8卡/节点)找到的最优训练配置。DP、TP、PP、梯度累加(GAS)、微批大小(MBS)、ZeRO 阶段等信息都列出。颜色亮度代表模型 FLOPs 利用率 (MFU),越亮表示效率越高。
可以看到:
首先,小模型在大并行度下效率急剧下降,因为通信和等待带来的代价相对更高。即使想用更大 batch size 来弥补,也难免增加优化步数和通信。
对于大模型,在小节点数时要么放不下,要么勉强放下后算力利用也不佳。
我们还发现许多性能细节取决于实现优化。比如开始时我们写的 TP 实现比 PP 更快,后来在 PP 上做了性能优化后 PP 又领先,再然后我们改进 TP 又赶上或反超。
想做所有组合的网格搜索看似只需简单地提交一个大任务,但在实践中会遇到许多麻烦,比如:
我们为此写了不少自动脚本和监控,甚至深入阅读 NCCL 日志和 CUDA 内存分配器的源码,才把如此大规模的实验跑完。希望通过开放源码(nanotron、picotron)能帮助更多人复现和摸索这些分布式技巧。
好了,分布式并行的主要算法方法到这里就告一段落了。但别忘了还有一个前提:我们假设在 GPU 上的计算与通信可以 100% 并行且不会互相争抢资源。实际上 GPU 里计算和通信经常会争夺同样的 SM 资源,需要深入到 GPU 硬件层面进行优化。下面我们就来聊聊 GPU 内部的那些事,包括 Flash Attention 是如何把注意力计算大幅加速并节省显存的,还有如何写高效的 GPU kernel,以及如何利用混合精度(尤其是 FP8)进一步提升速度等。
如想让阅读体验更像播客,可播放此音频,收听 NotebookLM 主播对以下章节的讨论。
到目前为止,我们主要关注宏观的并行组织与资源调度,即如何把模型或数据分给多个 GPU 并行。现在我们要把视角拉到更底层的 GPU 架构上,去看如何用高效的 kernel、合适的线程/内存布局来获得更佳性能。
首先简要回顾 GPU 的结构。
通常 GPU 由多个“流式多处理器(SM)”组成,每个 SM 内含一定数量的核心(如 H100 有 132 个 SM、每 SM 下有 128 个核心,总计 16896 个核心)。各 SM 之间共享一部分缓存(L2),也都有自己的寄存器和共享内存。下图可见其大致示意:
来源:https://blog.codingconfessions.com/p/gpu-computing
内存层级方面有寄存器、共享内存(SM 内部,速度快但容量小),以及全局内存(HBM,容量大但延迟高)。要想高效,就要尽量让数据在寄存器或共享内存中复用,减少访问全局内存的频次。
在编程模型上,我们通常用 CUDA 或 Triton 写 kernel,在 CPU 端(host 端)分配显存并启动 kernel。
示例自 https://docs.nvidia.com/cuda/cuda-c-programming-guide/
示例自 https://docs.nvidia.com/cuda/cuda-c-programming-guide/
在 CUDA 编程模型里,会把线程分成 warp(32 个线程),再组织成 blocks,每个 block 会分配到一个 SM 上执行;SM 可以并发地运行多个 block,具体由硬件调度。
了解这些细节有助于写更高效的 kernel,比如让内存访问对齐、减少分支发散等。若你只用 PyTorch 最高层接口,也可以用 @torch.compile
或 Triton 来自动优化,但理解原理有助于做更深入的微调。
最简单的方式是直接让 PyTorch 自己调用已实现好的 kernel,或用 torch.compile
自动生成优化版本。如果还不够快,就可以考虑 Triton,若仍需更高级别控制,可直接写 CUDA。大概梯度如下:
一个常见的优化点是减少全局显存(HBM)的访问,因为延迟和带宽都有限。若能用 shared memory 复用或块内合作加载,可以提升性能。另一个常见优化是将多个操作“融合(fuse)”到一个 kernel 中,减少 CPU-GPU 之间的往返调度。
GPU 上 kernel 与 kernel 之间的切换也有不小开销,而且每次算完都写回全局显存再读入下一个 kernel 非常浪费。若几个操作能一起在一个 kernel 内做,就能大大减少冗余。比如点积激活函数都放在一次 kernel 内完成,就叫“fusion”。
Horace He 在其博客中用以下图示解释得很好:
一系列算子若分开执行,每一步都在全局显存读写
进行 kernel fuse 后,不用反复把中间结果写回 global memory
Transformer 中也大量使用这种融合,例如 LayerNorm 层常常融合多步运算到一个 kernel 中。另外 FlashAttention 也可以视作融合了注意力的多个算子。
FlashAttention 系列由 Tri Dao 首创,通过定制 kernel,将注意力计算中最大的中间产物(S 矩阵)仅在更快速的 SM 内存中存放并及时消费,不写回全局显存,大幅减少了内存访问。
传统注意力做法会显式计算出 S 并存到显存,再算 Softmax 和乘 V,而 FlashAttention 避免了这一步,把计算进一步分块并保存在共享内存或寄存器里。示意对比如下:
而 FlashAttention 的改进示意:
来自 FlashAttention 论文
这样做既减少了显存开销,又跳过了对大矩阵 S 的写回操作,速度和空间上都有极大优势。FlashAttention 成为当前大多数 Transformer 实现的默认注意力方法。
后续还有 FlashAttention 2、3,主要进一步优化了对 GPU 硬件的适配(如 Hopper 上的 FP8 TensorCore 支持),但核心理念还是“分块 + 避免显存中间存储 + 融合算子”。
前文多次提到低精度,如 BF16、FP8 等大幅减少存储和加速计算。本节就来具体看看如何做到既用低精度,又保证数值稳定。
我们先从 16 位(BF16/FP16)混合精度开始,再介绍如今正热门的 FP8。
若我们单纯把所有张量都设为 fp16,往往会发生训练发散。早期论文
这套做法在实操中非常有效,可在保持与 FP32 一致的收敛性能的情况下,加快计算、节省存储。对于当下硬件而言,这几乎已成标准做法。
再往下走一步,FP8 能否也适用于大模型预训练呢?
NVIDIA H100 GPU 对 FP8 的理论峰值 FLOPS 是 BF16 的两倍,这让很多人对 FP8 非常感兴趣。但降低到 8 位会带来更严重的数值稳定问题,可能导致训练发散。尤其是大模型在早期会有更大的梯度变化,极易出现下溢或溢出。
虽然各大团队都在积极研究,但直到最近 DeepSeek-V3 才首度在超大规模预训练中稳定采用了 FP8。其报告
在实际实现中,需要实时统计激活的数值范围,并将其规范化到 FP8 可表示的范围内。这样也会多一些通信和计算开销,但相比把所有算子都固定在 FP16/BF16 上,潜在的加速更高。现在社区对 FP8 的研究仍在快速迭代中。若想跟进,可见 nanotron 的 FP8 PR。
不远的将来(比如 NVIDIA Blackwell),甚至会支持 FP4 等更低精度,这些都在探索之中。
本节我们从底层视角介绍了写高效 kernel、做算子融合、以及混合精度如何带来性能与内存上的优势。到这里,我们对如何在大规模分布式集群上高效训练 LLM 已有了比较完整的图景。
恭喜你读到这里!我们从在一块 GPU 上训练一个简单模型开始,一步步了解了如何扩展到数十、数百、上千张 GPU,掌握了当前 LLM 训练中最常用的并行技术:数据并行、流水线并行、张量并行、上下文并行、专家并行,以及 ZeRO、重计算、Flash Attention、FP8 等配套工具。读完后,你应该能够比较轻松地看懂像 Llama-3、DeepSeek-V3 这种庞大模型的多维并行结构了:
把数千张 GPU 协同起来训练大型模型绝非易事,需要对计算、通信、内存管理等多方面进行巧妙设计。本书从宏观的并行划分到微观的 kernel 优化,希望能帮助你在构建或研究大规模训练系统时减少踩坑,也希望能进一步激发你的灵感!
或许过去这些知识似乎只对预训练 LLM 的极少数人有用,但随着模型规模继续增长,以及企业和社区对大模型微调和部署需求的增加,这些分布式训练技巧也会被越来越多人使用。
这是一次漫长的学习之旅,对我们编写本书的人来说同样如此。在完成数千次分布式实验、调试了无数奇怪的 bug、读了数不尽的代码和文献后,我们也从中收获良多。
如果你想更深入下去,可考虑以下几个方向:
期待你也能运用这些知识,训练出下一代强大的开源模型!
最后的话:在本书刚发布时,我们打算印制少量实体版本赠给前 50 位读者。若有兴趣,可在这里填表:google form,等整理完毕后会邮件联系你。
无论你是第一时间读到本书,还是在未来某个时刻,感谢你的阅读,希望本书能帮助你更好地理解与实践大规模分布式训练!
我们感谢 Elie 细致地审稿,并使用 NotebookLM 生成了音频组件;感谢 Hynek 优化了前端性能;同时也感谢 Simon 为本书在 Hugging Face Hub 上提供的帮助。
如果想讨论本书内容,提出问题或反馈,欢迎到 讨论区发帖!
提出了张量并行和高效模型并行技术,用于训练大型语言模型。
结合 DeepSpeed 与 Megatron-LM,在 530B 参数规模上训练大语言模型。
Google 的 Pathways Language Model,在数百项语言任务和推理上表现优异。
Google 的多模态模型架构,可处理文本、图像、音频和视频。
The Llama 3 Herd of Models
DeepSeek 关于 V3 模型的技术报告。
我们用于训练大型语言模型的生产框架,涵盖多种并行策略。
NVIDIA 的大型模型训练框架,支持多种并行技术。
Microsoft 的深度学习优化库,包含 ZeRO 等多种扩展策略。
提供多个并行和优化方法的 PyTorch 扩展库。
大规模模型训练解决方案,包含多种并行和优化策略。
PyTorch 原生库,用于大型模型训练。
EleutherAI 的大模型训练框架,用于 GPT-NeoX-20B。
Lightning AI 提供的开源 LLM 实现,注重可复现性。
分布式大语言模型训练与推理。
PyTorch 版的 GPipe 流水线并行实现。
OSLO:Open Source for Large-scale Optimization。
PyTorch 官方对 profiler 的使用教程。
官方博客,深入讲解如何理解和优化 GPU 显存使用。
我们对一个简单案例做显存剖析的博客。
如何在 TensorBoard 中可视化 PyTorch 的性能数据。
关于数据并行在深度学习中的全面解读。
Zero Redundancy Optimizer 在大模型训练中的应用。
Fully Sharded Data Parallel 在 PyTorch 中的实现。
结合多种并行方法来高效训练大模型。
NVIDIA 对大模型流水线并行的介绍。
进一步讨论流水线调度的论文。
讲解分布式训练中 ring all-reduce 算法的细节。
将 ring attention 与 flash attention 结合的实现。
关于 ring attention 思路与实现的介绍。
DeepSpeed 里的关于 ZeRO 与 3D 并行互补关系的说明。
较早系统讨论混合精度训练方法的论文。
从更高维度介绍并行策略的通信。
DeepSeek 关于 1 万 PCI GPU 集群的报告。
Meta 对其 AI 超大规模基础设施的介绍。
对超大规模 H100 集群的剖析。
更易读的 CUDA 教程。
一份涵盖 LLM 训练各方面的全面手册。
记载 BLOOM 模型训练过程与挑战的详细日志。
Meta 关于 OPT-175B 训练过程的日志文档。
分析模型规模与训练开销之间的关系。
从数据和代价层面探讨长上下文训练。
一个社区式 GPU 阅读组。
ML Scalability & Performance 阅读讨论组。
关于如何扩展模型的一本书。
仅 500 行左右的最小化 FSDP 实现示例。
Horace He 的一些博客 - 讲解 GPU 如何跑得飞快。
通俗易懂地解读了 Flash Attention。
基于 PyTorch 的大规模语言模型教程。
在本书中,我们反复提到扩展到 N 块 GPU,需要对权重、梯度或激活值进行各种同步或通信,如 broadcast、gather、reduce-scatter 等。这里简要介绍这些“集体通信”操作。
先假设我们有 N 个节点(CPU 核、GPU、或其他设备),每个节点都能和其他节点通信。
最常见的几个模式:
把某个节点(root)上的数据广播给其他所有节点:
PyTorch 里就是 dist.broadcast(tensor, src=0)
。只有 src 节点的 tensor
会被同步给其他节点。
对 N 个节点分别持有的数据做一次归约(加和、平均等),将结果放到单个节点上(reduce),或者放到所有节点上(all-reduce):
在 PyTorch 里 dist.reduce(tensor, dst=0)
或 dist.all_reduce(tensor)
。reduce 只影响 dst 节点上的张量,而 all_reduce 把结果同步给所有节点。
如果每个节点有一块互不相同的数据,我们要把它们合并到一块,就可以用 gather(收集到一个节点)或 all-gather(收集到所有节点):
对应 PyTorch 中的 dist.gather(tensor, gather_list, dst=0)
或 dist.all_gather(gather_list, tensor)
。
Scatter 是 gather 的反操作,reduce-scatter 则在 scatter 前后再加一个 reduce 步骤:
PyTorch 里是 dist.scatter(tensor, scatter_list, src=0)
、dist.reduce_scatter(output_tensor, input_list)
。
最后提一下 dist.barrier()
用于同步进程,所有进程必须都到达此处才能继续。这在某些情况下很有用。
NVIDIA Collective Communications Library,简称 NCCL,是专门针对 GPU-GPU 通信的高效实现,PyTorch 在 GPU 训练时默认为它做底层通信。
若你想了解更多可参见 PyTorch 分布式文档。
若我们的算子已经由 PyTorch 内置实现,最简单的方式就是用 torch.cuda.Event
或 torch.profiler
。下面给出一个简单示例:
(示例代码略,保持原样)
若要分析自定义的 CUDA kernel,也可以用 PyTorch cpp_extension
来编译并加载 C++/CUDA 源码。再用 torch.profiler
或 Nsight Compute 等工具分析即可。
在 LLM 训练中,往往会遇到如下数量级:
以下是一些常见并行策略的“通信 vs. 计算”分析简要公式,帮助判断何时通信能被隐藏。
(略,保持原文公式即可)
如果在学术场景引用本书:
Tazi et al., "The Ultra-Scale Playbook: Training LLMs on GPU Clusters", 2025.
BibTeX citation:
@misc{ultrascale_playbook, title={The Ultra-Scale Playbook: Training LLMs on GPU Clusters}, author={Nouamane Tazi, Ferdinand Mom, Haojun Zhao, Phuc Nguyen, Mohamed Mekkouri, Leandro Werra, Thomas Wolf}, year={2025}, }