等变张量积#

cuequivariance_jax.equivariant_tensor_product(
e: EquivariantTensorProduct,
*inputs: RepArray | Array,
dtype_output: dtype | None = None,
dtype_math: dtype | None = None,
precision: Precision = Precision.HIGHEST,
algorithm: str = 'sliced',
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
) RepArray#

计算输入数组的等变张量积。

参数:
  • e (cue.EquivariantTensorProduct) – 等变张量积描述符。

  • *inputs (RepArrayjax.Array) – 输入数组。

  • dtype_output (jnp.dtype, 可选) – 输出数组的数据类型。默认为 None。

  • dtype_math (jnp.dtype, 可选) – 计算操作的数据类型。默认为 None。

  • precision (jax.lax.Precision, 可选) – 计算精度。默认为 jax.lax.Precision.HIGHEST

  • algorithm (str, 可选) – “sliced”、“stacked”、“compact_stacked”、“indexed_compact”、“indexed_vmap”、“indexed_for_loop” 之一。 默认为 “sliced”。 有关更多信息,请参阅 cuex.tensor_product

  • use_custom_primitive (bool, 可选) – 是否使用自定义 JVP 规则。 默认为 True。

  • use_custom_kernels (bool, 可选) – 是否使用自定义内核。 默认为 True。

返回:

等变张量积的结果。

返回类型:

RepArray

示例

让我们为 0、1 和 2 阶的球谐函数创建一个描述符。

>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2])
>>> e
EquivariantTensorProduct((1)^(0..2) -> 0+1+2)

我们需要一些输入数据。

>>> with cue.assume(cue.SO3, cue.ir_mul):
...    x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0]))
>>> x
{0: 1} [0. 1. 0.]

现在我们可以执行等变张量积。

>>> cuex.equivariant_tensor_product(e, x)
{0: 0+1+2}
[1. ... ]