跳到内容

连续信噪比变换

ContinuousSNRTransform

基类:ABC

连续信噪比调度的基类。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class ContinuousSNRTransform(ABC):
    """A base class for continuous SNR schedules."""

    def __init__(self, direction: TimeDirection):
        """Initialize the DiscreteNoiseSchedule.

        Args:
            direction (TimeDirection): required this defines in which direction the scheduler was built
        """
        self.direction = string_to_enum(direction, TimeDirection)

    def calculate_log_snr(
        self,
        t: Tensor,
        device: Union[str, torch.device] = "cpu",
        synchronize: Optional[TimeDirection] = None,
    ) -> Tensor:
        """Public wrapper to generate the time schedule as a tensor.

        Args:
            t (Tensor): The input tensor representing the time steps, with values ranging from 0 to 1.
            device (Optional[str]): The device to place the schedule on. Defaults to "cpu".
            synchronize (optional[TimeDirection]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
                this parameter allows to flip the direction to match the specified one. Defaults to None.

        Returns:
            Tensor: A tensor representing the log signal-to-noise (SNR) ratio for the given time steps.
        """
        if t.max() > 1:
            raise ValueError(f"Invalid value: max continuous time is 1, but got {t.max().item()}")

        if synchronize and self.direction != string_to_enum(synchronize, TimeDirection):
            t = 1 - t
        return self._calculate_log_snr(t, device)

    @abstractmethod
    def _calculate_log_snr(self, t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor:
        """Generate the log signal-to-noise (SNR) ratio.

        Args:
            t (Tensor): The input tensor representing the time steps.
            device (Optional[str]): The device to place the schedule on. Defaults to "cpu".

        Returns:
            Tensor: A tensor representing the log SNR values for the given time steps.
        """
        pass

    def log_snr_to_alphas_sigmas(self, log_snr: Tensor) -> Tuple[Tensor, Tensor]:
        """Converts log signal-to-noise ratio (SNR) to alpha and sigma values.

        Args:
            log_snr (Tensor): The input log SNR tensor.

        Returns:
            tuple[Tensor, Tensor]: A tuple containing the squared root of alpha and sigma values.
        """
        squared_alpha = log_snr.sigmoid()
        squared_sigma = (-log_snr).sigmoid()
        return squared_alpha.sqrt(), squared_sigma.sqrt()

    def derivative(self, t: Tensor, func: Callable) -> Tensor:
        """Compute derivative of a function, it supports bached single variable inputs.

        Args:
            t (Tensor): time variable at which derivatives are taken
            func (Callable): function for derivative calculation

        Returns:
            Tensor: derivative that is detached from the computational graph
        """
        with torch.enable_grad():
            t.requires_grad_(True)
            derivative = torch.autograd.grad(func(t).sum(), t, create_graph=False)[0].detach()
            t.requires_grad_(False)
        return derivative

    def calculate_general_sde_terms(self, t):
        """Compute the general SDE terms for a given time step t.

        Args:
            t (Tensor): The input tensor representing the time step.

        Returns:
            tuple[Tensor, Tensor]: A tuple containing the drift term f_t and the diffusion term g_t_2.

        Notes:
            This method computes the drift and diffusion terms of the general SDE, which can be used to simulate the stochastic process.
            The drift term represents the deterministic part of the process, while the diffusion term represents the stochastic part.
        """
        t = t.clone()
        t.requires_grad_(True)

        # Compute log SNR
        log_snr = self.calculate_log_snr(t, device=t.device)

        # Alpha^2 and Sigma^2
        alpha_squared = torch.sigmoid(log_snr)
        sigma_squared = torch.sigmoid(-log_snr)

        # Log Alpha
        log_alpha = 0.5 * torch.log(alpha_squared)

        # Compute derivatives
        log_alpha_deriv = torch.autograd.grad(log_alpha.sum(), t, create_graph=False)[0].detach()
        sigma_squared_deriv = torch.autograd.grad(sigma_squared.sum(), t, create_graph=False)[0].detach()

        # Compute drift and diffusion terms
        f_t = log_alpha_deriv  # Drift term
        g_t_2 = sigma_squared_deriv - 2 * log_alpha_deriv * sigma_squared  # Diffusion term

        return f_t, g_t_2

    def calculate_beta(self, t):
        r"""Compute the drift coefficient for the OU process of the form $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$.

        beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha)

        Args:
            t (Union[float, Tensor]): t in [0, 1]

        Returns:
            Tensor: beta(t)
        """
        t = t.clone()
        t.requires_grad_(True)
        log_snr = self.calculate_log_snr(t, device=t.device)
        alpha = self.calculate_alpha_log_snr(log_snr).detach()
        alpha_deriv_t = self.derivative(t, self.calculate_alpha_t).detach()
        beta = 2.0 * alpha_deriv_t / alpha
        # Chroma has a negative here but when removing the negative we get f = d/dt log (alpha**2) and the step_ode function works as expected
        return beta

    def calculate_alpha_log_snr(self, log_snr: Tensor) -> Tensor:
        """Compute alpha values based on the log SNR.

        Args:
            log_snr (Tensor): The input tensor representing the log signal-to-noise ratio.

        Returns:
            Tensor: A tensor representing the alpha values for the given log SNR.

        Notes:
            This method computes alpha values as the square root of the sigmoid of the log SNR.
        """
        return torch.sigmoid(log_snr).sqrt()

    def calculate_alpha_t(self, t: Tensor) -> Tensor:
        """Compute alpha values based on the log SNR schedule.

        Parameters:
            t (Tensor): The input tensor representing the time steps.

        Returns:
            Tensor: A tensor representing the alpha values for the given time steps.

        Notes:
            This method computes alpha values as the square root of the sigmoid of the log SNR.
        """
        log_snr = self.calculate_log_snr(t, device=t.device)
        alpha = torch.sigmoid(log_snr).sqrt()
        return alpha

__init__(direction)

初始化 DiscreteNoiseSchedule。

参数

名称 类型 描述 默认值
direction TimeDirection

必需,定义调度器构建的方向

必需
源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
45
46
47
48
49
50
51
def __init__(self, direction: TimeDirection):
    """Initialize the DiscreteNoiseSchedule.

    Args:
        direction (TimeDirection): required this defines in which direction the scheduler was built
    """
    self.direction = string_to_enum(direction, TimeDirection)

calculate_alpha_log_snr(log_snr)

基于对数信噪比计算 alpha 值。

参数

名称 类型 描述 默认值
log_snr Tensor

表示对数信噪比的输入张量。

必需

返回

名称 类型 描述
Tensor Tensor

表示给定对数信噪比的 alpha 值的张量。

注释

此方法计算 alpha 值,即对数信噪比的 sigmoid 函数的平方根。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
175
176
177
178
179
180
181
182
183
184
185
186
187
def calculate_alpha_log_snr(self, log_snr: Tensor) -> Tensor:
    """Compute alpha values based on the log SNR.

    Args:
        log_snr (Tensor): The input tensor representing the log signal-to-noise ratio.

    Returns:
        Tensor: A tensor representing the alpha values for the given log SNR.

    Notes:
        This method computes alpha values as the square root of the sigmoid of the log SNR.
    """
    return torch.sigmoid(log_snr).sqrt()

calculate_alpha_t(t)

基于对数信噪比调度计算 alpha 值。

参数

名称 类型 描述 默认值
t Tensor

表示时间步长的输入张量。

必需

返回

名称 类型 描述
Tensor Tensor

表示给定时间步长的 alpha 值的张量。

注释

此方法计算 alpha 值,即对数信噪比的 sigmoid 函数的平方根。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def calculate_alpha_t(self, t: Tensor) -> Tensor:
    """Compute alpha values based on the log SNR schedule.

    Parameters:
        t (Tensor): The input tensor representing the time steps.

    Returns:
        Tensor: A tensor representing the alpha values for the given time steps.

    Notes:
        This method computes alpha values as the square root of the sigmoid of the log SNR.
    """
    log_snr = self.calculate_log_snr(t, device=t.device)
    alpha = torch.sigmoid(log_snr).sqrt()
    return alpha

calculate_beta(t)

计算形式为 $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$ 的 OU 过程的漂移系数。

beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha)

