Megatron Core 用户指南

dist_checkpointing 包

一个用于保存和加载分布式检查点的库。“分布式检查点”可以有各种底层格式(当前默认格式基于 Zarr),但具有一个独特的属性 - 在一种并行配置(张量/流水线/数据并行)中保存的检查点可以在不同的并行配置中加载。

使用该库需要使用来自mappingoptimizer模块的函数定义分片 state_dict 字典。这些 state dict 可以使用来自strategies模块的策略,通过serialization模块进行保存或加载。

用于保存和加载分布式检查点的入口点。

函数 loadsave 等价于 torch.loadtorch.save,但期望 torch.Tensors 用 mapping module 中的类进行包装。此外,load 期望分片 state dict 参数作为加载分片张量的指导。

core.dist_checkpointing.serialization.get_default_load_sharded_strategy(checkpoint_dir: str)core.dist_checkpointing.strategies.base.LoadShardedStrategy

获取默认的加载分片策略。

core.dist_checkpointing.serialization.get_default_save_common_strategy(backend: str = 'torch', version: int = 1)core.dist_checkpointing.strategies.base.SaveCommonStrategy

获取默认的保存通用策略。

core.dist_checkpointing.serialization.get_default_save_sharded_strategy(backend: str = 'torch_dist', version: int = 1)core.dist_checkpointing.strategies.base.SaveShardedStrategy

获取默认的保存分片策略。

core.dist_checkpointing.serialization.load(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True, strict: Union[str, core.dist_checkpointing.validation.StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED) → Union[Dict[str, Any], Tuple[Dict[str, Any], Set[str], Set[str]]]

加载入口点。

在以下步骤中,以下动词指代相应的对象:- load = 从检查点加载 - extract = 从 sharded_state_dict 提取 - add = 添加到最终的 state dict 步骤:1. 加载通用 state dict 并形成结果 state dict 的基础 2. 将工厂应用于 sharded_state_dict 3. 提取 LocalNonPersistentObject 并添加 4. (可选)提取 ShardedObjects,加载并添加 5. 提取 ShardedBase,加载,应用工厂合并并添加

参数
  • sharded_state_dict (ShardedStateDict) – 现有模型的 state dict,其中填充了 ShardedTensors。用作映射,以确定应加载检查点中存储的全局张量的哪些部分。

  • checkpoint_dir (str) – 包含检查点的目录

  • sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – 配置分片张量的加载行为

  • common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – 配置通用数据的加载行为

  • validate_access_integrity (bool default = True) – 检查每个张量分片是否被某个进程精确访问一次(作为主副本)

  • strict (StrictHandling, str, optional) – 确定在请求的分片 state dict 和检查点之间存在不匹配时的行为。有关更多详细信息,请参阅 StrictHandling 文档。某些值会影响此函数的返回值(返回缺少和意外的键)。默认为 True (StrictHandling.ASSUME_OK_UNEXPECTED),这不会产生任何性能开销。其他建议的值包括:False (StrictHandling.LOG_UNEXPECTED),它仅记录意外的键;或 StrictHandling.RETURN_ALL,它返回所有不匹配的键。

返回值

在大多数情况下仅返回

加载的 state dict。如果 strict 标志设置为

返回类型

StateDict 或 Tuple[StateDict, Set[str], Set[str]]

core.dist_checkpointing.serialization.load_common_state_dict(checkpoint_dir: pathlib.Path) → Dict[str, Any]

从检查点加载通用(非分片)对象 state dict。

参数

checkpoint_dir (Path) – 检查点目录

返回值

包含来自检查点的非分片对象的 state dict

返回类型

StateDict

core.dist_checkpointing.serialization.load_plain_tensors(checkpoint_dir: str) → Dict[str, Any]

加载检查点张量,不带任何分片和普通结构。

注意:不包含通用 state dict。

参数

checkpoint_dir (str) – 要从中加载张量的检查点目录。

返回值

检查点 state dict,仅包含 torch.Tensors。

返回类型

StateDict

core.dist_checkpointing.serialization.load_sharded_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None, common_strategy: Optional[core.dist_checkpointing.strategies.base.LoadCommonStrategy] = None) → Dict[str, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedObject]]

