Jax

逻辑轴的预定义变量

变量在 transformer_engine.jax.sharding 中可用。

  • BATCH_AXES:批次维度的逻辑轴。它通常在 Mesh 上沿 DP + FSDP 分片。

  • SEQLEN_AXES:序列长度维度的逻辑轴。它通常不分片。

  • SEQLEN_TP_AXES:序列长度维度的逻辑轴。它通常在 Mesh 上沿 TP 分片。

  • HEAD_AXES:MHA 的头维度逻辑轴。它通常在 Mesh 上沿 TP 分片。

  • HIDDEN_AXES:隐藏维度的逻辑轴。它通常不分片。

  • HIDDEN_TP_AXES:隐藏维度的逻辑轴。它通常在 Mesh 上沿 TP 分片。

  • JOINED_AXES:未定义维度的逻辑轴。它通常不分片。

模块

class transformer_engine.jax.flax.TransformerLayerType(*args, **kwds)

TransformerLayerType 是一个 Enum 类,用于指定 TransformerLayer 的类型

:
  • ENCODER – TransformerLayer 的编码器类型。

  • DECODER – TransformerLayer 的解码器类型。

class transformer_engine.jax.MeshResource

一个数据容器,用于指示 Mesh 中哪个轴用于数据并行,哪个轴用于张量并行。

参数:
  • dp_resource (str, default = None) – Mesh 中用于沿批次分片的轴名称。如果为 None,则禁用数据并行。

  • tp_resource (str, default = None) – Mesh 中用于沿隐藏维度拆分的轴名称。如果为 None,则禁用张量并行。

  • fsdp_resource (str, default = None) – Mesh 中用于沿批次和权重拆分的轴名称。如果为 None,则禁用完全分片数据并行。

  • pp_resource (str, default = None) – Mesh 中用于沿模型层拆分的轴名称。如果为 None,则禁用流水线并行。

  • cp_resource (str, default = None) – Mesh 中用于在注意力中沿序列(上下文)维度拆分的轴名称。如果为 None,则禁用上下文并行。

transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: transformer_engine.common.recipe.DelayedScaling | None = None, mesh_resource: transformer_engine.jax.sharding.MeshResource | None = None) None

FP8 使用的上下文管理器。

mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)

with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
    mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)

    with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
        rules = extend_logical_axis_rules(tuple())
        transformer = TransformerLayer()

        with partitioning.axis_rules(rules):
            pjit(transformer.init, ...)(...)

注意

我们目前仅支持 margin, fp8_format, amax_history_len, 和 amax_compute_algo (值为 ‘max’ 和 ‘most_recent’) 在 recipe.DelayedScaling 中。recipe.DelayedScaling 中的其他参数将触发断言。

参数:
  • enabled (bool, default = False) – 是否启用 fp8

  • fp8_recipe (recipe.DelayedScaling, default = None) – 用于 FP8 训练的配方。

  • mesh_resource (MeshResource, default = None) – 指定用于数据和张量并行的 mesh 轴以进行分片。如果设置为 None,则不使用数据或张量并行。

transformer_engine.jax.update_collections(new: Collection, original: Collection) flax.core.frozen_dict.FrozenDict

用于更新 Flax 的 Collection 的辅助函数。

Collection = [dict, flax.core.frozen_dict.FrozenDict]

参数:
  • new (Collection) – 包含新数据的集合。

  • original (Collection) – 基础集合。

返回:

outputs – 更新后的集合。

返回类型:

Collection

class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)

对小批量输入应用层归一化。此模块支持两种类型的归一化,常规层归一化和均方根层归一化。

常规层归一化如论文 Layer Normalization 中所述

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

\(\gamma\)\(\beta\) 是每个输入样本大小的可学习仿射变换参数。

均方根层归一化 (RMSNorm) 如论文 Root Mean Square Layer Normalization 中所述

\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]
\[RMS = \sqrt{\mathrm{E}[x^2]}\]

\(\gamma\) 是每个输入样本大小的可学习仿射变换参数。

