重要提示
您正在查看 NeMo 2.0 文档。此版本引入了对 API 的重大更改和一个新的库 NeMo Run。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
NeMo SSL 集合 API#
模型类#
Mixin#
- class nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin
基类:
ASRAdapterModelMixin
ASRModuleMixin 是添加到 ASR 模型中的 mixin 类,目的是为了添加特定于 ASRModel 内部模块特定实例化的方法。
每个方法应首先检查子类中是否存在该模块,并在存在相应模块时支持其他功能。
- change_attention_model(
- self_attention_model: str | None = None,
- att_context_size: List[int] | None = None,
- update_config: bool = True,
如果编码器中提供了该功能,则更新 self_attention_model。
- 参数:
self_attention_model (str) –
注意力层和位置编码的类型
- “rel_pos”
相对位置嵌入和 Transformer-XL
- “rel_pos_local_attn”
相对位置嵌入和 Transformer-XL,使用重叠窗口的局部注意力。注意力上下文由 att_context_size 参数确定。
- “abs_pos”
绝对位置嵌入和 Transformer
如果提供 None,则 self_attention_model 不会更改。默认为 None。
att_context_size (List[int]) – 2 个整数的列表,分别对应左侧和右侧的注意力上下文大小,如果为 None,则保持不变。默认为 None。
update_config (bool) – 是否使用新的注意力模型更新配置。默认为 True。
- change_conv_asr_se_context_window(
- context_window: int,
- update_config: bool = True,
如果提供的模型包含 encoder,且该编码器是 ConvASREncoder 的实例,则更新 SqueezeExcitation 模块的上下文窗口。
- 参数:
context_window –
一个整数,表示用于计算上下文的输入时间帧数。每个时间帧对应于 STFT 特征的单个窗口步幅。
假设 window_stride = 0.01 秒,则上下文窗口 128 表示 128 * 0.01 秒的上下文来计算 Squeeze 步骤。
update_config – 是否使用新的上下文窗口更新配置。
- change_subsampling_conv_chunking_factor(
- subsampling_conv_chunking_factor: int,
- update_config: bool = True,
如果编码器中提供了该功能,则更新 conv_chunking_factor(整数)。默认为 1(自动)。如果 conv 子采样层中出现 OOM,则将其设置为 -1(禁用)或特定值(2 的幂)。
- 参数:
conv_chunking_factor (int)
- conformer_stream_step(
- processed_signal: torch.Tensor,
- processed_signal_length: torch.Tensor | None = None,
- cache_last_channel: torch.Tensor | None = None,
- cache_last_time: torch.Tensor | None = None,
- cache_last_channel_len: torch.Tensor | None = None,
- keep_all_outputs: bool = True,
- previous_hypotheses: List[Hypothesis] | None = None,
- previous_pred_out: torch.Tensor | None = None,
- drop_extra_pre_encoded: int | None = None,
- return_transcription: bool = True,
- return_log_probs: bool = False,
它模拟了具有缓存的前向步骤,用于流式传输目的。它支持编码器支持流式传输的 ASR 模型,如 Conformer。 :param processed_signal: 输入音频信号 :param processed_signal_length: 音频的长度 :param cache_last_channel: 最后一个通道层(如 MHA)的缓存张量 :param cache_last_channel_len: cache_last_channel 的长度 :param cache_last_time: 最后一个时间层(如卷积)的缓存张量 :param keep_all_outputs: 如果设置为 True,则不会删除编码器 streaming_cfg.valid_out_len 指定的额外输出 :param previous_hypotheses: RNNT 模型上一步的假设 :param previous_pred_out: CTC 模型上一步的预测输出 :param drop_extra_pre_encoded: 从下采样模块之后的输出开头删除的步数。如果在输入的左侧添加了额外的填充,则可以使用此参数。 :param return_transcription: 是否解码并返回转录。对于 Transducer 模型,无法禁用此参数。 :param return_log_probs: 是否返回对数概率,仅对 ctc 模型有效
- 返回:
来自解码器的贪婪预测 all_hyp_or_transcribed_texts:Transducer 模型的解码器假设和 CTC 模型的转录 cache_last_channel_next:已更新的最后一个通道层张量缓存,用于下一个流式传输步骤 cache_last_time_next:已更新的最后一个时间层张量缓存,用于下一个流式传输步骤 cache_last_channel_next_len:cache_last_channel 的已更新长度 best_hyp:Transducer 模型的最佳假设 log_probs:当前流式传输块的 logits 张量,仅在 return_log_probs=True 时返回 encoded_len:输出 log_probs + 历史块 log_probs 的长度,仅在 return_log_probs=True 时返回
- 返回类型:
greedy_predictions
- transcribe_simulate_cache_aware_streaming(
- paths2audio_files: List[str],
- batch_size: int = 4,
- logprobs: bool = False,
- return_hypotheses: bool = False,
- online_normalization: bool = False,
- 参数:
paths2audio_files – 音频文件路径的(列表)。
batch_size – (整数)推理期间使用的批大小。越大将导致更好的吞吐量性能,但会使用更多内存。
logprobs – (布尔值)传递 True 以获取对数概率而不是转录。
return_hypotheses – (布尔值)返回假设或文本。使用假设可以进行一些后处理,例如获取时间戳或重新评分
online_normalization – (布尔值)对每个块进行在线归一化。
- 返回:
与 paths2audio_files 顺序相同的转录列表(如果 logprobs 为 True,则为原始对数概率)
- class nemo.core.classes.mixins.access_mixins.AccessMixin
基类:
ABC
允许访问模型的中间层的输出
- property access_cfg
返回:跨所有访问 mixin 模块共享的全局访问配置。
- classmethod get_module_registry(module: torch.nn.Module)
从命名的子模块中提取所有注册表,返回字典,其中键是展平的模块名称,值是每个此类模块的内部注册表。
- register_accessible_tensor(name, tensor)
注册张量以供后续使用。
- reset_registry(registry_key: str | None = None)
重置所有命名子模块的注册表