等变张量积#

子模块 cuequivariance.descriptors 包含由类 cuequivariance.EquivariantTensorProduct 表示的等变张量积的许多描述符。

示例#

线性层#

import cuequivariance as cue

cue.descriptors.linear(cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "16x0e + 48x1o"))
EquivariantTensorProduct(2048x0e x 32x0e+32x1o -> 16x0e+48x1o)

在此示例中,第一个操作数是权重,它们始终是标量。 有 32 * 16 = 512 个权重将 0e 连接在一起,以及 32 * 48 = 1536 个权重将 1o 连接在一起。 这总共给出 2048 个权重。

球谐函数#

cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
EquivariantTensorProduct((1)^(0..3) -> 0+1+2+3)

球谐函数是输入向量的多项式。 此描述符指定 0、1、2 和 3 次多项式。

旋转#

cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
EquivariantTensorProduct(3x0e x 3x0e x 3x0e x 32x0e+32x1o -> 32x0e+32x1o)

这种情况有点特殊,它是输入按角度旋转,角度编码为 \(sin(\theta)\)\(cos(\theta)\)。 有关更多详细信息,请参阅函数 cuet.encode_rotation_angle

在 JAX 上执行#

import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

e = cue.descriptors.linear(
    cue.Irreps("O3", "32x0e + 32x1o"),
    cue.Irreps("O3", "8x0e + 4x1o")
)
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])

cuex.equivariant_tensor_product(e, w, x)
{0: 8x0e+4x1o}
[-0.56710666  0.29934764  1.438811   -1.0761446  -0.16420852  0.6024779
 -1.7548201   0.3891445   0.03765802 -0.03795518 -0.750764   -3.2584484
  0.6283557   0.09663387 -0.42426407 -0.8612407   0.4686108  -0.9862214
 -0.31201616  0.8071706 ]

函数 cuex.randn 生成随机 cuex.RepArray 对象。 函数 cuex.equivariant_tensor_product 执行张量积。 输出是一个 cuex.RepArray 对象。

在 PyTorch 上执行#

我们可以使用 PyTorch 执行 cuequivariance.EquivariantTensorProduct

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

e = cue.descriptors.linear(
    cue.Irreps("O3", "32x0e + 32x1o"),
    cue.Irreps("O3", "8x0e + 4x1o")
)
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)

w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)

module(w, x)
tensor([[-0.5272,  0.3613,  0.9547,  1.7503,  0.2967,  0.2417,  1.0393,  1.4447,
          0.0804,  0.5824, -1.9556,  1.4756, -0.1092, -0.5326,  1.3033, -2.3201,
         -0.3170,  0.9616, -0.6593,  1.0125]])

请注意,您必须指定布局。 如果指定的布局与描述符中的布局不同,则模块将转置输入/输出以匹配布局。