参数:
  • epsilon (float, default = 1e-6) – 添加到层归一化分母的值,以提高数值稳定性。

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。

  • zero_centered_gamma (bool, default = False) –

    如果设置为 True,则 LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。 scale_init 的默认值也将更改。 请参阅 scale_init

  • scale_init (Initializer, default = None) – 用于初始化比例因子 \(\gamma\)。 如果提供 None,则根据 zero_centered_gamma 的值设置 scale_init。 如果 zero_centered_gamma 设置为 True,则 scale_init 为 flax.linen.initializers.zeros。 否则,scale_init 为 flax.linen.initializers.ones。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • scale_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化移位因子 \(\beta\),仅当 layernorm_type='layernorm' 时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • bias_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对移位因子 \(\beta\) 进行分片的轴名称。 仅当 layernorm_type='layernorm' 时使用。

优化参数:
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, default = False) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

__call__(x: jax.numpy.ndarray) jax.numpy.ndarray

将层归一化应用于输入 inputs

参数:

inputs (jax.numpy.ndarray) – 输入张量。

返回:

outputs – 输出张量。

返回类型:

jax.numpy.ndarray

class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)

对传入数据应用线性变换 \(y = xA^T + b\)

参数:
  • features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • kernel_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对权重进行分片的轴名称。

  • use_bias (bool, default = False) – 指示是否启用偏置移位。 如果设置为 False,则该层不会学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化偏置,仅当 use_bias=True 时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • bias_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对偏置进行分片的轴名称,仅当 use_bias=True 时使用。

  • enable_low_rank_adaptation (bool, default = False) – 指示是否为每个线性层启用低秩自适应。

  • low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当 enable_low_rank_adaptation=True 时使用

  • low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。

  • axis (Union[Iterable[int], int], default = -1) – 要在其上应用变换的整数元组轴。

优化参数:
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, default = True) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

__call__(inputs: Array) Array

将线性变换应用于输入。

参数:

inputs (jax.numpy.ndarray) – 输入张量。

返回:

outputs – 输出张量。

返回类型:

jax.numpy.ndarray

class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

对传入数据应用层归一化,然后进行线性变换。

参数:
  • features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。

  • enable_layernorm (bool, default = True) – 指示是否在线性变换之前启用层归一化。

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。

  • epsilon (float, default = 1e-6) – 添加到层归一化分母的值,以提高数值稳定性。

  • zero_centered_gamma (bool, default = False) –

    如果设置为 True,则 LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。 scale_init 的默认值也将更改。 请参阅 scale_init

  • scale_init (Initializer, default = None) – 用于初始化比例因子 \(\gamma\)。 如果提供 None,则根据 zero_centered_gamma 的值设置 scale_init。 如果 zero_centered_gamma 设置为 True,则 scale_init 为 flax.linen.initializers.zeros。 否则,scale_init 为 flax.linen.initializers.ones。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • scale_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅当 enable_layernorm=True 时使用。

  • ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化移位因子 \(\beta\),仅当 enable_layernorm=Truelayernorm_type='layernorm' 时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • ln_bias_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对移位因子 \(\beta\) 进行分片的轴名称。 仅当 enable_layernorm=Truelayernorm_type='layernorm' 时使用。

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • kernel_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对权重进行分片的轴名称。

  • use_bias (bool, default = False) – 指示是否启用偏置移位。 如果设置为 False,则该层不会学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化偏置,仅当 use_bias=True 时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • bias_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对偏置进行分片的轴名称,仅当 use_bias=True 时使用。

  • return_layernorm_output (bool, default = True) – 指示是否返回层归一化的输出。 如果设置为 False,则在输出中将 None 作为第二个张量返回。

  • enable_low_rank_adaptation (bool, default = False) – 指示是否为每个线性层启用低秩自适应。

  • low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当 enable_low_rank_adaptation=True 时使用

  • low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。

  • axis (Union[Iterable[int], int], default = -1) – 要在其上应用变换的整数元组轴。

  • layernorm_input_axes (Tuple[str, ...], default = None) – 指示层归一化输入的分片约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。 默认为 None,表示不插入分片约束。

  • dot_input_axes (Tuple[str, ...], default = None) – 指示点积输入的分片约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。 默认为 None,表示不插入分片约束。

优化参数:
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, default = True) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

  • depth_scaling (float, default = None) – 用于缩放来自 DenseGeneral 的输出的因子。 它应该是一个浮点值或 None。 当设置为 None 时,不应用缩放。

__call__(inputs: Array) Array

将层归一化应用于输入,然后进行线性变换。

参数:

inputs (jax.numpy.ndarray) – 输入张量。

返回:

  • outputs (jax.numpy.ndarray) – 输出张量。

  • ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。 如果 return_layernorm_output=False,则这将为 None。

