跳到内容

ESM-2 微调

本自述文件用作实现 ESM-2 微调模块、运行回归示例以及使用模型进行推理的演示。

ESM-2 模型是一种基于 Transformer 的蛋白质语言模型,在各种蛋白质相关任务中取得了最先进的结果。在微调 ESM2 时,任务头起着至关重要的作用。任务头是指在预训练模型(如基于 Transformer 的 ESM-2 蛋白质语言模型)之上添加的额外层或层组,以使其适应特定的下游任务。作为迁移学习的一部分,预训练模型通常用于从大规模数据集中学习通用特征。但是,这些特征可能无法直接应用于手头的特定任务。通过合并包含可学习参数的任务头,模型可以适应并专门用于目标任务。任务头充当灵活且可适应的组件,通过利用预训练的特征作为基础来学习特定于任务的表示。通过微调,任务头使模型能够学习和提取特定于任务的模式,从而提高性能并解决下游任务的细微差别。它充当预训练模型和特定任务之间的关键桥梁,从而实现知识的有效高效转移。

设置和假设

在本教程中,我们将演示如何创建微调模块、训练回归任务头以及使用微调模型进行推理。

所有命令都应在 BioNeMo Docker 容器内执行,该容器已预先安装了所有 ESM-2 依赖项。本教程假设 BioNeMo 框架仓库的副本存在于工作站或服务器上,并且已挂载在容器内的 /workspace/bionemo2。(注意:如果您使用的是 VSCode Dev Container,则此 WORKDIR 可能是 /workspaces/bionemo-framework。)有关如何构建或拉取 BioNeMo2 容器的更多信息,请参阅访问和启动

为了成功完成此操作,我们需要定义一些关键类

  1. 损失缩减方法 - 用于计算监督微调损失。
  2. 微调模型头 - 下游任务头模型。
  3. 微调模型 - 将 ESM-2 与任务头模型相结合的模型。
  4. 微调配置 - 配置微调模型和损失,以在训练和推理框架中使用。
  5. 数据集 - ESM2 的训练和推理数据集。

1 - 损失缩减类

一个用于计算来自目标的微调模型的监督损失的类。我们从 Megatron Bert 掩码语言模型损失 (BERTMLMLossWithReduction) 继承,并覆盖 forward() 传递以计算微批次内回归头的 MSE 损失。reduce() 方法用于计算微批次的平均值,仅用于日志记录。

class RegressorLossReduction(BERTMLMLossWithReduction):
    def forward(
        self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Union[PerTokenLossDict, SameSizeLossDict]]:

        targets = batch["labels"]  # [b, 1]
        regression_output = forward_out
        loss = torch.nn.functional.mse_loss(regression_output, targets)
        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> torch.Tensor:
        losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return losses.mean()

2 - 微调模型头

用于序列级回归的 MLP 类。此类继承 MegatronModule 并使用微调配置 (TransformerConfig) 为微调 ESM-2 模型配置回归头。

class MegatronMLPHead(MegatronModule):
    def __init__(self, config: TransformerConfig):
        super().__init__(config)
        layer_sizes = [config.hidden_size, 256, 1]
        self.linear_layers = torch.nn.ModuleList(
            [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]
        )
        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=config.ft_dropout)

    def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
        ...

3 - 微调模型

用于令牌分类任务的微调 ESM-2 模型类。此类从 ESM2Model 类继承,并添加了我们在上一步中创建的自定义回归头 MegatronMLPHead。可以选择通过解析模型构造函数中的模型参数来冻结编码器的全部或部分。

class ESM2FineTuneSeqModel(ESM2Model):
    def __init__(self, config, *args, post_process: bool = True, return_embeddings: bool = False, **kwargs):
        super().__init__(config, *args, post_process=post_process, return_embeddings=True, **kwargs)

        # freeze encoder parameters
        if config.encoder_frozen:
            for _, param in self.named_parameters():
                param.requires_grad = False

        if post_process:
            self.regression_head = MegatronMLPHead(config)

    def forward(self, *args, **kwargs,):
        output = super().forward(*args, **kwargs)
        ...
        regression_output = self.regression_head(embeddings)
        return regression_output

