SymmetricContraction#
- class cuequivariance_torch.SymmetricContraction(
- irreps_in: Irreps,
- irreps_out: Irreps,
- contraction_degree: int,
- num_elements: int,
- *,
- layout: IrrepsLayout | None = None,
- layout_in: IrrepsLayout | None = None,
- layout_out: IrrepsLayout | None = None,
- device: device | None = None,
- dtype: dtype | None = None,
- math_dtype: dtype | None = None,
- original_mace: bool = False,
- use_fallback: bool | None = None,
对称收缩操作的加速实现,该操作在 https://arxiv.org/abs/2206.07697 中介绍。
- 参数:
irreps_in (Irreps) – 输入 irreps。irreps 内的所有重数 (mul) 必须相同,表示每个 irrep 出现的次数相同。
irreps_out (Irreps) – 输出 irreps。与 irreps_in 类似,所有重数必须相同。
contraction_degree (int) – 对称收缩的度数,指定对称收缩中多项式的最大度数。
num_elements (int) – 权重张量的元素数量。
layout (IrrepsLayout, 可选) – 输入和输出 irreps 的布局。如果未提供,则使用默认布局。
math_dtype (torch.dtype, 可选) – 数学运算的数据类型。如果未指定,则使用来自 torch 环境的默认数据类型。
use_fallback (bool, 可选) – 如果为 None (默认),则在可用时使用 CUDA 内核。 如果为 False,将使用 CUDA 内核,如果不可用,则会引发异常。 如果为 True,则无论 CUDA 内核是否可用,都使用 PyTorch 回退方法。
示例
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o") >>> irreps_out = cue.Irreps("O3", "32x0e") >>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32, device=device)
现在 layer 可以用作 PyTorch 模型的一部分。
参数 original_mace 可以设置为 True 以模拟原始 MACE 实现。
>>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e") >>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o") >>> # OLD FUNCTION DEFINITION: >>> # symmetric_contractions_old = SymmetricContraction( >>> # irreps_in=feats_irreps, >>> # irreps_out=target_irreps, >>> # correlation=3, >>> # num_elements=10, >>> # ) >>> # NEW FUNCTION DEFINITION: >>> symmetric_contractions_new = cuet.SymmetricContraction( ... irreps_in=feats_irreps, ... irreps_out=target_irreps, ... contraction_degree=3, ... num_elements=10, ... layout_in=cue.ir_mul, ... layout_out=cue.mul_ir, ... original_mace=True, ... dtype=torch.float64, ... device=device, ... )
然后执行如下
>>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device) >>> # with node_attrs_index being the index version of node_attrs, sth like: >>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int() >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device) >>> # OLD CALL: >>> # symmetric_contractions_old(node_feats, node_attrs) >>> # NEW CALL: >>> node_feats = torch.transpose(node_feats, 1, 2).flatten(1) >>> symmetric_contractions_new(node_feats, node_attrs_index) tensor([[...)
注意
术语“mul”指的是 irrep 的重数,表示它在表示中出现的次数。 此层要求所有输入和输出 irreps 具有相同的重数,以便对称收缩操作能够良好定义。
前向传播
- forward( ) Tensor #
执行对称收缩操作的前向传播。
- 参数:
x (torch.Tensor) – 输入张量。它应该具有形状 (batch, irreps_in.dim)。
indices (torch.Tensor) – 用于每个批次元素的权重的索引。它应该具有形状 (batch, )。
- 返回值:
输出张量。它具有形状 (batch, irreps_out.dim)。
- 返回类型: