分段张量积#

在此示例中,我们将展示如何创建一个自定义张量积描述符并执行它。首先我们需要导入必要的模块。

import itertools
import torch
import jax
import jax.numpy as jnp

import cuequivariance as cue
import cuequivariance.segmented_tensor_product as stp
import cuequivariance_torch as cuet # to execute the tensor product with PyTorch
import cuequivariance_jax as cuex   # to execute the tensor product with JAX

现在,我们将创建一个自定义张量积描述符,它表示两个表示的张量积。有关 irreps 的更多信息,请参阅群和表示

irreps1 = cue.Irreps("O3", "32x0e + 32x1o")
irreps2 = cue.Irreps("O3", "16x0e + 48x1o")

张量积描述符是逐步创建的。首先,我们根据其下标创建一个空描述符。在线性层的情况下,我们有 3 个操作数:权重、输入和输出。此张量积的下标为 “uv,iu,iv”,其中 “uv” 表示权重的模式,“iu” 表示输入的模式,“iv” 表示输出的模式。

d = stp.SegmentedTensorProduct.from_subscripts("uv,iu,iv")
d
uv,iu,iv sizes=0,0,0 num_segments=0,0,0 num_paths=0

张量积描述符的每个操作数都有一个段列表。我们可以使用 add_segment 方法将段添加到描述符。我们可以将输入和输出表示的段添加到描述符。

for mul, ir in irreps1:
   d.add_segment(1, (ir.dim, mul))
for mul, ir in irreps2:
   d.add_segment(2, (ir.dim, mul))

d
uv,iu,iv sizes=0,128,160 num_segments=0,2,2 num_paths=0 i={1, 3} u=32 v={16, 48}

现在我们可以枚举所有可能的 irreps 对,并在 irreps 相同时在它们之间添加权重段和路径。

for (i1, (mul1, ir1)), (i2, (mul2, ir2)) in itertools.product(
   enumerate(irreps1), enumerate(irreps2)
):
   if ir1 == ir2:
      d.add_path(None, i1, i2, c=1.0)

d
uv,iu,iv sizes=2048,128,160 num_segments=2,2,2 num_paths=2 i={1, 3} u=32 v={16, 48}

我们可以看到我们添加的两条路径

d.paths
[op0[0]*op1[0]*op2[0]*1., op0[1]*op1[1]*op2[1]*1.]

最后,我们可以对最后一个操作数的路径进行归一化,以使输出归一化为方差 1。

d = d.normalize_paths_for_operand(-1)
d.paths
[op0[0]*op1[0]*op2[0]*0.18, op0[1]*op1[1]*op2[1]*0.18]

正如我们所见,路径系数已被归一化。

现在我们可以从描述符创建一个张量积并执行它。在 PyTorch 中,我们可以使用 cuet.TensorProduct 类。

linear_torch = cuet.TensorProduct(d, use_fallback=True)
linear_torch
TensorProduct(uv,iu,iv sizes=2048,128,160 num_segments=2,2,2 num_paths=2 i={1, 3} u=32 v={16, 48} (without CUDA kernel))

在 JAX 中,我们可以使用 cuex.tensor_product 函数。

linear_jax = cuex.tensor_product(d)
linear_jax
<function cuequivariance_jax.primitives.tensor_product.tensor_product.<locals>._partial(*remaining_inputs: jax.Array) -> jax.Array>

现在我们可以使用随机输入和权重张量执行线性层。

w = torch.randn(1, d.operands[0].size)
x1 = torch.randn(3000, irreps1.dim)

x2 = linear_torch(w, x1)

assert x2.shape == (3000, irreps2.dim)

现在我们可以验证输出已正确归一化。

x2.var()
tensor(1.0330)

最后是 JAX 版本。

w = jax.random.normal(jax.random.key(0), (d.operands[0].size,))
x1 = jax.random.normal(jax.random.key(1), (3000, irreps1.dim))

x2 = linear_jax(w, x1)

assert x2.shape == (3000, irreps2.dim)
x2.var()
Array(1.0266904, dtype=float32)