如何训练、评估和微调 n-gram 语言模型#

语言建模返回单词序列上的概率分布。除了为单词序列分配概率外,语言模型还为给定单词(或单词序列)在单词序列之后出现的可能性分配概率。

句子:all of a sudden I notice three guys standing on the sidewalk 将比句子:on guys all I of notice sidewalk three a sudden standing the 获得更高的语言模型评分。

正如最近的研究表明,在大型语料库(即大型数据集)上训练的语言模型可以显着提高 ASR 系统的准确性。

n-gram 语言模型#

主要有两种类型的语言模型

  • n-gram 语言模型:这些模型使用 n-gram 的频率来学习单词上的概率分布。n-gram 语言模型的两个优点是简单性和可扩展性 – 随着更大的 n,模型可以存储更多上下文,并具有明确定义的空间-时间权衡,从而使小型实验能够有效地扩展。

  • 神经语言模型:这些模型使用不同类型的神经网络来建模单词上的概率分布,并且在语言建模能力方面已经超过了 n-gram 语言模型,但通常评估速度较慢。

在本教程中,我们将展示如何利用 NeMo 训练、评估和可选地微调 n-gram 语言模型

先决条件#

确保满足以下先决条件。

  1. 您已访问并登录 NVIDIA NGC。有关分步说明,请参阅 NGC 入门指南

  2. 您已安装 Kaggle API。有关分步说明,请参阅此安装和验证 Kaggle API


使用 KenLM 和 NeMo 训练和微调 LM#

安装和设置 NeMo#

从源代码克隆和安装 NeMo。

## Install NeMo
BRANCH = 'main'
!git clone https://github.com/NVIDIA/NeMo.git
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

安装和设置 KenLM#

从源代码安装 KenLM。

!apt install libeigen3-dev
!git clone https://github.com/kpu/kenlm.git
!cd kenlm && mkdir build && cd build && cmake .. && make -j

!pip3 install git+https://github.com/kpu/kenlm.git
!pip3 install git+https://github.com/flashlight/text.git

安装和设置 NGC CLI#

要安装和设置 NGC CLI,请按照此处的说明进行操作。


准备数据集#

LibriSpeech LM 标准化数据集#

在本教程中,我们使用 LibriSpeech LM 数据集的标准化版本来训练我们的 n-gram 语言模型。LibriSpeech LM 数据集的标准化版本可在此处获得。
训练数据在此处公开提供,可以直接下载。

下载数据集#

# Set the path to a folder where you want your data and results to be saved.
DATA_DOWNLOAD_DIR="content/datasets"
RESULTS_DIR="content/results"
MODELS_DIR="content/models"

!mkdir -p $DATA_DOWNLOAD_DIR $RESULTS_DIR $MODELS_DIR
# Note: Ensure that wget and unzip utilities are available. If not, install them.
!wget 'https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz' -P $DATA_DOWNLOAD_DIR

# Extract the data
!gzip -dk $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt.gz

为了减少本教程所需的时间,我们减少了训练数据集的行数。可以随意修改使用的行数。

# Use a random 100,000 lines for training
!shuf -n 100000 $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt  > $DATA_DOWNLOAD_DIR/reduced_training.txt

下载评估数据集#

#Note: This data can be used only with NVIDIA’s products or services for evaluation and benchmarking purposes.
!source ~/.bash_profile && ngc registry resource  download-version --dest $DATA_DOWNLOAD_DIR nvidia/riva/healthcare_eval_set:1.0

生成基础语言模型#

KENLM_BASE="kenlm/build/bin/"

$KENLM_BASE/lmplz
必需参数

  • -o:要估计的语言模型阶数。

可选参数

  • -S:要使用的内存。这是一个数字,后跟单字符后缀:% 表示物理内存的百分比(在测量此项的平台上),b 表示字节,K 表示千字节,M 表示兆字节,依此类推,直到 GT。如果未给出后缀,则假定为千字节,以便与 GNU sort 兼容。未使用 sort 程序;命令行仅设计为兼容。

  • -T:临时文件位置。

  • --discount_fallback:Kneser-Ney 平滑折扣是从计数的计数估计的,包括单例的数量。

