跳到内容

实用工具

remove_center_of_mass(data, mask=None)

计算给定数据的质心 (CoM)。

参数

名称 类型 描述 默认值
data 张量

输入数据,形状为 (..., 节点, 特征)。

必需
mask 可选[张量]

一个可选的二进制掩码,应用于形状为 (..., 节点) 的数据,以屏蔽来自 CoM 计算的交互。默认为 None。

None

返回值:数据的质心 (CoM),形状为 (..., 1, 特征)。

源代码位于 bionemo/moco/distributions/prior/continuous/utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def remove_center_of_mass(data: Tensor, mask: Optional[Tensor] = None) -> Tensor:
    """Calculates the center of mass (CoM) of the given data.

    Args:
        data: The input data with shape (..., nodes, features).
        mask: An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None.

    Returns:
    The CoM of the data with shape (..., 1, features).
    """
    if mask is None:
        com = data.mean(dim=-2, keepdim=True)
    else:
        masked_data = data * mask.unsqueeze(-1)
        num_nodes = mask.sum(dim=-1, keepdim=True).unsqueeze(-1)
        com = masked_data.sum(dim=-2, keepdim=True) / num_nodes
    return data - com