Attention Is All You Need!
Transformer 模型背后的核心思想是注意力机制 [1]。它识别词语之间的相关性,选择句子中最重要部分进行关注,并捕获数据中Meaningful的模式和依赖关系。图 1 显示了一个典型的注意力机制,其中 pre-softmax 操作可以是缩放、偏置和掩码的组合,而 post-softmax 操作通常只是 dropout。
图 1: 点积注意力。
Transformer Engine 支持在两个框架中计算点积注意力,PyTorch 和 JAX。 每个框架的 API 是
1. 注意力后端
Transformer Engine 为每个支持的框架提供了多个注意力后端。框架原生后端提供了强大的基线,而融合的、GPU 优化的实现提供了更高的性能。例如,PyTorch 中的 flash-attention 和 cuDNN 注意力后端。 框架原生后端通常以 “unfused” 命名,而更优化的后端则为 “fused” 或 “flash”。
框架 | 后端 (模块名称) | 模块位置 |
---|---|---|
PyTorch | cuDNN 注意力 ( | |
flash-attention ( | ||
PyTorch 原生注意力 ( | ||
JAX | cuDNN 注意力 ( | |
JAX 原生注意力 ( |
1.1 Flash 与 Non-Flash
注意力计算具有相对于序列长度的二次计算和内存复杂度。当序列长度加倍时,其运行时和内存需求将翻两番。 这对扩展 Transformer 模型以适应更长的上下文,从而获得更高的模型质量提出了重大挑战。
与标准的 non-flash 算法相比,flash 算法 [2] 被提出将内存缩放降低到线性,并通过优化的内存访问提高计算效率。它采用了以下两种独特的技术。
平铺 (Tiling): non-flash 算法尝试在一个步骤中处理 query、key、value 张量,这需要大量的全局内存,并导致全局内存和共享内存之间的大量读/写操作。 flash 算法根据可用的共享内存和寄存器大小将输入分解为多个平铺块,并一次计算一个平铺块的 softmax。
重计算 (Recomputation): non-flash 算法将 softmax 矩阵(相对于序列长度的二次方)存储到全局内存中以进行反向传播,而 flash 算法仅保存 softmax 归一化因子(相对于序列长度的线性)。这减少了所需的内存量以及全局内存和共享内存之间的带宽利用率。 即使为了在反向传播中重新计算注意力而产生了额外的计算,带宽节省仍然为效率的提高提供了显着的改进。
注意
Transformer Engine 的 flash-attention 后端(在 PyTorch 中可用)和 cuDNN 注意力后端(子后端 1 和 2,在 PyTorch 和 JAX 中可用)都基于 flash 算法。
1.2 flash-attention
flash-attention 后端仅在 PyTorch 中可用,它是围绕公共 flash-attn
包 [3] 封装的模块。
flash-attention 后端支持 flash-attn
的功能以及一些额外的功能,以方便使用 flash-attn
,例如将 attention_mask
转换为累积序列长度 cu_seqlens
以用于 padding
掩码用例。 有关详细信息,请参阅 transformer_engine.pytorch.attention.FlashAttention
。
flash-attn
依赖项在 Transformer Engine 中定期更新。 截至 v2.0,Transformer Engine 支持 flash-attn
2.0.6+(请参阅 setup.py)。
要了解 flash-attn
的性能,请参阅他们的基准测试 此处。
1.3 cuDNN 注意力
cuDNN 注意力后端在 PyTorch 和 JAX 中均可用,为注意力计算提供了另一种高性能解决方案。 它需要运行 cuDNN,并具有多个子后端以支持不同的精度和序列长度。
子后端 | 算法 | 精度 | 序列长度 | 架构 | 附加信息 |
---|---|---|---|---|---|
0 | Non-Flash | BF16/FP16 | ≤512 | sm80, 90 | |
1 | Flash | BF16/FP16 | 任意 | sm80+ | |
2 | Flash | FP8 | cuDNN pre-9.0: ≤512 | cuDNN pre-9.0: sm90 | |
cuDNN 9.0+: 任意 | cuDNN 9.0+: sm90+ | cuDNN 9.0+: cudnn-frontend |
cuDNN 注意力后端和 flash-attention 后端有几个显着的差异。 截至 Transformer Engine 2.0、cuDNN 9.3 和 flash-attn
2.4.2,
flash-attention 仅支持 PyTorch 框架,而 cuDNN 注意力支持 PyTorch 和 JAX。
flash-attention 支持 BF16、FP16 精度,而 cuDNN 注意力也支持 FP8(通过其子后端 2)。
flash-attention 支持
bshd
、thd
输入格式(无需任何转置)和sbhd
格式(需要转置),而 cuDNN 注意力支持所有三种格式(无需转置)(有关更多详细信息,请参阅第 3.1 节)。flash-attention 不支持
post_scale_bias
,而 cuDNN 注意力支持。flash-attention 支持 KV 缓存和分页注意力,而 cuDNN 注意力不支持。
flash-attention 在交叉注意力中对
causal
掩码使用右下对角线(请参阅 更改日志),而 cuDNN 注意力同时支持左上和右下。根据我们对许多常用模型配置的基准测试,flash-attention 在 Ampere 架构上的性能优于 cuDNN 注意力,而 cuDNN 注意力在 Hopper 架构上具有 20-50% 的优势。
为了比较 cuDNN 注意力和 flash-attention,用户可以修改 benchmarks/attention/benchmark_attention.py 中的 model_configs
字典来收集性能数据。 该脚本对 model_configs
中的每个条目运行 num_iters
次,每次运行都包含一次前向传播和一次反向传播。 两种后端都会尝试,如果某个后端不支持特定的用户输入,则最终表格中的运行时和加速比将为 0。
[ ]:
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
[1]:
!cd ../../../benchmarks/attention/ && python benchmark_attention.py
Device 0: NVIDIA H100 80GB HBM3 GPU, sm90 compute capability, 79.1GB memory
Running test_0 with cuDNN attention and flash-attention...
Running test_1 with cuDNN attention and flash-attention...
Running test_2 with cuDNN attention...
Running test_3 with cuDNN attention and flash-attention...
cuDNN fwd+bwd (ms) flash-attn fwd+bwd (ms) cuDNN vs flash speedup
test_0 0.0340 0.0468 1.3786
test_1 0.3664 0.5850 1.5968
test_2 0.9332 0.0000 0.0000
test_3 7.4875 11.8879 1.5877
2. 后端选择
鉴于各种注意力后端,Transformer Engine 具有选择逻辑,可以为特定用户输入和运行时环境选择最合适的后端。 选择逻辑基于后端可用性和后端性能。
后端可用性由模型配置、训练超参数、软件版本和所讨论的 GPU 架构等因素决定。 例如,一些考虑因素包括序列长度、注意力头数、头大小、注意力掩码类型、注意力偏置类型、训练或推理模式、自注意力或交叉注意力、MHA 或 MQA/GQA、flash-attn
/cuDNN 库版本以及 GPU 的计算能力。
当有多个后端可用时,Transformer Engine 会根据性能进行后端选择。 通常,我们的选择逻辑中遵循以下几个规则(见下表)。 随着我们监控不同后端的性能,选择逻辑可能会发生变化。
框架 | 选择顺序 |
---|---|
PyTorch | sm90: cuDNN 注意力 > flash-attention > PyTorch 原生注意力 |
sm80: flash-attention > cuDNN 注意力 > PyTorch 原生注意力 | |
cuDNN 注意力: 子后端 1 > 子后端 0 | |
JAX | cuDNN 注意力 > JAX 原生注意力 |
2.1 调试信息
为了找出运行时正在使用的后端,我们有以下两个调试标志。 日志记录是通过使用 logging
包完成的。
NVTE_DEBUG = 0/1 # disables/enables debugging
NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages
注意
截至 Transformer Engine 2.0,这些标志仅在 PyTorch 中受支持。 预计将来会添加 JAX 支持。
示例脚本 example_attention.py 运行一个非常基本的模型,其中包含两个注意力后端:cuDNN 注意力和 flash-attention。 在这里,NVTE_DEBUG_LEVEL=1
允许我们找出运行时使用的后端/子后端。
[24]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python example_attention.py
Run cuDNN attention...
[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
Run flash-attention...
[INFO | DotProductAttention]: Running with FlashAttention backend
Test passed.
NVTE_DEBUG_LEVEL=2
允许我们了解有关后端选择逻辑的更多信息。 如果用户想提交错误报告,我们鼓励他们仔细检查 config
并将其提供给 Transformer Engine 团队。
[23]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 python example_attention.py
Run cuDNN attention...
[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}
[DEBUG | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0
[DEBUG | DotProductAttention]: Available backends = {FlashAttention=False, FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
[DEBUG | DotProductAttention]: Selected backend = FusedAttention (sub-backend 1)
[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
Run flash-attention...
[DEBUG | DotProductAttention]: Running with config={'transformer_engine_version': '1.10.0.dev0+ee85a91', 'compute_capability': 'sm90', 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.3.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'bshd_bshd_bshd', 'batch_size': 2, 'num_heads': 16, 'num_gqa_groups': 16, 'max_seqlen_q': 512, 'max_seqlen_kv': 512, 'head_dim_qk': 64, 'head_dim_v': 64, 'attn_mask_type': 'no_mask', 'window_size': (-1, -1), 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None, 'recipe': margin=0, format=HYBRID, amax_history_len=1024, wgrad_override=False, fp8_dpa=False, fp8_mha=False}}
[DEBUG | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0
[DEBUG | DotProductAttention]: Available backends = {FlashAttention=True, FusedAttention=False, UnfusedDotProductAttention=True}
[DEBUG | DotProductAttention]: Selected backend = FlashAttention
[INFO | DotProductAttention]: Running with FlashAttention backend
Test passed.
2.2 用户控制
用户通常不需要担心后端选择。 但是,如果遇到收敛或性能问题,Transformer Engine 提供了其他一些环境变量供用户试验不同的后端。
flash-attention 或 cuDNN 注意力: 用户可以通过 PyTorch 中的以下两个环境变量启用/禁用 flash-attention 后端或 cuDNN 注意力后端。
NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1
cuDNN 注意力子后端: 此环境变量允许用户表达他们对 cuDNN 注意力子后端的偏好。 但是,只有在选定的子后端符合条件时(即,如果它支持提供的输入和运行时环境),才会使用它。
NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend
cuDNN 子后端 1 的执行路径: cuDNN 注意力子后端 1 还提供两个执行路径:工作区优化路径和非工作区优化路径。 工作区优化路径需要更大的全局内存量,提供确定性,并提供偏置梯度支持。 在 cuDNN 9.0 之前,它也比非工作区优化路径具有 20-30% 的性能优势。 但在 cuDNN 9.0 之后,它比非工作区优化路径慢 20-30%。
用户可以通过以下环境变量试验这两种路径。 但是,请注意可能的内存不足风险。
Before cuDNN 9.0:
NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path
NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path
After cuDNN 9.0:
NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path
NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path
注意
环境变量 NVTE_FLASH_ATTN、NVTE_FUSED_ATTN、NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT 和 NVTE_ALLOW_NONDETERMINISTIC_ALGO 仅在 PyTorch 中受支持,将来将添加到 JAX 中。
2.3 示例测试
我们的 单元测试 演示了 Transformer Engine 点积注意力 API 的使用。 我们鼓励用户在将 Transformer Engine 集成到他们的 ML 工作流程中时,将它们用作模板。
例如,在 PyTorch 中,test_dot_product_attention 提供了 pytorch.DotProductAttention
的各种用例,从数据类型、模型配置、检查点到 QKV 布局。
3. 后端支持
Transformer Engine 支持常用功能,例如自注意力、交叉注意力、FP16/BF16 精度、dropout 和检查点。 但它还提供了一系列其他功能。 截至 v2.0,Transformer Engine 的注意力后端具有以下支持矩阵。
注意力后端 |
精度 |
架构 |
滑动窗口注意力 |
MQA/GQA |
多潜在注意力 |
上下文并行 |
可能具有确定性 |
---|---|---|---|---|---|---|---|
cuDNN 注意力 (所有框架) |
BF16, FP16, FP8 (仅限 PyTorch) |
sm80+ |
否 |
是 |
是 |
是 ( |
是 |
flash-attention (PyTorch) |
BF16, FP16 |
sm80+ |
是 |
是 |
否 |
是 ( |
是 |
框架原生注意力 |
BF16, FP16, FP32 |
任意 |
否,除非用作掩码 |
是 |
是 (仅限 PyTorch) |
否 |
是 |
提供了一些单元测试,作为将此类功能集成到用户模型中的起点。 例如,- 滑动窗口注意力: test_dpa_swa - MQA/GQA: test_te_layer_mqa_gqa - 多潜在注意力: test_dpa_mla - 上下文并行: test_cp_with_fused_attention, test_cp_with_flash_attention
3.1 QKV 布局
Transformer Engine 支持查询 q
、键 k
、值 v
张量的各种布局。 它定义了 15 种 QKV 布局,这些布局分为 3 种 QKV 格式和 5 种 QKV 布局组,以帮助跨不同布局进行类似的内存/计算操作。 这些布局和组的映射关系是,
|
|
|
|
|
|
---|---|---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
符号系统是 b
代表批大小,s
序列长度,h
注意力头数,d
头维度,t
批次中的 token 总数,即 t = sum(s_i) for i in 0,...,b-1
。 以下是一些布局及其说明示例,以帮助澄清定义。
qkv_layout=sb3hd: q
、k
、v
是序列优先,即 s
是每个张量中的前导维度。 它们是张量 qkv
的不同切片:q, k, v = [qkv[:,:,i,:,:] for i in range(3)]
。 它们在 h * d
维度上交错排列。
qkv_layout=bshd_bsh2d: q
、k
、v
是批次优先,即 b
是每个张量中的前导维度。 q
是连续的,k
、v
是张量 kv
的不同切片:k, v = [kv[:,:,:,i,:] for i in range(2)]
。 k
、v
在 d
维度上交错排列。
bsh2d
中的 s
和 h
是 k
、v
的最大序列长度和头数,这可能与 q
的 bshd
中的 s
和 h
不同。 出于简洁的原因,我们将它们表示为相同。 Transformer Engine 确实会区分它们的值以进行实际执行。
qkv_layout=thd_thd_thd: q
、k
、v
在批次中具有可变的序列长度。 它们都是连续的,没有交错排列。
截至 v2.0,Transformer Engine 具有以下支持矩阵。
后端 | 支持的 QKV 格式 | 注释 |
---|---|---|
flash-attention |
| PyTorch: 3 种格式,即 15 种布局 |
cuDNN 注意力 |
| PyTorch: 3 种格式,即 15 种布局 |
JAX: | ||
框架原生注意力 |
| PyTorch, JAX: 2 种格式,即 10 种布局 |
不同布局的一些示例用法可以在 test_dpa_qkv_layout 和 test_dpa_qkv_layout_thd 中找到。 Transformer Engine 还提供了一个实用程序函数 transformer_engine.pytorch.attention.get_qkv_layout,以帮助确定一组 q
、k
、v
张量具有哪种布局(仅限 PyTorch)。
注意
当使用 RoPE 时,qkv_layout 可能会在 Transformer Engine PyTorch 中通过 get_qkv_layout 更改。 这是由于我们的 RoPE 实现的就地性质。 我们将 q
、k
、v
张量从其初始布局转换为相应的 hd_hd_hd 布局。 例如,从 RoPE 之前的 pytorch.MultiHeadAttention 中的 sbh3d 转换为 RoPE 之后的 pytorch.DotProductAttention 中的 sbhd_sbhd_sbhd。
3.2 注意力掩码
Transformer Engine 支持 7 种掩码类型,所有掩码都定义为 True
掩盖相应元素,False
包括注意力计算中的相应元素。
no_mask
,padding
,causal
,causal_bottom_right
,padding_causal
,padding_causal_bottom_right
,arbitrary
不同的后端为注意力掩码提供不同的支持。 截至 Transformer Engine 2.0,
后端 | 支持的掩码类型 | 需要 |
---|---|---|
flash-attention |
|
|
cuDNN 注意力 |
| |
框架原生注意力 | 全部 (PyTorch)
| |
Padding 掩码: 对于 padding
, padding_causal
, padding_causal_bottom_right
掩码类型,用户需要提供序列长度信息,以帮助 Transformer Engine 确定批次中每个序列的结束位置。截至 Transformer Engine 2.0,在 PyTorch 中有两种选择,在 JAX 中有一种选择。
PyTorch: 当用户同时提供这两个选项时,首选
cu_seqlens
,因为无需额外的转换。cu_seqlens
: 用户可以为q
和k
/v
向 flash-attention 或 cuDNN attention 后端提供累积序列长度张量cu_seqlens_q
和cu_seqlens_kv
。cu_seqlens
的一个示例是[0, 2, 6, 7]
,对应于 3 个批次的[aa000, bbbb0, c0000]
。attention_mask
: 用户也可以提供attention_mask
作为替代方案,它将被转换为cu_seqlens
。对于自注意力机制,attention_mask
应该是一个形状为[batch_size, 1, 1, seqlen_q]
的单个张量;对于交叉注意力机制,attention_mask
应该是一个包含两个张量的列表,形状分别为[batch_size, 1, 1, seqlen_q]
和[batch_size, 1, 1, seqlen_kv]
。
JAX: 用户应提供形状为
[batch_size, 1, seqlen_q, seqlen_kv]
的attention_mask
张量。
qkv_format=thd: 如果未提供 max_seqlen_q
和 max_seqlen_kv
,Transformer Engine 会从 q
、k
、v
中提取最大序列长度信息。这需要 GPU-CPU 复制和同步操作。出于性能原因,对于 thd
QKV 格式,请将 max_seqlen_q
和 max_seqlen_kv
设置为其适当的值。
Arbitrary 掩码: 截至 v9.3,cuDNN 不支持 Arbitrary
掩码类型。但是,用户可以将掩码转换为常规的 post_scale_bias
偏差,并实现相同的功能。有关此转换的示例脚本,请参阅 arbitrary_mask_to_post_scale_bias.py。
[33]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python arbitrary_mask_to_post_scale_bias.py
Run with post_scale_bias:
[INFO | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
Run with arbitrary mask:
[INFO | DotProductAttention]: Running with UnfusedDotProductAttention backend
Test passed!
有关使用不同注意力掩码运行 Transformer Engine 的更多示例,请参见 test_dpa_mask。
3.3 注意力偏差
Transformer Engine 支持 4 种注意力偏差类型:no_bias
、pre_scale_bias
、post_scale_bias
和 ALiBi
(带/不带自定义斜率)。截至 Transformer Engine 2.0,其支持矩阵如下。
后端 | 偏差类型 | 偏差形状 | 偏差数据类型 | 架构 |
---|---|---|---|---|
flash-attention |
| N/A | ALiBi 斜率: FP32 | sm80+ |
cuDNN 注意力 | PyTorch: |
|
| cuDNN 8.9.6+: sm90 |
JAX: | ALiBi 斜率: FP32 | cuDNN 9.0+: sm80+ | ||
框架原生注意力 |
|
|
| sm80+ |
flash-attention 后端通过要求用户传入 alibi_slopes
张量来启用 ALiBi
,该张量可以是 vanilla ALiBi 的默认斜率,也可以是用户定义的斜率。另一方面,cuDNN attention 通过接收一个 Boolean
标志来支持 ALiBi
,并且截至 cuDNN 9.0,它仅支持 vanilla ALiBi。
框架原生后端不显式支持 ALiBi
,但用户可以将 ALiBi
转换为常规的 post_scale_bias
偏差以达到相同的效果。在 PyTorch 中,可以使用实用程序函数 transformer_engine.pytorch.attention.get_alibi
来帮助进行转换。
有关如何使用各种注意力偏差的更多示例,请参见 test_dpa_bias。
3.4 FP8 注意力
Transformer Engine 的一个独特功能是其 FP8 支持,不仅适用于 Linear
层,也适用于点积注意力。Transformer Engine 的 FP8 注意力支持通过其 cuDNN attention 子后端 2 实现。回想一下图 1:两个 MatMul
操作在 FP8 中执行以提高计算效率,而 SoftMax
操作在 FP32 中执行以提高数值精度。
截至 v2.0,Transformer Engine 通过其 C API 和 PyTorch API 支持 FP8 注意力。其 PyTorch API 提供两个选项,均通过 FP8 配方定义 transformer_engine.common.recipe.DelayedScaling
进行控制。
DelayedScaling.fp8_dpa=True (default=False)
: 当 cuDNN attention 子后端 2 支持提供的用户输入时,这将启用它。FusedAttention
模块(用于 cuDNN attention)将 FP16 或 BF16 张量作为输入,在 FP8 中执行点积注意力,并返回 FP16 或 BF16(与输入类型相同)的注意力 logits。需要进行类型转换操作,以在模块开始时将张量转换为 FP8,并在模块结束时转换回 FP16/BF16。DelayedScaling.fp8_mha=True (default=False)
: 此选项在fp8_dpa=True
的基础上,移除了FusedAttention
模块开始和结束时的类型转换操作。此功能是实验性的。
有关使用这两个功能的示例,请参见 test_dpa_fp8_vs_f16 和 test_mha_fp8_vs_f16。要禁用 FP8 注意力的反向传播,仅在前向传播中使用它,用户还可以设置 NVTE_FP8_DPA_BWD=0 (default=1)
。