跳到内容

微调 Token 分类器

ClassifierInput

基类:TypedDict

在 ClassifierLossReduction 的 forward 方法中用作输入。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
53
54
55
56
57
class ClassifierInput(TypedDict):
    """Used as input in the ClassifierLossReduction's forward method."""

    labels: Tensor
    loss_mask: Tensor

ClassifierLossReduction

基类:BERTMLMLossWithReduction

用于计算分类输出交叉熵损失的类。

此类用于计算损失,以及记录跨微批次的缩减损失。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class ClassifierLossReduction(BERTMLMLossWithReduction):
    """A class for calculating the cross entropy loss of classification output.

    This class used for calculating the loss, and for logging the reduced loss across micro batches.
    """

    def forward(
        self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput
    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside classification head.

        Returns:
            A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
            the reduce method, which currently only works for logging.
        """
        targets = batch["labels"]  # [b, s]
        # [b, s, num_class] -> [b, num_class, s] to satisfy input dims for cross_entropy loss
        classification_output = forward_out["classification_output"].permute(0, 2, 1)
        loss_mask = batch["loss_mask"]  # [b, s]

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
            # losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
            masked_loss = torch.masked_select(losses, loss_mask)
            loss = masked_loss.sum() / loss_mask.sum()
        else:  # TODO: support CP with masked_token_loss_context_parallel
            raise NotImplementedError("Context Parallel support is not implemented for this loss")

        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return losses.mean()

forward(batch, forward_out)

计算微批次内的损失。微批次是单个 GPU 上的数据批次。

参数

名称 类型 描述 默认值
batch ClassifierInput

传递到 LitAutoEncoder 内部原始 forward 的数据批次。

必需
forward_out Esm2FineTuneTokenOutput

分类头内部 forward 方法的输出。

必需

返回值

类型 描述
Tensor

一个元组,其中损失张量将用于反向传播,字典将传递到

PerTokenLossDict | SameSizeLossDict

reduce 方法,目前仅用于日志记录。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def forward(
    self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside classification head.

    Returns:
        A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
        the reduce method, which currently only works for logging.
    """
    targets = batch["labels"]  # [b, s]
    # [b, s, num_class] -> [b, num_class, s] to satisfy input dims for cross_entropy loss
    classification_output = forward_out["classification_output"].permute(0, 2, 1)
    loss_mask = batch["loss_mask"]  # [b, s]

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
        # losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
        masked_loss = torch.masked_select(losses, loss_mask)
        loss = masked_loss.sum() / loss_mask.sum()
    else:  # TODO: support CP with masked_token_loss_context_parallel
        raise NotImplementedError("Context Parallel support is not implemented for this loss")

    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

跨微批次工作。(单个 GPU 上的数据)。

注意:目前仅用于日志记录,此损失将不会用于反向传播。

参数

名称 类型 描述 默认值
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

forward 的输出列表

必需

返回值

类型 描述
Tensor

一个张量,它是损失的平均值。(用于日志记录)。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
101
102
103
104
105
106
107
108
109
110
111
112
113
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return losses.mean()

ESM2FineTuneTokenConfig dataclass

基类:ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], IOMixinWithGettersSetters

ExampleConfig 是一个用于配置模型的数据类。

ModelParallelConfig 中的 Timers 是 megatron 向前兼容性所必需的。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@dataclass
class ESM2FineTuneTokenConfig(
    ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ESM2FineTuneTokenModel] = ESM2FineTuneTokenModel
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["classification_head"])

    encoder_frozen: bool = True  # freeze encoder parameters
    cnn_num_classes: int = 3  # number of classes in each label
    cnn_dropout: float = 0.25
    cnn_hidden_dim: int = 32  # The number of output channels in the bottleneck layer of the convolution.

    def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
        """The loss function type."""
        return ClassifierLossReduction

get_loss_reduction_class()

损失函数类型。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
203
204
205
def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
    """The loss function type."""
    return ClassifierLossReduction

ESM2FineTuneTokenModel

基类:ESM2Model

