Linear#

class cuequivariance_jax.flax_linen.Linear(irreps_out: ~cuequivariance.irreps_array.irreps.Irreps | str, layout: ~cuequivariance.irreps_array.irreps_layout.IrrepsLayout | None = None, force: bool = False, kernel_init: ~jax.nn.initializers.Initializer | ~collections.abc.Callable[[...], ~typing.Any] = <function normal>, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)#

等变线性层。

参数:
  • irreps_out (Irreps) – 输出不可约表示。(输入不可约表示从输入推断。)

  • layout (IrrepsLayout) – 输出不可约表示的布局。

  • force (bool) – 如果为 False,则输出不可约表示将被过滤,仅包含来自输入的可达不可约表示。

kernel_init(
shape: ~collections.abc.Sequence[int] = (),
dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType = <class 'float'>,
) Array#

使用给定的形状和浮点 dtype 采样标准正态随机值。

这些值根据概率密度函数返回

\[f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}\]

在域 \(-\infty < x < \infty\)

参数:
  • key – 用作随机密钥的 PRNG 密钥。

  • shape – 可选,表示结果形状的非负整数元组。 默认值 ()。

  • dtype – 可选,返回值的浮点 dtype(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。

返回值:

具有指定形状和 dtype 的随机数组。