参数

名称 类型 描述 默认值
t Union[float, Tensor]

t 在 [0, 1] 中

必需

返回

名称 类型 描述
Tensor

beta(t)

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def calculate_beta(self, t):
    r"""Compute the drift coefficient for the OU process of the form $dx = -\frac{1}{2} \beta(t) x dt + sqrt(beta(t)) dw_t$.

    beta = d/dt log(alpha**2) = 2 * 1/alpha * d/dt(alpha)

    Args:
        t (Union[float, Tensor]): t in [0, 1]

    Returns:
        Tensor: beta(t)
    """
    t = t.clone()
    t.requires_grad_(True)
    log_snr = self.calculate_log_snr(t, device=t.device)
    alpha = self.calculate_alpha_log_snr(log_snr).detach()
    alpha_deriv_t = self.derivative(t, self.calculate_alpha_t).detach()
    beta = 2.0 * alpha_deriv_t / alpha
    # Chroma has a negative here but when removing the negative we get f = d/dt log (alpha**2) and the step_ode function works as expected
    return beta

calculate_general_sde_terms(t)

计算给定时间步长 t 的通用 SDE 项。

参数

名称 类型 描述 默认值
t Tensor

表示时间步长的输入张量。

必需

返回

类型 描述

tuple[Tensor, Tensor]:包含漂移项 f_t 和扩散项 g_t_2 的元组。

