SphericalHarmonics#

class cuequivariance_torch.SphericalHarmonics(
ls: list[int],
normalize: bool = True,
device: device | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

计算作为torch模块的输入向量的球面调和函数。

前向传播

forward(vectors: Tensor) Tensor#
参数:

vectors (torch.Tensor) – 形状为 (batch, 3) 的输入向量。

返回:

输入向量的球面调和函数,形状为 (batch, dim),其中 dim 是 ls 中所有 l 的 2*l+1 的总和。

返回类型:

torch.Tensor