跳到内容

基础插值器

插值器

基类:ABC

表示插值器的抽象基类。

此类作为创建可用于各种应用中的插值器的基础,为插值相关操作提供基本结构和接口。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class Interpolant(ABC):
    """An abstract base class representing an Interpolant.

    This class serves as a foundation for creating interpolants that can be used
    in various applications, providing a basic structure and interface for
    interpolation-related operations.
    """

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: PriorDistribution,
        device: Union[str, torch.device] = "cpu",
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initializes the Interpolant class.

        Args:
            time_distribution (TimeDistribution): The distribution of time steps.
            prior_distribution (PriorDistribution): The prior distribution of the variable.
            device (Union[str, torch.device], optional): The device on which to operate. Defaults to "cpu".
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        """
        self.time_distribution = time_distribution
        self.prior_distribution = prior_distribution
        self.device = device
        self.rng_generator = rng_generator

    @abstractmethod
    def interpolate(self, *args, **kwargs) -> Tensor:
        """Get x(t) with given time t from noise and data.

        Interpolate between x0 and x1 at the given time t.
        """
        pass

    @abstractmethod
    def step(self, *args, **kwargs) -> Tensor:
        """Do one step integration."""
        pass

    def general_step(self, method_name: str, kwargs: dict):
        """Calls a step method of the class by its name, passing the provided keyword arguments.

        Args:
            method_name (str): The name of the step method to call.
            kwargs (dict): Keyword arguments to pass to the step method.

        Returns:
            The result of the step method call.

        Raises:
            ValueError: If the provided method name does not start with 'step'.
            Exception: If the step method call fails. The error message includes a list of available step methods.

        Note:
            This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.
        """
        if not method_name.startswith("step"):
            raise ValueError(f"Method name '{method_name}' does not start with 'step'")

        try:
            # Get the step method by its name
            func = getattr(self, method_name)
            # Call the step method with the provided keyword arguments
            return func(**kwargs)
        except Exception as e:
            # Get a list of available step methods
            available_methods = "\n".join([f"  - {attr}" for attr in dir(self) if attr.startswith("step")])
            # Create a detailed error message
            error_message = f"Error calling method '{method_name}': {e}\nAvailable step methods:\n{available_methods}"
            # Re-raise the exception with the detailed error message
            raise type(e)(error_message)

    def sample_prior(self, *args, **kwargs) -> Tensor:
        """Sample from prior distribution.

        This method generates a sample from the prior distribution specified by the
        `prior_distribution` attribute.

        Returns:
            Tensor: The generated sample from the prior distribution.
        """
        # Ensure the device is specified, default to self.device if not provided
        if "device" not in kwargs:
            kwargs["device"] = self.device
        kwargs["rng_generator"] = self.rng_generator
        # Sample from the prior distribution
        return self.prior_distribution.sample(*args, **kwargs)

    def sample_time(self, *args, **kwargs) -> Tensor:
        """Sample from time distribution."""
        # Ensure the device is specified, default to self.device if not provided
        if "device" not in kwargs:
            kwargs["device"] = self.device
        kwargs["rng_generator"] = self.rng_generator
        # Sample from the time distribution
        return self.time_distribution.sample(*args, **kwargs)

    def to_device(self, device: str):
        """Moves all internal tensors to the specified device and updates the `self.device` attribute.

        Args:
            device (str): The device to move the tensors to (e.g. "cpu", "cuda:0").

        Note:
            This method is used to transfer the internal state of the DDPM interpolant to a different device.
            It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
        """
        self.device = device
        for attr_name in dir(self):
            if attr_name.startswith("_") and isinstance(getattr(self, attr_name), torch.Tensor):
                setattr(self, attr_name, getattr(self, attr_name).to(device))
        return self

    def clean_mask_center(self, data: Tensor, mask: Optional[Tensor] = None, center: Bool = False) -> Tensor:
        """Returns a clean tensor that has been masked and/or centered based on the function arguments.

        Args:
            data: The input data with shape (..., nodes, features).
            mask: An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.
            center: A boolean indicating whether to center the data around the calculated CoM. Defaults to False.

        Returns:
            The data with shape (..., nodes, features) either centered around the CoM if `center` is True or unchanged if `center` is False.
        """
        if mask is not None:
            data = data * mask.unsqueeze(-1)
        if not center:
            return data
        if mask is None:
            num_nodes = torch.tensor(data.shape[1], device=data.device)
        else:
            num_nodes = torch.clamp(mask.sum(dim=-1), min=1)  # clamp used to prevent divide by 0
        com = data.sum(dim=-2) / num_nodes.unsqueeze(-1)
        return data - com.unsqueeze(-2)

__init__(time_distribution, prior_distribution, device='cpu', rng_generator=None)

初始化 Interpolant 类。

参数

名称 类型 描述 默认值
time_distribution TimeDistribution

时间步长的分布。

必需
prior_distribution PriorDistribution

变量的先验分布。

必需
device Union[str, device]

在其上操作的设备。默认为 "cpu"。

'cpu'
rng_generator Optional[Generator]

用于可重现采样的可选 :class:torch.Generator。默认为 None。

None
源代码位于 bionemo/moco/interpolants/base_interpolant.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: PriorDistribution,
    device: Union[str, torch.device] = "cpu",
    rng_generator: Optional[torch.Generator] = None,
):
    """Initializes the Interpolant class.

    Args:
        time_distribution (TimeDistribution): The distribution of time steps.
        prior_distribution (PriorDistribution): The prior distribution of the variable.
        device (Union[str, torch.device], optional): The device on which to operate. Defaults to "cpu".
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
    """
    self.time_distribution = time_distribution
    self.prior_distribution = prior_distribution
    self.device = device
    self.rng_generator = rng_generator

clean_mask_center(data, mask=None, center=False)

返回一个干净的张量,该张量已根据函数参数进行掩码和/或居中。

参数

名称 类型 描述 默认值
data Tensor

形状为 (..., 节点, 特征) 的输入数据。

必需
mask Optional[Tensor]

可选的掩码,用于应用于形状为 (..., 节点) 的数据。如果提供,则用于计算 CoM。默认为 None。

None
center Bool

一个布尔值,指示是否将数据围绕计算出的 CoM 居中。默认为 False。

False

返回值

类型 描述
Tensor

形状为 (..., 节点, 特征) 的数据,如果 center 为 True,则围绕 CoM 居中;如果 center 为 False,则保持不变。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def clean_mask_center(self, data: Tensor, mask: Optional[Tensor] = None, center: Bool = False) -> Tensor:
    """Returns a clean tensor that has been masked and/or centered based on the function arguments.

    Args:
        data: The input data with shape (..., nodes, features).
        mask: An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.
        center: A boolean indicating whether to center the data around the calculated CoM. Defaults to False.

    Returns:
        The data with shape (..., nodes, features) either centered around the CoM if `center` is True or unchanged if `center` is False.
    """
    if mask is not None:
        data = data * mask.unsqueeze(-1)
    if not center:
        return data
    if mask is None:
        num_nodes = torch.tensor(data.shape[1], device=data.device)
    else:
        num_nodes = torch.clamp(mask.sum(dim=-1), min=1)  # clamp used to prevent divide by 0
    com = data.sum(dim=-2) / num_nodes.unsqueeze(-1)
    return data - com.unsqueeze(-2)

general_step(method_name, kwargs)

通过名称调用类的 step 方法,并传递提供的关键字参数。

参数

名称 类型 描述 默认值
method_name str

要调用的 step 方法的名称。

必需
kwargs dict

要传递给 step 方法的关键字参数。

必需

返回值

类型 描述

step 方法调用的结果。

Raises

类型 描述
ValueError

如果提供的方法名称不以 'step' 开头。

Exception

如果 step 方法调用失败。错误消息包括可用 step 方法的列表。

Note

此方法允许动态调用 step 方法,从而在类的使用中提供灵活性。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def general_step(self, method_name: str, kwargs: dict):
    """Calls a step method of the class by its name, passing the provided keyword arguments.

    Args:
        method_name (str): The name of the step method to call.
        kwargs (dict): Keyword arguments to pass to the step method.

    Returns:
        The result of the step method call.

    Raises:
        ValueError: If the provided method name does not start with 'step'.
        Exception: If the step method call fails. The error message includes a list of available step methods.

    Note:
        This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.
    """
    if not method_name.startswith("step"):
        raise ValueError(f"Method name '{method_name}' does not start with 'step'")

    try:
        # Get the step method by its name
        func = getattr(self, method_name)
        # Call the step method with the provided keyword arguments
        return func(**kwargs)
    except Exception as e:
        # Get a list of available step methods
        available_methods = "\n".join([f"  - {attr}" for attr in dir(self) if attr.startswith("step")])
        # Create a detailed error message
        error_message = f"Error calling method '{method_name}': {e}\nAvailable step methods:\n{available_methods}"
        # Re-raise the exception with the detailed error message
        raise type(e)(error_message)

interpolate(*args, **kwargs) abstractmethod

从噪声和数据中获取给定时间 t 的 x(t)。

在给定时间 t 时在 x0 和 x1 之间插值。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
134
135
136
137
138
139
140
@abstractmethod
def interpolate(self, *args, **kwargs) -> Tensor:
    """Get x(t) with given time t from noise and data.

    Interpolate between x0 and x1 at the given time t.
    """
    pass

sample_prior(*args, **kwargs)

从先验分布中采样。

此方法从 prior_distribution 属性指定的先验分布生成一个样本。

返回值

名称 类型 描述
Tensor Tensor

从先验分布生成的样本。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def sample_prior(self, *args, **kwargs) -> Tensor:
    """Sample from prior distribution.

    This method generates a sample from the prior distribution specified by the
    `prior_distribution` attribute.

    Returns:
        Tensor: The generated sample from the prior distribution.
    """
    # Ensure the device is specified, default to self.device if not provided
    if "device" not in kwargs:
        kwargs["device"] = self.device
    kwargs["rng_generator"] = self.rng_generator
    # Sample from the prior distribution
    return self.prior_distribution.sample(*args, **kwargs)

sample_time(*args, **kwargs)

从时间分布中采样。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
196
197
198
199
200
201
202
203
def sample_time(self, *args, **kwargs) -> Tensor:
    """Sample from time distribution."""
    # Ensure the device is specified, default to self.device if not provided
    if "device" not in kwargs:
        kwargs["device"] = self.device
    kwargs["rng_generator"] = self.rng_generator
    # Sample from the time distribution
    return self.time_distribution.sample(*args, **kwargs)

step(*args, **kwargs) abstractmethod

执行一步积分。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
142
143
144
145
@abstractmethod
def step(self, *args, **kwargs) -> Tensor:
    """Do one step integration."""
    pass

to_device(device)

将所有内部张量移动到指定的设备,并更新 self.device 属性。

参数

名称 类型 描述 默认值
device str

要将张量移动到的设备(例如 "cpu", "cuda:0")。

必需
Note

此方法用于将 DDPM 插值器的内部状态传输到不同的设备。它更新 self.device 属性以反映新设备,并将所有内部张量移动到指定的设备。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def to_device(self, device: str):
    """Moves all internal tensors to the specified device and updates the `self.device` attribute.

    Args:
        device (str): The device to move the tensors to (e.g. "cpu", "cuda:0").

    Note:
        This method is used to transfer the internal state of the DDPM interpolant to a different device.
        It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
    """
    self.device = device
    for attr_name in dir(self):
        if attr_name.startswith("_") and isinstance(getattr(self, attr_name), torch.Tensor):
            setattr(self, attr_name, getattr(self, attr_name).to(device))
    return self

PredictionType

基类:Enum

一个枚举,表示降噪扩散概率模型 (DDPM) 可以用于的预测类型。

DDPM 是多功能模型,可用于各种预测任务,包括

  • 数据:从噪声输入预测原始数据分布。
  • 噪声:预测添加到原始数据以获得输入的噪声。
  • 速度:预测数据的速度或变化率,特别适用于建模时间动态。

这些预测类型可用于训练神经网络以执行特定任务,例如降噪、图像合成或时间序列预测。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class PredictionType(Enum):
    """An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for.

    DDPMs are versatile models that can be utilized for various prediction tasks, including:

    - **Data**: Predicting the original data distribution from a noisy input.
    - **Noise**: Predicting the noise that was added to the original data to obtain the input.
    - **Velocity**: Predicting the velocity or rate of change of the data, particularly useful for modeling temporal dynamics.

    These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.
    """

    DATA = "data"
    NOISE = "noise"
    VELOCITY = "velocity"

pad_like(source, target)

填充源张量的维度以匹配目标张量的维度。

参数

名称 类型 描述 默认值
source Tensor

要填充的张量。

必需
target Tensor

源张量应在维度上匹配的张量。

必需

返回值

名称 类型 描述
Tensor Tensor

填充后的源张量。

Raises

类型 描述
ValueError

如果源张量的维度多于目标张量。

示例

source = torch.tensor([1, 2, 3]) # 形状: (3,) target = torch.tensor([[1, 2], [4, 5], [7, 8]]) # 形状: (3, 2) padded_source = pad_like(source, target) # 形状: (3, 1)

源代码位于 bionemo/moco/interpolants/base_interpolant.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def pad_like(source: Tensor, target: Tensor) -> Tensor:
    """Pads the dimensions of the source tensor to match the dimensions of the target tensor.

    Args:
        source (Tensor): The tensor to be padded.
        target (Tensor): The tensor that the source tensor should match in dimensions.

    Returns:
        Tensor: The padded source tensor.

    Raises:
        ValueError: If the source tensor has more dimensions than the target tensor.

    Example:
        >>> source = torch.tensor([1, 2, 3])  # shape: (3,)
        >>> target = torch.tensor([[1, 2], [4, 5], [7, 8]])  # shape: (3, 2)
        >>> padded_source = pad_like(source, target)  # shape: (3, 1)
    """
    if source.ndim == target.ndim:
        return source
    elif source.ndim > target.ndim:
        raise ValueError(f"Cannot pad {source.shape} to {target.shape}")
    return source.view(list(source.shape) + [1] * (target.ndim - source.ndim))

string_to_enum(value, enum_type)

将字符串转换为指定类型的枚举值。如果输入已经是枚举实例,则按原样返回。

参数

名称 类型 描述 默认值
value Union[str, E]

要转换的字符串或现有枚举实例。

必需
enum_type Type[E]

要转换为的枚举类型。

必需

返回值

名称 类型 描述
E AnyEnum

对应的枚举值。

Raises

类型 描述
ValueError

如果字符串与任何枚举成员都不对应。

源代码位于 bionemo/moco/interpolants/base_interpolant.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def string_to_enum(value: Union[str, AnyEnum], enum_type: Type[AnyEnum]) -> AnyEnum:
    """Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is.

    Args:
        value (Union[str, E]): The string to convert or an existing enum instance.
        enum_type (Type[E]): The enum type to convert to.

    Returns:
        E: The corresponding enum value.

    Raises:
        ValueError: If the string does not correspond to any enum member.
    """
    if isinstance(value, enum_type):
        # If the value is already an enum, return it
        return value

    try:
        # Match the value to the Enum, case-insensitively
        return enum_type(value)
    except ValueError:
        # Raise a helpful error if the value is invalid
        valid_values = [e.value for e in enum_type]
        raise ValueError(f"Invalid value '{value}'. Expected one of {valid_values}.")