FullyConnectedTensorProductConv#

class cuequivariance_torch.layers.FullyConnectedTensorProductConv(
in_irreps: Irreps,
sh_irreps: Irreps,
out_irreps: Irreps,
batch_norm: bool = True,
mlp_channels: Sequence[int] | None = None,
mlp_activation: Module | Sequence[Module] | None = GELU(approximate='none'),
layout: IrrepsLayout = None,
use_fallback: bool | None = None,
)#

用于 DiffDock 类似架构中张量积的消息传递层。张量积的左操作数是节点特征;右操作数由边向量的球谐函数组成。

数学公式

\[\sum_{b \in \mathcal{N}_a} \mathbf{h}_b \otimes_{\psi_{a b}} Y\left(\hat{r}_{a b}\right)\]

其中路径权重 \(\psi_{a b}\) 可以使用 MLP 从边嵌入和标量特征构建

\[\psi_{a b} = \operatorname{MLP} \left(e_{a b}, \mathbf{h}_a^0, \mathbf{h}_b^0\right)\]

用户可以选择直接输入权重,或者提供 MLP 参数以及来自边和节点的标量特征。

参数:
  • in_irreps (Irreps) – 输入节点特征的 Irreps。

  • sh_irreps (Irreps) – 边向量的球谐表示的 Irreps。

  • out_irreps (Irreps) – 输出的 Irreps。

  • batch_norm (bool, 可选) – 如果为真,则应用批归一化。默认为 True。

  • mlp_channels (Sequence of int, 可选) – 定义 MLP 中每个层(输出层之前)神经元数量的整数序列。如果为 None,则不添加 MLP。输入层包含边嵌入和节点标量特征。默认为 None。

  • mlp_activation (nn.Modulenn.Module 的 Sequence, 可选) – 要在 MLP 中的线性层之间应用的一系列函数,例如,nn.Sequential(nn.ReLU(), nn.Dropout(0.4))。默认为 nn.GELU()

  • layout (IrrepsLayout, 可选) – 输入和输出 irreps 的布局。默认为 cue.mul_ir,这是对应于 e3nn 的布局。

  • use_fallback (bool, 可选) – 如果为 None (默认),则在 CUDA 内核可用时使用它。如果为 False,则将使用 CUDA 内核,如果它不可用,则会引发异常。如果为 True,则无论 CUDA 内核是否可用,都将使用 PyTorch 回退方法。

示例

>>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o")
>>> sh_irreps = cue.Irreps("O3", "0e + 1o")
>>> out_irreps = cue.Irreps("O3", "4x0e + 4x1o")

案例 1:MLP 的输入层具有 6 个通道,2 个隐藏层具有 16 个通道。edge_emb.size(1) 必须与输入层的大小匹配:6

>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
...     mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul)
>>> conv1
FullyConnectedTensorProductConv(...)
>>> # out = conv1(src_features, edge_sh, edge_emb, graph)

案例 2:如果 edge_emb 是通过连接来自边、源和目标的标量特征构建的,如 DiffDock 中那样,则该层可以分别接受每个标量组件

>>> # out = conv1(src_features, edge_sh, edge_emb, graph, src_scalars, dst_scalars)

这允许在第一个 MLP 层中进行更小的 GEMM,方法是在索引之前对每个组件执行 GEMM。第一层权重按顺序拆分为边、源和目标的部分。这等效于

>>> # src, dst = graph.edge_index
>>> # edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst]))
>>> # out = conv1(src_features, edge_sh, edge_emb, graph)

案例 3:没有 MLP,edge_emb 将直接用作张量积权重

>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
...     mlp_channels=None, layout=cue.ir_mul)
>>> # out = conv3(src_features, edge_sh, edge_emb, graph)

前向传播

forward(
src_features: Tensor,
edge_sh: Tensor,
edge_emb: Tensor,
graph: tuple[Tensor, tuple[int, int]],
src_scalars: Tensor | None = None,
dst_scalars: Tensor | None = None,
reduce: str = 'mean',
edge_envelope: Tensor | None = None,
) Tensor#

前向传播。

参数:
  • src_features (torch.Tensor) – 源节点特征。形状:(num_src_nodes, in_irreps.dim)

  • edge_sh (torch.Tensor) – 边向量的球谐表示。形状:(num_edges, sh_irreps.dim)

  • edge_emb (torch.Tensor) –

    馈送到 MLP 以生成张量积权重的边嵌入。形状:(num_edges, dim),其中 dim 应为

    • tp.weight_numel 当该层不包含 MLP 时。

    • num_edge_scalars,当分别传入来自边、源和目标的标量特征时。

  • graph (tuple) – 存储图信息的元组,第一个元素是 COO 格式的邻接矩阵,第二个元素是其形状:(num_src_nodes, num_dst_nodes)

  • src_scalars (torch.Tensor, 可选) – 源节点的标量特征。请参阅示例以了解用法。形状:(num_src_nodes, num_src_scalars)

  • dst_scalars (torch.Tensor, 可选) – 目标节点的标量特征。请参阅示例以了解用法。形状:(num_dst_nodes, num_dst_scalars)

  • reduce (str, 可选) – 归约运算符。在 “mean” 和 “sum” 之间选择。默认为 “mean”。

  • edge_envelope (torch.Tensor, 可选) – 通常用作衰减因子,以逐渐减弱来自靠近用于创建图的截止距离的节点的消息。这对于使模型平滑地适应节点坐标的变化非常重要。形状:(num_edges,)

返回值:

输出节点特征。形状:(num_dst_nodes, out_irreps.dim)

返回类型:

torch.Tensor