在 DALI 中运行自定义 JAX 增强#

本教程展示了如何使用 plugin.jax.fn.jax_function 在 DALI pipeline 或迭代器中运行 JAX 函数。通过这种方式,你可以使用 JAX 编写自定义增强,并使其与其他 DALI 运算互操作。

设置示例#

我们将从一个简单的图像处理 DALI 迭代器开始。你可以在 DALI 和 JAX 入门 中阅读更多关于如何为 JAX 和 DALI 定义迭代器的信息。

[1]:
import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator

image_dir = "../data/images"


@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def baseline_iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    return images, labels


baseline_iterator = baseline_iterator_fn(batch_size=4)

baseline_batch = next(baseline_iterator)

让我们定义一个简单的辅助函数来呈现生成的批次,我们稍后将使用它。

[2]:
import matplotlib.pyplot as plt
from matplotlib import gridspec


def show_image(images, columns=4, fig_size=24):
    rows = (len(images) + columns - 1) // columns
    plt.figure(figsize=(fig_size, (fig_size // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(images[j])

添加使用 JAX 定义的增强#

现在,让我们为图片添加一些 JAX 处理。作为一个简单的示例,我们将使用 jax.numpy 数组索引水平翻转图像。

我们导入 jax 并编写一个函数,该函数期望一个 4D 数组 - 一批 HWC 图像。类似地,该函数返回 4D 数组,只是 W 维度被反转。

[3]:
import jax


def horz_flip(images: jax.Array):
    return images[:, :, ::-1, :]

为了将 horz_flip 插入到迭代器中,我们需要使用 jax_function 转换该函数。

[4]:
from nvidia.dali.plugin.jax.fn import jax_function


@jax_function
def horz_flip(images: jax.Array):
    return images[:, :, ::-1, :]

就是这样,我们可以在迭代器定义中像调用常规 DALI 运算一样调用该函数。

[5]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))

    images = horz_flip(images)

    return images, labels


iterator = iterator_fn(batch_size=4)
batch = next(iterator)

让我们比较基线迭代器和使用 horz_flip 的迭代器的输出。

[6]:
show_image(
    [
        image
        for pair in zip(baseline_batch["images"], batch["images"])
        for image in pair
    ]
)
../../_images/examples_custom_operations_jax_operator_basic_12_0.png

JAX 函数转换#

jax_function 可以与常见的 JAX 转换结合使用。例如,我们可以使用 jax.vmap 来向量化沿批次维度的处理,使用 jax.jit 来获得 JAX 即时编译的好处,或者两者都使用。

需要注意的一点是,jax_function 必须是最外层的转换。

[7]:
@jax_function
@jax.jit
def horz_flip(images: jax.Array):
    return images[:, :, ::-1, :]  # batch of HWC images


@jax_function
@jax.vmap
def horz_flip(image: jax.Array):
    # single HWC image (batch is implicit thanks to jax.vmap)
    return image[:, ::-1, :]


@jax_function
@jax.jit
@jax.vmap
def horz_flip(image: jax.Array):
    # single HWC image (batch is implicit thanks to jax.vmap)
    return image[:, ::-1, :]

多个输入和输出#

接下来,让我们为 horz_flip 添加另一个参数,该参数将控制给定的图像是否应该翻转或保持不变。我们将根据 DALI 的 fn.random.coin_flip() 的输出翻转图像。

[8]:
@jax_function
@jax.jit
@jax.vmap
def horz_flip(image: jax.Array, should_flip: jax.Array):
    return jax.lax.cond(
        should_flip, lambda x: x[:, ::-1, :], lambda x: x, image
    )
[9]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    should_flip = fn.random.coin_flip(seed=45)
    # note, currently all the inputs must reside on the same backend type,
    # as images are in GPU memory, we need to move should_flip there as well.
    images = horz_flip(images, should_flip.gpu())
    return images, labels


iterator = iterator_fn(batch_size=8)
batch = next(iterator)  # batch of data ready to be used by JAX
[10]:
show_image(
    [
        image
        for pair in zip(baseline_batch["images"], batch["images"])
        for image in pair
    ],
    columns=4,
)
../../_images/examples_custom_operations_jax_operator_basic_18_0.png

正如预期的那样,一些图像保持不变。

我们刚刚看到处理函数可以接受多个输入。类似地,它可以返回多个输出。然而,为此,我们需要提示 DALI 它应该期望多少个输出。我们可以通过将 num_outputs 传递给 jax_function 来做到这一点。

[11]:
@jax_function(num_outputs=2)
@jax.jit
@jax.vmap
def flip(image: jax.Array):
    horz_flip = image[:, ::-1, :]
    vert_flip = image[::-1, :, :]
    return horz_flip, vert_flip
[12]:
@data_iterator(
    output_map=["horz", "vert", "labels"], reader_name="image_reader"
)
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    horz_flipped, vert_flipped = flip(images)
    return horz_flipped, vert_flipped, labels


iterator = iterator_fn(batch_size=2)
batch = next(iterator)  # batch of data ready to be used by JAX
[13]:
show_image(
    [
        image
        for triple in zip(
            baseline_batch["images"], batch["horz"], batch["vert"]
        )
        for image in triple
    ],
    columns=3,
)
../../_images/examples_custom_operations_jax_operator_basic_22_0.png

常规 pipeline 中的 JAX 增强#

JAX 增强不限于 JAX 迭代器,它们也可以与常规 DALI pipeline 一起使用。

[14]:
from nvidia.dali import pipeline_def


@pipeline_def(batch_size=4, device_id=0, num_threads=4)
def pipeline():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    should_flip = fn.random.coin_flip(seed=44)
    flipped_images = horz_flip(images, should_flip.gpu())
    return images, flipped_images, labels


p = pipeline()
p.build()
images, flipped_images, labels = p.run()
[15]:
show_image(
    [
        image
        for pair in zip(images.as_cpu(), flipped_images.as_cpu())
        for image in pair
    ],
    columns=4,
)
../../_images/examples_custom_operations_jax_operator_basic_25_0.png