class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

对输入应用层归一化,然后应用 MLP 模块,该模块由 2 个连续的线性变换组成,并由给定的激活函数分隔。

参数:
  • intermediate_dim (int, default = 2048) – 输入样本投影到的中间大小。

  • enable_layernorm (bool, default = True) – 指示是否在线性变换之前启用层归一化。

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。

  • epsilon (float, default = 1e-6) – 添加到层归一化分母的值,以提高数值稳定性。

  • zero_centered_gamma (bool, default = False) –

    如果设置为 True,则 LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。 scale_init 的默认值也将更改。 请参阅 scale_init

  • scale_init (Initializer, default = None) – 用于初始化比例因子 \(\gamma\)。 如果提供 None,则根据 zero_centered_gamma 的值设置 scale_init。 如果 zero_centered_gamma 设置为 True,则 scale_init 为 flax.linen.initializers.zeros。 否则,scale_init 为 flax.linen.initializers.ones。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • scale_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅当 enable_layernorm=True 时使用。

  • ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化移位因子 \(\beta\),仅当 enable_layernorm=Truelayernorm_type='layernorm' 时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • ln_bias_axes (Tuple[str, ...], default = ('embed', )) – 用于分片移位因子 \(\beta\) 以及对应网格的轴的名称。仅当 enable_layernorm=Truelayernorm_type='layernorm' 时使用。

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化线性变换的权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。

  • kernel_axes_1 (Tuple[str, ...], default = ('embed', 'act', 'mlp')) – 用于分片第一个线性变换权重的轴的名称,并与对应的网格关联。

  • kernel_axes_2 (Tuple[str, ...], default = ('mlp', 'embed')) – 用于分片第二个线性变换权重的轴的名称,并与对应的网格关联。

  • use_bias (bool, default = False) – 指示是否启用偏置移位。 如果设置为 False,则该层不会学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化偏置,仅当 use_bias=True 时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。

  • bias_axes_1 (Tuple[str, ...], default = ('mlp',)) – 用于分片第一个线性变换权重的偏置的轴的名称,并与对应的网格关联。仅当 use_bias=True 时使用。

  • bias_axes_2 (Tuple[str, ...], default = ('embed',)) – 用于分片第二个线性变换权重的偏置的轴的名称,并与对应的网格关联。仅当 use_bias=True 时使用。

  • return_layernorm_output (bool, default = True) – 指示是否返回层归一化的输出。 如果设置为 False,则在输出中将 None 作为第二个张量返回。

  • activations (Sequence[Union[str, Callable]], default = ('relu',)) – 在第一个线性变换之后应用的一系列激活函数。每个激活函数都有自己的变换层。

  • intermediate_dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于生成 Dropout 掩码的键名。

  • intermediate_dropout_rate (float, default = 0.1) – 在 activations 之后 dropout 操作的 dropout 概率。

  • intermediate_hidden_dropout_dims (Sequence[int], default = ()) – 将共享相同 dropout 掩码的隐藏层维度

  • enable_low_rank_adaptation (bool, default = False) – 指示是否为每个线性层启用低秩自适应。

  • low_rank_adaptation_dim (int, default = 32) – 低秩适配的维度,仅当 enable_low_rank_adaptation=True 时使用。

  • low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。

  • axis (Union[Iterable[int], int], default = -1) – 要在其上应用变换的整数元组轴。

  • layernorm_input_axes (Tuple[str, ...], default = None) – 指示层归一化输入的分片约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。 默认为 None,表示不插入分片约束。

  • dot_1_input_axes (Tuple[str, ...], default = None) – 指示第一个点积运算输入的 sharding 约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。默认为 None,表示不插入 sharding 约束。

  • dot_2_input_axes (Tuple[str, ...], default = None) – 指示第二个点积运算输入的 sharding 约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。默认为 None,表示不插入 sharding 约束。

优化参数:
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, default = True) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

__call__(inputs: Array, deterministic: bool = False) Array

对输入应用层归一化,然后接上前馈网络 (MLP Block)。

参数:
  • inputs (jax.numpy.ndarray) – 输入张量。

  • deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 操作。

返回:

  • outputs (jax.numpy.ndarray) – 输出张量。

  • ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。 如果 return_layernorm_output=False,则这将为 None。

class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)

T5 风格的相对位置嵌入,用于注意力 logits。

