等变张量积#
- cuequivariance_jax.equivariant_tensor_product(
- e: EquivariantTensorProduct,
- *inputs: RepArray | Array,
- dtype_output: dtype | None = None,
- dtype_math: dtype | None = None,
- precision: Precision = Precision.HIGHEST,
- algorithm: str = 'sliced',
- use_custom_primitive: bool = True,
- use_custom_kernels: bool = False,
计算输入数组的等变张量积。
- 参数:
e (
cue.EquivariantTensorProduct
) – 等变张量积描述符。dtype_output (jnp.dtype, 可选) – 输出数组的数据类型。默认为 None。
dtype_math (jnp.dtype, 可选) – 计算操作的数据类型。默认为 None。
precision (jax.lax.Precision, 可选) – 计算精度。默认为
jax.lax.Precision.HIGHEST
。algorithm (str, 可选) – “sliced”、“stacked”、“compact_stacked”、“indexed_compact”、“indexed_vmap”、“indexed_for_loop” 之一。 默认为 “sliced”。 有关更多信息,请参阅
cuex.tensor_product
。use_custom_primitive (bool, 可选) – 是否使用自定义 JVP 规则。 默认为 True。
use_custom_kernels (bool, 可选) – 是否使用自定义内核。 默认为 True。
- 返回:
等变张量积的结果。
- 返回类型:
示例
让我们为 0、1 和 2 阶的球谐函数创建一个描述符。
>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e EquivariantTensorProduct((1)^(0..2) -> 0+1+2)
我们需要一些输入数据。
>>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) >>> x {0: 1} [0. 1. 0.]
现在我们可以执行等变张量积。
>>> cuex.equivariant_tensor_product(e, x) {0: 0+1+2} [1. ... ]