nvidia.dali.plugin.jax.fn.jax_function#

nvidia.dali.plugin.jax.fn.jax_function(function=None, num_outputs=1, output_layouts=None, sharding=None, device=None, preserve=True)#

转换 Python 函数 function,该函数处理 jax.Array 对象到可以在 DALI pipeline 定义或 JAX 插件迭代器定义中使用的 DALI 操作符。转换后的函数接受并返回与原始 function 相同数量的输入和输出,但更改了它们的类型:从 jax.Array 到 DALI 跟踪的 DataNodes。结果函数可与其他 DALI 操作符互操作。

例如,我们可以在 JAX 中实现水平翻转操作,如下所示

import jax
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin import jax as dax

@dax.fn.jax_function
def flip_horizontal(image_batch: jax.Array):
    return image_batch[:, :, ::-1, :]  # batch of HWC images

@pipeline_def(batch_size=4, device_id=0, num_threads=4)
def pipeline():
    image, _ = fn.readers.file(file_root=jpeg_path_dali_extra)
    image = fn.decoders.image(image, device="mixed", output_type=types.RGB)
    image = fn.resize(image, size=(244, 244))
    flipped = flip_horizontal(image)
    return image, flipped

可以使用常用的 JAX 转换来转换 function,例如,我们可以利用 JAX 的即时编译和向量化,在上面的示例中添加适当的装饰器

@dax.fn.jax_function
@jax.jit
@jax.vmap
def flip_horizontal(image: jax.Array):
    return image[:, ::-1, :]  # HWC image

如果结果函数与 DALI GPU 批次一起运行,则内部 DALI 和 JAX 流将同步。JAX 操作不需要用户进一步同步。

function 完成后,不得访问传递给 functionjax.Arrays(例如,它们不应存储在某些非局部作用域中)。

注意

这是实验性 API,将来版本可能会更改。

注意

jax_function 需要 JAX 版本 0.4.16 或更高版本,并具有 GPU 支持。JAX 0.4.16 需要 Python 3.9 或更高版本。

参数:
  • function (JaxCallback) – 接受并返回零个或多个 jax.Array 对象的 Python 回调。该函数将接收 DALI 处理的批次,作为 jax.Array 张量(最左边的范围对应于 DALI 批次)。因此,转换后的函数只能接收包含形状统一的样本的 DALI 批次。

  • num_outputs (int, default=1) –

    function 返回的输出数量。

    函数可以不返回任何输出,在这种情况下,num_outputs 必须设置为 0。如果 num_outputs 为 1(默认值),则回调应返回单个 JAX 数组;对于 num_outputs > 1,回调应返回 JAX 数组的元组。

  • output_layouts (Union[str, Tuple[str]], optional) –

    返回的张量的布局。

    它可以是所有 num_outputs 各自输出的字符串列表,也可以是设置为所有输出的单个字符串。

    请注意,在 DALI 中,最外层的批次范围是隐式的,布局应仅考虑样本维度。

    如果未指定参数,并且 function 的第 i 个输出与第 i 个输入具有相同的维度,则布局将从输入传播到相应的输出。

  • sharding (jax.sharding.Sharding, optional) –

    JAX 分片对象(PositionalShardingNamedSharding)。如果指定,则传递给 functionjax.Arrays 将是全局 jax.Array,能够感知分片。

    注意

    目前,仅支持全局分片,即给定进程中的本地设备数量必须恰好为一个。

  • device (str, optional) – “cpu”、“gpu” 或 None 之一。转换后的函数的所有 DALI 输入和输出将放置在其上的设备类型。如果未指定,将根据传递给结果函数的 DALI 输入推断设备。目前,所有输入和输出的设备类型必须相同。

  • preserve (bool, default=True) – 如果设置为 False,则返回的 DALI 函数可能会从 DALI pipeline 中优化掉,如果它不返回任何输出,或者函数输出均未对 pipeline 的输出做出贡献。

返回:

处理 DALI 跟踪批次 (DataNodes) 的转换函数。

返回类型:

DaliCallback