跳到内容

ESM-2 预训练

本教程演示了如何使用 UniProt 序列从头开始预训练 ESM2

ESM-2 模型是一种基于 Transformer 的蛋白质语言模型,它在掩码语言模型 (MLM) 任务上进行了预训练。目标是从蛋白质序列的其余部分恢复扰动位置的原始氨基酸类型。通过预训练,ESM-2 学习蛋白质序列中的进化信息,类似于保守性分析和 Pott 模型,并预测任何给定蛋白质序列上的最佳突变。

设置和假设

在本教程中,我们将演示如何创建 ESM-2 预训练数据模块,以及创建和训练 ESM-2 模型。

所有命令都应在 BioNeMo Docker 容器内执行,该容器已预装所有 ESM-2 依赖项。BioNeMo 框架容器可以在 brev.dev 启动器中运行: 点击此处部署。。将此笔记本部署为 Launchable 大约需要 10 分钟。截至撰写本文时,我们正在开发免费层,因此可能需要信用卡。您可以联系您的 NVIDIA 代表以获取积分。启动实例后,在 Jupyter Lab UI 中启动终端会话。(注意:此链接指向每晚发布版本,可能与这些文档不同步。)

或者,有关如何在本地构建或拉取 BioNeMo2 容器的更多信息,请参阅初始化指南

本教程假设工作站或服务器上存在 BioNeMo 框架存储库的副本,并且已将其挂载到容器内的 /workspace/bionemo2

注意

如果您使用的是 VSCode Dev Container,则此 WORKDIR 可能是 /workspaces/bionemo-framework

与 PyTorch Lightning 类似,我们必须定义一些关键类

  1. MegatronStrategy - 启动和设置 NeMoMegatron-LM 的并行性。
  2. Trainer - 配置训练配置和日志记录。
  3. ESMDataModule - 加载预训练训练和验证数据,并将映射的 UniRef90 序列映射到 UniRef50 集群。
  4. ESM2Config - 将 ESM-2 模型配置为 BionemoLightningModule

1 - MegatronStrategy

BioNeMo2 支持数据并行 (DP)、张量并行 (TP) 和流水线并行 (PP) 以训练大型模型。我们使用 MegatronStrategy 来启动和设置 NeMo 和 Megatron-LM 的并行性,而不是 PyTorch Lightning 中的 DDPStrategy

from nemo import lightning as nl
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size

micro_batch_size = 2
num_nodes = 1
devices = 2
accumulate_grad_batches = 1
tensor_model_parallel_size = 2
pipeline_model_parallel_size = 1

global_batch_size = infer_global_batch_size(
    micro_batch_size=micro_batch_size,
    num_nodes=num_nodes,
    devices=devices,
    accumulate_grad_batches=accumulate_grad_batches,
    tensor_model_parallel_size=tensor_model_parallel_size,
    pipeline_model_parallel_size=pipeline_model_parallel_size,
)

strategy = nl.MegatronStrategy(
    tensor_model_parallel_size=tensor_model_parallel_size,
    pipeline_model_parallel_size=pipeline_model_parallel_size,
    ddp="megatron",
    find_unused_parameters=True,
    ckpt_include_optimizer=True,
)

2 - Trainer

BioNeMo2 训练器与 PyTorch Lightning 训练器非常相似。我们可以配置训练配置和日志记录。

from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary
from bionemo.llm.lightning import PerplexityLoggingCallback

num_steps = 20
limit_val_batches = 2  # limit the validation epoch to 2 batches
val_check_interval = 10  # validation epoch every 10 steps
precision = "bf16-mixed"  # use bf16-mixed precision

trainer = nl.Trainer(
    devices=devices,
    max_steps=num_steps,
    accelerator="gpu",
    strategy=strategy,
    limit_val_batches=limit_val_batches,
    val_check_interval=val_check_interval,
    num_nodes=num_nodes,
    callbacks=[
        PerplexityLoggingCallback(),
        RichModelSummary(max_depth=4),
        LearningRateMonitor(),
    ],
    plugins=nl.MegatronMixedPrecision(precision=precision),  # precision is handled through plugins in BioNeMo2
)

以下是其他可能的配置示例。

from bionemo.core.utils.dtypes import PrecisionTypes

limit_val_batches_all_data = 1.  # validate on 100% of the validation dataset
limit_val_batches_half_data = 0.5  # validate on 50% of the validation dataset
limit_val_batches_one_batch = 1  # validate on 1 batch

print(PrecisionTypes)  # show all possible precision types

3 - ESMDataModule

在使用数据模块实例化之前,我们可以先使用 download_bionemo_data 下载测试 ESM-2 预训练数据。命令行将下载数据(如果我们尚未下载),并将返回测试数据的路径,这是实例化 ESMDataModule 所需的。

download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source ngc  # test data
# download_bionemo_data esm2/fulldata_esm2_pretrain:2.0 --source ngc  # full data (~80GB)

除了数据目录的路径之外,BioNeMo2 数据模块还需要全局和微批量大小,以确保在模型并行 ranks 中正确初始化输入张量(请参阅Megatron 数据集注意事项)。

from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer

