重要提示
您正在查看 NeMo 2.0 文档。此版本为 API 和新库 NeMo Run 引入了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
指标#
- class nemo.collections.common.metrics.Perplexity(*args: Any, **kwargs: Any)#
基类:
Metric
此类计算输入最后一维分布的平均困惑度。它是 torch.distributions.Categorical.perplexity 方法的包装器。您必须向
update()
方法提供probs
或logits
。此类计算传递给update()
方法的probs
或logits
参数中的分布的困惑度,并对困惑度进行平均。所有工作进程之间的结果缩减通过 SUM 操作完成。有关指标使用说明,请参阅 PyTorch Lightning Metrics。- 参数:
dist_sync_on_step – 在每一步的
forward()
返回值之前,跨进程同步指标状态。process_group –
- 指定在其上调用同步的进程组。默认值:
None
(选择整个 世界)
- 指定在其上调用同步的进程组。默认值:
validate_args – 如果为
True
,则检查update()
方法参数的值。logits
不得包含 NaN,并且probs
的最后一维必须是有效的概率分布。
- compute()#
返回所有工作进程的困惑度,并将
perplexities_sum
和num_distributions
重置为 0。
- full_state_update = True#
- update(probs=None, logits=None)#
更新
perplexities_sum
和num_distributions
。 :param probs: 一个torch.Tensor
,其最内层维度是有效的概率分布。 :param logits: 一个不包含 NaN 的torch.Tensor
。