在 DALI 中运行自定义 JAX 增强#
本教程展示了如何使用 plugin.jax.fn.jax_function 在 DALI pipeline 或迭代器内部运行 JAX 函数。通过这种方式,您可以使用 JAX 编写自定义增强,并使其与其他 DALI 操作互操作。
设置示例#
我们将从一个简单的图像处理 DALI 迭代器开始。您可以在 DALI 和 JAX 入门指南 中阅读更多关于如何为 JAX 使用 DALI 定义迭代器的信息。
[1]:
import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator
image_dir = "../data/images"
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def baseline_iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(300, 300))
return images, labels
baseline_iterator = baseline_iterator_fn(batch_size=4)
baseline_batch = next(baseline_iterator)
让我们定义一个简单的辅助函数来展示生成的批次,我们稍后将使用它。
[2]:
import matplotlib.pyplot as plt
from matplotlib import gridspec
def show_image(images, columns=4, fig_size=24):
rows = (len(images) + columns - 1) // columns
plt.figure(figsize=(fig_size, (fig_size // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
plt.imshow(images[j])
添加使用 JAX 定义的增强#
现在,让我们为图片添加一些 JAX 处理。作为一个简单的例子,我们将使用 jax.numpy
数组索引水平翻转图像。
我们导入 jax
并编写一个函数,该函数期望一个 4D 数组 - 一批 HWC
图像。类似地,该函数返回 4D 数组,只是 W
维度被反转。
[3]:
import jax
def horz_flip(images: jax.Array):
return images[:, :, ::-1, :]
要将 horz_flip
插入迭代器,我们需要使用 jax_function
转换该函数。
[4]:
from nvidia.dali.plugin.jax.fn import jax_function
@jax_function
def horz_flip(images: jax.Array):
return images[:, :, ::-1, :]
就是这样,我们可以在迭代器定义中将该函数作为常规 DALI 操作调用。
[5]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(300, 300))
images = horz_flip(images)
return images, labels
iterator = iterator_fn(batch_size=4)
batch = next(iterator)
让我们比较一下基线迭代器和使用 horz_flip
的迭代器的输出。
[6]:
show_image(
[
image
for pair in zip(baseline_batch["images"], batch["images"])
for image in pair
]
)

JAX 函数变换#
jax_function
可以与常见的 JAX 变换结合使用。例如,我们可以使用 jax.vmap
来向量化沿批次维度的处理,使用 jax.jit
来获得 JAX 的即时编译的好处,或者两者都使用。
需要注意的一点是 jax_function
必须是最外层的变换。
[7]:
@jax_function
@jax.jit
def horz_flip(images: jax.Array):
return images[:, :, ::-1, :] # batch of HWC images
@jax_function
@jax.vmap
def horz_flip(image: jax.Array):
# single HWC image (batch is implicit thanks to jax.vmap)
return image[:, ::-1, :]
@jax_function
@jax.jit
@jax.vmap
def horz_flip(image: jax.Array):
# single HWC image (batch is implicit thanks to jax.vmap)
return image[:, ::-1, :]
多个输入和输出#
接下来,让我们为 horz_flip
添加另一个参数,该参数将控制给定的图像是否应翻转或保持不变。我们将根据 DALI 的 fn.random.coin_flip()
的输出翻转图像。
[8]:
@jax_function
@jax.jit
@jax.vmap
def horz_flip(image: jax.Array, should_flip: jax.Array):
return jax.lax.cond(
should_flip, lambda x: x[:, ::-1, :], lambda x: x, image
)
[9]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(300, 300))
should_flip = fn.random.coin_flip(seed=45)
# note, currently all the inputs must reside on the same backend type,
# as images are in GPU memory, we need to move should_flip there as well.
images = horz_flip(images, should_flip.gpu())
return images, labels
iterator = iterator_fn(batch_size=8)
batch = next(iterator) # batch of data ready to be used by JAX
[10]:
show_image(
[
image
for pair in zip(baseline_batch["images"], batch["images"])
for image in pair
],
columns=4,
)

正如预期的那样,一些图像保持不变。
我们刚刚看到处理函数可以接受多个输入。类似地,它可以返回多个输出。为此,但是,我们需要提示 DALI 它应该期望多少个输出。我们可以通过将 num_outputs
传递给 jax_function
来做到这一点。
[11]:
@jax_function(num_outputs=2)
@jax.jit
@jax.vmap
def flip(image: jax.Array):
horz_flip = image[:, ::-1, :]
vert_flip = image[::-1, :, :]
return horz_flip, vert_flip
[12]:
@data_iterator(
output_map=["horz", "vert", "labels"], reader_name="image_reader"
)
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(300, 300))
horz_flipped, vert_flipped = flip(images)
return horz_flipped, vert_flipped, labels
iterator = iterator_fn(batch_size=2)
batch = next(iterator) # batch of data ready to be used by JAX
[13]:
show_image(
[
image
for triple in zip(
baseline_batch["images"], batch["horz"], batch["vert"]
)
for image in triple
],
columns=3,
)

常规 pipeline 中的 JAX 增强#
JAX 增强不限于 JAX 迭代器,它们也可以与常规 DALI pipeline 一起使用。
[14]:
from nvidia.dali import pipeline_def
@pipeline_def(batch_size=4, device_id=0, num_threads=4)
def pipeline():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(300, 300))
should_flip = fn.random.coin_flip(seed=44)
flipped_images = horz_flip(images, should_flip.gpu())
return images, flipped_images, labels
p = pipeline()
p.build()
images, flipped_images, labels = p.run()
[15]:
show_image(
[
image
for pair in zip(images.as_cpu(), flipped_images.as_cpu())
for image in pair
],
columns=4,
)
