18#ifndef NV_INFER_RUNTIME_PLUGIN_H
19#define NV_INFER_RUNTIME_PLUGIN_H
21#define NV_INFER_INTERNAL_INCLUDE 1
23#undef NV_INFER_INTERNAL_INCLUDE
59static constexpr int32_t kPLUGIN_VERSION_PYTHON_BIT = 0x40;
180 virtual
AsciiChar const* getPluginVersion() const noexcept = 0;
195 virtual int32_t getNbOutputs() const noexcept = 0;
220 virtual
Dims getOutputDimensions(int32_t index,
Dims const* inputs, int32_t nbInputDims) noexcept = 0;
279 virtual
void configureWithFormat(
Dims const* inputDims, int32_t nbInputs,
Dims const* outputDims, int32_t nbOutputs,
294 virtual int32_t initialize() noexcept = 0;
309 virtual
void terminate() noexcept = 0;
328 virtual
size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept = 0;
351 virtual int32_t enqueue(int32_t batchSize,
void const* const* inputs,
void* const* outputs,
void* workspace,
352 cudaStream_t stream) noexcept
365 virtual
size_t getSerializationSize() const noexcept = 0;
380 virtual
void serialize(
void* buffer) const noexcept = 0;
390 virtual
void destroy() noexcept = 0;
425 virtual
void setPluginNamespace(
AsciiChar const* pluginNamespace) noexcept = 0;
438 virtual
AsciiChar const* getPluginNamespace() const noexcept = 0;
518 int32_t outputIndex,
bool const* inputIsBroadcasted, int32_t nbInputs)
const noexcept
586 DataType const* inputTypes,
DataType const* outputTypes,
bool const* inputIsBroadcast,
587 bool const* outputIsBroadcast,
PluginFormat floatFormat, int32_t maxBatchSize)
noexcept
686 int32_t getTensorRTVersion() const noexcept
override
778 int32_t pos,
PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs)
const noexcept
862 virtual
AsciiChar const* getPluginVersion() const noexcept = 0;
906 virtual
IPluginV2* deserializePlugin(
AsciiChar const* name,
void const* serialData,
size_t serialLength) noexcept
923 virtual
void setPluginNamespace(
AsciiChar const* pluginNamespace) noexcept = 0;
937 virtual
AsciiChar const* getPluginNamespace() const noexcept = 0;
#define NV_TENSORRT_VERSION
定义: NvInferRuntimeBase.h:91
#define TRT_DEPRECATED
定义: NvInferRuntimeBase.h:45
定义: NvInferRuntimeBase.h:203
用户实现的层的插件类。
定义: NvInferRuntimePlugin.h:468
virtual TRT_DEPRECATED bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept=0
如果插件可以使用跨批次广播而无需复制的输入张量,则返回 true。
~IPluginV2Ext() override=default
void configureWithFormat(Dims const *, int32_t, Dims const *, int32_t, DataType, PluginFormat, int32_t) noexcept override
派生类不得实现此方法。在 C++11 API 中,这将是 override final。
定义: NvInferRuntimePlugin.h:698
IPluginV2Ext * clone() const noexcept override=0
克隆插件对象。这也会复制内部插件参数并返回一个新插件...
virtual void configurePlugin(Dims const *inputDims, int32_t nbInputs, Dims const *outputDims, int32_t nbOutputs, DataType const *inputTypes, DataType const *outputTypes, bool const *inputIsBroadcast, bool const *outputIsBroadcast, PluginFormat floatFormat, int32_t maxBatchSize) noexcept=0
使用输入和输出数据类型配置层。
virtual void detachFromContext() noexcept
将插件对象从其执行上下文中分离。
定义: NvInferRuntimePlugin.h:645
virtual TRT_DEPRECATED bool isOutputBroadcastAcrossBatch(int32_t outputIndex, bool const *inputIsBroadcasted, int32_t nbInputs) const noexcept=0
如果输出张量跨批次广播,则返回 true。
virtual void attachToContext(cudnnContext *, cublasContext *, IGpuAllocator *) noexcept
将插件对象附加到执行上下文,并授予插件访问某些上下文资源的权限...
定义: NvInferRuntimePlugin.h:627
virtual nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const *inputTypes, int32_t nbInputs) const noexcept=0
返回请求索引处插件输出的 DataType。
用户实现的层的插件类。
定义: NvInferRuntimePlugin.h:133
virtual AsciiChar const * getPluginType() const noexcept=0
返回插件类型。应与相应插件创建器返回的插件名称匹配。
virtual int32_t getTensorRTVersion() const noexcept
返回构建此插件的 API 版本。
定义: NvInferRuntimePlugin.h:147
用户实现的层的插件类。
定义: NvInferRuntimePlugin.h:717
int32_t getTensorRTVersion() const noexcept override
返回构建此插件的 API 版本。高位字节由 TensorRT 保留,并且...
定义: NvInferRuntimePlugin.h:805
virtual void configurePlugin(PluginTensorDesc const *in, int32_t nbInput, PluginTensorDesc const *out, int32_t nbOutput) noexcept=0
配置层。
virtual bool supportsFormatCombination(int32_t pos, PluginTensorDesc const *inOut, int32_t nbInputs, int32_t nbOutputs) const noexcept=0
如果插件支持 pos 索引的输入/输出的格式和数据类型,则返回 true。
与 TRT 接口关联的版本信息。
定义: NvInferRuntimeBase.h:228
定义: NvInferRuntime.h:1608
定义: NvInferRuntimePlugin.h:834
virtual AsciiChar const * getPluginName() const noexcept=0
返回插件名称。
定义: NvInferPluginBase.h:193
PluginCreatorVersion
用于标识插件创建器版本的枚举。
定义: NvInferRuntimePlugin.h:111
@ kV1_PYTHON
基于 IPluginCreator 的 Python 插件创建器。
v_1_0::IPluginCreator IPluginCreator
定义: NvInferRuntimePlugin.h:970
v_1_0::IGpuAllocator IGpuAllocator
定义: NvInferRuntime.h:1807
char_t AsciiChar
定义: NvInferRuntimeBase.h:105
@ kV2_DYNAMICEXT
IPluginV2DynamicExt.
@ kV2_IOEXT
IPluginV2IOExt.
@ kV2_DYNAMICEXT_PYTHON
基于 IPluginV2DynamicExt 的 Python 插件。
DataType
权重和张量的类型。
定义: NvInferRuntimeBase.h:133
TensorFormat PluginFormat
PluginFormat 保留用于向后兼容。
定义: NvInferRuntimePlugin.h:54
TensorFormat
输入/输出张量的格式。
定义: NvInferRuntime.h:1382
插件字段集合结构体。
定义: NvInferPluginBase.h:103
插件可能看到的输入或输出字段。
定义: NvInferRuntimePlugin.h:73
DataType type
定义: NvInferRuntimePlugin.h:77
Dims dims
维度。
定义: NvInferRuntimePlugin.h:75
TensorFormat format
张量格式。
定义: NvInferRuntimePlugin.h:79
float scale
INT8 数据类型的缩放比例。
定义: NvInferRuntimePlugin.h:81