Numba 函数 - 运行编译后的 C 回调函数#

本教程展示了如何运行以 Numba JIT 编译函数形式编写的自定义操作。

该运算符将用户定义的函数转换为可从外部 C 代码调用的编译函数,从而消除 Python 解释器的开销。

让我们首先导入 DALI 和一些实用程序。

[1]:
from nvidia.dali import pipeline_def
import nvidia.dali as dali
import nvidia.dali.fn as fn
from nvidia.dali.plugin.numba.fn.experimental import numba_function
import nvidia.dali.types as dali_types

image_dir = "../data/images"
max_batch_size = 8

作为一个示例,我们将编写一个 pipeline,该 pipeline 将图像旋转 90 度。由于输出形状将与输入形状不同,因此我们还需要实现一个 setup 函数。在 setup 函数中,我们根据输入形状定义输出数据形状。请注意,setup 函数对整个批次调用一次。如果未提供 setup 函数,则运算符假定输出的类型和形状与输入相同。

定义交换宽度和高度的形状函数#

警告

当 pipeline 启用条件执行时,必须采取其他步骤来防止 run_fnsetup_fn 函数被 AutoGraph 重写。 有两种方法可以实现这一点

  1. 在全局范围内定义函数(即在 pipeline_def 范围之外)。

  2. 如果函数是另一个“工厂”函数的结果,则工厂函数必须在 pipeline 定义函数之外定义,并使用 <nvidia.dali.pipeline.do_not_convert> 进行装饰。

更多详细信息可以在 <nvidia.dali.pipeline.do_not_convert> 文档中找到。

[2]:
def rot_image_setup(outs, ins):
    for i in range(len(outs)):
        for sample_idx in range(len(outs[i])):
            outs[i][sample_idx][0] = ins[i][sample_idx][1]
            outs[i][sample_idx][1] = ins[i][sample_idx][0]
            outs[i][sample_idx][2] = ins[i][sample_idx][2]

setup 函数计算输出的形状。 它接受两个参数:要填充的输出形状列表和输入形状列表。 应用于 outs/ins 的外部索引是相应输出/输入的索引; 内部索引指示批次中的样本。

outs[i] - 第 i 个输出

outs[i][j] - 获取第 i 个输出的第 j 个样本。

此函数已编译。 它应该能够在 nopython 模式下工作。

定义基于输入样本填充输出样本的处理函数#

[3]:
def rot_image(out0, in0):
    for i in range(out0.shape[0]):
        for j in range(out0.shape[1]):
            out0[i][j] = in0[j][out0.shape[0] - i - 1]

run 函数可以有多个输入或输出。 具有 n 个输出和 m 个输入的函数应具有以下签名

def run_fn(out_0, out_1, ..., out_n, in_0, in_1, ..., in_m):

默认情况下,numba 函数运算符在每个样本的基础上工作。 也可以通过指定参数 batch_processing=True 来指定一次处理整个批次的函数。 在这种情况下,函数接收的输出和输入包含一个前导维度,表示样本索引。

定义 Pipeline#

要定义 pipeline,我们将 run 函数和 shape 函数都传递给 DALI numba 运算符。 我们需要传递输入和输出的 DALI 类型及其维度。

[4]:
@pipeline_def
def rotate_image_pipe(
    run_fn=None,
    setup_fn=None,
    out_types=None,
    in_types=None,
    outs_ndim=None,
    ins_ndim=None,
):
    files, labels = fn.readers.file(file_root=image_dir)
    images_in = dali.fn.decoders.image(files, device="cpu")
    return images_in, numba_function(
        images_in,
        run_fn=run_fn,
        setup_fn=setup_fn,
        out_types=out_types,
        in_types=in_types,
        outs_ndim=outs_ndim,
        ins_ndim=ins_ndim,
    )

有关读取器和解码器的更多信息,请参阅 入门指南 notebook。

下一步是构建和运行我们的 pipeline。

[5]:
pipe = rotate_image_pipe(
    batch_size=max_batch_size,
    num_threads=1,
    device_id=0,
    run_fn=rot_image,
    setup_fn=rot_image_setup,
    out_types=[dali_types.UINT8],
    in_types=[dali_types.UINT8],
    outs_ndim=[3],
    ins_ndim=[3],
)
pipe.build()
images_in, images_out = pipe.run()
[6]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

