Comparator
模块: polygraphy.comparator
- class Comparator[源代码]
基类:
object
比较推理输出。
- static run(runners, data_loader=None, warm_up=None, use_subprocess=None, subprocess_timeout=None, subprocess_polling_interval=None, save_inputs_path=None)[源代码]
顺序运行提供的 runners。
- 参数:
runners (List[BaseRunner]) – 要运行的 runners 列表。
data_loader (Sequence[OrderedDict[str, numpy.ndarray]]) –
一个生成器或可迭代对象,它产生一个字典,该字典将输入名称映射到输入 numpy 缓冲区。在最简单的情况下,这可以是一个 List[Dict[str, numpy.ndarray]] 。
如果您不提前了解有关输入的详细信息,则可以在数据加载器中访问 input_metadata 属性,该属性将由此函数设置为 TensorMetadata 实例。请注意,这不适用于生成器或列表。
此函数运行的迭代次数由数据加载器提供的项目数控制。
默认为 DataLoader 的实例。
warm_up (int) – 在计时之前为每个 runner 执行的预热运行次数。默认为 0。
use_subprocess (bool) – 是否应在子进程中运行每个 runner。这允许每个 runner 独占访问 GPU。当使用子进程时,runners 和加载器永远不会被修改。
subprocess_timeout (int) – 子进程自动终止之前的超时时间。这对于处理永不终止的进程很有用。None 值禁用超时。默认为 None。
subprocess_polling_interval (int) – 轮询间隔(以秒为单位),用于检查子进程是否已完成或崩溃。在极少数情况下,当启用子进程时省略此参数可能会导致此函数在子进程崩溃时无限期挂起。值 0 禁用轮询。默认为 30 秒。
save_inputs_path (str) – 用于保存推理期间使用的输入的路径。这将包括由提供的数据加载器生成的所有输入,并将保存为 JSON List[Dict[str, numpy.ndarray]]。
- 返回:
runner 名称到其推理结果的映射。runners 的顺序在此映射中保留。
- 返回类型:
- static postprocess(run_results, postprocess_func)[源代码]
将后处理应用于提供的运行结果中的所有输出。这是一个方便的函数,避免了手动迭代 run_results 字典的需要。
- 参数:
run_results (RunResults) – Comparator.run() 的结果。
postprocess_func (Callable(IterationResult) -> IterationResult) – 要应用于每个
IterationResult
的函数。
- 返回:
更新后的运行结果。
- 返回类型:
- static compare_accuracy(run_results, fail_fast=False, comparisons=None, compare_func=None)[源代码]
- 参数:
run_results (RunResults) – Comparator.run() 的结果
fail_fast (bool) – 是否在第一次失败后退出
comparisons (List[Tuple[int, int]]) – 要执行的比较,由 runner 索引指定。例如,[(0, 1), (1, 2)] 将比较第一个 runner 和第二个 runner,以及第二个 runner 和第三个 runner。默认情况下,这会将每个结果与后续结果进行比较。
compare_func (Callable(IterationResult, IterationResult) -> OrderedDict[str, bool]) – 一个接受两个 IterationResults 的函数,并返回一个字典,该字典将输出名称映射到一个布尔值(或任何可转换为布尔值的值),指示输出是否匹配。保证此函数的参数顺序与 comparisons 中包含的元组的顺序相同。
- 返回:
比较结果的摘要。键的顺序(即 runner 对)保证与 comparisons 的顺序相同。有关更多详细信息,请参阅 AccuracyResult 文档字符串(例如,help(AccuracyResult))。
- 返回类型:
- static validate(run_results, check_inf=None, check_nan=None, fail_fast=None)[源代码]
检查输出有效性。
- 参数:
run_results (Dict[str, List[IterationResult]]) – Comparator.run() 的结果。
check_inf (bool) – 是否在遇到 Infs 时失败。默认为 False。
check_nan (bool) – 是否在遇到 NaNs 时失败。默认为 True。
fail_fast (bool) – 是否在第一个无效值后失败。默认为 False。
- 返回:
如果所有输出都有效,则为 True,否则为 False。
- 返回类型:
bool