分布式优化器
分布式优化器的动机是通过在数据并行 ranks 之间均匀分布优化器状态 (https://arxiv.org/abs/1910.02054) 来节省内存,而不是在数据并行 ranks 之间复制优化器状态的朴素方法。
理论内存节省取决于模型参数的数据类型 (param_dtype) 和跨数据并行副本累积的主梯度 (grad_dtype) 的组合。我们始终对优化器步骤使用 fp32 主参数。在当前的实现中,每个参数的理论字节数为(其中 d 是数据并行大小)
非分布式优化 | 分布式优化 | |
|---|---|---|
fp16 参数,fp16 梯度 |
20 | 4 + 16/d |
bf16 参数,fp32 梯度 |
18 | 6 + 12/d |
fp32 参数,fp32 梯度 |
16 | 8 + 8/d |
我们分布式优化器的实现对参数和主梯度使用连续缓冲区;模型梯度一旦完全计算出来,就会被复制到主梯度中。
下图说明了分布式优化器的分片方案,以及分布式优化器参数更新的关键步骤
(注意:使用上面的插图,假设 bf16 模型权重,bf16 模型梯度(由反向传播计算得出)和 fp32 主梯度(也用于优化器步骤);我们始终对优化器步骤使用 fp32 主权重)
反向传播完成(梯度缓冲区保存 16 个
fp32梯度元素)。在每个 DP rank 上调用 reduce-scatter。
现在每个 DP rank 在梯度缓冲区中都有 4 个完全 reduce 的元素(剩余的 12 个元素是垃圾)。
DP rank 0 具有元素 [0:4] 的梯度值。
DP rank 1 具有元素 [4:8] 的梯度值。
DP rank 2 具有元素 [8:12] 的梯度值。
DP rank 3 具有元素 [12:16] 的梯度值。
Optimizer.step()。
每个 DP rank 将其 4 个
fp32主参数元素复制到相应的bf16参数缓冲区中(每个元素都从 fp32 转换为 fp16)。在每个 DP rank 上调用 all-gather。
现在参数缓冲区包含所有 16 个完全更新的
bf16模型参数元素。PyTorch 模块中的参数已经指向此参数缓冲区中的适当位置,因此在 all-gather 完成后,正向传播就可以运行了。此时,梯度缓冲区也已准备好为下一次迭代归零。