dist_checkpointing.strategies 软件包
定义不同检查点格式(后端)和保存/加载算法(策略)的软件包。
策略可用于实现新的检查点格式,或实现现有格式的新的(针对给定用例更优化的)保存/加载方式。策略被传递给 dist_checkpointing.load 和 dist_checkpointing.save 函数,并控制实际的保存/加载过程。
策略基础接口。
- class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)
- 基类: - core.dist_checkpointing.strategies.base.SaveShardedStrategy- 适用于异步保存的保存策略。 - abstract async_save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path) → core.dist_checkpointing.strategies.async_utils.AsyncRequest
- 执行准备工作并返回 AsyncRequest 给外部调用者。 - 参数
- sharded_state_dict (ShardedStateDict) – 要保存的分片状态字典 
- checkpoint_dir (Path) – 检查点目标目录 
 
- 返回
- 表示异步保存函数和最终化函数。
- 调用者有责任实际调度异步保存。 
 
- 返回类型
- AsyncRequest 
 
 - save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 每个异步策略都可以简单地用作同步策略。 
 
- class core.dist_checkpointing.strategies.base.LoadCommonStrategy
- 基类: - core.dist_checkpointing.strategies.base.LoadStrategyBase- 用于通用(非分片)对象的加载策略 - abstract load_common(checkpoint_dir: pathlib.Path)
- 加载检查点的通用部分。 
 - load_sharded_metadata(checkpoint_dir: pathlib.Path) → Dict[str, Any]
- 仅从检查点加载元数据。 
 - abstract load_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 从检查点加载分片对象。 
 
- class core.dist_checkpointing.strategies.base.LoadShardedStrategy
- 基类: - core.dist_checkpointing.strategies.base.LoadStrategyBase- 用于分片张量的加载策略 - abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 加载检查点的分片部分。 
 - load_sharded_metadata(checkpoint_dir: pathlib.Path)
- 从检查点加载 ShardedTensors 和 ShardedObjects 的分片元数据。 - 返回类似于分片状态字典的字典,但请注意,字典键只是分片键(与实际的分片状态字典相反,后者的键对应于状态字典键)。 - 字典值是不包含任何数据和分片的 ShardedTensors 或 ShardedObjects。 
 - abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)
- 从检查点加载 ShardedTensors 的张量元数据。 - 返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。 - 字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。 
 - remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)
- 删除所有键以 key_prefix 开头的张量 
 
- class core.dist_checkpointing.strategies.base.LoadStrategyBase
- 基类: - abc.ABC- 加载策略的基类。需要实现与给定检查点版本兼容性的检查。 - property can_handle_sharded_objects
- 返回此策略是否可以处理加载 ShardedObjects。 
 - abstract check_backend_compatibility(loaded_backend)
- 验证此策略是否与 loaded_backend 兼容。 
 - abstract check_version_compatibility(loaded_version)
- 验证此策略是否与 loaded_version 兼容。 
 
- class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)
- 基类: - core.dist_checkpointing.strategies.base.SaveStrategyBase- 用于通用(非分片)对象的保存策略 - abstract save_common(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 保存状态字典的通用部分。 
 - save_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 从状态字典保存分片对象。 
 
- class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)
- 基类: - core.dist_checkpointing.strategies.base.SaveStrategyBase- 用于分片张量的保存策略 - abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 保存状态字典的分片部分。 
 
- class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)
- 基类: - abc.ABC- 保存策略的基类。需要定义保存格式的后端类型和版本。 - property can_handle_sharded_objects
- 返回此策略是否可以处理保存 ShardedObjects。 
 
- class core.dist_checkpointing.strategies.base.StrategyAction(value)
- 基类: - enum.Enum- 指定保存与加载以及分片与通用操作。 - LOAD_COMMON = 'load_common'
 - LOAD_SHARDED = 'load_sharded'
 - SAVE_COMMON = 'save_common'
 - SAVE_SHARDED = 'save_sharded'
 
