重要提示
您正在查看 NeMo 2.0 文档。此版本对 API 和新库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
S3 检查点#
S3CheckpointIO#
此 checkpoint_io 用于将文件保存和加载到 S3 以及从 S3 加载文件。初始化此 checkpoint_io 需要 dirpath 是一个 S3 dirpath。
使用示例
async_checkpointing = self.cfg.s3_checkpointing.get('enable_async_checkpointing', False)
chunk_size_MB = self.cfg.s3_checkpointing.get('chunk_size_MB')
max_read_concurrency = self.cfg.s3_checkpointing.get('max_read_concurrency')
max_write_concurrency = self.cfg.s3_checkpointing.get('max_write_concurrency')
dirpath = self.cfg.exp_manager.checkpoint_callback_params.get('dirpath')
s3_checkpoint_io = S3CheckpointIO(dirpath=dirpath, chunk_size_MB=chunk_size_MB, max_read_concurrency=max_read_concurrency, max_write_concurrency=max_write_concurrency, async_checkpointing=async_checkpointing)
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True,
checkpoint_io=s3_checkpoint_io,
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
sharp=self.cfg.model.get('sharp', False),
)
配置更改
checkpoint_callback_params:
dirpath: s3://mstar-eks-dev-us-east-2/alxzhang/nemo123/1n/checkpoints
...
s3_checkpointing:
# write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition
max_write_concurrency: 10
# read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition
max_read_concurrency: 15
chunk_size_MB: 64
# enables asynchronous checkpoint writing to S3
enable_async_checkpointing: False
异步 默认情况下,S3CheckpointIO 类同步运行。异步功能目前不检查之前的异步保存是否完成,因此即使当前保存失败,也可能会删除旧的检查点。为了防止这种情况,此功能旨在与保存前 k 个检查点结合使用。
S3Utils 和依赖项#
S3CheckpoinIO 和 exp_manager 使用此实用程序类来执行 S3 相关操作。它依赖于
boto3[crt]
s3fs==0.4.2
tenacity
如果缺少其中任何一个,则无法使用此类。
s3_dirpath_utils#
用于通过检查字符串是否为 S3 dirpath 或将存储桶和密钥转换为 s3 dirpath 来操作字符串。这不依赖于 S3Utils 实用程序类,并且可以在没有任何新依赖项的情况下使用。
大规模运行时 S3 需求和 ExpManager 详细信息#
通常,在 ExpManager 中,每个 rank 都会查找要从中加载的检查点文件。在大规模情况下,可能有数千个 rank 查询 S3 以获取 dirpath,这可能会导致速度减慢或限制错误。
为了避免在从检查点恢复时 S3 过载,只有 rank 0 需要识别检查点路径并找到正确的恢复文件。Rank 0 会将检查点路径广播到其他 rank。
trainer._checkpoint_connector = NeMoCheckpointConnector(trainer)
NeMoModelCheckpoint setup() 方法将自动广播检查点路径。
NeMoCheckpointConnector 在 exp_manager.py 文件中定义,并在从现有检查点恢复训练时,在所有 rank 上使用 rank 0 找到的广播检查点路径。
trainer._checkpoint_connector 的设置需要在 ExpManager 调用之前进行,因为 ExpManager 会更新 trainer 的检查点连接器。