注释

此方法计算通用 SDE 的漂移项和扩散项,可用于模拟随机过程。漂移项表示过程的确定性部分,而扩散项表示随机部分。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def calculate_general_sde_terms(self, t):
    """Compute the general SDE terms for a given time step t.

    Args:
        t (Tensor): The input tensor representing the time step.

    Returns:
        tuple[Tensor, Tensor]: A tuple containing the drift term f_t and the diffusion term g_t_2.

    Notes:
        This method computes the drift and diffusion terms of the general SDE, which can be used to simulate the stochastic process.
        The drift term represents the deterministic part of the process, while the diffusion term represents the stochastic part.
    """
    t = t.clone()
    t.requires_grad_(True)

    # Compute log SNR
    log_snr = self.calculate_log_snr(t, device=t.device)

    # Alpha^2 and Sigma^2
    alpha_squared = torch.sigmoid(log_snr)
    sigma_squared = torch.sigmoid(-log_snr)

    # Log Alpha
    log_alpha = 0.5 * torch.log(alpha_squared)

    # Compute derivatives
    log_alpha_deriv = torch.autograd.grad(log_alpha.sum(), t, create_graph=False)[0].detach()
    sigma_squared_deriv = torch.autograd.grad(sigma_squared.sum(), t, create_graph=False)[0].detach()

    # Compute drift and diffusion terms
    f_t = log_alpha_deriv  # Drift term
    g_t_2 = sigma_squared_deriv - 2 * log_alpha_deriv * sigma_squared  # Diffusion term

    return f_t, g_t_2

calculate_log_snr(t, device='cpu', synchronize=None)

生成作为张量的时间调度的公共包装器。

参数

名称 类型 描述 默认值
t Tensor

表示时间步长的输入张量,值范围为 0 到 1。

必需
device Optional[str]

放置调度的设备。默认为“cpu”。

'cpu'
synchronize optional[TimeDirection]

用于同步调度的时间方向。如果调度是用不同的方向定义的,则此参数允许翻转方向以匹配指定的方向。默认为 None。

None

返回

名称 类型 描述
Tensor Tensor

表示给定时间步长的对数信噪比 (SNR) 的张量。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def calculate_log_snr(
    self,
    t: Tensor,
    device: Union[str, torch.device] = "cpu",
    synchronize: Optional[TimeDirection] = None,
) -> Tensor:
    """Public wrapper to generate the time schedule as a tensor.

    Args:
        t (Tensor): The input tensor representing the time steps, with values ranging from 0 to 1.
        device (Optional[str]): The device to place the schedule on. Defaults to "cpu".
        synchronize (optional[TimeDirection]): TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction,
            this parameter allows to flip the direction to match the specified one. Defaults to None.

    Returns:
        Tensor: A tensor representing the log signal-to-noise (SNR) ratio for the given time steps.
    """
    if t.max() > 1:
        raise ValueError(f"Invalid value: max continuous time is 1, but got {t.max().item()}")

    if synchronize and self.direction != string_to_enum(synchronize, TimeDirection):
        t = 1 - t
    return self._calculate_log_snr(t, device)

