重要提示

您正在查看 NeMo 2.0 文档。此版本对 API 和新库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档

S3 检查点#

S3CheckpointIO#

此 checkpoint_io 用于将文件保存和加载到 S3 以及从 S3 加载文件。初始化此 checkpoint_io 需要 dirpath 是一个 S3 dirpath。

使用示例

async_checkpointing = self.cfg.s3_checkpointing.get('enable_async_checkpointing', False)
chunk_size_MB = self.cfg.s3_checkpointing.get('chunk_size_MB')
max_read_concurrency = self.cfg.s3_checkpointing.get('max_read_concurrency')
max_write_concurrency = self.cfg.s3_checkpointing.get('max_write_concurrency')
dirpath = self.cfg.exp_manager.checkpoint_callback_params.get('dirpath')

s3_checkpoint_io = S3CheckpointIO(dirpath=dirpath, chunk_size_MB=chunk_size_MB, max_read_concurrency=max_read_concurrency, max_write_concurrency=max_write_concurrency, async_checkpointing=async_checkpointing)

strategy = NLPDDPStrategy(
    no_ddp_communication_hook=True,
    checkpoint_io=s3_checkpoint_io,
    gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
    find_unused_parameters=False,
    nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
    sharp=self.cfg.model.get('sharp', False),
)

配置更改

checkpoint_callback_params:
dirpath: s3://mstar-eks-dev-us-east-2/alxzhang/nemo123/1n/checkpoints

...

s3_checkpointing:
    # write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition
    max_write_concurrency: 10
    # read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition
    max_read_concurrency: 15
    chunk_size_MB: 64
    # enables asynchronous checkpoint writing to S3
    enable_async_checkpointing: False

异步 默认情况下,S3CheckpointIO 类同步运行。异步功能目前不检查之前的异步保存是否完成,因此即使当前保存失败,也可能会删除旧的检查点。为了防止这种情况,此功能旨在与保存前 k 个检查点结合使用。

S3Utils 和依赖项#

S3CheckpoinIO 和 exp_manager 使用此实用程序类来执行 S3 相关操作。它依赖于

  1. boto3[crt]

  2. s3fs==0.4.2

  3. tenacity

如果缺少其中任何一个,则无法使用此类。

s3_dirpath_utils#

用于通过检查字符串是否为 S3 dirpath 或将存储桶和密钥转换为 s3 dirpath 来操作字符串。这不依赖于 S3Utils 实用程序类,并且可以在没有任何新依赖项的情况下使用。

大规模运行时 S3 需求和 ExpManager 详细信息#

通常,在 ExpManager 中,每个 rank 都会查找要从中加载的检查点文件。在大规模情况下,可能有数千个 rank 查询 S3 以获取 dirpath,这可能会导致速度减慢或限制错误。

为了避免在从检查点恢复时 S3 过载,只有 rank 0 需要识别检查点路径并找到正确的恢复文件。Rank 0 会将检查点路径广播到其他 rank。

trainer._checkpoint_connector = NeMoCheckpointConnector(trainer)

NeMoModelCheckpoint setup() 方法将自动广播检查点路径。

NeMoCheckpointConnector 在 exp_manager.py 文件中定义,并在从现有检查点恢复训练时,在所有 rank 上使用 rank 0 找到的广播检查点路径。

trainer._checkpoint_connector 的设置需要在 ExpManager 调用之前进行,因为 ExpManager 会更新 trainer 的检查点连接器。