PyTorch DALI 代理#
概述#
DALI 代理是一个旨在将 NVIDIA DALI pipeline 与 PyTorch 数据工作器集成的工具,同时保持 PyTorch 数据集逻辑的简洁性。DALI 代理的关键特性包括
高效 GPU 利用率:DALI 代理确保 GPU 数据处理在运行主循环的进程中进行。这避免了由于同一 GPU 的多个 CUDA 上下文导致的性能下降。
选择性卸载:用户可以将数据处理 pipeline 的部分卸载到 DALI,同时保留 PyTorch Dataset 逻辑,使其成为多模态应用的理想选择。
本教程将解释 PyTorch 中 DALI 代理的关键组件、工作流程和用法。
注意
免责声明:目前,由 DALI 代理生成的数据不能在 Dataset 中进一步处理。它必须按原样传递到主循环。如果需要在 DALI 之外进行后处理,则应仅在迭代器生成数据后进行。
DALI 代理工作流程#
主要组件
DALI Pipeline 用户定义的 DALI pipeline 处理输入数据。
DALI Server 服务器运行一个后台线程来异步执行 DALI pipeline。
DALI 代理 PyTorch 数据工作器和 DALI Server 之间的可调用接口。
PyTorch Dataset 和 DataLoader Dataset 保持与 DALI 内部结构无关,并使用 Proxy 进行预处理。
工作流程摘要
定义一个 DALI pipeline 并连接到 DALI Server,后者在后台线程中执行 pipeline。
DALI 代理为 PyTorch 数据工作器提供了一个异步请求 DALI 处理的接口。
每个数据工作器调用代理,代理返回对未来处理样本的引用。
在批处理整理期间,代理将数据分组到一个批次中,并将其发送到服务器以执行。
服务器异步处理批次并将实际数据输出到输出队列。
PyTorch DataLoader 检索已处理的数据或对挂起 pipeline 运行的引用。然后,挂起的 pipeline 运行引用将替换为实际数据,并在必要时等待数据。
API#
- class nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer(pipeline, deterministic=False)#
- __enter__()#
启动 DALI pipeline 线程
- __exit__(exc_type, exc_value, tb)#
停止 DALI pipeline 线程
- __init__(pipeline, deterministic=False)#
初始化一个新的 DALI 服务器实例。
- 参数:
示例 1 - 通过 DALI 代理 DataLoader 与 PyTorch 完全集成
@pipeline_def def rn50_train_pipe(): rng = fn.random.coin_flip(probability=0.5) filepaths = fn.external_source(name="images", no_copy=True) jpegs = fn.io.file.read(filepaths) images = fn.decoders.image_random_crop( jpegs, device="mixed", output_type=types.RGB, random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0], ) images = fn.resize( images, size=[224, 224], interp_type=types.INTERP_LINEAR, antialias=False, ) output = fn.crop_mirror_normalize( images, dtype=types.FLOAT, output_layout="CHW", crop=(224, 224), mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255], mirror=rng, ) return output def read_filepath(path): return np.frombuffer(path.encode(), dtype=np.int8) nworkers = 8 pipe = rn50_train_pipe( batch_size=16, num_threads=3, device_id=0, prefetch_queue_depth=2*nworkers) # The scope makes sure the server starts and stops at enter/exit with dali_proxy.DALIServer(pipe) as dali_server: # DALI proxy instance can be used as a transform callable dataset = torchvision.datasets.ImageFolder( jpeg, transform=dali_server.proxy, loader=read_filepath) # Same interface as torch DataLoader, but takes a dali_server as first argument loader = nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader( dali_server, dataset, batch_size=batch_size, num_workers=nworkers, drop_last=True, ) for data, target in loader: # consume it
示例 2 - 使用 DALI 代理/DALI 服务器和 PyTorch 的 default_collate 手动执行
@pipeline_def def my_pipe(): a = fn.external_source(name="a", no_copy=True) b = fn.external_source(name="b", no_copy=True) return a + b, a - b with dali_proxy.DALIServer( my_pipe(device='cpu', batch_size=batch_size, num_threads=3, device_id=None)) as dali_server: outs = [] for _ in range(batch_size): a = np.array(np.random.rand(3, 3), dtype=np.float32) b = np.array(np.random.rand(3, 3), dtype=np.float32) out0, out1 = dali_server.proxy(a=a, b=b) outs.append((a, b, out0, out1)) outs = torch.utils.data.dataloader.default_collate(outs) a, b, a_plus_b, a_minus_b = dali_server.produce_data(outs)
示例 3 - 与 PyTorch 完全集成,但使用原始 PyTorch DataLoader
pipe = rn50_train_pipe(...) with dali_proxy.DALIServer(pipe) as dali_server: dataset = torchvision.datasets.ImageFolder( jpeg, transform=dali_server.proxy, loader=read_filepath) # Using PyTorch DataLoader directly loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=nworkers, drop_last=True, ) for data, target in loader: # replaces the output reference with actual data data = dali_server.produce_data(data) ...
- produce_data(obj)#
一个通用函数,用于递归访问嵌套结构中的所有元素,并将 DALIOutputBatchRef 的实例替换为 DALI 服务器提供的实际数据。有关完整示例,请参见
nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer
。- 参数:
obj¶ – 要映射的对象(可以是任何类的实例)。
- 返回:
一个新对象,其中 DALIOutputBatchRef 的任何实例都已替换为实际数据。
- start_thread()#
启动 DALI pipeline 线程。注意:首选使用作用域的 __enter__/__exit__
- stop_thread()#
停止 DALI pipeline 线程。注意:首选使用作用域的 __enter__/__exit__
- class nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader(*args, **kwargs)#
DALI 数据加载器,用于主循环,它将 pipeline 运行引用替换为 DALI 服务器生成的实际数据。有关完整示例,请参见
nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer
。- __init__(dali_server, *args, **kwargs)#
与 PyTorch 的 DataLoader 相同的接口,除了额外的 DALIServer 参数
示例用法#
DALI 代理简述#
from torchvision import datasets, transforms
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
# Step 1: Define a DALI pipeline
@pipeline_def
def my_dali_pipeline():
images = fn.external_source(name="images", no_copy=True)
images = fn.resize(images, size=[224, 224])
return fn.crop_mirror_normalize(
images, dtype=types.FLOAT, output_layout="CHW",
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
)
# Step 2: Initialize DALI server. The scope makes sure to start and stop the background thread
with dali_proxy.DALIServer(my_dali_pipeline(batch_size=64, num_threads=3, device_id=0)) as dali_server:
# Step 3: Define a PyTorch Dataset using the DALI proxy
dataset = datasets.ImageFolder("/path/to/images", transform=dali_server.proxy)
# Step 4: Use DALI proxy DataLoader
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=64, num_workers=8, drop_last=True)
# Step 5: Consume data
for data, target in loader:
print(data.shape) # Processed data ready
工作原理#
1. DALI Pipeline
DALI pipeline 定义数据处理步骤。输入数据使用 external_source()
馈送。
from nvidia.dali import pipeline_def, fn, types
@pipeline_def
def example_pipeline():
images = fn.external_source(name="images", no_copy=True)
images = fn.io.file.read(images)
images = fn.decoders.image(images, device="mixed", output_type=types.RGB)
return fn.resize(images, size=[224, 224])
pipeline = example_pipeline(batch_size=32, num_threads=2, device_id=0)
2. DALI Server 和代理
nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer
管理 pipeline 的执行。代理充当 PyTorch 数据工作器的接口。请注意,DALI pipeline 应至少包含一个输入(external_source()
实例),并且这些节点的名称随后将成为 DALI 代理可调用对象的输入。
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
with dali_proxy.DALIServer(pipeline) as dali_server:
future_samples = [dali_server.proxy(image) for image in images]
对于多个输入,我们可以选择使用位置参数、关键字参数
import numpy as np
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
@pipeline_def
def example_pipeline2(device):
a = fn.external_source(name="a", no_copy=True)
b = fn.external_source(name="b", no_copy=True)
return a + b, b - a
with dali_proxy.DALIServer(example_pipeline2(...)) as dali_server:
a = np.array(...)
b = np.array(...)
# Option 1: positional arguments
a_plus_b, b_minus_a = dali_server.proxy(a, b)
# Option 2: named arguments
a_plus_b, b_minus_a = dali_server.proxy(b=b, a=a)
也可以显式启动和停止服务器
dali_server = dali_proxy.DALIServer(example_pipeline2(...))
dataset = datasets.ImageFolder("/path/to/images", transform=dali_server.proxy)
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=64, num_workers=8, drop_last=True)
# Optional, it will be started on first attempt to get data from the loader anyway
dali_server.start_thread()
for data in loader:
...
# This is needed to make sure we have stopped the thread
dali_server.stop_thread()
在可能的情况下,使用 with
作用域。
3. 与 PyTorch DataLoader 集成
DALI 代理提供的 nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader
包装器简化了集成过程。
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
with dali_proxy.DALIServer(pipeline) as dali_server:
dataset = CustomDataset(dali_server.proxy, data=images)
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=32, num_workers=4)
for data, _ in loader:
print(data.shape) # Ready-to-use processed batch
如果使用自定义 nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader
,请显式调用 DALI 服务器
with dali_proxy.DALIServer(pipeline) as dali_server:
dataset = CustomDataset(dali_server.proxy, data=images)
loader = MyCustomDataloader(...)
for data, _ in loader:
# Replaces instances of ``DALIOutputBatchRef`` with actual data
processed_data = dali_server.produce_data(data)
print(processed_data.shape) # data is now ready
4. 与 PyTorch Dataset 集成
PyTorch Dataset 可以直接使用代理作为转换函数。请注意,我们可以选择仅将部分处理卸载到 DALI,同时保留一些原始数据的完整性。
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, transform_fn, data):
self.data = data
self.transform_fn = transform_fn
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
filename, label = self.data[idx]
return self.transform_fn(filename), label # Returns processed sample and the original label
5. 数据整理和执行
此步骤通常在 PyTorch DataLoader 内部被抽象出来,用户无需显式处理。default_collate
函数将处理后的样本组合成一个批次。当批次被整理时,DALI 异步执行 pipeline。
from torch.utils.data.dataloader import default_collate as default_collate
with dali_proxy.DALIServer(example_pipeline2(...)) as dali_server:
outs = []
for _ in range(10):
a = np.array(np.random.rand(3, 3), dtype=np.float32)
b = np.array(np.random.rand(3, 3), dtype=np.float32)
a_plus_b, b_minus_a = dali_server.proxy(a, b)
outs.append((a_plus_b, b_minus_a))
# Collate into a single batch run reference
outs = default_collate(outs)
# And we can now replace the run reference with actual data
outs = dali_server.produce_data(outs)
总结#
DALI 代理提供了一种简洁高效的方式来将 NVIDIA DALI 与 PyTorch 集成。通过将计算密集型任务卸载到 DALI,同时保持 PyTorch 的 Dataset 和 DataLoader 接口完整,它可以确保灵活性和最大性能。这种方法在大型数据 pipeline 和多模态工作流程中尤其强大。