跳到内容

Ddpm

DDPM

基类:Interpolant

降噪扩散概率模型 (DDPM) 插值器。


示例

>>> import torch
>>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
>>> from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
>>> from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule


ddpm = DDPM(
    time_distribution = UniformTimeDistribution(discrete_time = True,...),
    prior_distribution = GaussianPrior(...),
    noise_schedule = DiscreteCosineNoiseSchedule(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = ddpm.sample_time(batch_size)
    noise = ddpm.sample_prior(data.shape)
    xt = ddpm.interpolate(data, noise, time)

    x_pred = model(xt, time)
    loss = ddpm.loss(x_pred, data, time)
    loss.backward()

# Generation
x_pred = ddpm.sample_prior(data.shape)
for t in DiscreteLinearTimeSchedule(...).generate_schedule():
    time = torch.full((batch_size,), t)
    x_hat = model(x_pred, time)
    x_pred = ddpm.step(x_hat, time, x_pred)
return x_pred

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
class DDPM(Interpolant):
    """A Denoising Diffusion Probabilistic Model (DDPM) interpolant.

     -------

    Examples:
    ```python
    >>> import torch
    >>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
    >>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
    >>> from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
    >>> from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
    >>> from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule


    ddpm = DDPM(
        time_distribution = UniformTimeDistribution(discrete_time = True,...),
        prior_distribution = GaussianPrior(...),
        noise_schedule = DiscreteCosineNoiseSchedule(...),
        )
    model = Model(...)

    # Training
    for epoch in range(1000):
        data = data_loader.get(...)
        time = ddpm.sample_time(batch_size)
        noise = ddpm.sample_prior(data.shape)
        xt = ddpm.interpolate(data, noise, time)

        x_pred = model(xt, time)
        loss = ddpm.loss(x_pred, data, time)
        loss.backward()

    # Generation
    x_pred = ddpm.sample_prior(data.shape)
    for t in DiscreteLinearTimeSchedule(...).generate_schedule():
        time = torch.full((batch_size,), t)
        x_hat = model(x_pred, time)
        x_pred = ddpm.step(x_hat, time, x_pred)
    return x_pred

    ```
    """

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: PriorDistribution,
        noise_schedule: DiscreteNoiseSchedule,
        prediction_type: Union[PredictionType, str] = PredictionType.DATA,
        device: Union[str, torch.device] = "cpu",
        last_time_idx: int = 0,
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initializes the DDPM interpolant.

        Args:
            time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
            prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
            noise_schedule (DiscreteNoiseSchedule): The schedule of noise, defining the amount of noise added at each time step.
            prediction_type (PredictionType): The type of prediction, either "data" or another type. Defaults to "data".
            device (str): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
            last_time_idx (int, optional): The last time index for discrete time. Set to 0 if discrete time is T-1, ..., 0 or 1 if T, ..., 1. Defaults to 0.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        """
        super().__init__(time_distribution, prior_distribution, device, rng_generator)
        if not isinstance(prior_distribution, GaussianPrior):
            warnings.warn("Prior distribution is not a GaussianPrior, unexpected behavior may occur")
        self.noise_schedule = noise_schedule
        self._initialize_schedules(device)
        self.prediction_type = string_to_enum(prediction_type, PredictionType)
        self._loss_function = nn.MSELoss(reduction="none")
        self.last_time_idx = last_time_idx

    def _initialize_schedules(self, device: Union[str, torch.device] = "cpu"):
        """Sets up the Denoising Diffusion Probabilistic Model (DDPM) equations.

        This method initializes the schedules for the forward and reverse processes of the DDPM. It calculates the
        alphas, betas, and log variances required for the diffusion process.

        Specifically, it computes:

        * `alpha_bar`: the cumulative product of `alpha_t`
        * `alpha_bar_prev`: the previous cumulative product of `alpha_t`
        * `posterior_variance`: the variance of the posterior distribution
        * `posterior_mean_c0_coef` and `posterior_mean_ct_coef`: the coefficients for the posterior mean
        * `log_var`: the log variance of the posterior distribution

        These values are then used to set up the forward and reverse schedules for the DDPM.
        Specifically this is equation (6) (7) from https://arxiv.org/pdf/2006.11239
        """
        if self.noise_schedule is None:
            raise ValueError("noise_schedule cannot be None for DDPM")
        alphas = self.noise_schedule.generate_schedule(device=device)
        betas = 1 - alphas
        log_alpha = torch.log(alphas)
        log_alpha_bar = torch.cumsum(log_alpha, dim=0)
        alpha_bar = alphas_cumprod = torch.exp(log_alpha_bar)
        alpha_bar_prev = alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        posterior_variance = betas * (1.0 - alpha_bar_prev) / (1.0 - alpha_bar)
        posterior_mean_c0_coef = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alpha_bar)
        posterior_mean_ct_coef = (1.0 - alpha_bar_prev) * torch.sqrt(alphas) / (1.0 - alpha_bar)
        # log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        posterior_logvar = torch.log(
            torch.nn.functional.pad(posterior_variance[:-1], (1, 0), value=posterior_variance[0].item())
        )
        self._forward_data_schedule = torch.sqrt(alpha_bar)
        self._forward_noise_schedule = torch.sqrt(1 - alpha_bar)
        self._reverse_data_schedule = posterior_mean_c0_coef
        self._reverse_noise_schedule = posterior_mean_ct_coef
        self._log_var = posterior_logvar
        self._alpha_bar = alpha_bar
        self._alpha_bar_prev = alpha_bar_prev
        self._betas = betas
        self._posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

    @property
    def forward_data_schedule(self) -> torch.Tensor:
        """Returns the forward data schedule."""
        return self._forward_data_schedule

    @property
    def forward_noise_schedule(self) -> torch.Tensor:
        """Returns the forward noise schedule."""
        return self._forward_noise_schedule

    @property
    def reverse_data_schedule(self) -> torch.Tensor:
        """Returns the reverse data schedule."""
        return self._reverse_data_schedule

    @property
    def reverse_noise_schedule(self) -> torch.Tensor:
        """Returns the reverse noise schedule."""
        return self._reverse_noise_schedule

    @property
    def log_var(self) -> torch.Tensor:
        """Returns the log variance."""
        return self._log_var

    @property
    def alpha_bar(self) -> torch.Tensor:
        """Returns the alpha bar values."""
        return self._alpha_bar

    @property
    def alpha_bar_prev(self) -> torch.Tensor:
        """Returns the previous alpha bar values."""
        return self._alpha_bar_prev

    def interpolate(self, data: Tensor, t: Tensor, noise: Tensor):
        """Get x(t) with given time t from noise and data.

        Args:
            data (Tensor): target
            t (Tensor): time
            noise (Tensor): noise from prior()
        """
        psi = safe_index(self._forward_data_schedule, t - self.last_time_idx, data.device)
        omega = safe_index(self._forward_noise_schedule, t - self.last_time_idx, data.device)
        psi = pad_like(psi, data)
        omega = pad_like(omega, data)
        x_t = data * psi + noise * omega
        return x_t

    def forward_process(self, data: Tensor, t: Tensor, noise: Optional[Tensor] = None):
        """Get x(t) with given time t from noise and data.

        Args:
            data (Tensor): target
            t (Tensor): time
            noise (Tensor, optional): noise from prior(). Defaults to None.
        """
        if noise is None:
            noise = self.sample_prior(data.shape)
        return self.interpolate(data, t, noise)

    def process_data_prediction(self, model_output: Tensor, sample: Tensor, t: Tensor):
        """Converts the model output to a data prediction based on the prediction type.

        This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
        Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
        The conversion formulas are as follows:
        - For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale`
        - For "data" prediction type: `pred_data = model_output`
        - For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output`

        Args:
            model_output (Tensor): The output of the model.
            sample (Tensor): The input sample.
            t (Tensor): The time step.

        Returns:
            The data prediction based on the prediction type.

        Raises:
            ValueError: If the prediction type is not one of "noise", "data", or "v_prediction".
        """
        data_scale = safe_index(self._forward_data_schedule, t - self.last_time_idx, model_output.device)
        noise_scale = safe_index(self._forward_noise_schedule, t - self.last_time_idx, model_output.device)
        data_scale = pad_like(data_scale, model_output)
        noise_scale = pad_like(noise_scale, model_output)
        if self.prediction_type == PredictionType.NOISE:
            pred_data = (sample - noise_scale * model_output) / data_scale
        elif self.prediction_type == PredictionType.DATA:
            pred_data = model_output
        elif self.prediction_type == PredictionType.VELOCITY:
            pred_data = data_scale * sample - noise_scale * model_output
        else:
            raise ValueError(
                f"prediction_type given as {self.prediction_type} must be one of PredictionType.NOISE, PredictionType.DATA or"
                f" PredictionType.VELOCITY for DDPM."
            )
        return pred_data

    def process_noise_prediction(self, model_output, sample, t):
        """Do the same as process_data_prediction but take the model output and convert to nosie.

        Args:
            model_output: The output of the model.
            sample: The input sample.
            t: The time step.

        Returns:
            The input as noise if the prediction type is "noise".

        Raises:
            ValueError: If the prediction type is not "noise".
        """
        data_scale = safe_index(self._forward_data_schedule, t - self.last_time_idx, model_output.device)
        noise_scale = safe_index(self._forward_noise_schedule, t - self.last_time_idx, model_output.device)
        data_scale = pad_like(data_scale, model_output)
        noise_scale = pad_like(noise_scale, model_output)
        if self.prediction_type == PredictionType.NOISE:
            pred_noise = model_output
        elif self.prediction_type == PredictionType.DATA:
            pred_noise = (sample - data_scale * model_output) / noise_scale
        elif self.prediction_type == PredictionType.VELOCITY:
            pred_data = data_scale * sample - noise_scale * model_output
            pred_noise = (sample - data_scale * pred_data) / noise_scale
        else:
            raise ValueError(
                f"prediction_type given as {self.prediction_type} must be one of `noise`, `data` or"
                " `v_prediction`  for DDPM."
            )
        return pred_noise

    def calculate_velocity(self, data: Tensor, t: Tensor, noise: Tensor) -> Tensor:
        """Calculate the velocity term given the data, time step, and noise.

        Args:
            data (Tensor): The input data.
            t (Tensor): The current time step.
            noise (Tensor): The noise term.

        Returns:
            Tensor: The calculated velocity term.
        """
        data_scale = safe_index(self._forward_data_schedule, t - self.last_time_idx, data.device)
        noise_scale = safe_index(self._forward_noise_schedule, t - self.last_time_idx, data.device)
        data_scale = pad_like(data_scale, data)
        noise_scale = pad_like(noise_scale, data)
        v = data_scale * noise - noise_scale * data
        return v

    @torch.no_grad()
    def step(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        mask: Optional[Tensor] = None,
        center: Bool = False,
        temperature: Float = 1.0,
    ):
        """Do one step integration.

        Args:
        model_out (Tensor): The output of the model.
        t (Tensor): The current time step.
        xt (Tensor): The current data point.
        mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
        center (bool, optional): Whether to center the data. Defaults to False.
        temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.

        Note:
        The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.

        Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
        For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0

        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        x_hat = self.process_data_prediction(model_out, xt, t)
        psi_r = safe_index(self._reverse_data_schedule, t - self.last_time_idx, x_hat.device)
        omega_r = safe_index(self._reverse_noise_schedule, t - self.last_time_idx, x_hat.device)
        log_var = safe_index(self._log_var, t - self.last_time_idx, x_hat.device)  # self._log_var[t.long()]
        nonzero_mask = (t > self.last_time_idx).float()
        psi_r = pad_like(psi_r, x_hat)
        omega_r = pad_like(omega_r, x_hat)
        log_var = pad_like(log_var, x_hat)
        nonzero_mask = pad_like(nonzero_mask, x_hat)

        mean = psi_r * x_hat + omega_r * xt
        eps = torch.randn_like(mean).to(model_out.device)

        x_next = mean + nonzero_mask * (0.5 * log_var).exp() * eps * temperature
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def step_noise(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        mask: Optional[Tensor] = None,
        center: Bool = False,
        temperature: Float = 1.0,
    ):
        """Do one step integration.

        Args:
        model_out (Tensor): The output of the model.
        t (Tensor): The current time step.
        xt (Tensor): The current data point.
        mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
        center (bool, optional): Whether to center the data. Defaults to False.
        temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.

        Note:
        The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.

        Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
        For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0

        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        eps_hat = self.process_noise_prediction(model_out, xt, t)
        beta_t = safe_index(self._betas, t - self.last_time_idx, model_out.device)
        recip_sqrt_alpha_t = torch.sqrt(1 / (1 - beta_t))
        eps_factor = (
            safe_index(self._betas, t - self.last_time_idx, model_out.device)
            / (1 - safe_index(self._alpha_bar, t - self.last_time_idx, model_out.device)).sqrt()
        )
        var = safe_index(self._posterior_variance, t - self.last_time_idx, model_out.device)  # self._log_var[t.long()]

        nonzero_mask = (t > self.last_time_idx).float()
        nonzero_mask = pad_like(nonzero_mask, model_out)
        eps_factor = pad_like(eps_factor, xt)
        recip_sqrt_alpha_t = pad_like(recip_sqrt_alpha_t, xt)
        var = pad_like(var, xt)

        x_next = recip_sqrt_alpha_t * (xt - eps_factor * eps_hat) + nonzero_mask * var.sqrt() * torch.randn_like(
            eps_hat
        ).to(model_out.device)
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def score(self, x_hat: Tensor, xt: Tensor, t: Tensor):
        """Converts the data prediction to the estimated score function.

        Args:
            x_hat (Tensor): The predicted data point.
            xt (Tensor): The current data point.
            t (Tensor): The time step.

        Returns:
            The estimated score function.
        """
        alpha = safe_index(self._forward_data_schedule, t - self.last_time_idx, x_hat.device)
        beta = safe_index(self._forward_noise_schedule, t - self.last_time_idx, x_hat.device)
        alpha = pad_like(alpha, x_hat)
        beta = pad_like(beta, x_hat)
        score = alpha * x_hat - xt
        score = score / (beta * beta)
        return score

    def step_ddim(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        mask: Optional[Tensor] = None,
        eta: Float = 0.0,
        center: Bool = False,
    ):
        """Do one step of DDIM sampling.

        Args:
            model_out (Tensor): output of the model
            t (Tensor): current time step
            xt (Tensor): current data point
            mask (Optional[Tensor], optional): mask for the data point. Defaults to None.
            eta (Float, optional): DDIM sampling parameter. Defaults to 0.0.
            center (Bool, optional): whether to center the data point. Defaults to False.
        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        data_pred = self.process_data_prediction(model_out, xt, t)
        noise_pred = self.process_noise_prediction(model_out, xt, t)
        eps = torch.randn_like(data_pred).to(model_out.device)
        sigma = (
            eta
            * torch.sqrt((1 - self._alpha_bar_prev) / (1 - self._alpha_bar))
            * torch.sqrt(1 - self._alpha_bar / self._alpha_bar_prev)
        )
        sigma_t = safe_index(sigma, t - self.last_time_idx, model_out.device)
        psi_r = safe_index(torch.sqrt(self._alpha_bar_prev), t - self.last_time_idx, model_out.device)
        omega_r = safe_index(torch.sqrt(1 - self._alpha_bar_prev - sigma**2), t - self.last_time_idx, model_out.device)
        sigma_t = pad_like(sigma_t, model_out)
        psi_r = pad_like(psi_r, model_out)
        omega_r = pad_like(omega_r, model_out)
        mean = data_pred * psi_r + omega_r * noise_pred
        x_next = mean + sigma_t * eps
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def set_loss_weight_fn(self, fn):
        """Sets the loss_weight attribute of the instance to the given function.

        Args:
            fn: The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.
        """
        self.loss_weight = fn

    def loss_weight(self, raw_loss: Tensor, t: Optional[Tensor], weight_type: str) -> Tensor:
        """Calculates the weight for the loss based on the given weight type.

        These data_to_noise loss weights is derived in Equation (9) of https://arxiv.org/pdf/2202.00512.

        Args:
            raw_loss (Tensor): The raw loss calculated from the model prediction and target.
            t (Tensor): The time step.
            weight_type (str): The type of weight to use. Can be "ones" or "data_to_noise" or "noise_to_data".

        Returns:
            Tensor: The weight for the loss.

        Raises:
            ValueError: If the weight type is not recognized.
        """
        if weight_type == "ones":
            schedule = torch.ones_like(raw_loss).to(raw_loss.device)
        elif weight_type == "data_to_noise":
            if t is None:
                raise ValueError("Time cannot be None when using the data_to_noise loss weight")
            schedule = (safe_index(self._forward_data_schedule, t - self.last_time_idx, raw_loss.device) ** 2) / (
                safe_index(self._forward_noise_schedule, t - self.last_time_idx, raw_loss.device) ** 2
            )
            schedule = pad_like(schedule, raw_loss)
        elif weight_type == "noise_to_data":
            if t is None:
                raise ValueError("Time cannot be None when using the data_to_noise loss weight")
            schedule = (safe_index(self._forward_noise_schedule, t - self.last_time_idx, raw_loss.device) ** 2) / (
                safe_index(self._forward_data_schedule, t - self.last_time_idx, raw_loss.device) ** 2
            )
            schedule = pad_like(schedule, raw_loss)
        else:
            raise ValueError("Invalid loss weight keyword")
        return schedule

    def loss(
        self,
        model_pred: Tensor,
        target: Tensor,
        t: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        weight_type: str = "ones",
    ):
        """Calculate the loss given the model prediction, data sample, and time.

        Args:
            model_pred (Tensor): The predicted output from the model.
            target (Tensor): The target output for the model prediction.
            t (Tensor): The time at which the loss is calculated.
            mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
            weight_type (str, optional): The type of weight to use for the loss. Defaults to "ones".

        Returns:
            Tensor: The calculated loss batch tensor.
        """
        raw_loss = self._loss_function(model_pred, target)
        update_weight = self.loss_weight(raw_loss, t, weight_type)
        loss = raw_loss * update_weight
        if mask is not None:
            loss = loss * mask.unsqueeze(-1)
            n_elem = torch.sum(mask, dim=-1)
            loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / n_elem
        else:
            loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / model_pred.size(1)
        return loss

alpha_bar: torch.Tensor property

返回 alpha bar 值。

alpha_bar_prev: torch.Tensor property

返回之前的 alpha bar 值。

forward_data_schedule: torch.Tensor property

返回前向数据计划。

forward_noise_schedule: torch.Tensor property

返回前向噪声计划。

log_var: torch.Tensor property

返回对数方差。

reverse_data_schedule: torch.Tensor property

返回反向数据计划。

reverse_noise_schedule: torch.Tensor property

返回反向噪声计划。

__init__(time_distribution, prior_distribution, noise_schedule, prediction_type=PredictionType.DATA, device='cpu', last_time_idx=0, rng_generator=None)

初始化 DDPM 插值器。

参数

名称 类型 描述 默认值
time_distribution TimeDistribution

时间步长的分布,用于为扩散过程采样时间点。

必需
prior_distribution PriorDistribution

变量的先验分布,用作扩散过程的起点。

必需
noise_schedule DiscreteNoiseSchedule

噪声计划,定义在每个时间步长添加的噪声量。

必需
prediction_type PredictionType

预测类型,可以是“data”或另一种类型。默认为“data”。

DATA
device str

运行插值器的设备,可以是“cpu”或 CUDA 设备(例如“cuda:0”)。默认为“cpu”。

'cpu'
last_time_idx int

离散时间的最后一个时间索引。如果离散时间为 T-1, ..., 0,则设置为 0;如果为 T, ..., 1,则设置为 1。默认为 0。

0
rng_generator Optional[Generator]

可选的 :class:torch.Generator,用于可重复的采样。默认为 None。

None
源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: PriorDistribution,
    noise_schedule: DiscreteNoiseSchedule,
    prediction_type: Union[PredictionType, str] = PredictionType.DATA,
    device: Union[str, torch.device] = "cpu",
    last_time_idx: int = 0,
    rng_generator: Optional[torch.Generator] = None,
):
    """Initializes the DDPM interpolant.

    Args:
        time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
        prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
        noise_schedule (DiscreteNoiseSchedule): The schedule of noise, defining the amount of noise added at each time step.
        prediction_type (PredictionType): The type of prediction, either "data" or another type. Defaults to "data".
        device (str): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
        last_time_idx (int, optional): The last time index for discrete time. Set to 0 if discrete time is T-1, ..., 0 or 1 if T, ..., 1. Defaults to 0.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
    """
    super().__init__(time_distribution, prior_distribution, device, rng_generator)
    if not isinstance(prior_distribution, GaussianPrior):
        warnings.warn("Prior distribution is not a GaussianPrior, unexpected behavior may occur")
    self.noise_schedule = noise_schedule
    self._initialize_schedules(device)
    self.prediction_type = string_to_enum(prediction_type, PredictionType)
    self._loss_function = nn.MSELoss(reduction="none")
    self.last_time_idx = last_time_idx

calculate_velocity(data, t, noise)

计算给定数据、时间步长和噪声的速度项。

参数

名称 类型 描述 默认值
data Tensor

输入数据。

必需
t Tensor

当前时间步长。

必需
noise Tensor

噪声项。

必需

返回

名称 类型 描述
Tensor Tensor

计算出的速度项。

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def calculate_velocity(self, data: Tensor, t: Tensor, noise: Tensor) -> Tensor:
    """Calculate the velocity term given the data, time step, and noise.

    Args:
        data (Tensor): The input data.
        t (Tensor): The current time step.
        noise (Tensor): The noise term.

    Returns:
        Tensor: The calculated velocity term.
    """
    data_scale = safe_index(self._forward_data_schedule, t - self.last_time_idx, data.device)
    noise_scale = safe_index(self._forward_noise_schedule, t - self.last_time_idx, data.device)
    data_scale = pad_like(data_scale, data)
    noise_scale = pad_like(noise_scale, data)
    v = data_scale * noise - noise_scale * data
    return v

forward_process(data, t, noise=None)

从噪声和数据中获取给定时间 t 的 x(t)。

参数

名称 类型 描述 默认值
data Tensor

目标

必需
t Tensor

时间

必需
noise Tensor

来自 prior() 的噪声。默认为 None。

None
源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
199
200
201
202
203
204
205
206
207
208
209
def forward_process(self, data: Tensor, t: Tensor, noise: Optional[Tensor] = None):
    """Get x(t) with given time t from noise and data.

    Args:
        data (Tensor): target
        t (Tensor): time
        noise (Tensor, optional): noise from prior(). Defaults to None.
    """
    if noise is None:
        noise = self.sample_prior(data.shape)
    return self.interpolate(data, t, noise)

interpolate(data, t, noise)

从噪声和数据中获取给定时间 t 的 x(t)。

参数

名称 类型 描述 默认值
data Tensor

目标

必需
t Tensor

时间

必需
noise Tensor

来自 prior() 的噪声

必需
源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def interpolate(self, data: Tensor, t: Tensor, noise: Tensor):
    """Get x(t) with given time t from noise and data.

    Args:
        data (Tensor): target
        t (Tensor): time
        noise (Tensor): noise from prior()
    """
    psi = safe_index(self._forward_data_schedule, t - self.last_time_idx, data.device)
    omega = safe_index(self._forward_noise_schedule, t - self.last_time_idx, data.device)
    psi = pad_like(psi, data)
    omega = pad_like(omega, data)
    x_t = data * psi + noise * omega
    return x_t

loss(model_pred, target, t=None, mask=None, weight_type='ones')

计算给定模型预测、数据样本和时间的损失。

参数

名称 类型 描述 默认值
model_pred Tensor

模型预测的输出。

必需
目标 Tensor

目标

必需
t Tensor

模型预测的目标输出。

None
t 计算损失的时间。

mask

None
Optional[Tensor] str

数据点的掩码。默认为 None。

weight_type

返回

名称 类型 描述
Tensor

用于损失的权重类型。默认为“ones”。

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def loss(
    self,
    model_pred: Tensor,
    target: Tensor,
    t: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
    weight_type: str = "ones",
):
    """Calculate the loss given the model prediction, data sample, and time.

    Args:
        model_pred (Tensor): The predicted output from the model.
        target (Tensor): The target output for the model prediction.
        t (Tensor): The time at which the loss is calculated.
        mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
        weight_type (str, optional): The type of weight to use for the loss. Defaults to "ones".

    Returns:
        Tensor: The calculated loss batch tensor.
    """
    raw_loss = self._loss_function(model_pred, target)
    update_weight = self.loss_weight(raw_loss, t, weight_type)
    loss = raw_loss * update_weight
    if mask is not None:
        loss = loss * mask.unsqueeze(-1)
        n_elem = torch.sum(mask, dim=-1)
        loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / n_elem
    else:
        loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / model_pred.size(1)
    return loss

'ones'

计算出的损失批张量。

loss_weight(raw_loss, t, weight_type)

参数

名称 类型 描述 默认值
根据给定的权重类型计算损失的权重。 Tensor

这些 data_to_noise 损失权重来源于 https://arxiv.org/pdf/2202.00512 的公式 (9)。

必需
t Tensor

raw_loss

必需
Optional[Tensor] str

从模型预测和目标计算的原始损失。

必需

返回

名称 类型 描述
Tensor Tensor

t

时间步长。

类型 描述
weight_type

要使用的权重类型。可以是“ones”、“data_to_noise”或“noise_to_data”。

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def loss_weight(self, raw_loss: Tensor, t: Optional[Tensor], weight_type: str) -> Tensor:
    """Calculates the weight for the loss based on the given weight type.

    These data_to_noise loss weights is derived in Equation (9) of https://arxiv.org/pdf/2202.00512.

    Args:
        raw_loss (Tensor): The raw loss calculated from the model prediction and target.
        t (Tensor): The time step.
        weight_type (str): The type of weight to use. Can be "ones" or "data_to_noise" or "noise_to_data".

    Returns:
        Tensor: The weight for the loss.

    Raises:
        ValueError: If the weight type is not recognized.
    """
    if weight_type == "ones":
        schedule = torch.ones_like(raw_loss).to(raw_loss.device)
    elif weight_type == "data_to_noise":
        if t is None:
            raise ValueError("Time cannot be None when using the data_to_noise loss weight")
        schedule = (safe_index(self._forward_data_schedule, t - self.last_time_idx, raw_loss.device) ** 2) / (
            safe_index(self._forward_noise_schedule, t - self.last_time_idx, raw_loss.device) ** 2
        )
        schedule = pad_like(schedule, raw_loss)
    elif weight_type == "noise_to_data":
        if t is None:
            raise ValueError("Time cannot be None when using the data_to_noise loss weight")
        schedule = (safe_index(self._forward_noise_schedule, t - self.last_time_idx, raw_loss.device) ** 2) / (
            safe_index(self._forward_data_schedule, t - self.last_time_idx, raw_loss.device) ** 2
        )
        schedule = pad_like(schedule, raw_loss)
    else:
        raise ValueError("Invalid loss weight keyword")
    return schedule

返回

损失的权重。

Raises

参数

名称 类型 描述 默认值
ValueError Tensor

如果权重类型无法识别。

必需
process_data_prediction(model_output, sample, t) Tensor

根据预测类型将模型输出转换为数据预测。

必需
t Tensor

raw_loss

必需

返回

类型 描述

此转换源于《扩散模型快速采样的渐进式蒸馏》https://arxiv.org/pdf/2202.00512。给定模型输出和样本,我们根据预测类型将输出转换为数据预测。转换公式如下: - 对于“noise”预测类型:pred_data = (sample - noise_scale * model_output) / data_scale - 对于“data”预测类型:pred_data = model_output - 对于“v_prediction”预测类型:pred_data = data_scale * sample - noise_scale * model_output

时间步长。

类型 描述
weight_type

model_output

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def process_data_prediction(self, model_output: Tensor, sample: Tensor, t: Tensor):
    """Converts the model output to a data prediction based on the prediction type.

    This conversion stems from the Progressive Distillation for Fast Sampling of Diffusion Models https://arxiv.org/pdf/2202.00512.
    Given the model output and the sample, we convert the output to a data prediction based on the prediction type.
    The conversion formulas are as follows:
    - For "noise" prediction type: `pred_data = (sample - noise_scale * model_output) / data_scale`
    - For "data" prediction type: `pred_data = model_output`
    - For "v_prediction" prediction type: `pred_data = data_scale * sample - noise_scale * model_output`

    Args:
        model_output (Tensor): The output of the model.
        sample (Tensor): The input sample.
        t (Tensor): The time step.

    Returns:
        The data prediction based on the prediction type.

    Raises:
        ValueError: If the prediction type is not one of "noise", "data", or "v_prediction".
    """
    data_scale = safe_index(self._forward_data_schedule, t - self.last_time_idx, model_output.device)
    noise_scale = safe_index(self._forward_noise_schedule, t - self.last_time_idx, model_output.device)
    data_scale = pad_like(data_scale, model_output)
    noise_scale = pad_like(noise_scale, model_output)
    if self.prediction_type == PredictionType.NOISE:
        pred_data = (sample - noise_scale * model_output) / data_scale
    elif self.prediction_type == PredictionType.DATA:
        pred_data = model_output
    elif self.prediction_type == PredictionType.VELOCITY:
        pred_data = data_scale * sample - noise_scale * model_output
    else:
        raise ValueError(
            f"prediction_type given as {self.prediction_type} must be one of PredictionType.NOISE, PredictionType.DATA or"
            f" PredictionType.VELOCITY for DDPM."
        )
    return pred_data

模型的输出。

sample

参数

名称 类型 描述 默认值
ValueError

如果权重类型无法识别。

必需
process_data_prediction(model_output, sample, t)

根据预测类型将模型输出转换为数据预测。

必需
t

raw_loss

必需

返回

类型 描述

输入样本。

时间步长。

类型 描述
weight_type

返回

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def process_noise_prediction(self, model_output, sample, t):
    """Do the same as process_data_prediction but take the model output and convert to nosie.

    Args:
        model_output: The output of the model.
        sample: The input sample.
        t: The time step.

    Returns:
        The input as noise if the prediction type is "noise".

    Raises:
        ValueError: If the prediction type is not "noise".
    """
    data_scale = safe_index(self._forward_data_schedule, t - self.last_time_idx, model_output.device)
    noise_scale = safe_index(self._forward_noise_schedule, t - self.last_time_idx, model_output.device)
    data_scale = pad_like(data_scale, model_output)
    noise_scale = pad_like(noise_scale, model_output)
    if self.prediction_type == PredictionType.NOISE:
        pred_noise = model_output
    elif self.prediction_type == PredictionType.DATA:
        pred_noise = (sample - data_scale * model_output) / noise_scale
    elif self.prediction_type == PredictionType.VELOCITY:
        pred_data = data_scale * sample - noise_scale * model_output
        pred_noise = (sample - data_scale * pred_data) / noise_scale
    else:
        raise ValueError(
            f"prediction_type given as {self.prediction_type} must be one of `noise`, `data` or"
            " `v_prediction`  for DDPM."
        )
    return pred_noise

基于预测类型的数据预测。

Raises

参数

名称 类型 描述 默认值
ValueError Tensor

如果预测类型不是“noise”、“data”或“v_prediction”之一。

必需
process_noise_prediction(model_output, sample, t) Tensor

执行与 process_data_prediction 相同的操作,但获取模型输出并转换为噪声。

必需
t Tensor

raw_loss

必需

返回

类型 描述

返回

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def score(self, x_hat: Tensor, xt: Tensor, t: Tensor):
    """Converts the data prediction to the estimated score function.

    Args:
        x_hat (Tensor): The predicted data point.
        xt (Tensor): The current data point.
        t (Tensor): The time step.

    Returns:
        The estimated score function.
    """
    alpha = safe_index(self._forward_data_schedule, t - self.last_time_idx, x_hat.device)
    beta = safe_index(self._forward_noise_schedule, t - self.last_time_idx, x_hat.device)
    alpha = pad_like(alpha, x_hat)
    beta = pad_like(beta, x_hat)
    score = alpha * x_hat - xt
    score = score / (beta * beta)
    return score

如果预测类型为“noise”,则将输入作为噪声。

Raises

参数

名称 类型 描述 默认值
ValueError

如果预测类型不是“noise”。

必需
源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
453
454
455
456
457
458
459
def set_loss_weight_fn(self, fn):
    """Sets the loss_weight attribute of the instance to the given function.

    Args:
        fn: The function to set as the loss_weight attribute. This function should take three arguments: raw_loss, t, and weight_type.
    """
    self.loss_weight = fn

score(x_hat, xt, t)

将数据预测转换为估计的得分函数。

参数

x_hat

预测的数据点。

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@torch.no_grad()
def step(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    mask: Optional[Tensor] = None,
    center: Bool = False,
    temperature: Float = 1.0,
):
    """Do one step integration.

    Args:
    model_out (Tensor): The output of the model.
    t (Tensor): The current time step.
    xt (Tensor): The current data point.
    mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
    center (bool, optional): Whether to center the data. Defaults to False.
    temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.

    Note:
    The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.

    Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
    For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0

    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    x_hat = self.process_data_prediction(model_out, xt, t)
    psi_r = safe_index(self._reverse_data_schedule, t - self.last_time_idx, x_hat.device)
    omega_r = safe_index(self._reverse_noise_schedule, t - self.last_time_idx, x_hat.device)
    log_var = safe_index(self._log_var, t - self.last_time_idx, x_hat.device)  # self._log_var[t.long()]
    nonzero_mask = (t > self.last_time_idx).float()
    psi_r = pad_like(psi_r, x_hat)
    omega_r = pad_like(omega_r, x_hat)
    log_var = pad_like(log_var, x_hat)
    nonzero_mask = pad_like(nonzero_mask, x_hat)

    mean = psi_r * x_hat + omega_r * xt
    eps = torch.randn_like(mean).to(model_out.device)

    x_next = mean + nonzero_mask * (0.5 * log_var).exp() * eps * temperature
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

xt

当前数据点。

参数

名称 类型 描述 默认值
返回 Tensor

估计的得分函数。

必需
t Tensor

set_loss_weight_fn(fn)

必需
process_noise_prediction(model_output, sample, t) Tensor

将实例的 loss_weight 属性设置为给定的函数。

必需
t 计算损失的时间。

参数

None
fn 要设置为 loss_weight 属性的函数。此函数应接受三个参数:raw_loss、t 和 weight_type。

step(model_out, t, xt, mask=None, center=False, temperature=1.0)

0.0
执行一步积分。 Args: model_out (Tensor): 模型的输出。 t (Tensor): 当前时间步长。 xt (Tensor): 当前数据点。 mask (Optional[Tensor], optional): 应用于数据的可选掩码。默认为 None。 center (bool, optional): 是否居中数据。默认为 False。 temperature (Float, optional): 低温采样的温度参数。默认为 1.0。

注意:温度参数控制采样过程中随机性的程度。温度 1.0 对应于标准扩散采样,而较低的温度(例如 0.5、0.2)会导致更少的随机性和更具确定性的样本。这对于需要更多地控制生成过程的任务可能很有用。

注意,对于离散时间,我们从 [T-1, ..., 1, 0] 中采样 T 步,因此我们采样 t = 0,因此需要掩码。对于连续时间,我们从 [1, 1 -dt, ..., dt] 开始 T 步,其中 s = t - 1,当 t = 0 时,即 dt 为 0
源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def step_ddim(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    mask: Optional[Tensor] = None,
    eta: Float = 0.0,
    center: Bool = False,
):
    """Do one step of DDIM sampling.

    Args:
        model_out (Tensor): output of the model
        t (Tensor): current time step
        xt (Tensor): current data point
        mask (Optional[Tensor], optional): mask for the data point. Defaults to None.
        eta (Float, optional): DDIM sampling parameter. Defaults to 0.0.
        center (Bool, optional): whether to center the data point. Defaults to False.
    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    data_pred = self.process_data_prediction(model_out, xt, t)
    noise_pred = self.process_noise_prediction(model_out, xt, t)
    eps = torch.randn_like(data_pred).to(model_out.device)
    sigma = (
        eta
        * torch.sqrt((1 - self._alpha_bar_prev) / (1 - self._alpha_bar))
        * torch.sqrt(1 - self._alpha_bar / self._alpha_bar_prev)
    )
    sigma_t = safe_index(sigma, t - self.last_time_idx, model_out.device)
    psi_r = safe_index(torch.sqrt(self._alpha_bar_prev), t - self.last_time_idx, model_out.device)
    omega_r = safe_index(torch.sqrt(1 - self._alpha_bar_prev - sigma**2), t - self.last_time_idx, model_out.device)
    sigma_t = pad_like(sigma_t, model_out)
    psi_r = pad_like(psi_r, model_out)
    omega_r = pad_like(omega_r, model_out)
    mean = data_pred * psi_r + omega_r * noise_pred
    x_next = mean + sigma_t * eps
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

step_ddim(model_out, t, xt, mask=None, eta=0.0, center=False)

将数据预测转换为估计的得分函数。

参数

x_hat

预测的数据点。

源代码位于 bionemo/moco/interpolants/discrete_time/continuous/ddpm.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def step_noise(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    mask: Optional[Tensor] = None,
    center: Bool = False,
    temperature: Float = 1.0,
):
    """Do one step integration.

    Args:
    model_out (Tensor): The output of the model.
    t (Tensor): The current time step.
    xt (Tensor): The current data point.
    mask (Optional[Tensor], optional): An optional mask to apply to the data. Defaults to None.
    center (bool, optional): Whether to center the data. Defaults to False.
    temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.

    Note:
    The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.

    Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
    For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0

    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    eps_hat = self.process_noise_prediction(model_out, xt, t)
    beta_t = safe_index(self._betas, t - self.last_time_idx, model_out.device)
    recip_sqrt_alpha_t = torch.sqrt(1 / (1 - beta_t))
    eps_factor = (
        safe_index(self._betas, t - self.last_time_idx, model_out.device)
        / (1 - safe_index(self._alpha_bar, t - self.last_time_idx, model_out.device)).sqrt()
    )
    var = safe_index(self._posterior_variance, t - self.last_time_idx, model_out.device)  # self._log_var[t.long()]

    nonzero_mask = (t > self.last_time_idx).float()
    nonzero_mask = pad_like(nonzero_mask, model_out)
    eps_factor = pad_like(eps_factor, xt)
    recip_sqrt_alpha_t = pad_like(recip_sqrt_alpha_t, xt)
    var = pad_like(var, xt)

    x_next = recip_sqrt_alpha_t * (xt - eps_factor * eps_hat) + nonzero_mask * var.sqrt() * torch.randn_like(
        eps_hat
    ).to(model_out.device)
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next