4 - 微调配置

配置微调 ESM-2 模型的 dataclass。在此示例中,ESM2FineTuneSeqConfigESM2GenericConfig 继承,并添加了自定义参数来设置微调模型。此 dataclassconfigure_model() 方法在 Lightning 模块内调用,以使用 dataclass 参数调用模型构造函数。

不同微调任务之间的常见参数是

  • model_cls:微调模型类 (ESM2FineTuneSeqModel)
  • initial_ckpt_path:BioNeMo 2.0 ESM-2 预训练检查点
  • initial_ckpt_skip_keys_with_these_prefixes:从检查点加载参数时跳过键。在这里,我们不应在预训练检查点中查找 regression_head
  • get_loss_reduction_class():实现适当的 MegatronLossReduction 类的选择,例如 bionemo.esm2.model.finetune.finetune_regressor.RegressorLossReduction
@dataclass
class ESM2FineTuneSeqConfig(ESM2GenericConfig[ESM2FineTuneSeqModel], iom.IOMixinWithGettersSetters):
    model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel
    # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
    # self.override_parent_fields will be loaded from the checkpoint and override those values here.
    initial_ckpt_path: str | None = None
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])

    encoder_frozen: bool = True  # freeze encoder parameters
    ft_dropout: float = 0.25  # MLP layer dropout

    def get_loss_reduction_class(self) -> Type[MegatronLossReduction]:
        return RegressorLossReduction

5 - 数据集

我们将使用示例数据集进行演示。通过从 torch.utils.data.Dataset 扩展来创建数据集类。对于本演示的目的,我们将假设数据集由一小部分蛋白质序列组成,目标值为 len(sequence) / 100.0 作为其标签。

data = [
    ("MVLSPADKTNVKAAWGKVGAHAGEYGAEALERH", 0.33),
    ...
]

因此,自定义 BioNeMo 数据集类将是合适的(在 bionemo.esm2.model.finetune.finetune_regressor.InMemorySingleValueDataset 中找到),因为它有助于预测单个值。下面显示了该类的一个摘录。此示例数据集期望一个 Tuple 序列,其中包含 (sequence, target) 值。但是,可以简单地以类似的方式扩展 InMemorySingleValueDataset 类,以自定义您的类以适应您的数据。

class InMemorySingleValueDataset(Dataset):
    def __init__(
        self,
        data: Sequence[Tuple[str, float]],
        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
        seed: int = np.random.SeedSequence().entropy,
    ):

对于任何任意数据文件格式,用户可以将数据处理成包含(序列,标签)的元组列表,并使用此数据集类。或覆盖数据集类以加载其自定义数据文件。

为了协调从您的数据创建训练、验证和测试数据集,我们需要使用 datamodule 类。为此,我们可以直接使用或扩展 ESM2FineTuneDataModule 类(位于 bionemo.esm2.model.finetune.datamodule.ESM2FineTuneDataModule),该类定义了使用您的数据集类的有用的抽象方法。

dataset = InMemorySingleValueDataset(data)
data_module = ESM2FineTuneDataModule(
    train_dataset=train_dataset,
    valid_dataset=valid_dataset
    micro_batch_size=4,   # size of a batch to be processed in a device
    global_batch_size=8,  # size of batch across all devices. Should be multiple of micro_batch_size
)

微调 ESM2 的回归器任务头

