nvidia.dali.plugin.jax.fn.jax_function#
- nvidia.dali.plugin.jax.fn.jax_function(function=None, num_outputs=1, output_layouts=None, sharding=None, device=None, preserve=True)#
转换 Python 函数
function
,该函数处理jax.Array
对象到 DALI 操作符,该操作符可在 DALI pipeline 定义或 JAX 插件迭代器定义中使用。转换后的函数接受并返回与原始function
相同数量的输入和输出,但更改了它们的类型:从jax.Array
到 DALI 跟踪的DataNodes
。结果函数可与其他 DALI 操作符互操作。例如,我们可以在 JAX 中实现水平翻转操作,如下所示
import jax from nvidia.dali import pipeline_def, fn, types from nvidia.dali.plugin import jax as dax @dax.fn.jax_function def flip_horizontal(image_batch: jax.Array): return image_batch[:, :, ::-1, :] # batch of HWC images @pipeline_def(batch_size=4, device_id=0, num_threads=4) def pipeline(): image, _ = fn.readers.file(file_root=jpeg_path_dali_extra) image = fn.decoders.image(image, device="mixed", output_type=types.RGB) image = fn.resize(image, size=(244, 244)) flipped = flip_horizontal(image) return image, flipped
可以使用常用的 JAX 转换来转换
function
,例如,我们可以利用 JAX 的即时编译和向量化,在上面的示例中添加适当的装饰器@dax.fn.jax_function @jax.jit @jax.vmap def flip_horizontal(image: jax.Array): return image[:, ::-1, :] # HWC image
如果结果函数与 DALI GPU 批次一起运行,则内部 DALI 和 JAX 流将同步。JAX 操作无需用户进一步同步。
在
function
完成后,不得访问传递给function
的jax.Arrays
(例如,不应将其存储在某些非局部作用域中)。注意
这是一个实验性 API,未来版本可能会更改。
注意
jax_function 需要 JAX 版本 0.4.16 或更高版本,并具有 GPU 支持。JAX 0.4.16 需要 Python 3.9 或更高版本。
- 参数:
function¶ (JaxCallback) – 接受并返回零个或多个 jax.Array 对象的 Python 回调。该函数将接收 DALI 处理的批次,作为 jax.Array 张量(最左侧的范围对应于 DALI 批次)。因此,转换后的函数只能接收包含形状一致的样本的 DALI 批次。
num_outputs¶ (int, default=1) –
由
function
返回的输出数量。函数可以不返回任何输出,在这种情况下,
num_outputs
必须设置为 0。如果num_outputs
为 1(默认值),则回调应返回单个 JAX 数组;对于num_outputs
> 1,回调应返回 JAX 数组的元组。output_layouts¶ (Union[str, Tuple[str]], optional) –
返回张量的布局。
它可以是
num_outputs
各个输出的所有字符串列表,也可以是设置为所有输出的单个字符串。请注意,在 DALI 中,最外层的批次范围是隐式的,布局应仅考虑样本维度。
如果未指定参数,并且
function
的第 i 个输出与第 i 个输入具有相同的维度,则布局将从输入传播到相应的输出。sharding¶ (jax.sharding.Sharding, optional) –
JAX 分片对象(
PositionalSharding
或NamedSharding
)。如果指定,则传递给function
的jax.Arrays
将是全局jax.Array
,了解分片。注意
目前,仅支持全局分片,即给定进程中的本地设备数量必须正好为一个。
device¶ (str, optional) – “cpu”、“gpu” 或 None 之一。转换后的函数的所有 DALI 输入和输出将放置的设备类型。如果未指定,将根据传递给结果函数的 DALI 输入推断设备。目前,所有输入和输出的设备类型必须相同。
preserve¶ (bool, default=True) – 如果设置为 False,则返回的 DALI 函数可能会从 DALI pipeline 中优化掉,如果它不返回任何输出或函数输出均不贡献于 pipeline 的输出。
- 返回:
处理 DALI 跟踪批次 (DataNodes) 的转换函数。
- 返回类型:
DaliCallback