分布式优化器
分布式优化器的动机是通过在数据并行 ranks 之间均匀分布优化器状态 (https://arxiv.org/abs/1910.02054) 来节省内存,而不是在数据并行 ranks 之间复制优化器状态的朴素方法。
理论内存节省取决于模型参数的数据类型 (param_dtype
) 和跨数据并行副本累积的主梯度 (grad_dtype
) 的组合。我们始终对优化器步骤使用 fp32
主参数。在当前的实现中,每个参数的理论字节数为(其中 d 是数据并行大小)
非分布式优化 | 分布式优化 | |
---|---|---|
fp16 参数,fp16 梯度 |
20 | 4 + 16/d |
bf16 参数,fp32 梯度 |
18 | 6 + 12/d |
fp32 参数,fp32 梯度 |
16 | 8 + 8/d |
我们分布式优化器的实现对参数和主梯度使用连续缓冲区;模型梯度一旦完全计算出来,就会被复制到主梯度中。
下图说明了分布式优化器的分片方案,以及分布式优化器参数更新的关键步骤
(注意:使用上面的插图,假设 bf16
模型权重,bf16
模型梯度(由反向传播计算得出)和 fp32
主梯度(也用于优化器步骤);我们始终对优化器步骤使用 fp32
主权重)
反向传播完成(梯度缓冲区保存 16 个
fp32
梯度元素)。在每个 DP rank 上调用 reduce-scatter。
现在每个 DP rank 在梯度缓冲区中都有 4 个完全 reduce 的元素(剩余的 12 个元素是垃圾)。
DP rank 0 具有元素 [0:4] 的梯度值。
DP rank 1 具有元素 [4:8] 的梯度值。
DP rank 2 具有元素 [8:12] 的梯度值。
DP rank 3 具有元素 [12:16] 的梯度值。
Optimizer.step()。
每个 DP rank 将其 4 个
fp32
主参数元素复制到相应的bf16
参数缓冲区中(每个元素都从 fp32 转换为 fp16)。在每个 DP rank 上调用 all-gather。
现在参数缓冲区包含所有 16 个完全更新的
bf16
模型参数元素。PyTorch 模块中的参数已经指向此参数缓冲区中的适当位置,因此在 all-gather 完成后,正向传播就可以运行了。此时,梯度缓冲区也已准备好为下一次迭代归零。