常见问题 (FAQ)
FP8 检查点兼容性
Transformer Engine 从 1.6 版本开始支持 FP8 注意力机制。它将 FP8 元数据,即缩放因子和 amax 历史记录,存储在检查点中的 ._extra_state 键下。随着 FP8 注意力支持从一个后端扩展到多个后端,._extra_state 键的位置也发生了变化。
在此,我们以 MultiheadAttention 模块为例。其在 Transformer Engine 1.11 版本中的 FP8 注意力元数据存储为 core_attention._extra_state,如下所示。
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
... bias=True,
... params_dtype=torch.bfloat16,
... input_layernorm=False,
... fuse_qkv_params=True,
... attention_type="self",
... qkv_weight_interleaved=True,
... ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
以下是所有 Transformer Engine 版本的检查点保存/加载行为的完整列表。
版本:<= 1.5
|
版本:1.6, 1.7
|
版本:>=1.8, <= 1.11
|
版本:>=1.12
|