参数:
  • num_buckets (int) – 将键和查询位置之间的距离分桶的数量。

  • max_distance (int) – 最大距离,超过此距离的所有距离都将被归入最后一个距离桶。

  • num_attention_heads (int) – Transformer 层中的注意力头的数量。

  • embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – 用于初始化相对嵌入表。

  • embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – 用于分片嵌入注意力偏置的轴的名称,并与对应的网格关联。

优化参数:

dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

__call__(q_seqlen, k_seqlen, bidirectional=True)

生成相对位置嵌入注意力偏置。

参数:
  • q_seqlen (int) – 查询的序列长度。

  • k_seqlen (int) – 键的序列长度。

  • bidirectional (bool, default = True) – 指示是否允许正向的 memory-query 相对位置嵌入。

返回:

output – 形状为 (1, num_attention_heads, q_seqlen, k_seqlen) 的注意力偏置。

返回类型:

jax.numpy.ndarray

class transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)

点积注意力 (DPA)。允许模型共同关注来自不同表示子空间的信息,如论文 Attention Is All You Need 中所述。

注意

DotProductAttention 模块支持两种后端:非融合和融合注意力机制。非融合注意力使用 JAX 原生操作实现,提供广泛的兼容性和灵活性。相比之下,融合注意力使用 cuDNN 融合注意力,以在支持的硬件上获得更高的性能和更低的内存使用率。用户可以通过 NVTE_FUSED_ATTN 环境变量在两种后端之间进行选择

  • 设置 NVTE_FUSED_ATTN=0 以使用非融合注意力(默认)。

  • 设置 NVTE_FUSED_ATTN=1 以使用融合注意力。如果系统上没有所需的 cuDNN 融合注意力内核,则会发出警告,并且模块将自动回退到非融合后端。

注意

DotProductAttention 默认设置启用非确定性内核,以减少工作空间需求并加快计算速度。用户可以通过 NVTE_ALLOW_NONDETERMINISTIC_ALGO 环境变量禁用非确定性内核

  • NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 仅允许确定性内核。

  • NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 允许非确定性内核(默认)。

参数:
  • head_dim (int) – 每个注意力头的隐藏维度。

  • num_attention_heads (int) – 注意力头的数量。

  • num_gqa_groups (int, default = None) – GQA 组的数量。当为 None 时,它等于 num_attention_heads。分组查询注意力在 本文 中描述。这仅影响键和值,不影响查询。GQA-1 等效于多查询注意力 (MQA),而 GQA-H 等效于 MHA,即 num_gqa_groups = num_attention_heads

  • attention_dropout (float, default = 0.0) – softmax 之后 dropout 操作的 dropout 概率。

  • attn_mask_type (str, default = 'causal') –

    此参数指定在 softmax 操作期间应用的注意力掩码的类型。可用选项为 {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’, ‘padding_causal’}

    每个选项描述如下

    • no_mask: 不应用注意力掩码。这意味着注意力将考虑整个序列,没有任何限制。

    • padding: 指示每个序列末尾存在填充。用户必须在 __call__ 方法中提供形状为 [batch, 1, max_seqlen_q, max_seqlen_kv] 的掩码,以指定填充位置。

    • causal: 将上三角掩码应用于 softmax 输入,确保对某个位置的预测仅依赖于来自其之前位置的已知输出。

    • causal_padding / padding_causal: 因果掩码和填充掩码的组合。“causal_padding” 和 “padding_causal” 都是可接受的,并且效果相同。

    注意

    mask 在 ‘no_mask’ 和 ‘causal’ 情况下被忽略。

  • attn_bias_type (Optional[str], default = None) – 传入注意力的注意力偏置类型。可用选项:{‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}。当默认值存在时,类型由 MHA 的偏置参数自动决定。如果有偏置,则为 post_scale_bias。否则使用 no_bias

  • dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于在核心注意力中生成 Dropout 掩码的键名。

  • float32_logits (bool, default = False) – 是否在 float32 中计算非融合注意力后端的注意力 logits。对于融合注意力后端,累积始终为 float32,而没有性能开销。

  • qkv_layout (str, default = 'bshd_bshd_bshd') –

    指定 __call__() 中 query、key 和 value 张量的维度布局格式。它指示如何处理输入。可用选项:{‘bs3hd’, ‘bshd_bs2hd’, ‘bshd_bshd_bshd’}。其中

    • bs3hd: query 张量被视为 qkvpacked 张量,形状为 [b, s, 3, h, d]。 __call__() 中的 key 和 value 参数在此布局中被忽略。

    • bshd_bs2hd: query 张量形状为 [b, s, h, d]。 key 张量被视为 kvpacked 张量,形状为 [b, s, 2, h, d]。 __call__() 中的 value 参数被忽略。

    • bshd_bshd_bshd: query、key 和 value 是分开的,形状为 [b, s, h, d]。

    符号解释

    • b: 批大小

    • s: 序列长度

    • h: num_attention_heads 或 num_gqa_groups

    • d: head 维度

  • scale_factor (Optional[float], default = None) – 应用于 query 的缩放因子。当为 None 时,缩放因子等于 \(\frac{1}{\sqrt{head\_dim}}\)。这对于像 T5X 这样的模型很有用,它不需要对 query 应用缩放,即设置为 scale_factor=1.

  • transpose_batch_sequence (bool, default = True) – 指示输入张量是否已交换批次和序列长度维度的轴。如果设置为 True,则输入张量应为 (seqlen, batch, …),否则为 (batch, seqlen, …)。

  • window_size (Optional[Tuple[int, int]], default = None) – 滑动窗口大小。默认值是没有滑动窗口。

  • (bool) (context_parallel_causal_load_balanced) – 指示在运行上下文并行时,序列是否针对因果掩码负载均衡进行排序。

  • (str) (context_parallel_axis)