从检查点加载分片元数据。

类似于 load_tensors_metadata,但也包括 ShardedObjects。

返回一个类似于分片 state dict 的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片 state dict 不同,在实际的分片 state dict 中,键对应于 state dict 键)。

Dict 值是不带任何分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。

具体的实现取决于加载策略。如果未给出策略,则使用给定后端的默认策略。

参数
  • checkpoint_dir (str) – 要从中加载的检查点目录

  • sharded_strategy (LoadShardedStrategy, optional) – 用于加载元数据的分片策略。默认为 None - 在这种情况下,将使用给定检查点类型的默认加载策略。

  • common_strategy (LoadCommonStrategy, optional) – 用于加载元数据的通用策略。默认为 None - 在这种情况下,将使用给定检查点类型的默认加载策略。除非 sharded_strategy 无法处理 ShardedObjects,否则不会使用此策略

返回值

不包含描述 ShardedTensors 数据的扁平 state dict

以及检查点中的 ShardedObjects

返回类型

CkptShardedMetadata

core.dist_checkpointing.serialization.load_tensors_metadata(checkpoint_dir: str, sharded_strategy: Optional[core.dist_checkpointing.strategies.base.LoadShardedStrategy] = None) → Dict[str, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedObject]]

从检查点加载张量元数据。

返回一个类似于分片 state dict 的字典,但请注意,字典键只是 ShardedTensor 键(与实际的分片 state dict 不同,在实际的分片 state dict 中,键对应于 state dict 键)。

Dict 值是不带任何分片的 ShardedTensors(因此,唯一有用的信息是张量的全局形状和 dtype)。

具体的实现取决于加载策略。如果未给出策略,则使用给定后端的默认策略。

参数
  • checkpoint_dir (str) – 要从中加载的检查点目录

  • sharded_strategy (LoadShardedStrategy, optional) – 用于加载元数据的分片策略。默认为 None - 在这种情况下,将使用给定检查点类型的默认加载策略。

返回值

不包含描述 ShardedTensors 数据的扁平 state dict

在检查点中

返回类型

CkptShardedMetadata

core.dist_checkpointing.serialization.remove_sharded_tensors(checkpoint_dir: str, key_prefix: str)

确定适当的分片策略,并将删除操作委托给分片策略

core.dist_checkpointing.serialization.save(sharded_state_dict: Dict[str, Any], checkpoint_dir: str, sharded_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveShardedStrategy, Tuple[str, int]]] = None, common_strategy: Optional[Union[core.dist_checkpointing.strategies.base.SaveCommonStrategy, Tuple[str, int]]] = None, validate_access_integrity: bool = True, async_sharded_save: bool = False, preprocess_common_before_consistancy_check: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None) → Optional[core.dist_checkpointing.strategies.async_utils.AsyncRequest]

保存入口点。

从给定的 state dict 中提取 ShardedTensors。Rank 0 将检查点的“常规”部分保存到通用的 torch 文件。ShardedTensors 根据配置指定的策略进行保存。

步骤:1. 应用工厂 2. 提取并丢弃 LocalNonPersistentObject 3. 提取所有 ShardedBase 对象 4. 将所有其他对象保存到 common.pt 5. (可选)提取并保存 ShardedObjects 6. 保存所有 ShardedBase 对象 7. 写入包含后端和版本元数据的 metadata.json 文件。

步骤 (6) 可以异步执行(参见 async_sharded_save),在这种情况下,实际的保存体现在返回的异步请求中,并且可以由外部调用者调度。对于异步请求,步骤 (7) 作为最终确定函数之一添加,以便仅在检查点完成时才写入 metadata.json。