!$KENLM_BASE/lmplz -o 4 < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/base_lm.arpa

$KENLM_BASE/build_binary
参数

  • -q:概率的量化标志。例如,-q 8 存储 8 位概率

  • -b:回退权重的量化标志。例如,-b 7 存储 7 位回退

  • -a:要删除的最大位数,并使用偏移量表隐式存储,以减少内存占用。

!$KENLM_BASE/build_binary  trie -q 8 -b 7 -a 256 $RESULTS_DIR/base_lm.arpa $RESULTS_DIR/base_lm.bin

加载 ASR 模型#

import torch
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.metrics.wer import CTCDecodingConfig

从 NGC 下载英语 Conformer 模型#

!source ~/.bash_profile && ngc registry model download-version "nvidia/riva/speechtotext_en_us_conformer:trainable_v3.1" --dest $MODELS_DIR

为 Flashlight 解码器创建词汇表文件#

!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/base_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR

更新解码器类型、语言模型和词汇表#

device = torch.device('cuda')
asr_model = ASRModel.restore_from(f"{MODELS_DIR}/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo").to(device)


decoding_cfg = CTCDecodingConfig()

decoding_cfg.strategy = "flashlight"
decoding_cfg.beam.search_type = "flashlight"
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/base_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'
decoding_cfg.beam.beam_size = 32
decoding_cfg.beam.beam_alpha = 0.2
decoding_cfg.beam.beam_beta = 0.2
decoding_cfg.beam.flashlight_cfg.beam_size_token = 32
decoding_cfg.beam.flashlight_cfg.beam_threshold = 25.0

asr_model.change_decoding_strategy(decoding_cfg)

评估#

import json
import os

def transcribe_json(asr_model, json_path, output_json):
    dataset_root = os.path.split(json_path)[0]
    with open(json_path) as fin, open(output_json, 'w') as fout:
        manifest = []
        audios = []
        for line in fin:
            dt = json.loads(line.strip())
            manifest.append(dt)
            audios.append(dt['audio_filepath'].replace("/data", dataset_root))
        transcripts = asr_model.transcribe(paths2audio_files=audios)
        for i in range(len(transcripts)):
            dt = {
                'audio_filepath': manifest[i]['audio_filepath'],
                'text': transcripts[i]
            }
            fout.write(json.dumps(dt)+"\n")

transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json")

计算词错误率#

!pip install jiwer
from jiwer import wer
import json

def calculate_wer(ground_truth_manifest, asr_transcript):
    data ={}
    ground_truths = []
    predictions = []
    with open(ground_truth_manifest) as file:
        for line in file:
            dt = json.loads(line)
            data[dt['audio_filepath']] = dt['text']
    with open(asr_transcript) as file:
        for line in file:
            dt = json.loads(line)
            if dt['audio_filepath'] in data:
                ground_truths.append(data[dt['audio_filepath']])
                predictions.append(dt['text'])
    return round(100*wer(ground_truths, predictions), 2)

print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))

微调和插值#

微调过程将继续使用先前训练的模型进行训练,方法是在新的领域数据上训练第二个模型,并将其与原始模型进行插值。微调要求原始模型在训练期间启用中间层。微调后的模型不能再次用于微调。

下载和处理领域数据(医疗保健)以进行 LM 微调#

为了在医疗保健领域进行微调,我们可以使用 Kaggle 数据集 PubMed 200k RCT:用于医学摘要中顺序句子分类的数据集
此数据集可在此处获得。
按照说明安装和验证 Kaggle API
注意:每位用户都有责任检查数据集的内容和适用的许可证,并确定它们是否适合预期用途。

!kaggle datasets download -d anshulmehtakaggl/200000-abstracts-for-seq-sentence-classification
!unzip -d $DATA_DOWNLOAD_DIR 200000-abstracts-for-seq-sentence-classification.zip
# Perform basic text cleaning and generate domain data
import string,re
def clean_text(text):
    text = re.sub(r"[^a-z' ]+", "", text.lower().strip())
    text = ' '.join(text.split())
    if len(text.split())> 5:
        return text.strip()
    
