Transformer Engine 文档

Transformer Engine (TE) 是一个用于加速 NVIDIA GPU 上 Transformer 模型的库,包括在 Hopper、Ada 和 Blackwell GPU 上使用 8 位浮点 (FP8) 精度,从而在训练和推理中提供更好的性能和更低的内存利用率。TE 为流行的 Transformer 架构提供了一系列高度优化的构建块,以及一个自动混合精度类似的 API,可以与您的框架特定代码无缝使用。TE 还包括一个框架无关的 C++ API,可以与其他深度学习库集成,以实现对 Transformer 的 FP8 支持。

随着 Transformer 模型中参数数量的持续增长,BERT、GPT 和 T5 等架构的训练和推理变得非常消耗内存和计算资源。大多数深度学习框架默认使用 FP32 进行训练。然而,对于许多深度学习模型来说,这对于实现完全精度并非必要。使用混合精度训练,在训练模型时将单精度 (FP32) 与较低精度(例如 FP16)格式相结合,与 FP32 训练相比,可以在精度差异最小的情况下显著加速。随着 Hopper GPU 架构引入了 FP8 精度,与 FP16 相比,它在不降低精度的情况下提供了更高的性能。虽然所有主要的深度学习框架都支持 FP16,但 FP8 支持在当今的框架中尚不可用。

TE 通过提供与流行的 LLM(大型语言模型)库集成的 API 来解决 FP8 支持问题。它提供了一个 Python API,其中包含易于构建 Transformer 层的模块,以及一个框架无关的 C++ 库,其中包括 FP8 支持所需的结构和内核。TE 提供的模块在内部维护 FP8 训练所需的缩放因子和其他值,大大简化了用户的混合精度训练。

亮点

  • 易于使用的模块,用于构建具有 FP8 支持的 Transformer 层

  • Transformer 模型的优化(例如,融合内核)

  • 在 NVIDIA Hopper、Ada 和 Blackwell GPU 上支持 FP8

  • 在 NVIDIA Ampere GPU 架构及更高版本上,支持跨所有精度(FP16、BF16)的优化

示例

PyTorch

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = model(inp)

loss = out.sum()
loss.backward()

JAX

Flax

import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe

BATCH = 32
SEQLEN = 128
HIDDEN = 1024

# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    model = te_flax.DenseGeneral(features=HIDDEN)

    def loss_fn(params, other_vars, inp):
      out = model.apply({'params':params, **other_vars}, inp)
      return jnp.mean(out)

    # Initialize models.
    variables = model.init(init_rng, inp)
    other_variables, params = flax.core.pop(variables, 'params')

    # Construct the forward and backward function
    fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

    for _ in range(10):
      loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)