跳到内容

Vdm

VDM

基类:Interpolant

变分扩散模型 (VDM) 插值器。


示例

>>> import torch
>>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.moco.interpolants.discrete_time.continuous.vdm import VDM
>>> from bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform
>>> from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule


vdm = VDM(
    time_distribution = UniformTimeDistribution(...),
    prior_distribution = GaussianPrior(...),
    noise_schedule = CosineSNRTransform(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = vdm.sample_time(batch_size)
    noise = vdm.sample_prior(data.shape)
    xt = vdm.interpolate(data, noise, time)

    x_pred = model(xt, time)
    loss = vdm.loss(x_pred, data, time)
    loss.backward()

# Generation
x_pred = vdm.sample_prior(data.shape)
for t in LinearInferenceSchedule(...).generate_schedule():
    time = torch.full((batch_size,), t)
    x_hat = model(x_pred, time)
    x_pred = vdm.step(x_hat, time, x_pred)
return x_pred

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
class VDM(Interpolant):
    """A Variational Diffusion Models (VDM) interpolant.

     -------

    Examples:
    ```python
    >>> import torch
    >>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
    >>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
    >>> from bionemo.moco.interpolants.discrete_time.continuous.vdm import VDM
    >>> from bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform
    >>> from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule


    vdm = VDM(
        time_distribution = UniformTimeDistribution(...),
        prior_distribution = GaussianPrior(...),
        noise_schedule = CosineSNRTransform(...),
        )
    model = Model(...)

    # Training
    for epoch in range(1000):
        data = data_loader.get(...)
        time = vdm.sample_time(batch_size)
        noise = vdm.sample_prior(data.shape)
        xt = vdm.interpolate(data, noise, time)

        x_pred = model(xt, time)
        loss = vdm.loss(x_pred, data, time)
        loss.backward()

    # Generation
    x_pred = vdm.sample_prior(data.shape)
    for t in LinearInferenceSchedule(...).generate_schedule():
        time = torch.full((batch_size,), t)
        x_hat = model(x_pred, time)
        x_pred = vdm.step(x_hat, time, x_pred)
    return x_pred

    ```
    """

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: PriorDistribution,
        noise_schedule: ContinuousSNRTransform,
        prediction_type: Union[PredictionType, str] = PredictionType.DATA,
        device: Union[str, torch.device] = "cpu",
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initializes the DDPM interpolant.

        Args:
            time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
            prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
            noise_schedule (ContinuousSNRTransform): The schedule of noise, defining the amount of noise added at each time step.
            prediction_type (PredictionType, optional): The type of prediction, either "data" or another type. Defaults to "data".
            device (str, optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        """
        super().__init__(time_distribution, prior_distribution, device, rng_generator)
        if not isinstance(prior_distribution, GaussianPrior):
            warnings.warn("Prior distribution is not a GaussianPrior, unexpected behavior may occur")
        self.noise_schedule = noise_schedule
        self.prediction_type = string_to_enum(prediction_type, PredictionType)
        self._loss_function = nn.MSELoss(reduction="none")

    def interpolate(self, data: Tensor, t: Tensor, noise: Tensor):
        """Get x(t) with given time t from noise and data.

        Args:
            data (Tensor): target
            t (Tensor): time
            noise (Tensor): noise from prior()
        """
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        psi, omega = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        psi = pad_like(psi, data)
        omega = pad_like(omega, data)
        x_t = data * psi + noise * omega
        return x_t

    def forward_process(self, data: Tensor, t: Tensor, noise: Optional[Tensor] = None):
        """Get x(t) with given time t from noise and data.

        Args:
            data (Tensor): target
            t (Tensor): time
            noise (Tensor, optional): noise from prior(). Defaults to None
        """
        if noise is None:
            noise = self.sample_prior(data.shape)
        return self.interpolate(data, t, noise)

    def process_data_prediction(self, model_output: Tensor, sample, t):
        """Converts the model output to a data prediction based on the prediction type.

        This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
        Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
        The conversion formulas are as follows:
        - For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale`
        - For "data" prediction type: `pred_data = model_output`
        - For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output`

        Args:
            model_output (Tensor): The output of the model.
            sample (Tensor): The input sample.
            t (Tensor): The time step.

        Returns:
            The data prediction based on the prediction type.

        Raises:
            ValueError: If the prediction type is not one of "noise", "data", or "v_prediction".
        """
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        data_scale, noise_scale = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        data_scale = pad_like(data_scale, model_output)
        noise_scale = pad_like(noise_scale, model_output)
        if self.prediction_type == PredictionType.NOISE:
            pred_data = (sample - noise_scale * model_output) / data_scale
        elif self.prediction_type == PredictionType.DATA:
            pred_data = model_output
        elif self.prediction_type == PredictionType.VELOCITY:
            pred_data = data_scale * sample - noise_scale * model_output
        else:
            raise ValueError(
                f"prediction_type given as {self.prediction_type} must be one of PredictionType.NOISE, PredictionType.DATA or"
                f" PredictionType.VELOCITY for vdm."
            )
        return pred_data

    def process_noise_prediction(self, model_output: Tensor, sample: Tensor, t: Tensor):
        """Do the same as process_data_prediction but take the model output and convert to nosie.

        Args:
            model_output (Tensor): The output of the model.
            sample (Tensor): The input sample.
            t (Tensor): The time step.

        Returns:
            The input as noise if the prediction type is "noise".

        Raises:
            ValueError: If the prediction type is not "noise".
        """
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        data_scale, noise_scale = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        data_scale = pad_like(data_scale, model_output)
        noise_scale = pad_like(noise_scale, model_output)
        if self.prediction_type == PredictionType.NOISE:
            pred_noise = model_output
        elif self.prediction_type == PredictionType.DATA:
            pred_noise = (sample - data_scale * model_output) / noise_scale
        elif self.prediction_type == PredictionType.VELOCITY:
            pred_data = data_scale * sample - noise_scale * model_output
            pred_noise = (sample - data_scale * pred_data) / noise_scale
        else:
            raise ValueError(
                f"prediction_type given as {self.prediction_type} must be one of `noise`, `data` or"
                " `v_prediction`  for vdm."
            )
        return pred_noise

    def step(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        dt: Tensor,
        mask: Optional[Tensor] = None,
        center: Bool = False,
        temperature: Float = 1.0,
    ):
        """Do one step integration.

        Args:
            model_out (Tensor): The output of the model.
            xt (Tensor): The current data point.
            t (Tensor): The current time step.
            dt (Tensor): The time step increment.
            mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
            center (bool): Whether to center the data. Defaults to False.
            temperature (Float): The temperature parameter for low temperature sampling. Defaults to 1.0.

        Note:
            The temperature parameter controls the trade off between diversity and sample quality.
            Decreasing the temperature sharpens the sampling distribtion to focus on more likely samples.
            The impact of low temperature sampling must be ablated analytically.
        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        x_hat = self.process_data_prediction(model_out, xt, t)

        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        alpha_t, sigma_t = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)

        if (t - dt < 0).any():
            raise ValueError(
                "Error in inference schedule: t - dt < 0. Please ensure that your inference time schedule has shape T with the final t = dt to make s = 0"
            )

        log_snr_s = self.noise_schedule.calculate_log_snr(t - dt, device=self.device)
        alpha_s, sigma_s = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr_s)
        sigma_s_2 = sigma_s * sigma_s
        sigma_t_2 = sigma_t * sigma_t
        alpha_t_s = alpha_t / alpha_s
        sigma_2_t_s = -torch.expm1(F.softplus(-log_snr_s) - F.softplus(-log_snr))  # Equation 63

        omega_r = alpha_t_s * sigma_s_2 / sigma_t_2  # Equation 28
        psi_r = alpha_s * sigma_2_t_s / sigma_t_2
        std = sigma_2_t_s.sqrt() * sigma_s / sigma_t
        nonzero_mask = (
            t > 0
        ).float()  # based on the time this is always just ones. can leave for now to see if ever want to take extra step and only grab mean

        psi_r = pad_like(psi_r, x_hat)
        omega_r = pad_like(omega_r, x_hat)
        std = pad_like(std, x_hat)
        nonzero_mask = pad_like(nonzero_mask, x_hat)

        mean = psi_r * x_hat + omega_r * xt
        eps = torch.randn_like(mean).to(model_out.device)
        x_next = mean + nonzero_mask * std * eps * temperature
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def score(self, x_hat: Tensor, xt: Tensor, t: Tensor):
        """Converts the data prediction to the estimated score function.

        Args:
            x_hat (tensor): The predicted data point.
            xt (Tensor): The current data point.
            t (Tensor): The time step.

        Returns:
            The estimated score function.
        """
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        psi, omega = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        psi = pad_like(psi, x_hat)
        omega = pad_like(omega, x_hat)
        score = psi * x_hat - xt
        score = score / (omega * omega)
        return score

    def step_ddim(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        dt: Tensor,
        mask: Optional[Tensor] = None,
        eta: Float = 0.0,
        center: Bool = False,
    ):
        """Do one step of DDIM sampling.

        From the ddpm equations alpha_bar = alpha**2 and  1 - alpha**2 = sigma**2

        Args:
            model_out (Tensor): output of the model
            t (Tensor): current time step
            xt (Tensor): current data point
            dt (Tensor): The time step increment.
            mask (Optional[Tensor], optional): mask for the data point. Defaults to None.
            eta (Float, optional): DDIM sampling parameter. Defaults to 0.0.
            center (Bool, optional): whether to center the data point. Defaults to False.
        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        data_pred = self.process_data_prediction(model_out, xt, t)
        noise_pred = self.process_noise_prediction(model_out, xt, t)
        eps = torch.randn_like(data_pred).to(model_out.device)
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        squared_alpha = log_snr.sigmoid()
        squared_sigma = (-log_snr).sigmoid()
        log_snr_prev = self.noise_schedule.calculate_log_snr(t - dt, device=self.device)
        squared_alpha_prev = log_snr_prev.sigmoid()
        squared_sigma_prev = (-log_snr_prev).sigmoid()
        sigma_t_2 = squared_sigma_prev / squared_sigma * (1 - squared_alpha / squared_alpha_prev)
        psi_r = torch.sqrt(squared_alpha_prev)
        omega_r = torch.sqrt(1 - squared_alpha_prev - eta * eta * sigma_t_2)

        sigma_t_2 = pad_like(sigma_t_2, model_out)
        psi_r = pad_like(psi_r, model_out)
        omega_r = pad_like(omega_r, model_out)

        mean = data_pred * psi_r + omega_r * noise_pred
        x_next = mean + eta * sigma_t_2.sqrt() * eps
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def set_loss_weight_fn(self, fn: Callable):
        """Sets the loss_weight attribute of the instance to the given function.

        Args:
            fn: The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.
        """
        self.loss_weight = fn

    def loss_weight(self, raw_loss: Tensor, t: Tensor, weight_type: str, dt: Float = 0.001) -> Tensor:
        """Calculates the weight for the loss based on the given weight type.

        This function computes the loss weight according to the specified `weight_type`.
        The available weight types are:
        - "ones": uniform weight of 1.0
        - "data_to_noise": derived from Equation (9) of https://arxiv.org/pdf/2202.00512
        - "variational_objective": based on the variational objective, see https://arxiv.org/pdf/2202.00512

        Args:
            raw_loss (Tensor): The raw loss calculated from the model prediction and target.
            t (Tensor): The time step.
            weight_type (str): The type of weight to use. Can be "ones", "data_to_noise", or "variational_objective".
            dt (Float, optional): The time step increment. Defaults to 0.001.

        Returns:
            Tensor: The weight for the loss.

        Raises:
            ValueError: If the weight type is not recognized.
        """
        if weight_type == "ones":
            schedule = torch.ones_like(raw_loss).to(raw_loss.device)
        elif weight_type == "data_to_noise":  #
            log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
            psi, omega = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
            schedule = (psi**2) / (omega**2)
            for _ in range(raw_loss.ndim - 1):
                schedule = schedule.unsqueeze(-1)
        elif weight_type == "variational_objective":
            # (1-SNR(t-1)/SNR(t)),
            snr = torch.exp(self.noise_schedule.calculate_log_snr(t, device=self.device))
            snr_m1 = torch.exp(self.noise_schedule.calculate_log_snr(t - dt, device=self.device))
            schedule = 1 - snr_m1 / snr
            for _ in range(raw_loss.ndim - 1):
                schedule = schedule.unsqueeze(-1)
        else:
            raise ValueError("Invalid loss weight keyword")
        return schedule

    def loss(
        self,
        model_pred: Tensor,
        target: Tensor,
        t: Tensor,
        dt: Optional[Float] = 0.001,
        mask: Optional[Tensor] = None,
        weight_type: str = "ones",
    ):
        """Calculates the loss given the model prediction, target, and time.

        Args:
            model_pred (Tensor): The predicted output from the model.
            target (Tensor): The target output for the model prediction.
            t (Tensor): The time at which the loss is calculated.
            dt (Optional[Float], optional): The time step increment. Defaults to 0.001.
            mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
            weight_type (str, optional): The type of weight to use for the loss. Can be "ones", "data_to_noise", or "variational_objective". Defaults to "ones".

        Returns:
            Tensor: The calculated loss batch tensor.
        """
        raw_loss = self._loss_function(model_pred, target)
        update_weight = self.loss_weight(raw_loss, t, weight_type, dt)
        loss = raw_loss * update_weight
        if mask is not None:
            loss = loss * mask.unsqueeze(-1)
            n_elem = torch.sum(mask, dim=-1)
            loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / n_elem
        else:
            loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / model_pred.size(1)
        return loss

    def step_hybrid_sde(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        dt: Tensor,
        mask: Optional[Tensor] = None,
        center: Bool = False,
        temperature: Float = 1.0,
        equilibrium_rate: Float = 0.0,
    ) -> Tensor:
        """Do one step integration of Hybrid Langevin-Reverse Time SDE.

        See section B.3 page 37 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
        and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

        Args:
            model_out (Tensor): The output of the model.
            xt (Tensor): The current data point.
            t (Tensor): The current time step.
            dt (Tensor): The time step increment.
            mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
            center (bool, optional): Whether to center the data. Defaults to False.
            temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.
            equilibrium_rate (Float, optional): The rate of Langevin equilibration.  Scales the amount of Langevin dynamics per unit time. Best values are in the range [1.0, 5.0]. Defaults to 0.0.

        Note:
        For all step functions that use the SDE formulation its important to note that we are moving backwards in time which corresponds to an apparent sign change.
        A clear example can be seen in slide 29 https://ernestryu.com/courses/FM/diffusion1.pdf.
        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        x_hat = self.process_data_prediction(model_out, xt, t)
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        alpha, sigma = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        # Schedule coeffiecients
        beta = self.noise_schedule.calculate_beta(t)
        inverse_temperature = 1 / temperature  # lambda_0
        langevin_factor = equilibrium_rate
        # Temperature coefficients
        lambda_t = (
            inverse_temperature * (sigma.pow(2) + alpha.pow(2)) / (inverse_temperature * sigma.pow(2) + alpha.pow(2))
        )
        # langevin_isothermal = True
        lambda_langevin = inverse_temperature  # if langevin_isothermal else lambda_t

        score_scale_t = lambda_t + lambda_langevin * langevin_factor / 2.0

        eps = torch.randn_like(x_hat).to(model_out.device)
        score = self.score(x_hat, xt, t)
        beta = pad_like(beta, model_out)
        score_scale_t = pad_like(score_scale_t, model_out)

        gT = beta * ((-1 / 2) * xt - score_scale_t * score)
        gW = torch.sqrt((1.0 + langevin_factor) * beta.abs()) * eps

        x_next = xt + dt * gT + dt.sqrt() * gW
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def step_ode(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        dt: Tensor,
        mask: Optional[Tensor] = None,
        center: Bool = False,
        temperature: Float = 1.0,
    ) -> Tensor:
        """Do one step integration of ODE.

        See section B page 36 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
        and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

        Args:
            model_out (Tensor): The output of the model.
            xt (Tensor): The current data point.
            t (Tensor): The current time step.
            dt (Tensor): The time step increment.
            mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
            center (bool, optional): Whether to center the data. Defaults to False.
            temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.
        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        x_hat = self.process_data_prediction(model_out, xt, t)
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        alpha, sigma = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        # Schedule coeffiecients
        beta = self.noise_schedule.calculate_beta(t)
        inverse_temperature = 1 / temperature
        # Temperature coefficients
        lambda_t = (
            inverse_temperature * (sigma.pow(2) + alpha.pow(2)) / (inverse_temperature * sigma.pow(2) + alpha.pow(2))
        )

        score = self.score(x_hat, xt, t)
        beta = pad_like(beta, model_out)
        lambda_t = pad_like(lambda_t, model_out)

        gT = (-1 / 2) * beta * (xt + lambda_t * score)

        x_next = xt + gT * dt
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

__init__(time_distribution, prior_distribution, noise_schedule, prediction_type=PredictionType.DATA, device='cpu', rng_generator=None)

初始化 DDPM 插值器。

参数

名称 类型 描述 默认值
time_distribution TimeDistribution

时间步的分布,用于为扩散过程采样时间点。

必需
prior_distribution PriorDistribution

变量的先验分布,用作扩散过程的起点。

必需
noise_schedule ContinuousSNRTransform

噪声的调度,定义在每个时间步添加的噪声量。

必需
prediction_type PredictionType

预测的类型,可以是“data”或其他类型。默认为“data”。

DATA
device str

运行插值器的设备,可以是“cpu”或 CUDA 设备(例如“cuda:0”)。默认为“cpu”。

'cpu'
rng_generator Optional[Generator]

一个可选的 :class:torch.Generator 用于可重复的采样。默认为 None。

None
源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: PriorDistribution,
    noise_schedule: ContinuousSNRTransform,
    prediction_type: Union[PredictionType, str] = PredictionType.DATA,
    device: Union[str, torch.device] = "cpu",
    rng_generator: Optional[torch.Generator] = None,
):
    """Initializes the DDPM interpolant.

    Args:
        time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
        prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
        noise_schedule (ContinuousSNRTransform): The schedule of noise, defining the amount of noise added at each time step.
        prediction_type (PredictionType, optional): The type of prediction, either "data" or another type. Defaults to "data".
        device (str, optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
    """
    super().__init__(time_distribution, prior_distribution, device, rng_generator)
    if not isinstance(prior_distribution, GaussianPrior):
        warnings.warn("Prior distribution is not a GaussianPrior, unexpected behavior may occur")
    self.noise_schedule = noise_schedule
    self.prediction_type = string_to_enum(prediction_type, PredictionType)
    self._loss_function = nn.MSELoss(reduction="none")

forward_process(data, t, noise=None)

从噪声和数据中获取给定时间 t 的 x(t)。

参数

名称 类型 描述 默认值
data Tensor

目标

必需
t Tensor

时间

必需
noise Tensor

来自 prior() 的噪声。默认为 None

None
源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
118
119
120
121
122
123
124
125
126
127
128
def forward_process(self, data: Tensor, t: Tensor, noise: Optional[Tensor] = None):
    """Get x(t) with given time t from noise and data.

    Args:
        data (Tensor): target
        t (Tensor): time
        noise (Tensor, optional): noise from prior(). Defaults to None
    """
    if noise is None:
        noise = self.sample_prior(data.shape)
    return self.interpolate(data, t, noise)

interpolate(data, t, noise)

从噪声和数据中获取给定时间 t 的 x(t)。

参数

名称 类型 描述 默认值
data Tensor

目标

必需
t Tensor

时间

必需
noise Tensor

来自 prior() 的噪声

必需
源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def interpolate(self, data: Tensor, t: Tensor, noise: Tensor):
    """Get x(t) with given time t from noise and data.

    Args:
        data (Tensor): target
        t (Tensor): time
        noise (Tensor): noise from prior()
    """
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    psi, omega = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
    psi = pad_like(psi, data)
    omega = pad_like(omega, data)
    x_t = data * psi + noise * omega
    return x_t

loss(model_pred, target, t, dt=0.001, mask=None, weight_type='ones')

计算给定模型预测、目标和时间的损失。

参数

名称 类型 描述 默认值
model_pred Tensor

来自模型的预测输出。

必需
目标 Tensor

模型预测的目标输出。

必需
t Tensor

计算损失的时间。

必需
dt Optional[Float]

时间步增量。默认为 0.001。

0.001
mask Optional[Tensor]

数据点的掩码。默认为 None。

None
weight_type str

用于损失的权重类型。可以是“ones”、“data_to_noise”或“variational_objective”。默认为“ones”。

'ones'

返回

名称 类型 描述
Tensor

计算出的损失批张量。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def loss(
    self,
    model_pred: Tensor,
    target: Tensor,
    t: Tensor,
    dt: Optional[Float] = 0.001,
    mask: Optional[Tensor] = None,
    weight_type: str = "ones",
):
    """Calculates the loss given the model prediction, target, and time.

    Args:
        model_pred (Tensor): The predicted output from the model.
        target (Tensor): The target output for the model prediction.
        t (Tensor): The time at which the loss is calculated.
        dt (Optional[Float], optional): The time step increment. Defaults to 0.001.
        mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
        weight_type (str, optional): The type of weight to use for the loss. Can be "ones", "data_to_noise", or "variational_objective". Defaults to "ones".

    Returns:
        Tensor: The calculated loss batch tensor.
    """
    raw_loss = self._loss_function(model_pred, target)
    update_weight = self.loss_weight(raw_loss, t, weight_type, dt)
    loss = raw_loss * update_weight
    if mask is not None:
        loss = loss * mask.unsqueeze(-1)
        n_elem = torch.sum(mask, dim=-1)
        loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / n_elem
    else:
        loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / model_pred.size(1)
    return loss

loss_weight(raw_loss, t, weight_type, dt=0.001)

根据给定的权重类型计算损失的权重。

此函数根据指定的 weight_type 计算损失权重。可用的权重类型有: - “ones”:均匀权重 1.0 - “data_to_noise”:从 https://arxiv.org/pdf/2202.00512 的公式 (9) 导出 - “variational_objective”:基于变分目标,参见 https://arxiv.org/pdf/2202.00512

参数

名称 类型 描述 默认值
raw_loss Tensor

从模型预测和目标计算的原始损失。

必需
t Tensor

时间步。

必需
weight_type str

要使用的权重类型。可以是“ones”、“data_to_noise”或“variational_objective”。

必需
dt Float

时间步增量。默认为 0.001。

0.001

返回

名称 类型 描述
Tensor Tensor

损失的权重。

抛出

类型 描述
ValueError

如果无法识别权重类型。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def loss_weight(self, raw_loss: Tensor, t: Tensor, weight_type: str, dt: Float = 0.001) -> Tensor:
    """Calculates the weight for the loss based on the given weight type.

    This function computes the loss weight according to the specified `weight_type`.
    The available weight types are:
    - "ones": uniform weight of 1.0
    - "data_to_noise": derived from Equation (9) of https://arxiv.org/pdf/2202.00512
    - "variational_objective": based on the variational objective, see https://arxiv.org/pdf/2202.00512

    Args:
        raw_loss (Tensor): The raw loss calculated from the model prediction and target.
        t (Tensor): The time step.
        weight_type (str): The type of weight to use. Can be "ones", "data_to_noise", or "variational_objective".
        dt (Float, optional): The time step increment. Defaults to 0.001.

    Returns:
        Tensor: The weight for the loss.

    Raises:
        ValueError: If the weight type is not recognized.
    """
    if weight_type == "ones":
        schedule = torch.ones_like(raw_loss).to(raw_loss.device)
    elif weight_type == "data_to_noise":  #
        log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
        psi, omega = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
        schedule = (psi**2) / (omega**2)
        for _ in range(raw_loss.ndim - 1):
            schedule = schedule.unsqueeze(-1)
    elif weight_type == "variational_objective":
        # (1-SNR(t-1)/SNR(t)),
        snr = torch.exp(self.noise_schedule.calculate_log_snr(t, device=self.device))
        snr_m1 = torch.exp(self.noise_schedule.calculate_log_snr(t - dt, device=self.device))
        schedule = 1 - snr_m1 / snr
        for _ in range(raw_loss.ndim - 1):
            schedule = schedule.unsqueeze(-1)
    else:
        raise ValueError("Invalid loss weight keyword")
    return schedule

process_data_prediction(model_output, sample, t)

根据预测类型将模型输出转换为数据预测。

此转换源于扩散模型快速采样的渐进蒸馏 https://arxiv.org/pdf/2202.00512。给定模型输出和样本,我们根据预测类型将输出转换为数据预测。转换公式如下: - 对于“noise”预测类型:pred_data = (sample - noise_scale * model_output) / data_scale - 对于“data”预测类型:pred_data = model_output - 对于“v_prediction”预测类型:pred_data = data_scale * sample - noise_scale * model_output

参数

名称 类型 描述 默认值
model_output Tensor

模型的输出。

必需
sample Tensor

输入样本。

必需
t Tensor

时间步。

必需

返回

类型 描述

基于预测类型的数据预测。

抛出

类型 描述
ValueError

如果预测类型不是“noise”、“data”或“v_prediction”之一。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def process_data_prediction(self, model_output: Tensor, sample, t):
    """Converts the model output to a data prediction based on the prediction type.

    This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
    Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
    The conversion formulas are as follows:
    - For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale`
    - For "data" prediction type: `pred_data = model_output`
    - For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output`

    Args:
        model_output (Tensor): The output of the model.
        sample (Tensor): The input sample.
        t (Tensor): The time step.

    Returns:
        The data prediction based on the prediction type.

    Raises:
        ValueError: If the prediction type is not one of "noise", "data", or "v_prediction".
    """
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    data_scale, noise_scale = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
    data_scale = pad_like(data_scale, model_output)
    noise_scale = pad_like(noise_scale, model_output)
    if self.prediction_type == PredictionType.NOISE:
        pred_data = (sample - noise_scale * model_output) / data_scale
    elif self.prediction_type == PredictionType.DATA:
        pred_data = model_output
    elif self.prediction_type == PredictionType.VELOCITY:
        pred_data = data_scale * sample - noise_scale * model_output
    else:
        raise ValueError(
            f"prediction_type given as {self.prediction_type} must be one of PredictionType.NOISE, PredictionType.DATA or"
            f" PredictionType.VELOCITY for vdm."
        )
    return pred_data

process_noise_prediction(model_output, sample, t)

执行与 process_data_prediction 相同的操作,但获取模型输出并转换为噪声。

参数

名称 类型 描述 默认值
model_output Tensor

模型的输出。

必需
sample Tensor

输入样本。

必需
t Tensor

时间步。

必需

返回

类型 描述

如果预测类型为“noise”,则将输入作为噪声。

抛出

类型 描述
ValueError

如果预测类型不是“noise”。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def process_noise_prediction(self, model_output: Tensor, sample: Tensor, t: Tensor):
    """Do the same as process_data_prediction but take the model output and convert to nosie.

    Args:
        model_output (Tensor): The output of the model.
        sample (Tensor): The input sample.
        t (Tensor): The time step.

    Returns:
        The input as noise if the prediction type is "noise".

    Raises:
        ValueError: If the prediction type is not "noise".
    """
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    data_scale, noise_scale = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
    data_scale = pad_like(data_scale, model_output)
    noise_scale = pad_like(noise_scale, model_output)
    if self.prediction_type == PredictionType.NOISE:
        pred_noise = model_output
    elif self.prediction_type == PredictionType.DATA:
        pred_noise = (sample - data_scale * model_output) / noise_scale
    elif self.prediction_type == PredictionType.VELOCITY:
        pred_data = data_scale * sample - noise_scale * model_output
        pred_noise = (sample - data_scale * pred_data) / noise_scale
    else:
        raise ValueError(
            f"prediction_type given as {self.prediction_type} must be one of `noise`, `data` or"
            " `v_prediction`  for vdm."
        )
    return pred_noise

score(x_hat, xt, t)

将数据预测转换为估计的分数函数。

参数

名称 类型 描述 默认值
x_hat tensor

预测的数据点。

必需
xt Tensor

当前数据点。

必需
t Tensor

时间步。

必需

返回

类型 描述

估计的分数函数。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def score(self, x_hat: Tensor, xt: Tensor, t: Tensor):
    """Converts the data prediction to the estimated score function.

    Args:
        x_hat (tensor): The predicted data point.
        xt (Tensor): The current data point.
        t (Tensor): The time step.

    Returns:
        The estimated score function.
    """
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    psi, omega = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
    psi = pad_like(psi, x_hat)
    omega = pad_like(omega, x_hat)
    score = psi * x_hat - xt
    score = score / (omega * omega)
    return score

set_loss_weight_fn(fn)

将实例的 loss_weight 属性设置为给定的函数。

参数

名称 类型 描述 默认值
fn Callable

要设置为 loss_weight 属性的函数。此函数应接受三个参数:raw_loss、t 和 weight_type。

必需
源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
329
330
331
332
333
334
335
def set_loss_weight_fn(self, fn: Callable):
    """Sets the loss_weight attribute of the instance to the given function.

    Args:
        fn: The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.
    """
    self.loss_weight = fn

step(model_out, t, xt, dt, mask=None, center=False, temperature=1.0)

执行一步积分。

参数

名称 类型 描述 默认值
model_out Tensor

模型的输出。

必需
xt Tensor

当前数据点。

必需
t Tensor

当前时间步。

必需
dt Tensor

时间步增量。

必需
mask Optional[Tensor]

应用于数据的可选掩码。默认为 None。

None
center bool

是否居中数据。默认为 False。

False
temperature Float

低温采样的温度参数。默认为 1.0。

1.0
注意

温度参数控制多样性和样本质量之间的权衡。降低温度会锐化采样分布,以关注更可能的样本。低温采样的影响必须进行分析消融。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def step(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    dt: Tensor,
    mask: Optional[Tensor] = None,
    center: Bool = False,
    temperature: Float = 1.0,
):
    """Do one step integration.

    Args:
        model_out (Tensor): The output of the model.
        xt (Tensor): The current data point.
        t (Tensor): The current time step.
        dt (Tensor): The time step increment.
        mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
        center (bool): Whether to center the data. Defaults to False.
        temperature (Float): The temperature parameter for low temperature sampling. Defaults to 1.0.

    Note:
        The temperature parameter controls the trade off between diversity and sample quality.
        Decreasing the temperature sharpens the sampling distribtion to focus on more likely samples.
        The impact of low temperature sampling must be ablated analytically.
    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    x_hat = self.process_data_prediction(model_out, xt, t)

    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    alpha_t, sigma_t = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)

    if (t - dt < 0).any():
        raise ValueError(
            "Error in inference schedule: t - dt < 0. Please ensure that your inference time schedule has shape T with the final t = dt to make s = 0"
        )

    log_snr_s = self.noise_schedule.calculate_log_snr(t - dt, device=self.device)
    alpha_s, sigma_s = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr_s)
    sigma_s_2 = sigma_s * sigma_s
    sigma_t_2 = sigma_t * sigma_t
    alpha_t_s = alpha_t / alpha_s
    sigma_2_t_s = -torch.expm1(F.softplus(-log_snr_s) - F.softplus(-log_snr))  # Equation 63

    omega_r = alpha_t_s * sigma_s_2 / sigma_t_2  # Equation 28
    psi_r = alpha_s * sigma_2_t_s / sigma_t_2
    std = sigma_2_t_s.sqrt() * sigma_s / sigma_t
    nonzero_mask = (
        t > 0
    ).float()  # based on the time this is always just ones. can leave for now to see if ever want to take extra step and only grab mean

    psi_r = pad_like(psi_r, x_hat)
    omega_r = pad_like(omega_r, x_hat)
    std = pad_like(std, x_hat)
    nonzero_mask = pad_like(nonzero_mask, x_hat)

    mean = psi_r * x_hat + omega_r * xt
    eps = torch.randn_like(mean).to(model_out.device)
    x_next = mean + nonzero_mask * std * eps * temperature
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