# Using dev file since we want a small amount of finetuning data. For better text Normalization use NeMo [https://github.com/NVIDIA/NeMo/tree/main/nemo_text_processing]
with open(f'{DATA_DOWNLOAD_DIR}/20k_abstracts_numbers_with_@/dev.txt') as file, open(f'{DATA_DOWNLOAD_DIR}/domain_data_all.txt', 'w') as outfile:
    for line in file:
        if line.startswith("###") or not line.strip():
            continue
        _, text = line.strip().split('\t')
        text = clean_text(text)
        if text:
            outfile.write(text+'\n')
            
# Picking top 10000 lines from dataset
!head -10000 $DATA_DOWNLOAD_DIR/domain_data_all.txt > $DATA_DOWNLOAD_DIR/domain_data.txt

微调过程将继续使用先前训练的模型进行训练,方法是在新的领域数据上训练第二个模型,并将其与原始模型进行插值。微调要求原始模型在训练期间启用中间层。微调后的模型不能再次用于微调。

为了使用 KenLM 微调 n-gram 语言模型,请执行以下步骤

  1. 为基础 LM 和领域 LM 生成中间 ARPA 文件.

  2. 使用合适的权重插值基础 LM 和领域 LM.

生成中间 ARPA#

# Base LM
!mkdir base_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate base_intermediate/inter < $DATA_DOWNLOAD_DIR/reduced_training.txt

# Healthcare LM
!mkdir healthcare_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate healthcare_intermediate/inter < $DATA_DOWNLOAD_DIR/domain_data.txt

插值#

插值权重可以通过以下方式传递

 -w 0.6 0.4

这里,60% 的权重分配给基础 LM,40% 分配给领域。

!$KENLM_BASE/interpolate -w 0.6 0.4 -m base_intermediate/inter healthcare_intermediate/inter > $RESULTS_DIR/interpolated_lm_60-40.arpa
!$KENLM_BASE/build_binary  trie -q 8 -b 7 -a 256 $RESULTS_DIR/interpolated_lm_60-40.arpa $RESULTS_DIR/interpolated_lm_60-40.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/interpolated_lm_60-40.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR

评估#

decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/interpolated_lm_60-40.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/interpolated_lm_60-40.lexicon'

asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json")

计算 WER#

print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Domain model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Domain model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json"))

剪枝#

通过简单地将文本语料库传递给 kenLM 而生成的 LM 包含一些不常见的 n-gram(在语料库中),因此概率非常低。可以通过 pruning 删除此类 n-gram。
剪枝需要一些阈值,这些阈值可以通过 --prune 参数传递,后跟空格分隔的阈值,这些阈值指定生成 ARPA 时每个阶数的计数阈值

 --prune 0 1 7 9

所有频率小于或等于指定阈值的 n-gram 都将被消除。
此处,频率 <= 1 的 2-gram、频率 <= 7 的 3-gram 和频率 <= 9 的 4-gram 将被消除。
剪枝程度和准确性之间存在权衡。高剪枝参数会减小语言模型的大小,但会牺牲模型准确性!

*注意:#

不支持 1-gram 的剪枝,1-gram 的阈值应始终为 0

!kenlm/build/bin/lmplz -o 4 --prune 0 1 7 9  < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/pruned_lm.arpa
!$KENLM_BASE/build_binary  trie -q 8 -b 7 -a 256 $RESULTS_DIR/pruned_lm.arpa $RESULTS_DIR/pruned_lm.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/pruned_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
# Lets check the size of original LM and Pruned LM
!echo "Size of unpruned ARPA: $(du -h $RESULTS_DIR/base_lm.arpa | cut -f 1)"
!echo "Size of pruned ARPA: $(du -h $RESULTS_DIR/pruned_lm.arpa | cut -f 1)"

评估#

decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/pruned_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'

asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json")

计算 WER#

print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Pruned base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Pruned base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json"))