跳到内容

连续流匹配

ContinuousFlowMatcher

Bases: Interpolant

连续流匹配插值器。


示例

>>> import torch
>>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
>>> from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule

flow_matcher = ContinuousFlowMatcher(
    time_distribution = UniformTimeDistribution(...),
    prior_distribution = GaussianPrior(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = flow_matcher.sample_time(batch_size)
    noise = flow_matcher.sample_prior(data.shape)
    data, time, noise = flow_matcher.apply_ot(noise, data) # Optional, only for OT
    xt = flow_matcher.interpolate(data, time, noise)
    flow = flow_matcher.calculate_target(data, noise)

    u_pred = model(xt, time)
    loss = flow_matcher.loss(u_pred, flow)
    loss.backward()

# Generation
x_pred = flow_matcher.sample_prior(data.shape)
inference_sched = LinearInferenceSchedule(...)
for t in inference_sched.generate_schedule():
    time = inference_sched.pad_time(x_pred.shape[0], t)
    u_hat = model(x_pred, time)
    x_pred = flow_matcher.step(u_hat, x_pred, time)
return x_pred

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
class ContinuousFlowMatcher(Interpolant):
    """A Continuous Flow Matching interpolant.

     -------

    Examples:
    ```python
    >>> import torch
    >>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
    >>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
    >>> from bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
    >>> from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule

    flow_matcher = ContinuousFlowMatcher(
        time_distribution = UniformTimeDistribution(...),
        prior_distribution = GaussianPrior(...),
        )
    model = Model(...)

    # Training
    for epoch in range(1000):
        data = data_loader.get(...)
        time = flow_matcher.sample_time(batch_size)
        noise = flow_matcher.sample_prior(data.shape)
        data, time, noise = flow_matcher.apply_ot(noise, data) # Optional, only for OT
        xt = flow_matcher.interpolate(data, time, noise)
        flow = flow_matcher.calculate_target(data, noise)

        u_pred = model(xt, time)
        loss = flow_matcher.loss(u_pred, flow)
        loss.backward()

    # Generation
    x_pred = flow_matcher.sample_prior(data.shape)
    inference_sched = LinearInferenceSchedule(...)
    for t in inference_sched.generate_schedule():
        time = inference_sched.pad_time(x_pred.shape[0], t)
        u_hat = model(x_pred, time)
        x_pred = flow_matcher.step(u_hat, x_pred, time)
    return x_pred

    ```
    """

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: PriorDistribution,
        prediction_type: Union[PredictionType, str] = PredictionType.DATA,
        sigma: Float = 0,
        ot_type: Optional[Union[OptimalTransportType, str]] = None,
        ot_num_threads: int = 1,
        data_scale: Float = 1.0,
        device: Union[str, torch.device] = "cpu",
        rng_generator: Optional[torch.Generator] = None,
        eps: Float = 1e-5,
    ):
        """Initializes the Continuous Flow Matching interpolant.

        Args:
            time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
            prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
            prediction_type (PredictionType, optional): The type of prediction, either "flow" or another type. Defaults to PredictionType.DATA.
            sigma (Float, optional): The standard deviation of the Gaussian noise added to the interpolated data. Defaults to 0.
            ot_type (Optional[Union[OptimalTransportType, str]], optional): The type of optimal transport, if applicable. Defaults to None.
            ot_num_threads:  Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.
            data_scale (Float, optional): The scale factor for the data. Defaults to 1.0.
            device (Union[str, torch.device], optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
            eps: Small float to prevent divide by zero
        """
        super().__init__(time_distribution, prior_distribution, device, rng_generator)
        self.prediction_type = string_to_enum(prediction_type, PredictionType)
        self.sigma = sigma
        self.ot_type = ot_type
        self.data_scale = data_scale
        self.eps = eps
        if data_scale <= 0:
            raise ValueError("Data Scale must be > 0")
        if ot_type is not None:
            self.ot_type = ot_type = string_to_enum(ot_type, OptimalTransportType)
            self.ot_sampler = self._build_ot_sampler(method_type=ot_type, num_threads=ot_num_threads)
        self._loss_function = nn.MSELoss(reduction="none")

    def _build_ot_sampler(self, method_type: OptimalTransportType, num_threads: int = 1):
        """Build the optimal transport sampler for the given optimal transport type.

        Args:
            method_type (OptimalTransportType): The type of augmentation.
            num_threads (int): The number of threads to use for the OT sampler, default to 1.

        Returns:
            The augmentation object.
        """
        return BatchAugmentation(self.device, num_threads).create(method_type)

    def apply_ot(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None, **kwargs) -> tuple:
        """Sample and apply the optimal transport plan between batched (and masked) x0 and x1.

        Args:
            x0 (Tensor): shape (bs, *dim), noise from source minibatch.
            x1 (Tensor): shape (bs, *dim), data from source minibatch.
            mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
            **kwargs: Additional keyword arguments to be passed to self.ot_sampler.apply_ot or handled within this method.


        Returns:
            Tuple: tuple of 2 tensors, represents the noise and data samples following OT plan pi.
        """
        if self.ot_sampler is None:
            raise ValueError("Optimal Transport Sampler is not defined")
        return self.ot_sampler.apply_ot(x0, x1, mask=mask, **kwargs)

    def undo_scale_data(self, data: Tensor) -> Tensor:
        """Downscale the input data by the data scale factor.

        Args:
            data (Tensor): The input data to downscale.

        Returns:
            The downscaled data.
        """
        return 1 / self.data_scale * data

    def scale_data(self, data: Tensor) -> Tensor:
        """Upscale the input data by the data scale factor.

        Args:
            data (Tensor): The input data to upscale.

        Returns:
            The upscaled data.
        """
        return self.data_scale * data

    def interpolate(self, data: Tensor, t: Tensor, noise: Tensor) -> Tensor:
        """Get x_t with given time t from noise (x_0) and data (x_1).

        Currently, we use the linear interpolation as defined in:
            1. Rectified flow: https://arxiv.org/abs/2209.03003.
            2. Conditional flow matching: https://arxiv.org/abs/2210.02747 (called conditional optimal transport).

        Args:
            noise (Tensor): noise from prior(), shape (batchsize, nodes, features)
            t (Tensor): time, shape (batchsize)
            data (Tensor): target, shape (batchsize, nodes, features)
        """
        assert data.size() == noise.size()
        # Expand t to the same shape as noise: ones([b,n,f]) * t([b,1,1])
        t = pad_like(t, data)
        # Calculate x_t as the linear interpolation between noise and data
        x_t = data * t + noise * (1.0 - t)
        # Add Gaussian Noise
        if self.sigma > 0:
            x_t += self.sigma * torch.randn(x_t.shape, device=x_t.device, generator=self.rng_generator)
        return x_t

    def calculate_target(self, data: Tensor, noise: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """Get the target vector field at time t.

        Args:
            noise (Tensor): noise from prior(), shape (batchsize, nodes, features)
            data (Tensor): target, shape (batchsize, nodes, features)
            mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

        Returns:
            Tensor: The target vector field at time t.
        """
        assert data.size() == noise.size()
        # Calculate the target vector field u_t(x_t|x_1) as the difference between data and noise because t~[0,1]
        if self.prediction_type == PredictionType.VELOCITY:
            u_t = data - noise
        elif self.prediction_type == PredictionType.DATA:
            u_t = data
        else:
            raise ValueError(
                f"Given prediction_type {self.prediction_type} is not supproted for Continuous Flow Matching."
            )
        if mask is not None:
            u_t = u_t * mask.unsqueeze(-1)
        return u_t

    def process_vector_field_prediction(
        self,
        model_output: Tensor,
        xt: Optional[Tensor] = None,
        t: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
    ):
        """Process the model output based on the prediction type to calculate vecotr field.

        Args:
            model_output (Tensor): The output of the model.
            xt (Tensor): The input sample.
            t (Tensor): The time step.
            mask (Optional[Tensor], optional): An optional mask to apply to the model output. Defaults to None.

        Returns:
            The vector field prediction based on the prediction type.

        Raises:
            ValueError: If the prediction type is not "flow" or "data".
        """
        if self.prediction_type == PredictionType.VELOCITY:
            pred_vector_field = model_output
        elif self.prediction_type == PredictionType.DATA:
            if xt is None or t is None:
                raise ValueError("Xt and Time cannpt be None for vector field conversion")
            t = pad_like(t, model_output)
            pred_vector_field = (model_output - xt) / (1 - t + self.eps)
        else:
            raise ValueError(
                f"prediction_type given as {self.prediction_type} must be `flow` or `data` "
                "for Continuous Flow Matching."
            )
        if mask is not None:
            pred_vector_field = pred_vector_field * mask.unsqueeze(-1)
        return pred_vector_field

    def process_data_prediction(
        self,
        model_output: Tensor,
        xt: Optional[Tensor] = None,
        t: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
    ):
        """Process the model output based on the prediction type to generate clean data.

        Args:
            model_output (Tensor): The output of the model.
            xt (Tensor): The input sample.
            t (Tensor): The time step.
            mask (Optional[Tensor], optional): An optional mask to apply to the model output. Defaults to None.

        Returns:
            The data prediction based on the prediction type.

        Raises:
            ValueError: If the prediction type is not "flow".
        """
        if self.prediction_type == PredictionType.VELOCITY:
            if xt is None or t is None:
                raise ValueError("Xt and time cannot be None")
            t = pad_like(t, model_output)
            pred_data = xt + (1 - t) * model_output
        elif self.prediction_type == PredictionType.DATA:
            pred_data = model_output
        else:
            raise ValueError(
                f"prediction_type given as {self.prediction_type} must be `flow` " "for Continuous Flow Matching."
            )
        if mask is not None:
            pred_data = pred_data * mask.unsqueeze(-1)
        return pred_data

    def step(
        self,
        model_out: Tensor,
        xt: Tensor,
        dt: Tensor,
        t: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        center: Bool = False,
    ):
        """Perform a single ODE step integration using Euler method.

        Args:
            model_out (Tensor): The output of the model at the current time step.
            xt (Tensor): The current intermediate state.
            dt (Tensor): The time step size.
            t (Tensor, optional): The current time. Defaults to None.
            mask (Optional[Tensor], optional): A mask to apply to the model output. Defaults to None.
            center (Bool, optional): Whether to center the output. Defaults to False.

        Returns:
            x_next (Tensor): The updated state of the system after the single step, x_(t+dt).

        Notes:
        - If a mask is provided, it is applied element-wise to the model output before scaling.
        - The `clean` method is called on the updated state before it is returned.
        """
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        v_t = self.process_vector_field_prediction(model_out, xt, t, mask)
        dt = pad_like(dt, model_out)
        delta_x = v_t * dt
        x_next = xt + delta_x
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def step_score_stochastic(
        self,
        model_out: Tensor,
        xt: Tensor,
        dt: Tensor,
        t: Tensor,
        mask: Optional[Tensor] = None,
        gt_mode: str = "tan",
        gt_p: Float = 1.0,
        gt_clamp: Optional[Float] = None,
        score_temperature: Float = 1.0,
        noise_temperature: Float = 1.0,
        t_lim_ode: Float = 0.99,
        center: Bool = False,
    ):
        r"""Perform a single ODE step integration using Euler method.

        d x_t = [v(x_t, t) + g(t) * s(x_t, t) * sc_score_scale] dt + \sqrt{2 * g(t) * temperature} dw_t.

        At the moment we do not scale the vector field v but this can be added with sc_score_scale.

        Args:
            model_out (Tensor): The output of the model at the current time step.
            xt (Tensor): The current intermediate state.
            dt (Tensor): The time step size.
            t (Tensor, optional): The current time. Defaults to None.
            mask (Optional[Tensor], optional): A mask to apply to the model output. Defaults to None.
            gt_mode (str, optional): The mode for the gt function. Defaults to "1/t".
            gt_p (Float, optional): The parameter for the gt function. Defaults to 1.0.
            gt_clamp: (Float, optional): Upper limit of gt term. Defaults to None.
            score_temperature (Float, optional): The temperature for the score part of the step. Defaults to 1.0.
            noise_temperature (Float, optional): The temperature for the stochastic part of the step. Defaults to 1.0.
            t_lim_ode (Float, optional): The time limit for the ODE step. Defaults to 0.99.
            center (Bool, optional): Whether to center the output. Defaults to False.

        Returns:
            x_next (Tensor): The updated state of the system after the single step, x_(t+dt).

        Notes:
            - If a mask is provided, it is applied element-wise to the model output before scaling.
            - The `clean` method is called on the updated state before it is returned.
        """
        if self.ot_type is not None:
            raise ValueError("Optimal Transport violates the vector field to score conversion")
        if not isinstance(self.prior_distribution, GaussianPrior):
            raise ValueError(
                "Prior distribution must be an instance of GaussianPrior to learn a proper score function"
            )
        if t.min() >= t_lim_ode:
            return self.step(model_out, xt, dt, t, mask, center)
        if mask is not None:
            model_out = model_out * mask.unsqueeze(-1)
        v_t = self.process_vector_field_prediction(model_out, xt, t, mask)
        dt = pad_like(dt, model_out)
        t = pad_like(t, model_out)
        score = self.vf_to_score(xt, v_t, t)
        gt = self.get_gt(t, gt_mode, gt_p, gt_clamp)
        eps = torch.randn(xt.shape, dtype=xt.dtype, device=xt.device, generator=self.rng_generator)
        std_eps = torch.sqrt(2 * gt * noise_temperature * dt)
        delta_x = (v_t + gt * score * score_temperature) * dt + std_eps * eps
        x_next = xt + delta_x
        x_next = self.clean_mask_center(x_next, mask, center)
        return x_next

    def loss(
        self,
        model_pred: Tensor,
        target: Tensor,
        t: Optional[Tensor] = None,
        xt: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        target_type: Union[PredictionType, str] = PredictionType.DATA,
    ):
        """Calculate the loss given the model prediction, data sample, time, and mask.

        If target_type is FLOW loss = ||v_hat - (x1-x0)||**2
        If target_type is DATA loss = ||x1_hat - x1||**2 * 1 / (1 - t)**2 as the target vector field = x1 - x0 = (1/(1-t)) * x1 - xt where xt = tx1 - (1-t)x0.
        This functions supports any cominbation of prediction_type and target_type in {DATA, FLOW}.

        Args:
            model_pred (Tensor): The predicted output from the model.
            target (Tensor): The target output for the model prediction.
            t (Optional[Tensor], optional): The time for the model prediction. Defaults to None.
            xt (Optional[Tensor], optional): The interpolated data. Defaults to None.
            mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
            target_type (PredictionType, optional): The type of the target output. Defaults to PredictionType.DATA.

        Returns:
            Tensor: The calculated loss batch tensor.
        """
        target_type = string_to_enum(target_type, PredictionType)
        if target_type == PredictionType.DATA:
            model_pred = self.process_data_prediction(model_pred, xt, t, mask)
        else:
            model_pred = self.process_vector_field_prediction(model_pred, xt, t, mask)
        raw_loss = self._loss_function(model_pred, target)

        if mask is not None:
            loss = raw_loss * mask.unsqueeze(-1)
            n_elem = torch.sum(mask, dim=-1)
            loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / n_elem
        else:
            loss = torch.sum(raw_loss, dim=tuple(range(1, raw_loss.ndim))) / model_pred.size(1)
        if target_type == PredictionType.DATA:
            if t is None:
                raise ValueError("Time cannot be None when using a time-based weighting")
            loss_weight = 1.0 / ((1.0 - t) ** 2 + self.eps)
            loss = loss_weight * loss
        return loss

    def vf_to_score(
        self,
        x_t: Tensor,
        v: Tensor,
        t: Tensor,
    ) -> Tensor:
        """From Geffner et al. Computes score of noisy density given the vector field learned by flow matching.

        With our interpolation scheme these are related by

        v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),

        or equivalently,

        s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).

        with scale_ref = 1

        Args:
            x_t: Noisy sample, shape [*, dim]
            v: Vector field, shape [*, dim]
            t: Interpolation time, shape [*] (must be < 1)

        Returns:
            Score of intermediate density, shape [*, dim].
        """
        assert torch.all(t < 1.0), "vf_to_score requires t < 1 (strict)"
        t = pad_like(t, v)
        num = t * v - x_t  # [*, dim]
        den = 1.0 - t  # [*, 1]
        score = num / den
        return score  # [*, dim]

    def get_gt(
        self,
        t: Tensor,
        mode: str = "tan",
        param: float = 1.0,
        clamp_val: Optional[float] = None,
        eps: float = 1e-2,
    ) -> Tensor:
        """From Geffner et al. Computes gt for different modes.

        Args:
            t: times where we'll evaluate, covers [0, 1), shape [nsteps]
            mode: "us" or "tan"
            param: parameterized transformation
            clamp_val: value to clamp gt, no clamping if None
            eps: small value leave as it is
        """

        # Function to get variants for some gt mode
        def transform_gt(gt, f_pow=1.0):
            # 1.0 means no transformation
            if f_pow == 1.0:
                return gt

            # First we somewhat normalize between 0 and 1
            log_gt = torch.log(gt)
            mean_log_gt = torch.mean(log_gt)
            log_gt_centered = log_gt - mean_log_gt
            normalized = torch.nn.functional.sigmoid(log_gt_centered)
            # Transformation here
            normalized = normalized**f_pow
            # Undo normalization with the transformed variable
            log_gt_centered_rec = torch.logit(normalized, eps=1e-6)
            log_gt_rec = log_gt_centered_rec + mean_log_gt
            gt_rec = torch.exp(log_gt_rec)
            return gt_rec

        # Numerical reasons for some schedule
        t = torch.clamp(t, 0, 1 - self.eps)

        if mode == "us":
            num = 1.0 - t
            den = t
            gt = num / (den + eps)
        elif mode == "tan":
            num = torch.sin((1.0 - t) * torch.pi / 2.0)
            den = torch.cos((1.0 - t) * torch.pi / 2.0)
            gt = (torch.pi / 2.0) * num / (den + eps)
        elif mode == "1/t":
            num = 1.0
            den = t
            gt = num / (den + eps)
        elif mode == "1/t2":
            num = 1.0
            den = t**2
            gt = num / (den + eps)
        elif mode == "1/t1p5":
            num = 1.0
            den = t**1.5
            gt = num / (den + eps)
        elif mode == "2/t":
            num = 2.0
            den = t
            gt = num / (den + eps)
        elif mode == "2/t2":
            num = 2.0
            den = t**2
            gt = num / (den + eps)
        elif mode == "2/t1p5":
            num = 2.0
            den = t**1.5
            gt = num / (den + eps)
        elif mode == "1mt":
            gt = 1 - t
        elif mode == "t":
            gt = t
        elif mode == "ones":
            gt = 0 * t + 1
        else:
            raise NotImplementedError(f"gt not implemented {mode}")
        gt = transform_gt(gt, f_pow=param)
        gt = torch.clamp(gt, 0, clamp_val)  # If None no clamping
        return gt  # [s]

__init__(time_distribution, prior_distribution, prediction_type=PredictionType.DATA, sigma=0, ot_type=None, ot_num_threads=1, data_scale=1.0, device='cpu', rng_generator=None, eps=1e-05)

初始化连续流匹配插值器。

参数

名称 类型 描述 默认值
time_distribution TimeDistribution

时间步长的分布,用于为扩散过程采样时间点。

必需
prior_distribution PriorDistribution

变量的先验分布,用作扩散过程的起点。

必需
prediction_type PredictionType

预测类型,可以是“flow”或其他类型。默认为 PredictionType.DATA。

DATA
sigma Float

添加到插值数据的高斯噪声的标准偏差。默认为 0。

0
ot_type Optional[Union[OptimalTransportType, str]]

最优传输的类型(如果适用)。默认为 None。

None
ot_num_threads int

用于 OT 求解器的线程数。如果为“max”,则使用最大线程数。默认为 1。

1
data_scale Float

数据的比例因子。默认为 1.0。

1.0
device Union[str, device]

运行插值器的设备,可以是“cpu”或 CUDA 设备(例如“cuda:0”)。默认为“cpu”。

'cpu'
rng_generator Optional[Generator]

可选的 :class:torch.Generator 用于可重复的采样。默认为 None。

None
eps Float

防止被零除的小浮点数

1e-05
源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: PriorDistribution,
    prediction_type: Union[PredictionType, str] = PredictionType.DATA,
    sigma: Float = 0,
    ot_type: Optional[Union[OptimalTransportType, str]] = None,
    ot_num_threads: int = 1,
    data_scale: Float = 1.0,
    device: Union[str, torch.device] = "cpu",
    rng_generator: Optional[torch.Generator] = None,
    eps: Float = 1e-5,
):
    """Initializes the Continuous Flow Matching interpolant.

    Args:
        time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
        prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
        prediction_type (PredictionType, optional): The type of prediction, either "flow" or another type. Defaults to PredictionType.DATA.
        sigma (Float, optional): The standard deviation of the Gaussian noise added to the interpolated data. Defaults to 0.
        ot_type (Optional[Union[OptimalTransportType, str]], optional): The type of optimal transport, if applicable. Defaults to None.
        ot_num_threads:  Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.
        data_scale (Float, optional): The scale factor for the data. Defaults to 1.0.
        device (Union[str, torch.device], optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        eps: Small float to prevent divide by zero
    """
    super().__init__(time_distribution, prior_distribution, device, rng_generator)
    self.prediction_type = string_to_enum(prediction_type, PredictionType)
    self.sigma = sigma
    self.ot_type = ot_type
    self.data_scale = data_scale
    self.eps = eps
    if data_scale <= 0:
        raise ValueError("Data Scale must be > 0")
    if ot_type is not None:
        self.ot_type = ot_type = string_to_enum(ot_type, OptimalTransportType)
        self.ot_sampler = self._build_ot_sampler(method_type=ot_type, num_threads=ot_num_threads)
    self._loss_function = nn.MSELoss(reduction="none")

apply_ot(x0, x1, mask=None, **kwargs)

在批处理(和掩码)x0 和 x1 之间采样并应用最优传输方案。

参数

名称 类型 描述 默认值
x0 Tensor

形状 (bs, *dim),来自源小批量的噪声。

必需
x1 Tensor

形状 (bs, *dim),来自源小批量的数据。

必需
mask Optional[Tensor]

应用于输出的掩码,形状 (batchsize, nodes),如果未提供,则不应用掩码。默认为 None。

None
**kwargs

要传递给 self.ot_sampler.apply_ot 或在此方法中处理的其他关键字参数。

{}

返回

名称 类型 描述
Tuple tuple

2 个张量的元组,表示遵循 OT 方案 pi 的噪声和数据样本。

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def apply_ot(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None, **kwargs) -> tuple:
    """Sample and apply the optimal transport plan between batched (and masked) x0 and x1.

    Args:
        x0 (Tensor): shape (bs, *dim), noise from source minibatch.
        x1 (Tensor): shape (bs, *dim), data from source minibatch.
        mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
        **kwargs: Additional keyword arguments to be passed to self.ot_sampler.apply_ot or handled within this method.


    Returns:
        Tuple: tuple of 2 tensors, represents the noise and data samples following OT plan pi.
    """
    if self.ot_sampler is None:
        raise ValueError("Optimal Transport Sampler is not defined")
    return self.ot_sampler.apply_ot(x0, x1, mask=mask, **kwargs)

calculate_target(data, noise, mask=None)

获取时间 t 的目标向量场。

参数

名称 类型 描述 默认值
noise Tensor

来自 prior() 的噪声,形状 (batchsize, nodes, features)

必需
data Tensor

目标,形状 (batchsize, nodes, features)

必需
mask Optional[Tensor]

应用于输出的掩码,形状 (batchsize, nodes),如果未提供,则不应用掩码。默认为 None。

None

返回

名称 类型 描述
Tensor Tensor

时间 t 的目标向量场。

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def calculate_target(self, data: Tensor, noise: Tensor, mask: Optional[Tensor] = None) -> Tensor:
    """Get the target vector field at time t.

    Args:
        noise (Tensor): noise from prior(), shape (batchsize, nodes, features)
        data (Tensor): target, shape (batchsize, nodes, features)
        mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

    Returns:
        Tensor: The target vector field at time t.
    """
    assert data.size() == noise.size()
    # Calculate the target vector field u_t(x_t|x_1) as the difference between data and noise because t~[0,1]
    if self.prediction_type == PredictionType.VELOCITY:
        u_t = data - noise
    elif self.prediction_type == PredictionType.DATA:
        u_t = data
    else:
        raise ValueError(
            f"Given prediction_type {self.prediction_type} is not supproted for Continuous Flow Matching."
        )
    if mask is not None:
        u_t = u_t * mask.unsqueeze(-1)
    return u_t

get_gt(t, mode='tan', param=1.0, clamp_val=None, eps=0.01)

来自 Geffner 等人。计算不同模式的 gt。

参数

名称 类型 描述 默认值
t Tensor

我们将评估的时间,涵盖 [0, 1),形状 [nsteps]

必需
mode str

“us”或“tan”

'tan'
param float

参数化变换

1.0
clamp_val Optional[float]

用于钳制 gt 的值,如果为 None 则不钳制

None
eps float

小值保持原样

0.01
源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def get_gt(
    self,
    t: Tensor,
    mode: str = "tan",
    param: float = 1.0,
    clamp_val: Optional[float] = None,
    eps: float = 1e-2,
) -> Tensor:
    """From Geffner et al. Computes gt for different modes.

    Args:
        t: times where we'll evaluate, covers [0, 1), shape [nsteps]
        mode: "us" or "tan"
        param: parameterized transformation
        clamp_val: value to clamp gt, no clamping if None
        eps: small value leave as it is
    """

    # Function to get variants for some gt mode
    def transform_gt(gt, f_pow=1.0):
        # 1.0 means no transformation
        if f_pow == 1.0:
            return gt

        # First we somewhat normalize between 0 and 1
        log_gt = torch.log(gt)
        mean_log_gt = torch.mean(log_gt)
        log_gt_centered = log_gt - mean_log_gt
        normalized = torch.nn.functional.sigmoid(log_gt_centered)
        # Transformation here
        normalized = normalized**f_pow
        # Undo normalization with the transformed variable
        log_gt_centered_rec = torch.logit(normalized, eps=1e-6)
        log_gt_rec = log_gt_centered_rec + mean_log_gt
        gt_rec = torch.exp(log_gt_rec)
        return gt_rec

    # Numerical reasons for some schedule
    t = torch.clamp(t, 0, 1 - self.eps)

    if mode == "us":
        num = 1.0 - t
        den = t
        gt = num / (den + eps)
    elif mode == "tan":
        num = torch.sin((1.0 - t) * torch.pi / 2.0)
        den = torch.cos((1.0 - t) * torch.pi / 2.0)
        gt = (torch.pi / 2.0) * num / (den + eps)
    elif mode == "1/t":
        num = 1.0
        den = t
        gt = num / (den + eps)
    elif mode == "1/t2":
        num = 1.0
        den = t**2
        gt = num / (den + eps)
    elif mode == "1/t1p5":
        num = 1.0
        den = t**1.5
        gt = num / (den + eps)
    elif mode == "2/t":
        num = 2.0
        den = t
        gt = num / (den + eps)
    elif mode == "2/t2":
        num = 2.0
        den = t**2
        gt = num / (den + eps)
    elif mode == "2/t1p5":
        num = 2.0
        den = t**1.5
        gt = num / (den + eps)
    elif mode == "1mt":
        gt = 1 - t
    elif mode == "t":
        gt = t
    elif mode == "ones":
        gt = 0 * t + 1
    else:
        raise NotImplementedError(f"gt not implemented {mode}")
    gt = transform_gt(gt, f_pow=param)
    gt = torch.clamp(gt, 0, clamp_val)  # If None no clamping
    return gt  # [s]

interpolate(data, t, noise)

从噪声 (x_0) 和数据 (x_1) 获取给定时间 t 的 x_t。

目前,我们使用线性插值,如以下文献中所定义:1. Rectified flow: https://arxiv.org/abs/2209.03003。2. Conditional flow matching: https://arxiv.org/abs/2210.02747(称为条件最优传输)。

参数

名称 类型 描述 默认值
noise Tensor

来自 prior() 的噪声,形状 (batchsize, nodes, features)

必需
t Tensor

time,形状 (batchsize)

必需
data Tensor

目标,形状 (batchsize, nodes, features)

必需
源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def interpolate(self, data: Tensor, t: Tensor, noise: Tensor) -> Tensor:
    """Get x_t with given time t from noise (x_0) and data (x_1).

    Currently, we use the linear interpolation as defined in:
        1. Rectified flow: https://arxiv.org/abs/2209.03003.
        2. Conditional flow matching: https://arxiv.org/abs/2210.02747 (called conditional optimal transport).

    Args:
        noise (Tensor): noise from prior(), shape (batchsize, nodes, features)
        t (Tensor): time, shape (batchsize)
        data (Tensor): target, shape (batchsize, nodes, features)
    """
    assert data.size() == noise.size()
    # Expand t to the same shape as noise: ones([b,n,f]) * t([b,1,1])
    t = pad_like(t, data)
    # Calculate x_t as the linear interpolation between noise and data
    x_t = data * t + noise * (1.0 - t)
    # Add Gaussian Noise
    if self.sigma > 0:
        x_t += self.sigma * torch.randn(x_t.shape, device=x_t.device, generator=self.rng_generator)
    return x_t

loss(model_pred, target, t=None, xt=None, mask=None, target_type=PredictionType.DATA)

计算给定模型预测、数据样本、时间和掩码的损失。

如果 target_type 为 FLOW,则 loss = ||v_hat - (x1-x0)||2 如果 target_type 为 DATA,则 loss = ||x1_hat - x1||2 * 1 / (1 - t)**2,因为目标向量场 = x1 - x0 = (1/(1-t)) * x1 - xt,其中 xt = tx1 - (1-t)x0。此函数支持 {DATA, FLOW} 中 prediction_type 和 target_type 的任何组合。

参数

名称 类型 描述 默认值
model_pred Tensor

来自模型的预测输出。

必需
target Tensor

模型预测的目标输出。

必需
t Optional[Tensor]

模型预测的时间。默认为 None。

None
xt Optional[Tensor]

插值数据。默认为 None。

None
mask Optional[Tensor]

数据点的掩码。默认为 None。

None
target_type PredictionType

目标输出的类型。默认为 PredictionType.DATA。

DATA

返回

名称 类型 描述
Tensor

计算出的损失批张量。

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def loss(
    self,
    model_pred: Tensor,
    target: Tensor,
    t: Optional[Tensor] = None,
    xt: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
    target_type: Union[PredictionType, str] = PredictionType.DATA,
):
    """Calculate the loss given the model prediction, data sample, time, and mask.

    If target_type is FLOW loss = ||v_hat - (x1-x0)||**2
    If target_type is DATA loss = ||x1_hat - x1||**2 * 1 / (1 - t)**2 as the target vector field = x1 - x0 = (1/(1-t)) * x1 - xt where xt = tx1 - (1-t)x0.
    This functions supports any cominbation of prediction_type and target_type in {DATA, FLOW}.

    Args:
        model_pred (Tensor): The predicted output from the model.
        target (Tensor): The target output for the model prediction.
        t (Optional[Tensor], optional): The time for the model prediction. Defaults to None.
        xt (Optional[Tensor], optional): The interpolated data. Defaults to None.
        mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
        target_type (PredictionType, optional): The type of the target output. Defaults to PredictionType.DATA.

    Returns:
        Tensor: The calculated loss batch tensor.
    """
    target_type = string_to_enum(target_type, PredictionType)
    if target_type == PredictionType.DATA:
        model_pred = self.process_data_prediction(model_pred, xt, t, mask)
    else:
        model_pred = self.process_vector_field_prediction(model_pred, xt, t, mask)
    raw_loss = self._loss_function(model_pred, target)

    if mask is not None:
        loss = raw_loss * mask.unsqueeze(-1)
        n_elem = torch.sum(mask, dim=-1)
        loss = torch.sum(loss, dim=tuple(range(1, raw_loss.ndim))) / n_elem
    else:
        loss = torch.sum(raw_loss, dim=tuple(range(1, raw_loss.ndim))) / model_pred.size(1)
    if target_type == PredictionType.DATA:
        if t is None:
            raise ValueError("Time cannot be None when using a time-based weighting")
        loss_weight = 1.0 / ((1.0 - t) ** 2 + self.eps)
        loss = loss_weight * loss
    return loss

process_data_prediction(model_output, xt=None, t=None, mask=None)

根据预测类型处理模型输出以生成干净数据。

参数

名称 类型 描述 默认值
model_output Tensor

模型的输出。

必需
xt Tensor

xt

None
t Tensor

输入样本。

None
mask Optional[Tensor]

t

None

返回

类型 描述

时间步长。

mask

类型 描述
应用于模型输出的可选掩码。默认为 None。

返回

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def process_data_prediction(
    self,
    model_output: Tensor,
    xt: Optional[Tensor] = None,
    t: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
):
    """Process the model output based on the prediction type to generate clean data.

    Args:
        model_output (Tensor): The output of the model.
        xt (Tensor): The input sample.
        t (Tensor): The time step.
        mask (Optional[Tensor], optional): An optional mask to apply to the model output. Defaults to None.

    Returns:
        The data prediction based on the prediction type.

    Raises:
        ValueError: If the prediction type is not "flow".
    """
    if self.prediction_type == PredictionType.VELOCITY:
        if xt is None or t is None:
            raise ValueError("Xt and time cannot be None")
        t = pad_like(t, model_output)
        pred_data = xt + (1 - t) * model_output
    elif self.prediction_type == PredictionType.DATA:
        pred_data = model_output
    else:
        raise ValueError(
            f"prediction_type given as {self.prediction_type} must be `flow` " "for Continuous Flow Matching."
        )
    if mask is not None:
        pred_data = pred_data * mask.unsqueeze(-1)
    return pred_data

ValueError

基于预测类型的数据预测。

参数

名称 类型 描述 默认值
model_output Tensor

模型的输出。

必需
xt Tensor

xt

None
t Tensor

输入样本。

None
mask Optional[Tensor]

t

None

返回

类型 描述

Raises

mask

类型 描述
应用于模型输出的可选掩码。默认为 None。

ValueError

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def process_vector_field_prediction(
    self,
    model_output: Tensor,
    xt: Optional[Tensor] = None,
    t: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
):
    """Process the model output based on the prediction type to calculate vecotr field.

    Args:
        model_output (Tensor): The output of the model.
        xt (Tensor): The input sample.
        t (Tensor): The time step.
        mask (Optional[Tensor], optional): An optional mask to apply to the model output. Defaults to None.

    Returns:
        The vector field prediction based on the prediction type.

    Raises:
        ValueError: If the prediction type is not "flow" or "data".
    """
    if self.prediction_type == PredictionType.VELOCITY:
        pred_vector_field = model_output
    elif self.prediction_type == PredictionType.DATA:
        if xt is None or t is None:
            raise ValueError("Xt and Time cannpt be None for vector field conversion")
        t = pad_like(t, model_output)
        pred_vector_field = (model_output - xt) / (1 - t + self.eps)
    else:
        raise ValueError(
            f"prediction_type given as {self.prediction_type} must be `flow` or `data` "
            "for Continuous Flow Matching."
        )
    if mask is not None:
        pred_vector_field = pred_vector_field * mask.unsqueeze(-1)
    return pred_vector_field

如果预测类型不是“flow”。

process_vector_field_prediction(model_output, xt=None, t=None, mask=None)

参数

名称 类型 描述 默认值
data Tensor

根据预测类型处理模型输出以计算向量场。

必需

返回

类型 描述
Tensor

model_output

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
156
157
158
159
160
161
162
163
164
165
def scale_data(self, data: Tensor) -> Tensor:
    """Upscale the input data by the data scale factor.

    Args:
        data (Tensor): The input data to upscale.

    Returns:
        The upscaled data.
    """
    return self.data_scale * data

模型的输出。

xt

参数

名称 类型 描述 默认值
输入样本。 Tensor

t

必需
xt Tensor

时间步长。

必需
mask Tensor

应用于模型输出的可选掩码。默认为 None。

必需
t Tensor

返回

None
mask Optional[Tensor]

Tensor

None
基于预测类型的向量场预测。 Raises

ValueError

如果预测类型不是“flow”或“data”。

返回

名称 类型 描述
scale_data(data) Tensor

按数据比例因子放大输入数据。

data

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def step(
    self,
    model_out: Tensor,
    xt: Tensor,
    dt: Tensor,
    t: Optional[Tensor] = None,
    mask: Optional[Tensor] = None,
    center: Bool = False,
):
    """Perform a single ODE step integration using Euler method.

    Args:
        model_out (Tensor): The output of the model at the current time step.
        xt (Tensor): The current intermediate state.
        dt (Tensor): The time step size.
        t (Tensor, optional): The current time. Defaults to None.
        mask (Optional[Tensor], optional): A mask to apply to the model output. Defaults to None.
        center (Bool, optional): Whether to center the output. Defaults to False.

    Returns:
        x_next (Tensor): The updated state of the system after the single step, x_(t+dt).

    Notes:
    - If a mask is provided, it is applied element-wise to the model output before scaling.
    - The `clean` method is called on the updated state before it is returned.
    """
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    v_t = self.process_vector_field_prediction(model_out, xt, t, mask)
    dt = pad_like(dt, model_out)
    delta_x = v_t * dt
    x_next = xt + delta_x
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

要放大的输入数据。

xt

返回

Tensor

参数

名称 类型 描述 默认值
输入样本。 Tensor

t

必需
xt Tensor

时间步长。

必需
mask Tensor

应用于模型输出的可选掩码。默认为 None。

必需
t Tensor

返回

必需
mask Optional[Tensor]

Tensor

None
放大的数据。 str

step(model_out, xt, dt, t=None, mask=None, center=False)

'tan'
使用 Euler 方法执行单个 ODE 步长积分。 Float

model_out

1.0
模型在当前时间步长的输出。 xt

当前中间状态。

None
dt Float

时间步长大小。

1.0
t Float

当前时间。默认为 None。

1.0
mask Float

应用于模型输出的掩码。默认为 None。

0.99
基于预测类型的向量场预测。 Raises

ValueError

如果预测类型不是“flow”或“data”。

返回

名称 类型 描述
scale_data(data) Tensor

按数据比例因子放大输入数据。

center
  • Bool
  • 是否居中输出。默认为 False。
源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def step_score_stochastic(
    self,
    model_out: Tensor,
    xt: Tensor,
    dt: Tensor,
    t: Tensor,
    mask: Optional[Tensor] = None,
    gt_mode: str = "tan",
    gt_p: Float = 1.0,
    gt_clamp: Optional[Float] = None,
    score_temperature: Float = 1.0,
    noise_temperature: Float = 1.0,
    t_lim_ode: Float = 0.99,
    center: Bool = False,
):
    r"""Perform a single ODE step integration using Euler method.

    d x_t = [v(x_t, t) + g(t) * s(x_t, t) * sc_score_scale] dt + \sqrt{2 * g(t) * temperature} dw_t.

    At the moment we do not scale the vector field v but this can be added with sc_score_scale.

    Args:
        model_out (Tensor): The output of the model at the current time step.
        xt (Tensor): The current intermediate state.
        dt (Tensor): The time step size.
        t (Tensor, optional): The current time. Defaults to None.
        mask (Optional[Tensor], optional): A mask to apply to the model output. Defaults to None.
        gt_mode (str, optional): The mode for the gt function. Defaults to "1/t".
        gt_p (Float, optional): The parameter for the gt function. Defaults to 1.0.
        gt_clamp: (Float, optional): Upper limit of gt term. Defaults to None.
        score_temperature (Float, optional): The temperature for the score part of the step. Defaults to 1.0.
        noise_temperature (Float, optional): The temperature for the stochastic part of the step. Defaults to 1.0.
        t_lim_ode (Float, optional): The time limit for the ODE step. Defaults to 0.99.
        center (Bool, optional): Whether to center the output. Defaults to False.

    Returns:
        x_next (Tensor): The updated state of the system after the single step, x_(t+dt).

    Notes:
        - If a mask is provided, it is applied element-wise to the model output before scaling.
        - The `clean` method is called on the updated state before it is returned.
    """
    if self.ot_type is not None:
        raise ValueError("Optimal Transport violates the vector field to score conversion")
    if not isinstance(self.prior_distribution, GaussianPrior):
        raise ValueError(
            "Prior distribution must be an instance of GaussianPrior to learn a proper score function"
        )
    if t.min() >= t_lim_ode:
        return self.step(model_out, xt, dt, t, mask, center)
    if mask is not None:
        model_out = model_out * mask.unsqueeze(-1)
    v_t = self.process_vector_field_prediction(model_out, xt, t, mask)
    dt = pad_like(dt, model_out)
    t = pad_like(t, model_out)
    score = self.vf_to_score(xt, v_t, t)
    gt = self.get_gt(t, gt_mode, gt_p, gt_clamp)
    eps = torch.randn(xt.shape, dtype=xt.dtype, device=xt.device, generator=self.rng_generator)
    std_eps = torch.sqrt(2 * gt * noise_temperature * dt)
    delta_x = (v_t + gt * score * score_temperature) * dt + std_eps * eps
    x_next = xt + delta_x
    x_next = self.clean_mask_center(x_next, mask, center)
    return x_next

默认值

False

参数

名称 类型 描述 默认值
data Tensor

返回

必需

返回

类型 描述
Tensor

Tensor

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
145
146
147
148
149
150
151
152
153
154
def undo_scale_data(self, data: Tensor) -> Tensor:
    """Downscale the input data by the data scale factor.

    Args:
        data (Tensor): The input data to downscale.

    Returns:
        The downscaled data.
    """
    return 1 / self.data_scale * data

x_next

单步后的系统更新状态,x_(t+dt)。

Notes: - 如果提供了掩码,则在缩放之前将其逐元素应用于模型输出。- 在返回更新状态之前,将调用 clean 方法。

step_score_stochastic(model_out, xt, dt, t, mask=None, gt_mode='tan', gt_p=1.0, gt_clamp=None, score_temperature=1.0, noise_temperature=1.0, t_lim_ode=0.99, center=False)

d x_t = [v(x_t, t) + g(t) * s(x_t, t) * sc_score_scale] dt + \sqrt{2 * g(t) * temperature} dw_t。

目前我们不缩放向量场 v,但可以使用 sc_score_scale 添加此项。

gt_mode

参数

名称 类型 描述 默认值
gt 函数的模式。默认为“1/t”。 Tensor

gt_p

必需
gt 函数的参数。默认为 1.0。 Tensor

gt_clamp

必需
t Tensor

Optional[Float]

必需

返回

类型 描述
Tensor

(Float, optional): gt 项的上限。默认为 None。

源代码在 bionemo/moco/interpolants/continuous_time/continuous/continuous_flow_matching.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def vf_to_score(
    self,
    x_t: Tensor,
    v: Tensor,
    t: Tensor,
) -> Tensor:
    """From Geffner et al. Computes score of noisy density given the vector field learned by flow matching.

    With our interpolation scheme these are related by

    v(x_t, t) = (1 / t) (x_t + scale_ref ** 2 * (1 - t) * s(x_t, t)),

    or equivalently,

    s(x_t, t) = (t * v(x_t, t) - x_t) / (scale_ref ** 2 * (1 - t)).

    with scale_ref = 1

    Args:
        x_t: Noisy sample, shape [*, dim]
        v: Vector field, shape [*, dim]
        t: Interpolation time, shape [*] (must be < 1)

    Returns:
        Score of intermediate density, shape [*, dim].
    """
    assert torch.all(t < 1.0), "vf_to_score requires t < 1 (strict)"
    t = pad_like(t, v)
    num = t * v - x_t  # [*, dim]
    den = 1.0 - t  # [*, 1]
    score = num / den
    return score  # [*, dim]