参数
  • sharded_state_dict (ShardedStateDict) – 填充了 ShardedTensors 的 state dict。用作映射,以确定应如何将本地张量保存为检查点中的全局张量。

  • checkpoint_dir (str) – 要将检查点保存到的目录

  • sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional) – 配置分片张量保存行为和后端

  • common_strategy (SaveCommonStrategy, Tuple[str, int], optional) – 配置通用数据保存行为和后端

  • validate_access_integrity (bool default = True) – 检查每个张量分片是否被某个进程精确访问一次(作为主副本)。它还确保通用 state dict 在所有 rank 中是一致的

  • async_sharded_save (bool, optional) – 如果为 True,则对于分片 state dict 部分,将调用异步保存实现,并将 AsyncRequest 返回给调用者。请注意,实际调度异步保存是调用者的责任。默认为 False。

  • preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None) – 一个可调用函数,它将预处理通用 state dict(即可用于删除我们期望在 state dict 中不同的键)。该函数不得修改原始 state dict

返回值

如果 async_sharded_save 为 True,则返回

应由此函数的调用者调度的异步请求。否则为 None。

返回类型

AsyncRequest (可选)

用于表示张量和对象分片的核心库类。

主要预期用法是使用 ShardedTensor 类(主要是使用 ShardedTensor.from_rank_offsets 类方法)将 torch.Tensors 包装在 state dict 中。

class core.dist_checkpointing.mapping.LocalNonpersistentObject(obj)

基类:object

不应存储在检查点中,而应在本地还原的对象。

使用 LocalNonpersistentObject 包装 state dict 内的任何对象将导致:- 在保存期间,此对象将不会存储在检查点中 - 在加载期间,此对象的本地版本将放置在 state dict 中

unwrap()

返回原始对象。

core.dist_checkpointing.mapping.LocalNonpersitentObject

别名:core.dist_checkpointing.mapping.LocalNonpersistentObject

class core.dist_checkpointing.mapping.ShardedBase

基类:abc.ABC

ShardedTensor 和 ShardedStateDict 的基类。

data: object
key: str
replica_id: Union[int, Tuple[int, ...]]
abstract validate_metadata_integrity()

编纂关于元数据属性的约束。

abstract without_data()core.dist_checkpointing.mapping.ShardedBase

返回一个新的 ShardedBase 实例,其中 data=None。

class core.dist_checkpointing.mapping.ShardedObject(key: str, data: object, global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], replica_id: Union[int, Tuple[int, ...]] = 0)

基类:core.dist_checkpointing.mapping.ShardedBase

表示本地对象和全局对象之间的映射。

全局对象被假定为由分布在不同进程之间的许多本地对象组成。

注意:与 ShardedTensor 相反,无法更改全局对象分片。从概念上讲,ShardedObject 是一个完全分片的 ShardedTensor,具有原子任意类型的元素。

参数
  • key – 全局张量的唯一标识符

  • data – 本地对象数据。仅在一致性验证时可以为 None

  • global_shape – 全局对象形状

  • global_offset – 本地对象在全局对象中的偏移量,以分片数量指定

  • replica_id – 指示本地对象相对于不同进程中的本地对象的复制

data: object
classmethod empty_from_unique_key(unique_key, replica_id: Union[int, Tuple[int, ...]] = 0)core.dist_checkpointing.mapping.ShardedObject

从唯一键实例化 ShardedObject。

参数
  • unique_key – 格式为 <key>/shard_<global_offset>_<global_shape> 的字符串

  • replica_id – 指示本地对象相对于不同进程中的本地对象的复制

返回值

一个 data=None 的 ShardedObject

global_offset: Tuple[int, ...]
global_shape: Tuple[int, ...]
key: str
replica_id: Union[int, Tuple[int, ...]] = 0
property unique_key

返回此对象的唯一键

validate_metadata_integrity()

编纂关于元数据属性的约束。

without_data()

返回一个新的 ShardedBase 实例,其中 data=None。

class core.dist_checkpointing.mapping.ShardedTensor(key: str, data: Optional[torch.Tensor], dtype: torch.dtype, local_shape: Tuple[int, ...], global_shape: Tuple[int, ...], global_offset: Tuple[int, ...], axis_fragmentations: Optional[Tuple[int, ...]], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, allow_shape_mismatch: bool = False, flattened_range: Optional[slice] = None)

