重要提示
您正在查看 NeMo 2.0 文档。此版本对 API 和新库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚未提供的功能的文档,请参阅 NeMo 24.07 文档。
序列化#
NeMo 2.0 提供了捕获实验的训练器、模型和数据加载器的初始化参数的选项。此功能能够精确重建这些对象,从而轻松实现实验的可重复性。
IOMixin#
序列化使用 IOMixin
类执行。此类捕获传递给类的 __init__
方法的参数,这允许从给定的实验精确恢复训练器、模型和数据模块。以下是一个简单的示例
from nemo.lightning import io
ckpt = io.TrainerContext(model, trainer, extra={"datamodule": data})
## dump the current state
ckpt.io_dump(save_dir)
## restore the serialized state
loaded = io.load_context(save_dir)
## model, trainer and dataloader will be reinitialized using the same args as before
model = loaded.model
trainer = loaded.trainer
datamodule = loaded.extra["datamodule"]
可以通过 ModelCheckpoint
的 enable_nemo_ckpt_io
参数自动完成保存这些初始化状态。如果 enable_nemo_ckpt_io=True
,则将调用 IOMixin
的 io_dump
功能来保存训练器、模型和数据加载器初始化状态。然后可以使用 io.load_context
函数恢复这些状态。请注意,此功能独立于检查点加载;一旦对象被实例化,如果您想使用先前运行中的权重,它们仍然需要从检查点恢复。示例如下
首先,运行一些训练并保存检查点
import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning import io
trainer = nl.Trainer(...)
model = llm.GPTModel(...)
datamodule = llm.PreTrainingDataModule(...)
optim = nl.MegatronOptimizerModule(...)
checkpoint_callback = nl.ModelCheckpoint(
...
enable_nemo_ckpt_io=True,
...
)
nemo_logger = nl.NeMoLogger(
...
explicit_log_dir='explicit_dir_test',
ckpt=checkpoint_callback,
...
)
resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)
llm.train(
model=model,
data=datamodule,
trainer=trainer,
log=nemo_logger,
resume=resume,
tokenizer='data',
optim=opt,
)
在上面的示例中,ModelCheckpoint
、NeMoLogger
和 AutoResume
负责设置日志记录和检查点目录,并确定何时保存和恢复检查点。有关这些类的更多信息,请参阅 日志记录和检查点文档。
一旦初始化状态被保存,我们可以从序列化路径恢复训练器、模型和数据模块。请注意,未被 io_dump
捕获的所有内容(例如,检查点回调、记录器和恢复)都应重新初始化。这样做可确保正确设置日志记录和检查点目录。它还可以确保在重新初始化后恢复适当的模型权重。
import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning import io
loaded = io.load_context("explicit_dir_test/<PATH TO LATEST CHECKPOINT>")
model = loaded.model
trainer = loaded.trainer
datamodule = loaded.extra["datamodule"]
optim = nl.MegatronOptimizerModule(...) ## optimizer needs to be reinitialized
checkpoint_callback = nl.ModelCheckpoint(
...
enable_nemo_ckpt_io=True,
...
)
nemo_logger = nl.NeMoLogger(
...
explicit_log_dir='explicit_dir_test',
ckpt=checkpoint_callback,
...
)
resume = nl.AutoResume( ## handles resuming of the latest checkpoint in `explicit_dir_test`
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
)
llm.train(
model=model,
data=datamodule,
trainer=trainer,
log=nemo_logger,
resume=resume,
tokenizer='data',
optim=opt,
)