EquivariantTensorProduct#

class cuequivariance_torch.EquivariantTensorProduct(
e: EquivariantTensorProduct,
*,
layout: IrrepsLayout | None = None,
layout_in: IrrepsLayout | tuple[IrrepsLayout | None, ...] | None = None,
layout_out: IrrepsLayout | None = None,
device: device | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

等变张量积。

参数:
  • e (cuequivariance.EquivariantTensorProduct) – 等变张量积。

  • layout (IrrepsLayout) – 输入和输出的布局。

  • layout_in (IrrepsLayout) – 输入的布局。

  • layout_out (IrrepsLayout) – 输出的布局。

  • device (torch.device) – 模块的设备。

  • math_dtype (torch.dtype) – 内部计算的数据类型。

  • use_fallback (bool, 可选) – 确定计算方法。 如果为 None (默认),则在可用时使用 CUDA 内核。 如果为 False,将使用 CUDA 内核,如果不可用则会引发异常。 如果为 True,则无论 CUDA 内核是否可用,都使用 PyTorch 后备方法。

引发:

RuntimeError – 如果 use_fallbackFalse 且没有 CUDA 内核可用。

示例

>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> e = cue.descriptors.fully_connected_tensor_product(
...    cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1")
... )
>>> w = torch.ones(1, e.inputs[0].dim, device=device)
>>> x1 = torch.ones(17, e.inputs[1].dim, device=device)
>>> x2 = torch.ones(17, e.inputs[2].dim, device=device)
>>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device)
>>> tp(w, x1, x2)
tensor([[0., 0., 0., 0., 0., 0.],...)

您可以选择性地索引第一个输入张量

>>> w = torch.ones(3, e.inputs[0].dim, device=device)
>>> indices = torch.randint(3, (17,))
>>> tp(w, x1, x2, indices=indices)
tensor([[0., 0., 0., 0., 0., 0.],...)

前向传播

forward(
x0: Tensor,
x1: Tensor | None = None,
x2: Tensor | None = None,
x3: Tensor | None = None,
indices: Tensor | None = None,
) Tensor#

如果 indices 不是 None,则第一个输入将通过 indices 进行索引。