适用于微调的 ESM2 模型。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ESM2FineTuneTokenModel(ESM2Model):
    """An ESM2 model that is suitable for fine tuning."""

    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
        """Constructor."""
        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)

        # freeze encoder parameters
        if config.encoder_frozen:
            for _, param in self.named_parameters():
                param.requires_grad = False

        self.include_hiddens_finetuning = (
            include_hiddens  # this include_hiddens is for the final output of fine-tuning
        )
        # If post_process is True that means that we are at the last megatron parallelism stage and we can
        #   apply the head.
        if post_process:
            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
            self.classification_head = MegatronConvNetHead(config)

    def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput:
        """Inference."""
        output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs)
        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
        if not self.post_process:
            return output  # we are not at the last pipeline stage so just return what the parent has
        # Double check that the output from the parent has everything we need to do prediction in this head.
        if not isinstance(output, dict) or "hidden_states" not in output:
            raise ValueError(
                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
                "Make sure include_hiddens=True in the call to super().__init__"
            )
        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
        hidden_states: Tensor = output["hidden_states"]
        # Predict our 1d regression target
        classification_output = self.classification_head(hidden_states)
        if not self.include_hiddens_finetuning:
            del output["hidden_states"]
        output["classification_output"] = classification_output
        return output

__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)

构造函数。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
    """Constructor."""
    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)

    # freeze encoder parameters
    if config.encoder_frozen:
        for _, param in self.named_parameters():
            param.requires_grad = False

    self.include_hiddens_finetuning = (
        include_hiddens  # this include_hiddens is for the final output of fine-tuning
    )
    # If post_process is True that means that we are at the last megatron parallelism stage and we can
    #   apply the head.
    if post_process:
        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
        self.classification_head = MegatronConvNetHead(config)

forward(*args, **kwargs)

推理。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput:
    """Inference."""
    output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs)
    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
    if not self.post_process:
        return output  # we are not at the last pipeline stage so just return what the parent has
    # Double check that the output from the parent has everything we need to do prediction in this head.
    if not isinstance(output, dict) or "hidden_states" not in output:
        raise ValueError(
            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
            "Make sure include_hiddens=True in the call to super().__init__"
        )
    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
    hidden_states: Tensor = output["hidden_states"]
    # Predict our 1d regression target
    classification_output = self.classification_head(hidden_states)
    if not self.include_hiddens_finetuning:
        del output["hidden_states"]
    output["classification_output"] = classification_output
    return output

Esm2FineTuneTokenOutput

基类:BioBertOutput

来自 ESM2FineTuneTokenModel 的推理输出。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
60
61
62
63
class Esm2FineTuneTokenOutput(BioBertOutput):
    """Inference output from ESM2FineTuneTokenModel."""

    classification_output: Tensor

InMemoryPerTokenValueDataset

基类:Dataset

标记字符串的内存数据集,按需进行分词。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
class InMemoryPerTokenValueDataset(Dataset):
    """An in-memory dataset of labeled strings, which are tokenized on demand."""

    def __init__(
        self,
        data: Sequence[Tuple[str, str]],
        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
        seed: int = np.random.SeedSequence().entropy,  # type: ignore
    ):
        """Initializes a dataset for per-token classification fine-tuning.

        This is an in-memory dataset that does not apply masking to the sequence.

        Args:
            data: A sequence of tuples containing the sequence and target data.
            tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer().
            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to
                ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random
                seed is generated.
        """
        self.data = data
        self.seed = seed
        self._len = len(self.data)
        self.tokenizer = tokenizer
        label_tokenizer = Label2IDTokenizer()
        self.label_tokenizer = label_tokenizer.build_vocab("CHE")
        self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX

    def __len__(self) -> int:
        """Length of dataset."""
        return self._len

    def __getitem__(self, index: int) -> BertSample:
        """Gets a BertSample associated to the supplied index."""
        sequence, target = self.data[index]
        tokenized_sequence = self._tokenize(sequence)
        # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
        loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids))
        labels = self._tokenize_labels(target)

        return {
            "text": tokenized_sequence,
            "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
            "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
            "labels": labels,
            "loss_mask": loss_mask,
            "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
        }

    def _tokenize_labels(self, labels_sequence: str) -> Tensor:
        label_ids = torch.tensor(self.label_tokenizer.text_to_ids(labels_sequence))

        # # for multi-label classification with BCEWithLogitsLoss
        # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
        # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)

        # for multi-class (mutually exclusive) classification with CrossEntropyLoss
        tokenized_labels = label_ids
        cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)

        # add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
        labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
        return labels

    def _tokenize(self, sequence: str) -> Tensor:
        """Tokenize a protein sequence.

        Args:
            sequence: The protein sequence.

        Returns:
            The tokenized sequence.
        """
        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
        return tensor.flatten()  # type: ignore