现在我们可以将这五个要求放在一起,从预训练的 650M ESM-2 模型 (pretrain_ckpt_path) 开始微调回归器任务头。我们可以利用 bionemo.esm2.model.fnetune.train 中的简单训练循环,并使用 `train_model() 函数在下面开始微调过程。

# create a List[Tuple] with (sequence, target) values
artificial_sequence_data = [
    "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
    "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
    "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
    "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
    "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
    "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]

data = [(seq, len(seq)/100.0) for seq in artificial_sequence_data]

# we are training and validating on the same dataset for simplicity
dataset = InMemorySingleValueDataset(data)
data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)

experiment_name = "finetune_regressor"
n_steps_train = 50
seed = 42

# To download a 650M pre-trained ESM2 model
pretrain_ckpt_path = load("esm2/650m:2.0")

config = ESM2FineTuneSeqConfig(
    initial_ckpt_path=str(pretrain_ckpt_path)
)

checkpoint, metrics, trainer = train_model(
    experiment_name=experiment_name,
    experiment_dir=Path(experiment_results_dir),  # new checkpoint will land in a subdir of this
    config=config,  # same config as before since we are just continuing training
    data_module=data_module,
    n_steps_train=n_steps_train,
)

此示例已在 bionemo.esm2.model.finetune.train 中完全实现,可以通过以下方式执行

python -m bionemo.esm2.model.finetune.train

注释

  1. 上面的示例正在微调 650M ESM-2 模型。预训练检查点可以使用以下 bash 命令或 bionemo.core.data.load 中的 load 函数从 NGC 资源下载,如上所示。
    download_bionemo_data esm2/650m:2.0
    
    并将输出路径(例如 .../.cache/bionemo/975d29ee980fcb08c97401bbdfdcf8ce-esm2_650M_nemo2.tar.gz.untar)作为参数传递到 initial_ckpt_path,同时设置配置对象
    config = ESM2FineTuneSeqConfig(
        initial_ckpt_path=str(pretrain_ckpt_path)
    )
    
  2. 由于 Megatron 的限制,训练运行生成的日志会迭代步骤/迭代,而不是 epoch。因此,训练 epoch 计数器保持为零值,而 iterationglobal_step 在训练过程中增加(以下示例中)。
    Training epoch 0, iteration <x/max_steps> | ... | global_step: <x> | reduced_train_loss: ... | val_loss: ...
    
    为了在训练时实现相同的基于 epoch 的效果,请选择训练步骤数 (n_steps_train),以便
    n_steps_train * global_batch_size = len(dataset) * desired_num_epochs
    
  3. 在此示例中,我们使用人工序列的小数据集作为微调数据。您可能会遇到过拟合,并且观察不到验证指标的变化。

微调 ESM-2 模型推理

现在我们可以使用 bionemo.esm2.model.finetune.train.infer 对示例预测数据集运行推理。记录在微调运行结束时报告的检查点路径,在执行 python -m bionemo.esm2.model.finetune.train 之后(例如 /tmp/tmp1b5wlnba/finetune_regressor/checkpoints/finetune_regressor--reduced_train_loss=0.0016-epoch=0-last),并将其用作推理脚本的参数 (--checkpoint-path)。

我们为此推理示例下载了一个人工序列的 CSV 示例数据集。有关参数的详细说明以及如何创建您自己的 CSV 文件,请参阅ESM-2 推理教程。

mkdir -p $WORKDIR/esm2_finetune_tutorial

# download sample data CSV for inference
DATA_PATH=$(download_bionemo_data esm2/testdata_esm2_infer:2.0)
RESULTS_PATH=$WORKDIR/esm2_finetune_tutorial/

infer_esm2 --checkpoint-path <finetune checkpoint path> \
           --data-path $DATA_PATH \
           --results-path $RESULTS_PATH \
           --config-class ESM2FineTuneSeqConfig

这将会在 $WORKDIR/esm2_finetune_tutorial/predictions__rank_0.pt 下创建一个结果 .pt 文件,该文件可以通过 Python 环境中的 PyTorch 库加载

import torch

# Set the path to results file e.g. /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
# results_path = /workspace/bionemo2/esm2_finetune_tutorial/predictions__rank_0.pt
results = torch.load(results_path)

# results is a python dict which includes the following result tensors for this example:
# results['regression_output'] is a tensor with shape: torch.Size([10, 1])

注释

  • ESM2 推理模块采用 --checkpoint-path--config-class 参数,通过指向 initial_ckpt_path 中的路径来创建配置对象。由于我们需要从此检查点加载所有参数(并且不跳过头),因此我们在此配置中重置 initial_ckpt_skip_keys_with_these_prefixes

    config = ESM2FineTuneSeqConfig(
        initial_ckpt_path = <finetuned checkpoint>,
        initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
    )