NVIDIA cuEquivariance 文档#

cuEquivariance 是一个 Python 库,旨在帮助构建使用分段张量积的高性能等变神经网络。cuEquivariance 提供了一个全面的 API,用于描述分段张量积和优化的 CUDA 内核以执行它们。此外,cuEquivariance 为 PyTorch 和 JAX 都提供了绑定,确保了广泛的兼容性和易于集成。

等变性是对“尊重对称性”概念的数学形式化。稳健的物理模型在三维空间中表现出对旋转和平移的等变性。结合等变性的人工智能模型通常具有更高的数据效率。

关于群表示的介绍可以在页面 群和表示 中找到。

开源#

cuEquivariance 前端是开源的,并在 GitHub 上以 Apache 2.0 许可证提供。

安装#

安装 cuEquivariance 最简单的方法是从 PyPi 使用 pip

# Choose the frontend you want to use
pip install cuequivariance-jax
pip install cuequivariance-torch
pip install cuequivariance  # Installs only the core non-ML components

# CUDA kernels for different CUDA versions
pip install cuequivariance-ops-torch-cu11
pip install cuequivariance-ops-torch-cu12

要求#

cuequivariance-ops-torch-* 软件包仅适用于 Linux x86_64,并且需要 PyTorch 2.4.0 或更高版本。

组织结构#

cuEquivariance 分为三个软件包

import cuequivariance as cue
# All the non-ML components

import cuequivariance_jax as cuex
# For the JAX implementations

import cuequivariance_torch as cuet
# For the PyTorch implementations
Main components of cuEquivariance

大多数张量积是使用 cue.EquivariantTensorProduct 类定义的,该类封装了每个输入和输出的 cue.Irrepscue.IrrepsLayout。它还包括一个或多个 cue.SegmentedTensorProduct 实例,这些实例定义了张量积运算。然后,此描述符用于创建一个 cuet.EquivariantTensorProduct 模块,该模块可以在 PyTorch 模型中使用。或者用于在 JAX 中使用 cuex.equivariant_tensor_product 执行张量积运算。

教程#

API 参考#

最新动态#