优化参数:

dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

__call__(query: Array, key: Array, value: Array, mask: Array | None = None, bias: Array | None = None, *, deterministic: bool = False) Array
参数:
  • query (jax.numpy.ndarray) – query 张量表示的详细信息在 qkv_layout 中描述。

  • key (jax.numpy.ndarrary) – kery 张量表示的详细信息在 qkv_layout 中描述。

  • value (jax.numpy.ndarrary) – value 张量表示的详细信息在 qkv_layout 中描述。

  • mask (jax.numpy.ndarray, default = None) – 用于屏蔽注意力 softmax 输入的布尔张量。True 表示屏蔽对应的值。当 self.attn_mask_type 为 ‘no_mask’ 或 ‘causal’ 时被忽略。

  • bias (jax.numpy.ndarray, default = None) – 用于移动注意力 softmax 输入的张量。

  • * – 以下参数仅为关键字参数

  • deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。

返回:

outputs – 输出张量。

返回类型:

jax.numpy.ndarray

class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)

多头注意力 (MHA),包括 Query、Key、Value 和 Output 投影。

参数:
  • head_dim (int) – 每个注意力头的隐藏维度。

  • num_attention_heads (int) – 注意力头的数量。

  • num_gqa_groups (int, default = None) –

    GQA 组的数量。当为 None 时,它等于 num_attention_heads。分组查询注意力在 本文 中描述。这仅影响键和值,不影响查询。GQA-1 等效于多查询注意力 (MQA),而 GQA-H 等效于 MHA,即 num_gqa_groups = num_attention_heads

  • attention_dropout (float, default = 0.0) – softmax 之后 dropout 操作的 dropout 概率。

  • attn_mask_type (str, default = 'causal') –

    此参数指定在 softmax 操作期间应用的注意力掩码的类型。可用选项为 {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’, ‘padding_causal’}

    每个选项描述如下

    • no_mask: 不应用注意力掩码。这意味着注意力将考虑整个序列,没有任何限制。

    • padding: 指示每个序列末尾存在填充。用户必须在 __call__ 方法中提供形状为 [batch, 1, max_seqlen_q, max_seqlen_kv] 的掩码,以指定填充位置。

    • causal: 将上三角掩码应用于 softmax 输入,确保对某个位置的预测仅依赖于来自其之前位置的已知输出。

    • causal_padding / padding_causal: 因果掩码和填充掩码的组合。“causal_padding” 和 “padding_causal” 都是可接受的,并且效果相同。

    注意

    mask 在 ‘no_mask’ 和 ‘causal’ 情况下被忽略。

  • attn_bias_type (Optional[str], default = None) – 传入注意力的注意力偏置类型。可用选项:{‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}。当默认值存在时,类型由 MHA 的偏置参数自动决定。如果有偏置,则为 post_scale_bias。否则使用 no_bias

  • dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于在核心注意力中生成 Dropout 掩码的键名。

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。

  • layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的值,以提高数值稳定性。

  • zero_centered_gamma (bool, default = False) –

    如果设置为 True,则 LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。

  • use_bias (bool, default = False) – 指示是否为 QKV 和输出投影启用偏置移位。如果设置为 False,则该层将不学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影的偏置,仅当 use_bias=True 时使用。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。

  • input_layernorm (bool, default = True) – 如果设置为 False,则不对输入应用层归一化。

  • return_layernorm_output (bool, default = False) – 如果设置为 True,则层归一化的输出将与线性变换的输出一起从前向传播中返回。示例用例:Transformer 模块的残差连接在层归一化之后获取。

  • enable_rotary_pos_emb (bool, default = False) – 是否为投影的 query 和 key 启用旋转位置嵌入。

  • rotary_pos_emb_windows (Tuple[int, int], default = (1, 10000)) – 指示旋转位置嵌入的最小和最大时间尺度,仅当 enable_rotary_pos_emb=True 时使用

  • rotary_pos_emb_group_method (str, default = 'consecutive') – 指示耦合坐标的方法。它应该是 [‘consecutive’, ‘alternate’] 之一。 ‘alternate’ 是将索引 \(i\)\(i + d/2\) 配对,d 是隐藏维度。 ‘consecutive’ 将索引 \(i\)\(i + 1\) 配对。

  • low_rank_adaptation_scope (str, default = 'none') – 指示应用低秩适配的范围。它应该是 [‘none’, ‘all’, ‘qkv_proj’, ‘output_proj’, ‘exclude_qkv_proj’, ‘exclude_output_proj’] 之一

  • low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当 enable_low_rank_adaptation=True 时使用

  • low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。

  • enable_sequence_parallel (bool, default = False) – 是否对点积以外的操作启用序列并行。

  • num_heads (int, default = None) – 已弃用。请参考 num_attention_heads

  • dropout_rate (float, default = None) – 已弃用。请参考 attention_dropout

  • output_layernorm (bool, default = None) – 已弃用。请参考 input_layernorm

  • apply_residual_connection_post_layernorm (bool, default = None) – 已弃用。请参考 return_layernorm_output