%matplotlib inline


def show_images(image_batch):
    columns = 4
    rows = (max_batch_size + 1) // (columns)
    fig = plt.figure(figsize=(8, (8 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(image_batch.at(j))
[7]:
show_images(images_in)
../../_images/examples_custom_operations_numba_function_12_0.png
[8]:
show_images(images_out)
../../_images/examples_custom_operations_numba_function_13_0.png

多个输入和输出#

numba 函数运算符最多可以接受 6 个输入并产生最多 6 个输出。 例如,让我们编写一个 pipeline,该 pipeline 接收一个图像并将 RGB 通道拆分为 3 个单独的输出

首先,我们需要定义 setup 函数,该函数定义输出的形状。 setup 函数可能如下所示

[9]:
def setup_split_rgb(outs, ins):
    out0 = outs[0]
    out1 = outs[1]
    out2 = outs[2]
    for r_shape, g_shape, b_shape, in_shape in zip(
        outs[0], outs[1], outs[2], ins[0]
    ):
        r_shape[:] = g_shape[:] = b_shape[:] = in_shape[0:2]

第二步是编写 run 函数,该函数处理数据

[10]:
def run_split_rgb(out0_batch, out1_batch, out2_batch, in0_batch):
    for R, G, B, in0 in zip(out0_batch, out1_batch, out2_batch, in0_batch):
        for i in range(in0.shape[0]):
            for j in range(in0.shape[1]):
                R[i][j] = in0[i][j][0]
                G[i][j] = in0[i][j][1]
                B[i][j] = in0[i][j][2]
[11]:
@pipeline_def
def numba_function_split_rgb_pipe(
    run_fn=None,
    out_types=None,
    in_types=None,
    outs_ndim=None,
    ins_ndim=None,
    setup_fn=None,
):
    files, _ = fn.readers.file(file_root=image_dir)
    images_in = dali.fn.decoders.image(files, device="cpu")
    out0, out1, out2 = numba_function(
        images_in,
        run_fn=run_fn,
        setup_fn=setup_fn,
        out_types=out_types,
        in_types=in_types,
        outs_ndim=outs_ndim,
        ins_ndim=ins_ndim,
        batch_processing=True,
    )
    return images_in, out0, out1, out2
[12]:
pipe = numba_function_split_rgb_pipe(
    batch_size=max_batch_size,
    num_threads=3,
    device_id=0,
    run_fn=run_split_rgb,
    setup_fn=setup_split_rgb,
    out_types=[dali_types.UINT8 for _ in range(3)],
    in_types=[dali_types.UINT8],
    outs_ndim=[2, 2, 2],
    ins_ndim=[3],
)
pipe.build()
images_in, R, G, B = pipe.run()
[13]:
show_images(images_in)
../../_images/examples_custom_operations_numba_function_20_0.png
[14]:
import numpy as np


def show_images_rgb_planes(R, G, B):
    columns = 4
    rows = (max_batch_size + 1) // (columns)
    fig = plt.figure(figsize=(8, (8 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        shape = None
        shape = R.at(j).shape if j < len(R) else shape
        shape = G.at(j).shape if j < len(G) else shape
        shape = B.at(j).shape if j < len(B) else shape
        plt.imshow(
            np.stack(
                [
                    R.at(j) if j < len(R) else np.zeros(shape, dtype=np.uint8),
                    G.at(j) if j < len(G) else np.zeros(shape, dtype=np.uint8),
                    B.at(j) if j < len(B) else np.zeros(shape, dtype=np.uint8),
                ],
                axis=2,
            )
        )
[15]:
show_images_rgb_planes(R, G, B)
../../_images/examples_custom_operations_numba_function_22_0.png

我们可以尝试仅根据一个通道显示图像,看看我们的 pipeline 是否正常工作。

[16]:
show_images_rgb_planes(R, [], [])
../../_images/examples_custom_operations_numba_function_24_0.png
[17]:
show_images_rgb_planes([], G, [])
../../_images/examples_custom_operations_numba_function_25_0.png
[18]:
show_images_rgb_planes([], [], B)
../../_images/examples_custom_operations_numba_function_26_0.png