dist_checkpointing 包
一个用于保存和加载分布式检查点的库。“分布式检查点”可以有各种底层格式(当前默认格式基于 Zarr),但具有一个独特的属性 - 在一种并行配置(张量/流水线/数据并行)中保存的检查点可以在不同的并行配置中加载。
使用该库需要使用来自mapping和optimizer模块的函数定义分片 state_dict 字典。这些 state dict 可以使用来自strategies模块的策略,通过serialization模块进行保存或加载。
用于保存和加载分布式检查点的入口点。
函数 load 和 save 等价于 torch.load 和 torch.save,但期望 torch.Tensors 用 mapping module 中的类进行包装。此外,load 期望分片 state dict 参数作为加载分片张量的指导。
- core.dist_checkpointing.serialization.get_default_load_sharded_strategy(checkpoint_dir: str) → core.dist_checkpointing.strategies.base.LoadShardedStrategy
获取默认的加载分片策略。
- core.dist_checkpointing.serialization.get_default_save_common_strategy(backend: str = 'torch', version: int = 1) → core.dist_checkpointing.strategies.base.SaveCommonStrategy
获取默认的保存通用策略。
- core.dist_checkpointing.serialization.get_default_save_sharded_strategy(backend: str = 'torch_dist', version: int = 1) → core.dist_checkpointing.strategies.base.SaveShardedStrategy
获取默认的保存分片策略。
- core.dist_checkpointing.serialization.load(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True, strict: Union[str, core.dist_checkpointing.validation.StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED) → Union[Dict[str, Any], Tuple[Dict[str, Any], Set[str], Set[str]]]
加载入口点。
在以下步骤中,以下动词指代相应的对象:- load = 从检查点加载 - extract = 从 sharded_state_dict 提取 - add = 添加到最终的 state dict 步骤:1. 加载通用 state dict 并形成结果 state dict 的基础 2. 将工厂应用于 sharded_state_dict 3. 提取 LocalNonPersistentObject 并添加 4. (可选)提取 ShardedObjects,加载并添加 5. 提取 ShardedBase,加载,应用工厂合并并添加
- 参数
sharded_state_dict (ShardedStateDict) – 现有模型的 state dict,其中填充了 ShardedTensors。用作映射,以确定应加载检查点中存储的全局张量的哪些部分。
checkpoint_dir (str) – 包含检查点的目录
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – 配置分片张量的加载行为
common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – 配置通用数据的加载行为
validate_access_integrity (bool default = True) – 检查每个张量分片是否被某个进程精确访问一次(作为主副本)
strict (StrictHandling, str, optional) – 确定在请求的分片 state dict 和检查点之间存在不匹配时的行为。有关更多详细信息,请参阅 StrictHandling 文档。某些值会影响此函数的返回值(返回缺少和意外的键)。默认为 True (StrictHandling.ASSUME_OK_UNEXPECTED),这不会产生任何性能开销。其他建议的值包括:False (StrictHandling.LOG_UNEXPECTED),它仅记录意外的键;或 StrictHandling.RETURN_ALL,它返回所有不匹配的键。
- 返回值
- 在大多数情况下仅返回
加载的 state dict。如果 strict 标志设置为
- 返回类型
StateDict 或 Tuple[StateDict, Set[str], Set[str]]
- core.dist_checkpointing.serialization.load_common_state_dict(checkpoint_dir: pathlib.Path) → Dict[str, Any]
从检查点加载通用(非分片)对象 state dict。
- 参数
checkpoint_dir (Path) – 检查点目录
- 返回值
包含来自检查点的非分片对象的 state dict
- 返回类型
StateDict
- core.dist_checkpointing.serialization.load_plain_tensors(checkpoint_dir: str) → Dict[str, Any]
加载检查点张量,不带任何分片和普通结构。
注意:不包含通用 state dict。
- 参数
checkpoint_dir (str) – 要从中加载张量的检查点目录。
- 返回值
检查点 state dict,仅包含 torch.Tensors。
- 返回类型
StateDict
- core.dist_checkpointing.serialization.load_sharded_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None, common_strategy: Optional[core.dist_checkpointing.strategies.base.LoadCommonStrategy] = None) → Dict[str, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedObject]]
从检查点加载分片元数据。
类似于 load_tensors_metadata,但也包括 ShardedObjects。
返回一个类似于分片 state dict 的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片 state dict 不同,在实际的分片 state dict 中,键对应于 state dict 键)。
Dict 值是不带任何分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。
具体的实现取决于加载策略。如果未给出策略,则使用给定后端的默认策略。
- 参数
checkpoint_dir (str) – 要从中加载的检查点目录
sharded_strategy (LoadShardedStrategy, optional) – 用于加载元数据的分片策略。默认为 None - 在这种情况下,将使用给定检查点类型的默认加载策略。
common_strategy (LoadCommonStrategy, optional) – 用于加载元数据的通用策略。默认为 None - 在这种情况下,将使用给定检查点类型的默认加载策略。除非 sharded_strategy 无法处理 ShardedObjects,否则不会使用此策略
- 返回值
- 不包含描述 ShardedTensors 数据的扁平 state dict
以及检查点中的 ShardedObjects
- 返回类型
CkptShardedMetadata
- core.dist_checkpointing.serialization.load_tensors_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None) → Dict[str, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedObject]]
从检查点加载张量元数据。
返回一个类似于分片 state dict 的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片 state dict 不同,在实际的分片 state dict 中,键对应于 state dict 键)。
Dict 值是不带任何分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。
具体的实现取决于加载策略。如果未给出策略,则使用给定后端的默认策略。
- 参数
checkpoint_dir (str) – 要从中加载的检查点目录
sharded_strategy (LoadShardedStrategy, optional) – 用于加载元数据的分片策略。默认为 None - 在这种情况下,将使用给定检查点类型的默认加载策略。
- 返回值
- 不包含描述 ShardedTensors 数据的扁平 state dict
在检查点中
- 返回类型
CkptShardedMetadata
- core.dist_checkpointing.serialization.remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)
确定适当的分片策略,并将删除操作委托给分片策略
- core.dist_checkpointing.serialization.save(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True, async_sharded_save: bool = False, preprocess_common_before_consistancy_check: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None) → Optional[core.dist_checkpointing.strategies.async_utils.AsyncRequest]
保存入口点。
从给定的 state dict 中提取 ShardedTensors。Rank 0 将检查点的“常规”部分保存到通用的 torch 文件。ShardedTensors 根据配置指定的策略进行保存。
步骤:1. 应用工厂 2. 提取并丢弃 LocalNonPersistentObject 3. 提取所有 ShardedBase 对象 4. 将所有其他对象保存到 common.pt 5. (可选)提取并保存 ShardedObjects 6. 保存所有 ShardedBase 对象 7. 写入包含后端和版本元数据的 metadata.json 文件。
步骤 (6) 可以异步执行(参见 async_sharded_save),在这种情况下,实际的保存体现在返回的异步请求中,并且可以由外部调用者调度。对于异步请求,步骤 (7) 作为最终确定函数之一添加,以便仅在检查点完成时才写入 metadata.json。
- 参数
sharded_state_dict (ShardedStateDict) – 填充了 ShardedTensors 的 state dict。用作映射,以确定应如何将本地张量保存为检查点中的全局张量。
checkpoint_dir (str) – 要将检查点保存到的目录
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – 配置分片张量保存行为和后端
common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – 配置通用数据保存行为和后端
validate_access_integrity (bool default = True) – 检查每个张量分片是否被某个进程精确访问一次(作为主副本)。它还确保通用 state dict 在所有 rank 中是一致的
async_sharded_save (bool, optional) – 如果为 True,则对于分片 state dict 部分,将调用异步保存实现,并将 AsyncRequest 返回给调用者。请注意,实际调度异步保存是调用者的责任。默认为 False。
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None) – 一个可调用函数,它将预处理通用 state dict(即可用于删除我们期望在 state dict 中不同的键)。该函数不得修改原始 state dict
- 返回值
- 如果 async_sharded_save 为 True,则返回
应由此函数的调用者调度的异步请求。否则为 None。
- 返回类型
AsyncRequest (可选)
用于表示张量和对象分片的核心库类。
主要预期用法是使用 ShardedTensor 类(主要是使用 ShardedTensor.from_rank_offsets 类方法)将 torch.Tensors 包装在 state dict 中。
- class core.dist_checkpointing.mapping.LocalNonpersistentObject(obj)
基类:
object
不应存储在检查点中,而应在本地还原的对象。
使用 LocalNonpersistentObject 包装 state dict 内的任何对象将导致:- 在保存期间,此对象将不会存储在检查点中 - 在加载期间,此对象的本地版本将放置在 state dict 中
- unwrap()
返回原始对象。
- core.dist_checkpointing.mapping.LocalNonpersitentObject
- class core.dist_checkpointing.mapping.ShardedBase
基类:
abc.ABC
ShardedTensor 和 ShardedStateDict 的基类。
- data: object
- key: str
- replica_id: Union[int, Tuple[int, ...]]
- abstract validate_metadata_integrity()
编纂关于元数据属性的约束。
- abstract without_data() → core.dist_checkpointing.mapping.ShardedBase
返回一个新的 ShardedBase 实例,其中 data=None。
- class core.dist_checkpointing.mapping.ShardedObject(key: str, data: object, global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: Union[int, Tuple[int, ...]] = 0)
基类:
core.dist_checkpointing.mapping.ShardedBase
表示本地对象和全局对象之间的映射。
全局对象被假定为由分布在不同进程之间的许多本地对象组成。
注意:与 ShardedTensor 相反,无法更改全局对象分片。从概念上讲,ShardedObject 是一个完全分片的 ShardedTensor,具有原子任意类型的元素。
- 参数
key – 全局张量的唯一标识符
data – 本地对象数据。仅在一致性验证时可以为 None
global_shape – 全局对象形状
global_offset – 本地对象在全局对象中的偏移量,以分片数量指定
replica_id – 指示本地对象相对于不同进程中的本地对象的复制
- data: object
- classmethod empty_from_unique_key(unique_key, replica_id: Union[int, Tuple[int, ...]] = 0) → core.dist_checkpointing.mapping.ShardedObject
从唯一键实例化 ShardedObject。
- 参数
unique_key – 格式为 <key>/shard_<global_offset>_<global_shape> 的字符串
replica_id – 指示本地对象相对于不同进程中的本地对象的复制
- 返回值
一个 data=None 的 ShardedObject
- global_offset: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- key: str
- replica_id: Union[int, Tuple[int, ...]] = 0
- property unique_key
返回此对象的唯一键
- validate_metadata_integrity()
编纂关于元数据属性的约束。
- without_data()
返回一个新的 ShardedBase 实例,其中 data=None。
- class core.dist_checkpointing.mapping.ShardedTensor(key: str, data: Optional[torch.Tensor], dtype: torch.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], axis_fragmentations: Optional[Tuple[int, ...]], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False, flattened_range: Optional[slice] = None)
基类:
core.dist_checkpointing.mapping.ShardedBase
表示本地张量和全局张量之间的映射。
全局张量被假定为由分布在不同进程之间的许多本地张量组成。
- 参数
key – 全局张量的唯一标识符
data – 本地张量数据。仅在一致性验证时可以为 None
dtype – 张量 dtype
local_shape – 本地张量形状
global_shape – 全局张量形状
global_offset – 本地张量在全局张量中的偏移量,以张量元素数量指定
axis_fragmentations – 每个轴的全局张量分片
replica_id – 指示给定本地张量相对于不同进程中的本地张量的复制
prepend_axis_num – 预先添加到本地张量的轴数,以反映全局张量形状。该行为类似于对本地张量进行unsqueeze操作。
allow_shape_mismatch – 如果为 True,则在加载期间,存储的张量的全局形状不必与预期的全局形状匹配。对于表示具有灵活形状的张量(例如,填充张量)很有用。
flattened_range – 指定应应用于具有 local_shape 的扁平张量的切片,以便获得存储为 data 的张量
- allow_shape_mismatch: bool = False
- axis_fragmentations: Optional[Tuple[int, ...]]
- data: Optional[torch.Tensor]
- dtype: torch.dtype
- flattened_range: Optional[slice] = None
- classmethod from_rank_offsets(key: str, data: torch.Tensor, *rank_offsets: Tuple[int, int, int], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, flattened_range: None = None, **init_kwargs)
允许构造 ShardedTensor,给定在进程 rank 中指定的偏移量。
- 参数
key (str) – 唯一键
data (torch.Tensor) – 本地张量数据
rank_offsets (Tuple[int, int, int]) – 每个元组 (axis, axis_rank_offset, axis_fragm) 表示,如果全局张量沿 axis 轴划分为 axis_fragm 片段,则本地张量数据对应于 axis_rank_offset 块。
replica_id (ReplicaId) – 请参阅 ShardedTensor
prepend_axis_num (int) – 请参阅 ShardedTensor
flattened_range (None) – 使用此构造函数时必须为 None
init_kwargs – 传递给 ShardedTensor.__init__
- classmethod from_rank_offsets_flat(key: str, data: torch.Tensor, non_flat_local_shape: Tuple[int, ...], *args, flattened_range: Optional[slice] = None, **kwargs)
允许构造扁平化的 ShardedTensor,给定在进程 rank 中指定的偏移量。
- 参数
key (str) –
data (torch.Tensor) – 这应该是一个扁平化的数据张量
non_flat_local_shape (Tuple[int, ...]) – 非扁平块的预期本地形状
*args – 不变地传递给 from_rank_offsets 构造函数
flattened_range (slice) – 请参阅 ShardedTensor。默认为 None,但必须设置为非 None 切片。
**kwargs –
- 返回值
构造的 ShardedTensor 实例
- 返回类型
- global_coordinates() → Tuple[numpy.ndarray, ...]
返回一个 np.ndarrays 元组,表示此 ShardedTensor 对应的全局张量的坐标。
- global_offset: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- global_slice() → Tuple[Union[int, slice], ...]
返回一个 int 和 slice 对象元组,表示此 ShardedTensor 对应的全局张量的切片。
- init_data(device: Union[str, torch.device], init_fn=torch.empty)
初始化此 ShardedTensor 的张量数据。
仅当 data 属性为 None 时调用。
- 参数
device (Union[str, torch.device]) – 张量所在的设备
init_fn (Callable, optional) – 用于初始化张量的函数。默认为 torch.empty。
- key: str
- local_chunk_offset_in_global() → Tuple[int, ...]
全局 chunk 数组中本地 chunk 的偏移量。
- 返回值
整个本地 chunk 在全局 chunk 数组中的偏移量。
- 返回类型
Tuple[int, …]
- local_coordinates() → Tuple[numpy.ndarray, ...]
返回表示此 ShardedTensor 对应的本地张量坐标的 np.ndarrays 元组。
- local_shape: Tuple[int, ...]
- max_allowed_chunks() → Tuple[int, ...]
返回此 ShardedTensor 允许的最大 chunk 数。
- narrow(dim: int, start: int, length: int) → List[core.dist_checkpointing.mapping.ShardedTensor]
这是 ShardedTensor 的 torch.narrow 的类似物。
Narrowing 假设我们在每个 rank 上 narrowing 一个本地张量。 这对 local_shape、global_shape、global_offset 等有影响。
- 参数
dim (int) – 要 narrowing 的维度。不包括 prepended 轴。
start (int) – 起始元素
length (int) – 切片的长度
- 返回值
- narrowed ShardedTensors。对于非扁平张量,
列表将始终有 1 个元素。对于扁平 ShardedTensors,元素数量取决于 dim 和 overlap,因为扁平张量必须是连续的。 特别是,列表可以为空。
- 返回类型
List[ShardedTensor]
- prepend_axis_num: int = 0
- replica_id: Union[int, Tuple[int, ...]] = 0
- validate_metadata_integrity() → None
编纂关于元数据属性的约束。
当使用 from_rank_offsets 或 from_rank_offsets_flat 构造函数实例化 ShardedTensor 类时,可以保证满足这些约束。
- 返回值
None
- without_data()
返回一个新的 ShardedBase 实例,其中 data=None。
- class core.dist_checkpointing.mapping.ShardedTensorFactory(key: str, data: torch.Tensor, build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]], Optional[slice]], Dict[str, Any]], merge_fn: Callable[[Dict[str, Any]], torch.Tensor], replica_id: Union[int, Tuple[int, ...]] = 0, flattened_range: Optional[slice] = None)
基类:
core.dist_checkpointing.mapping.ShardedBase
允许在序列化之前/之后对张量应用转换。
这些转换的本质是,它们可以像应用于模型参数一样应用于优化器状态。 具有分片张量的最终状态字典必须在功能上取决于 build_fn 参数 (key, data, replica_id, flattened_range),这些参数将由优化器提供。
构建器在保存之前从张量创建一个子状态字典,合并器在加载后合并相应的状态字典。
- 参数
key (str) – 工厂的唯一标识符
data (torch.Tensor) – 将由此工厂进一步转换的原始模型参数
build_fn (callable) – 将原始张量转换为分片状态字典的函数
merge_fn (callable) – 将加载的子树转换回单个张量的函数 ( build_fn 的逆函数)
replica_id (ReplicaId) – 指示工厂相对于不同进程中的工厂的复制
flattened_range (slice, optional) – 指示应用于工厂生成 ShardedTensors 的附加扁平化
- build()
从原始张量构建 ShardedStateDict
- build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]], Optional[slice]], Dict[str, Any]]
- data: torch.Tensor
- flattened_range: Optional[slice] = None
- key: str
- merge_fn: Callable[[Dict[str, Any]], torch.Tensor]
- replica_id: Union[int, Tuple[int, ...]] = 0
- validate_metadata_integrity()
无法应用合理的检查
- without_data()
返回一个新的 ShardedBase 实例,其中 data=None。
- core.dist_checkpointing.mapping.apply_factories(sharded_state_dict: Dict[str, Any])
就地将 ShardedTensorFactories 转换为 ShardedTensors。
- 参数
sharded_state_dict (ShardedStateDict) – 可能包含 ShardedTensorFactory 对象的状态字典
- 返回值
状态字典已就地修改
- 返回类型
None
- core.dist_checkpointing.mapping.apply_factory_merges(x1: Dict[str, Any], x2: Dict[str, Any], key: Tuple[str, ...] = ()) → Dict[str, Any]
就地应用由 ShardedTensorFactories 定义的合并。
- 参数
x1 (StateDict) – 从 checkpoint 加载的状态字典
x2 (ShardedStateDict) – x1 的子集(在字典键方面),其中 ShardedTensorFactory 作为(可能嵌套的)值,定义如何合并来自 x1 状态字典的对象
key (Tuple[str, ...]) – 递归调用中的当前键。仅用于报告有意义的错误
- 返回值
x1 已就地修改
- 返回类型
StateDict
- core.dist_checkpointing.mapping.is_main_replica(replica_id: Union[int, Tuple[int, ...]])
检查给定的 replica_id 是否被视为主副本。
“主”副本是:- 整数 0 - 或所有元素均为 0 的可迭代对象
应用程序有责任为分片张量设置正确的副本。
- 参数
replica_id (Union[int, Tuple[int, ...]]) – 副本 ID
- 返回值
对于“主”副本为 True
- 返回类型
(bool)
用于基于模型参数的现有分片为优化器状态定义分片的助手函数。
- core.dist_checkpointing.optimizer.get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, int]
生成从优化器参数到优化器状态 ID 的映射。
- core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(model_sharded_state_dict: Dict[str, Any], optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]]
生成从优化器状态 ID 到模型分片参数的映射。
- 参数
model_sharded_state_dict – 包含所有模型分片张量的分片状态字典(可以有任何结构)
optim_params_iter – 迭代优化器跟踪的模型参数的可迭代对象。迭代顺序必须与优化器参数中的顺序相同。
- 返回值
- 从优化器状态 ID 的映射
到模型分片参数。
- 返回类型
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]
- core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory], optim_param: torch.Tensor, prefix: str) → Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]
基于模型参数为优化器参数构建 ShardedTensor 或 ShardedTensorFactory
- 参数
model_param (Union[ShardedTensor, ShardedTensorFactory]) – 模型参数
optim_param (torch.Tensor) – 对应的优化器参数
prefix (str) – ShardedTensor 或 ShardedTensorFactory 的优化器前缀
- 返回值
包装后的优化器参数
- 返回类型
Union[ShardedTensor, ShardedTensorFactory]
- core.dist_checkpointing.optimizer.optim_state_to_sharding_state(optim_state_dict: Dict[str, Any], id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor], exclude_keys: Tuple[str] = ())
就地将优化器状态字典转换为基于模型状态字典的分片状态字典。
可用于向最常见的优化器状态字典添加分片信息。 为 optim_state_dict[‘state’] 中的每个键创建单独的 ShardedTensor(例如,对于 torch.optim.Adam,将为 exp_avg 和 exp_avg_sq 创建单独的张量)
- 参数
optim_state_dict (StateDict) – 优化器状态字典,其中状态参数位于 state 键下,组超参数位于 param_groups -> params 键下。
id_to_sharded_param_map (Dict[int, ShardedTensor]) – 从优化器参数 ID 到模型分片张量的映射。 可以使用 get_param_id_to_sharded_param_map 函数生成。
exclude_keys (Tuple[str]) – 从最终状态字典中排除的优化器状态键。
- 返回值
状态字典已就地修改
- 返回类型
None
用于管理分布式 checkpoint 元数据的模块。
- class core.dist_checkpointing.core.CheckpointingConfig(sharded_backend: str, sharded_backend_version: int = 1, common_backend: str = 'torch', common_backend_version: int = 1)
基类:
object
记录 checkpoint 中使用的后端。
Checkpoint 配置跟踪用于存储分片张量 (sharded_backend) 和其他对象 (common_backend) 的格式。
请注意,版本控制不是针对 checkpoint 内容(这是应用程序特定的),而是针对 checkpoint 格式本身。
- common_backend: str = 'torch'
- common_backend_version: int = 1
- sharded_backend: str
- sharded_backend_version: int = 1
- exception core.dist_checkpointing.core.CheckpointingException
Bases:
Exception
与 checkpoint 相关的基本异常
- core.dist_checkpointing.core.check_is_distributed_checkpoint(checkpoint_dir)
检查 metadata.json 是否存在于 checkpoint 中并且是有效的配置。
- 参数
checkpoint_dir – checkpoint 目录
- 返回值
如果 metadata.json 存在于 checkpoint 中并且是有效的配置,则为 True。
- 返回类型
bool
- core.dist_checkpointing.core.maybe_load_config(checkpoint_dir: str) → Optional[core.dist_checkpointing.core.CheckpointingConfig]
如果 checkpoint_dir 是分布式 checkpoint,则返回 checkpoint 配置,否则返回 None
- 参数
checkpoint_dir – checkpoint 目录
- 返回值
如果 checkpoint 不是有效的分布式 checkpoint,则为 None
- 返回类型
CheckpointingConfig (可选)
- core.dist_checkpointing.core.save_config(config: core.dist_checkpointing.core.CheckpointingConfig, checkpoint_dir: str)
将给定配置保存到 checkpoint 目录。
- 参数
config – checkpoint 配置
checkpoint_dir – checkpoint 目录
- 返回值
None
用于操作字典和列表的实用工具。
此模块中的所有函数都处理字典和列表的嵌套。 其他对象(例如元组)被视为不能遍历的原子叶类型。
- core.dist_checkpointing.dict_utils.dict_list_map_inplace(f: Callable[[core.dist_checkpointing.dict_utils.U], core.dist_checkpointing.dict_utils.V], x: Union[Dict, List, core.dist_checkpointing.dict_utils.U])
使用给定函数就地映射字典和列表。
- core.dist_checkpointing.dict_utils.dict_list_map_outplace(f: Callable[[core.dist_checkpointing.dict_utils.U], core.dist_checkpointing.dict_utils.V], x: Union[Dict, List, core.dist_checkpointing.dict_utils.U]) → Union[Dict, List, core.dist_checkpointing.dict_utils.V]
使用给定函数异地映射字典和列表。
- core.dist_checkpointing.dict_utils.dict_map(f: Callable, d: dict)
字典的 map 等效项。
- core.dist_checkpointing.dict_utils.dict_map_with_key(f: Callable, d: dict)
字典的 map 等效项,函数接受元组 (key, value)。
- core.dist_checkpointing.dict_utils.diff(x1: Any, x2: Any, prefix: Tuple = ()) → Tuple[list, list, list]
字典的递归 diff。
- 参数
x1 (object) – 左侧字典
x2 (object) – 右侧字典
prefix (tuple) – 跟踪递归调用。用于报告不同的键。
- 返回值
- 元组,包含:
only_left:仅在左侧字典中存在的前缀
only_right:仅在右侧字典中存在的前缀
- mismatch:在两个字典中都存在但跨字典不相等的值。
对于张量,检查所有元素的相等性。 每个元素都是一个元组(前缀,左侧值的类型,右侧值的类型)。
- 返回类型
Tuple[list, list, list]
- core.dist_checkpointing.dict_utils.extract_matching_values(x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False) → Tuple[Union[dict, list], Union[dict, list]]
返回匹配和不匹配的值。保持层次结构。
- 参数
x (Union[dict, list]) – 要处理的状态字典。顶级参数必须是字典或列表
predicate (object -> bool) – 确定匹配值
return_lists_as_dicts (bool) – 如果为 True,则匹配的列表将转换为字典,键指示原始元素的索引。 用于重建原始层次结构。
- core.dist_checkpointing.dict_utils.inspect_types(x: Any, prefix: Tuple = (), indent: int = 4)
帮助打印(嵌套)字典值的类型。
- core.dist_checkpointing.dict_utils.map_reduce(xs: typing.Iterable, key_fn: typing.Callable = <function <lambda>>, value_fn: typing.Callable = <function <lambda>>, reduce_fn: typing.Callable = <function <lambda>>) → dict
简单的 map-reduce 实现,遵循 more_itertools.map_reduce 接口。
- core.dist_checkpointing.dict_utils.merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ())
递归地合并字典和列表。
- core.dist_checkpointing.dict_utils.nested_items_iter(x: Union[dict, list])
返回给定字典或列表的(嵌套)元组 (container, key, value) 的迭代器。
- core.dist_checkpointing.dict_utils.nested_values(x: Union[dict, list])
返回给定字典或列表的(嵌套)值的迭代器。
用于操作分片张量和分片状态字典的助手函数。
- core.dist_checkpointing.utils.add_prefix_for_sharding(sharded_state_dict: Dict[str, Any], prefix: str)
就地将给定前缀添加到给定状态字典中的所有 ShardedBase 对象。
- 参数
sharded_state_dict (ShardedStateDict) – 分片状态字典
prefix (str) – 要添加的前缀
- 返回值
状态字典已就地修改
- 返回类型
None
- core.dist_checkpointing.utils.apply_prefix_mapping(sharded_state_dict: Dict[str, Any], prefix_map: Dict[str, str])
仅在与映射中的前缀之一匹配的键中替换前缀。
- 参数
sharded_state_dict (ShardedStateDict) – 要替换键的分片状态字典
prefix_map (Dict[str, str]) – 旧前缀 -> 新前缀的映射。 使用每个键的第一个匹配前缀
- 返回值
状态字典已就地修改
- 返回类型
None
- core.dist_checkpointing.utils.extract_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
从给定的状态字典中提取仅由 LocalNonpersistentObjects 组成的字典。
- 参数
sharded_state_dict – 可能包含 LocalNonpersistentObjects 的状态字典
- 返回值
- 元组,包含:
包含所有 LocalNonpersistentObjects 的状态字典(保持原始状态字典结构)
包含所有其他对象的状态字典(保持原始状态字典结构)
- 返回类型
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_base(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
从包含任何对象的给定状态字典中提取仅由 ShardedBase 组成的字典。
- 参数
sharded_state_dict – 可能包含 ShardedBase 对象的状态字典
- 返回值
- 元组,包含:
包含所有 ShardedBase 对象的状态字典(保持原始状态字典结构)
包含所有其他对象的状态字典(保持原始状态字典结构)
- 返回类型
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
从包含任何对象的给定状态字典中提取仅由 ShardedTensor 对象组成的字典。
- 参数
sharded_state_dict – 可能包含 ShardedTensor 对象的状态字典
- 返回值
- 元组,包含:
包含所有 ShardedTensor 的状态字典(保持原始状态字典结构)
包含 ShardedTensor 以外的所有对象的状态字典(保持原始状态字典结构)
- 返回类型
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
从包含任何对象的给定状态字典中提取仅由 ShardedTensor 和 ShardedTensorFactory 对象组成的字典。
- 参数
sharded_state_dict – 可能包含 ShardedTensor 和 ShardedTensorFactory 对象的状态字典
- 返回值
- 元组,包含:
包含所有 ShardedTensor 和 ShardedTensorFactory 的状态字典(保持原始状态字典结构)
包含所有其他对象的状态字典(保持原始状态字典结构)
- 返回类型
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]
从包含任何对象的给定状态字典中提取仅由 ShardedTensor、ShardedTensorFactory 和 LocalNonpersistentObject 对象组成的字典。
- 参数
sharded_state_dict – 可能包含 ShardedTensor、ShardedTensorFactory
objects (and LocalNonpersistentObject) –
- 返回值
- 元组,包含:
包含所有 ShardedTensor、ShardedTensorFactory 和 LocalNonpersistentObject 的状态字典(保持原始状态字典结构)
包含所有其他对象的状态字典(保持原始状态字典结构)
- 返回类型
Tuple[ShardedStateDict, StateDict]
- core.dist_checkpointing.utils.replace_prefix_for_sharding(sharded_state_dict: Dict[str, Any], old_prefix: str, new_prefix: str)
在给定状态字典的所有分片键中替换给定前缀。
如果某些键不以给定前缀开头,则会报错。
- 参数
sharded_state_dict (ShardedStateDict) – 要替换键的分片状态字典
old_prefix (str) – 要在每个键中替换的前缀
new_prefix (str) – 新前缀
- 返回值
状态字典已就地修改
- 返回类型
None