nvidia.dali.plugin.pytorch.fn.torch_python_function#
- nvidia.dali.plugin.pytorch.fn.torch_python_function(*input, batch_processing=True, bytes_per_sample_hint=[0], function, num_outputs=1, output_layouts=None, preserve=False, device=None, name=None)#
执行在 Torch 张量上运行的函数。
此类类似于
nvidia.dali.fn.python_function()
,但张量数据被处理为 PyTorch 张量。此操作符允许序列输入并支持体积数据。
此操作符将不会从图中优化掉。
- 支持的后端
‘cpu’
‘gpu’
- 参数:
__input_¶[0..255] (TensorList, optional) – 此函数最多接受 256 个可选的位置输入
- 关键词参数:
batch_processing¶ (bool, optional, default = True) – 确定函数是否将整个批次作为输入。
bytes_per_sample_hint¶ (int or list of int, optional, default = [0]) –
每个样本的输出大小提示(以字节为单位)。
如果指定,则驻留在 GPU 或分页锁定主机内存中的操作符输出将被预先分配,以容纳此大小的样本批次。
function¶ (object) –
定义操作符功能的可调用对象。
警告
该函数不得持有对其使用的 Pipeline 的引用。如果持有引用,则会形成对 Pipeline 的循环引用,并且永远不会释放该 Pipeline。
num_outputs¶ (int, optional, default = 1) – 输出数量。
output_layouts¶ (layout str or list of layout str, optional) –
输出的张量数据布局。
此参数可以是列表,其中包含每个输出的不同布局。如果列表的元素少于 num_outputs,则仅前几个输出设置了布局,其余输出未分配布局。
preserve¶ (bool, optional, default = False) – 即使操作符的输出未使用,也防止将其从图中删除。