优化参数:
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • fuse_qkv_params (bool, default = True) – 如果设置为 True,则此模块为自注意力公开一个用于 query-key-value 的融合参数,为交叉注意力公开一个用于 key-value 的融合参数。

  • transpose_batch_sequence (bool, default = True) – 指示输入张量是否已交换批次和序列长度维度的轴。如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

  • scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。如果设置为 True,则为 \(\frac{Q}{\sqrt{head\_dim}*K}\),否则为 \(Q*K\)

  • scaled_query_init (bool, default = True) – 是否在初始化时按 \(\frac{1}{\sqrt{head\_dim}}\) 缩放 WQ

  • float32_logits (bool, default = False) – 是否在 float32 中计算非融合注意力后端的注意力 logits。对于融合注意力后端,累积始终为 float32,而没有性能开销。

  • fuse_qkv (bool, default = None) – 已弃用。请参考 fuse_qkv_params

  • window_size (Optional[Tuple[int, int]], default = None) – 滑动窗口大小。默认值是没有滑动窗口。

__call__(inputs_q: Array, inputs_kv: Array, mask: Array | None = None, bias: Array | None = None, *, decode: bool = False, deterministic: bool = False) Array

MultiHeadAttention 层:[Query, Key, Value 投影] -> 点积注意力 -> 输出投影。

参数:
  • inputs_q (jax.numpy.ndarray) – 用于 query 投影的输入张量。

  • inputs_kv (jax.numpy.ndarray) – 用于 key/value 投影的输入张量。

  • mask (jax.numpy.ndarray, default = None) – 用于屏蔽注意力 softmax 输入的布尔张量。True 表示屏蔽对应的值。当 self.attn_mask_type 为 ‘no_mask’ 或 ‘causal’ 时被忽略。

  • bias (jax.numpy.ndarray, default = None) – 用于移动注意力 softmax 输入的张量。

  • *

  • decode (bool, default = False) – 指示是否准备和使用自回归缓存。

  • deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。

返回:

outputs – 输出张量。

返回类型:

jax.numpy.ndarray

class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)

TransformerLayer 由相对嵌入、注意力块和前馈网络 (MLP) 组成。这个标准层基于论文 “Attention Is All You Need”。

