Rotation#

class cuequivariance_torch.Rotation(
irreps: Irreps,
*,
layout: IrrepsLayout | None = None,
layout_in: IrrepsLayout | None = None,
layout_out: IrrepsLayout | None = None,
device: device | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

表示 SO3 或 O3 表示的旋转层的类。

Parameters:
  • irreps (Irreps) – 要旋转的张量的不可约表示。

  • layout (IrrepsLayout, optional) – 张量的内存布局,首选 cue.ir_mul

Forward Pass

forward(
gamma: Tensor,
beta: Tensor,
alpha: Tensor,
x: Tensor,
) Tensor#

旋转层的前向传播。

Parameters:
  • gamma (torch.Tensor) – Gamma 角。 绕 y 轴的第一次旋转。

  • beta (torch.Tensor) – Beta 角。 绕 x 轴的第二次旋转。

  • alpha (torch.Tensor) – Alpha 角。 绕 y 轴的第三次旋转。

  • x (torch.Tensor) – 输入张量。

Returns:

旋转后的张量。

Return type:

torch.Tensor