跳到内容

D3pm

D3PM

Bases: Interpolant

离散去噪扩散概率模型 (D3PM) 插值器。

源代码位于 bionemo/moco/interpolants/discrete_time/discrete/d3pm.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
class D3PM(Interpolant):
    """A Discrete Denoising Diffusion Probabilistic Model (D3PM) interpolant."""

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: DiscretePriorDistribution,
        noise_schedule: DiscreteNoiseSchedule,
        device: str = "cpu",
        last_time_idx: int = 0,
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initializes the D3PM interpolant.

        Args:
            time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
            prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
            noise_schedule (DiscreteNoiseSchedule): The schedule of noise, defining the amount of noise added at each time step.
            device (str, optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
            last_time_idx (int, optional): The last time index to consider in the interpolation process. Defaults to 0.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        """
        # We initialize with CPU due to numerical precision issues on A100 that are not observed on A6000
        super().__init__(time_distribution, prior_distribution, "cpu", rng_generator)
        self.noise_schedule = noise_schedule
        self._loss_function = nn.CrossEntropyLoss(reduction="none")
        self.timesteps = noise_schedule.nsteps
        self.num_classes = prior_distribution.num_classes
        self.terminal_distribution = prior_distribution.prior_dist.to(self.device)
        self._initialize_schedules(self.device)
        self.last_time_idx = last_time_idx
        self.to_device(device)

    def _get_Qt(self, alphas: Tensor) -> Tensor:
        """Calculate the transition matrix Qt based on the terminal distribution.

        The transition matrix Qt represents the probabilities of transitioning from one state to another at a given time step.
        It is calculated based on the terminal distribution, which can be either uniform, a mask, or a custom distribution.
        See Appendix A.2 D3PM https://arxiv.org/pdf/2107.03006 which shows what happens for various prior distributions.

        The terminal distribution can be:
        - Uniform: a uniform distribution over all states.
        - Mask: a mask where the last dimension is 1 and the rest are 0.
        - Custom: a custom distribution provided by the user.

        Args:
            alphas (Tensor): A tensor of probabilities, where each alpha represents the probability of staying in a state at a given time step.

        Returns:
            Tensor: The transition matrix Qt.
        """
        QT = []
        for alpha_t in alphas:
            stay_prob = torch.eye(len(self.terminal_distribution), device=self.device) * alpha_t
            diffuse_prob = (1.0 - alpha_t) * (
                torch.ones(1, len(self.terminal_distribution), device=self.device)
                * (self.terminal_distribution.unsqueeze(0))
            )
            QT.append(stay_prob + diffuse_prob)
        return torch.stack(QT, dim=0)

    def _calculate_transition_matrix(self, alphas: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Calculates the rate transition matrix `Qt`, its cumulative variant `Qt_bar`, and the cumulative variant of the previous time step `Qt_bar_prev`.

        Args:
            alphas (Tensor): A tensor of probabilities, where each alpha represents the probability of staying in a state at a given time step.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: A tuple containing the rate transition matrix `Qt`, its cumulative variant `Qt_bar`, and the cumulative variant of the previous time step `Qt_bar_prev`.
        """
        Qt = self._get_Qt(alphas)
        Qt_prev = torch.eye(self.num_classes, device=self.device)
        Qt_bar = []
        for i in range(len(alphas)):
            Qtb = Qt_prev @ Qt[i]
            if torch.any((Qtb.sum(-1) - 1.0).abs() > 1e-4):
                raise ValueError(f"Invalid Distribution for Qt_bar at step {i}")
            Qt_bar.append(Qtb)
            Qt_prev = Qtb
        Qt_bar = torch.stack(Qt_bar)
        Qt_bar_prev = Qt_bar[:-1]
        Qt_prev_pad = torch.eye(self.num_classes, device=self.device)
        Qt_bar_prev = torch.concat([Qt_prev_pad.unsqueeze(0), Qt_bar_prev], dim=0)
        return Qt, Qt_bar, Qt_bar_prev

    def _initialize_schedules(self, device):
        """Initializes the transition matrices for the discrete diffusion process.

        This method computes the rate transition matrix `Qt` and its cumulative variants `Qt_bar` and `Qt_prev_bar`
        based on the provided noise schedule.

        Note:
            `Qt` represents the rate transition matrix, where `Qt[t]` is the transition matrix at time step `t`.
            `Qt_bar` and `Qt_prev_bar` are the cumulative variants of `Qt`, where `Qt_bar[t]` represents the cumulative
            transition matrix from time step `0` to `t`, and `Qt_prev_bar[t]` represents the cumulative transition matrix
            from time step `0` to `t-1`.

        Args:
            device (str): The device on which to compute the transition matrices.
        """
        if self.noise_schedule is None:
            raise ValueError("noise_schedule cannot be None for D3PM")
        alphas = self.noise_schedule.generate_schedule(device=device)
        log_alpha = torch.log(alphas)
        log_alpha_bar = torch.cumsum(log_alpha, dim=0)
        self._alpha_bar = torch.exp(log_alpha_bar)
        #! Note to users that the tranditional cosine schedule is a very quick convergence of alpha. Pay close attention to the scheduler here
        Qt, Qt_bar, Qt_prev_bar = self._calculate_transition_matrix(alphas)
        self._Qt = Qt[-self.timesteps :]
        self._Qt_transposed = self._Qt.transpose(1, 2)
        self._Qt_bar = Qt_bar[-self.timesteps :]
        self._Qt_prev_bar = Qt_prev_bar[-self.timesteps :]

    def interpolate(self, data: Tensor, t: Tensor):
        """Interpolate using discrete interpolation method.

        This method implements Equation 2 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which
        calculates the interpolated discrete state `xt` at time `t` given the input data and noise
        via q(xt|x0) = Cat(xt; p = x0*Qt_bar).

        Args:
            data (Tensor): The input data to be interpolated.
            t (Tensor): The time step at which to interpolate.

        Returns:
            Tensor: The interpolated discrete state `xt` at time `t`.
        """
        if not _is_one_hot(data, self.num_classes):
            x1_hot = F.one_hot(data, self.num_classes)
        else:
            x1_hot = data
        ford = safe_index(self._Qt_bar, t - self.last_time_idx, data.device)
        if x1_hot.ndim > 3:  # einsum precision issues on A100 not A6000 for 2D inputs
            ford_prep = ford
            for _ in range(x1_hot.ndim - 2):
                ford_prep = ford_prep.unsqueeze(1)
            probs = (x1_hot.float().unsqueeze(-2) * ford_prep).sum(dim=(-2))
        else:
            probs = torch.einsum("b...j, bji -> b...i", [x1_hot.float(), ford])
        if torch.any((probs.sum(-1) - 1.0).abs() > 1e-4):
            raise ValueError(
                f"**INVALID BEHAVIOR** Probability Distribution does not sum to 1.0 for time {t}. "
                f"**INVESTIGATE YOUR DEVICE PRECISION**: This error has been triggered before on A100 by initializing the Qt terms on gpu. "
                f"Normalized to ensure validity. Original sums: {probs.sum(-1)}",
            )
        xt = self._sample_categorical(torch.log(probs) + 1.0e-6)
        return xt

    def forward_process(self, data: Tensor, t: Tensor) -> Tensor:
        """Apply the forward process to the data at time t.

        Args:
            data (Tensor): target discrete ids
            t (Tensor): time

        Returns:
            Tensor: x(t) after applying the forward process
        """
        return self.interpolate(data, t)

    def _sample_categorical(self, logits, mask: Optional[Tensor] = None, temperature: Float = 1.0) -> Tensor:
        """Sample a categorical distribution using the Gumbel-Softmax trick.

        This method samples a categorical distribution from the given logits,
        optionally applying a mask and using a specified temperature.

        Args:
            logits (Tensor): The logits of the categorical distribution.
            mask (Optional[Tensor], optional): An optional mask to apply to the noise added to logits. Defaults to None.
            temperature (float, optional): The temperature to use for the Gumbel-Softmax trick. Defaults to 1.0.

        Returns:
            Tensor: A sample from the categorical distribution.
        """
        noise = torch.rand_like(logits)
        noise = torch.clip(noise, 1.0e-6, 1.0)
        gumbel_noise = -torch.log(-torch.log(noise))
        if mask is not None:
            sample = torch.argmax((logits / temperature) + gumbel_noise * mask, dim=-1)
        else:
            sample = torch.argmax((logits / temperature) + gumbel_noise, dim=-1)
        return sample

    def _q_posterior_logits(
        self, model_out: Tensor, t: Tensor, xt: Tensor, model_out_is_logits: bool = True
    ) -> Tensor:
        """Calculate the q-posterior logits using the predicted x0 and the current state xt at time t.

        This method implements Equation 3 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which calculates the q-posterior
        distribution over the previous state x0 given the current state xt and the model output.

        Args:
            model_out (Tensor): The output of the model at the current time step.
            t (Tensor): The current time step.
            xt (Tensor): The current discrete state at time t.
            model_out_is_logits (bool, optional): A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True.

        Returns:
            Tensor: The q-posterior logits.
        """
        if not model_out_is_logits:  # model_out.dtype == torch.int64 or model_out.dtype == torch.int32:
            # Convert model output to logits if it's a categorical distribution
            x0_logits = torch.log(torch.nn.functional.one_hot(model_out, self.num_classes).float() + 1.0e-6)
        else:
            # Otherwise, assume model output is already logits
            x0_logits = model_out.clone()

        # Calculate xt_guess: the predicted probability of xt given x0 and t
        xt_guess = torch.einsum(
            "b...j, bji -> b...i",
            [
                torch.nn.functional.one_hot(xt, self.num_classes).float(),
                safe_index(self._Qt_transposed, t - self.last_time_idx, model_out.device),
            ],
        )

        # Calculate softmaxed x0_logits
        softmaxed = torch.softmax(x0_logits, dim=-1)  # bs, ..., num_classes

        # Calculate x0_guess: the predicted probability of x0 given xt and t-1
        x0_guess = torch.einsum(
            "b...c,bcd->b...d",
            softmaxed,
            safe_index(self._Qt_prev_bar, t - self.last_time_idx, model_out.device),
        )

        # Calculate q-posterior logits
        out = torch.log(xt_guess + 1.0e-6) + torch.log(x0_guess + 1.0e-6)
        t_broadcast = t.reshape((t.shape[0], *[1] * (xt.dim())))
        q_posterior_logits = torch.where(t_broadcast == self.last_time_idx, x0_logits, out)
        return q_posterior_logits

    def step(
        self,
        model_out: Tensor,
        t: Tensor,
        xt: Tensor,
        mask: Optional[Tensor] = None,
        temperature: Float = 1.0,
        model_out_is_logits: bool = True,
    ):
        """Perform a single step in the discrete interpolant method, transitioning from the current discrete state `xt` at time `t` to the next state.

        This step involves:

        1. Computing the predicted q-posterior logits using the model output `model_out` and the current state `xt` at time `t`.
        2. Sampling the next state from the predicted q-posterior distribution using the Gumbel-Softmax trick.

        Args:
            model_out (Tensor): The output of the model at the current time step, which is used to compute the predicted q-posterior logits.
            t (Tensor): The current time step, which is used to index into the transition matrices and compute the predicted q-posterior logits.
            xt (Tensor): The current discrete state at time `t`, which is used to compute the predicted q-posterior logits and sample the next state.
            mask (Optional[Tensor], optional): An optional mask to apply to the next state, which can be used to mask out certain tokens or regions. Defaults to None.
            temperature (Float, optional): The temperature to use for the Gumbel-Softmax trick, which controls the randomness of the sampling process. Defaults to 1.0.
            model_out_is_logits (bool, optional): A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True.

        Returns:
            Tensor: The next discrete state at time `t-1`.
        """
        pred_q_posterior_logits = self._q_posterior_logits(model_out, t, xt, model_out_is_logits)
        nonzero_mask = (t != self.last_time_idx).to(xt.dtype).reshape(xt.shape[0], *([1] * (len(xt.shape))))
        x_next = self._sample_categorical(pred_q_posterior_logits, nonzero_mask, temperature=temperature)
        # # Apply mask if provided
        if mask is not None:
            x_next = x_next * mask
        return x_next

    def loss(
        self,
        logits: Tensor,
        target: Tensor,
        xt: Tensor,
        time: Tensor,
        mask: Optional[Tensor] = None,
        vb_scale: Float = 0.0,
    ):
        """Calculate the cross-entropy loss between the model prediction and the target output.

        The loss is calculated between the batch x node x class logits and the target batch x node. If a mask is provided, the loss is
        calculated only for the non-masked elements. Additionally, if vb_scale is greater than 0, the variational lower bound loss is
        calculated and added to the total loss.

        Args:
            logits (Tensor): The predicted output from the model, with shape batch x node x class.
            target (Tensor): The target output for the model prediction, with shape batch x node.
            xt (Tensor): The current data point.
            time (Tensor): The time at which the loss is calculated.
            mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
            vb_scale (Float, optional): The scale factor for the variational lower bound loss. Defaults to 0.0.

        Returns:
            Tensor: The calculated loss tensor. If aggregate is True, the loss and variational lower bound loss are aggregated and
            returned as a single tensor. Otherwise, the loss and variational lower bound loss are returned as separate tensors.
        """
        assert target.ndim + 1 == logits.ndim
        loss = self._loss_function(logits.transpose(-1, 1), target.long())
        if mask is not None:
            loss = loss * mask
            num_non_masked_elements = torch.sum(mask, dim=-1)
            loss = torch.sum(loss, dim=(-1)) / num_non_masked_elements
        else:
            loss = torch.sum(loss, dim=(-1)) / logits.size(1)
        if vb_scale > 0:
            target = F.one_hot(target, num_classes=self.num_classes).float()
            true_q_posterior_logits = self._q_posterior_logits(target, time, xt)
            pred_q_posterior_logits = self._q_posterior_logits(logits, time, xt)
            vb_loss = self._variational_lower_bound(true_q_posterior_logits, pred_q_posterior_logits)
            vb_loss = vb_scale * vb_loss
        else:
            vb_loss = 0
        if vb_scale > 0:
            loss += vb_loss
        return loss

    def _variational_lower_bound(self, dist1: Tensor, dist2: Tensor) -> Tensor:
        """Calculate the variational lower bound (VLB) between two distributions.

        The VLB measures the difference between the true and approximate posterior distributions.
        It is used to regularize the model and encourage it to produce more accurate predictions.

        Args:
            dist1 (Tensor): The true posterior distribution.
            dist2 (Tensor): The approximate posterior distribution.

        Returns:
            Tensor: The variational lower bound loss.
        """
        # Flatten dist1 and dist2 to simplify calculations
        dist1 = dist1.flatten(start_dim=0, end_dim=-2)
        dist2 = dist2.flatten(start_dim=0, end_dim=-2)

        # Calculate the VLB
        out = torch.softmax(dist1 + 1.0e-6, dim=-1) * (
            torch.log_softmax(dist1 + 1.0e-6, dim=-1) - torch.log_softmax(dist2 + 1.0e-6, dim=-1)
        )
        # Return the mean of the VLB across all elements
        return out.sum(dim=-1).mean()

__init__(time_distribution, prior_distribution, noise_schedule, device='cpu', last_time_idx=0, rng_generator=None)

初始化 D3PM 插值器。

参数

名称 类型 描述 默认值
time_distribution TimeDistribution

时间步长的分布,用于为扩散过程采样时间点。

必需
prior_distribution PriorDistribution

变量的先验分布,用作扩散过程的起点。

必需
noise_schedule DiscreteNoiseSchedule

噪声计划,定义在每个时间步添加的噪声量。

必需
device str

运行插值器的设备,可以是 "cpu" 或 CUDA 设备(例如 "cuda:0")。默认为 "cpu"。

'cpu'
last_time_idx int

在插值过程中要考虑的最后一个时间索引。默认为 0。

0
rng_generator Optional[Generator]

用于可重复采样的可选 :class:torch.Generator。默认为 None。

None
源代码位于 bionemo/moco/interpolants/discrete_time/discrete/d3pm.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: DiscretePriorDistribution,
    noise_schedule: DiscreteNoiseSchedule,
    device: str = "cpu",
    last_time_idx: int = 0,
    rng_generator: Optional[torch.Generator] = None,
):
    """Initializes the D3PM interpolant.

    Args:
        time_distribution (TimeDistribution): The distribution of time steps, used to sample time points for the diffusion process.
        prior_distribution (PriorDistribution): The prior distribution of the variable, used as the starting point for the diffusion process.
        noise_schedule (DiscreteNoiseSchedule): The schedule of noise, defining the amount of noise added at each time step.
        device (str, optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
        last_time_idx (int, optional): The last time index to consider in the interpolation process. Defaults to 0.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
    """
    # We initialize with CPU due to numerical precision issues on A100 that are not observed on A6000
    super().__init__(time_distribution, prior_distribution, "cpu", rng_generator)
    self.noise_schedule = noise_schedule
    self._loss_function = nn.CrossEntropyLoss(reduction="none")
    self.timesteps = noise_schedule.nsteps
    self.num_classes = prior_distribution.num_classes
    self.terminal_distribution = prior_distribution.prior_dist.to(self.device)
    self._initialize_schedules(self.device)
    self.last_time_idx = last_time_idx
    self.to_device(device)

