通用 API

class transformer_engine.common.recipe.Format(*args, **kwds)

支持的 FP8 格式。

:
  • E4M3 – 所有 FP8 张量均为 e4m3 格式

  • E5M2 – 所有 FP8 张量均为 e5m2 格式

  • HYBRID – 前向传播中的 FP8 张量为 e4m3 格式,反向传播中的 FP8 张量为 e5m2 格式

class transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo='max', scaling_factor_compute_algo=None)

使用延迟缩放因子策略。使用前一次迭代的缩放因子,并记录 amax_history_len 步的 amax 历史记录。

参数:
  • margin (int, 默认 = 0) – 缩放因子计算的边距。

  • fp8_format ({Format.E4M3, Format.HYBRID}, 默认 = Format.HYBRID) – 控制前向和反向传播期间使用的 FP8 数据格式。

  • amax_history_len (int, 默认 = 1024) – 用于缩放因子计算的 amax 历史窗口的长度。

  • amax_compute_algo ({'max', 'most_recent', Callable}, 默认 = 'max') –

    用于选择缩放因子计算的 amax 值的算法。 有 2 个预定义选项:max 选择历史窗口中最大的 amax,而 most_recent 始终选择最近看到的值。 或者,可以传递一个具有以下签名的函数

    def amax_compute(amax_history: Tensor) -> Tensor
    

    其中 Tensor 是框架张量类型。

  • scaling_factor_compute_algo (Callable, 默认 = None) –

    用于基于 amax 值计算新缩放因子的算法。 它应该是一个具有以下签名的函数

    def scaling_factor_compute(amax: Tensor,
                               old_scaling_factor: Tensor,
                               fp8_max: Tensor,
                               recipe: DelayedScaling) -> Tensor
    

    其中 Tensor 是框架张量类型。

  • reduce_amax (bool, 默认 = True) – 默认情况下,如果 torch.distributed 已初始化,则 FP8 张量的 amax 值将在 fp8_group(在 fp8_autocast 调用中指定)中减少。 这使 amax 和缩放因子在给定的分布式组中保持同步。 如果设置为 False,则跳过此减少,并且每个 GPU 都维护本地 amax 和缩放因子。 为了确保在这种情况下跨检查点边界的结果在数值上相同,所有 rank 都必须进行检查点以存储本地张量。

  • fp8_dpa (bool, 默认 = False) – 是否启用 FP8 点积注意力 (DPA)。 当模型放置在 fp8_autocast(enabled=True) 区域中且 fp8_dpa 设置为 True 时,DPA 会将输入从更高精度转换为 FP8,在 FP8 中执行注意力,并将张量转换回更高精度作为输出。 FP8 DPA 目前仅在 FusedAttention 后端中受支持。

  • fp8_mha (bool, 默认 = False) – 是否启用 FP8 多头注意力 (MHA)。 当为 True 时,它会删除上述在 DPA 边界处的转换操作。 目前,此功能仅支持标准 MHA 模块,即 LayerNormLinear/Linear + DPA + Linear。 当 fp8_mha = False, fp8_dpa = True 时,典型的 MHA 模块的工作方式为 LayerNormLinear(BF16 输出)->(转换为 FP8)FP8 DPA(转换为 BF16)-> Linear。 当 fp8_mha = True, fp8_dpa = True 时,它变为 LayerNormLinear(FP8 输出)-> FP8 DPA -> Linear

注意

  • 默认情况下(当 scaling_factor_compute_algo 保留为 None 时),缩放因子是使用以下公式从最终 amax 值计算得出的

    FP8_MAX = maximum_representable_value(fp8_format)
    new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)
    
  • fp8_dpafp8_mha 是 Beta 功能,它们的 API 和功能可能会在未来的 Transformer Engine 版本中发生更改。

class transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)

使用 MXFP8 缩放因子策略。

在此策略中,张量以分块方式缩放。 每组 32 个连续值使用自己的缩放因子一起缩放。 缩放因子的类型为 E8M0(8 位指数,0 位尾数),相当于按 2 的幂进行缩放。

由于缩放发生在特定方向(行式或列式)上,因此在此配方中,量化张量及其转置在数值上并不等效。 因此,当 Transformer Engine 需要 MXFP8 张量及其转置(例如,计算前向和反向传播)时,在量化期间,两个版本都从高精度输入计算得出,以避免双重量化错误。

参数:

fp8_format ({Format.E4M3, Format.HYBRID}, 默认 = Format.E4M3) – 控制前向和反向传播期间使用的 FP8 数据格式。