参数:
  • hidden_size (int, default = 512) – 每个输入样本的隐藏大小。

  • mlp_hidden_size (int, default = 2048) – 输入样本被投影到的中间大小。

  • num_attention_heads (int, default = 8) – Transformer 层中的注意力头的数量。

  • num_gqa_groups (int, default = None) –

    GQA 组的数量。当为 None 时,它等于 num_attention_heads。分组查询注意力在 本文 中描述。这仅影响键和值,不影响查询。GQA-1 等效于多查询注意力 (MQA),而 GQA-H 等效于 MHA,即 num_gqa_groups = num_attention_heads

  • layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。

  • layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的值,以提高数值稳定性。

  • zero_centered_gamma (bool, default = False) –

    如果设置为 True,则 LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。

  • hidden_dropout (float, default = 0.1) – FC2 层之后 dropout 操作的 dropout 概率。

  • hidden_dropout_dims (Sequence[int], default = ()) – 将共享相同 dropout 掩码的隐藏层维度

  • attention_dropout (float, default = 0.1) – 多头注意力期间 dropout 操作的 dropout 概率。

  • intermediate_dropout (float, default = 0.1) – FC1 层之后 dropout 操作的 dropout 概率。

  • intermediate_dropout_dims (Sequence[int], default = ()) – FC1 层之后将共享相同 dropout 掩码的隐藏层维度。

  • dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于在多头注意力中生成 Dropout 掩码的键名。

  • mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影权重的权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。

  • mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化 FC1 和 FC2 层的权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。

  • mlp_activations (Sequence[str], default = ('relu', )) – 在第一个线性变换之后应用的一系列激活函数。每个激活函数都有自己的变换层。

  • use_bias (bool, default = False) – 指示是否为 QKVO 投影、FC1 和 FC2 启用偏置移位。如果设置为 False,则该层将不学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影、FC1 和 FC2 的偏置。它仅在 use_bias=True 时使用。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。

  • apply_residual_connection_post_layernorm (bool, default = False) – 如果设置为 True,则残差连接取自层归一化的输出(默认取自层归一化的输入)

  • output_layernorm (bool, default = False) – 如果设置为 True,则在最终 dropout-add 之后,在输出侧应用层归一化。默认行为是在输入侧、QKV 变换之前应用层归一化。

  • float32_attention_logits (bool, default = False) – 是否在 float32 中计算非融合注意力后端的注意力 logits。对于融合注意力后端,累积始终为 float32,而没有性能开销。

  • layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – 如果设置为 TransformerLayerType.DECODER,则在自注意力之后添加一个额外的交叉注意力块。这可以与 TransformerLayerType.ENCODER 选项结合使用,用于像 T5 Transformer 这样的结构。

  • self_attn_mask_type (str, default = 'causal') –

    此参数指定在自注意力中的 softmax 操作期间应用的注意力掩码的类型。可用选项为 {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’, ‘padding_causal’}

    每个选项描述如下

    • no_mask: 不应用注意力掩码。这意味着自注意力将考虑整个序列,没有任何限制。

    • padding: 指示每个序列末尾存在填充。用户必须在 __call__ 方法中提供形状为 [batch, 1, max_seqlen_q, max_seqlen_kv] 的掩码,以指定填充位置。

    • causal: 将上三角掩码应用于 softmax 输入,确保对某个位置的预测仅依赖于来自其之前位置的已知输出。

    • causal_padding / padding_causal: 因果掩码和填充掩码的组合。“causal_padding” 和 “padding_causal” 都是可接受的,并且效果相同。

    注意

    attention_mask 在 ‘no_mask’ 和 ‘causal’ 情况下被忽略。

  • self_attn_bias_type (Optional[str], default = None) – 传入自注意力的注意力偏置类型。可用选项:{‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}。当默认值存在时,类型由 MHA 的偏置参数自动决定。如果有偏置,则为 post_scale_bias。否则使用 no_bias

  • enable_relative_embedding (bool, default = True) – 是否启用相对嵌入作为注意力 logits 的移位。

  • relative_embedding (flax.linen.Module, default = None) – 用于执行相对嵌入的模块,仅当 enable_relative_embedding=True 时使用。 默认为 None,如果 enable_relative_embedding=True,则将创建 RelativePositionBiases 的实例。 默认值:RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)

  • enable_rotary_pos_emb (bool, default = False) – 是否在 MHA 中为 projected query 和 key 启用 rotary position embedding。

  • rotary_pos_emb_windows (Tuple[int, int], default = (1, 10000)) – 指示旋转位置嵌入的最小和最大时间尺度,仅当 enable_rotary_pos_emb=True 时使用

  • rotary_pos_emb_group_method (str, default = 'consecutive') – 指示耦合坐标的方法。 它应该是 [‘consecutive’, ‘alternate’] 之一。 ‘alternate’ 是将索引 \(i\)\(i + d/2\) 配对,其中 \(d\) 是隐藏维度。 ‘consecutive’ 将索引 \(i\)\(i + 1\) 配对。

  • low_rank_adaptation_scope (str, default = 'none') – 指示应用低秩自适应的范围。 它应该是 [‘none’, ‘all’, ‘qkv_proj’, ‘output_proj’, ‘mlp’, ‘exclude_qkv_proj’, ‘exclude_output_proj’, ‘exclude_mlp’] 之一

  • low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当 enable_low_rank_adaptation=True 时使用

  • low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha 值。 \(\frac{alpha}{rank} * lora\_output\)。 None 表示不进行缩放。

  • enable_sequence_parallel (bool, default = False) – 是否对点积以外的操作启用序列并行。

  • window_size (Optional[Tuple[int, int]], default = None) – 滑动窗口大小。 默认值是没有滑动窗口。