derivative(t, func)

计算函数的导数,它支持批处理的单变量输入。

参数

名称 类型 描述 默认值
t Tensor

计算导数的时间变量

必需
func Callable

用于导数计算的函数

必需

返回

名称 类型 描述
Tensor Tensor

从计算图中分离的导数

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def derivative(self, t: Tensor, func: Callable) -> Tensor:
    """Compute derivative of a function, it supports bached single variable inputs.

    Args:
        t (Tensor): time variable at which derivatives are taken
        func (Callable): function for derivative calculation

    Returns:
        Tensor: derivative that is detached from the computational graph
    """
    with torch.enable_grad():
        t.requires_grad_(True)
        derivative = torch.autograd.grad(func(t).sum(), t, create_graph=False)[0].detach()
        t.requires_grad_(False)
    return derivative

log_snr_to_alphas_sigmas(log_snr)

将对数信噪比 (SNR) 转换为 alpha 和 sigma 值。

参数

名称 类型 描述 默认值
log_snr Tensor

输入对数信噪比张量。

必需

返回

类型 描述
Tuple[Tensor, Tensor]

tuple[Tensor, Tensor]:包含 alpha 和 sigma 值的平方根的元组。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def log_snr_to_alphas_sigmas(self, log_snr: Tensor) -> Tuple[Tensor, Tensor]:
    """Converts log signal-to-noise ratio (SNR) to alpha and sigma values.

    Args:
        log_snr (Tensor): The input log SNR tensor.

    Returns:
        tuple[Tensor, Tensor]: A tuple containing the squared root of alpha and sigma values.
    """
    squared_alpha = log_snr.sigmoid()
    squared_sigma = (-log_snr).sigmoid()
    return squared_alpha.sqrt(), squared_sigma.sqrt()

CosineSNRTransform

基类:ContinuousSNRTransform

余弦信噪比调度。

参数

名称 类型 描述 默认值
nu Optional[Float]

余弦调度指数的超参数(默认为 1.0)。

1.0
s Optional[Float]

余弦调度偏移的超参数(默认为 0.008)。