data_path = __your_downloaded_test_data_path__  # fill your path from the command line output

train_cluster_path = f"{data_path}/2024_03_sanity/train_clusters_sanity.parquet"
train_database_path = f"{data_path}/2024_03_sanity/train_sanity.db"
valid_cluster_path = f"{data_path}/2024_03_sanity/valid_clusters.parquet"
valid_database_path = f"{data_path}/2024_03_sanity/validation.db"

min_seq_length = None  # optional; filter sequences by minimum length if given
max_seq_length = 128  # required; default to 1024

num_dataset_workers = 1
random_mask_strategy = RandomMaskStrategy.ALL_TOKENS  # default in BioNemo2 and HuggingFace implementation

data = ESMDataModule(
    train_cluster_path=train_cluster_path,  # UniRef50 training cluster centers
    train_database_path=train_database_path,  # UniRef90 training sequences
    valid_cluster_path=valid_cluster_path,  # UniRef50 validation cluster centers
    valid_database_path=valid_database_path,  # UniRef90 validation sequences
    global_batch_size=global_batch_size,
    micro_batch_size=micro_batch_size,
    min_seq_length=min_seq_length,
    max_seq_length=max_seq_length,
    num_workers=num_dataset_workers,
    random_mask_strategy=random_mask_strategy,
)

RandomMaskStrategy

当在 MLM 目标上训练时,损失函数随机包含 15% 的令牌,其中 80% 被掩码,10% 被替换为随机令牌,10% 保持不变。由于词汇表包括氨基酸以及特殊令牌,因此部分蛋白质序列可能会被特殊令牌替换。这是 BioNeMo2 和 HuggingFace ESM-2 实现中的默认设置。

为了强制仅进行氨基酸替换,用户可以将 random_mask_strategy=RandomMaskStrategy.AMINO_ACID_ONLY 传递给 ESMDataModule

4. ESM2Config

分片模型不是在每个 rank 上初始化整个模型,而是在配置对象的帮助下在目标 rank 上延迟创建。ESM2Config 是一个数据类,它封装了架构参数(例如 num_layers)以及 Transformer 中每个 torch 模块 (ModuleSpec) 的规范,这些模块在 TransformerEngine 中使用 flash 和 fused attention 加速。虽然我们可以从 ESM2Config 初始化模型,但其设置仅在 trainer.setup 下完成,该设置在各个设备上调用。

from megatron.core.optimizer import OptimizerConfig
from nemo.lightning.pytorch.optim import MegatronOptimizerModule

from bionemo.core.utils.dtypes import get_autocast_dtype
from bionemo.esm2.api import ESM2Config
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
from bionemo.llm.lightning import BionemoLightningModule
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption

# ESM-2 650M config
num_layers = 33
hidden_size = 1280
num_attention_heads = 20
ffn_hidden_size = 4 * hidden_size

nemo1_init_path = None  # initialize from nemo1 checkpoint
restore_from_checkpoint_path = None  # initialize from nemo2 checkpoint
need_megatron_variable_seq_lengths_reductions: bool = (
    pipeline_model_parallel_size * tensor_model_parallel_size > 1 and min_seq_length != max_seq_length
)  # essential for pipeline/tensor parallel
biobert_spec_option = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec  # accelerated esm2 with transformer engine

warmup_steps = 2000
lr = 1e-4

# Create model config
esm2_config = ESM2Config(
    seq_length=max_seq_length,
    num_layers=num_layers,
    hidden_size=hidden_size,
    num_attention_heads=num_attention_heads,
    ffn_hidden_size=ffn_hidden_size,
    params_dtype=get_autocast_dtype(precision),
    pipeline_dtype=get_autocast_dtype(precision),
    autocast_dtype=get_autocast_dtype(precision),  # setting this speeds things up a lot
    biobert_spec_option=biobert_spec_option,
    nemo1_ckpt_path=str(nemo1_init_path) if nemo1_init_path is not None else None,
    initial_ckpt_path=str(restore_from_checkpoint_path) if restore_from_checkpoint_path is not None else None,
    variable_seq_lengths=need_megatron_variable_seq_lengths_reductions,
)

# Create model instance
tokenizer = get_tokenizer()

model: BionemoLightningModule = biobert_lightning_module(
    esm2_config,
    tokenizer=tokenizer,
    optimizer=MegatronOptimizerModule(
        config=OptimizerConfig(
            lr=lr,
            optimizer="adam",
            use_distributed_optimizer=True,
            weight_decay=0.01,
            adam_beta1=0.9,
            adam_beta2=0.98,
        ),
        lr_scheduler=WarmupAnnealDecayHoldScheduler(
            warmup_steps=warmup_steps, max_steps=num_steps, max_lr=lr, min_lr=lr / 10.0, anneal_percentage=0.10
        ),
    ),
)

ModuleSpec

ModelSpec 决定了 Transformer 层中使用哪些 torch 模块。默认情况下,BioNeMo2 使用 TransformerEngine 层加速 ESM-2 架构。用户可以为自定义 Transformer 层定义自己的 ModelSpec。请参阅get_biobert_spec

BionemoLightningModule

