tensor_product#

cuequivariance_jax.tensor_product(
d: SegmentedTensorProduct,
*inputs: Array,
dtype_output: dtype | None = None,
dtype_math: dtype | None = None,
precision: Precision = Precision.HIGHEST,
algorithm: str = 'sliced',
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
) Array#

计算 SegmentedTensorProduct 的最后一个操作数。

参数:
  • d (SegmentedTensorProduct) – 操作的描述符。

  • *inputs (jax.Array) – 除了最后一个操作数之外,每个操作数的输入数组。

  • dtype_output (jnp.dtype, 可选) – 输出的数据类型。

  • dtype_math (jnp.dtype, 可选) – 数学运算的数据类型。

  • precision (jax.lax.Precision, 可选) – 计算的精度。默认为 jax.lax.Precision.HIGHEST

  • algorithm (str, 可选) – 用于计算的算法。默认为 “sliced”。 请参阅下表了解可用的算法。

  • use_custom_primitive (bool, 可选) – 是否使用自定义 JVP/转置规则。

  • use_custom_kernels (bool, 可选) – 是否使用自定义内核。

返回:

张量积的结果。SegmentedTensorProduct 的最后一个操作数。

返回类型:

jax.Array

张量积的可用算法#

算法

需要相同的分段

编译时间

执行时间

sliced

几分钟

取决于情况

stacked

几分钟

取决于情况

compact_stacked

几秒

取决于情况

indexed_compact

几秒

取决于情况

indexed_vmap

几秒

可能第二慢

indexed_for_loop

几秒

可能最慢