Megatron Core 用户指南

dist_checkpointing.strategies 软件包

定义不同检查点格式(后端)和保存/加载算法(策略)的软件包。

策略可用于实现新的检查点格式,或实现现有格式的新的(针对给定用例更优化的)保存/加载方式。策略被传递给 dist_checkpointing.loaddist_checkpointing.save 函数,并控制实际的保存/加载过程。

策略基础接口。

class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)

基类: core.dist_checkpointing.strategies.base.SaveShardedStrategy

适用于异步保存的保存策略。

abstract async_save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path) → core.dist_checkpointing.strategies.async_utils.AsyncRequest

执行准备工作并返回 AsyncRequest 给外部调用者。

参数
  • sharded_state_dict (ShardedStateDict) – 要保存的分片状态字典

  • checkpoint_dir (Path) – 检查点目标目录

返回

表示异步保存函数和最终化函数。

调用者有责任实际调度异步保存。

返回类型

AsyncRequest

save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

每个异步策略都可以简单地用作同步策略。

class core.dist_checkpointing.strategies.base.LoadCommonStrategy

基类: core.dist_checkpointing.strategies.base.LoadStrategyBase

用于通用(非分片)对象的加载策略

abstract load_common(checkpoint_dir: pathlib.Path)

加载检查点的通用部分。

load_sharded_metadata(checkpoint_dir: pathlib.Path) → Dict[str, Any]

仅从检查点加载元数据。

abstract load_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

从检查点加载分片对象。

class core.dist_checkpointing.strategies.base.LoadShardedStrategy

基类: core.dist_checkpointing.strategies.base.LoadStrategyBase

用于分片张量的加载策略

abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

加载检查点的分片部分。

load_sharded_metadata(checkpoint_dir: pathlib.Path)

从检查点加载 ShardedTensors 和 ShardedObjects 的分片元数据。

返回类似于分片状态字典的字典,但请注意,字典键只是分片键(与实际的分片状态字典相反,后者的键对应于状态字典键)。

字典值是不包含任何数据和分片的 ShardedTensors 或 ShardedObjects。

abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)

从检查点加载 ShardedTensors 的张量元数据。

返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。

字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。

remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)

删除所有键以 key_prefix 开头的张量

class core.dist_checkpointing.strategies.base.LoadStrategyBase

基类: abc.ABC

加载策略的基类。需要实现与给定检查点版本兼容性的检查。

property can_handle_sharded_objects

返回此策略是否可以处理加载 ShardedObjects。

abstract check_backend_compatibility(loaded_backend)

验证此策略是否与 loaded_backend 兼容。

abstract check_version_compatibility(loaded_version)

验证此策略是否与 loaded_version 兼容。

class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)

基类: core.dist_checkpointing.strategies.base.SaveStrategyBase

用于通用(非分片)对象的保存策略

abstract save_common(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

保存状态字典的通用部分。

save_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

从状态字典保存分片对象。

class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)

基类: core.dist_checkpointing.strategies.base.SaveStrategyBase

用于分片张量的保存策略

abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

保存状态字典的分片部分。

class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)

基类: abc.ABC

保存策略的基类。需要定义保存格式的后端类型和版本。

property can_handle_sharded_objects

返回此策略是否可以处理保存 ShardedObjects。

class core.dist_checkpointing.strategies.base.StrategyAction(value)

基类: enum.Enum

指定保存与加载以及分片与通用操作。

LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)

检索给定操作、后端和版本的默认策略。

core.dist_checkpointing.strategies.base.register_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int, strategy: Union[core.dist_checkpointing.strategies.base.SaveStrategyBase, core.dist_checkpointing.strategies.base.LoadStrategyBase])

将给定策略添加到默认策略的注册表中。

参数
  • action (StrategyAction) – 指定保存/加载和分片/通用

  • backend (str) – 策略成为默认策略的后端

  • version (int) – 策略成为默认策略的版本

  • strategy (SaveStrategyBase, LoadStrategyBase) – 要注册的策略

使用 TensorStore 加载和保存 Zarr 数组的策略。

