dist_checkpointing.strategies 软件包
定义不同检查点格式(后端)和保存/加载算法(策略)的软件包。
策略可用于实现新的检查点格式,或实现现有格式的新的(针对给定用例更优化的)保存/加载方式。策略被传递给 dist_checkpointing.load 和 dist_checkpointing.save 函数,并控制实际的保存/加载过程。
策略基础接口。
- class core.dist_checkpointing.strategies.base.AsyncSaveShardedStrategy(backend: str, version: int)
基类:
core.dist_checkpointing.strategies.base.SaveShardedStrategy
适用于异步保存的保存策略。
- abstract async_save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path) → core.dist_checkpointing.strategies.async_utils.AsyncRequest
执行准备工作并返回 AsyncRequest 给外部调用者。
- 参数
sharded_state_dict (ShardedStateDict) – 要保存的分片状态字典
checkpoint_dir (Path) – 检查点目标目录
- 返回
- 表示异步保存函数和最终化函数。
调用者有责任实际调度异步保存。
- 返回类型
AsyncRequest
- save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
每个异步策略都可以简单地用作同步策略。
- class core.dist_checkpointing.strategies.base.LoadCommonStrategy
基类:
core.dist_checkpointing.strategies.base.LoadStrategyBase
用于通用(非分片)对象的加载策略
- abstract load_common(checkpoint_dir: pathlib.Path)
加载检查点的通用部分。
- load_sharded_metadata(checkpoint_dir: pathlib.Path) → Dict[str, Any]
仅从检查点加载元数据。
- abstract load_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
从检查点加载分片对象。
- class core.dist_checkpointing.strategies.base.LoadShardedStrategy
基类:
core.dist_checkpointing.strategies.base.LoadStrategyBase
用于分片张量的加载策略
- abstract load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
加载检查点的分片部分。
- load_sharded_metadata(checkpoint_dir: pathlib.Path)
从检查点加载 ShardedTensors 和 ShardedObjects 的分片元数据。
返回类似于分片状态字典的字典,但请注意,字典键只是分片键(与实际的分片状态字典相反,后者的键对应于状态字典键)。
字典值是不包含任何数据和分片的 ShardedTensors 或 ShardedObjects。
- abstract load_tensors_metadata(checkpoint_dir: pathlib.Path)
从检查点加载 ShardedTensors 的张量元数据。
返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。
字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。
- remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)
删除所有键以 key_prefix 开头的张量
- class core.dist_checkpointing.strategies.base.LoadStrategyBase
基类:
abc.ABC
加载策略的基类。需要实现与给定检查点版本兼容性的检查。
- property can_handle_sharded_objects
返回此策略是否可以处理加载 ShardedObjects。
- abstract check_backend_compatibility(loaded_backend)
验证此策略是否与 loaded_backend 兼容。
- abstract check_version_compatibility(loaded_version)
验证此策略是否与 loaded_version 兼容。
- class core.dist_checkpointing.strategies.base.SaveCommonStrategy(backend: str, version: int)
基类:
core.dist_checkpointing.strategies.base.SaveStrategyBase
用于通用(非分片)对象的保存策略
- abstract save_common(common_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
保存状态字典的通用部分。
- save_sharded_objects(sharded_objects_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
从状态字典保存分片对象。
- class core.dist_checkpointing.strategies.base.SaveShardedStrategy(backend: str, version: int)
基类:
core.dist_checkpointing.strategies.base.SaveStrategyBase
用于分片张量的保存策略
- abstract save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
保存状态字典的分片部分。
- class core.dist_checkpointing.strategies.base.SaveStrategyBase(backend: str, version: int)
基类:
abc.ABC
保存策略的基类。需要定义保存格式的后端类型和版本。
- property can_handle_sharded_objects
返回此策略是否可以处理保存 ShardedObjects。
- class core.dist_checkpointing.strategies.base.StrategyAction(value)
基类:
enum.Enum
指定保存与加载以及分片与通用操作。
- LOAD_COMMON = 'load_common'
- LOAD_SHARDED = 'load_sharded'
- SAVE_COMMON = 'save_common'
- SAVE_SHARDED = 'save_sharded'
- core.dist_checkpointing.strategies.base.get_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int)
检索给定操作、后端和版本的默认策略。
- core.dist_checkpointing.strategies.base.register_default_strategy(action: core.dist_checkpointing.strategies.base.StrategyAction, backend: str, version: int, strategy: Union[core.dist_checkpointing.strategies.base.SaveStrategyBase, core.dist_checkpointing.strategies.base.LoadStrategyBase])
将给定策略添加到默认策略的注册表中。
- 参数
action (StrategyAction) – 指定保存/加载和分片/通用
backend (str) – 策略成为默认策略的后端
version (int) – 策略成为默认策略的版本
strategy (SaveStrategyBase, LoadStrategyBase) – 要注册的策略
使用 TensorStore 加载和保存 Zarr 数组的策略。
- class core.dist_checkpointing.strategies.tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device: bool = False)
基类:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
用于 Zarr 后端的加载策略,使用 tensorstore 进行加载。
- check_backend_compatibility(loaded_version)
验证此策略是否与 loaded_backend 兼容。
- check_version_compatibility(loaded_version)
验证此策略是否与 loaded_version 兼容。
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
加载检查点的分片部分。
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
从检查点加载 ShardedTensors 的张量元数据。
返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。
字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。
- core.dist_checkpointing.strategies.tensorstore.merge_global_slice_with_shape(global_slice, actual_shape, key)
将全局切片与实际形状相交(防止溢出)。
- core.dist_checkpointing.strategies.tensorstore.open_ts_array(arr_path: pathlib.Path)
使用 Tensorstore 和基本设置打开 Zarr 文件数组。
- 参数
arr_path (Path) – Zarr (Tensorstore) 数组的路径
- core.dist_checkpointing.strategies.tensorstore.register_default_tensorstore_strategies()
注册利用 tensorstore 的默认策略。
两阶段检查点加载。
- class core.dist_checkpointing.strategies.two_stage.TwoStageDataParallelLoadShardedStrategy(data_parallel_group, cpu_transfer=True)
基类:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
从存储加载一个检查点副本并广播到其他节点。
此策略从最少数量的节点上的存储加载检查点,并使用 torch.distributed 将检查点分发到其他节点。加载使用 tensorstore 执行。
步骤: 0. (可选)创建 Gloo 分布式组 1. 在所有节点之间交换 ShardedTensors 元数据 2. 在 DP 组内对齐所需的张量 3. 对于每个全局唯一的张量: 3.a) 在其中一个 rank 上,从存储加载到 CPU 并移动到 CUDA 3.b) 在其他 rank 上分配 CUDA 张量 3.c) 在 DP 组内广播 3.d) 将张量内容复制到模型参数位置 3.e) 从 a) 和 b) 释放张量缓冲区
注意: 1. 加载和广播是顺序完成的,以避免主机和设备 OOM 2. 为每个张量完成的所有三个步骤之间存在大量重叠潜力: 2.a) 从存储加载到 numpy 2.b) 将 CPU 张量移动到 CUDA 2.c) 广播
- check_backend_compatibility(loaded_version)
验证此策略是否与 loaded_backend 兼容。
- check_version_compatibility(loaded_version)
验证此策略是否与 loaded_version 兼容。
- deduplicate_chunks(ten_metas: List[core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata])
按 chunk 对张量进行分组,然后选择 rank 最小的张量。
- 注意:通过适当的加载重叠,从随机 rank 加载
(而不是最小的 rank)可能在此处有利。
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
加载检查点的分片部分。
- load_tensor_from_storage(checkpoint_dir, ten_meta: core.dist_checkpointing.strategies.two_stage._ShardedTensorMetadata)
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
从检查点加载 ShardedTensors 的张量元数据。
返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。
字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。
- maybe_init_gloo_group()
- summarize_load_times()
- core.dist_checkpointing.strategies.two_stage.sharded_tensor_chunk_id(sharded_tensor: core.dist_checkpointing.mapping.ShardedTensor)
- core.dist_checkpointing.strategies.two_stage.timed(verbose=True)
使用 Zarr 作为底层格式的策略。
- class core.dist_checkpointing.strategies.zarr.ZarrLoadShardedStrategy
基类:
core.dist_checkpointing.strategies.base.LoadShardedStrategy
用于 Zarr 后端的加载策略。
- check_backend_compatibility(loaded_version)
验证此策略是否与 loaded_backend 兼容。
- check_version_compatibility(loaded_version)
验证此策略是否与 loaded_version 兼容。
- load(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
加载检查点的分片部分。
- load_tensors_metadata(checkpoint_dir: pathlib.Path)
从检查点加载 ShardedTensors 的张量元数据。
返回类似于分片状态字典的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片状态字典相反,后者的键对应于状态字典键)。
字典值是不包含任何数据和分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。
- class core.dist_checkpointing.strategies.zarr.ZarrSaveShardedStrategy(backend: str, version: int)
基类:
core.dist_checkpointing.strategies.base.SaveShardedStrategy
用于 Zarr 后端的保存策略。
- save(sharded_state_dict: Dict[str, Any], checkpoint_dir: pathlib.Path)
保存状态字典的分片部分。
- core.dist_checkpointing.strategies.zarr.flatten_range(sharded_tensor, x)
将展平范围应用于张量。
- core.dist_checkpointing.strategies.zarr.load_zarr_based_sharded_metadata(checkpoint_dir: pathlib.Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], numpy.dtype]]) → Dict[str, Any]
加载 Zarr 数组的元数据。
- 参数
checkpoint_dir (str) – 检查点根目录
get_shape_dtype_fn (str -> ((int, ...), np.dtype)) – 一个函数,为给定的 Zarr 数组路径返回数组形状和 dtype
- core.dist_checkpointing.strategies.zarr.pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: core.dist_checkpointing.mapping.ShardedTensor)
将张量填充到预期形状。
- core.dist_checkpointing.strategies.zarr.postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True)
将 numpy 数组转换为 torch 张量。
- core.dist_checkpointing.strategies.zarr.register_default_zarr_strategies()
注册与 Zarr 后端相关的默认策略。
各种加载和保存策略