安装

前提条件

  1. Linux x86_64

  2. CUDA 12.1+ (12.8+ 支持 Blackwell)

  3. NVIDIA 驱动程序 支持 CUDA 12.1 或更高版本。

  4. cuDNN 9.3 或更高版本。

如果 CUDA Toolkit 头文件在运行时在标准安装路径(例如 CUDA_HOME 内)中不可用,请在环境中设置 NVTE_CUDA_INCLUDE_PATH

NGC 容器中的 Transformer Engine

Transformer Engine 库已预装在 NVIDIA GPU Cloud ( NVIDIA GPU Cloud ) 版本 22.09 及更高版本的 PyTorch 容器中。

pip - 来自 PyPI

Transformer Engine 可以直接从 我们的 PyPI 安装,例如:

pip install transformer_engine[pytorch]

要获取 Transformer Engine 所需的 Python 绑定,必须在逗号分隔的列表中将所需的框架显式指定为额外的依赖项(例如 [jax,pytorch])。Transformer Engine 为核心库提供 wheels 包。源代码分发包用于 JAX 和 PyTorch 扩展。

pip - 来自 GitHub

其他前提条件

  1. [对于 PyTorch 支持] 支持 GPU 的 PyTorch

  2. [对于 JAX 支持] 支持 GPU 的 JAX,版本 >= 0.4.7。

安装 (稳定版本)

执行以下命令安装最新稳定版本的 Transformer Engine

pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

这将自动检测是否安装了任何受支持的深度学习框架,并为其构建 Transformer Engine 支持。要显式指定框架,请将环境变量 NVTE_FRAMEWORK 设置为逗号分隔的列表(例如 NVTE_FRAMEWORK=jax,pytorch)。

安装 (开发构建版本)

警告

虽然 Transformer Engine 的开发构建版本可能包含官方版本尚未提供的新功能,但它不受支持,因此不建议将其用于通用用途。

执行以下命令安装最新开发构建版本的 Transformer Engine

pip install git+https://github.com/NVIDIA/TransformerEngine.git@main

这将自动检测是否安装了任何受支持的深度学习框架,并为其构建 Transformer Engine 支持。要显式指定框架,请将环境变量 NVTE_FRAMEWORK 设置为逗号分隔的列表(例如 NVTE_FRAMEWORK=jax,pytorch)。要仅构建框架无关的 C++ API,请设置 NVTE_FRAMEWORK=none

为了安装特定的 PR,请执行(将 NNN 更改为 PR 编号后)

pip install git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge

安装 (从源代码)

执行以下命令从源代码安装 Transformer Engine

# Clone repository, checkout stable branch, clone submodules
git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git

cd TransformerEngine
export NVTE_FRAMEWORK=pytorch   # Optionally set framework
pip install .                   # Build and install

如果 Git 仓库已经克隆,请确保同时克隆子模块

git submodule update --init --recursive

可以通过设置 “test” 选项安装用于测试的额外依赖项

pip install .[test]

要使用调试符号构建 C++ 扩展,例如使用 -g 标志

pip install . --global-option=--debug