Attention Is All You Need!

Transformer 模型背后的核心思想是注意力机制 [1]。它识别词语之间的相关性,选择句子中最重要部分进行关注,并捕获数据中Meaningful的模式和依赖关系。图 1 显示了一个典型的注意力机制,其中 pre-softmax 操作可以是缩放、偏置和掩码的组合,而 post-softmax 操作通常只是 dropout。

4dadbf0322ba48f6a755ed1fc36ba926

图 1: 点积注意力。

Transformer Engine 支持在两个框架中计算点积注意力,PyTorchJAX。 每个框架的 API 是

1. 注意力后端

Transformer Engine 为每个支持的框架提供了多个注意力后端。框架原生后端提供了强大的基线,而融合的、GPU 优化的实现提供了更高的性能。例如,PyTorch 中的 flash-attention 和 cuDNN 注意力后端。 框架原生后端通常以 “unfused” 命名,而更优化的后端则为 “fused” 或 “flash”。

框架

后端 (模块名称)

模块位置

PyTorch

cuDNN 注意力 (FusedAttention)

transformer_engine.pytorch.attention

flash-attention (FlashAttention)

PyTorch 原生注意力 (UnfusedDotProductAttention)

JAX

cuDNN 注意力 (_FusedDotProductAttention)

transformer_engine.jax.flax.transformer

JAX 原生注意力 (_UnfusedDotProductAttention)

1.1 Flash 与 Non-Flash

注意力计算具有相对于序列长度的二次计算和内存复杂度。当序列长度加倍时,其运行时和内存需求将翻两番。 这对扩展 Transformer 模型以适应更长的上下文,从而获得更高的模型质量提出了重大挑战。

与标准的 non-flash 算法相比,flash 算法 [2] 被提出将内存缩放降低到线性,并通过优化的内存访问提高计算效率。它采用了以下两种独特的技术。

  • 平铺 (Tiling): non-flash 算法尝试在一个步骤中处理 query、key、value 张量,这需要大量的全局内存,并导致全局内存和共享内存之间的大量读/写操作。 flash 算法根据可用的共享内存和寄存器大小将输入分解为多个平铺块,并一次计算一个平铺块的 softmax。

  • 重计算 (Recomputation): non-flash 算法将 softmax 矩阵(相对于序列长度的二次方)存储到全局内存中以进行反向传播,而 flash 算法仅保存 softmax 归一化因子(相对于序列长度的线性)。这减少了所需的内存量以及全局内存和共享内存之间的带宽利用率。 即使为了在反向传播中重新计算注意力而产生了额外的计算,带宽节省仍然为效率的提高提供了显着的改进。

注意

Transformer Engine 的 flash-attention 后端(在 PyTorch 中可用)和 cuDNN 注意力后端(子后端 1 和 2,在 PyTorch 和 JAX 中可用)都基于 flash 算法。

1.2 flash-attention

flash-attention 后端仅在 PyTorch 中可用,它是围绕公共 flash-attn[3] 封装的模块。

flash-attention 后端支持 flash-attn 的功能以及一些额外的功能,以方便使用 flash-attn,例如将 attention_mask 转换为累积序列长度 cu_seqlens 以用于 padding 掩码用例。 有关详细信息,请参阅 transformer_engine.pytorch.attention.FlashAttention

flash-attn 依赖项在 Transformer Engine 中定期更新。 截至 v2.0,Transformer Engine 支持 flash-attn 2.0.6+(请参阅 setup.py)。

要了解 flash-attn 的性能,请参阅他们的基准测试 此处

1.3 cuDNN 注意力

cuDNN 注意力后端在 PyTorch 和 JAX 中均可用,为注意力计算提供了另一种高性能解决方案。 它需要运行 cuDNN,并具有多个子后端以支持不同的精度和序列长度。

子后端

算法

精度

序列长度

架构

附加信息

0

Non-Flash

BF16/FP16

≤512

sm80, 90

cuDNN

1

Flash

BF16/FP16

任意

sm80+

cuDNN, cudnn-frontend

2

Flash

FP8

cuDNN pre-9.0: ≤512

cuDNN pre-9.0: sm90

cuDNN 9.0+: 任意

cuDNN 9.0+: sm90+

cuDNN 9.0+: cudnn-frontend