基类:core.dist_checkpointing.mapping.ShardedBase

表示本地张量和全局张量之间的映射。

全局张量被假定为由分布在不同进程之间的许多本地张量组成。

参数
  • key – 全局张量的唯一标识符

  • data – 本地张量数据。仅在一致性验证时可以为 None

  • dtype – 张量 dtype

  • local_shape – 本地张量形状

  • global_shape – 全局张量形状

  • global_offset – 本地张量在全局张量中的偏移量,以张量元素数量指定

  • axis_fragmentations – 每个轴的全局张量分片

  • replica_id – 指示给定本地张量相对于不同进程中的本地张量的复制

  • prepend_axis_num – 预先添加到本地张量的轴数,以反映全局张量形状。该行为类似于对本地张量进行unsqueeze操作。

  • allow_shape_mismatch – 如果为 True,则在加载期间,存储的张量的全局形状不必与预期的全局形状匹配。对于表示具有灵活形状的张量(例如,填充张量)很有用。

  • flattened_range – 指定应应用于具有 local_shape 的扁平张量的切片,以便获得存储为 data 的张量

allow_shape_mismatch: bool = False
axis_fragmentations: Optional[Tuple[int, ...]]
data: Optional[torch.Tensor]
dtype: torch.dtype
flattened_range: Optional[slice] = None
classmethod from_rank_offsets(key: str, data: torch.Tensor, *rank_offsets: Tuple[int, int, int], replica_id: Union[int, Tuple[int, ...]] = 0, prepend_axis_num: int = 0, flattened_range: None = None, **init_kwargs)

允许构造 ShardedTensor,给定在进程 rank 中指定的偏移量。

参数
  • key (str) – 唯一键

  • data (torch.Tensor) – 本地张量数据

  • rank_offsets (Tuple[int, int, int]) – 每个元组 (axis, axis_rank_offset, axis_fragm) 表示,如果全局张量沿 axis 轴划分为 axis_fragm 片段,则本地张量数据对应于 axis_rank_offset 块。

  • replica_id (ReplicaId) – 请参阅 ShardedTensor

  • prepend_axis_num (int) – 请参阅 ShardedTensor

  • flattened_range (None) – 使用此构造函数时必须为 None

  • init_kwargs – 传递给 ShardedTensor.__init__

classmethod from_rank_offsets_flat(key: str, data: torch.Tensor, non_flat_local_shape: Tuple[int, ...], *args, flattened_range: Optional[slice] = None, **kwargs)

允许构造扁平化的 ShardedTensor,给定在进程 rank 中指定的偏移量。

参数
  • key (str) –

  • data (torch.Tensor) – 这应该是一个扁平化的数据张量

  • non_flat_local_shape (Tuple[int, ...]) – 非扁平块的预期本地形状

  • *args – 不变地传递给 from_rank_offsets 构造函数

  • flattened_range (slice) – 请参阅 ShardedTensor。默认为 None,但必须设置为非 None 切片。

  • **kwargs

返回值

构造的 ShardedTensor 实例

返回类型

ShardedTensor

global_coordinates() → Tuple[numpy.ndarray, ...]

返回一个 np.ndarrays 元组,表示此 ShardedTensor 对应的全局张量的坐标。

global_offset: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_slice() → Tuple[Union[int, slice], ...]

返回一个 int 和 slice 对象元组,表示此 ShardedTensor 对应的全局张量的切片。

init_data(device: Union[str, torch.device], init_fn=torch.empty)

初始化此 ShardedTensor 的张量数据。

仅当 data 属性为 None 时调用。

参数
  • device (Union[str, torch.device]) – 张量所在的设备

  • init_fn (Callable, optional) – 用于初始化张量的函数。默认为 torch.empty

key: str
local_chunk_offset_in_global() → Tuple[int, ...]

全局 chunk 数组中本地 chunk 的偏移量。

返回值

整个本地 chunk 在全局 chunk 数组中的偏移量。

返回类型

Tuple[int, …]

local_coordinates() → Tuple[numpy.ndarray, ...]

返回表示此 ShardedTensor 对应的本地张量坐标的 np.ndarrays 元组。

