Megatron Core 用户指南

分布式优化器

分布式优化器的动机是通过在数据并行 ranks 之间均匀分布优化器状态 (https://arxiv.org/abs/1910.02054) 来节省内存,而不是在数据并行 ranks 之间复制优化器状态的朴素方法。

理论内存节省取决于模型参数的数据类型 (param_dtype) 和跨数据并行副本累积的主梯度 (grad_dtype) 的组合。我们始终对优化器步骤使用 fp32 主参数。在当前的实现中,每个参数的理论字节数为(其中 d 是数据并行大小)

非分布式优化

分布式优化

fp16 参数,fp16 梯度 20 4 + 16/d
bf16 参数,fp32 梯度 18 6 + 12/d
fp32 参数,fp32 梯度 16 8 + 8/d

我们分布式优化器的实现对参数和主梯度使用连续缓冲区;模型梯度一旦完全计算出来,就会被复制到主梯度中。

下图说明了分布式优化器的分片方案,以及分布式优化器参数更新的关键步骤

(注意:使用上面的插图,假设 bf16 模型权重,bf16 模型梯度(由反向传播计算得出)和 fp32 主梯度(也用于优化器步骤);我们始终对优化器步骤使用 fp32 主权重)

  • 反向传播完成(梯度缓冲区保存 16 个 fp32 梯度元素)。

  • 在每个 DP rank 上调用 reduce-scatter。

  • 现在每个 DP rank 在梯度缓冲区中都有 4 个完全 reduce 的元素(剩余的 12 个元素是垃圾)。

    • DP rank 0 具有元素 [0:4] 的梯度值。

    • DP rank 1 具有元素 [4:8] 的梯度值。

    • DP rank 2 具有元素 [8:12] 的梯度值。

    • DP rank 3 具有元素 [12:16] 的梯度值。

  • Optimizer.step()。

  • 每个 DP rank 将其 4 个 fp32 主参数元素复制到相应的 bf16 参数缓冲区中(每个元素都从 fp32 转换为 fp16)。

  • 在每个 DP rank 上调用 all-gather。

  • 现在参数缓冲区包含所有 16 个完全更新的 bf16 模型参数元素。PyTorch 模块中的参数已经指向此参数缓冲区中的适当位置,因此在 all-gather 完成后,正向传播就可以运行了。

  • 此时,梯度缓冲区也已准备好为下一次迭代归零。

上一篇 dist_checkpointing.strategies package
下一篇 distributed package
© 版权所有 2022-2025, NVIDIA。 最后更新于 2025 年 1 月 14 日。