优化参数:
  • dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • drop_path (float, default = 0.0) – 当 > 0.0 时,在残差块的主路径中对每个样本应用随机深度。

  • fuse_qkv_params (bool, default = True) – 如果设置为 True,则 TransformerLayer 模块为自注意力机制的 query-key-value 和交叉注意力机制的 key-value 公开一个融合的参数。

  • transpose_batch_sequence (bool, default = False) – 指示输入张量是否已切换批次和序列长度维度的轴。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden) 格式,否则为 (batch, seqlen, hidden) 格式。

  • scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。 如果设置为 True,则为 \(\frac{Q}{\sqrt{head_dim}*K}\),否则为 \(Q*K\)

  • scaled_query_init (bool, default = True) – 是否在初始化时按 \(\sqrt{head_dim}\) 缩放 WQ

__call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)

Transformer Layer:注意力块和前馈网络 (MLP)

参数:
  • inputs (jax.numpy.ndarray) – 输入张量。

  • encoded (jax.numpy.ndarray, default = None) – 编码器块的输出张量,如果使用 layer_type=TransformerLayerType.DECODER,则将其馈送到解码器块。

  • attention_mask (jax.numpy.ndarray, default = None) – 用于屏蔽自注意力 softmax 输入的布尔张量。 True 表示屏蔽相应的值。 当 self.self_attn_mask_type 为 ‘no_mask’ 或 ‘causal’ 时,将被忽略。

  • encoder_decoder_mask (jax.numpy.ndarray, default = None) – 当 layer_type=TransformerLayerType.DECODER 时,用于屏蔽交叉注意力 softmax 输入的布尔张量。 True 表示屏蔽相应的值。

  • deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。

  • decode (bool, default = False) – 指示是否在多头注意力 (MHA) 中准备和使用自回归缓存。

  • max_decode_length (bool, default = None) – 当 layer_type=TransformerLayerType.DECODERenable_relative_embedding=True 时,生成相对嵌入偏差的最大长度。

返回:

outputs – 输出张量。

返回类型:

jax.numpy.ndarray

transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules) LogicalRules

使用预定义的 TransformerLayer 的逻辑轴规则扩展给定的 Flax 逻辑轴规则。

注意

我们目前仅支持单 GPU 训练、数据并行训练和 1D 分片张量并行训练的逻辑轴规则。 有关 1D 分片张量并行性的信息,请参阅 Figure 3 in Megatron-LM tensor parallel

警告

请确保在调用此函数之前通过 fp8_autocast 设置 ShardingResource。

注意

仅当使用 TransformerLayer 时才需要此功能。 对于其他模块(例如 DenseGeneral),请正确设置内核和偏差的轴。

参数:

rules (Sequence[Tuple[str, Union[str, None]]]) – 要扩展的基础 Flax 逻辑轴规则。

返回:

extended_rules – 扩展的 Flax 逻辑轴规则。

返回类型:

Sequence[Tuple[str, Union[str, None]]]