EquivariantTensorProduct#
- class cuequivariance_torch.EquivariantTensorProduct(
- e: EquivariantTensorProduct,
- *,
- layout: IrrepsLayout | None = None,
- layout_in: IrrepsLayout | tuple[IrrepsLayout | None, ...] | None = None,
- layout_out: IrrepsLayout | None = None,
- device: device | None = None,
- math_dtype: dtype | None = None,
- use_fallback: bool | None = None,
等变张量积。
- 参数:
e (cuequivariance.EquivariantTensorProduct) – 等变张量积。
layout (IrrepsLayout) – 输入和输出的布局。
layout_in (IrrepsLayout) – 输入的布局。
layout_out (IrrepsLayout) – 输出的布局。
device (torch.device) – 模块的设备。
math_dtype (torch.dtype) – 内部计算的数据类型。
use_fallback (bool, 可选) – 确定计算方法。 如果为 None (默认),则在可用时使用 CUDA 内核。 如果为 False,将使用 CUDA 内核,如果不可用则会引发异常。 如果为 True,则无论 CUDA 内核是否可用,都使用 PyTorch 后备方法。
- 引发:
RuntimeError – 如果 use_fallback 为 False 且没有 CUDA 内核可用。
示例
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) >>> w = torch.ones(1, e.inputs[0].dim, device=device) >>> x1 = torch.ones(17, e.inputs[1].dim, device=device) >>> x2 = torch.ones(17, e.inputs[2].dim, device=device) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) >>> tp(w, x1, x2) tensor([[0., 0., 0., 0., 0., 0.],...)
您可以选择性地索引第一个输入张量
>>> w = torch.ones(3, e.inputs[0].dim, device=device) >>> indices = torch.randint(3, (17,)) >>> tp(w, x1, x2, indices=indices) tensor([[0., 0., 0., 0., 0., 0.],...)
前向传播