SymmetricContraction#

class cuequivariance_torch.SymmetricContraction(
irreps_in: Irreps,
irreps_out: Irreps,
contraction_degree: int,
num_elements: int,
*,
layout: IrrepsLayout | None = None,
layout_in: IrrepsLayout | None = None,
layout_out: IrrepsLayout | None = None,
device: device | None = None,
dtype: dtype | None = None,
math_dtype: dtype | None = None,
original_mace: bool = False,
use_fallback: bool | None = None,
)#

对称收缩操作的加速实现,该操作在 https://arxiv.org/abs/2206.07697 中介绍。

参数:
  • irreps_in (Irreps) – 输入 irreps。irreps 内的所有重数 (mul) 必须相同,表示每个 irrep 出现的次数相同。

  • irreps_out (Irreps) – 输出 irreps。与 irreps_in 类似,所有重数必须相同。

  • contraction_degree (int) – 对称收缩的度数,指定对称收缩中多项式的最大度数。

  • num_elements (int) – 权重张量的元素数量。

  • layout (IrrepsLayout, 可选) – 输入和输出 irreps 的布局。如果未提供,则使用默认布局。

  • math_dtype (torch.dtype, 可选) – 数学运算的数据类型。如果未指定,则使用来自 torch 环境的默认数据类型。

  • use_fallback (bool, 可选) – 如果为 None (默认),则在可用时使用 CUDA 内核。 如果为 False,将使用 CUDA 内核,如果不可用,则会引发异常。 如果为 True,则无论 CUDA 内核是否可用,都使用 PyTorch 回退方法。

示例

>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o")
>>> irreps_out = cue.Irreps("O3", "32x0e")
>>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32, device=device)

现在 layer 可以用作 PyTorch 模型的一部分。

参数 original_mace 可以设置为 True 以模拟原始 MACE 实现。

>>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
>>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o")
>>> # OLD FUNCTION DEFINITION:
>>> # symmetric_contractions_old = SymmetricContraction(
>>> #     irreps_in=feats_irreps,
>>> #     irreps_out=target_irreps,
>>> #     correlation=3,
>>> #     num_elements=10,
>>> # )
>>> # NEW FUNCTION DEFINITION:
>>> symmetric_contractions_new = cuet.SymmetricContraction(
...     irreps_in=feats_irreps,
...     irreps_out=target_irreps,
...     contraction_degree=3,
...     num_elements=10,
...     layout_in=cue.ir_mul,
...     layout_out=cue.mul_ir,
...     original_mace=True,
...     dtype=torch.float64,
...     device=device,
... )

然后执行如下

>>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device)
>>> # with node_attrs_index being the index version of node_attrs, sth like:
>>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int()
>>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device)
>>> # OLD CALL:
>>> # symmetric_contractions_old(node_feats, node_attrs)
>>> # NEW CALL:
>>> node_feats = torch.transpose(node_feats, 1, 2).flatten(1)
>>> symmetric_contractions_new(node_feats, node_attrs_index)
tensor([[...)

注意

术语“mul”指的是 irrep 的重数,表示它在表示中出现的次数。 此层要求所有输入和输出 irreps 具有相同的重数,以便对称收缩操作能够良好定义。

前向传播

forward(
x: Tensor,
indices: Tensor,
) Tensor#

执行对称收缩操作的前向传播。

参数:
  • x (torch.Tensor) – 输入张量。它应该具有形状 (batch, irreps_in.dim)。

  • indices (torch.Tensor) – 用于每个批次元素的权重的索引。它应该具有形状 (batch, )。

返回值:

输出张量。它具有形状 (batch, irreps_out.dim)。

返回类型:

torch.Tensor