比较函数
模块: polygraphy.comparator
- class OutputCompareResult(passed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff, quantile_absdiff, quantile_reldiff)[source]
基类:
object
表示在两个运行器之间单次迭代的单个输出的比较结果。
记录比较期间收集的所需容差和其他统计信息。
- 参数:
passed (bool) – 错误是否在可接受的范围内。
max_absdiff (float) – 考虑输出等效所需的最小绝对容差。
max_reldiff (float) – 考虑输出等效所需的最小相对容差。
mean_absdiff (float) – 输出之间的平均绝对误差。
mean_reldiff (float) – 输出之间的平均相对误差。
median_absdiff (float) – 输出之间的中位数绝对误差。
median_reldiff (float) – 输出之间的中位数相对误差。
quantile_absdiff (float) – 输出之间的 q 分位数绝对误差。
quantile_reldiff (float) – 输出之间的 q 分位数相对误差。
- class CompareFunc[source]
基类:
object
提供可用于比较两个 IterationResult 的函数。
- static simple(check_shapes=None, rtol=None, atol=None, fail_fast=None, find_output_func=None, check_error_stat=None, infinities_compare_equal=None, save_heatmaps=None, show_heatmaps=None, save_error_metrics_plot=None, show_error_metrics_plot=None, error_quantile=None)[source]
创建一个比较两个 IterationResult 的函数,可用作
Comparator.compare_accuracy
中的 compare_func 参数。- 参数:
check_shapes (bool) – 形状是否必须完全匹配。 如果为 False,则此函数可能会在比较之前置换或重塑输出。 默认为 True。
rtol (Union[float, Dict[str, float]]) –
检查准确性时使用的相对容差。 这表示为第二组输出值的百分比。 例如,值 0.01 将检查第一组输出是否在第二组输出的 1% 之内。
可以基于每个输出来使用字典提供此值。 在这种情况下,使用空字符串 (“”) 作为键来指定未明确列出的输出的默认容差。 默认为 1e-5。
atol (Union[float, Dict[str, float]]) – 检查准确性时使用的绝对容差。 可以基于每个输出来使用字典提供此值。 在这种情况下,使用空字符串 (“”) 作为键来指定未明确列出的输出的默认容差。 默认为 1e-5。
fail_fast (bool) – 函数是否应在第一次失败后立即退出。 默认为 False。
find_output_func (Callable(str, int, IterationResult) -> List[str]) – 一个回调函数,给定另一个 IterationResult 的输出名称和索引,它从提供的 IterationResult 返回要比较的输出名称列表。 比较函数将始终迭代第一个 IterationResult 的输出名称,并期望来自第二个 IterationResult 的名称。 返回值 [] 或 None 表示应跳过该输出。
check_error_stat (Union[str, Dict[str, str]]) –
要检查的错误统计量。 可能的值包括
- ”elemwise”: 检查输出中的每个元素,以确定它是否超过指定的两个容差。
在此模式下显示的最小所需容差仅在仅设置一种类型的容差时适用。 由于检查的性质,当同时指定绝对/相对容差时,所需的最小容差可能会更低。
”max”: 检查最大绝对/相对误差是否超出各自的容差。 这是最严格的检查。
”mean” 检查平均绝对/相对误差是否超出各自的容差。
”median”: 检查中位数绝对/相对误差是否超出各自的容差。
”quantile”: 检查分位数绝对/相对误差是否超出各自的容差。
可以基于每个输出来使用字典提供此值。 在这种情况下,使用空字符串 (“”) 作为键来指定未明确列出的输出的默认错误统计量。 默认为 “elemwise”。
infinities_compare_equal (bool) – 如果为 True,则输出中匹配的 +-inf 值具有 0 的 absdiff。 如果为 False,则输出中匹配的 +-inf 值具有 NaN 的 absdiff。 默认为 False。
save_heatmaps (str) – [实验性] 用于保存绝对误差和相对误差热图的图形的目录路径。 默认为 None。
show_heatmaps (bool) – [实验性] 是否显示绝对误差和相对误差的热图。 默认为 False。
save_error_metrics_plot (str) – [实验性] 用于保存错误指标图的目录路径。 默认为 None。
show_error_metrics_plot (bool) – [实验性] 是否显示错误指标图。
error_quantile (Union[float, Dict[str, float]]) – 检查准确性时要计算的分位数误差。 这表示为 [0, 1] 范围内的浮点数。 例如,error_quantile=0.5 是中位数。 默认为 0.99。
- 返回值:
一个可调用对象,它返回输出名称到 OutputCompareResult 的映射,指示相应的输出是否匹配。
- 返回类型:
Callable(IterationResult, IterationResult) -> OrderedDict[str, OutputCompareResult]
- static indices(index_tolerance=None, fail_fast=None)[source]
创建一个比较包含索引的两个 IterationResult 的函数,可用作
Comparator.compare_accuracy
中的 compare_func 参数。 这对于比较 Top-K 操作的输出非常有用。具有多个维度的输出被视为多批值。 例如,形状为 (3, 4, 5, 10) 的输出将被视为 60 批(3 x 4 x 5)的每批 10 个值。
- 参数:
index_tolerance (Union[int, Dict[str, int]]) –
比较索引时使用的容差。 这是一个整数,指示值之间的最大距离,超过此距离则被视为不匹配。 例如,考虑两个输出
output0 = [0, 1, 2] output1 = [1, 0, 2]
在索引容差为 0 的情况下,这将视为不匹配,因为 0 和 1 的位置在两个输出之间翻转。 但是,在索引容差为 1 的情况下,它将通过,因为不匹配的值仅相差 1 个位置。 如果输出改为
output0 = [0, 1, 2] output1 = [1, 2, 0]
那么我们将需要索引容差为 2,因为两个输出中的 0 值相差 2 个位置。
设置此值后,每批将忽略最后 ‘index_tolerance’ 个值。 例如,在索引容差为 1 的情况下,不考虑最后一个元素中的不匹配项。 如果与 Top-K 输出一起使用,您可以通过改为使用 Top-(K + index_tolerance) 来补偿这一点。
可以基于每个输出来使用字典提供此值。 在这种情况下,使用空字符串 (“”) 作为键来指定未明确列出的输出的默认容差。
fail_fast (bool) – 函数是否应在第一次失败后立即退出。 默认为 False。
- 返回值:
一个可调用对象,它返回输出名称到 bool 的映射,指示相应的输出是否匹配。
- 返回类型:
Callable(IterationResult, IterationResult) -> OrderedDict[str, bool]