step_ddim(model_out, t, xt, dt, mask=None, eta=0.0, center=False)

执行一步 DDIM 采样。

从 ddpm 方程 alpha_bar = alpha**2 和 1 - alpha**2 = sigma**2

参数

名称 类型 描述 默认值
model_out Tensor

模型的输出

必需
t Tensor

当前时间步

必需
xt Tensor

当前数据点

必需
dt Tensor

时间步增量。

必需
mask Optional[Tensor]

数据点的掩码。默认为 None。

None
eta Float

DDIM 采样参数。默认为 0.0。

0.0
center Bool

是否居中数据点。默认为 False。

False
源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def step_ddim(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    dt: Tensor,
    mask: Optional[Tensor] = None,
    eta: Float = 0.0,
    center: Bool = False,
):
    """Do one step of DDIM sampling.

    From the ddpm equations alpha_bar = alpha**2 and  1 - alpha**2 = sigma**2

    Args:
        model_out (Tensor): output of the model
        t (Tensor): current time step
        xt (Tensor): current data point
        dt (Tensor): The time step increment.
        mask (Optional[Tensor], optional): mask for the data point. Defaults to None.
        eta (Float, optional): DDIM sampling parameter. Defaults to 0.0.
        center (Bool, optional): whether to center the data point. Defaults to False.
    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    data_pred = self.process_data_prediction(model_out, xt, t)
    noise_pred = self.process_noise_prediction(model_out, xt, t)
    eps = torch.randn_like(data_pred).to(model_out.device)
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    squared_alpha = log_snr.sigmoid()
    squared_sigma = (-log_snr).sigmoid()
    log_snr_prev = self.noise_schedule.calculate_log_snr(t - dt, device=self.device)
    squared_alpha_prev = log_snr_prev.sigmoid()
    squared_sigma_prev = (-log_snr_prev).sigmoid()
    sigma_t_2 = squared_sigma_prev / squared_sigma * (1 - squared_alpha / squared_alpha_prev)
    psi_r = torch.sqrt(squared_alpha_prev)
    omega_r = torch.sqrt(1 - squared_alpha_prev - eta * eta * sigma_t_2)

    sigma_t_2 = pad_like(sigma_t_2, model_out)
    psi_r = pad_like(psi_r, model_out)
    omega_r = pad_like(omega_r, model_out)

    mean = data_pred * psi_r + omega_r * noise_pred
    x_next = mean + eta * sigma_t_2.sqrt() * eps
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

step_hybrid_sde(model_out, t, xt, dt, mask=None, center=False, temperature=1.0, equilibrium_rate=0.0)

执行一步混合 Langevin-逆时 SDE 积分。

参见 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf 第 37 页 B.3 节,以及 https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

参数

名称 类型 描述 默认值
model_out Tensor

模型的输出。

必需
xt Tensor

当前数据点。

必需
t Tensor

当前时间步。

必需
dt Tensor

时间步增量。

必需
mask Optional[Tensor]

应用于数据的可选掩码。默认为 None。

None
center bool

是否居中数据。默认为 False。

False
temperature Float

低温采样的温度参数。默认为 1.0。

1.0
equilibrium_rate Float

Langevin 平衡的速率。缩放单位时间内 Langevin 动力学的量。最佳值范围为 [1.0, 5.0]。默认为 0.0。

0.0

注意:对于所有使用 SDE 公式的步进函数,重要的是要注意我们在时间上向后移动,这对应于明显的符号变化。一个清晰的例子可以在幻灯片 29 https://ernestryu.com/courses/FM/diffusion1.pdf 中看到。

源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def step_hybrid_sde(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    dt: Tensor,
    mask: Optional[Tensor] = None,
    center: Bool = False,
    temperature: Float = 1.0,
    equilibrium_rate: Float = 0.0,
) -> Tensor:
    """Do one step integration of Hybrid Langevin-Reverse Time SDE.

    See section B.3 page 37 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
    and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

    Args:
        model_out (Tensor): The output of the model.
        xt (Tensor): The current data point.
        t (Tensor): The current time step.
        dt (Tensor): The time step increment.
        mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
        center (bool, optional): Whether to center the data. Defaults to False.
        temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.
        equilibrium_rate (Float, optional): The rate of Langevin equilibration.  Scales the amount of Langevin dynamics per unit time. Best values are in the range [1.0, 5.0]. Defaults to 0.0.

    Note:
    For all step functions that use the SDE formulation its important to note that we are moving backwards in time which corresponds to an apparent sign change.
    A clear example can be seen in slide 29 https://ernestryu.com/courses/FM/diffusion1.pdf.
    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    x_hat = self.process_data_prediction(model_out, xt, t)
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    alpha, sigma = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
    # Schedule coeffiecients
    beta = self.noise_schedule.calculate_beta(t)
    inverse_temperature = 1 / temperature  # lambda_0
    langevin_factor = equilibrium_rate
    # Temperature coefficients
    lambda_t = (
        inverse_temperature * (sigma.pow(2) + alpha.pow(2)) / (inverse_temperature * sigma.pow(2) + alpha.pow(2))
    )
    # langevin_isothermal = True
    lambda_langevin = inverse_temperature  # if langevin_isothermal else lambda_t

    score_scale_t = lambda_t + lambda_langevin * langevin_factor / 2.0

    eps = torch.randn_like(x_hat).to(model_out.device)
    score = self.score(x_hat, xt, t)
    beta = pad_like(beta, model_out)
    score_scale_t = pad_like(score_scale_t, model_out)

    gT = beta * ((-1 / 2) * xt - score_scale_t * score)
    gW = torch.sqrt((1.0 + langevin_factor) * beta.abs()) * eps

    x_next = xt + dt * gT + dt.sqrt() * gW
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

