Linear#
- class cuequivariance_jax.flax_linen.Linear(irreps_out: ~cuequivariance.irreps_array.irreps.Irreps | str, layout: ~cuequivariance.irreps_array.irreps_layout.IrrepsLayout | None = None, force: bool = False, kernel_init: ~jax.nn.initializers.Initializer | ~collections.abc.Callable[[...], ~typing.Any] = <function normal>, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)#
等变线性层。
- 参数:
irreps_out (Irreps) – 输出不可约表示。(输入不可约表示从输入推断。)
layout (IrrepsLayout) – 输出不可约表示的布局。
force (bool) – 如果为 False,则输出不可约表示将被过滤,仅包含来自输入的可达不可约表示。
- kernel_init(
- shape: ~collections.abc.Sequence[int] = (),
- dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>,
使用给定的形状和浮点 dtype 采样标准正态随机值。
这些值根据概率密度函数返回
\[f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}\]在域 \(-\infty < x < \infty\) 上
- 参数:
key – 用作随机密钥的 PRNG 密钥。
shape – 可选,表示结果形状的非负整数元组。 默认值 ()。
dtype – 可选,返回值的浮点 dtype(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。
- 返回值:
具有指定形状和 dtype 的随机数组。