tensorrt.plugin.register

tensorrt.plugin.register(plugin_id: str, lazy_register: bool = False) Callable

包装一个函数以注册和描述 TensorRT 插件的 IO 特性。此外,一个完整的插件至少需要注册一个 trt.plugin.impl 函数。

此 API 仅旨在用作装饰器。被装饰的函数必须对所有输入以及返回值都有类型提示。

(inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, ...) -> Union[TensorDesc, Tuple[TensorDesc]]
  • 输入张量首先被声明,每个张量都由张量描述符 TensorDesc 描述。

  • 插件属性接下来被声明。“SupportedAttrType” 必须是以下类型之一
    • 支持的内置类型:int、float、str、bool、bytes (注意:不支持这些类型的列表/元组)

    • 以下类型的一维 Numpy 数组:int8、int16、int32、int64、float16、float32、float64、bool。这些必须使用 'numpy.typing.NDArray[dtype]' 进行注释,其中 'dtype' 是预期的 numpy dtype。

  • 如果插件只有一个输出,则返回注释可以是 TensorDesc。Tuple[TensorDesc] 可以用于任意数量的输出。

默认情况下,插件将立即在 TRT 插件注册表中注册。使用 lazy_register 参数来更改此行为。

参数:
  • plugin_id – 插件的 ID,格式为“{namespace}::{name}”,例如“my_project::add_plugin”。命名空间用于避免冲突,因此建议使用您的产品/项目名称。

  • lazy_register – 在插件开发期间/构建引擎时,可以使用延迟注册来延迟插件注册,直到使用 trt.plugin.op.ns.plugin_name(…) 显式实例化插件为止

逐元素插件的注册(输出与输入具有相同的特性)
1import tensorrt.plugin as trtp
2
3@trtp.register("my::add_plugin")
4def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
5    return inp0.like()