分段张量积#
在此示例中,我们将展示如何创建一个自定义张量积描述符并执行它。首先我们需要导入必要的模块。
import itertools
import torch
import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance.segmented_tensor_product as stp
import cuequivariance_torch as cuet # to execute the tensor product with PyTorch
import cuequivariance_jax as cuex # to execute the tensor product with JAX
现在,我们将创建一个自定义张量积描述符,它表示两个表示的张量积。有关 irreps 的更多信息,请参阅群和表示。
irreps1 = cue.Irreps("O3", "32x0e + 32x1o")
irreps2 = cue.Irreps("O3", "16x0e + 48x1o")
张量积描述符是逐步创建的。首先,我们根据其下标创建一个空描述符。在线性层的情况下,我们有 3 个操作数:权重、输入和输出。此张量积的下标为 “uv,iu,iv”,其中 “uv” 表示权重的模式,“iu” 表示输入的模式,“iv” 表示输出的模式。
d = stp.SegmentedTensorProduct.from_subscripts("uv,iu,iv")
d
uv,iu,iv sizes=0,0,0 num_segments=0,0,0 num_paths=0
张量积描述符的每个操作数都有一个段列表。我们可以使用 add_segment
方法将段添加到描述符。我们可以将输入和输出表示的段添加到描述符。
for mul, ir in irreps1:
d.add_segment(1, (ir.dim, mul))
for mul, ir in irreps2:
d.add_segment(2, (ir.dim, mul))
d
uv,iu,iv sizes=0,128,160 num_segments=0,2,2 num_paths=0 i={1, 3} u=32 v={16, 48}
现在我们可以枚举所有可能的 irreps 对,并在 irreps 相同时在它们之间添加权重段和路径。
for (i1, (mul1, ir1)), (i2, (mul2, ir2)) in itertools.product(
enumerate(irreps1), enumerate(irreps2)
):
if ir1 == ir2:
d.add_path(None, i1, i2, c=1.0)
d
uv,iu,iv sizes=2048,128,160 num_segments=2,2,2 num_paths=2 i={1, 3} u=32 v={16, 48}
我们可以看到我们添加的两条路径
d.paths
[op0[0]*op1[0]*op2[0]*1., op0[1]*op1[1]*op2[1]*1.]
最后,我们可以对最后一个操作数的路径进行归一化,以使输出归一化为方差 1。
d = d.normalize_paths_for_operand(-1)
d.paths
[op0[0]*op1[0]*op2[0]*0.18, op0[1]*op1[1]*op2[1]*0.18]
正如我们所见,路径系数已被归一化。
现在我们可以从描述符创建一个张量积并执行它。在 PyTorch 中,我们可以使用 cuet.TensorProduct
类。
linear_torch = cuet.TensorProduct(d, use_fallback=True)
linear_torch
TensorProduct(uv,iu,iv sizes=2048,128,160 num_segments=2,2,2 num_paths=2 i={1, 3} u=32 v={16, 48} (without CUDA kernel))
在 JAX 中,我们可以使用 cuex.tensor_product
函数。
linear_jax = cuex.tensor_product(d)
linear_jax
<function cuequivariance_jax.primitives.tensor_product.tensor_product.<locals>._partial(*remaining_inputs: jax.Array) -> jax.Array>
现在我们可以使用随机输入和权重张量执行线性层。
w = torch.randn(1, d.operands[0].size)
x1 = torch.randn(3000, irreps1.dim)
x2 = linear_torch(w, x1)
assert x2.shape == (3000, irreps2.dim)
现在我们可以验证输出已正确归一化。
x2.var()
tensor(1.0330)
最后是 JAX 版本。
w = jax.random.normal(jax.random.key(0), (d.operands[0].size,))
x1 = jax.random.normal(jax.random.key(1), (3000, irreps1.dim))
x2 = linear_jax(w, x1)
assert x2.shape == (3000, irreps2.dim)
x2.var()
Array(1.0266904, dtype=float32)