local_shape: Tuple[int, ...]
max_allowed_chunks() → Tuple[int, ...]

返回此 ShardedTensor 允许的最大 chunk 数。

narrow(dim: int, start: int, length: int) → List[core.dist_checkpointing.mapping.ShardedTensor]

这是 ShardedTensor 的 torch.narrow 的类似物。

Narrowing 假设我们在每个 rank 上 narrowing 一个本地张量。 这对 local_shape、global_shape、global_offset 等有影响。

参数
  • dim (int) – 要 narrowing 的维度。不包括 prepended 轴。

  • start (int) – 起始元素

  • length (int) – 切片的长度

返回值

narrowed ShardedTensors。对于非扁平张量,

列表将始终有 1 个元素。对于扁平 ShardedTensors,元素数量取决于 dim 和 overlap,因为扁平张量必须是连续的。 特别是,列表可以为空。

返回类型

List[ShardedTensor]

prepend_axis_num: int = 0
replica_id: Union[int, Tuple[int, ...]] = 0
validate_metadata_integrity() → None

编纂关于元数据属性的约束。

当使用 from_rank_offsetsfrom_rank_offsets_flat 构造函数实例化 ShardedTensor 类时,可以保证满足这些约束。

返回值

None

without_data()

返回一个新的 ShardedBase 实例,其中 data=None。

class core.dist_checkpointing.mapping.ShardedTensorFactory(key: str, data: torch.Tensor, build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]], Optional[slice]], Dict[str, Any]], merge_fn: Callable[[Dict[str, Any]], torch.Tensor], replica_id: Union[int, Tuple[int, ...]] = 0, flattened_range: Optional[slice] = None)

基类:core.dist_checkpointing.mapping.ShardedBase

允许在序列化之前/之后对张量应用转换。

这些转换的本质是,它们可以像应用于模型参数一样应用于优化器状态。 具有分片张量的最终状态字典必须在功能上取决于 build_fn 参数 (key, data, replica_id, flattened_range),这些参数将由优化器提供。

构建器在保存之前从张量创建一个子状态字典,合并器在加载后合并相应的状态字典。

参数
  • key (str) – 工厂的唯一标识符

  • data (torch.Tensor) – 将由此工厂进一步转换的原始模型参数

  • build_fn (callable) – 将原始张量转换为分片状态字典的函数

  • merge_fn (callable) – 将加载的子树转换回单个张量的函数 ( build_fn 的逆函数)

  • replica_id (ReplicaId) – 指示工厂相对于不同进程中的工厂的复制

  • flattened_range (slice, optional) – 指示应用于工厂生成 ShardedTensors 的附加扁平化

build()

从原始张量构建 ShardedStateDict

build_fn: Callable[[str, torch.Tensor, Union[int, Tuple[int, ...]], Optional[slice]], Dict[str, Any]]
data: torch.Tensor
flattened_range: Optional[slice] = None
key: str
merge_fn: Callable[[Dict[str, Any]], torch.Tensor]
replica_id: Union[int, Tuple[int, ...]] = 0
validate_metadata_integrity()

无法应用合理的检查

without_data()

返回一个新的 ShardedBase 实例,其中 data=None。

core.dist_checkpointing.mapping.apply_factories(sharded_state_dict: Dict[str, Any])

就地将 ShardedTensorFactories 转换为 ShardedTensors。

参数

sharded_state_dict (ShardedStateDict) – 可能包含 ShardedTensorFactory 对象的状态字典

返回值

状态字典已就地修改

返回类型

None

core.dist_checkpointing.mapping.apply_factory_merges(x1: Dict[str, Any], x2: Dict[str, Any], key: Tuple[str, ...] = ()) → Dict[str, Any]

就地应用由 ShardedTensorFactories 定义的合并。

参数
  • x1 (StateDict) – 从 checkpoint 加载的状态字典

  • x2 (ShardedStateDict) – x1 的子集(在字典键方面),其中 ShardedTensorFactory 作为(可能嵌套的)值,定义如何合并来自 x1 状态字典的对象

  • key (Tuple[str, ...]) – 递归调用中的当前键。仅用于报告有意义的错误

