TensorRT 10.8.0
NvInferRuntimePlugin.h
前往此文件的文档。
1/*
2 * SPDX-FileCopyrightText: 版权所有 (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES。保留所有权利。
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * 根据 Apache 许可证 2.0 版本(“许可证”)获得许可;
6 * 除非遵守许可证,否则您不得使用此文件。
7 * 您可以在以下位置获取许可证副本:
8 *
9 * https://apache.ac.cn/licenses/LICENSE-2.0
10 *
11 * 除非适用法律要求或书面同意,否则根据许可证
12 * 分发的软件是按“现状”基础分发的,
13 * 不附带任何形式的明示或暗示的保证或条件。
14 * 有关管理权限和
15 * 许可证限制的具体语言,请参阅许可证。
16 */
17
18#ifndef NV_INFER_RUNTIME_PLUGIN_H
19#define NV_INFER_RUNTIME_PLUGIN_H
20
21#define NV_INFER_INTERNAL_INCLUDE 1
22#include "NvInferPluginBase.h"
23#undef NV_INFER_INTERNAL_INCLUDE
24
33
39namespace nvinfer1
40{
41
42enum class TensorFormat : int32_t;
43namespace v_1_0
44{
45class IGpuAllocator;
46}
48
55
59static constexpr int32_t kPLUGIN_VERSION_PYTHON_BIT = 0x40;
60
73{
81 float scale;
82};
83
91enum class PluginVersion : uint8_t
92{
94 kV2 = 0,
96 kV2_EXT = 1,
98 kV2_IOEXT = 2,
100 kV2_DYNAMICEXT = 3,
102 kV2_DYNAMICEXT_PYTHON = kPLUGIN_VERSION_PYTHON_BIT | 3
103};
104
110enum class PluginCreatorVersion : int32_t
111{
113 kV1 = 0,
115 kV1_PYTHON = kPLUGIN_VERSION_PYTHON_BIT
116};
117
133{
134public
147 virtual int32_t getTensorRTVersion() const noexcept
148 {
149 return NV_TENSORRT_VERSION;
150 }
151
165 virtual AsciiChar const* getPluginType() const noexcept = 0;
166
180 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
181
195 virtual int32_t getNbOutputs() const noexcept = 0;
196
220 virtual Dims getOutputDimensions(int32_t index, Dims const* inputs, int32_t nbInputDims) noexcept = 0;
221
245 virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0;
246
279 virtual void configureWithFormat(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
280 DataType type, PluginFormat format, int32_t maxBatchSize) noexcept
281 = 0;
282
294 virtual int32_t initialize() noexcept = 0;
295
309 virtual void terminate() noexcept = 0;
310
328 virtual size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
329
351 virtual int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void* workspace,
352 cudaStream_t stream) noexcept
353 = 0;
354
365 virtual size_t getSerializationSize() const noexcept = 0;
366
380 virtual void serialize(void* buffer) const noexcept = 0;
381
390 virtual void destroy() noexcept = 0;
391
409 virtual IPluginV2* clone() const noexcept = 0;
410
425 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
426
438 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
439
440 // @cond SuppressDoxyWarnings
441 IPluginV2() = default;
442 virtual ~IPluginV2() noexcept = default;
443// @endcond
444
445protected
446// @cond SuppressDoxyWarnings
447 IPluginV2(IPluginV2 const&) = default;
448 IPluginV2(IPluginV2&&) = default;
449 IPluginV2& operator=(IPluginV2 const&) & = default;
450 IPluginV2& operator=(IPluginV2&&) & = default;
451// @endcond
452};
453
468{
469public
494 int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept
495 = 0;
496
518 int32_t outputIndex, bool const* inputIsBroadcasted, int32_t nbInputs) const noexcept
519 = 0;
520
546 TRT_DEPRECATED virtual bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept = 0;
547
585 virtual void configurePlugin(Dims const* inputDims, int32_t nbInputs, Dims const* outputDims, int32_t nbOutputs,
586 DataType const* inputTypes, DataType const* outputTypes, bool const* inputIsBroadcast,
587 bool const* outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept
588 = 0;
589
590 IPluginV2Ext() = default;
591 ~IPluginV2Ext() override = default;
592
627 virtual void attachToContext(
628 cudnnContext* /*cudnn*/, cublasContext* /*cublas*/, IGpuAllocator* /*allocator*/) noexcept
629 {
630 }
631
645 virtual void detachFromContext() noexcept {}
646
661 IPluginV2Ext* clone() const noexcept override = 0;
662
663protected
664 // @cond SuppressDoxyWarnings
665 IPluginV2Ext(IPluginV2Ext const&) = default;
666 IPluginV2Ext(IPluginV2Ext&&) = default;
667 IPluginV2Ext& operator=(IPluginV2Ext const&) & = default;
668 IPluginV2Ext& operator=(IPluginV2Ext&&) & = default;
669// @endcond
670
686 int32_t getTensorRTVersion() const noexcept override
687 {
688 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_EXT) << 24U)
689 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
690 }
691
698 void configureWithFormat(Dims const* /*inputDims*/, int32_t /*nbInputs*/, Dims const* /*outputDims*/,
699 int32_t /*nbOutputs*/, DataType /*type*/, PluginFormat /*format*/, int32_t /*maxBatchSize*/) noexcept override
700 {
701 }
702};
703
717{
718public
736 virtual void configurePlugin(
737 PluginTensorDesc const* in, int32_t nbInput, PluginTensorDesc const* out, int32_t nbOutput) noexcept
738 = 0;
739
778 int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept
779 = 0;
780
781 // @cond SuppressDoxyWarnings
782 IPluginV2IOExt() = default;
783 ~IPluginV2IOExt() override = default;
784// @endcond
785
786protected
787// @cond SuppressDoxyWarnings
788 IPluginV2IOExt(IPluginV2IOExt const&) = default;
789 IPluginV2IOExt(IPluginV2IOExt&&) = default;
790 IPluginV2IOExt& operator=(IPluginV2IOExt const&) & = default;
791 IPluginV2IOExt& operator=(IPluginV2IOExt&&) & = default;
792// @endcond
793
805 int32_t getTensorRTVersion() const noexcept override
806 {
807 return static_cast<int32_t>((static_cast<uint32_t>(PluginVersion::kV2_IOEXT) << 24U)
808 | (static_cast<uint32_t>(NV_TENSORRT_VERSION) & 0xFFFFFFU));
809 }
810
811private
812 // 以下是过时的基类方法,不应实现或使用。
813
817 void configurePlugin(Dims const*, int32_t, Dims const*, int32_t, DataType const*, DataType const*, bool const*,
818 bool const*, PluginFormat, int32_t) noexcept final
819 {
820 }
821
825 bool supportsFormat(DataType, PluginFormat) const noexcept final
826 {
827 return false;
828 }
829};
830
831namespace v_1_0
832{
834{
835public
848 virtual AsciiChar const* getPluginName() const noexcept = 0;
849
862 virtual AsciiChar const* getPluginVersion() const noexcept = 0;
863
875 virtual PluginFieldCollection const* getFieldNames() noexcept = 0;
876
889 virtual IPluginV2* createPlugin(AsciiChar const* name, PluginFieldCollection const* fc) noexcept = 0;
890
906 virtual IPluginV2* deserializePlugin(AsciiChar const* name, void const* serialData, size_t serialLength) noexcept
907 = 0;
908
923 virtual void setPluginNamespace(AsciiChar const* pluginNamespace) noexcept = 0;
924
937 virtual AsciiChar const* getPluginNamespace() const noexcept = 0;
938
939 IPluginCreator() = default;
940 ~IPluginCreator() override = default;
941
942protected
943 // @cond SuppressDoxyWarnings
944 IPluginCreator(IPluginCreator const&) = default;
945 IPluginCreator(IPluginCreator&&) = default;
946 IPluginCreator& operator=(IPluginCreator const&) & = default;
947 IPluginCreator& operator=(IPluginCreator&&) & = default;
948 // @endcond
949public
953 InterfaceInfo getInterfaceInfo() const noexcept override
954 {
955 return InterfaceInfo{"PLUGIN CREATOR_V1", 1, 0};
956 }
957};
958} // namespace v_1_0
959
971
972} // namespace nvinfer1
973
974#endif // NV_INFER_RUNTIME_PLUGIN_H
#define NV_TENSORRT_VERSION
定义: NvInferRuntimeBase.h:91
#define TRT_DEPRECATED
定义: NvInferRuntimeBase.h:45
应用程序实现的类,用于控制 GPU 上的分配。
定义: NvInferRuntimeBase.h:203
用户实现的层的插件类。
定义: NvInferRuntimePlugin.h:468
virtual TRT_DEPRECATED bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
如果插件可以使用跨批次广播而无需复制的输入张量,则返回 true。
~IPluginV2Ext() override=default
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
派生类不得实现此方法。在 C++11 API 中,这将是 override final。
定义: NvInferRuntimePlugin.h:698
IPluginV2Ext * clone() const noexcept override=0
克隆插件对象。这也会复制内部插件参数并返回一个新插件...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
使用输入和输出数据类型配置层。
virtual void detachFromContext() noexcept
将插件对象从其执行上下文中分离。
定义: NvInferRuntimePlugin.h:645
virtual TRT_DEPRECATED bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
如果输出张量跨批次广播,则返回 true。
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
将插件对象附加到执行上下文,并授予插件访问某些上下文资源的权限...
定义: NvInferRuntimePlugin.h:627
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
返回请求索引处插件输出的 DataType。
用户实现的层的插件类。
定义: NvInferRuntimePlugin.h:133
virtual AsciiChar const * getPluginType() const noexcept=0
返回插件类型。应与相应插件创建器返回的插件名称匹配。
virtual int32_t getTensorRTVersion() const noexcept
返回构建此插件的 API 版本。
定义: NvInferRuntimePlugin.h:147
用户实现的层的插件类。
定义: NvInferRuntimePlugin.h:717
int32_t getTensorRTVersion() const noexcept override
返回构建此插件的 API 版本。高位字节由 TensorRT 保留,并且...
定义: NvInferRuntimePlugin.h:805
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
配置层。
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
如果插件支持 pos 索引的输入/输出的格式和数据类型,则返回 true。
与 TRT 接口关联的版本信息。
定义: NvInferRuntimeBase.h:228
定义: NvInferRuntime.h:1608
定义: NvInferRuntimePlugin.h:834
virtual AsciiChar const * getPluginName() const noexcept=0
返回插件名称。
定义: NvInferPluginBase.h:193
TensorRT API 版本 1 的命名空间。
PluginCreatorVersion
用于标识插件创建器版本的枚举。
定义: NvInferRuntimePlugin.h:111
@ kV1_PYTHON
基于 IPluginCreator 的 Python 插件创建器。
v_1_0::IPluginCreator IPluginCreator
定义: NvInferRuntimePlugin.h:970
v_1_0::IGpuAllocator IGpuAllocator
定义: NvInferRuntime.h:1807
char_t AsciiChar
定义: NvInferRuntimeBase.h:105
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_EXT
IPluginV2Ext.
@ kV2_DYNAMICEXT_PYTHON
基于 IPluginV2DynamicExt 的 Python 插件。
DataType
权重和张量的类型。
定义: NvInferRuntimeBase.h:133
TensorFormat PluginFormat
PluginFormat 保留用于向后兼容。
定义: NvInferRuntimePlugin.h:54
TensorFormat
输入/输出张量的格式。
定义: NvInferRuntime.h:1382
插件版本的定义。
插件字段集合结构体。
定义: NvInferPluginBase.h:103
插件可能看到的输入或输出字段。
定义: NvInferRuntimePlugin.h:73
DataType type
定义: NvInferRuntimePlugin.h:77
Dims dims
维度。
定义: NvInferRuntimePlugin.h:75
TensorFormat format
张量格式。
定义: NvInferRuntimePlugin.h:79
float scale
INT8 数据类型的缩放比例。
定义: NvInferRuntimePlugin.h:81

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact