跳到内容

掩码

BertMaskConfig dataclass

用于在 BERT 风格模型中掩码令牌的配置。

属性

名称 类型 描述
mask_prob float

掩码令牌的概率。

mask_token_prob float

用掩码令牌替换被掩码令牌的概率。

random_token_prob float

用随机令牌替换被掩码令牌的概率。

源代码在 bionemo/llm/data/masking.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@dataclass(frozen=True)
class BertMaskConfig:
    """Configuration for masking tokens in a BERT-style model.

    Attributes:
        mask_prob: Probability of masking a token.
        mask_token_prob: Probability of replacing a masked token with the mask token.
        random_token_prob: Probability of replacing a masked token with a random token.
    """

    tokenizer: Tokenizer
    random_tokens: range
    mask_prob: float = 0.15
    mask_token_prob: float = 0.8
    random_token_prob: float = 0.1

    def __post_init__(self) -> None:
        """Check that the sum of `mask_token_prob` and `random_token_prob` is less than or equal to 1.0.

        Raises:
            ValueError: If the sum of `mask_token_prob` and `random_token_prob` is greater than 1.0.
        """
        if self.random_token_prob + self.mask_token_prob > 1.0:
            raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.")

__post_init__()

检查 mask_token_probrandom_token_prob 的总和是否小于或等于 1.0。

引发

类型 描述
ValueError

如果 mask_token_probrandom_token_prob 的总和大于 1.0。

源代码在 bionemo/llm/data/masking.py
40
41
42
43
44
45
46
47
def __post_init__(self) -> None:
    """Check that the sum of `mask_token_prob` and `random_token_prob` is less than or equal to 1.0.

    Raises:
        ValueError: If the sum of `mask_token_prob` and `random_token_prob` is greater than 1.0.
    """
    if self.random_token_prob + self.mask_token_prob > 1.0:
        raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.")

add_cls_and_eos_tokens(sequence, labels, loss_mask, cls_token=None, eos_token=None)

将 CLS 令牌前置,并将 EOS 令牌附加到掩码序列,更新损失掩码和标签。

这些标签永远不应被掩码,因此这是在掩码步骤之后完成的。

参数

名称 类型 描述 默认
sequence Tensor

输入(可能是掩码的)序列。

必需
labels Tensor

掩码位置处输入序列的真值。

必需
loss_mask Tensor

一个布尔张量,指示哪些令牌应包含在损失中。

必需
cls_token int | None

用于 CLS 令牌的令牌。如果为 None,则不添加 CLS 令牌。

None
eos_token int | None

用于 EOS 令牌的令牌。如果为 None,则不添加 EOS 令牌。

None

返回

类型 描述
tuple[Tensor, Tensor, Tensor]

相同的输入张量,其中添加了 CLS 和 EOS 令牌,并且标签和 loss_mask 相应地更新。

源代码在 bionemo/llm/data/masking.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_cls_and_eos_tokens(
    sequence: torch.Tensor,
    labels: torch.Tensor,
    loss_mask: torch.Tensor,
    cls_token: int | None = None,
    eos_token: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Prepends the CLS token and appends the EOS token to the masked sequence, updating the loss mask and labels.

    These labels should never be masked, so this is done after the masking step.

    Args:
        sequence: The input (likely masked) sequence.
        labels: The true values of the input sequence at the mask positions.
        loss_mask: A boolean tensor indicating which tokens should be included in the loss.
        cls_token: The token to use for the CLS token. If None, no CLS token is added.
        eos_token: The token to use for the EOS token. If None, no EOS token is added.

    Returns:
        The same input tensors with the CLS and EOS tokens added, and the labels and loss_mask updated accordingly.
    """
    # Prepend the CLS token and append the EOS token, and update the loss mask and labels accordingly.
    sequence = torch.cat(
        [
            torch.tensor([cls_token], dtype=sequence.dtype)
            if cls_token is not None
            else torch.tensor([], dtype=sequence.dtype),
            sequence,
            torch.tensor([eos_token], dtype=sequence.dtype)
            if eos_token is not None
            else torch.tensor([], dtype=sequence.dtype),
        ]
    )

    labels = torch.cat(
        [
            torch.tensor([-1], dtype=labels.dtype) if cls_token is not None else torch.tensor([], dtype=labels.dtype),
            labels,
            torch.tensor([-1], dtype=labels.dtype) if eos_token is not None else torch.tensor([], dtype=labels.dtype),
        ]
    )

    loss_mask = torch.cat(
        [
            torch.tensor([False]) if cls_token is not None else torch.tensor([], dtype=loss_mask.dtype),
            loss_mask,
            torch.tensor([False]) if eos_token is not None else torch.tensor([], dtype=loss_mask.dtype),
        ]
    )

    return sequence, labels, loss_mask

apply_bert_pretraining_mask(tokenized_sequence, random_seed, mask_config)

将预训练掩码应用于令牌化序列。

参数

名称 类型 描述 默认
tokenized_sequence Tensor

令牌化的蛋白质序列。

必需
random_seed int

用于可重复性的随机种子。

必需
mask_config BertMaskConfig

用于在 BERT 风格模型中掩码令牌的配置。

必需

返回

名称 类型 描述
masked_sequence Tensor

令牌化的序列,其中某些令牌被掩码。

labels Tensor

masked_sequence 形状相同的张量,包含掩码令牌的标签,非掩码令牌为 -1。

loss_mask Tensor

masked_sequence 形状相同的布尔张量,其中“True”表示哪些令牌应包含在损失中。

源代码在 bionemo/llm/data/masking.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def apply_bert_pretraining_mask(
    tokenized_sequence: torch.Tensor, random_seed: int, mask_config: BertMaskConfig
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Applies the pretraining mask to a tokenized sequence.

    Args:
        tokenized_sequence: Tokenized protein sequence.
        random_seed: Random seed for reproducibility.
        mask_config: Configuration for masking tokens in a BERT-style model.

    Returns:
        masked_sequence:
            The tokenized sequence with some tokens masked.
        labels:
            A tensor the same shape as `masked_sequence` containing labels for the masked tokens, with -1 for non-masked
            tokens.
        loss_mask:
            A boolean tensor the same shape as `masked_sequence`, where 'True' indicates which tokens should be included
            in the loss.
    """
    if mask_config.tokenizer.mask_token_id is None:
        raise ValueError("Tokenizer must have a mask token.")

    if mask_config.random_token_prob + mask_config.mask_token_prob > 1.0:
        raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.")

    # Set the seed so that __getitem__(idx) is always deterministic.
    # This is required by Megatron-LM's parallel strategies.
    generator = torch.Generator().manual_seed(random_seed)

    mask_stop_1 = mask_config.mask_prob * mask_config.mask_token_prob
    mask_stop_2 = mask_config.mask_prob * (mask_config.mask_token_prob + mask_config.random_token_prob)

    random_draws = torch.rand(tokenized_sequence.shape, generator=generator)  # Random draws for each token in [0, 1).

    # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
    # (identity). We don't want to mask special tokens.
    loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(mask_config.tokenizer.all_special_ids))
    loss_mask &= random_draws < mask_config.mask_prob

    # The first `mask_token_prob` fraction of the `mask_prob` tokens are replaced with the mask token.
    mask_token_mask = (random_draws < mask_stop_1) & loss_mask

    # The next `random_token_prob` fraction of the `mask_prob` tokens are replaced with a random token.
    random_token_mask = ((random_draws >= mask_stop_1) & (random_draws < mask_stop_2)) & loss_mask

    # The remaining tokens are implicitly left as-is, representing an identity mask.

    # Mask the tokens.
    masked_sequence = tokenized_sequence.clone()
    masked_sequence[mask_token_mask] = mask_config.tokenizer.mask_token_id
    num_random_tokens: int = random_token_mask.sum().item()  # type: ignore[assignment]
    masked_sequence[random_token_mask] = torch.randint(
        low=mask_config.random_tokens.start,
        high=mask_config.random_tokens.stop,
        size=(num_random_tokens,),
        dtype=masked_sequence.dtype,
        generator=generator,
    )

    # Create the labels for the masked tokens.
    labels = tokenized_sequence.clone()
    labels[~loss_mask] = -100  # Ignore loss for non-masked tokens.

    return masked_sequence, labels, loss_mask