由于模型在目标 rank 中延迟初始化,因此用于调试目的的断点应在 trainer.setup 之后添加。

模型预训练

为了完成循环,用户可以利用 NeMo 中的 llm.train 开始训练。

from typing import Optional

from nemo.collections import llm
from nemo.lightning import resume
from nemo.lightning.pytorch import callbacks as nl_callbacks

from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger


# WANDB logging
wandb_options: Optional[WandbLoggerOptions] = (
    None
    if wandb_project is None
    else WandbLoggerOptions(
        offline=False,
        project=__your_wandb_project__,
        entity=__your_wandb_entity__,
        tags=None,
        group=None,
        id=None,
        anonymous=False,
        log_model=False,
    )
)

checkpoint_callback = nl_callbacks.ModelCheckpoint(
    save_last=True,
    monitor="val_loss",
    save_top_k=1,
    always_save_context=True,
)

nemo_logger = setup_nemo_lightning_logger(
    root_dir=__your_result_dir__,
    name=__your_experiment_name__,
    initialize_tensorboard_logger=True,
    wandb_kwargs=wandb_options,
    ckpt_callback=checkpoint_callback,
)

llm.train(
    model=model,
    data=data,
    trainer=trainer,
    log=nemo_logger,
    resume=resume.AutoResume(
        resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
        resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
    ),
)

或者直接使用位于 $WORKDIR/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py 中的 ESM2 预训练。可以通过直接使用 python 或安装的可执行文件 train_esm2 来调用此脚本

# Enable fused attention in transformer engine for speed-up
DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source ngc)

train_esm2 \
    --train-cluster-path ${DATA_DIR}/2024_03_sanity/train_clusters_sanity.parquet \
    --train-database-path ${DATA_DIR}/2024_03_sanity/train_sanity.db \
    --valid-cluster-path ${DATA_DIR}/2024_03_sanity/valid_clusters.parquet \
    --valid-database-path ${DATA_DIR}/2024_03_sanity/validation.db \
    --precision="bf16-mixed" \
    --num-gpus 1 \
    --num-nodes 1 \
    --num-steps 100 \
    --val-check-interval 25 \
    --max-seq-length 1024 \
    --limit-val-batches 2 \
    --micro-batch-size 2 \
    --num-layers 33 \
    --hidden-size 1280 \
    --num-attention-head 20 \
    --ffn-hidden-size 5120 \
    --tensor-model-parallel-size 1 \
    --create-tensorboard-logger \
    --wandb_project=__your_wandb_project__ \
    --experiment-name=__your_wandb_experiment_name

此脚本将自动创建 ./results 并在 esm2 下存储检查点。当 --resume-if-exists 设置为 True 时,将自动处理自动预训练恢复,如果用户想要从特定路径恢复,则可以使用 --restore-from-checkpoint-path

Weight And Biases

如果打算使用 --wandb_project,用户应登录 Weight and Biases 或选择导出环境变量 WANDB_API_KEY。如果未提供,则将禁用记录器。

命令行运行中的非关键警告

用户可能会遇到来自 Megatron-LM 的 torch._dynamo.convert_frame 警告消息和关于 async_grad_allreduce 的弃用警告。用户可以安全地忽略它们,并且它们对于预训练来说是非关键的。

我们在以下模型大小1上基准测试了我们的实现。这些参数由以下因素处理

模型大小 # 层数 隐藏层大小 # 注意力头数 FFN 隐藏层大小
8M 8 320 20 1280
650M 33 1280 20 5120
3B 36 2560 40 10240
15B 48 5120 40 20480

在我们当前的基准测试中,我们建议在 A100 80GB GPU 上使用以下训练和设备配置,以匹配已发布的 2M 令牌全局批量大小。

模型大小 # GPU 数量 微批量大小 张量模型并行大小
8M 32 64 1
650M 64 32 1
3B 128 16 1
15B 3120 2 2

关于微批量大小的其他说明

虽然上述微批量大小以 2^n 选择,以达到 2,097,152 个令牌的全局批量大小,但用户应通过将尽可能大的微批量大小拟合到设备上而不发生 OOM 来观察性能提升。当前最大的批量大小列在下面。

模型大小 最大微批量大小 张量模型并行大小
8M 70 1
650M 48 1
3B 16 1
15B 3 2

唯一的例外是 15B 模型,作者报告了 3.2M 令牌的全局批量大小。我们在 390 个 A100 节点上达到了 3,194,880 个令牌。

这些模型大小的最大微批量大小在 2 个 A100 80GB GPU 节点上进行了测试。

来自分布式优化器的内存分配

默认情况下启用分布式优化器以改进内存分配。用户可能会观察到,在多设备预训练中使用的相同微批量大小会导致在单个设备上发生 OOM。如果需要额外的优化,我们建议在与生产运行中相同数量的设备上运行简短的基准测试。


  1. Lin, Zeming, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Nikita Smetanin, 等。“Evolutionary-Scale Prediction of Atomic-Level Protein Structure with a Language Model.” Science 379, no. 6637 (2023 年 3 月 17 日): 1123–30. https://doi.org/10.1126/science.ade2574