返回值

x1 已就地修改

返回类型

StateDict

core.dist_checkpointing.mapping.is_main_replica(replica_id: Union[int, Tuple[int, ...]])

检查给定的 replica_id 是否被视为主副本。

“主”副本是:- 整数 0 - 或所有元素均为 0 的可迭代对象

应用程序有责任为分片张量设置正确的副本。

参数

replica_id (Union[int, Tuple[int, ...]]) – 副本 ID

返回值

对于“主”副本为 True

返回类型

(bool)

用于基于模型参数的现有分片为优化器状态定义分片的助手函数。

core.dist_checkpointing.optimizer.get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, int]

生成从优化器参数到优化器状态 ID 的映射。

core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(model_sharded_state_dict: Dict[str, Any], optim_params_iter: Iterable[torch.nn.Parameter]) → Dict[int, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]]

生成从优化器状态 ID 到模型分片参数的映射。

参数
  • model_sharded_state_dict – 包含所有模型分片张量的分片状态字典(可以有任何结构)

  • optim_params_iter – 迭代优化器跟踪的模型参数的可迭代对象。迭代顺序必须与优化器参数中的顺序相同。

返回值

从优化器状态 ID 的映射

到模型分片参数。

返回类型

Dict[int, Union[ShardedTensor, ShardedTensorFactory]]

core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory], optim_param: torch.Tensor, prefix: str) → Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]

基于模型参数为优化器参数构建 ShardedTensor 或 ShardedTensorFactory

参数
  • model_param (Union[ShardedTensor, ShardedTensorFactory]) – 模型参数

  • optim_param (torch.Tensor) – 对应的优化器参数

  • prefix (str) – ShardedTensor 或 ShardedTensorFactory 的优化器前缀

返回值

包装后的优化器参数

返回类型

Union[ShardedTensor, ShardedTensorFactory]

core.dist_checkpointing.optimizer.optim_state_to_sharding_state(optim_state_dict: Dict[str, Any], id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor], exclude_keys: Tuple[str] = ())

就地将优化器状态字典转换为基于模型状态字典的分片状态字典。

可用于向最常见的优化器状态字典添加分片信息。 为 optim_state_dict[‘state’] 中的每个键创建单独的 ShardedTensor(例如,对于 torch.optim.Adam,将为 exp_avgexp_avg_sq 创建单独的张量)

参数
  • optim_state_dict (StateDict) – 优化器状态字典,其中状态参数位于 state 键下,组超参数位于 param_groups -> params 键下。

  • id_to_sharded_param_map (Dict[int, ShardedTensor]) – 从优化器参数 ID 到模型分片张量的映射。 可以使用 get_param_id_to_sharded_param_map 函数生成。

  • exclude_keys (Tuple[str]) – 从最终状态字典中排除的优化器状态键。

返回值

状态字典已就地修改

返回类型

None

用于管理分布式 checkpoint 元数据的模块。

class core.dist_checkpointing.core.CheckpointingConfig(sharded_backend: str, sharded_backend_version: int = 1, common_backend: str = 'torch', common_backend_version: int = 1)

基类:object

记录 checkpoint 中使用的后端。

Checkpoint 配置跟踪用于存储分片张量 (sharded_backend) 和其他对象 (common_backend) 的格式。

请注意,版本控制不是针对 checkpoint 内容(这是应用程序特定的),而是针对 checkpoint 格式本身。

common_backend: str = 'torch'
common_backend_version: int = 1
sharded_backend: str
sharded_backend_version: int = 1
exception core.dist_checkpointing.core.CheckpointingException

Bases: Exception

与 checkpoint 相关的基本异常

core.dist_checkpointing.core.check_is_distributed_checkpoint(checkpoint_dir)

检查 metadata.json 是否存在于 checkpoint 中并且是有效的配置。

参数

checkpoint_dir – checkpoint 目录

返回值

如果 metadata.json 存在于 checkpoint 中并且是有效的配置,则为 True。

返回类型