step_ode(model_out, t, xt, dt, mask=None, center=False, temperature=1.0)

执行一步 ODE 积分。

参见 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf 第 36 页 B 节,以及 https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

参数

名称 类型 描述 默认值
model_out Tensor

模型的输出。

必需
xt Tensor

当前数据点。

必需
t Tensor

当前时间步。

必需
dt Tensor

时间步增量。

必需
mask Optional[Tensor]

应用于数据的可选掩码。默认为 None。

None
center bool

是否居中数据。默认为 False。

False
temperature Float

低温采样的温度参数。默认为 1.0。

1.0
源代码位于 bionemo/moco/interpolants/continuous_time/continuous/vdm.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def step_ode(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    dt: Tensor,
    mask: Optional[Tensor] = None,
    center: Bool = False,
    temperature: Float = 1.0,
) -> Tensor:
    """Do one step integration of ODE.

    See section B page 36 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf.
    and https://github.com/generatebio/chroma/blob/929407c605013613941803c6113adefdccaad679/chroma/layers/structure/diffusion.py#L730

    Args:
        model_out (Tensor): The output of the model.
        xt (Tensor): The current data point.
        t (Tensor): The current time step.
        dt (Tensor): The time step increment.
        mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
        center (bool, optional): Whether to center the data. Defaults to False.
        temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.
    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    x_hat = self.process_data_prediction(model_out, xt, t)
    log_snr = self.noise_schedule.calculate_log_snr(t, device=self.device)
    alpha, sigma = self.noise_schedule.log_snr_to_alphas_sigmas(log_snr)
    # Schedule coeffiecients
    beta = self.noise_schedule.calculate_beta(t)
    inverse_temperature = 1 / temperature
    # Temperature coefficients
    lambda_t = (
        inverse_temperature * (sigma.pow(2) + alpha.pow(2)) / (inverse_temperature * sigma.pow(2) + alpha.pow(2))
    )

    score = self.score(x_hat, xt, t)
    beta = pad_like(beta, model_out)
    lambda_t = pad_like(lambda_t, model_out)

    gT = (-1 / 2) * beta * (xt + lambda_t * score)

    x_next = xt + gT * dt
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next