0.008
源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class CosineSNRTransform(ContinuousSNRTransform):
    """A cosine SNR schedule.

    Args:
        nu (Optional[Float]): Hyperparameter for the cosine schedule exponent (default is 1.0).
        s (Optional[Float]): Hyperparameter for the cosine schedule shift (default is 0.008).
    """

    def __init__(self, nu: Float = 1.0, s: Float = 0.008):
        """Initialize the CosineNoiseSchedule."""
        self.direction = TimeDirection.DIFFUSION
        self.nu = nu
        self.s = s

    def _calculate_log_snr(self, t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor:
        """Calculate the log signal-to-noise ratio (SNR) for the cosine noise schedule i.e. -gamma.

        The SNR is the equivalent to alpha_bar**2 / (1 - alpha_bar**2) from DDPM.
        This method computes the log SNR as described in the paper "Improved Denoising Diffusion Probabilistic Models" (https://arxiv.org/pdf/2107.00630).
        Note 1 / (1 + exp(- log_snr)) returns this cosine**2 for alpha_bar**2
        See  https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material and https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

        Args:
            t (Tensor): The input tensor representing the time steps.
            device (str): Device to place the schedule on (default is "cpu").

        Returns:
            Tensor: A tensor representing the log SNR for the given time steps.
        """
        return -log((torch.cos((t**self.nu + self.s) / (1 + self.s) * math.pi * 0.5) ** -2) - 1, eps=1e-5).to(device)

__init__(nu=1.0, s=0.008)

初始化 CosineNoiseSchedule。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
214
215
216
217
218
def __init__(self, nu: Float = 1.0, s: Float = 0.008):
    """Initialize the CosineNoiseSchedule."""
    self.direction = TimeDirection.DIFFUSION
    self.nu = nu
    self.s = s

LinearLogInterpolatedSNRTransform

基类:ContinuousSNRTransform

线性对数空间插值信噪比调度。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
class LinearLogInterpolatedSNRTransform(ContinuousSNRTransform):
    """A Linear Log space interpolated SNR schedule."""

    def __init__(self, min_value: Float = -7.0, max_value=13.5):
        """Initialize the Linear log space interpolated SNR Schedule from Chroma.

        Args:
            min_value (Float): The min log SNR value.
            max_value (Float): the max log SNR value.
        """
        self.direction = TimeDirection.DIFFUSION
        self.min_value = min_value
        self.max_value = max_value

    def _calculate_log_snr(self, t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor:
        """Calculate the log signal-to-noise ratio (SNR) for the cosine noise schedule i.e. -gamma.

        See https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L316C23-L316C50

        Args:
            t (Tensor): The input tensor representing the time steps.
            device (Optional[str]): The device to place the schedule on. Defaults to "cpu".

        Returns:
            Tensor: A tensor representing the log SNR for the given time steps.
        """
        log_snr = (1 - t) * self.max_value + t * self.min_value
        return log_snr.to(device)

__init__(min_value=-7.0, max_value=13.5)

从 Chroma 初始化线性对数空间插值信噪比调度。

参数

名称 类型 描述 默认值
min_value Float

最小对数信噪比值。

-7.0
max_value Float

最大对数信噪比值。

13.5
源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
270
271
272
273
274
275
276
277
278
279
def __init__(self, min_value: Float = -7.0, max_value=13.5):
    """Initialize the Linear log space interpolated SNR Schedule from Chroma.

    Args:
        min_value (Float): The min log SNR value.
        max_value (Float): the max log SNR value.
    """
    self.direction = TimeDirection.DIFFUSION
    self.min_value = min_value
    self.max_value = max_value

LinearSNRTransform

基类:ContinuousSNRTransform

线性信噪比调度。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class LinearSNRTransform(ContinuousSNRTransform):
    """A Linear SNR schedule."""

    def __init__(self, min_value: Float = 1.0e-4):
        """Initialize the Linear SNR Transform.

        Args:
            min_value (Float): min vaue of SNR defaults to 1.e-4.
        """
        self.direction = TimeDirection.DIFFUSION
        self.min_value = min_value

    def _calculate_log_snr(self, t: Tensor, device: Union[str, torch.device] = "cpu") -> Tensor:
        """Calculate the log signal-to-noise ratio (SNR) for the cosine noise schedule i.e. -gamma.

        The SNR is the equivalent to alpha_bar**2 / (1 - alpha_bar**2) from DDPM.
        See  https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material and https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

        Args:
            t (Tensor): The input tensor representing the time steps.
            device (Optional[str]): The device to place the schedule on. Defaults to "cpu".

        Returns:
            Tensor: A tensor representing the log SNR for the given time steps.
        """
        # This is equivalanet to the interpolated one from -10 to 9.2
        return -log(torch.expm1(self.min_value + 10 * (t**2))).to(device)

__init__(min_value=0.0001)

初始化线性信噪比变换。

参数

名称 类型 描述 默认值
min_value Float

信噪比的最小值为 1.e-4。

0.0001
源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
241
242
243
244
245
246
247
248
def __init__(self, min_value: Float = 1.0e-4):
    """Initialize the Linear SNR Transform.

    Args:
        min_value (Float): min vaue of SNR defaults to 1.e-4.
    """
    self.direction = TimeDirection.DIFFUSION
    self.min_value = min_value

log(t, eps=1e-20)

计算张量的自然对数,钳制值以避免数值不稳定性。

参数

名称 类型 描述 默认值
t Tensor

输入张量。

必需
eps float

钳制输入张量的最小值(默认为 1e-20)。

1e-20

返回

名称 类型 描述
Tensor

输入张量的自然对数。

源代码位于 bionemo/moco/schedules/noise/continuous_snr_transforms.py
29
30
31
32
33
34
35
36
37
38
39
def log(t, eps=1e-20):
    """Compute the natural logarithm of a tensor, clamping values to avoid numerical instability.

    Args:
        t (Tensor): The input tensor.
        eps (float, optional): The minimum value to clamp the input tensor (default is 1e-20).

    Returns:
        Tensor: The natural logarithm of the input tensor.
    """
    return torch.log(t.clamp(min=eps))