加载器
模块:polygraphy.tools.args
- class TrtOnnxFlagArgs[source]
基类:
BaseArgs
ONNX-TRT 解析器标志:设置 TensorRT ONNX 解析器的标志
依赖于
TrtConfigArgs:如果在 VC/HC 模式下应自动启用 NATIVE_INSTANCENORM
- class TrtLoadNetworkArgs(allow_custom_outputs: bool | None = None, allow_onnx_loading: bool | None = None, allow_tensor_formats: bool | None = None)[source]
基类:
BaseArgs
TensorRT 网络加载:加载 TensorRT 网络。
依赖于
ModelArgs
TrtLoadPluginsArgs
OnnxLoadArgs:如果 allow_onnx_loading == True
TrtOnnxFlagArgs
- 参数:
allow_custom_outputs (bool) – 是否允许标记自定义输出张量。默认为 True。
allow_onnx_loading (bool) – 是否允许从 ONNX 模型解析网络。默认为 True。
allow_tensor_formats (bool) – 是否允许设置张量格式和相关选项。默认为 False。
- parse_impl(args)[source]
解析命令行参数并填充以下属性
- outputs
输出张量的名称。
- 类型:
List[str]
- exclude_outputs
应取消标记为输出的张量的名称。
- 类型:
List[str]
- trt_network_func_name
自定义网络脚本中创建网络的函数的名称。
- 类型:
str
- layer_precisions
层名称映射到其所需的计算精度,以字符串形式。
- 类型:
Dict[str, str]
- tensor_datatypes
张量名称映射到其所需的数据类型,以字符串形式。
- 类型:
Dict[str, str]
- tensor_formats
张量名称映射到其所需的格式,以字符串形式。
- 类型:
Dict[str, List[str]]
- postprocess_scripts
指定网络后处理脚本的路径和后处理函数名称的元组列表。
- 类型:
List[Tuple[str, str]]
- strongly_typed
是否将网络标记为强类型。
- 类型:
bool
- mark_debug
应标记为调试张量的张量的名称。
- 类型:
List[str]
- class TrtSaveEngineBytesArgs(output_opt: str | None = None, output_short_opt: str | None = None)[source]
基类:
BaseArgs
TensorRT 引擎保存:保存 TensorRT 引擎。
保存序列化的引擎。由于从 TensorRT 8.6 开始,版本兼容引擎在初始反序列化后无法重新序列化,因此应优先选择此方法而不是 TrtSaveEngineArgs()。
- 参数:
output_opt (str) – 输出路径选项的名称。默认为“output”。使用
False
值禁用该选项。output_short_opt (str) – 用于输出路径的短选项。默认为“-o”。使用
False
值禁用短选项。
- TrtSaveEngineArgs
别名
Deprecated
- class TrtLoadEngineBytesArgs(allow_saving: bool | None = None)[source]
基类:
BaseArgs
TensorRT 引擎:加载或构建 TensorRT 引擎。
依赖于
ModelArgs
TrtLoadPluginsArgs
TrtLoadNetworkArgs:如果需要支持构建引擎
TrtConfigArgs:如果需要支持构建引擎
TrtSaveEngineBytesArgs:如果 allow_saving == True
- 参数:
allow_saving (bool) – 是否允许保存加载的模型。默认为 False。
- class TrtLoadEngineArgs[source]
基类:
BaseArgs
TensorRT 引擎:加载 TensorRT 引擎。
依赖于
TrtLoadEngineBytesArgs
TrtLoadPluginsArgs
- class TrtConfigArgs(precision_constraints_default: bool | None = None, allow_random_data_calib_warning: bool | None = None, allow_custom_input_shapes: bool | None = None, allow_engine_capability: bool | None = None, allow_tensor_formats: bool | None = None)[source]
基类:
BaseArgs
TensorRT 构建器配置:创建 TensorRT BuilderConfig。
依赖于
DataLoaderArgs
ModelArgs:如果 allow_custom_input_shapes == True
- 参数:
precision_constraints_default (str) – 用于精度约束选项的默认值。默认为“none”。
allow_random_data_calib_warning (bool) – 当随机生成的数据用于校准时是否发出警告。默认为 True。
allow_custom_input_shapes (bool) – 当随机生成数据时是否允许自定义输入形状。默认为 True。
allow_engine_capability (bool) – 是否允许指定引擎能力。默认为 False。
allow_tensor_formats (bool) – 是否允许设置张量格式和相关选项。默认为 False。
- parse_impl(args)[source]
解析命令行参数并填充以下属性
- profile_dicts
配置文件列表,其中每个配置文件都是一个字典,将输入名称映射到 (min, opt, max) 形状的元组。
- 类型:
List[OrderedDict[str, Tuple[Shape]]]
- tf32
是否启用 TF32。
- 类型:
bool
- fp16
是否启用 FP16。
- 类型:
bool
- bf16
是否启用 BF16。
- 类型:
bool
- fp8
是否启用 FP8。
- 类型:
bool
- int8
是否启用 INT8。
- 类型:
bool
- precision_constraints
要应用的精度约束。
- 类型:
str
- restricted
是否在构建器中启用安全范围检查。
- 类型:
bool
- calibration_cache
校准缓存的路径。
- 类型:
str
- calibration_base_class
用于校准器的基类的名称。
- 类型:
str
- sparse_weights
是否启用稀疏权重。
- 类型:
bool
- load_timing_cache
从中加载时序缓存的路径。
- 类型:
str
- load_tactics
从中加载策略重放文件的路径。
- 类型:
str
- save_tactics
保存策略重放文件的路径。
- 类型:
str
- tactic_sources
表示要启用的策略源的枚举值的字符串。
- 类型:
List[str]
- trt_config_script
自定义 TensorRT 配置脚本的路径。
- 类型:
str
- trt_config_func_name
自定义配置脚本中创建配置的函数的名称。
- 类型:
str
- trt_config_postprocess_script
TensorRT 配置后处理脚本的路径。
- 类型:
str
- trt_config_postprocess_func_name
配置后处理脚本中应用后处理的函数的名称。
- 类型:
str
- use_dla
是否启用 DLA。
- 类型:
bool
- allow_gpu_fallback
启用 DLA 时是否允许 GPU 回退。
- 类型:
bool
- memory_pool_limits
表示内存池枚举值的字符串到内存限制(以字节为单位)的映射。
- 类型:
Dict[str, int]
- engine_capability
所需的引擎能力。
- 类型:
str
- direct_io
是否禁止在具有用户指定格式的网络输入/输出张量处重新格式化层。
- 类型:
bool
- preview_features
要启用的预览功能的名称。
- 类型:
List[str]
- refittable
引擎是否应该是可重构的。
- 类型:
bool
- strip_plan
构建引擎时是否应剥离可重构权重。
- 类型:
bool
- builder_optimization_level
构建器优化级别。
- 类型:
int
- hardware_compatibility_level
表示硬件兼容性级别枚举值的字符串。
- 类型:
str
- profiling_verbosity
一个字符串,表示性能分析详细程度枚举值。
- 类型:
str
- max_aux_streams
TensorRT 允许使用的最大辅助流数量。
- 类型:
int
- version_compatible
是否构建向前兼容 TensorRT 版本。
- 类型:
bool
- exclude_lean_runtime
是否从版本兼容计划中排除精简运行时。
- 类型:
bool
- quantization_flags
要启用的量化标志的名称。
- 类型:
List[str]
- error_on_timing_cache_miss
当正被计时的策略在时序缓存中不存在时,是否发出错误。
- 类型:
bool
- disable_compilation_cache
是否禁用缓存 JIT 编译的代码。
- 类型:
bool
- weight_streaming
是否为 TensorRT 引擎启用权重流式传输。
- 类型:
bool
- runtime_platform
一个字符串,表示目标运行时平台枚举值。
- 类型:
str
- tiling_optimization_level
平铺优化级别。
- 类型:
str