class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)

基类: core.dist_checkpointing.strategies.base.LoadShardedStrategy

用于 Zarr 后端的加载策略,使用 tensorstore 进行加载。

check_backend_compatibility(loaded_version)

验证此策略是否与 loaded_backend 兼容。

check_version_compatibility(loaded_version)

验证此策略是否与 loaded_version 兼容。

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

加载检查点的分片部分。

load_tensors_metadata(checkpoint_dir: pathlib.Path)

从检查点加载 ShardedTensors 的张量元数据。

返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。

字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。

core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)

将全局切片与实际形状相交(防止溢出)。

core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)

使用 Tensorstore 和基本设置打开 Zarr 文件数组。

参数

arr_path (Path) – Zarr (Tensorstore) 数组的路径

core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()

注册利用 tensorstore 的默认策略。

两阶段检查点加载。

class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)

基类: core.dist_checkpointing.strategies.base.LoadShardedStrategy

从存储加载一个检查点副本并广播到其他节点。

此策略从最少数量的节点上的存储加载检查点,并使用 torch.distributed 将检查点分发到其他节点。加载使用 tensorstore 执行。

步骤: 0. (可选)创建 Gloo 分布式组 1. 在所有节点之间交换 ShardedTensors 元数据 2. 在 DP 组内对齐所需的张量 3. 对于每个全局唯一的张量: 3.a) 在其中一个 rank 上,从存储加载到 CPU 并移动到 CUDA 3.b) 在其他 rank 上分配 CUDA 张量 3.c) 在 DP 组内广播 3.d) 将张量内容复制到模型参数位置 3.e) 从 a) 和 b) 释放张量缓冲区

注意: 1. 加载和广播是顺序完成的,以避免主机和设备 OOM 2. 为每个张量完成的所有三个步骤之间存在大量重叠潜力: 2.a) 从存储加载到 numpy 2.b) 将 CPU 张量移动到 CUDA 2.c) 广播

check_backend_compatibility(loaded_version)

验证此策略是否与 loaded_backend 兼容。

check_version_compatibility(loaded_version)

验证此策略是否与 loaded_version 兼容。

deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])

按 chunk 对张量进行分组,然后选择 rank 最小的张量。

注意:通过适当的加载重叠,从随机 rank 加载

(而不是最小的 rank)可能在此处有利。

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

加载检查点的分片部分。

load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)
load_tensors_metadata(checkpoint_dir: pathlib.Path)

从检查点加载 ShardedTensors 的张量元数据。

返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。

字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。

maybe_init_gloo_group()
summarize_load_times()
core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)
core.dist_checkpointing.strategies.two_stage.timed(verbose=True)

使用 Zarr 作为底层格式的策略。

class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy

基类: core.dist_checkpointing.strategies.base.LoadShardedStrategy

用于 Zarr 后端的加载策略。

check_backend_compatibility(loaded_version)

验证此策略是否与 loaded_backend 兼容。

check_version_compatibility(loaded_version)

验证此策略是否与 loaded_version 兼容。

load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

加载检查点的分片部分。

load_tensors_metadata(checkpoint_dir: pathlib.Path)

从检查点加载 ShardedTensors 的张量元数据。

返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。

字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。

class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)

基类: core.dist_checkpointing.strategies.base.SaveShardedStrategy

用于 Zarr 后端的保存策略。

save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)

保存状态字典的分片部分。

core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)

将展平范围应用于张量。

core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]

加载 Zarr 数组的元数据。

参数
  • checkpoint_dir (str) – 检查点根目录

  • get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – 一个函数,为给定的 Zarr 数组路径返回数组形状和 dtype

core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)

将张量填充到预期形状。

core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)

将 numpy 数组转换为 torch 张量。

core.dist_checkpointing.strategies.zarr.register_default_zarr_strategies()

注册与 Zarr 后端相关的默认策略。

各种加载和保存策略

上一页 dist_checkpointing 软件包
下一页 分布式优化器
© 版权所有 2022-2025, NVIDIA。 上次更新于 2025 年 1 月 14 日。