校准器
模块: polygraphy.backend.trt
- Calibrator(data_loader, cache=None, BaseClass=None, batch_size=None, quantile=None, regression_cutoff=None, algo=None)[source]
为 TensorRT 提供校准数据,以校准网络进行 INT8 推理。
- 参数:
data_loader (Sequence[OrderedDict[str, Union[numpy.ndarray, DeviceView, torch.Tensor, int]]]) –
一个生成器或可迭代对象,它产生一个字典,该字典将输入名称映射到 NumPy 数组、Polygraphy DeviceViews、PyTorch 张量或 GPU 指针。如果提供 NumPy 数组、DeviceViews 或 PyTorch 张量,校准器将尽可能检查数据类型和形状,以确保它们与模型预期的类型和形状匹配。
如果您事先不知道有关输入的详细信息,则可以访问数据加载器中的 input_metadata 属性,该属性将由 Polygraphy API(如
CreateConfig
和EngineFromNetwork
)设置为TensorMetadata
实例。请注意,这不适用于生成器或列表。校准批次的数量由数据加载器提供的项目数量控制。
cache (Union[str, file-like]) – 用于保存/加载校准缓存的路径或类似文件的对象。默认情况下,不保存校准缓存。
BaseClass (type) – 要继承的校准器的类型。默认为
trt.IInt8EntropyCalibrator2
。batch_size (int) – [已弃用] 数据加载器提供的每个批次的大小。
quantile (float) – 用于
trt.IInt8LegacyCalibrator
的分位数。对其他校准器类型无效。默认为 0.5。regression_cutoff (float) – 用于
trt.IInt8LegacyCalibrator
的回归截止值。对其他校准器类型无效。默认为 0.5。algo (trt.CalibrationAlgoType) – 用于
trt.IInt8Calibrator
的校准算法。对其他校准器类型无效。默认为trt.CalibrationAlgoType.MINMAX_CALIBRATION
。