重要

您正在查看 NeMo 2.0 文档。此版本引入了 API 的重大更改和一个新库 NeMo Run。我们目前正在移植 NeMo 1.0 中的所有功能到 2.0 版本。有关先前版本或 2.0 版本中尚未提供的功能的文档,请参阅 NeMo 24.07 文档

使用预训练检查点的持续学习#

持续学习使 LLM 能够获得新技能并与快速发展的人类知识领域保持同步。在本指南中,我们将探讨如何使用 NeMo 2.0 和现有的预训练检查点进行持续学习。此过程适用于各种模型,包括 Llama 1、Llama 2、Llama 3、Gemma、Mistral、Mixtral 等。这里我们以 Llama 3.1 8B 为例,说明运行持续学习的工作流程。

获取预训练检查点#

在开始持续学习之前,请确保您已下载检查点。您可以从 Hugging Face 自动下载 Llama 3.1 8B 检查点,并使用以下脚本将其转换为 NeMo

from pathlib import Path
from nemo.collections import llm

if __name__ == "__main__":
    llm.import_ckpt(
        model=llm.Llama31Config8b(),
        source="hf://meta-llama/Meta-Llama-3.1-8B",
    )

配置持续学习#

要启用持续学习,请从所需模型的预训练配方开始。

recipe = llm.recipes.llama31_8b.pretrain_recipe(dir="path/to/save", name="llama3_continual_learning", num_nodes=1, num_gpus_per_node=8)

您可以修改预定义的 pretrain_recipe 的恢复组件,而不是从头开始预训练,以从预训练检查点恢复持续学习

from nemo import lightning as nl
import nemo_run as run

recipe.resume = run.Config(
    nl.AutoResume,
    restore_config=run.Config(nl.RestoreConfig, path="nemo://meta-llama/Meta-Llama-3.1-8B"),
    resume_if_exists=True,
)

调整训练配置#

当进行持续学习时,修改各种训练配置通常是有益的。例如

  • 模型并行性:根据可用的计算资源,您可以调整模型的并行性设置。

  • 数据混合:您可以更改数据集或修改训练期间数据的混合方式,以更好地适应新的训练目标。

  • 学习率调度器:调整学习率计划有助于优化新条件下的训练。

以上所有内容都可以通过 NeMo 的配方轻松配置

# Modify Model Parallelism if needed
# These are recommended number for Llama3.1 8B, and for larger models,
# Use more parallelism parameters to fit the model as needed.
recipe.trainer.strategy.tensor_model_parallel_size = 1
recipe.trainer.strategy.pipeline_model_parallel_size = 1
recipe.trainer.strategy.context_parallel_size = 2
# Modify Data Blend if needed
new_paths = [.3, "path/to/data1", .7, "path/to/data2"]
recipe.data.paths = new_paths
# Or you can directly swap the data module if needed
new_data_module = run.Config(
  llm.PreTrainingDataModule,
  paths = new_paths,
  seq_length = seq_length,
  global_batch_size = gbs,
  micro_batch_size = mbs,
)
# Modify Learning Rate Scheduler if needed
recipe.optim.lr_scheduler.warmup_steps = warmup_steps
recipe.optim.lr_scheduler.min_lr = min_lr
recipe.optim.config.lr = max_lr

执行持续学习#

有关执行训练的各种方法,请参阅 nemo2-quickstart-nemo-run