常见问题 (FAQ)

FP8 检查点兼容性

Transformer Engine 从 1.6 版本开始支持 FP8 注意力机制。它将 FP8 元数据,即缩放因子和 amax 历史记录,存储在检查点中的 ._extra_state 键下。随着 FP8 注意力支持从一个后端扩展到多个后端,._extra_state 键的位置也发生了变化。

在此,我们以 MultiheadAttention 模块为例。其在 Transformer Engine 1.11 版本中的 FP8 注意力元数据存储为 core_attention._extra_state,如下所示。

>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
...     mha = MultiheadAttention(
...         hidden_size=1024,
...         num_attention_heads=16,
...         bias=True,
...         params_dtype=torch.bfloat16,
...         input_layernorm=False,
...         fuse_qkv_params=True,
...         attention_type="self",
...         qkv_weight_interleaved=True,
...     ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])

以下是所有 Transformer Engine 版本的检查点保存/加载行为的完整列表。

版本:<= 1.5

  • 由于不支持 FP8 注意力,因此不保存 FP8 元数据

  • 以下版本创建的检查点的加载行为

    <= 1.5:

    不加载 FP8 元数据

    > 1.5:

    错误:意外的键

版本:1.6, 1.7

  • 将 FP8 元数据保存到 core_attention.fused_attention._extra_state

  • 以下版本创建的检查点的加载行为

    <= 1.5:

    将 FP8 元数据初始化为默认值,即缩放因子为 1,amax 为 0

    1.6, 1.7:

    从检查点加载 FP8 元数据

    >= 1.8:

    错误:意外的键

版本:>=1.8, <= 1.11

  • 将 FP8 元数据保存到 core_attention._extra_state

  • 以下版本创建的检查点的加载行为

    <= 1.5:

    将 FP8 元数据初始化为默认值,即缩放因子为 1,amax 为 0

    1.6, 1.7:

    此保存/加载组合依赖于用户将 1.6/1.7 键映射到 1.8-1.11 键。否则,它会将 FP8 元数据初始化为默认值,即缩放因子为 1,amax 为 0。在这种 MultiheadAttention 示例中,可以通过以下方式完成映射:

    >>> state_dict["core_attention._extra_state"] = \
            state_dict["core_attention.fused_attention._extra_state"]
    >>> del state_dict["core_attention.fused_attention._extra_state"]
    
    >= 1.8:

    从检查点加载 FP8 元数据

版本:>=1.12

  • 将 FP8 元数据保存到 core_attention._extra_state

  • 以下版本创建的检查点的加载行为

    <= 1.5:

    将 FP8 元数据初始化为默认值,即缩放因子为 1,amax 为 0

    >= 1.6:

    从检查点加载 FP8 元数据