tensor_product#
- cuequivariance_jax.tensor_product(
- d: SegmentedTensorProduct,
- *inputs: Array,
- dtype_output: dtype | None = None,
- dtype_math: dtype | None = None,
- precision: Precision = Precision.HIGHEST,
- algorithm: str = 'sliced',
- use_custom_primitive: bool = True,
- use_custom_kernels: bool = False,
计算 SegmentedTensorProduct 的最后一个操作数。
- 参数:
d (SegmentedTensorProduct) – 操作的描述符。
*inputs (jax.Array) – 除了最后一个操作数之外,每个操作数的输入数组。
dtype_output (jnp.dtype, 可选) – 输出的数据类型。
dtype_math (jnp.dtype, 可选) – 数学运算的数据类型。
precision (jax.lax.Precision, 可选) – 计算的精度。默认为
jax.lax.Precision.HIGHEST
。algorithm (str, 可选) – 用于计算的算法。默认为 “sliced”。 请参阅下表了解可用的算法。
use_custom_primitive (bool, 可选) – 是否使用自定义 JVP/转置规则。
use_custom_kernels (bool, 可选) – 是否使用自定义内核。
- 返回:
张量积的结果。SegmentedTensorProduct 的最后一个操作数。
- 返回类型:
张量积的可用算法# 算法
需要相同的分段
编译时间
执行时间
sliced
否
几分钟
取决于情况
stacked
是
几分钟
取决于情况
compact_stacked
是
几秒
取决于情况
indexed_compact
是
几秒
取决于情况
indexed_vmap
是
几秒
可能第二慢
indexed_for_loop
是
几秒
可能最慢