跳到内容

Mdlm

MDLM

基类:Interpolant

一个掩码离散扩散语言模型 (MDLM) 插值器。


示例

>>> import torch
>>> from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
>>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM
>>> from bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
>>> from bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule


mdlm = MDLM(
    time_distribution = UniformTimeDistribution(discrete_time = False,...),
    prior_distribution = DiscreteMaskedPrior(...),
    noise_schedule = CosineExpNoiseTransform(...),
    )
model = Model(...)

# Training
for epoch in range(1000):
    data = data_loader.get(...)
    time = mdlm.sample_time(batch_size)
    xt = mdlm.interpolate(data, time)

    logits = model(xt, time)
    loss = mdlm.loss(logits, data, xt, time)
    loss.backward()

# Generation
x_pred = mdlm.sample_prior(data.shape)
schedule = LinearTimeSchedule(...)
inference_time = schedule.generate_schedule()
dts = schedue.discreteize()
for t, dt in zip(inference_time, dts):
    time = torch.full((batch_size,), t)
    logits = model(x_pred, time)
    x_pred = mdlm.step(logits, time, x_pred, dt)
return x_pred

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class MDLM(Interpolant):
    """A Masked discrete Diffusion Language Model (MDLM) interpolant.

     -------

    Examples:
    ```python
    >>> import torch
    >>> from bionemo.moco.distributions.prior.discrete.mask import DiscreteMaskedPrior
    >>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
    >>> from bionemo.moco.interpolants.continuous_time.discrete.mdlm import MDLM
    >>> from bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
    >>> from bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule


    mdlm = MDLM(
        time_distribution = UniformTimeDistribution(discrete_time = False,...),
        prior_distribution = DiscreteMaskedPrior(...),
        noise_schedule = CosineExpNoiseTransform(...),
        )
    model = Model(...)

    # Training
    for epoch in range(1000):
        data = data_loader.get(...)
        time = mdlm.sample_time(batch_size)
        xt = mdlm.interpolate(data, time)

        logits = model(xt, time)
        loss = mdlm.loss(logits, data, xt, time)
        loss.backward()

    # Generation
    x_pred = mdlm.sample_prior(data.shape)
    schedule = LinearTimeSchedule(...)
    inference_time = schedule.generate_schedule()
    dts = schedue.discreteize()
    for t, dt in zip(inference_time, dts):
        time = torch.full((batch_size,), t)
        logits = model(x_pred, time)
        x_pred = mdlm.step(logits, time, x_pred, dt)
    return x_pred

    ```
    """

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: DiscreteMaskedPrior,
        noise_schedule: ContinuousExpNoiseTransform,
        device: str = "cpu",
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initialize the Masked Discrete Language Model (MDLM) interpolant.

        Args:
            time_distribution (TimeDistribution): The distribution governing the time variable in the diffusion process.
            prior_distribution (DiscreteMaskedPrior): The prior distribution over the discrete token space, including masked tokens.
            noise_schedule (ContinuousExpNoiseTransform): The noise schedule defining the noise intensity as a function of time.
            device (str, optional): The device to use for computations. Defaults to "cpu".
            rng_generator (Optional[torch.Generator], optional): The random number generator for reproducibility. Defaults to None.
        """
        super().__init__(time_distribution, prior_distribution, device, rng_generator)
        if not isinstance(prior_distribution, DiscreteMaskedPrior):
            raise ValueError("DiscreteMaskedPrior required for MDLM")
        if not isinstance(noise_schedule, ContinuousExpNoiseTransform):
            raise ValueError("ContinuousExpNoiseTransform required for MDLM")
        self.noise_schedule = noise_schedule
        self.num_classes = prior_distribution.num_classes
        self.mask_index = prior_distribution.mask_dim
        # Gumbel used for confidence sampling. Note rng_generator not compatible with torch.Distribution.
        # self.gumbel_dist = torch.distributions.Gumbel(torch.tensor(0.0), torch.tensor(1.0))

    def interpolate(self, data: Tensor, t: Tensor):
        """Get x(t) with given time t from noise and data.

        Args:
            data (Tensor): target discrete ids
            t (Tensor): time
        """
        if data.dtype == torch.float and data.ndim > 2:
            x0 = data.argmax(-1)
        else:
            x0 = data
        sigma = self.noise_schedule.calculate_sigma(t, data.device)
        alpha = self.noise_schedule.sigma_to_alpha(sigma)
        p_mask = 1 - alpha
        p_mask = pad_like(p_mask, x0)
        mask_indices = torch.rand(*x0.shape, device=x0.device, generator=self.rng_generator) < p_mask
        xt = torch.where(mask_indices, self.mask_index, x0)
        return xt

    def forward_process(self, data: Tensor, t: Tensor) -> Tensor:
        """Apply the forward process to the data at time t.

        Args:
            data (Tensor): target discrete ids
            t (Tensor): time

        Returns:
            Tensor: x(t) after applying the forward process
        """
        return self.interpolate(data, t)

    def loss(
        self,
        logits: Tensor,
        target: Tensor,
        xt: Tensor,
        time: Tensor,
        mask: Optional[Tensor] = None,
        use_weight=True,
    ):
        """Calculate the cross-entropy loss between the model prediction and the target output.

        The loss is calculated between the batch x node x class logits and the target batch x node,
        considering the current state of the discrete sequence `xt` at time `time`.

        If `use_weight` is True, the loss is weighted by the reduced form of the MDLM time weight for continuous NELBO,
        as specified in equation 11 of https://arxiv.org/pdf/2406.07524. This weight is proportional to the derivative
        of the noise schedule with respect to time, and is used to emphasize the importance of accurate predictions at
        certain times in the diffusion process.

        Args:
            logits (Tensor): The predicted output from the model, with shape batch x node x class.
            target (Tensor): The target output for the model prediction, with shape batch x node.
            xt (Tensor): The current state of the discrete sequence, with shape batch x node.
            time (Tensor): The time at which the loss is calculated.
            mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
            use_weight (bool, optional): Whether to use the MDLM time weight for the loss. Defaults to True.

        Returns:
            Tensor: The calculated loss batch tensor.
        """
        logprobs = self._subs_parameterization(logits, xt)
        log_p_theta = torch.gather(input=logprobs, dim=-1, index=target[..., None]).squeeze(-1)

        sigma = self.noise_schedule.calculate_sigma(time, target.device)
        dsigma = self.noise_schedule.d_dt_sigma(time, target.device)  # type: ignore
        loss = -log_p_theta
        if use_weight:
            loss = loss * (dsigma / torch.expm1(sigma))[:, None]

        if mask is not None:
            loss = loss * mask
            num_non_masked_elements = torch.sum(mask, dim=-1)
            loss = torch.sum(loss, dim=(-1)) / num_non_masked_elements
        else:
            loss = torch.sum(loss, dim=(-1)) / logits.size(1)
        return loss

    def _subs_parameterization(self, logits: Tensor, xt: Tensor) -> Tensor:
        """Apply subsititution parameterization to the logits.

        This function enforces that the model can never predict a mask token by lowering the mask logits.
        Then, for all unmasked tokens, it copies over from xt to enable carry over unmasked.
        Once a token is unmasked, it stays the same.
        See Sec. 3.2.3 https://arxiv.org/pdf/2406.07524.

        Note that recent work has shown that allowing the model to rethink
        carry over unmasking is beneficial https://arxiv.org/abs/2410.06264.

        Args:
            logits (Tensor): The logits tensor with shape batch x node x class.
            xt (Tensor): The tensor of unmasked tokens with shape batch x node.

        Returns:
            Tensor: The modified logits tensor with substitution parameterization applied.
        """
        logits[..., self.mask_index] += -1000000.0  # clean input is never masked
        logprobs = logits - torch.logsumexp(logits, dim=-1, keepdim=True)  # normalize
        unmasked_indices = xt != self.mask_index
        logprobs[unmasked_indices] = -1000000.0
        logprobs[unmasked_indices, xt[unmasked_indices]] = 0  # Unmasked token remains unchanged
        return logprobs

    def step(self, logits, t, xt, dt) -> Tensor:
        """Perform a single step of MDLM DDPM step.

        Parameters:
        logits (Tensor): The input logits.
        t (float): The current time step.
        xt (Tensor): The current state.
        dt (float): The time step increment.

        Returns:
        Tensor: The updated state.
        """
        sigma_t = self.noise_schedule.calculate_sigma(t, logits.device)
        sigma_s = self.noise_schedule.calculate_sigma(t - dt, logits.device)
        alpha_t = torch.exp(-sigma_t)
        alpha_s = torch.exp(-sigma_s)
        p_mask_s = 1 - alpha_s
        alpha_t = pad_like(alpha_t, logits)
        alpha_s = pad_like(alpha_s, logits)
        p_mask_s = pad_like(p_mask_s, logits)
        # Apply subs parameterization
        log_p_x0 = self._subs_parameterization(logits, xt)
        if p_mask_s.ndim != log_p_x0.ndim:
            raise ValueError(f"Dimension Mistmatch {p_mask_s.shape} {log_p_x0.shape}")
        # Equation 6 from MDLM
        prob_s_given_t = log_p_x0.exp() * (alpha_s - alpha_t)  # righthand side (alpha_s - alpha_t)*x
        prob_s_given_t[..., self.mask_index] = p_mask_s[..., 0]  # lefthand side (1 - alpha_s)*M
        sampled_x = self._sample_categorical(prob_s_given_t)
        carry_over_unmask = (xt != self.mask_index).to(xt.dtype)
        return carry_over_unmask * xt + (1 - carry_over_unmask) * sampled_x

    def _sample_categorical(self, categorical_probs: Tensor) -> Tensor:
        """Sample from a categorical distribution using the Gumbel trick.

        Args:
            categorical_probs (Tensor): The probabilities of each category, shape batch x node x class.

        Returns:
            Tensor: The sampled category indices, shape batch x node.
        """
        gumbel_norm = (
            1e-10
            - (
                torch.rand(*categorical_probs.shape, device=categorical_probs.device, generator=self.rng_generator)
                + 1e-10
            ).log()
        )
        scaled_proability = categorical_probs / gumbel_norm
        return scaled_proability.argmax(dim=-1)

    def step_confidence(
        self,
        logits: Tensor,
        xt: Tensor,
        curr_step: int,
        num_steps: int,
        logit_temperature: float = 1.0,
        randomness: float = 1.0,
        confidence_temperature: float = 1.0,
    ) -> Tensor:
        """Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.

        Method taken from GenMol Seul et al.

        Args:
            logits: Predicted logits
            xt: Input sequence
            curr_step: Current step
            num_steps: Total number of steps
            logit_temperature: Temperature for softmax over logits
            randomness: Scale for Gumbel noise
            confidence_temperature: Temperature for Gumbel confidence

        Returns:
            Updated input sequence xt
        """
        if xt.ndim > 3:
            raise NotImplementedError(
                "step_confidence is implemented for Batch x Sequence x State Space shaped tensors."
            )
        xt = xt.clone()
        log_p_x0 = self._subs_parameterization(logits, xt)
        # sample the code from the softmax prediction
        probs = torch.softmax(log_p_x0 / logit_temperature, dim=-1)
        preds = torch.distributions.Categorical(probs=probs).sample()

        confidence = probs.gather(-1, preds.unsqueeze(-1)).squeeze(-1)
        # add Gumbel noise decreasing over the sampling process
        ratio = curr_step / (num_steps - 1)
        # Using manual definition of 0,1 Gumbel to pass in generator
        gumbel_sample = -torch.log(-torch.log(torch.rand(xt.shape, generator=self.rng_generator))).to(logits.device)
        # gumbel_sample = self.gumbel_dist.sample(xt.shape).to(logits.device)
        gumbel_noise = gumbel_sample * randomness * (1 - ratio)  # type: ignore
        confidence = (
            (torch.log(confidence) + gumbel_noise) / confidence_temperature
        )  # stems from tau of https://pytorch.ac.cn/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax

        # do not predict on already predicted tokens
        mask = xt == self.mask_index
        confidence[~mask] = -torch.inf

        # choose the predicted token with the highest confidence
        confidence_threshold, idx_mask = torch.topk(confidence, k=1, dim=-1)
        confidence_threshold = confidence_threshold[:, -1].unsqueeze(-1)

        # replace the chosen tokens
        to_replace = confidence >= confidence_threshold
        to_replace = (mask.float() * to_replace.float()).bool()
        xt[to_replace] = preds[to_replace]
        return xt

    def step_argmax(self, model_out: Tensor):
        """Returns the index of the maximum value in the last dimension of the model output.

        Args:
            model_out (Tensor): The output of the model.

        Returns:
            Tensor: The index of the maximum value in the last dimension of the model output.
        """
        return model_out.argmax(dim=-1)

    def calculate_score(self, logits, x, t):
        """Returns score of the given sample x at time t with the corresponding model output logits.

        Args:
            logits (Tensor): The output of the model.
            x (Tensor): The current data point.
            t (Tensor): The current time.

        Returns:
            Tensor: The score defined in Appendix C.3 Equation 76 of MDLM.
        """
        sigma_t = self.noise_schedule.calculate_sigma(t, logits.device)
        log_ratio = -torch.log(
            torch.expm1(sigma_t)
        )  # log ( exp(-sigma) / (1 - exp(-sigma))) = log(1/ (exp(sigma) - 1))

        # Create masked and unmasked log scores
        masked_log_score = logits + pad_like(log_ratio, logits)  # xt is masked and prediction is not
        masked_log_score[..., self.mask_index] = 0  # xt and prediction are mask

        unmasked_log_score = torch.full_like(logits, -1000000.0)
        unmasked_log_score.scatter_(-1, x[..., None], 0)  # place zeros where current predictions are
        unmasked_log_score[..., self.mask_index] = -pad_like(log_ratio, logits[..., 0])

        # Combine masked and unmasked log scores
        masked_indices = (x == self.mask_index).to(logits.dtype)[..., None]
        log_score = masked_log_score * masked_indices + unmasked_log_score * (1 - masked_indices)

        return log_score.exp()

__init__(time_distribution, prior_distribution, noise_schedule, device='cpu', rng_generator=None)

初始化掩码离散语言模型 (MDLM) 插值器。

参数

名称 类型 描述 默认值
time_distribution TimeDistribution

扩散过程中控制时间变量的分布。

必需
prior_distribution DiscreteMaskedPrior

离散 Token 空间上的先验分布,包括掩码 Token。

必需
noise_schedule ContinuousExpNoiseTransform

定义噪声强度随时间变化的噪声调度。

必需
device 字符串

用于计算的设备。默认为 "cpu"。

'cpu'
rng_generator Optional[Generator]

用于可重复性的随机数生成器。默认为 None。

None
源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: DiscreteMaskedPrior,
    noise_schedule: ContinuousExpNoiseTransform,
    device: str = "cpu",
    rng_generator: Optional[torch.Generator] = None,
):
    """Initialize the Masked Discrete Language Model (MDLM) interpolant.

    Args:
        time_distribution (TimeDistribution): The distribution governing the time variable in the diffusion process.
        prior_distribution (DiscreteMaskedPrior): The prior distribution over the discrete token space, including masked tokens.
        noise_schedule (ContinuousExpNoiseTransform): The noise schedule defining the noise intensity as a function of time.
        device (str, optional): The device to use for computations. Defaults to "cpu".
        rng_generator (Optional[torch.Generator], optional): The random number generator for reproducibility. Defaults to None.
    """
    super().__init__(time_distribution, prior_distribution, device, rng_generator)
    if not isinstance(prior_distribution, DiscreteMaskedPrior):
        raise ValueError("DiscreteMaskedPrior required for MDLM")
    if not isinstance(noise_schedule, ContinuousExpNoiseTransform):
        raise ValueError("ContinuousExpNoiseTransform required for MDLM")
    self.noise_schedule = noise_schedule
    self.num_classes = prior_distribution.num_classes
    self.mask_index = prior_distribution.mask_dim

calculate_score(logits, x, t)

返回给定样本 x 在时间 t 及其对应模型输出 logits 的分数。

参数

名称 类型 描述 默认值
logits 张量

模型的输出。

必需
x 张量

当前数据点。

必需
t 张量

当前时间。

必需

返回值

名称 类型 描述
张量

MDLM 附录 C.3 方程 76 中定义的分数。

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def calculate_score(self, logits, x, t):
    """Returns score of the given sample x at time t with the corresponding model output logits.

    Args:
        logits (Tensor): The output of the model.
        x (Tensor): The current data point.
        t (Tensor): The current time.

    Returns:
        Tensor: The score defined in Appendix C.3 Equation 76 of MDLM.
    """
    sigma_t = self.noise_schedule.calculate_sigma(t, logits.device)
    log_ratio = -torch.log(
        torch.expm1(sigma_t)
    )  # log ( exp(-sigma) / (1 - exp(-sigma))) = log(1/ (exp(sigma) - 1))

    # Create masked and unmasked log scores
    masked_log_score = logits + pad_like(log_ratio, logits)  # xt is masked and prediction is not
    masked_log_score[..., self.mask_index] = 0  # xt and prediction are mask

    unmasked_log_score = torch.full_like(logits, -1000000.0)
    unmasked_log_score.scatter_(-1, x[..., None], 0)  # place zeros where current predictions are
    unmasked_log_score[..., self.mask_index] = -pad_like(log_ratio, logits[..., 0])

    # Combine masked and unmasked log scores
    masked_indices = (x == self.mask_index).to(logits.dtype)[..., None]
    log_score = masked_log_score * masked_indices + unmasked_log_score * (1 - masked_indices)

    return log_score.exp()

forward_process(data, t)

在时间 t 对数据应用前向过程。

参数

名称 类型 描述 默认值
data 张量

目标离散 ID

必需
t 张量

time

必需

返回值

名称 类型 描述
张量 张量

应用前向过程后的 x(t)

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
121
122
123
124
125
126
127
128
129
130
131
def forward_process(self, data: Tensor, t: Tensor) -> Tensor:
    """Apply the forward process to the data at time t.

    Args:
        data (Tensor): target discrete ids
        t (Tensor): time

    Returns:
        Tensor: x(t) after applying the forward process
    """
    return self.interpolate(data, t)

interpolate(data, t)

从噪声和数据中获取给定时间 t 的 x(t)。

参数

名称 类型 描述 默认值
data 张量

目标离散 ID

必需
t 张量

time

必需
源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def interpolate(self, data: Tensor, t: Tensor):
    """Get x(t) with given time t from noise and data.

    Args:
        data (Tensor): target discrete ids
        t (Tensor): time
    """
    if data.dtype == torch.float and data.ndim > 2:
        x0 = data.argmax(-1)
    else:
        x0 = data
    sigma = self.noise_schedule.calculate_sigma(t, data.device)
    alpha = self.noise_schedule.sigma_to_alpha(sigma)
    p_mask = 1 - alpha
    p_mask = pad_like(p_mask, x0)
    mask_indices = torch.rand(*x0.shape, device=x0.device, generator=self.rng_generator) < p_mask
    xt = torch.where(mask_indices, self.mask_index, x0)
    return xt

loss(logits, target, xt, time, mask=None, use_weight=True)

计算模型预测和目标输出之间的交叉熵损失。

损失是在批次 x 节点 x 类别 logits 和目标批次 x 节点之间计算的,考虑了时间 time 离散序列 xt 的当前状态。

如果 use_weight 为 True,则损失将由连续 NELBO 的 MDLM 时间权重的简化形式加权,如 https://arxiv.org/pdf/2406.07524 的方程 11 中所述。 此权重与噪声调度相对于时间的导数成正比,用于强调在扩散过程中的某些时间准确预测的重要性。

参数

名称 类型 描述 默认值
logits 张量

logits

必需
target 张量

模型预测的目标输出,形状为 批次 x 节点。

必需
xt 张量

离散序列的当前状态,形状为 批次 x 节点。

必需
time 张量

time

必需
mask Optional[Tensor]

数据点的掩码。默认为 None。

None
use_weight 布尔值

是否使用 MDLM 时间权重计算损失。默认为 True。

True

返回值

名称 类型 描述
张量

返回值

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def loss(
    self,
    logits: Tensor,
    target: Tensor,
    xt: Tensor,
    time: Tensor,
    mask: Optional[Tensor] = None,
    use_weight=True,
):
    """Calculate the cross-entropy loss between the model prediction and the target output.

    The loss is calculated between the batch x node x class logits and the target batch x node,
    considering the current state of the discrete sequence `xt` at time `time`.

    If `use_weight` is True, the loss is weighted by the reduced form of the MDLM time weight for continuous NELBO,
    as specified in equation 11 of https://arxiv.org/pdf/2406.07524. This weight is proportional to the derivative
    of the noise schedule with respect to time, and is used to emphasize the importance of accurate predictions at
    certain times in the diffusion process.

    Args:
        logits (Tensor): The predicted output from the model, with shape batch x node x class.
        target (Tensor): The target output for the model prediction, with shape batch x node.
        xt (Tensor): The current state of the discrete sequence, with shape batch x node.
        time (Tensor): The time at which the loss is calculated.
        mask (Optional[Tensor], optional): The mask for the data point. Defaults to None.
        use_weight (bool, optional): Whether to use the MDLM time weight for the loss. Defaults to True.

    Returns:
        Tensor: The calculated loss batch tensor.
    """
    logprobs = self._subs_parameterization(logits, xt)
    log_p_theta = torch.gather(input=logprobs, dim=-1, index=target[..., None]).squeeze(-1)

    sigma = self.noise_schedule.calculate_sigma(time, target.device)
    dsigma = self.noise_schedule.d_dt_sigma(time, target.device)  # type: ignore
    loss = -log_p_theta
    if use_weight:
        loss = loss * (dsigma / torch.expm1(sigma))[:, None]

    if mask is not None:
        loss = loss * mask
        num_non_masked_elements = torch.sum(mask, dim=-1)
        loss = torch.sum(loss, dim=(-1)) / num_non_masked_elements
    else:
        loss = torch.sum(loss, dim=(-1)) / logits.size(1)
    return loss

step(logits, t, xt, dt)

执行 MDLM DDPM 步进的单步操作。

参数:logits (张量):输入 logits。t (浮点数):当前时间步。xt (张量):当前状态。dt (浮点数):时间步增量。

返回值:张量:更新后的状态。

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def step(self, logits, t, xt, dt) -> Tensor:
    """Perform a single step of MDLM DDPM step.

    Parameters:
    logits (Tensor): The input logits.
    t (float): The current time step.
    xt (Tensor): The current state.
    dt (float): The time step increment.

    Returns:
    Tensor: The updated state.
    """
    sigma_t = self.noise_schedule.calculate_sigma(t, logits.device)
    sigma_s = self.noise_schedule.calculate_sigma(t - dt, logits.device)
    alpha_t = torch.exp(-sigma_t)
    alpha_s = torch.exp(-sigma_s)
    p_mask_s = 1 - alpha_s
    alpha_t = pad_like(alpha_t, logits)
    alpha_s = pad_like(alpha_s, logits)
    p_mask_s = pad_like(p_mask_s, logits)
    # Apply subs parameterization
    log_p_x0 = self._subs_parameterization(logits, xt)
    if p_mask_s.ndim != log_p_x0.ndim:
        raise ValueError(f"Dimension Mistmatch {p_mask_s.shape} {log_p_x0.shape}")
    # Equation 6 from MDLM
    prob_s_given_t = log_p_x0.exp() * (alpha_s - alpha_t)  # righthand side (alpha_s - alpha_t)*x
    prob_s_given_t[..., self.mask_index] = p_mask_s[..., 0]  # lefthand side (1 - alpha_s)*M
    sampled_x = self._sample_categorical(prob_s_given_t)
    carry_over_unmask = (xt != self.mask_index).to(xt.dtype)
    return carry_over_unmask * xt + (1 - carry_over_unmask) * sampled_x

step_argmax(model_out)

返回模型输出最后一维中最大值的索引。

参数

名称 类型 描述 默认值
model_out 张量

模型的输出。

必需

返回值

名称 类型 描述
张量

返回值

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
316
317
318
319
320
321
322
323
324
325
def step_argmax(self, model_out: Tensor):
    """Returns the index of the maximum value in the last dimension of the model output.

    Args:
        model_out (Tensor): The output of the model.

    Returns:
        Tensor: The index of the maximum value in the last dimension of the model output.
    """
    return model_out.argmax(dim=-1)

step_confidence(logits, xt, curr_step, num_steps, logit_temperature=1.0, randomness=1.0, confidence_temperature=1.0)

通过从预测的 logits 中采样并添加 Gumbel 噪声来更新输入序列 xt。

方法取自 GenMol Seul 等人。

参数

名称 类型 描述 默认值
logits 张量

logits

必需
xt 张量

xt

必需
curr_step 整数

当前步骤

必需
num_steps 整数

总步数

必需
logit_temperature 浮点数

用于 logits 上 softmax 的温度

1.0
randomness 浮点数

Gumbel 噪声的尺度

1.0
confidence_temperature 浮点数

Gumbel 置信度的温度

1.0

返回值

类型 描述
张量

返回值

源代码位于 bionemo/moco/interpolants/continuous_time/discrete/mdlm.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def step_confidence(
    self,
    logits: Tensor,
    xt: Tensor,
    curr_step: int,
    num_steps: int,
    logit_temperature: float = 1.0,
    randomness: float = 1.0,
    confidence_temperature: float = 1.0,
) -> Tensor:
    """Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.

    Method taken from GenMol Seul et al.

    Args:
        logits: Predicted logits
        xt: Input sequence
        curr_step: Current step
        num_steps: Total number of steps
        logit_temperature: Temperature for softmax over logits
        randomness: Scale for Gumbel noise
        confidence_temperature: Temperature for Gumbel confidence

    Returns:
        Updated input sequence xt
    """
    if xt.ndim > 3:
        raise NotImplementedError(
            "step_confidence is implemented for Batch x Sequence x State Space shaped tensors."
        )
    xt = xt.clone()
    log_p_x0 = self._subs_parameterization(logits, xt)
    # sample the code from the softmax prediction
    probs = torch.softmax(log_p_x0 / logit_temperature, dim=-1)
    preds = torch.distributions.Categorical(probs=probs).sample()

    confidence = probs.gather(-1, preds.unsqueeze(-1)).squeeze(-1)
    # add Gumbel noise decreasing over the sampling process
    ratio = curr_step / (num_steps - 1)
    # Using manual definition of 0,1 Gumbel to pass in generator
    gumbel_sample = -torch.log(-torch.log(torch.rand(xt.shape, generator=self.rng_generator))).to(logits.device)
    # gumbel_sample = self.gumbel_dist.sample(xt.shape).to(logits.device)
    gumbel_noise = gumbel_sample * randomness * (1 - ratio)  # type: ignore
    confidence = (
        (torch.log(confidence) + gumbel_noise) / confidence_temperature
    )  # stems from tau of https://pytorch.ac.cn/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax

    # do not predict on already predicted tokens
    mask = xt == self.mask_index
    confidence[~mask] = -torch.inf

    # choose the predicted token with the highest confidence
    confidence_threshold, idx_mask = torch.topk(confidence, k=1, dim=-1)
    confidence_threshold = confidence_threshold[:, -1].unsqueeze(-1)

    # replace the chosen tokens
    to_replace = confidence >= confidence_threshold
    to_replace = (mask.float() * to_replace.float()).bool()
    xt[to_replace] = preds[to_replace]
    return xt