forward_process(data, t)

在时间 t 将前向过程应用于数据。

参数

名称 类型 描述 默认值
data Tensor

目标离散 ID

必需
t Tensor

时间

必需

返回

名称 类型 描述
Tensor Tensor

应用前向过程后的 x(t)

源代码位于 bionemo/moco/interpolants/discrete_time/discrete/d3pm.py
196
197
198
199
200
201
202
203
204
205
206
def forward_process(self, data: Tensor, t: Tensor) -> Tensor:
    """Apply the forward process to the data at time t.

    Args:
        data (Tensor): target discrete ids
        t (Tensor): time

    Returns:
        Tensor: x(t) after applying the forward process
    """
    return self.interpolate(data, t)

interpolate(data, t)

使用离散插值方法进行插值。

此方法实现了 D3PM 论文(https://arxiv.org/pdf/2107.03006)中的公式 2,该公式计算了在给定输入数据和噪声的情况下,时间 t 处的插值离散状态 xt,通过 q(xt|x0) = Cat(xt; p = x0*Qt_bar)。

参数

名称 类型 描述 默认值
data Tensor

要插值的输入数据。

必需
t Tensor

要插值的时间步长。

必需

返回

名称 类型 描述
Tensor

时间 t 处的插值离散状态 xt

源代码位于 bionemo/moco/interpolants/discrete_time/discrete/d3pm.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def interpolate(self, data: Tensor, t: Tensor):
    """Interpolate using discrete interpolation method.

    This method implements Equation 2 from the D3PM paper (https://arxiv.org/pdf/2107.03006), which
    calculates the interpolated discrete state `xt` at time `t` given the input data and noise
    via q(xt|x0) = Cat(xt; p = x0*Qt_bar).

    Args:
        data (Tensor): The input data to be interpolated.
        t (Tensor): The time step at which to interpolate.

    Returns:
        Tensor: The interpolated discrete state `xt` at time `t`.
    """
    if not _is_one_hot(data, self.num_classes):
        x1_hot = F.one_hot(data, self.num_classes)
    else:
        x1_hot = data
    ford = safe_index(self._Qt_bar, t - self.last_time_idx, data.device)
    if x1_hot.ndim > 3:  # einsum precision issues on A100 not A6000 for 2D inputs
        ford_prep = ford
        for _ in range(x1_hot.ndim - 2):
            ford_prep = ford_prep.unsqueeze(1)
        probs = (x1_hot.float().unsqueeze(-2) * ford_prep).sum(dim=(-2))
    else:
        probs = torch.einsum("b...j, bji -> b...i", [x1_hot.float(), ford])
    if torch.any((probs.sum(-1) - 1.0).abs() > 1e-4):
        raise ValueError(
            f"**INVALID BEHAVIOR** Probability Distribution does not sum to 1.0 for time {t}. "
            f"**INVESTIGATE YOUR DEVICE PRECISION**: This error has been triggered before on A100 by initializing the Qt terms on gpu. "
            f"Normalized to ensure validity. Original sums: {probs.sum(-1)}",
        )
    xt = self._sample_categorical(torch.log(probs) + 1.0e-6)
    return xt

loss(logits, target, xt, time, mask=None, vb_scale=0.0)

计算模型预测和目标输出之间的交叉熵损失。

损失在批次 x 节点 x 类 logits 和目标批次 x 节点之间计算。如果提供了掩码,则仅针对非掩码元素计算损失。此外,如果 vb_scale 大于 0,则计算变分下界损失并将其添加到总损失中。

参数

名称 类型 描述 默认值
logits Tensor

来自模型的预测输出,形状为批次 x 节点 x 类。

必需
target Tensor

模型预测的目标输出,形状为批次 x 节点。

必需
xt Tensor

当前数据点。

必需
时间 Tensor

计算损失的时间。

必需
mask Optional[Tensor]

数据点的掩码。默认为 None。

None
vb_scale Float

变分下界损失的比例因子。默认为 0.0。

0.0

返回

名称 类型 描述
Tensor

计算出的损失张量。如果 aggregate 为 True,则损失和变分下界损失将被聚合,并且

作为单个张量返回。否则,损失和变分下界损失将作为单独的张量返回。

源代码位于 bionemo/moco/interpolants/discrete_time/discrete/d3pm.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def loss(
    self,
    logits: Tensor,
    target: Tensor,
    xt: Tensor,
    time: Tensor,
    mask: Optional[Tensor] = None,
    vb_scale: Float = 0.0,
):
    """Calculate the cross-entropy loss between the model prediction and the target output.

    The loss is calculated between the batch x node x class logits and the target batch x node. If a mask is provided, the loss is
    calculated only for the non-masked elements. Additionally, if vb_scale is greater than 0, the variational lower bound loss is
    calculated and added to the total loss.

    Args:
        logits (Tensor): The predicted output from the model, with shape batch x node x class.
        target (Tensor): The target output for the model prediction, with shape batch x node.
        xt (Tensor): The current data point.
        time (Tensor): The time at which the loss is calculated.
        mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
        vb_scale (Float, optional): The scale factor for the variational lower bound loss. Defaults to 0.0.

    Returns:
        Tensor: The calculated loss tensor. If aggregate is True, the loss and variational lower bound loss are aggregated and
        returned as a single tensor. Otherwise, the loss and variational lower bound loss are returned as separate tensors.
    """
    assert target.ndim + 1 == logits.ndim
    loss = self._loss_function(logits.transpose(-1, 1), target.long())
    if mask is not None:
        loss = loss * mask
        num_non_masked_elements = torch.sum(mask, dim=-1)
        loss = torch.sum(loss, dim=(-1)) / num_non_masked_elements
    else:
        loss = torch.sum(loss, dim=(-1)) / logits.size(1)
    if vb_scale > 0:
        target = F.one_hot(target, num_classes=self.num_classes).float()
        true_q_posterior_logits = self._q_posterior_logits(target, time, xt)
        pred_q_posterior_logits = self._q_posterior_logits(logits, time, xt)
        vb_loss = self._variational_lower_bound(true_q_posterior_logits, pred_q_posterior_logits)
        vb_loss = vb_scale * vb_loss
    else:
        vb_loss = 0
    if vb_scale > 0:
        loss += vb_loss
    return loss

step(model_out, t, xt, mask=None, temperature=1.0, model_out_is_logits=True)

在离散插值方法中执行单个步骤,从时间 t 的当前离散状态 xt 过渡到下一个状态。

此步骤涉及

  1. 使用模型输出 model_out 和时间 t 的当前状态 xt 计算预测的 q-后验 logits。
  2. 使用 Gumbel-Softmax 技巧从预测的 q-后验分布中采样下一个状态。

参数

名称 类型 描述 默认值
model_out Tensor

当前时间步的模型输出,用于计算预测的 q-后验 logits。

必需
t Tensor

当前时间步,用于索引到过渡矩阵并计算预测的 q-后验 logits。

必需
xt Tensor

时间 t 的当前离散状态,用于计算预测的 q-后验 logits 并采样下一个状态。

必需
mask Optional[Tensor]

应用于下一个状态的可选掩码,可用于屏蔽某些标记或区域。默认为 None。

None
temperature Float

用于 Gumbel-Softmax 技巧的温度,它控制采样过程的随机性。默认为 1.0。

1.0
model_out_is_logits bool

一个标志,指示模型输出是否已采用 logits 形式。如果为 True,则假定输出为 logits;否则,将其转换为 logits。默认为 True。

True

返回

名称 类型 描述
Tensor

时间 t-1 的下一个离散状态。

源代码位于 bionemo/moco/interpolants/discrete_time/discrete/d3pm.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def step(
    self,
    model_out: Tensor,
    t: Tensor,
    xt: Tensor,
    mask: Optional[Tensor] = None,
    temperature: Float = 1.0,
    model_out_is_logits: bool = True,
):
    """Perform a single step in the discrete interpolant method, transitioning from the current discrete state `xt` at time `t` to the next state.

    This step involves:

    1. Computing the predicted q-posterior logits using the model output `model_out` and the current state `xt` at time `t`.
    2. Sampling the next state from the predicted q-posterior distribution using the Gumbel-Softmax trick.

    Args:
        model_out (Tensor): The output of the model at the current time step, which is used to compute the predicted q-posterior logits.
        t (Tensor): The current time step, which is used to index into the transition matrices and compute the predicted q-posterior logits.
        xt (Tensor): The current discrete state at time `t`, which is used to compute the predicted q-posterior logits and sample the next state.
        mask (Optional[Tensor], optional): An optional mask to apply to the next state, which can be used to mask out certain tokens or regions. Defaults to None.
        temperature (Float, optional): The temperature to use for the Gumbel-Softmax trick, which controls the randomness of the sampling process. Defaults to 1.0.
        model_out_is_logits (bool, optional): A flag indicating whether the model output is already in logits form. If True, the output is assumed to be logits; otherwise, it is converted to logits. Defaults to True.

    Returns:
        Tensor: The next discrete state at time `t-1`.
    """
    pred_q_posterior_logits = self._q_posterior_logits(model_out, t, xt, model_out_is_logits)
    nonzero_mask = (t != self.last_time_idx).to(xt.dtype).reshape(xt.shape[0], *([1] * (len(xt.shape))))
    x_next = self._sample_categorical(pred_q_posterior_logits, nonzero_mask, temperature=temperature)
    # # Apply mask if provided
    if mask is not None:
        x_next = x_next * mask
    return x_next