使用 Python 算子处理 GPU 数据#
本示例演示如何在 GPU 上使用 PythonFunction
算子。有关 Python 算子系列的介绍和一般信息,请参阅Python 算子 部分。
尽管 Python 算子并非旨在追求速度,但在 GPU 上运行它们可能很有用,例如,当我们想将自定义操作引入到现有的 GPU pipeline 中时。为此,PythonFunction
系列中的所有算子都有其 GPU 变体。
对于 TorchPythonFunction
和 DLTensorPythonFunction
算子,它们操作的数据格式与 CPU 上的格式保持一致,前者为 PyTorch 张量,后者为 DLPack 张量。对于 GPU PythonFunction
,实现函数的输入和输出是 CuPy 数组。
CuPy 操作#
由于 CuPy 数组 API 与 NumPy 中的 API 类似,因此我们可以实现与 CPU 示例中定义的几乎相同的操作,而无需进行任何代码更改。
[1]:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy
import cupy
def edit_images(image1, image2):
assert image1.shape == image2.shape
h, w, c = image1.shape
y, x = cupy.ogrid[0:h, 0:w]
mask = (x - w / 2) ** 2 + (y - h / 2) ** 2 > h * w / 9
result1 = cupy.copy(image1)
result1[mask] = image2[mask]
result2 = cupy.copy(image2)
result2[mask] = image1[mask]
return result1, result2
使用 CuPy 定义 GPU 函数的另一种方法是编写 CUDA kernel。在这里,我们展示了一个简单的 kernel,它可以交错两个图像的通道。有关更多信息,请参阅 CuPy 文档。
[2]:
mix_channels_kernel = cupy.ElementwiseKernel(
"uint8 x, uint8 y", "uint8 z", "z = (i % 3) ? x : y", "mix_channels"
)
警告
当 pipeline 启用条件执行时,必须采取额外的步骤来防止 function
被 AutoGraph 重写。 有两种方法可以实现这一点
在全局作用域(即
pipeline_def
作用域之外)定义函数。如果函数是另一个“工厂”函数的结果,则工厂函数必须在 pipeline 定义函数之外定义,并使用
<nvidia.dali.pipeline.do_not_convert>
修饰。
更多详细信息可以在 nvidia.dali.pipeline.do_not_convert
文档中找到。
定义 Pipeline#
我们定义一个类似于 Python 算子 部分中使用的 pipeline。要将执行从 CPU 移动到 GPU,我们只需要更改算子的设备参数。这也是 PythonFunction
算子的唯一用法差异。
[3]:
image_dir = "../data/images"
batch_size = 4
python_function_pipe = Pipeline(
batch_size=batch_size,
num_threads=4,
device_id=0,
exec_async=False,
exec_pipelined=False,
seed=99,
)
with python_function_pipe:
input1, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
input2, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
im1, im2 = fn.decoders.image(
[input1, input2], device="mixed", output_type=types.RGB
)
res1, res2 = fn.resize([im1, im2], device="gpu", resize_x=300, resize_y=300)
out1, out2 = fn.python_function(
res1, res2, device="gpu", function=edit_images, num_outputs=2
)
out3 = fn.python_function(
res1, res2, device="gpu", function=mix_channels_kernel
)
python_function_pipe.set_outputs(out1, out2, out3)
运行 Pipeline 并可视化结果#
我们可以运行 pipeline 并以类似于 CPU 示例的方式显示结果。
注意: 在尝试绘制它们之前,请记住将输出批次移动到主机内存。
[4]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm
%matplotlib inline
batch_size = 4
def show_images(image_batch):
columns = 4
rows = (batch_size + 1) // columns
fig = plt.figure(figsize=(32, (32 // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
plt.imshow(image_batch.at(j))
python_function_pipe.build()
ims1, ims2, ims3 = python_function_pipe.run()
show_images(ims1.as_cpu())
show_images(ims2.as_cpu())
show_images(ims3.as_cpu())
data:image/s3,"s3://crabby-images/45e12/45e121e7f4a9f0e5a542458d7839b79c51d4a37c" alt="../../_images/examples_custom_operations_gpu_python_operator_10_0.png"
data:image/s3,"s3://crabby-images/bcd19/bcd1963e9f8820b67b225bb42729e36d869685ec" alt="../../_images/examples_custom_operations_gpu_python_operator_10_1.png"
data:image/s3,"s3://crabby-images/d0ed9/d0ed9ea1f37f74d2879136335d5ea8605951222d" alt="../../_images/examples_custom_operations_gpu_python_operator_10_2.png"
高级:DLTensorPythonFunction 中的设备同步#
当使用 PythonFunction
或 TorchPythonFunction
时,我们不必将 GPU 代码与 DALI pipeline 的其余部分同步,因为同步由算子处理。DLTensorPythonFunction
算子另一方面,将设备同步留给用户。
注意: 不同框架和库的同步过程可能有所不同。
例如,我们将围绕先前实现的 mix_channels_kernel
编写一个包装器,该包装器将 DLPack 张量转换为 CuPy 数组并处理流同步。
[5]:
def mix_channels_wrapper(tensor1, tensor2):
array1 = cupy.fromDlpack(tensor1)
array2 = cupy.fromDlpack(tensor2)
result = mix_channels_kernel(array1, array2)
cupy.cuda.get_current_stream().synchronize()
return result.toDlpack()
dltensor_function_pipe = Pipeline(
batch_size=batch_size,
num_threads=4,
device_id=0,
exec_async=False,
exec_pipelined=False,
seed=99,
)
with dltensor_function_pipe:
input1, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
input2, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
im1, im2 = fn.decoders.image(
[input1, input2], device="mixed", output_type=types.RGB
)
res1, res2 = fn.resize([im1, im2], device="gpu", resize_x=300, resize_y=300)
out = fn.dl_tensor_python_function(
res1,
res2,
device="gpu",
function=mix_channels_wrapper,
synchronize_stream=True,
batch_processing=False,
)
dltensor_function_pipe.set_outputs(out)
dltensor_function_pipe.build()
(ims,) = dltensor_function_pipe.run()
show_images(ims.as_cpu())
data:image/s3,"s3://crabby-images/ce16d/ce16d5ff4b4549934fde94f2bdce973b1d6a3aab" alt="../../_images/examples_custom_operations_gpu_python_operator_12_0.png"
结果与使用 PythonFunction
运行 mix_channels_kernel
后的结果相同。为了在 DLTensorPythonFunction
中正确同步设备代码,请确保满足以下条件
在提供的函数开始之前,所有先前的 DALI GPU 工作都已完成。
在我们返回结果之前,我们在提供的函数中调度的任务已完成。
第一个条件由 synchronize_stream=True
标志(默认设置为 True
)保证。用户负责提供第二部分。在上面的示例中,通过添加 cupy.cuda.get_current_stream().synchronize()
行来实现同步。