bool

core.dist_checkpointing.core.maybe_load_config(checkpoint_dir: str) → Optional[core.dist_checkpointing.core.CheckpointingConfig]

如果 checkpoint_dir 是分布式 checkpoint,则返回 checkpoint 配置,否则返回 None

参数

checkpoint_dir – checkpoint 目录

返回值

如果 checkpoint 不是有效的分布式 checkpoint,则为 None

返回类型

CheckpointingConfig (可选)

core.dist_checkpointing.core.save_config(config: core.dist_checkpointing.core.CheckpointingConfig, checkpoint_dir: str)

将给定配置保存到 checkpoint 目录。

参数
  • config – checkpoint 配置

  • checkpoint_dir – checkpoint 目录

返回值

None

用于操作字典和列表的实用工具。

此模块中的所有函数都处理字典和列表的嵌套。 其他对象(例如元组)被视为不能遍历的原子叶类型。

core.dist_checkpointing.dict_utils.dict_list_map_inplace(f: Callable[[core.dist_checkpointing.dict_utils.U], core.dist_checkpointing.dict_utils.V], x: Union[Dict, List, core.dist_checkpointing.dict_utils.U])

使用给定函数就地映射字典和列表。

core.dist_checkpointing.dict_utils.dict_list_map_outplace(f: Callable[[core.dist_checkpointing.dict_utils.U], core.dist_checkpointing.dict_utils.V], x: Union[Dict, List, core.dist_checkpointing.dict_utils.U]) → Union[Dict, List, core.dist_checkpointing.dict_utils.V]

使用给定函数异地映射字典和列表。

core.dist_checkpointing.dict_utils.dict_map(f: Callable, d: dict)

字典的 map 等效项。

core.dist_checkpointing.dict_utils.dict_map_with_key(f: Callable, d: dict)

字典的 map 等效项,函数接受元组 (key, value)。

core.dist_checkpointing.dict_utils.diff(x1: Any, x2: Any, prefix: Tuple = ()) → Tuple[list, list, list]

字典的递归 diff。

参数
  • x1 (object) – 左侧字典

  • x2 (object) – 右侧字典

  • prefix (tuple) – 跟踪递归调用。用于报告不同的键。

返回值

元组,包含:
  • only_left:仅在左侧字典中存在的前缀

  • only_right:仅在右侧字典中存在的前缀

  • mismatch:在两个字典中都存在但跨字典不相等的值。

    对于张量,检查所有元素的相等性。 每个元素都是一个元组(前缀,左侧值的类型,右侧值的类型)。

返回类型

Tuple[list, list, list]

core.dist_checkpointing.dict_utils.extract_matching_values(x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False) → Tuple[Union[dict, list], Union[dict, list]]

返回匹配和不匹配的值。保持层次结构。

参数
  • x (Union[dict, list]) – 要处理的状态字典。顶级参数必须是字典或列表

  • predicate (object -> bool) – 确定匹配值

  • return_lists_as_dicts (bool) – 如果为 True,则匹配的列表将转换为字典,键指示原始元素的索引。 用于重建原始层次结构。

core.dist_checkpointing.dict_utils.inspect_types(x: Any, prefix: Tuple = (), indent: int = 4)

帮助打印(嵌套)字典值的类型。

core.dist_checkpointing.dict_utils.map_reduce(xs: typing.Iterable, key_fn: typing.Callable = <function <lambda>>, value_fn: typing.Callable = <function <lambda>>, reduce_fn: typing.Callable = <function <lambda>>) → dict

简单的 map-reduce 实现,遵循 more_itertools.map_reduce 接口。

core.dist_checkpointing.dict_utils.merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ())

递归地合并字典和列表。

core.dist_checkpointing.dict_utils.nested_items_iter(x: Union[dict, list])

返回给定字典或列表的(嵌套)元组 (container, key, value) 的迭代器。

core.dist_checkpointing.dict_utils.nested_values(x: Union[dict, list])

返回给定字典或列表的(嵌套)值的迭代器。

用于操作分片张量和分片状态字典的助手函数。

