NVIDIA cuEquivariance 文档#
cuEquivariance 是一个 Python 库,旨在帮助构建使用分段张量积的高性能等变神经网络。cuEquivariance 提供了一个全面的 API,用于描述分段张量积和优化的 CUDA 内核以执行它们。此外,cuEquivariance 为 PyTorch 和 JAX 都提供了绑定,确保了广泛的兼容性和易于集成。
等变性是对“尊重对称性”概念的数学形式化。稳健的物理模型在三维空间中表现出对旋转和平移的等变性。结合等变性的人工智能模型通常具有更高的数据效率。
关于群表示的介绍可以在页面 群和表示 中找到。
开源#
cuEquivariance 前端是开源的,并在 GitHub 上以 Apache 2.0 许可证提供。
安装#
安装 cuEquivariance 最简单的方法是从 PyPi 使用 pip。
# Choose the frontend you want to use
pip install cuequivariance-jax
pip install cuequivariance-torch
pip install cuequivariance # Installs only the core non-ML components
# CUDA kernels for different CUDA versions
pip install cuequivariance-ops-torch-cu11
pip install cuequivariance-ops-torch-cu12
要求#
cuequivariance-ops-torch-*
软件包仅适用于 Linux x86_64,并且需要 PyTorch 2.4.0 或更高版本。
组织结构#
cuEquivariance 分为三个软件包
import cuequivariance as cue
# All the non-ML components
import cuequivariance_jax as cuex
# For the JAX implementations
import cuequivariance_torch as cuet
# For the PyTorch implementations

大多数张量积是使用 cue.EquivariantTensorProduct
类定义的,该类封装了每个输入和输出的 cue.Irreps
和 cue.IrrepsLayout
。它还包括一个或多个 cue.SegmentedTensorProduct
实例,这些实例定义了张量积运算。然后,此描述符用于创建一个 cuet.EquivariantTensorProduct
模块,该模块可以在 PyTorch 模型中使用。或者用于在 JAX 中使用 cuex.equivariant_tensor_product
执行张量积运算。