nvidia.dali.plugin.pytorch.fn.torch_python_function#
- nvidia.dali.plugin.pytorch.fn.torch_python_function(*input, batch_processing=True, bytes_per_sample_hint=[0], function, num_outputs=1, output_layouts=None, preserve=False, device=None, name=None)#
执行一个操作 Torch 张量的函数。
此类类似于
nvidia.dali.fn.python_function()
,但张量数据被处理为 PyTorch 张量。此运算符允许序列输入并支持体积数据。
此运算符将**不会**从图中优化掉。
- 支持的后端
‘cpu’
‘gpu’
- 参数:
__input_¶[0..255] (TensorList, 可选) – 此函数最多接受 256 个可选的位置输入
- 关键字参数:
batch_processing¶ (bool, 可选, 默认值 = True) – 确定函数是否将整个批次作为输入。
bytes_per_sample_hint¶ (int 或 list of int, 可选, 默认值 = [0]) –
每个样本的输出大小提示(以字节为单位)。
如果指定,则将预先分配驻留在 GPU 或分页锁定主机内存中的运算符输出,以适应此大小的样本批次。
function¶ (object) –
定义运算符功能的可调用对象。
警告
该函数不得持有对其所用 Pipeline 的引用。如果持有,则将形成对 Pipeline 的循环引用,并且永远不会释放 Pipeline。
num_outputs¶ (int, 可选, 默认值 = 1) – 输出数量。
output_layouts¶ (layout str 或 list of layout str, 可选) –
输出的张量数据布局。
此参数可以是列表,其中包含每个输出的不同布局。如果列表的元素少于 num_outputs,则仅设置前几个输出的布局,其余输出不分配布局。
preserve¶ (bool, 可选, 默认值 = False) – 即使运算符的输出未使用,也阻止将其从图中删除。