cuDNN 注意力后端和 flash-attention 后端有几个显着的差异。 截至 Transformer Engine 2.0、cuDNN 9.3 和 flash-attn 2.4.2,

  • flash-attention 仅支持 PyTorch 框架,而 cuDNN 注意力支持 PyTorch 和 JAX。

  • flash-attention 支持 BF16、FP16 精度,而 cuDNN 注意力也支持 FP8(通过其子后端 2)。

  • flash-attention 支持 bshdthd 输入格式(无需任何转置)和 sbhd 格式(需要转置),而 cuDNN 注意力支持所有三种格式(无需转置)(有关更多详细信息,请参阅第 3.1 节)。

  • flash-attention 不支持 post_scale_bias,而 cuDNN 注意力支持。

  • flash-attention 支持 KV 缓存和分页注意力,而 cuDNN 注意力不支持。

  • flash-attention 在交叉注意力中对 causal 掩码使用右下对角线(请参阅 更改日志),而 cuDNN 注意力同时支持左上和右下。

  • 根据我们对许多常用模型配置的基准测试,flash-attention 在 Ampere 架构上的性能优于 cuDNN 注意力,而 cuDNN 注意力在 Hopper 架构上具有 20-50% 的优势。

为了比较 cuDNN 注意力和 flash-attention,用户可以修改 benchmarks/attention/benchmark_attention.py 中的 model_configs 字典来收集性能数据。 该脚本对 model_configs 中的每个条目运行 num_iters 次,每次运行都包含一次前向传播和一次反向传播。 两种后端都会尝试,如果某个后端不支持特定的用户输入,则最终表格中的运行时和加速比将为 0。

[ ]:
model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
    "test_0": ModelConfig(2, 16, 16,  64,  512,  512, 0.0, "no_mask",         "no_bias"), # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal",         "no_bias"), # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal", "post_scale_bias"), # bias
    "test_3": ModelConfig(2, 32,  4, 128, 8192, 8192, 0.0,  "causal",         "no_bias"), # GQA
}
[1]:
!cd ../../../benchmarks/attention/ && python benchmark_attention.py
Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory
Running test_0 with cuDNN attention and flash-attention...
Running test_1 with cuDNN attention and flash-attention...
Running test_2 with cuDNN attention...
Running test_3 with cuDNN attention and flash-attention...

        cuDNN fwd+bwd (ms)  flash-attn fwd+bwd (ms)  cuDNN vs flash speedup
test_0              0.0340                   0.0468                  1.3786
test_1              0.3664                   0.5850                  1.5968
test_2              0.9332                   0.0000                  0.0000
test_3              7.4875                  11.8879                  1.5877

2. 后端选择

鉴于各种注意力后端,Transformer Engine 具有选择逻辑,可以为特定用户输入和运行时环境选择最合适的后端。 选择逻辑基于后端可用性和后端性能。

后端可用性由模型配置、训练超参数、软件版本和所讨论的 GPU 架构等因素决定。 例如,一些考虑因素包括序列长度、注意力头数、头大小、注意力掩码类型、注意力偏置类型、训练或推理模式、自注意力或交叉注意力、MHA 或 MQA/GQA、flash-attn/cuDNN 库版本以及 GPU 的计算能力。

当有多个后端可用时,Transformer Engine 会根据性能进行后端选择。 通常,我们的选择逻辑中遵循以下几个规则(见下表)。 随着我们监控不同后端的性能,选择逻辑可能会发生变化。

框架

选择顺序

PyTorch

sm90: cuDNN 注意力 > flash-attention > PyTorch 原生注意力

sm80: flash-attention > cuDNN 注意力 > PyTorch 原生注意力

cuDNN 注意力: 子后端 1 > 子后端 0

JAX

cuDNN 注意力 > JAX 原生注意力

2.1 调试信息

为了找出运行时正在使用的后端,我们有以下两个调试标志。 日志记录是通过使用 logging 包完成的。

NVTE_DEBUG       = 0/1   # disables/enables debugging
NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages

注意

截至 Transformer Engine 2.0,这些标志仅在 PyTorch 中受支持。 预计将来会添加 JAX 支持。

示例脚本 example_attention.py 运行一个非常基本的模型,其中包含两个注意力后端:cuDNN 注意力和 flash-attention。 在这里,NVTE_DEBUG_LEVEL=1 允许我们找出运行时使用的后端/子后端。

[24]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python example_attention.py
Run cuDNN attention...
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)

Run flash-attention...
[INFO     | DotProductAttention]: Running with FlashAttention backend

Test passed.

NVTE_DEBUG_LEVEL=2 允许我们了解有关后端选择逻辑的更多信息。 如果用户想提交错误报告,我们鼓励他们仔细检查 config 并将其提供给 Transformer Engine 团队。