- core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)
- 检索给定操作、后端和版本的默认策略。 
- core.dist_checkpointing.strategies.base.register_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int, strategy: Union[core.dist_checkpointing.strategies.base.SaveStrategyBase, core.dist_checkpointing.strategies.base.LoadStrategyBase])
- 将给定策略添加到默认策略的注册表中。 - 参数
- action (StrategyAction) – 指定保存/加载和分片/通用 
- backend (str) – 策略成为默认策略的后端 
- version (int) – 策略成为默认策略的版本 
- strategy (SaveStrategyBase, LoadStrategyBase) – 要注册的策略 
 
 
使用 TensorStore 加载和保存 Zarr 数组的策略。
- class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)
- 基类: - core.dist_checkpointing.strategies.base.LoadShardedStrategy- 用于 Zarr 后端的加载策略,使用 tensorstore 进行加载。 - check_backend_compatibility(loaded_version)
- 验证此策略是否与 loaded_backend 兼容。 
 - check_version_compatibility(loaded_version)
- 验证此策略是否与 loaded_version 兼容。 
 - load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 加载检查点的分片部分。 
 - load_tensors_metadata(checkpoint_dir: pathlib.Path)
- 从检查点加载 ShardedTensors 的张量元数据。 - 返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。 - 字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。 
 
- core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)
- 将全局切片与实际形状相交(防止溢出)。 
- core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)
- 使用 Tensorstore 和基本设置打开 Zarr 文件数组。 - 参数
- arr_path (Path) – Zarr (Tensorstore) 数组的路径 
 
- core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()
- 注册利用 tensorstore 的默认策略。 
两阶段检查点加载。
- class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)
- 基类: - core.dist_checkpointing.strategies.base.LoadShardedStrategy- 从存储加载一个检查点副本并广播到其他节点。 - 此策略从最少数量的节点上的存储加载检查点,并使用 torch.distributed 将检查点分发到其他节点。加载使用 tensorstore 执行。 - 步骤: 0. (可选)创建 Gloo 分布式组 1. 在所有节点之间交换 ShardedTensors 元数据 2. 在 DP 组内对齐所需的张量 3. 对于每个全局唯一的张量: 3.a) 在其中一个 rank 上,从存储加载到 CPU 并移动到 CUDA 3.b) 在其他 rank 上分配 CUDA 张量 3.c) 在 DP 组内广播 3.d) 将张量内容复制到模型参数位置 3.e) 从 a) 和 b) 释放张量缓冲区 - 注意: 1. 加载和广播是顺序完成的,以避免主机和设备 OOM 2. 为每个张量完成的所有三个步骤之间存在大量重叠潜力: 2.a) 从存储加载到 numpy 2.b) 将 CPU 张量移动到 CUDA 2.c) 广播 - check_backend_compatibility(loaded_version)
- 验证此策略是否与 loaded_backend 兼容。 
 - check_version_compatibility(loaded_version)
- 验证此策略是否与 loaded_version 兼容。 
 - deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])
- 按 chunk 对张量进行分组,然后选择 rank 最小的张量。 - 注意:通过适当的加载重叠,从随机 rank 加载
- (而不是最小的 rank)可能在此处有利。 
 
 - load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 加载检查点的分片部分。 
 - load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)
 - load_tensors_metadata(checkpoint_dir: pathlib.Path)
- 从检查点加载 ShardedTensors 的张量元数据。 - 返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。 - 字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。 
 - maybe_init_gloo_group()
 - summarize_load_times()
 
- core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)
- core.dist_checkpointing.strategies.two_stage.timed(verbose=True)
使用 Zarr 作为底层格式的策略。
- class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy
- 基类: - core.dist_checkpointing.strategies.base.LoadShardedStrategy- 用于 Zarr 后端的加载策略。 - check_backend_compatibility(loaded_version)
- 验证此策略是否与 loaded_backend 兼容。 
 - check_version_compatibility(loaded_version)
- 验证此策略是否与 loaded_version 兼容。 
 - load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 加载检查点的分片部分。 
 - load_tensors_metadata(checkpoint_dir: pathlib.Path)
- 从检查点加载 ShardedTensors 的张量元数据。 - 返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。 - 字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。 
 
- class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)
- 基类: - core.dist_checkpointing.strategies.base.SaveShardedStrategy- 用于 Zarr 后端的保存策略。 - save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
- 保存状态字典的分片部分。 
 
- core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)
- 将展平范围应用于张量。 
- core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]
- 加载 Zarr 数组的元数据。 - 参数
- checkpoint_dir (str) – 检查点根目录 
- get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – 一个函数,为给定的 Zarr 数组路径返回数组形状和 dtype 
 
 
- core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)
- 将张量填充到预期形状。 
- core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)
- 将 numpy 数组转换为 torch 张量。 
- core.dist_checkpointing.strategies.zarr.register_default_zarr_strategies()
- 注册与 Zarr 后端相关的默认策略。 
各种加载和保存策略