__getitem__(index)

获取与提供的索引关联的 BertSample。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def __getitem__(self, index: int) -> BertSample:
    """Gets a BertSample associated to the supplied index."""
    sequence, target = self.data[index]
    tokenized_sequence = self._tokenize(sequence)
    # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
    loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids))
    labels = self._tokenize_labels(target)

    return {
        "text": tokenized_sequence,
        "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
        "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
        "labels": labels,
        "loss_mask": loss_mask,
        "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
    }

__init__(data, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy)

初始化用于每个 token 分类微调的数据集。

这是一个内存数据集,不将掩码应用于序列。

参数

名称 类型 描述 默认值
data Sequence[Tuple[str, str]]

包含序列和目标数据的元组序列。

必需
tokenizer BioNeMoESMTokenizer

要使用的分词器。默认为 tokenizer.get_tokenizer()。

get_tokenizer()
seed int

用于重现性的随机种子。此种子与要检索的样本索引混合,以确保 getitem 是确定性的,但在不同的运行中可以是随机的。如果为 None,则生成一个随机种子。

entropy
源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def __init__(
    self,
    data: Sequence[Tuple[str, str]],
    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
    seed: int = np.random.SeedSequence().entropy,  # type: ignore
):
    """Initializes a dataset for per-token classification fine-tuning.

    This is an in-memory dataset that does not apply masking to the sequence.

    Args:
        data: A sequence of tuples containing the sequence and target data.
        tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer().
        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to
            ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random
            seed is generated.
    """
    self.data = data
    self.seed = seed
    self._len = len(self.data)
    self.tokenizer = tokenizer
    label_tokenizer = Label2IDTokenizer()
    self.label_tokenizer = label_tokenizer.build_vocab("CHE")
    self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX

__len__()

数据集的长度。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
236
237
238
def __len__(self) -> int:
    """Length of dataset."""
    return self._len

MegatronConvNetHead

基类:MegatronModule

用于残基级别分类的卷积神经网络类。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class MegatronConvNetHead(MegatronModule):
    """A convolutional neural network class for residue-level classification."""

    def __init__(self, config: TransformerConfig):
        """Constructor."""
        super().__init__(config)

        self.finetune_model = torch.nn.Sequential(
            torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_dim, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
            torch.nn.ReLU(),
            torch.nn.Dropout(config.cnn_dropout),
        )
        # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
        # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
        self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))

    def forward(self, hidden_states: Tensor) -> List[Tensor]:
        """Inference."""
        # [b, s, h] -> [b, h, s, 1]
        hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
        hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
        output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
        return output

__init__(config)

构造函数。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(self, config: TransformerConfig):
    """Constructor."""
    super().__init__(config)

    self.finetune_model = torch.nn.Sequential(
        torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_dim, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
        torch.nn.ReLU(),
        torch.nn.Dropout(config.cnn_dropout),
    )
    # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
    # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
    self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))

forward(hidden_states)

推理。

源代码位于 bionemo/esm2/model/finetune/finetune_token_classifier.py
132
133
134
135
136
137
138
def forward(self, hidden_states: Tensor) -> List[Tensor]:
    """Inference."""
    # [b, s, h] -> [b, h, s, 1]
    hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
    hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
    output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
    return output