重要提示

您正在查看 NeMo 2.0 文档。此版本对 API 和新的库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 中的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档

蒸馏#

NeMo 2.0 提供了一种易于启用的知识蒸馏 (KD) 训练设置。以下部分将介绍如何使用它。

知识蒸馏#

KD 涉及使用来自现有训练模型的信息来训练第二个(通常更小、更快)模型,从而将知识从一个模型“蒸馏”到另一个模型。

蒸馏有两个主要好处:比传统训练更快的收敛速度和更高的最终准确率。

在 NeMo 中,蒸馏由 NVIDIA TensorRT 模型优化器 (ModelOpt) 库启用——该库用于优化深度学习模型以在 GPU 上进行推理。

Logits 蒸馏过程#

logits 蒸馏过程包括以下步骤

  1. 加载检查点:加载学生模型和教师模型检查点。它们都必须支持相同的并行策略。

  2. 替换损失函数:将标准损失函数替换为输出 logits 之间的 KL 散度。

  3. 训练模型:对两个模型运行前向传播,但仅对学生模型执行反向传播。

  4. 保存检查点:仅保存学生模型检查点,使其可以像以前一样在以后使用。

局限性#

  • 仅支持基于 GPT 的 NeMo 2.0 检查点。

  • 目前仅启用 logits 对蒸馏。

示例#

以下示例展示了如何在给定任何 NeMo 2.0 检查点的情况下运行蒸馏脚本。

使用 NeMo-Run Recipes#

注意

先决条件:在继续之前,请按照 NeMo-Run 快速入门 中的示例操作,以首先熟悉 NeMo-Run。

import nemo_run as run
from nemo.collections.llm.distillation.recipe import distillation_recipe

recipe = distillation_recipe(
    student_model_path="path/to/student/nemo2-checkpoint/",
    teacher_model_path="path/to/teacher/nemo2-checkpoint/",
    dir="./distill_logs",  # Path to store logs and checkpoints
    name="distill_testrun",
    num_nodes=1,
    num_gpus_per_node=8,
)

# Override the configuration with desired components:
# recipe.data = run.Config(...)
# recipe.trainer = run.Config(...)
...

run.run(recipe)

将蒸馏脚本与 torchrun 或 Slurm 一起使用#

或者,您可以运行具有更精细自定义程度的传统脚本。

STUDENT_CKPT="path/to/student/nemo2-checkpoint/"
TEACHER_CKPT="path/to/teacher/nemo2-checkpoint/"

DATA_PATHS="1.0 path/to/tokenized/data"
SEQUENCE_LEN=8192
MICRO_BATCHSIZE=1
GLOBAL_BATCHSIZE=4
STEPS=100

TP=8
CP=1
PP=1
DP=1
NUM_NODES=1
DEVICES_PER_NODE=8

NAME="distill_testrun"
LOG_DIR="./distill_logs/"


launch_cmd="torchrun --nproc_per_node=$(($TP * $CP * $PP * $DP))"

${launch_cmd} scripts/llm/gpt_distillation.py \
    --name ${NAME} \
    --student_path ${STUDENT_CKPT} \
    --teacher_path ${TEACHER_CKPT} \
    --tp_size ${TP} \
    --cp_size ${CP} \
    --pp_size ${PP} \
    --devices ${DEVICES_PER_NODE} \
    --num_nodes ${NUM_NODES} \
    --log_dir ${LOG_DIR} \
    --max_steps ${STEPS} \
    --gbs ${GLOBAL_BATCHSIZE} \
    --mbs ${MICRO_BATCHSIZE} \
    --data_paths ${DATA_PATHS} \
    --seq_length ${SEQUENCE_LEN}