等变张量积#
子模块 cuequivariance.descriptors
包含由类 cuequivariance.EquivariantTensorProduct
表示的等变张量积的许多描述符。
示例#
线性层#
import cuequivariance as cue
cue.descriptors.linear(cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "16x0e + 48x1o"))
EquivariantTensorProduct(2048x0e x 32x0e+32x1o -> 16x0e+48x1o)
在此示例中,第一个操作数是权重,它们始终是标量。 有 32 * 16 = 512
个权重将 0e
连接在一起,以及 32 * 48 = 1536
个权重将 1o
连接在一起。 这总共给出 2048
个权重。
球谐函数#
cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
EquivariantTensorProduct((1)^(0..3) -> 0+1+2+3)
球谐函数是输入向量的多项式。 此描述符指定 0、1、2 和 3 次多项式。
旋转#
cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
EquivariantTensorProduct(3x0e x 3x0e x 3x0e x 32x0e+32x1o -> 32x0e+32x1o)
这种情况有点特殊,它是输入按角度旋转,角度编码为 \(sin(\theta)\) 和 \(cos(\theta)\)。 有关更多详细信息,请参阅函数 cuet.encode_rotation_angle
。
在 JAX 上执行#
import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex
e = cue.descriptors.linear(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "8x0e + 4x1o")
)
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])
cuex.equivariant_tensor_product(e, w, x)
{0: 8x0e+4x1o}
[-0.56710666 0.29934764 1.438811 -1.0761446 -0.16420852 0.6024779
-1.7548201 0.3891445 0.03765802 -0.03795518 -0.750764 -3.2584484
0.6283557 0.09663387 -0.42426407 -0.8612407 0.4686108 -0.9862214
-0.31201616 0.8071706 ]
函数 cuex.randn
生成随机 cuex.RepArray
对象。 函数 cuex.equivariant_tensor_product
执行张量积。 输出是一个 cuex.RepArray
对象。
在 PyTorch 上执行#
我们可以使用 PyTorch 执行 cuequivariance.EquivariantTensorProduct
。
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
e = cue.descriptors.linear(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "8x0e + 4x1o")
)
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)
w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)
module(w, x)
tensor([[-0.5272, 0.3613, 0.9547, 1.7503, 0.2967, 0.2417, 1.0393, 1.4447,
0.0804, 0.5824, -1.9556, 1.4756, -0.1092, -0.5326, 1.3033, -2.3201,
-0.3170, 0.9616, -0.6593, 1.0125]])
请注意,您必须指定布局。 如果指定的布局与描述符中的布局不同,则模块将转置输入/输出以匹配布局。