跳到内容

训练

train_model(experiment_name, experiment_dir, config, data_module, n_steps_train, metric_tracker=None, tokenizer=get_tokenizer(), peft=None, _use_rich_model_summary=True)

使用 PyTorch Lightning 训练 BioNeMo ESM2 模型。

参数

名称 类型 描述 默认值
experiment_name str

实验的名称。

必需
experiment_dir Path

实验将保存的目录。

必需
config ESM2GenericConfig

ESM2 模型的配置。

必需
data_module LightningDataModule

用于训练和验证的数据模块。

必需
n_steps_train int

训练步骤的数量。

必需
metric_tracker Callback | None

用于跟踪指标的可选回调

None
tokenizer BioNeMoESMTokenizer

要使用的分词器。默认为 get_tokenizer()

get_tokenizer()
peft PEFT | None

PEFT(参数高效微调)模块。默认为 None。

None
_use_rich_model_summary bool

是否使用 RichModelSummary 回调,在 https://nvbugspro.nvidia.com/bug/4959776 解决之前,我们的测试套件中省略了此回调。默认为 True。

True

返回值

类型 描述
Path

一个元组,包含保存的检查点路径、一个 MetricTracker

Callback | None

对象和 PyTorch Lightning Trainer 对象。

源代码位于 bionemo/esm2/model/finetune/train.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def train_model(
    experiment_name: str,
    experiment_dir: Path,
    config: ESM2GenericConfig,
    data_module: pl.LightningDataModule,
    n_steps_train: int,
    metric_tracker: Callback | None = None,
    tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
    peft: PEFT | None = None,
    _use_rich_model_summary: bool = True,
) -> Tuple[Path, Callback | None, nl.Trainer]:
    """Trains a BioNeMo ESM2 model using PyTorch Lightning.

    Parameters:
        experiment_name: The name of the experiment.
        experiment_dir: The directory where the experiment will be saved.
        config: The configuration for the ESM2 model.
        data_module: The data module for training and validation.
        n_steps_train: The number of training steps.
        metric_tracker: Optional callback to track metrics
        tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.
        peft: The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None.
        _use_rich_model_summary: Whether to use the RichModelSummary callback, omitted in our test suite until
            https://nvbugspro.nvidia.com/bug/4959776 is resolved. Defaults to True.

    Returns:
        A tuple containing the path to the saved checkpoint, a MetricTracker
        object, and the PyTorch Lightning Trainer object.
    """
    checkpoint_callback = nl_callbacks.ModelCheckpoint(
        save_last=True,
        save_on_train_epoch_end=True,
        monitor="reduced_train_loss",  # TODO find out how to get val_loss logged and use "val_loss",
        every_n_train_steps=n_steps_train // 2,
        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
    )

    # Setup the logger and train the model
    nemo_logger = NeMoLogger(
        log_dir=str(experiment_dir),
        name=experiment_name,
        tensorboard=TensorBoardLogger(save_dir=experiment_dir, name=experiment_name),
        ckpt=checkpoint_callback,
    )
    # Needed so that the trainer can find an output directory for the profiler
    # ckpt_path needs to be a string for SerDe
    optimizer = MegatronOptimizerModule(
        config=OptimizerConfig(
            lr=5e-4,
            optimizer="adam",
            use_distributed_optimizer=True,
            fp16=config.fp16,
            bf16=config.bf16,
        )
    )
    module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft)

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        ddp="megatron",
        find_unused_parameters=True,
        enable_nemo_ckpt_io=True,
    )

    if _use_rich_model_summary:
        # RichModelSummary is not used in the test suite until https://nvbugspro.nvidia.com/bug/4959776 is resolved due
        # to errors with serialization / deserialization.
        callbacks: list[Callback] = [RichModelSummary(max_depth=4)]
    else:
        callbacks = []

    if metric_tracker is not None:
        callbacks.append(metric_tracker)
    if peft is not None:
        callbacks.append(
            ModelTransform()
        )  # Callback needed for PEFT fine-tuning using NeMo2, i.e. biobert_lightning_module(model_transform=peft).

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=1,
        strategy=strategy,
        limit_val_batches=2,
        val_check_interval=n_steps_train // 2,
        max_steps=n_steps_train,
        num_nodes=1,
        log_every_n_steps=n_steps_train // 2,
        callbacks=callbacks,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )
    nllm.train(
        model=module,
        data=data_module,
        trainer=trainer,
        log=nemo_logger,
        resume=resume.AutoResume(
            resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
        ),
    )
    ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
    return ckpt_path, metric_tracker, trainer