nvidia.dali.plugin.pytorch.fn.torch_python_function#

nvidia.dali.plugin.pytorch.fn.torch_python_function(*input, batch_processing=True, bytes_per_sample_hint=[0], function, num_outputs=1, output_layouts=None, preserve=False, device=None, name=None)#

执行在 Torch 张量上运行的函数。

此类类似于 nvidia.dali.fn.python_function(),但张量数据被处理为 PyTorch 张量。

此操作符允许序列输入并支持体积数据。

此操作符将不会从图中优化掉。

支持的后端
  • ‘cpu’

  • ‘gpu’

参数:

__input_[0..255] (TensorList, optional) – 此函数最多接受 256 个可选的位置输入

关键词参数:
  • batch_processing (bool, optional, default = True) – 确定函数是否将整个批次作为输入。

  • bytes_per_sample_hint (int or list of int, optional, default = [0]) –

    每个样本的输出大小提示(以字节为单位)。

    如果指定,则驻留在 GPU 或分页锁定主机内存中的操作符输出将被预先分配,以容纳此大小的样本批次。

  • function (object) –

    定义操作符功能的可调用对象。

    警告

    该函数不得持有对其使用的 Pipeline 的引用。如果持有引用,则会形成对 Pipeline 的循环引用,并且永远不会释放该 Pipeline。

  • num_outputs (int, optional, default = 1) – 输出数量。

  • output_layouts (layout str or list of layout str, optional) –

    输出的张量数据布局。

    此参数可以是列表,其中包含每个输出的不同布局。如果列表的元素少于 num_outputs,则仅前几个输出设置了布局,其余输出未分配布局。

  • preserve (bool, optional, default = False) – 即使操作符的输出未使用,也防止将其从图中删除。