重要提示
您正在查看 NeMo 2.0 文档。此版本对 API 和新的库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。 有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
NeMo SSL 配置文件#
本页介绍 NeMo 配置文件设置,该设置特定于语音自监督预训练集合中的模型。 有关如何设置和运行所有 NeMo 模型通用的实验(例如,实验管理器和 PyTorch Lightning 训练器参数)的常规信息,请参阅 NeMo 模型 页面。
数据集配置#
自监督模型的数据集配置与标准 ASR 训练基本相同,此处已介绍。 主要区别在于,为了执行对比损失,我们需要为批次中所有 utterance 掩蔽等量的 patches。 这意味着我们希望避免单个批次内 duration 变化过大。 在 NeMo 中,您可以通过几种方式实现此目的
1) 最简单的方法是在数据集配置中使用 min_duration
参数,这将简单地丢弃所有低于指定长度的 utterance。 如果删除这些 utterance 不会显着影响数据集的总小时数,那么这是一个可行的选择。
2) 如果您的数据集包含许多长度各异的长 utterance(超过约 16 秒),那么您可能需要改用 random_segment
扰动,它将在运行时从完整样本中采样特定长度的 segment(低于提供的 segment 长度的样本将被填充)。 您可以通过将以下内容添加到数据集配置来启用此功能
augmentor:
random_segment:
prob: 1.0
duration_sec: 16 # specify the duration you want
3) 您还可以使用 bucketing 来确保批次内 utterance 长度相似。 请参阅 Bucketing 文档。
SSL 训练和验证配置的示例如下所示
model:
train_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 8
pin_memory: false
use_start_end_token: true
trim_silence: false
max_duration: 16.7
min_duration: 8.0
# tarred datasets
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "synced_randomized"
bucketing_batch_size: null
validation_ds:
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 8
pin_memory: true
use_start_end_token: false
min_duration: 8.0
预处理器配置#
预处理器有助于计算 MFCC 或梅尔频谱图特征,这些特征作为模型的输入。 有关如何编写此部分的详细信息,请参阅 预处理器配置
增强配置#
对于自监督预训练,我们建议使用 MaskedPatchAugmentation
类进行频谱图掩蔽。 这种增强将 utterance 分成固定大小的 patches,然后掩蔽固定数量/比例的 patches。 您还可以添加 freq_masks
和 freq_width
以将掩蔽应用于频带。
如果您使用对比损失,并且负样本仅从同一 utterance 中的掩蔽步骤中采样,请确保每个 utterance 中的掩蔽步骤总数足够大,以满足采样的负样本数量。 例如,如果您使用 4 倍步幅并想要采样 100 个负样本,那么您将需要超过 400 个掩蔽步骤。 如果您使用默认的 patch_size
48,那么这意味着您需要将 mask_patches
设置为至少 9。 当使用总 patches 数量的一部分而不是固定数量时,您需要确保样本的最小 duration 足够大,以满足要采样的负样本数量。
spec_augment:
_target_: nemo.collections.asr.modules.MaskedPatchAugmentation
patch_size: 48 # size of a single patch
mask_patches: 0.5 # fraction of patches to mask (can be fixed int amount instead)
freq_masks: 3 # Cut three frequency bands
freq_width: 20 # ... of width 20 at maximum
模型架构配置#
每个配置文件都应描述实验中使用的模型架构。 对于自监督预训练,我们通常会训练模型的编码器,然后在微调中重复使用它,因此编码器的配置方式可以与 ASR 模型相同。 请注意,任何 ASR 模型编码器都可以与任何可用的预训练方法一起使用,但是,在模型大小相同的情况下,我们发现使用 Conformer 时可以获得最佳的下游结果。
与编码器不同,解码器和相应的损失函数将特定于自监督预训练,并且足够小,您可以在将模型转移到微调时丢弃它们。
我们可以使用的最基本的预训练方法是让模型解决对比任务(这是 wav2vec 2.0 [SSL-MODELS1] 中使用的方法)。 我们可以为步幅为 4 倍的编码器按以下方式定义相应的解码器和损失配置。
decoder_out: 128
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: 128
feat_out: ${model.decoder_out}
stride_layers: 0
# if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to
# be added to the decoder (here stride and combine_time_steps are both 4)
non_stride_layers: 0
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: ${model.preprocessor.features}
proj_dim: ${model.decoder_out}
combine_time_steps: 4 # how many spectrogram time steps are used for one target/representation for contrastive task
quantized_targets: true # should quantizer or linear layer be used
codebook_size: 300 # size of a single codebook for quantizer
num_groups: 2 # number of codebooks to use for quantizer
num_negatives: 100 # number of sampled negatives for each target
sample_from_same_utterance_only: true # should negatives be sampled only from the same utterance
sample_from_non_masked: false # should negatives be sampled from non-masked steps
请注意,在上面的示例中,我们将输入频谱图中的 4 个步骤组合成一个用于损失的“token”,这对应于编码器步幅 4 倍。 我们可能希望为 “combine_time_steps” 和编码器步幅使用不同的值。 在这种情况下,我们将需要在解码器中添加步幅层以匹配步幅。 我们可以为步幅为 8 倍的 Citrinet 编码器使用以下示例配置。 为了从步幅 8 倍变为 4 倍,我们在解码器中使用单个 stride_layer
,并将 stride_transpose
设置为 True。
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.model_defaults.enc_final}
feat_hidden: 128
feat_out: ${model.model_defaults.decoder_out_channels}
stride_layers: 1
#if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to
#be added to the decoder (here stride is 8 and combine_time_steps is 4, so 1 stride layer is added)
non_stride_layers: 0
stride_tranpose: true # whether to use transposed convolution for stride layers or not
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: *n_mels
proj_dim: ${model.model_defaults.decoder_out_channels}
combine_time_steps: 4 #how many spectrogram time steps are used for one target/representation for contrastive task
quantized_targets: false #should quantizer or linear layer be used
sample_from_same_utterance_only: true #should negatives be sampled only from the same utterance
sample_from_non_masked: false #should negatives be sampled from non-masked steps
将对比损失与其他损失(例如掩蔽语言建模 (mlm) 损失(类似于 W2V-Bert [SSL-MODELS2] 的方法))结合起来可能是有益的。 为了做到这一点,我们可以指定一个 loss_list
,而不是在配置中指定单个 decoder
和 loss
,loss_list
可以包含任意数量的相应解码器和损失。 对于每个解码器-损失对,我们可以指定一个单独的命名子配置,其中包含以下字段
decoder
- 解码器配置,指定目标类和参数。loss
- 相应的损失配置,指定目标类和参数。loss_alpha
- 此损失的乘数(默认为 1.0)。targets_from_loss
- 此参数指定我们应该从中提取标签的对比损失。 如果标签未在您的清单中显示,则对于任何需要标签的损失,此参数都是必需的。transpose_encoded
- 此参数用于在将编码特征传递到此损失之前,选择性地转置编码特征。start_step
- 我们应该开始使用此解码器+损失的训练步骤。output_from_layer
- 此参数可用于指定我们应该从中提取编码特征以传递到此解码器的层的名称。 如果未指定或设置为 null,则使用最终编码器层。
以下是对比损失+mlm 损失组合的 loss_list 示例,其中 mlm 损失使用来自对比损失的量化模块的目标。
decoder_out: 128
loss_list:
contrastive:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: 128
# features in hidden layer of decoder
feat_out: ${model.decoder_out}
stride_layers: 0
# if loss.combine_time_steps is less than the encoder stride, then a corresponding amount of stride_layers needs to
# be added to the decoder (here stride and combine_time_steps are both 4)
non_stride_layers: 0
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: ${model.preprocessor.features}
proj_dim: ${model.decoder_out}
combine_time_steps: 4 # how many spectrogram time steps are used for one target/representation for contrastive task
quantized_targets: true # should quantizer or linear layer be used
# (quantizer is required to extract pseudo-labels for other losses)
codebook_size: 300
num_groups: 2
sample_from_same_utterance_only: true # should negatives be sampled only from the same utterance
sample_from_non_masked: false # should negatives be sampled from non-masked steps
mlm:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: ${model.encoder.d_model}
num_classes: 90000
# set this to be equal to codebook_size^groups in the contrastive loss
loss:
_target_: nemo.collections.asr.losses.MLMLoss
combine_time_steps: 4
targets_from_loss: "contrastive"
# since this loss requires targets, we can either get them from a manifest or from a quantized contrastive loss
loss_alpha: 1000.
# multiplier applied to this loss relative to others
transpose_encoded: false
# transposing input may be necessary depending on which layer is used as input to decoder
start_step: 0
# determines what global step this loss starts being used at;
# this can be set to a higher number if your training is long enough,
# which may increase early training stability
output_from_layer: null
# if we wanted to use outputs from non-final encoder layer as input to this decoder,
# the layer name should be specified here
我们还可以使用其他需要标签而不是 mlm 的损失,例如 ctc 或 rnnt 损失。 由于这些损失与 mlm 不同,不需要我们的目标与我们的步骤直接对齐,我们可能还希望将对比损失的 reduce_ids
参数设置为 true,以将任何连续等效 id 序列转换为该 id 的单个实例。
由对比损失+ctc 损失组成的 loss_list
示例如下所示
decoder_out: 128
loss_list:
contr:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: 128
feat_out: ${model.decoder_out}
stride_layers: 0
non_stride_layers: 0
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: ${model.preprocessor.features}
proj_dim: ${model.decoder_out}
combine_time_steps: 4
quantized_targets: true
codebook_size: 300
num_groups: 2
sample_from_same_utterance_only: true
sample_from_non_masked: false
reduce_ids: true
ctc:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: ${model.encoder.d_model}
num_classes: 90000
loss:
_target_: nemo.collections.asr.losses.CTCLossForSSL
num_classes: 90000
targets_from_loss: "contr"
start_step: 3000
对比损失+rnnt 的示例如下所示
decoder_out: 128
loss_list:
contr:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: 128
feat_out: ${model.decoder_out}
stride_layers: 0
non_stride_layers: 0
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: ${model.preprocessor.features}
proj_dim: ${model.decoder_out}
combine_time_steps: 4
quantized_targets: true
codebook_size: 24
sample_from_same_utterance_only: true
sample_from_non_masked: false
reduce_ids: true
rnnt:
decoder:
_target_: nemo.collections.asr.modules.RNNTDecoderJointSSL
decoder:
_target_: nemo.collections.asr.modules.RNNTDecoder
normalization_mode: null # Currently only null is supported for export.
random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf
blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference.
vocab_size: 576
prednet:
pred_hidden: 640
pred_rnn_layers: 1
t_max: null
dropout: 0.1
joint:
_target_: nemo.collections.asr.modules.RNNTJoint
log_softmax: null # 'null' would set it automatically according to CPU/GPU device
preserve_memory: false # dramatically slows down training, but might preserve some memory
experimental_fuse_loss_wer: false
jointnet:
encoder_hidden: 512
pred_hidden: 640
joint_hidden: 640
activation: "relu"
dropout: 0.1
num_classes: 576
loss:
_target_: nemo.collections.asr.losses.RNNTLossForSSL
num_classes: 576
targets_from_loss: "contr"
start_step: 1000
我们还可以使用多个损失函数,这些损失函数使用来自编码器不同中间层的特征作为输入 [SSL-MODELS3]。 在以下配置示例中,我们使用对比损失 + 三个不同的 mlm 损失,它们分别使用来自第 6 层、第 12 层和最后一层的编码器输出。
decoder_out: 128
loss_list:
contr:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoderReconstruction
feat_in: ${model.encoder.d_model}
feat_hidden: 128
feat_out: ${model.decoder_out}
stride_layers: 0
non_stride_layers: 0
loss:
_target_: nemo.collections.asr.losses.ContrastiveLoss
in_dim: ${model.preprocessor.features}
proj_dim: ${model.decoder_out}
combine_time_steps: 4
quantized_targets: true
codebook_size: 300
sample_from_same_utterance_only: true
sample_from_non_masked: false
loss_alpha: 5.
mlm:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: ${model.encoder.d_model}
num_classes: 90000
loss:
_target_: nemo.collections.asr.losses.MLMLoss
combine_time_steps: 4
targets_from_loss: "contr"
loss_alpha: 1000.
mlm_2:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: ${model.encoder.d_model}
num_classes: 90000
loss:
_target_: nemo.collections.asr.losses.MLMLoss
combine_time_steps: 4
targets_from_loss: "contr"
loss_alpha: 300.
output_from_layer: "layers.5"
transpose_encoded: true
mlm_3:
decoder:
_target_: nemo.collections.asr.modules.ConvASRDecoder
feat_in: ${model.encoder.d_model}
num_classes: 90000
loss:
_target_: nemo.collections.asr.losses.MLMLoss
combine_time_steps: 4
targets_from_loss: "contr"
loss_alpha: 300.
output_from_layer: "layers.11"
transpose_encoded: true
参考#
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed 和 Michael Auli。 Wav2vec 2.0:语音表示自监督学习框架。 2020. URL: https://arxiv.org/abs/2006.11477, doi:10.48550/ARXIV.2006.11477。
Yu-An Chung, Yu Zhang, Wei Han, Chung-Cheng Chiu, James Qin, Ruoming Pang 和 Yonghui Wu。 W2v-bert:结合对比学习和掩蔽语言建模进行自监督语音预训练。 2021. URL: https://arxiv.org/abs/2108.06209, doi:10.48550/ARXIV.2108.06209。
Chengyi Wang, Yu Wu, Sanyuan Chen, Shujie Liu, Jinyu Li, Yao Qian 和 Zhenglu Yang。 用于中间层监督的语音识别自监督学习。 2021. URL: https://arxiv.org/abs/2112.08778, doi:10.48550/ARXIV.2112.08778。