评估 TTS 流水线#

在本教程中,我们将使用自动语音识别 (ASR) 从 TTS 合成数据生成转录,并使用字符错误率 (CER) 和词错误率 (WER) 将生成的转录与 groundtruth 进行比较。

通过将 ASR 生成的转录与 ground truth 转录进行比较,这些指标有助于发现音频、转录对之间的任何不一致之处。

本教程将包括

  • 下载 5 分钟的 hifiTTS 音频转录对。

  • 使用预训练的 NeMo ASR 模型为音频生成转录。

  • 计算 ground truth 转录和 ASR 生成的转录之间的字符错误率和词错误率。

下载数据#

在本教程中,我们将使用 Hi-Fi 多说话人英语 TTS (Hi-Fi TTS) 数据集的一小部分。您可以在此处阅读有关数据集的更多信息。我们将使用说话人 6097 作为目标说话人,并且本评估示例仅使用 5 分钟的音频子集。

!wget https://nemo-public.s3.us-east-2.amazonaws.com/6097_5_mins.tar.gz  # Contains 10MB of data
!tar -xzf 6097_5_mins.tar.gz
manifest_file = "6097_5_mins/manifest.json"
asr_pred = "6097_5_mins/asr_pred.json"


## Fix audiopaths in manifest.json
!sed -i 's,audio/,6097_5_mins/audio/,g' {manifest_file}

查看 manifest.json,我们看到一个标准的 NeMo json,其中包含文件路径、文本和持续时间。请确保 manifest.json 包含相对路径。

manifest 文件应如下所示

{"audio_filepath": "6097_5_mins/audio/presentpictureofnsw_02_mann_0532.wav", "text": "not to stop more than ten minutes by the way", "duration": 2.6, "text_no_preprocessing": "not to stop more than ten minutes by the way,", "text_normalized": "not to stop more than ten minutes by the way,"}
## Print the first line of manifest file.
!head -n 1 {manifest_file}

从 asr 合成文本。#

我们将需要 nemo toolkittranscribe_speech.py 来为我们的音频样本生成转录。

让我们安装 nemo toolkit

## Clone the latest NeMo.
!pip install nemo_toolkit['all']
!pip install --upgrade protobuf==3.20.0

现在下载 transcribe_speech.py

!wget https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/asr/transcribe_speech.py

使用 nemo 和 transcribe_speech.py 转录音频样本。这将稍后用于计算字符错误率和词错误率。

使用的模型是英语预训练的 conformer CTC ASR 模型

# Generate transcriptions
!python transcribe_speech.py \
    pretrained_name=stt_en_conformer_ctc_large \
    dataset_manifest={manifest_file} \
    output_filename={asr_pred} \
    batch_size=32 ++compute_langs=False cuda=0 amp=True

让我们看一下 asr_pred 文件,并确保我们有一个 text 字段和一个 pred_text 字段。asr_pred 文件应如下所示

{"audio_filepath": "6097_5_mins/audio/presentpictureofnsw_02_mann_0532.wav", "text": "not to stop more than ten minutes by the way", "duration": 2.6, "text_no_preprocessing": "not to stop more than ten minutes by the way,", "text_normalized": "not to stop more than ten minutes by the way,", "pred_text": "not to stop more than ten minutes by the way"}
!head -2 {asr_pred}

计算字符错误率 (CER)。#

编辑距离或 Levenshtein 距离是衡量两个字符串相似度的指标。该指标考虑了 ground truth 中的任何添加、删除或替换,以获得评估字符串。

使用 Levenshtein 距离来测量生成的转录和 ground truth 转录之间的 编辑距离字符错误率

字符错误率 是每个 ground truth 单词的编辑距离。它也可以解释为归一化编辑距离。

\(error\ rate = \frac{edit\ distance}{no\ of\ words\ in\ ground\ truth}\)

## Install the edit distance package
!pip install editdistance
## Install ndjson to read the asr_pred file
!pip install ndjson
import editdistance
import ndjson
import string

设置编辑距离和错误率的阈值。任何超过这些阈值的发音都需要调查。这些值可以进行微调。

distance_threshold = 5
cer_threshold = 0.5

由于 ASR 转录不包含任何标点符号,请在计算编辑距离之前从原始转录中删除标点符号。

## Punctuation translation dictionary.
punct_dict = str.maketrans('', '', string.punctuation)

f = open(asr_pred)
manifest = ndjson.load(f)
f.close()

计算编辑距离并打印所有发音,其中

  • error_rate > error_threshold

  • distance > distance_threshold

for line in manifest:
    transcript = line["text"].lower().translate(punct_dict)
    pred_text = line["pred_text"]
    try:
        distance = editdistance.eval(transcript, pred_text)
        cer = distance / len(transcript)
    except Exception as e:
        print(f"Got error: {e} for line: {line}")
        distance = 0
        cer = 0
    if distance > distance_threshold or cer > cer_threshold:
        print(f"Low confidence for {line}")

计算 WER(词错误率)#

现在我们已经列出了所有字符错误率高的句子,我们将列出所有词错误率高的句子。

顾名思义,词错误率衡量的是单词级别的错误,而不是 字符错误率 中的字符级别。此指标考虑了参考文本中的单词替换、单词插入和单词删除的数量。

计算公式为: $\( WER=\frac{S+I+D}{N} \)$ S = 替换次数
I = 插入次数
D = 删除次数
N = 参考文本中的单词总数

我们将使用 python 包 jiwer

## Install python package to calculate word error rate.
!pip install jiwer
from jiwer import wer

设置词错误率的阈值。任何 WER 大于此值的发音都需要调查。此值可以进行微调。

wer_threshold = 0.8 #Can be finetuned.

计算词错误率并打印所有词错误率高的发音

for line in manifest:
    transcript = line["text"].lower().translate(punct_dict)
    pred_text = line["pred_text"]
    try:
        error_rate = wer(transcript, pred_text)
    except Exception as e:
        print(f"Got error: {e} for line: {line}")
        error_rate = 0
    if error_rate > wer_threshold:
        print(f"Low confidence for file: {line['audio_filepath']} --- Transcript: {transcript} --- Predicted text: {pred_text} --- Word error rate: {error_rate}")

结论#

在本教程中,我们学习了如何计算编辑距离、字符错误率和词错误率。我们还学习了如何应用这些指标来评估音频、转录对的质量。

这些类型的指标可以用作冒烟测试和选择候选模型。但最终,衡量 TTS 模型质量的唯一方法是使用主观方法来评估和比较模型,例如 MOS 和 CMOS。