core.dist_checkpointing.utils.add_prefix_for_sharding(sharded_state_dict: Dict[str, Any], prefix: str)

就地将给定前缀添加到给定状态字典中的所有 ShardedBase 对象。

参数
  • sharded_state_dict (ShardedStateDict) – 分片状态字典

  • prefix (str) – 要添加的前缀

返回值

状态字典已就地修改

返回类型

None

core.dist_checkpointing.utils.apply_prefix_mapping(sharded_state_dict: Dict[str, Any], prefix_map: Dict[str, str])

仅在与映射中的前缀之一匹配的键中替换前缀。

参数
  • sharded_state_dict (ShardedStateDict) – 要替换键的分片状态字典

  • prefix_map (Dict[str, str]) – 旧前缀 -> 新前缀的映射。 使用每个键的第一个匹配前缀

返回值

状态字典已就地修改

返回类型

None

core.dist_checkpointing.utils.extract_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

从给定的状态字典中提取仅由 LocalNonpersistentObjects 组成的字典。

参数

sharded_state_dict – 可能包含 LocalNonpersistentObjects 的状态字典

返回值

元组,包含:
  • 包含所有 LocalNonpersistentObjects 的状态字典(保持原始状态字典结构)

  • 包含所有其他对象的状态字典(保持原始状态字典结构)

返回类型

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_base(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

从包含任何对象的给定状态字典中提取仅由 ShardedBase 组成的字典。

参数

sharded_state_dict – 可能包含 ShardedBase 对象的状态字典

返回值

元组,包含:
  • 包含所有 ShardedBase 对象的状态字典(保持原始状态字典结构)

  • 包含所有其他对象的状态字典(保持原始状态字典结构)

返回类型

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

从包含任何对象的给定状态字典中提取仅由 ShardedTensor 对象组成的字典。

参数

sharded_state_dict – 可能包含 ShardedTensor 对象的状态字典

返回值

元组,包含:
  • 包含所有 ShardedTensor 的状态字典(保持原始状态字典结构)

  • 包含 ShardedTensor 以外的所有对象的状态字典(保持原始状态字典结构)

返回类型

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_and_factories(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

从包含任何对象的给定状态字典中提取仅由 ShardedTensor 和 ShardedTensorFactory 对象组成的字典。

参数

sharded_state_dict – 可能包含 ShardedTensor 和 ShardedTensorFactory 对象的状态字典

返回值

元组,包含:
  • 包含所有 ShardedTensor 和 ShardedTensorFactory 的状态字典(保持原始状态字典结构)

  • 包含所有其他对象的状态字典(保持原始状态字典结构)

返回类型

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.extract_sharded_tensors_or_nonpersistent(sharded_state_dict: Dict[str, Any]) → Tuple[Dict[str, Any], Dict[str, Any]]

从包含任何对象的给定状态字典中提取仅由 ShardedTensor、ShardedTensorFactory 和 LocalNonpersistentObject 对象组成的字典。

参数
  • sharded_state_dict – 可能包含 ShardedTensor、ShardedTensorFactory

  • objects (and LocalNonpersistentObject) –

返回值

元组,包含:
  • 包含所有 ShardedTensor、ShardedTensorFactory 和 LocalNonpersistentObject 的状态字典(保持原始状态字典结构)

  • 包含所有其他对象的状态字典(保持原始状态字典结构)

返回类型

Tuple[ShardedStateDict, StateDict]

core.dist_checkpointing.utils.replace_prefix_for_sharding(sharded_state_dict: Dict[str, Any], old_prefix: str, new_prefix: str)

在给定状态字典的所有分片键中替换给定前缀。

如果某些键不以给定前缀开头,则会报错。

参数
  • sharded_state_dict (ShardedStateDict) – 要替换键的分片状态字典

  • old_prefix (str) – 要在每个键中替换的前缀

  • new_prefix (str) – 新前缀

返回值

状态字典已就地修改

返回类型

None

Previous Mixture of Experts package
Next dist_checkpointing.strategies package
© Copyright 2022-2025, NVIDIA. Last updated on Jan 14, 2025.