重要
您正在查看 NeMo 2.0 文档。此版本对 API 和一个新库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
检查点#
NeMo 中提供的预训练 SSL 检查点需要在下游任务上进一步微调。在 NeMo 中加载预训练检查点主要有两种方法
使用
restore_from()
方法加载本地检查点文件 (.nemo
),或者使用
from_pretrained()
方法从 NGC 下载并设置检查点。
请参阅以下部分,获取每种方法的说明和示例。
请注意,这些说明用于微调。要恢复未完成的训练实验,请使用实验管理器,并将 resume_if_exists
标志设置为 True
。
加载本地检查点#
NeMo 会自动保存以 .nemo
格式训练的模型的检查点。或者,要在任何时候手动保存模型,请执行 model.save_to(<checkpoint_path>.nemo)
。
如果存在您想要加载的本地 .nemo
检查点,请使用 restore_from()
方法
import nemo.collections.asr as nemo_asr
ssl_model = nemo_asr.models.<MODEL_BASE_CLASS>.restore_from(restore_path="<path/to/checkpoint/file.nemo>")
其中模型基类是原始检查点的 ASR 模型类,或通用的 ASRModel
类。
加载 NGC 预训练检查点#
SSL 集合包含在各种数据集上训练的多个模型的检查点。这些检查点可以通过 NGC NeMo 自动语音识别集合 获取。NGC 上的模型卡包含有关每个可用检查点的更多信息。
本页末尾的表格列出了可从 NGC 获取的 SSL 模型。这些模型可以通过 ASR 模型类中的 from_pretrained()
方法访问。通常,您可以使用以下格式的代码加载任何这些模型
import nemo.collections.asr as nemo_asr
ssl_model = nemo_asr.models.ASRModel.from_pretrained(model_name="<MODEL_NAME>")
其中 model_name
是下表中 “Model Name” 条目下的值。
例如,要加载 conformer Large SSL 检查点,请运行
ssl_model = nemo_asr.models.ASRModel.from_pretrained(model_name="ssl_en_conformer_large")
如果您需要访问特定的模型功能,也可以从特定的模型类(例如 Conformer 的 SpeechEncDecSelfSupervisedModel
)调用 from_pretrained()
。
如果您想以编程方式列出特定基类可用的模型,可以使用 list_available_models()
方法。
nemo_asr.models.<MODEL_BASE_CLASS>.list_available_models()
将 SSL 检查点加载到下游模型#
如上所示加载 SSL 检查点后,需要将其 state_dict
复制到下游模型以进行微调。
例如,要使用 EncDecRNNTBPEModel
为 ASR 下游任务加载 SSL 检查点,请运行
# define down-stream model
asr_model = nemo_asr.models.EncDecRNNTBPEModel(cfg=cfg.model, trainer=trainer)
# load ssl checkpoint
asr_model.load_state_dict(ssl_model.state_dict(), strict=False)
# discard ssl model
del ssl model
请参阅 SSL configs 以通过配置文件自动执行此操作。
在下游数据集上微调#
将 SSL 检查点加载到下游模型后,请参阅 教程 部分提供的多个 ASR 教程。其中大多数教程解释了如何在某些数据集上进行微调作为演示。
推理执行流程图#
在下游微调后准备您自己的推理脚本时,请按照推理执行流程图顺序进行正确的推理,该图可在 ASR 集合的 examples 目录 中找到。
SSL 模型#
以下是在 NeMo 中可用的所有 SSL 模型的列表。
模型名称 |
模型基类 |
模型卡 |
---|---|---|
ssl_en_conformer_large |
SpeechEncDecSelfSupervisedModel |
https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large |
ssl_en_conformer_xlarge |
SpeechEncDecSelfSupervisedModel |
https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge |