Jax
逻辑轴的预定义变量
变量在 transformer_engine.jax.sharding 中可用。
BATCH_AXES:批次维度的逻辑轴。它通常在 Mesh 上沿 DP + FSDP 分片。
SEQLEN_AXES:序列长度维度的逻辑轴。它通常不分片。
SEQLEN_TP_AXES:序列长度维度的逻辑轴。它通常在 Mesh 上沿 TP 分片。
HEAD_AXES:MHA 的头维度逻辑轴。它通常在 Mesh 上沿 TP 分片。
HIDDEN_AXES:隐藏维度的逻辑轴。它通常不分片。
HIDDEN_TP_AXES:隐藏维度的逻辑轴。它通常在 Mesh 上沿 TP 分片。
JOINED_AXES:未定义维度的逻辑轴。它通常不分片。
模块
- class transformer_engine.jax.flax.TransformerLayerType(*args, **kwds)
TransformerLayerType 是一个 Enum 类,用于指定 TransformerLayer 的类型
- 值:
ENCODER – TransformerLayer 的编码器类型。
DECODER – TransformerLayer 的解码器类型。
- class transformer_engine.jax.MeshResource
一个数据容器,用于指示 Mesh 中哪个轴用于数据并行,哪个轴用于张量并行。
- 参数:
dp_resource (str, default = None) – Mesh 中用于沿批次分片的轴名称。如果为 None,则禁用数据并行。
tp_resource (str, default = None) – Mesh 中用于沿隐藏维度拆分的轴名称。如果为 None,则禁用张量并行。
fsdp_resource (str, default = None) – Mesh 中用于沿批次和权重拆分的轴名称。如果为 None,则禁用完全分片数据并行。
pp_resource (str, default = None) – Mesh 中用于沿模型层拆分的轴名称。如果为 None,则禁用流水线并行。
cp_resource (str, default = None) – Mesh 中用于在注意力中沿序列(上下文)维度拆分的轴名称。如果为 None,则禁用上下文并行。
- transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: transformer_engine.common.recipe.DelayedScaling | None = None, mesh_resource: transformer_engine.jax.sharding.MeshResource | None = None) None
FP8 使用的上下文管理器。
mesh_shape = (4, 2) dp_mesh_axis_name = 'data_parallel' tp_mesh_axis_name = 'tensor_parallel' devices = np.asarray(jax.devices()).reshape(*mesh_shape) with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) with fp8_autocast(enabled=True, mesh_resource=mesh_resource): rules = extend_logical_axis_rules(tuple()) transformer = TransformerLayer() with partitioning.axis_rules(rules): pjit(transformer.init, ...)(...)
注意
我们目前仅支持
margin
,fp8_format
,amax_history_len
, 和amax_compute_algo
(值为 ‘max’ 和 ‘most_recent’) 在 recipe.DelayedScaling 中。recipe.DelayedScaling 中的其他参数将触发断言。- 参数:
enabled (bool, default = False) – 是否启用 fp8
fp8_recipe (recipe.DelayedScaling, default = None) – 用于 FP8 训练的配方。
mesh_resource (MeshResource, default = None) – 指定用于数据和张量并行的 mesh 轴以进行分片。如果设置为 None,则不使用数据或张量并行。
- transformer_engine.jax.update_collections(new: Collection, original: Collection) flax.core.frozen_dict.FrozenDict
用于更新 Flax 的 Collection 的辅助函数。
Collection = [dict, flax.core.frozen_dict.FrozenDict]
- 参数:
new (Collection) – 包含新数据的集合。
original (Collection) – 基础集合。
- 返回:
outputs – 更新后的集合。
- 返回类型:
Collection
- class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
对小批量输入应用层归一化。此模块支持两种类型的归一化,常规层归一化和均方根层归一化。
常规层归一化如论文 Layer Normalization 中所述
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]\(\gamma\) 和 \(\beta\) 是每个输入样本大小的可学习仿射变换参数。
均方根层归一化 (RMSNorm) 如论文 Root Mean Square Layer Normalization 中所述
\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]\[RMS = \sqrt{\mathrm{E}[x^2]}\]\(\gamma\) 是每个输入样本大小的可学习仿射变换参数。
- 参数:
epsilon (float, default = 1e-6) – 添加到层归一化分母的值,以提高数值稳定性。
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。
zero_centered_gamma (bool, default = False) –
如果设置为 True,则 LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。 scale_init 的默认值也将更改。 请参阅 scale_init。
scale_init (Initializer, default = None) – 用于初始化比例因子 \(\gamma\)。 如果提供 None,则根据 zero_centered_gamma 的值设置 scale_init。 如果 zero_centered_gamma 设置为 True,则 scale_init 为 flax.linen.initializers.zeros。 否则,scale_init 为 flax.linen.initializers.ones。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。
scale_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化移位因子 \(\beta\),仅当
layernorm_type='layernorm'
时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。bias_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对移位因子 \(\beta\) 进行分片的轴名称。 仅当
layernorm_type='layernorm'
时使用。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, default = False) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
- __call__(x: jax.numpy.ndarray) jax.numpy.ndarray
将层归一化应用于输入
inputs
。- 参数:
inputs (jax.numpy.ndarray) – 输入张量。
- 返回:
outputs – 输出张量。
- 返回类型:
jax.numpy.ndarray
- class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
对传入数据应用线性变换 \(y = xA^T + b\)。
- 参数:
features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。
kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。
kernel_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对权重进行分片的轴名称。
use_bias (bool, default = False) – 指示是否启用偏置移位。 如果设置为 False,则该层不会学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化偏置,仅当
use_bias=True
时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。bias_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对偏置进行分片的轴名称,仅当
use_bias=True
时使用。enable_low_rank_adaptation (bool, default = False) – 指示是否为每个线性层启用低秩自适应。
low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当
enable_low_rank_adaptation=True
时使用low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。
axis (Union[Iterable[int], int], default = -1) – 要在其上应用变换的整数元组轴。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, default = True) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
- __call__(inputs: Array) Array
将线性变换应用于输入。
- 参数:
inputs (jax.numpy.ndarray) – 输入张量。
- 返回:
outputs – 输出张量。
- 返回类型:
jax.numpy.ndarray
- class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
对传入数据应用层归一化,然后进行线性变换。
- 参数:
features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。
enable_layernorm (bool, default = True) – 指示是否在线性变换之前启用层归一化。
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。
epsilon (float, default = 1e-6) – 添加到层归一化分母的值,以提高数值稳定性。
zero_centered_gamma (bool, default = False) –
如果设置为 True,则 LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。 scale_init 的默认值也将更改。 请参阅 scale_init
scale_init (Initializer, default = None) – 用于初始化比例因子 \(\gamma\)。 如果提供 None,则根据 zero_centered_gamma 的值设置 scale_init。 如果 zero_centered_gamma 设置为 True,则 scale_init 为 flax.linen.initializers.zeros。 否则,scale_init 为 flax.linen.initializers.ones。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。
scale_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅当
enable_layernorm=True
时使用。ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化移位因子 \(\beta\),仅当
enable_layernorm=True
且layernorm_type='layernorm'
时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。ln_bias_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对移位因子 \(\beta\) 进行分片的轴名称。 仅当
enable_layernorm=True
且layernorm_type='layernorm'
时使用。kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。
kernel_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对权重进行分片的轴名称。
use_bias (bool, default = False) – 指示是否启用偏置移位。 如果设置为 False,则该层不会学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化偏置,仅当
use_bias=True
时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。bias_axes (Tuple[str, ...], default = ()) – 用于使用相应的 mesh 对偏置进行分片的轴名称,仅当
use_bias=True
时使用。return_layernorm_output (bool, default = True) – 指示是否返回层归一化的输出。 如果设置为 False,则在输出中将 None 作为第二个张量返回。
enable_low_rank_adaptation (bool, default = False) – 指示是否为每个线性层启用低秩自适应。
low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当
enable_low_rank_adaptation=True
时使用low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。
axis (Union[Iterable[int], int], default = -1) – 要在其上应用变换的整数元组轴。
layernorm_input_axes (Tuple[str, ...], default = None) – 指示层归一化输入的分片约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。 默认为 None,表示不插入分片约束。
dot_input_axes (Tuple[str, ...], default = None) – 指示点积输入的分片约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。 默认为 None,表示不插入分片约束。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, default = True) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
depth_scaling (float, default = None) – 用于缩放来自 DenseGeneral 的输出的因子。 它应该是一个浮点值或 None。 当设置为 None 时,不应用缩放。
- __call__(inputs: Array) Array
将层归一化应用于输入,然后进行线性变换。
- 参数:
inputs (jax.numpy.ndarray) – 输入张量。
- 返回:
outputs (jax.numpy.ndarray) – 输出张量。
ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。 如果
return_layernorm_output=False
,则这将为 None。
- class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
对输入应用层归一化,然后应用 MLP 模块,该模块由 2 个连续的线性变换组成,并由给定的激活函数分隔。
- 参数:
intermediate_dim (int, default = 2048) – 输入样本投影到的中间大小。
enable_layernorm (bool, default = True) – 指示是否在线性变换之前启用层归一化。
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。
epsilon (float, default = 1e-6) – 添加到层归一化分母的值,以提高数值稳定性。
zero_centered_gamma (bool, default = False) –
如果设置为 True,则 LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。 scale_init 的默认值也将更改。 请参阅 scale_init。
scale_init (Initializer, default = None) – 用于初始化比例因子 \(\gamma\)。 如果提供 None,则根据 zero_centered_gamma 的值设置 scale_init。 如果 zero_centered_gamma 设置为 True,则 scale_init 为 flax.linen.initializers.zeros。 否则,scale_init 为 flax.linen.initializers.ones。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。
scale_axes (Tuple[str, ...], default = ('embed', )) – 用于使用相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅当
enable_layernorm=True
时使用。ln_bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化移位因子 \(\beta\),仅当
enable_layernorm=True
且layernorm_type='layernorm'
时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。ln_bias_axes (Tuple[str, ...], default = ('embed', )) – 用于分片移位因子 \(\beta\) 以及对应网格的轴的名称。仅当
enable_layernorm=True
且layernorm_type='layernorm'
时使用。kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化线性变换的权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。
kernel_axes_1 (Tuple[str, ...], default = ('embed', 'act', 'mlp')) – 用于分片第一个线性变换权重的轴的名称,并与对应的网格关联。
kernel_axes_2 (Tuple[str, ...], default = ('mlp', 'embed')) – 用于分片第二个线性变换权重的轴的名称,并与对应的网格关联。
use_bias (bool, default = False) – 指示是否启用偏置移位。 如果设置为 False,则该层不会学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化偏置,仅当
use_bias=True
时使用。 它应该是一个带有三个参数(jax.random.PRNGKey、shape、dtype)的可调用对象。bias_axes_1 (Tuple[str, ...], default = ('mlp',)) – 用于分片第一个线性变换权重的偏置的轴的名称,并与对应的网格关联。仅当
use_bias=True
时使用。bias_axes_2 (Tuple[str, ...], default = ('embed',)) – 用于分片第二个线性变换权重的偏置的轴的名称,并与对应的网格关联。仅当
use_bias=True
时使用。return_layernorm_output (bool, default = True) – 指示是否返回层归一化的输出。 如果设置为 False,则在输出中将 None 作为第二个张量返回。
activations (Sequence[Union[str, Callable]], default = ('relu',)) – 在第一个线性变换之后应用的一系列激活函数。每个激活函数都有自己的变换层。
intermediate_dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于生成 Dropout 掩码的键名。
intermediate_dropout_rate (float, default = 0.1) – 在
activations
之后 dropout 操作的 dropout 概率。intermediate_hidden_dropout_dims (Sequence[int], default = ()) – 将共享相同 dropout 掩码的隐藏层维度
enable_low_rank_adaptation (bool, default = False) – 指示是否为每个线性层启用低秩自适应。
low_rank_adaptation_dim (int, default = 32) – 低秩适配的维度,仅当
enable_low_rank_adaptation=True
时使用。low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。
axis (Union[Iterable[int], int], default = -1) – 要在其上应用变换的整数元组轴。
layernorm_input_axes (Tuple[str, ...], default = None) – 指示层归一化输入的分片约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。 默认为 None,表示不插入分片约束。
dot_1_input_axes (Tuple[str, ...], default = None) – 指示第一个点积运算输入的 sharding 约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。默认为 None,表示不插入 sharding 约束。
dot_2_input_axes (Tuple[str, ...], default = None) – 指示第二个点积运算输入的 sharding 约束的逻辑轴,例如 (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)。默认为 None,表示不插入 sharding 约束。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, default = True) – 指示输入张量是否切换了批次轴和序列长度维度。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
- __call__(inputs: Array, deterministic: bool = False) Array
对输入应用层归一化,然后接上前馈网络 (MLP Block)。
- 参数:
inputs (jax.numpy.ndarray) – 输入张量。
deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 操作。
- 返回:
outputs (jax.numpy.ndarray) – 输出张量。
ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。 如果
return_layernorm_output=False
,则这将为 None。
- class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
T5 风格的相对位置嵌入,用于注意力 logits。
- 参数:
num_buckets (int) – 将键和查询位置之间的距离分桶的数量。
max_distance (int) – 最大距离,超过此距离的所有距离都将被归入最后一个距离桶。
num_attention_heads (int) – Transformer 层中的注意力头的数量。
embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – 用于初始化相对嵌入表。
embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – 用于分片嵌入注意力偏置的轴的名称,并与对应的网格关联。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
- __call__(q_seqlen, k_seqlen, bidirectional=True)
生成相对位置嵌入注意力偏置。
- 参数:
q_seqlen (int) – 查询的序列长度。
k_seqlen (int) – 键的序列长度。
bidirectional (bool, default = True) – 指示是否允许正向的 memory-query 相对位置嵌入。
- 返回:
output – 形状为 (1, num_attention_heads, q_seqlen, k_seqlen) 的注意力偏置。
- 返回类型:
jax.numpy.ndarray
- class transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
点积注意力 (DPA)。允许模型共同关注来自不同表示子空间的信息,如论文 Attention Is All You Need 中所述。
注意
DotProductAttention 模块支持两种后端:非融合和融合注意力机制。非融合注意力使用 JAX 原生操作实现,提供广泛的兼容性和灵活性。相比之下,融合注意力使用 cuDNN 融合注意力,以在支持的硬件上获得更高的性能和更低的内存使用率。用户可以通过
NVTE_FUSED_ATTN
环境变量在两种后端之间进行选择设置
NVTE_FUSED_ATTN=0
以使用非融合注意力(默认)。设置
NVTE_FUSED_ATTN=1
以使用融合注意力。如果系统上没有所需的 cuDNN 融合注意力内核,则会发出警告,并且模块将自动回退到非融合后端。
注意
DotProductAttention 默认设置启用非确定性内核,以减少工作空间需求并加快计算速度。用户可以通过
NVTE_ALLOW_NONDETERMINISTIC_ALGO
环境变量禁用非确定性内核NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
仅允许确定性内核。NVTE_ALLOW_NONDETERMINISTIC_ALGO=1
允许非确定性内核(默认)。
- 参数:
head_dim (int) – 每个注意力头的隐藏维度。
num_attention_heads (int) – 注意力头的数量。
num_gqa_groups (int, default = None) – GQA 组的数量。当为 None 时,它等于 num_attention_heads。分组查询注意力在 本文 中描述。这仅影响键和值,不影响查询。GQA-1 等效于多查询注意力 (MQA),而 GQA-H 等效于 MHA,即 num_gqa_groups = num_attention_heads。
attention_dropout (float, default = 0.0) – softmax 之后 dropout 操作的 dropout 概率。
attn_mask_type (str, default = 'causal') –
此参数指定在 softmax 操作期间应用的注意力掩码的类型。可用选项为 {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’, ‘padding_causal’}
每个选项描述如下
no_mask: 不应用注意力掩码。这意味着注意力将考虑整个序列,没有任何限制。
padding: 指示每个序列末尾存在填充。用户必须在
__call__
方法中提供形状为 [batch, 1, max_seqlen_q, max_seqlen_kv] 的掩码,以指定填充位置。causal: 将上三角掩码应用于 softmax 输入,确保对某个位置的预测仅依赖于来自其之前位置的已知输出。
causal_padding / padding_causal: 因果掩码和填充掩码的组合。“causal_padding” 和 “padding_causal” 都是可接受的,并且效果相同。
注意
mask
在 ‘no_mask’ 和 ‘causal’ 情况下被忽略。attn_bias_type (Optional[str], default = None) – 传入注意力的注意力偏置类型。可用选项:{‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}。当默认值存在时,类型由 MHA 的偏置参数自动决定。如果有偏置,则为
post_scale_bias
。否则使用no_bias
。dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于在核心注意力中生成 Dropout 掩码的键名。
float32_logits (bool, default = False) – 是否在 float32 中计算非融合注意力后端的注意力 logits。对于融合注意力后端,累积始终为 float32,而没有性能开销。
qkv_layout (str, default = 'bshd_bshd_bshd') –
指定 __call__() 中 query、key 和 value 张量的维度布局格式。它指示如何处理输入。可用选项:{‘bs3hd’, ‘bshd_bs2hd’, ‘bshd_bshd_bshd’}。其中
bs3hd: query 张量被视为 qkvpacked 张量,形状为 [b, s, 3, h, d]。
__call__()
中的 key 和 value 参数在此布局中被忽略。bshd_bs2hd: query 张量形状为 [b, s, h, d]。 key 张量被视为 kvpacked 张量,形状为 [b, s, 2, h, d]。
__call__()
中的 value 参数被忽略。bshd_bshd_bshd: query、key 和 value 是分开的,形状为 [b, s, h, d]。
符号解释
b: 批大小
s: 序列长度
h: num_attention_heads 或 num_gqa_groups
d: head 维度
scale_factor (Optional[float], default = None) – 应用于 query 的缩放因子。当为
None
时,缩放因子等于 \(\frac{1}{\sqrt{head\_dim}}\)。这对于像 T5X 这样的模型很有用,它不需要对 query 应用缩放,即设置为scale_factor=1.
。transpose_batch_sequence (bool, default = True) – 指示输入张量是否已交换批次和序列长度维度的轴。如果设置为 True,则输入张量应为 (seqlen, batch, …),否则为 (batch, seqlen, …)。
window_size (Optional[Tuple[int, int]], default = None) – 滑动窗口大小。默认值是没有滑动窗口。
(bool) (context_parallel_causal_load_balanced) – 指示在运行上下文并行时,序列是否针对因果掩码负载均衡进行排序。
(str) (context_parallel_axis)
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
- __call__(query: Array, key: Array, value: Array, mask: Array | None = None, bias: Array | None = None, *, deterministic: bool = False) Array
- 参数:
query (jax.numpy.ndarray) – query 张量表示的详细信息在
qkv_layout
中描述。key (jax.numpy.ndarrary) – kery 张量表示的详细信息在
qkv_layout
中描述。value (jax.numpy.ndarrary) – value 张量表示的详细信息在
qkv_layout
中描述。mask (jax.numpy.ndarray, default = None) – 用于屏蔽注意力 softmax 输入的布尔张量。
True
表示屏蔽对应的值。当self.attn_mask_type
为 ‘no_mask’ 或 ‘causal’ 时被忽略。bias (jax.numpy.ndarray, default = None) – 用于移动注意力 softmax 输入的张量。
* – 以下参数仅为关键字参数
deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。
- 返回:
outputs – 输出张量。
- 返回类型:
jax.numpy.ndarray
- class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
多头注意力 (MHA),包括 Query、Key、Value 和 Output 投影。
- 参数:
head_dim (int) – 每个注意力头的隐藏维度。
num_attention_heads (int) – 注意力头的数量。
num_gqa_groups (int, default = None) –
GQA 组的数量。当为 None 时,它等于 num_attention_heads。分组查询注意力在 本文 中描述。这仅影响键和值,不影响查询。GQA-1 等效于多查询注意力 (MQA),而 GQA-H 等效于 MHA,即 num_gqa_groups = num_attention_heads。
attention_dropout (float, default = 0.0) – softmax 之后 dropout 操作的 dropout 概率。
attn_mask_type (str, default = 'causal') –
此参数指定在 softmax 操作期间应用的注意力掩码的类型。可用选项为 {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’, ‘padding_causal’}
每个选项描述如下
no_mask: 不应用注意力掩码。这意味着注意力将考虑整个序列,没有任何限制。
padding: 指示每个序列末尾存在填充。用户必须在
__call__
方法中提供形状为 [batch, 1, max_seqlen_q, max_seqlen_kv] 的掩码,以指定填充位置。causal: 将上三角掩码应用于 softmax 输入,确保对某个位置的预测仅依赖于来自其之前位置的已知输出。
causal_padding / padding_causal: 因果掩码和填充掩码的组合。“causal_padding” 和 “padding_causal” 都是可接受的,并且效果相同。
注意
mask
在 ‘no_mask’ 和 ‘causal’ 情况下被忽略。attn_bias_type (Optional[str], default = None) – 传入注意力的注意力偏置类型。可用选项:{‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}。当默认值存在时,类型由 MHA 的偏置参数自动决定。如果有偏置,则为 post_scale_bias。否则使用 no_bias。
dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于在核心注意力中生成 Dropout 掩码的键名。
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。
layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的值,以提高数值稳定性。
zero_centered_gamma (bool, default = False) –
如果设置为 True,则 LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。
kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。
use_bias (bool, default = False) – 指示是否为 QKV 和输出投影启用偏置移位。如果设置为 False,则该层将不学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影的偏置,仅当
use_bias=True
时使用。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。input_layernorm (bool, default = True) – 如果设置为 False,则不对输入应用层归一化。
return_layernorm_output (bool, default = False) – 如果设置为 True,则层归一化的输出将与线性变换的输出一起从前向传播中返回。示例用例:Transformer 模块的残差连接在层归一化之后获取。
enable_rotary_pos_emb (bool, default = False) – 是否为投影的 query 和 key 启用旋转位置嵌入。
rotary_pos_emb_windows (Tuple[int, int], default = (1, 10000)) – 指示旋转位置嵌入的最小和最大时间尺度,仅当
enable_rotary_pos_emb=True
时使用rotary_pos_emb_group_method (str, default = 'consecutive') – 指示耦合坐标的方法。它应该是 [‘consecutive’, ‘alternate’] 之一。 ‘alternate’ 是将索引 \(i\) 与 \(i + d/2\) 配对,d 是隐藏维度。 ‘consecutive’ 将索引 \(i\) 与 \(i + 1\) 配对。
low_rank_adaptation_scope (str, default = 'none') – 指示应用低秩适配的范围。它应该是 [‘none’, ‘all’, ‘qkv_proj’, ‘output_proj’, ‘exclude_qkv_proj’, ‘exclude_output_proj’] 之一
low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当
enable_low_rank_adaptation=True
时使用low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha。 \(\frac{alpha}{rank} * lora_output\)。 None 表示不缩放。
enable_sequence_parallel (bool, default = False) – 是否对点积以外的操作启用序列并行。
num_heads (int, default = None) – 已弃用。请参考 num_attention_heads。
dropout_rate (float, default = None) – 已弃用。请参考 attention_dropout。
output_layernorm (bool, default = None) – 已弃用。请参考 input_layernorm
apply_residual_connection_post_layernorm (bool, default = None) – 已弃用。请参考 return_layernorm_output。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
fuse_qkv_params (bool, default = True) – 如果设置为 True,则此模块为自注意力公开一个用于 query-key-value 的融合参数,为交叉注意力公开一个用于 key-value 的融合参数。
transpose_batch_sequence (bool, default = True) – 指示输入张量是否已交换批次和序列长度维度的轴。如果设置为 True,则输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。如果设置为 True,则为 \(\frac{Q}{\sqrt{head\_dim}*K}\),否则为 \(Q*K\)
scaled_query_init (bool, default = True) – 是否在初始化时按 \(\frac{1}{\sqrt{head\_dim}}\) 缩放 WQ
float32_logits (bool, default = False) – 是否在 float32 中计算非融合注意力后端的注意力 logits。对于融合注意力后端,累积始终为 float32,而没有性能开销。
fuse_qkv (bool, default = None) – 已弃用。请参考 fuse_qkv_params
window_size (Optional[Tuple[int, int]], default = None) – 滑动窗口大小。默认值是没有滑动窗口。
- __call__(inputs_q: Array, inputs_kv: Array, mask: Array | None = None, bias: Array | None = None, *, decode: bool = False, deterministic: bool = False) Array
MultiHeadAttention 层:[Query, Key, Value 投影] -> 点积注意力 -> 输出投影。
- 参数:
inputs_q (jax.numpy.ndarray) – 用于 query 投影的输入张量。
inputs_kv (jax.numpy.ndarray) – 用于 key/value 投影的输入张量。
mask (jax.numpy.ndarray, default = None) – 用于屏蔽注意力 softmax 输入的布尔张量。
True
表示屏蔽对应的值。当self.attn_mask_type
为 ‘no_mask’ 或 ‘causal’ 时被忽略。bias (jax.numpy.ndarray, default = None) – 用于移动注意力 softmax 输入的张量。
*
decode (bool, default = False) – 指示是否准备和使用自回归缓存。
deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。
- 返回:
outputs – 输出张量。
- 返回类型:
jax.numpy.ndarray
- class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
TransformerLayer 由相对嵌入、注意力块和前馈网络 (MLP) 组成。这个标准层基于论文 “Attention Is All You Need”。
- 参数:
hidden_size (int, default = 512) – 每个输入样本的隐藏大小。
mlp_hidden_size (int, default = 2048) – 输入样本被投影到的中间大小。
num_attention_heads (int, default = 8) – Transformer 层中的注意力头的数量。
num_gqa_groups (int, default = None) –
GQA 组的数量。当为 None 时,它等于 num_attention_heads。分组查询注意力在 本文 中描述。这仅影响键和值,不影响查询。GQA-1 等效于多查询注意力 (MQA),而 GQA-H 等效于 MHA,即 num_gqa_groups = num_attention_heads。
layernorm_type ({'layernorm', 'rmsnorm'}, default = 'layernorm') – 指示层归一化的类型。
layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的值,以提高数值稳定性。
zero_centered_gamma (bool, default = False) –
如果设置为 True,则 LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。
hidden_dropout (float, default = 0.1) – FC2 层之后 dropout 操作的 dropout 概率。
hidden_dropout_dims (Sequence[int], default = ()) – 将共享相同 dropout 掩码的隐藏层维度
attention_dropout (float, default = 0.1) – 多头注意力期间 dropout 操作的 dropout 概率。
intermediate_dropout (float, default = 0.1) – FC1 层之后 dropout 操作的 dropout 概率。
intermediate_dropout_dims (Sequence[int], default = ()) – FC1 层之后将共享相同 dropout 掩码的隐藏层维度。
dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 传入的 RNGs 中,用于在多头注意力中生成 Dropout 掩码的键名。
mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影权重的权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。
mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化 FC1 和 FC2 层的权重。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。
mlp_activations (Sequence[str], default = ('relu', )) – 在第一个线性变换之后应用的一系列激活函数。每个激活函数都有自己的变换层。
use_bias (bool, default = False) – 指示是否为 QKVO 投影、FC1 和 FC2 启用偏置移位。如果设置为 False,则该层将不学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影、FC1 和 FC2 的偏置。它仅在
use_bias=True
时使用。它应该是一个可调用对象,带有三个参数 (jax.random.PRNGKey, shape, dtype)。apply_residual_connection_post_layernorm (bool, default = False) – 如果设置为 True,则残差连接取自层归一化的输出(默认取自层归一化的输入)
output_layernorm (bool, default = False) – 如果设置为 True,则在最终 dropout-add 之后,在输出侧应用层归一化。默认行为是在输入侧、QKV 变换之前应用层归一化。
float32_attention_logits (bool, default = False) – 是否在 float32 中计算非融合注意力后端的注意力 logits。对于融合注意力后端,累积始终为 float32,而没有性能开销。
layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – 如果设置为 TransformerLayerType.DECODER,则在自注意力之后添加一个额外的交叉注意力块。这可以与 TransformerLayerType.ENCODER 选项结合使用,用于像 T5 Transformer 这样的结构。
self_attn_mask_type (str, default = 'causal') –
此参数指定在自注意力中的 softmax 操作期间应用的注意力掩码的类型。可用选项为 {‘no_mask’, ‘padding’, ‘causal’, ‘causal_padding’, ‘padding_causal’}
每个选项描述如下
no_mask: 不应用注意力掩码。这意味着自注意力将考虑整个序列,没有任何限制。
padding: 指示每个序列末尾存在填充。用户必须在
__call__
方法中提供形状为 [batch, 1, max_seqlen_q, max_seqlen_kv] 的掩码,以指定填充位置。causal: 将上三角掩码应用于 softmax 输入,确保对某个位置的预测仅依赖于来自其之前位置的已知输出。
causal_padding / padding_causal: 因果掩码和填充掩码的组合。“causal_padding” 和 “padding_causal” 都是可接受的,并且效果相同。
注意
attention_mask
在 ‘no_mask’ 和 ‘causal’ 情况下被忽略。self_attn_bias_type (Optional[str], default = None) – 传入自注意力的注意力偏置类型。可用选项:{‘no_bias’, ‘pre_scale_bias’, ‘post_scale_bias’}。当默认值存在时,类型由 MHA 的偏置参数自动决定。如果有偏置,则为 post_scale_bias。否则使用 no_bias。
enable_relative_embedding (bool, default = True) – 是否启用相对嵌入作为注意力 logits 的移位。
relative_embedding (flax.linen.Module, default = None) – 用于执行相对嵌入的模块,仅当
enable_relative_embedding=True
时使用。 默认为 None,如果enable_relative_embedding=True
,则将创建 RelativePositionBiases 的实例。 默认值:RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)enable_rotary_pos_emb (bool, default = False) – 是否在 MHA 中为 projected query 和 key 启用 rotary position embedding。
rotary_pos_emb_windows (Tuple[int, int], default = (1, 10000)) – 指示旋转位置嵌入的最小和最大时间尺度,仅当
enable_rotary_pos_emb=True
时使用rotary_pos_emb_group_method (str, default = 'consecutive') – 指示耦合坐标的方法。 它应该是 [‘consecutive’, ‘alternate’] 之一。 ‘alternate’ 是将索引 \(i\) 与 \(i + d/2\) 配对,其中 \(d\) 是隐藏维度。 ‘consecutive’ 将索引 \(i\) 与 \(i + 1\) 配对。
low_rank_adaptation_scope (str, default = 'none') – 指示应用低秩自适应的范围。 它应该是 [‘none’, ‘all’, ‘qkv_proj’, ‘output_proj’, ‘mlp’, ‘exclude_qkv_proj’, ‘exclude_output_proj’, ‘exclude_mlp’] 之一
low_rank_adaptation_dim (int, default = 32) – 低秩自适应的维度,仅当
enable_low_rank_adaptation=True
时使用low_rank_adaptation_alpha (float, default = None) – 用于计算 LoRA 输出的缩放因子的 alpha 值。 \(\frac{alpha}{rank} * lora\_output\)。 None 表示不进行缩放。
enable_sequence_parallel (bool, default = False) – 是否对点积以外的操作启用序列并行。
window_size (Optional[Tuple[int, int]], default = None) – 滑动窗口大小。 默认值是没有滑动窗口。
- 优化参数:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – 用于分配初始参数的数据类型。
drop_path (float, default = 0.0) – 当 > 0.0 时,在残差块的主路径中对每个样本应用随机深度。
fuse_qkv_params (bool, default = True) – 如果设置为 True,则 TransformerLayer 模块为自注意力机制的 query-key-value 和交叉注意力机制的 key-value 公开一个融合的参数。
transpose_batch_sequence (bool, default = False) – 指示输入张量是否已切换批次和序列长度维度的轴。 如果设置为 True,则输入张量应为 (seqlen, batch, hidden) 格式,否则为 (batch, seqlen, hidden) 格式。
scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。 如果设置为 True,则为 \(\frac{Q}{\sqrt{head_dim}*K}\),否则为 \(Q*K\)
scaled_query_init (bool, default = True) – 是否在初始化时按 \(\sqrt{head_dim}\) 缩放 WQ
- __call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)
Transformer Layer:注意力块和前馈网络 (MLP)
- 参数:
inputs (jax.numpy.ndarray) – 输入张量。
encoded (jax.numpy.ndarray, default = None) – 编码器块的输出张量,如果使用
layer_type=TransformerLayerType.DECODER
,则将其馈送到解码器块。attention_mask (jax.numpy.ndarray, default = None) – 用于屏蔽自注意力 softmax 输入的布尔张量。
True
表示屏蔽相应的值。 当self.self_attn_mask_type
为 ‘no_mask’ 或 ‘causal’ 时,将被忽略。encoder_decoder_mask (jax.numpy.ndarray, default = None) – 当
layer_type=TransformerLayerType.DECODER
时,用于屏蔽交叉注意力 softmax 输入的布尔张量。True
表示屏蔽相应的值。deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。
decode (bool, default = False) – 指示是否在多头注意力 (MHA) 中准备和使用自回归缓存。
max_decode_length (bool, default = None) – 当
layer_type=TransformerLayerType.DECODER
且enable_relative_embedding=True
时,生成相对嵌入偏差的最大长度。
- 返回:
outputs – 输出张量。
- 返回类型:
jax.numpy.ndarray
- transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules) LogicalRules
使用预定义的 TransformerLayer 的逻辑轴规则扩展给定的 Flax 逻辑轴规则。
注意
我们目前仅支持单 GPU 训练、数据并行训练和 1D 分片张量并行训练的逻辑轴规则。 有关 1D 分片张量并行性的信息,请参阅 Figure 3 in Megatron-LM tensor parallel。
警告
请确保在调用此函数之前通过 fp8_autocast 设置 ShardingResource。
注意
仅当使用 TransformerLayer 时才需要此功能。 对于其他模块(例如 DenseGeneral),请正确设置内核和偏差的轴。
- 参数:
rules (Sequence[Tuple[str, Union[str, None]]]) – 要扩展的基础 Flax 逻辑轴规则。
- 返回:
extended_rules – 扩展的 Flax 逻辑轴规则。
- 返回类型:
Sequence[Tuple[str, Union[str, None]]]