如何训练、评估和微调 n-gram 语言模型
目录

如何训练、评估和微调 n-gram 语言模型#
语言建模返回单词序列上的概率分布。除了为单词序列分配概率外,语言模型还为给定单词(或单词序列)在单词序列之后出现的可能性分配概率。
句子:all of a sudden I notice three guys standing on the sidewalk 将比句子:on guys all I of notice sidewalk three a sudden standing the 获得更高的语言模型评分。
正如最近的研究表明,在大型语料库(即大型数据集)上训练的语言模型可以显着提高 ASR 系统的准确性。
n-gram 语言模型#
主要有两种类型的语言模型
n-gram 语言模型:这些模型使用 n-gram 的频率来学习单词上的概率分布。n-gram 语言模型的两个优点是简单性和可扩展性 – 随着更大的
n
,模型可以存储更多上下文,并具有明确定义的空间-时间权衡,从而使小型实验能够有效地扩展。神经语言模型:这些模型使用不同类型的神经网络来建模单词上的概率分布,并且在语言建模能力方面已经超过了 n-gram 语言模型,但通常评估速度较慢。
在本教程中,我们将展示如何利用 NeMo 训练、评估和可选地微调 n-gram 语言模型。
先决条件#
确保满足以下先决条件。
您已访问并登录 NVIDIA NGC。有关分步说明,请参阅 NGC 入门指南。
您已安装 Kaggle API。有关分步说明,请参阅此安装和验证 Kaggle API。
使用 KenLM 和 NeMo 训练和微调 LM#
安装和设置 NeMo#
从源代码克隆和安装 NeMo。
## Install NeMo
BRANCH = 'main'
!git clone https://github.com/NVIDIA/NeMo.git
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]
安装和设置 KenLM#
从源代码安装 KenLM。
!apt install libeigen3-dev
!git clone https://github.com/kpu/kenlm.git
!cd kenlm && mkdir build && cd build && cmake .. && make -j
!pip3 install git+https://github.com/kpu/kenlm.git
!pip3 install git+https://github.com/flashlight/text.git
准备数据集#
LibriSpeech LM 标准化数据集#
在本教程中,我们使用 LibriSpeech LM 数据集的标准化版本来训练我们的 n-gram 语言模型。LibriSpeech LM 数据集的标准化版本可在此处获得。
训练数据在此处公开提供,可以直接下载。
下载数据集#
# Set the path to a folder where you want your data and results to be saved.
DATA_DOWNLOAD_DIR="content/datasets"
RESULTS_DIR="content/results"
MODELS_DIR="content/models"
!mkdir -p $DATA_DOWNLOAD_DIR $RESULTS_DIR $MODELS_DIR
# Note: Ensure that wget and unzip utilities are available. If not, install them.
!wget 'https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz' -P $DATA_DOWNLOAD_DIR
# Extract the data
!gzip -dk $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt.gz
为了减少本教程所需的时间,我们减少了训练数据集的行数。可以随意修改使用的行数。
# Use a random 100,000 lines for training
!shuf -n 100000 $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt > $DATA_DOWNLOAD_DIR/reduced_training.txt
下载评估数据集#
#Note: This data can be used only with NVIDIA’s products or services for evaluation and benchmarking purposes.
!source ~/.bash_profile && ngc registry resource download-version --dest $DATA_DOWNLOAD_DIR nvidia/riva/healthcare_eval_set:1.0
生成基础语言模型#
KENLM_BASE="kenlm/build/bin/"
$KENLM_BASE/lmplz
必需参数
-o
:要估计的语言模型阶数。
可选参数
-S
:要使用的内存。这是一个数字,后跟单字符后缀:%
表示物理内存的百分比(在测量此项的平台上),b
表示字节,K
表示千字节,M
表示兆字节,依此类推,直到G
和T
。如果未给出后缀,则假定为千字节,以便与 GNU sort 兼容。未使用 sort 程序;命令行仅设计为兼容。-T
:临时文件位置。--discount_fallback
:Kneser-Ney 平滑折扣是从计数的计数估计的,包括单例的数量。
!$KENLM_BASE/lmplz -o 4 < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/base_lm.arpa
$KENLM_BASE/build_binary
参数
-q
:概率的量化标志。例如,-q 8
存储 8 位概率-b
:回退权重的量化标志。例如,-b 7
存储 7 位回退-a
:要删除的最大位数,并使用偏移量表隐式存储,以减少内存占用。
!$KENLM_BASE/build_binary trie -q 8 -b 7 -a 256 $RESULTS_DIR/base_lm.arpa $RESULTS_DIR/base_lm.bin
加载 ASR 模型#
import torch
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
从 NGC 下载英语 Conformer 模型#
!source ~/.bash_profile && ngc registry model download-version "nvidia/riva/speechtotext_en_us_conformer:trainable_v3.1" --dest $MODELS_DIR
为 Flashlight 解码器创建词汇表文件#
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/base_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
更新解码器类型、语言模型和词汇表#
device = torch.device('cuda')
asr_model = ASRModel.restore_from(f"{MODELS_DIR}/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo").to(device)
decoding_cfg = CTCDecodingConfig()
decoding_cfg.strategy = "flashlight"
decoding_cfg.beam.search_type = "flashlight"
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/base_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'
decoding_cfg.beam.beam_size = 32
decoding_cfg.beam.beam_alpha = 0.2
decoding_cfg.beam.beam_beta = 0.2
decoding_cfg.beam.flashlight_cfg.beam_size_token = 32
decoding_cfg.beam.flashlight_cfg.beam_threshold = 25.0
asr_model.change_decoding_strategy(decoding_cfg)
评估#
import json
import os
def transcribe_json(asr_model, json_path, output_json):
dataset_root = os.path.split(json_path)[0]
with open(json_path) as fin, open(output_json, 'w') as fout:
manifest = []
audios = []
for line in fin:
dt = json.loads(line.strip())
manifest.append(dt)
audios.append(dt['audio_filepath'].replace("/data", dataset_root))
transcripts = asr_model.transcribe(paths2audio_files=audios)
for i in range(len(transcripts)):
dt = {
'audio_filepath': manifest[i]['audio_filepath'],
'text': transcripts[i]
}
fout.write(json.dumps(dt)+"\n")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json")
计算词错误率#
!pip install jiwer
from jiwer import wer
import json
def calculate_wer(ground_truth_manifest, asr_transcript):
data ={}
ground_truths = []
predictions = []
with open(ground_truth_manifest) as file:
for line in file:
dt = json.loads(line)
data[dt['audio_filepath']] = dt['text']
with open(asr_transcript) as file:
for line in file:
dt = json.loads(line)
if dt['audio_filepath'] in data:
ground_truths.append(data[dt['audio_filepath']])
predictions.append(dt['text'])
return round(100*wer(ground_truths, predictions), 2)
print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
微调和插值#
微调过程将继续使用先前训练的模型进行训练,方法是在新的领域数据上训练第二个模型,并将其与原始模型进行插值。微调要求原始模型在训练期间启用中间层。微调后的模型不能再次用于微调。
下载和处理领域数据(医疗保健)以进行 LM 微调#
为了在医疗保健领域进行微调,我们可以使用 Kaggle 数据集 PubMed 200k RCT:用于医学摘要中顺序句子分类的数据集。
此数据集可在此处获得。
按照说明安装和验证 Kaggle API。
注意:每位用户都有责任检查数据集的内容和适用的许可证,并确定它们是否适合预期用途。
!kaggle datasets download -d anshulmehtakaggl/200000-abstracts-for-seq-sentence-classification
!unzip -d $DATA_DOWNLOAD_DIR 200000-abstracts-for-seq-sentence-classification.zip
# Perform basic text cleaning and generate domain data
import string,re
def clean_text(text):
text = re.sub(r"[^a-z' ]+", "", text.lower().strip())
text = ' '.join(text.split())
if len(text.split())> 5:
return text.strip()
# Using dev file since we want a small amount of finetuning data. For better text Normalization use NeMo [https://github.com/NVIDIA/NeMo/tree/main/nemo_text_processing]
with open(f'{DATA_DOWNLOAD_DIR}/20k_abstracts_numbers_with_@/dev.txt') as file, open(f'{DATA_DOWNLOAD_DIR}/domain_data_all.txt', 'w') as outfile:
for line in file:
if line.startswith("###") or not line.strip():
continue
_, text = line.strip().split('\t')
text = clean_text(text)
if text:
outfile.write(text+'\n')
# Picking top 10000 lines from dataset
!head -10000 $DATA_DOWNLOAD_DIR/domain_data_all.txt > $DATA_DOWNLOAD_DIR/domain_data.txt
微调过程将继续使用先前训练的模型进行训练,方法是在新的领域数据上训练第二个模型,并将其与原始模型进行插值。微调要求原始模型在训练期间启用中间层。微调后的模型不能再次用于微调。
为了使用 KenLM 微调 n-gram 语言模型,请执行以下步骤
生成中间 ARPA#
# Base LM
!mkdir base_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate base_intermediate/inter < $DATA_DOWNLOAD_DIR/reduced_training.txt
# Healthcare LM
!mkdir healthcare_intermediate
!$KENLM_BASE/lmplz -o 4 --intermediate healthcare_intermediate/inter < $DATA_DOWNLOAD_DIR/domain_data.txt
插值#
插值权重可以通过以下方式传递
-w 0.6 0.4
这里,60% 的权重分配给基础 LM,40% 分配给领域。
!$KENLM_BASE/interpolate -w 0.6 0.4 -m base_intermediate/inter healthcare_intermediate/inter > $RESULTS_DIR/interpolated_lm_60-40.arpa
!$KENLM_BASE/build_binary trie -q 8 -b 7 -a 256 $RESULTS_DIR/interpolated_lm_60-40.arpa $RESULTS_DIR/interpolated_lm_60-40.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/interpolated_lm_60-40.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
评估#
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/interpolated_lm_60-40.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/interpolated_lm_60-40.lexicon'
asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json")
计算 WER#
print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Domain model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_interpolated_lm_60-40.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Domain model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_interpolated_lm_60-40.json"))
剪枝#
通过简单地将文本语料库传递给 kenLM 而生成的 LM 包含一些不常见的 n-gram(在语料库中),因此概率非常低。可以通过 pruning
删除此类 n-gram。
剪枝需要一些阈值,这些阈值可以通过 --prune
参数传递,后跟空格分隔的阈值,这些阈值指定生成 ARPA 时每个阶数的计数阈值
--prune 0 1 7 9
所有频率小于或等于指定阈值的 n-gram 都将被消除。
此处,频率 <= 1 的 2-gram、频率 <= 7 的 3-gram 和频率 <= 9 的 4-gram 将被消除。
剪枝程度和准确性之间存在权衡。高剪枝参数会减小语言模型的大小,但会牺牲模型准确性!
*注意:#
不支持 1-gram 的剪枝,1-gram 的阈值应始终为 0
!kenlm/build/bin/lmplz -o 4 --prune 0 1 7 9 < $DATA_DOWNLOAD_DIR/reduced_training.txt > $RESULTS_DIR/pruned_lm.arpa
!$KENLM_BASE/build_binary trie -q 8 -b 7 -a 256 $RESULTS_DIR/pruned_lm.arpa $RESULTS_DIR/pruned_lm.bin
!python NeMo/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py --arpa $RESULTS_DIR/pruned_lm.arpa --model $MODELS_DIR/speechtotext_en_us_conformer_vtrainable_v3.1/Conformer-CTC-L_spe-1024_en-US_Riva-ASR-SET-3.1.nemo --lower --dst $RESULTS_DIR
# Lets check the size of original LM and Pruned LM
!echo "Size of unpruned ARPA: $(du -h $RESULTS_DIR/base_lm.arpa | cut -f 1)"
!echo "Size of pruned ARPA: $(du -h $RESULTS_DIR/pruned_lm.arpa | cut -f 1)"
评估#
decoding_cfg.beam.kenlm_path = f'{RESULTS_DIR}/pruned_lm.bin'
decoding_cfg.beam.flashlight_cfg.lexicon_path=f'{RESULTS_DIR}/base_lm.lexicon'
asr_model.change_decoding_strategy(decoding_cfg)
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json")
transcribe_json(asr_model, f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json")
计算 WER#
print( "WER of base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_base_lm.json"))
print( "WER of Pruned base model on generic domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/general.json", f"{RESULTS_DIR}/general_pruned_lm.json"))
print("WER of base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_base_lm.json"))
print("WER of Pruned base model on Healthcare domain data", calculate_wer(f"{DATA_DOWNLOAD_DIR}/healthcare_eval_set_v1.0/healthcare.json", f"{RESULTS_DIR}/healthcare_pruned_lm.json"))