TensorRT 10.8.0
NvOnnxParser.h
前往此文件的文档。
1/*
2 * 版权所有 (c) 1993-2024, NVIDIA CORPORATION。保留所有权利。
3 *
4 * 特此授权,免费授予任何获得
5 * 本软件和相关文档文件(以下简称“软件”)副本的人员
处理本软件不受限制,包括但不限于
使用、复制、修改、合并、发布、分发、再许可的权利,
和/或出售本软件副本的权利,并允许向其
提供本软件的人员这样做,但须遵守以下条件:
10 *
* 上述版权声明和本许可声明应包含在
* 本软件的所有副本或主要部分中。
13 *
* 本软件按“现状”提供,不作任何明示或暗示的担保,包括但不限于
* 对适销性、特定用途的适用性和非侵权性的保证。
* 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他
* 责任负责,无论是在合同、侵权行为或其他方面引起的诉讼中,
* 因本软件或本软件的使用或其他
* 与本软件相关的交易而产生。
20 * DEALINGS IN THE SOFTWARE.
21 */
22
23#ifndef NV_ONNX_PARSER_H
24#define NV_ONNX_PARSER_H
25
26#include "NvInfer.h"
27#include <stddef.h>
28#include <string>
29#include <vector>
30
36
37#define NV_ONNX_PARSER_MAJOR 0
38#define NV_ONNX_PARSER_MINOR 1
39#define NV_ONNX_PARSER_PATCH 0
40
41static constexpr int32_t NV_ONNX_PARSER_VERSION
43
50typedef std::pair<std::vector<size_t>, bool> SubGraph_t;
51
58typedef std::vector<SubGraph_t> SubGraphCollection_t;
59
65namespace nvonnxparser
66{
67
68template <typename T>
69constexpr inline int32_t EnumMax() noexcept;
70
76enum class ErrorCode : int
77{
kSUCCESS = 0, // 成功
kINTERNAL_ERROR = 1, // 内部错误
kMEM_ALLOC_FAILED = 2, // 内存分配失败
kMODEL_DESERIALIZE_FAILED = 3, // 模型反序列化失败
kINVALID_VALUE = 4, // 无效值
kINVALID_GRAPH = 5, // 无效图
kINVALID_NODE = 6, // 无效节点
kUNSUPPORTED_GRAPH = 7, // 不支持的图
kUNSUPPORTED_NODE = 8, // 不支持的节点
kUNSUPPORTED_NODE_ATTR = 9, // 不支持的节点属性
kUNSUPPORTED_NODE_INPUT = 10, // 不支持的节点输入
kUNSUPPORTED_NODE_DATATYPE = 11, // 不支持的节点数据类型
kUNSUPPORTED_NODE_DYNAMIC = 12, // 不支持的节点动态性
kUNSUPPORTED_NODE_SHAPE = 13, // 不支持的节点形状
kREFIT_FAILED = 14 // 重拟合失败
93};
94
100template <>
101constexpr inline int32_t EnumMax<ErrorCode>() noexcept
102{
103 return 14;
104}
105
112using OnnxParserFlags = uint32_t;
113
114enum class OnnxParserFlag : int32_t
115{
kNATIVE_INSTANCENORM = 0 // 本地 InstanceNorm
121};
122
128template <>
129constexpr inline int32_t EnumMax<OnnxParserFlag>() noexcept
130{
131 return 1;
132}
133
140{
141public
145 virtual ErrorCode code() const = 0;
149 virtual char const* desc() const = 0;
153 virtual char const* file() const = 0;
157 virtual int line() const = 0;
161 virtual char const* func() const = 0;
165 virtual int node() const = 0;
169 virtual char const* nodeName() const = 0;
173 virtual char const* nodeOperator() const = 0;
179 virtual char const* const* localFunctionStack() const = 0;
// 堆栈大小为 0。
185 virtual int32_t localFunctionStackSize() const = 0;
186
187protected
188 virtual ~IParserError() {}
189};
190
202{
203public
218 virtual bool parse(
219 void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) noexcept
220 = 0;
221
232 virtual bool parseFromFile(const char* onnxModelFile, int verbosity) noexcept = 0;
233
247 TRT_DEPRECATED virtual bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
248 SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) noexcept = 0;
249
261 void const* serialized_onnx_model, size_t serialized_onnx_model_size) noexcept
262 = 0;
263
273 virtual bool supportsOperator(const char* op_name) const noexcept = 0;
274
281 virtual int getNbErrors() const noexcept = 0;
282
288 virtual IParserError const* getError(int index) const noexcept = 0;
289
295 virtual void clearErrors() noexcept = 0;
296
297 virtual ~IParser() noexcept = default;
298
315 virtual char const* const* getUsedVCPluginLibraries(int64_t& nbPluginLibs) const noexcept = 0;
316
328 virtual void setFlags(OnnxParserFlags onnxParserFlags) noexcept = 0;
329
337 virtual OnnxParserFlags getFlags() const noexcept = 0;
338
346 virtual void clearFlag(OnnxParserFlag onnxParserFlag) noexcept = 0;
347
355 virtual void setFlag(OnnxParserFlag onnxParserFlag) noexcept = 0;
356
364 virtual bool getFlag(OnnxParserFlag onnxParserFlag) const noexcept = 0;
365
378 virtual nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i) noexcept = 0;
379
392 virtual bool supportsModelV2(
393 void const* serializedOnnxModel, size_t serializedOnnxModelSize, char const* modelPath = nullptr) noexcept = 0;
394
402 virtual int64_t getNbSubgraphs() noexcept = 0;
403
412 virtual bool isSubgraphSupported(int64_t const index) noexcept = 0;
413
424 virtual int64_t* getSubgraphNodes(int64_t const index, int64_t& subgraphLength) noexcept = 0;
425};
426
435{
436public
449 virtual bool refitFromBytes(
450 void const* serializedOnnxModel, size_t serializedOnnxModelSize, char const* modelPath = nullptr) noexcept
451 = 0;
452
463 virtual bool refitFromFile(char const* onnxModelFile) noexcept = 0;
464
470 virtual int32_t getNbErrors() const noexcept = 0;
471
477 virtual IParserError const* getError(int32_t index) const noexcept = 0;
478
484 virtual void clearErrors() = 0;
485
486 virtual ~IParserRefitter() noexcept = default;
487};
488
} // namespace nvonnxparser
490
491extern "C" TENSORRTAPI void* createNvOnnxParser_INTERNAL(void* network, void* logger, int version) noexcept;
493 void* refitter, void* logger, int32_t version) noexcept;
494extern "C" TENSORRTAPI int getNvOnnxParserVersion() noexcept;
495
496namespace nvonnxparser
497{
498
499namespace
500{
501
518{
519 try
520 {
521 return static_cast<IParser*>(createNvOnnxParser_INTERNAL(&network, &logger, NV_ONNX_PARSER_VERSION));
522 }
523 catch (std::exception& e)
524 {
526 }
527
528 return nullptr;
529}
530
541{
542 try
543 {
544 return static_cast<IParserRefitter*>(
545 createNvOnnxParserRefitter_INTERNAL(&refitter, &logger, NV_ONNX_PARSER_VERSION));
546 }
547 catch (std::exception& e)
548 {
550 }
551
552 return nullptr;
553}
554
} // namespace
556
} // namespace nvonnxparser
558
559#endif // NV_ONNX_PARSER_H
#define TENSORRTAPI
定义: NvInferRuntimeBase.h:59
#define TRT_DEPRECATED
定义: NvInferRuntimeBase.h:45
TENSORRTAPI void * createNvOnnxParserRefitter_INTERNAL(void *refitter, void *logger, int32_t version) noexcept
std::vector< SubGraph_t > SubGraphCollection_t
包含从 ONNX 图中划分出的所有 SubGraph_t 的数据结构。
定义: NvOnnxParser.h:58
TENSORRTAPI void * createNvOnnxParser_INTERNAL(void *network, void *logger, int version) noexcept
TENSORRTAPI int getNvOnnxParserVersion() noexcept
#define NV_ONNX_PARSER_PATCH
定义: NvOnnxParser.h:39
#define NV_ONNX_PARSER_MINOR
定义: NvOnnxParser.h:38
std::pair< std::vector< size_t >, bool > SubGraph_t
包含 ONNX 图中一组节点的解析能力的数据结构。
定义: NvOnnxParser.h:50
#define NV_ONNX_PARSER_MAJOR
定义: NvOnnxParser.h:37
构建器、重拟合器和运行时的应用程序实现的日志记录接口。
定义: NvInferRuntime.h:1540
@ kINTERNAL_ERROR
发生内部错误。执行无法恢复。
用于输入到构建器的网络定义。
定义: NvInfer.h:6628
更新引擎中的权重。
定义: NvInferRuntime.h:2136
包含有关错误信息的对象
定义: NvOnnxParser.h:140
virtual char const * nodeOperator() const =0
发生错误的节点操作的名称。
virtual char const * func() const =0
发生错误的源函数。
virtual int line() const =0
发生错误的源代码行。
virtual char const * desc() const =0
错误描述。
virtual ~IParserError()
定义: NvOnnxParser.h:188
virtual ErrorCode code() const =0
错误代码。
virtual char const * nodeName() const =0
发生错误的节点的名称。
virtual char const *const * localFunctionStack() const =0
本地函数名称列表,从顶层向下,构成当前堆栈跟踪,位于...
virtual int node() const =0
发生错误的 ONNX 模型节点索引。
virtual char const * file() const =0
发生错误的源文件。
virtual int32_t localFunctionStackSize() const =0
发生错误时本地函数堆栈的大小。顶级节点...
用于将 ONNX 模型解析为 TensorRT 网络定义的对象
定义: NvOnnxParser.h:202
virtual IParserError const * getError(int index) const noexcept=0
获取先前调用 parse 期间发生的错误。
virtual char const *const * getUsedVCPluginLibraries(int64_t &nbPluginLibs) const noexcept=0
查询插件库,这些插件库是实现版本兼容的分析器所用操作所必需的 ...
virtual void setFlag(OnnxParserFlag onnxParserFlag) noexcept=0
设置单个解析器标志。
virtual bool parseWithWeightDescriptors(void const *serialized_onnx_model, size_t serialized_onnx_model_size) noexcept=0
将序列化的 ONNX 模型解析到 TensorRT 网络中,并考虑用户提供的权重。
virtual bool supportsOperator(const char *op_name) const noexcept=0
返回解析器是否可能支持指定的运算符。
virtual bool supportsModelV2(void const *serializedOnnxModel, size_t serializedOnnxModelSize, char const *modelPath=nullptr) noexcept=0
检查 TensorRT 是否支持特定的 ONNX 模型。如果函数返回 True,...
virtual void clearErrors() noexcept=0
清除先前解析调用的错误。
virtual TRT_DEPRECATED bool supportsModel(void const *serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t &sub_graph_collection, const char *model_path=nullptr) noexcept=0
检查 TensorRT 是否支持特定的 ONNX 模型。如果函数返回 True,...
virtual bool getFlag(OnnxParserFlag onnxParserFlag) const noexcept=0
如果设置了解析器标志,则返回 true。
virtual void clearFlag(OnnxParserFlag onnxParserFlag) noexcept=0
清除解析器标志。
virtual OnnxParserFlags getFlags() const noexcept=0
获取解析器标志。默认为 0。
virtual bool parseFromFile(const char *onnxModelFile, int verbosity) noexcept=0
解析 onnx 模型文件,该文件可以是二进制 protobuf 或文本 onnx 模型,调用内部解析方法...
virtual bool isSubgraphSupported(int64_t const index) noexcept=0
返回是否支持子图。在调用 supportsModelV2 之前调用此函数会导...
virtual int64_t getNbSubgraphs() noexcept=0
获取子图的数量。在调用 supportsModelV2 之前调用此函数会导致未定义...
virtual int64_t * getSubgraphNodes(int64_t const index, int64_t &subgraphLength) noexcept=0
获取指定子图的节点。在调用 supportsModelV2 之前调用此函数会导致...
virtual void setFlags(OnnxParserFlags onnxParserFlags) noexcept=0
设置解析器标志。
virtual bool parse(void const *serialized_onnx_model, size_t serialized_onnx_model_size, const char *model_path=nullptr) noexcept=0
将序列化的 ONNX 模型解析到 TensorRT 网络中。此方法具有非常有限的诊断功能....
virtual int getNbErrors() const noexcept=0
获取先前解析调用期间发生的错误数。
virtual nvinfer1::ITensor const * getLayerOutputTensor(char const *name, int64_t i) noexcept=0
返回 ONNX 层 “name” 的第 i 个输出 ITensor 对象。
一个旨在从 ONNX 模型重新拟合权重的接口。
Definition: NvOnnxParser.h:435
virtual bool refitFromBytes(void const *serializedOnnxModel, size_t serializedOnnxModelSize, char const *modelPath=nullptr) noexcept=0
从内存加载序列化的 ONNX 模型并执行权重重新拟合。
TensorRT API 版本 1 命名空间。
IParser * createParser(nvinfer1::INetworkDefinition &network, nvinfer1::ILogger &logger) noexcept
创建一个新的解析器对象。
Definition: NvOnnxParser.h:517
IParserRefitter * createParserRefitter(nvinfer1::IRefitter &refitter, nvinfer1::ILogger &logger) noexcept
创建一个新的 ONNX refitter 对象。
Definition: NvOnnxParser.h:540
TensorRT ONNX 解析器 API 命名空间。
Definition: NvOnnxConfig.h:24
uint32_t OnnxParserFlags
表示一个或多个使用二进制 OR 运算的 OnnxParserFlag 值,例如,1U << OnnxParserFlag::...
Definition: NvOnnxParser.h:112
constexpr int32_t EnumMax() noexcept
ErrorCode
解析器或 refitter 可能返回的错误类型。
Definition: NvOnnxParser.h:77
OnnxParserFlag
Definition: NvOnnxParser.h:115
constexpr int32_t EnumMax< ErrorCode >() noexcept
Definition: NvOnnxParser.h:101
constexpr int32_t EnumMax< OnnxParserFlag >() noexcept
Definition: NvOnnxParser.h:129

  版权所有 © 2024 NVIDIA Corporation
  隐私政策 | 管理我的隐私 | 请勿出售或分享我的数据 | 服务条款 | 辅助功能 | 公司政策 | 产品安全 | 联系方式