重要提示

您正在查看 NeMo 2.0 文档。此版本对 API 和新库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚未提供的功能的文档,请参阅 NeMo 24.07 文档

序列化#

NeMo 2.0 提供了捕获实验的训练器、模型和数据加载器的初始化参数的选项。此功能能够精确重建这些对象,从而轻松实现实验的可重复性。

IOMixin#

序列化使用 IOMixin 类执行。此类捕获传递给类的 __init__ 方法的参数,这允许从给定的实验精确恢复训练器、模型和数据模块。以下是一个简单的示例

from nemo.lightning import io

ckpt = io.TrainerContext(model, trainer, extra={"datamodule": data})
## dump the current state
ckpt.io_dump(save_dir)

## restore the serialized state
loaded = io.load_context(save_dir)
## model, trainer and dataloader will be reinitialized using the same args as before
model = loaded.model
trainer = loaded.trainer
datamodule = loaded.extra["datamodule"]

可以通过 ModelCheckpointenable_nemo_ckpt_io 参数自动完成保存这些初始化状态。如果 enable_nemo_ckpt_io=True,则将调用 IOMixinio_dump 功能来保存训练器、模型和数据加载器初始化状态。然后可以使用 io.load_context 函数恢复这些状态。请注意,此功能独立于检查点加载;一旦对象被实例化,如果您想使用先前运行中的权重,它们仍然需要从检查点恢复。示例如下

首先,运行一些训练并保存检查点

import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning import io

trainer = nl.Trainer(...)
model = llm.GPTModel(...)
datamodule = llm.PreTrainingDataModule(...)
optim = nl.MegatronOptimizerModule(...)
checkpoint_callback = nl.ModelCheckpoint(
    ...
    enable_nemo_ckpt_io=True,
    ...
)
nemo_logger = nl.NeMoLogger(
    ...
    explicit_log_dir='explicit_dir_test',
    ckpt=checkpoint_callback,
    ...
)
resume = nl.AutoResume(
    resume_if_exists=True,
    resume_ignore_no_checkpoint=True,
)

llm.train(
    model=model,
    data=datamodule,
    trainer=trainer,
    log=nemo_logger,
    resume=resume,
    tokenizer='data',
    optim=opt,
)

在上面的示例中,ModelCheckpointNeMoLoggerAutoResume 负责设置日志记录和检查点目录,并确定何时保存和恢复检查点。有关这些类的更多信息,请参阅 日志记录和检查点文档

一旦初始化状态被保存,我们可以从序列化路径恢复训练器、模型和数据模块。请注意,未被 io_dump 捕获的所有内容(例如,检查点回调、记录器和恢复)都应重新初始化。这样做可确保正确设置日志记录和检查点目录。它还可以确保在重新初始化后恢复适当的模型权重。

import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning import io

loaded = io.load_context("explicit_dir_test/<PATH TO LATEST CHECKPOINT>")
model = loaded.model
trainer = loaded.trainer
datamodule = loaded.extra["datamodule"]
optim = nl.MegatronOptimizerModule(...) ## optimizer needs to be reinitialized

checkpoint_callback = nl.ModelCheckpoint(
    ...
    enable_nemo_ckpt_io=True,
    ...
)
nemo_logger = nl.NeMoLogger(
    ...
    explicit_log_dir='explicit_dir_test',
    ckpt=checkpoint_callback,
    ...
)
resume = nl.AutoResume( ## handles resuming of the latest checkpoint in `explicit_dir_test`
    resume_if_exists=True,
    resume_ignore_no_checkpoint=True,
)

llm.train(
    model=model,
    data=datamodule,
    trainer=trainer,
    log=nemo_logger,
    resume=resume,
    tokenizer='data',
    optim=opt,
)