校准器

模块: polygraphy.backend.trt

Calibrator(data_loader, cache=None, BaseClass=None, batch_size=None, quantile=None, regression_cutoff=None, algo=None)[source]

为 TensorRT 提供校准数据,以校准网络进行 INT8 推理。

参数:
  • data_loader (Sequence[OrderedDict[str, Union[numpy.ndarray, DeviceView, torch.Tensor, int]]]) –

    一个生成器或可迭代对象,它产生一个字典,该字典将输入名称映射到 NumPy 数组、Polygraphy DeviceViews、PyTorch 张量或 GPU 指针。如果提供 NumPy 数组、DeviceViews 或 PyTorch 张量,校准器将尽可能检查数据类型和形状,以确保它们与模型预期的类型和形状匹配。

    如果您事先不知道有关输入的详细信息,则可以访问数据加载器中的 input_metadata 属性,该属性将由 Polygraphy API(如 CreateConfigEngineFromNetwork)设置为 TensorMetadata 实例。请注意,这不适用于生成器或列表。

    校准批次的数量由数据加载器提供的项目数量控制。

  • cache (Union[str, file-like]) – 用于保存/加载校准缓存的路径或类似文件的对象。默认情况下,不保存校准缓存。

  • BaseClass (type) – 要继承的校准器的类型。默认为 trt.IInt8EntropyCalibrator2

  • batch_size (int) – [已弃用] 数据加载器提供的每个批次的大小。

  • quantile (float) – 用于 trt.IInt8LegacyCalibrator 的分位数。对其他校准器类型无效。默认为 0.5。

  • regression_cutoff (float) – 用于 trt.IInt8LegacyCalibrator 的回归截止值。对其他校准器类型无效。默认为 0.5。

  • algo (trt.CalibrationAlgoType) – 用于 trt.IInt8Calibrator 的校准算法。对其他校准器类型无效。默认为 trt.CalibrationAlgoType.MINMAX_CALIBRATION