[23]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 python example_attention.py
Run cuDNN attention...
[DEBUG    | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}
[DEBUG    | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0
[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
[DEBUG    | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)

Run flash-attention...
[DEBUG    | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}
[DEBUG    | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0
[DEBUG    | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}
[DEBUG    | DotProductAttention]: Selected backend = FlashAttention
[INFO     | DotProductAttention]: Running with FlashAttention backend

Test passed.

2.2 用户控制

用户通常不需要担心后端选择。 但是,如果遇到收敛或性能问题,Transformer Engine 提供了其他一些环境变量供用户试验不同的后端。

flash-attention 或 cuDNN 注意力: 用户可以通过 PyTorch 中的以下两个环境变量启用/禁用 flash-attention 后端或 cuDNN 注意力后端。

NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1

cuDNN 注意力子后端: 此环境变量允许用户表达他们对 cuDNN 注意力子后端的偏好。 但是,只有在选定的子后端符合条件时(即,如果它支持提供的输入和运行时环境),才会使用它。

NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend

cuDNN 子后端 1 的执行路径: cuDNN 注意力子后端 1 还提供两个执行路径:工作区优化路径和非工作区优化路径。 工作区优化路径需要更大的全局内存量,提供确定性,并提供偏置梯度支持。 在 cuDNN 9.0 之前,它也比非工作区优化路径具有 20-30% 的性能优势。 但在 cuDNN 9.0 之后,它比非工作区优化路径慢 20-30%。

用户可以通过以下环境变量试验这两种路径。 但是,请注意可能的内存不足风险。

Before cuDNN 9.0:
    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path
    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path

After cuDNN 9.0:
    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path
    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path

注意

环境变量 NVTE_FLASH_ATTN、NVTE_FUSED_ATTN、NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT 和 NVTE_ALLOW_NONDETERMINISTIC_ALGO 仅在 PyTorch 中受支持,将来将添加到 JAX 中。

2.3 示例测试

我们的 单元测试 演示了 Transformer Engine 点积注意力 API 的使用。 我们鼓励用户在将 Transformer Engine 集成到他们的 ML 工作流程中时,将它们用作模板。

例如,在 PyTorch 中,test_dot_product_attention 提供了 pytorch.DotProductAttention 的各种用例,从数据类型、模型配置、检查点到 QKV 布局。

3. 后端支持

Transformer Engine 支持常用功能,例如自注意力、交叉注意力、FP16/BF16 精度、dropout 和检查点。 但它还提供了一系列其他功能。 截至 v2.0,Transformer Engine 的注意力后端具有以下支持矩阵。

注意力后端

精度

架构

滑动窗口注意力

MQA/GQA

多潜在注意力

上下文并行

可能具有确定性

cuDNN 注意力 (所有框架)

BF16, FP16, FP8 (仅限 PyTorch)

sm80+

是 (bshd,sbhd, thd)

flash-attention (PyTorch)

BF16, FP16

sm80+

是 (bshd,thd)

框架原生注意力

BF16, FP16, FP32

任意

否,除非用作掩码

是 (仅限 PyTorch)

提供了一些单元测试,作为将此类功能集成到用户模型中的起点。 例如,- 滑动窗口注意力: test_dpa_swa - MQA/GQA: test_te_layer_mqa_gqa - 多潜在注意力: test_dpa_mla - 上下文并行: test_cp_with_fused_attention, test_cp_with_flash_attention

3.1 QKV 布局

Transformer Engine 支持查询 q、键 k、值 v 张量的各种布局。 它定义了 15 种 QKV 布局,这些布局分为 3 种 QKV 格式和 5 种 QKV 布局组,以帮助跨不同布局进行类似的内存/计算操作。 这些布局和组的映射关系是,

qkv_layout

qkv_layout_group=3hd

h3d

hd_2hd

hd_h2d

hd_hd_hd

qkv_format=sbhd

sb3hd

sbh3d

sbhd_sb2hd

sbhd_sbh2d

sbhd_sbhd_sbhd

bshd

bs3hd

bsh3d

bshd_bs2hd

bshd_bsh2d

bshd_bshd_bshd

thd

t3hd

th3d

thd_t2hd

thd_th2d

thd_thd_thd

符号系统是 b 代表批大小,s 序列长度,h 注意力头数,d 头维度,t 批次中的 token 总数,即 t = sum(s_i) for i in 0,...,b-1。 以下是一些布局及其说明示例,以帮助澄清定义。

qkv_layout=sb3hd: qkv 是序列优先,即 s 是每个张量中的前导维度。 它们是张量 qkv 的不同切片:q, k, v = [qkv[:,:,i,:,:] for i in range(3)]。 它们在 h * d 维度上交错排列。

qkv_layout=bshd_bsh2d: qkv 是批次优先,即 b 是每个张量中的前导维度。 q 是连续的,kv 是张量 kv 的不同切片:k, v = [kv[:,:,:,i,:] for i in range(2)]kvd 维度上交错排列。

bsh2d 中的 shkv 的最大序列长度和头数,这可能与 qbshd 中的 sh 不同。 出于简洁的原因,我们将它们表示为相同。 Transformer Engine 确实会区分它们的值以进行实际执行。

qkv_layout=thd_thd_thd: qkv 在批次中具有可变的序列长度。 它们都是连续的,没有交错排列。

截至 v2.0,Transformer Engine 具有以下支持矩阵。

后端

支持的 QKV 格式

注释

flash-attention

bshd, sbhd, thd

PyTorch: 3 种格式,即 15 种布局

cuDNN 注意力

bshd, sbhd, thd

PyTorch: 3 种格式,即 15 种布局

JAX: bs3hd, bshd_bs2hd, bshd_bshd_bshd 布局

框架原生注意力

bshd, sbhd

PyTorch, JAX: 2 种格式,即 10 种布局

不同布局的一些示例用法可以在 test_dpa_qkv_layouttest_dpa_qkv_layout_thd 中找到。 Transformer Engine 还提供了一个实用程序函数 transformer_engine.pytorch.attention.get_qkv_layout,以帮助确定一组 qkv 张量具有哪种布局(仅限 PyTorch)。

注意

当使用 RoPE 时,qkv_layout 可能会在 Transformer Engine PyTorch 中通过 get_qkv_layout 更改。 这是由于我们的 RoPE 实现的就地性质。 我们将 qkv 张量从其初始布局转换为相应的 hd_hd_hd 布局。 例如,从 RoPE 之前的 pytorch.MultiHeadAttention 中的 sbh3d 转换为 RoPE 之后的 pytorch.DotProductAttention 中的 sbhd_sbhd_sbhd。

3.2 注意力掩码

Transformer Engine 支持 7 种掩码类型,所有掩码都定义为 True 掩盖相应元素,False 包括注意力计算中的相应元素。

  • no_mask, padding, causal, causal_bottom_right, padding_causal, padding_causal_bottom_right, arbitrary

不同的后端为注意力掩码提供不同的支持。 截至 Transformer Engine 2.0,

后端

支持的掩码类型

需要 attention_mask

flash-attention

  • no_mask, causal (自注意力),

  • padding, padding_causal (自注意力),

  • causal_bottom_right, padding_causal_bottom_right

  • no_mask, causal causal_bottom_right: 否

  • padding, padding_causal, padding_causal_bottom_right: 如果未提供 cu_seqlens,则为是

  • arbitrary: 是

  • cuDNN 注意力

  • no_mask, causal,

  • padding, padding_causal,

  • causal_bottom_right, padding_causal_bottom_right

  • 框架原生注意力

  • 全部 (PyTorch)

  • no_mask, causal, padding (Jax)

  • Padding 掩码: 对于 padding, padding_causal, padding_causal_bottom_right 掩码类型,用户需要提供序列长度信息,以帮助 Transformer Engine 确定批次中每个序列的结束位置。截至 Transformer Engine 2.0,在 PyTorch 中有两种选择,在 JAX 中有一种选择。

    • PyTorch: 当用户同时提供这两个选项时,首选 cu_seqlens,因为无需额外的转换。

      • cu_seqlens: 用户可以为 qk/v 向 flash-attention 或 cuDNN attention 后端提供累积序列长度张量 cu_seqlens_qcu_seqlens_kvcu_seqlens 的一个示例是 [0, 2, 6, 7],对应于 3 个批次的 [aa000, bbbb0, c0000]

      • attention_mask: 用户也可以提供 attention_mask 作为替代方案,它将被转换为 cu_seqlens。对于自注意力机制,attention_mask 应该是一个形状为 [batch_size, 1, 1, seqlen_q] 的单个张量;对于交叉注意力机制,attention_mask 应该是一个包含两个张量的列表,形状分别为 [batch_size, 1, 1, seqlen_q][batch_size, 1, 1, seqlen_kv]

    • JAX: 用户应提供形状为 [batch_size, 1, seqlen_q, seqlen_kv]attention_mask 张量。

    qkv_format=thd: 如果未提供 max_seqlen_qmax_seqlen_kv,Transformer Engine 会从 qkv 中提取最大序列长度信息。这需要 GPU-CPU 复制和同步操作。出于性能原因,对于 thd QKV 格式,请将 max_seqlen_qmax_seqlen_kv 设置为其适当的值。

    Arbitrary 掩码: 截至 v9.3,cuDNN 不支持 Arbitrary 掩码类型。但是,用户可以将掩码转换为常规的 post_scale_bias 偏差,并实现相同的功能。有关此转换的示例脚本,请参阅 arbitrary_mask_to_post_scale_bias.py

    [33]:
    
    !NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py
    
    Run with post_scale_bias:
    [INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
    
    Run with arbitrary mask:
    [INFO     | DotProductAttention]: Running with UnfusedDotProductAttention backend
    
    Test passed!
    

    有关使用不同注意力掩码运行 Transformer Engine 的更多示例,请参见 test_dpa_mask

    3.3 注意力偏差

    Transformer Engine 支持 4 种注意力偏差类型:no_biaspre_scale_biaspost_scale_biasALiBi(带/不带自定义斜率)。截至 Transformer Engine 2.0,其支持矩阵如下。

    后端

    偏差类型

    偏差形状

    偏差数据类型

    架构

    flash-attention

    no_bias, ALiBi (带斜率)

    N/A

    ALiBi 斜率: FP32

    sm80+

    cuDNN 注意力

    PyTorch: no_bias, post_scale_bias, ALiBi (不带斜率)

    post_scale_bias: 前向传播为 BHSS、1HSS、B1SS、11SS,反向传播为 1HSS

    post_scale_bias: 与 QKV 类型相同

    cuDNN 8.9.6+: sm90

    JAX: no_bias, post_scale_bias

    ALiBi 斜率: FP32

    cuDNN 9.0+: sm80+

    框架原生注意力

    no_bias, pre_scale_bias, post_scale_bias

    post_scale_bias: BHSS、1HSS、B1SS、11SS

    post_scale_bias: 与 QKV 类型相同

    sm80+

    flash-attention 后端通过要求用户传入 alibi_slopes 张量来启用 ALiBi,该张量可以是 vanilla ALiBi 的默认斜率,也可以是用户定义的斜率。另一方面,cuDNN attention 通过接收一个 Boolean 标志来支持 ALiBi,并且截至 cuDNN 9.0,它仅支持 vanilla ALiBi。

    框架原生后端不显式支持 ALiBi,但用户可以将 ALiBi 转换为常规的 post_scale_bias 偏差以达到相同的效果。在 PyTorch 中,可以使用实用程序函数 transformer_engine.pytorch.attention.get_alibi 来帮助进行转换。

    有关如何使用各种注意力偏差的更多示例,请参见 test_dpa_bias

    3.4 FP8 注意力

    Transformer Engine 的一个独特功能是其 FP8 支持,不仅适用于 Linear 层,也适用于点积注意力。Transformer Engine 的 FP8 注意力支持通过其 cuDNN attention 子后端 2 实现。回想一下图 1:两个 MatMul 操作在 FP8 中执行以提高计算效率,而 SoftMax 操作在 FP32 中执行以提高数值精度。

    截至 v2.0,Transformer Engine 通过其 C APIPyTorch API 支持 FP8 注意力。其 PyTorch API 提供两个选项,均通过 FP8 配方定义 transformer_engine.common.recipe.DelayedScaling 进行控制。

    • DelayedScaling.fp8_dpa=True (default=False): 当 cuDNN attention 子后端 2 支持提供的用户输入时,这将启用它。FusedAttention 模块(用于 cuDNN attention)将 FP16 或 BF16 张量作为输入,在 FP8 中执行点积注意力,并返回 FP16 或 BF16(与输入类型相同)的注意力 logits。需要进行类型转换操作,以在模块开始时将张量转换为 FP8,并在模块结束时转换回 FP16/BF16。

    • DelayedScaling.fp8_mha=True (default=False): 此选项在 fp8_dpa=True 的基础上,移除了 FusedAttention 模块开始和结束时的类型转换操作。此功能是实验性的。

    有关使用这两个功能的示例,请参见 test_dpa_fp8_vs_f16test_mha_fp8_vs_f16。要禁用 FP8 注意力的反向传播,仅在前向传播中使用它,用户还可以设置 NVTE_FP8_DPA_BWD=0 (default=1)