FullyConnectedTensorProductConv#
- class cuequivariance_torch.layers.FullyConnectedTensorProductConv(
- in_irreps: Irreps,
- sh_irreps: Irreps,
- out_irreps: Irreps,
- batch_norm: bool = True,
- mlp_channels: Sequence[int] | None = None,
- mlp_activation: Module | Sequence[Module] | None = GELU(approximate='none'),
- layout: IrrepsLayout = None,
- use_fallback: bool | None = None,
用于 DiffDock 类似架构中张量积的消息传递层。张量积的左操作数是节点特征;右操作数由边向量的球谐函数组成。
数学公式
\[\sum_{b \in \mathcal{N}_a} \mathbf{h}_b \otimes_{\psi_{a b}} Y\left(\hat{r}_{a b}\right)\]其中路径权重 \(\psi_{a b}\) 可以使用 MLP 从边嵌入和标量特征构建
\[\psi_{a b} = \operatorname{MLP} \left(e_{a b}, \mathbf{h}_a^0, \mathbf{h}_b^0\right)\]用户可以选择直接输入权重,或者提供 MLP 参数以及来自边和节点的标量特征。
- 参数:
in_irreps (Irreps) – 输入节点特征的 Irreps。
sh_irreps (Irreps) – 边向量的球谐表示的 Irreps。
out_irreps (Irreps) – 输出的 Irreps。
batch_norm (bool, 可选) – 如果为真,则应用批归一化。默认为 True。
mlp_channels (Sequence of int, 可选) – 定义 MLP 中每个层(输出层之前)神经元数量的整数序列。如果为 None,则不添加 MLP。输入层包含边嵌入和节点标量特征。默认为 None。
mlp_activation (
nn.Module
或nn.Module
的 Sequence, 可选) – 要在 MLP 中的线性层之间应用的一系列函数,例如,nn.Sequential(nn.ReLU(), nn.Dropout(0.4))
。默认为nn.GELU()
。layout (IrrepsLayout, 可选) – 输入和输出 irreps 的布局。默认为
cue.mul_ir
,这是对应于 e3nn 的布局。use_fallback (bool, 可选) – 如果为 None (默认),则在 CUDA 内核可用时使用它。如果为 False,则将使用 CUDA 内核,如果它不可用,则会引发异常。如果为 True,则无论 CUDA 内核是否可用,都将使用 PyTorch 回退方法。
示例
>>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o") >>> sh_irreps = cue.Irreps("O3", "0e + 1o") >>> out_irreps = cue.Irreps("O3", "4x0e + 4x1o")
案例 1:MLP 的输入层具有 6 个通道,2 个隐藏层具有 16 个通道。edge_emb.size(1) 必须与输入层的大小匹配:6
>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, ... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul) >>> conv1 FullyConnectedTensorProductConv(...) >>> # out = conv1(src_features, edge_sh, edge_emb, graph)
案例 2:如果 edge_emb 是通过连接来自边、源和目标的标量特征构建的,如 DiffDock 中那样,则该层可以分别接受每个标量组件
>>> # out = conv1(src_features, edge_sh, edge_emb, graph, src_scalars, dst_scalars)
这允许在第一个 MLP 层中进行更小的 GEMM,方法是在索引之前对每个组件执行 GEMM。第一层权重按顺序拆分为边、源和目标的部分。这等效于
>>> # src, dst = graph.edge_index >>> # edge_emb = torch.hstack((edge_scalars, src_scalars[src], dst_scalars[dst])) >>> # out = conv1(src_features, edge_sh, edge_emb, graph)
案例 3:没有 MLP,edge_emb 将直接用作张量积权重
>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, ... mlp_channels=None, layout=cue.ir_mul) >>> # out = conv3(src_features, edge_sh, edge_emb, graph)
前向传播
- forward(
- src_features: Tensor,
- edge_sh: Tensor,
- edge_emb: Tensor,
- graph: tuple[Tensor, tuple[int, int]],
- src_scalars: Tensor | None = None,
- dst_scalars: Tensor | None = None,
- reduce: str = 'mean',
- edge_envelope: Tensor | None = None,
前向传播。
- 参数:
src_features (torch.Tensor) – 源节点特征。形状:(num_src_nodes, in_irreps.dim)
edge_sh (torch.Tensor) – 边向量的球谐表示。形状:(num_edges, sh_irreps.dim)
edge_emb (torch.Tensor) –
馈送到 MLP 以生成张量积权重的边嵌入。形状:(num_edges, dim),其中 dim 应为
tp.weight_numel 当该层不包含 MLP 时。
num_edge_scalars,当分别传入来自边、源和目标的标量特征时。
graph (tuple) – 存储图信息的元组,第一个元素是 COO 格式的邻接矩阵,第二个元素是其形状:(num_src_nodes, num_dst_nodes)
src_scalars (torch.Tensor, 可选) – 源节点的标量特征。请参阅示例以了解用法。形状:(num_src_nodes, num_src_scalars)
dst_scalars (torch.Tensor, 可选) – 目标节点的标量特征。请参阅示例以了解用法。形状:(num_dst_nodes, num_dst_scalars)
reduce (str, 可选) – 归约运算符。在 “mean” 和 “sum” 之间选择。默认为 “mean”。
edge_envelope (torch.Tensor, 可选) – 通常用作衰减因子,以逐渐减弱来自靠近用于创建图的截止距离的节点的消息。这对于使模型平滑地适应节点坐标的变化非常重要。形状:(num_edges,)
- 返回值:
输出节点特征。形状:(num_dst_nodes, out_irreps.dim)
- 返回类型: