JAX 插件 API 参考#

class nvidia.dali.plugin.jax.DALIGenericIterator(pipelines, output_map, size=-1, reader_name=None, auto_reset=False, last_batch_padded=False, last_batch_policy=LastBatchPolicy.FILL, prepare_first_batch=True, sharding=None)#

用于 JAX 的通用 DALI 迭代器。它可以从 DALI pipeline 返回任意数量的输出,形式为 JAX 数组。

参数:
  • pipelines (list of Pipeline) – 要使用的 pipeline 列表

  • output_map (list of str) – 字符串列表,将 DALI pipeline 的连续输出映射到用户指定的名称。输出将以这些名称的字典形式从迭代器返回。每个名称应distinct

  • size (int, default = -1) – 包装的 pipeline 的分片中的样本数(如果有多于一个,则是总和)提供 -1 意味着迭代器将工作到从 iter_setup() 内部引发 StopIteration 为止。选项 last_batch_policylast_batch_padded 在这种情况下不起作用。它仅适用于迭代器内部的一个 pipeline。与 reader_name 参数互斥

  • reader_name (str, default = None) – 读取器的名称,将查询该读取器以获取分片大小、分片数和所有其他属性,这些属性对于正确计数迭代器需要处理的相关样本和填充样本的数量是必要的。它自动设置 last_batch_padded 以匹配读取器的配置。

  • auto_reset (string or bool, optional, default = False) –

    迭代器是否为下一个 epoch 自动重置自身,还是需要显式调用 reset()。

    它可以是以下值之一

    • "no", FalseNone - 在 epoch 结束时引发 StopIteration

    并且需要调用 reset() * "yes""True"- 在 epoch 结束时引发 StopIteration,但 reset() 在内部自动调用。

  • last_batch_policy (optional, default = LastBatchPolicy.FILL) – 当 epoch 中没有足够的样本来完全填充最后一个批次时,该如何处理。请参阅 nvidia.dali.plugin.base_iterator.LastBatchPolicy()。JAX 迭代器不支持 LastBatchPolicy.PARTIAL

  • last_batch_padded (bool, optional, default = False) – DALI 提供的最后一个批次是否用最后一个样本填充,或者只是包装起来。与 last_batch_policy 结合使用时,它会告知迭代器返回的最后一个批次是否仅部分填充了当前 epoch 的数据,是丢弃填充样本还是来自下一个 epoch 的样本。如果设置为 False,则下一个 epoch 将提前结束,因为来自它的数据已被消耗但被丢弃。如果设置为 True,则下一个 epoch 的长度将与第一个 epoch 相同。为此,还需要将读取器中的选项 pad_last_batch 设置为 True。当提供 reader_name 参数时,它将被覆盖

  • prepare_first_batch (bool, optional, default = True) – DALI 是否应在创建迭代器后立即缓冲第一个批次,以便在提示迭代器获取数据时,已经准备好一个批次

  • sharding (jax.sharding.Sharding) – jax.sharding.Sharding 兼容对象,如果存在,将用于为每个类别构建输出 jax.Array。如果 None,则如果提供了多个 pipeline,则迭代器返回与 pmapped JAX 函数兼容的值。

示例

对于数据集 [1,2,3,4,5,6,7] 和批次大小 2

last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = True -> 最后一个批次 = [7, 7], 下一次迭代将返回 [1, 2]

last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = False -> 最后一个批次 = [7, 1], 下一次迭代将返回 [2, 3]

last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = True -> 最后一个批次 = [5, 6], 下一次迭代将返回 [1, 2]

last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = False -> 最后一个批次 = [5, 6], 下一次迭代将返回 [2, 3]

注意

JAX 迭代器不支持 LastBatchPolicy.PARTIAL。

checkpoints()#

返回 pipeline 的当前检查点。

next()#

返回下一批数据。

reset()#

在完整 epoch 后重置迭代器。DALI 迭代器不支持在 epoch 结束前重置,并将忽略此类请求。

property size#
nvidia.dali.plugin.jax.data_iterator(pipeline_fn=None, output_map=[], size=-1, reader_name=None, auto_reset=False, last_batch_padded=False, last_batch_policy=LastBatchPolicy.FILL, prepare_first_batch=True, sharding=None, devices=None)#

用于 JAX 的 DALI 迭代器的装饰器。当调用装饰函数时,返回用于 JAX 的 DALI 迭代器。

装饰函数应返回 DALI pipeline 定义函数。装饰器接受 nvidia.dali.plugin.base_iterator.DALIGenericIterator.__init__() 的所有参数,并将它们传递给迭代器构造函数。如果未将 device_id 参数传递给装饰函数,则假定第一个设备是我们想要使用的设备,并且 device_id 设置为 0。如果相同的参数传递给装饰器和装饰函数,则会引发异常。

参数:
  • function (pipeline_fn) – 要装饰的函数。它应与 nvidia.dali.pipeline.pipeline_def() 装饰器兼容。对于多 GPU 支持,它应接受 device_idshard_idnum_shards 参数。

  • output_map (list of str) – 字符串列表,将 DALI pipeline 的连续输出映射到用户指定的名称。输出将以这些名称的字典形式从迭代器返回。每个名称应distinct

  • size (int, default = -1) – 包装的 pipeline 的分片中的样本数(如果有多于一个,则是总和)提供 -1 意味着迭代器将工作到从 iter_setup() 内部引发 StopIteration 为止。选项 last_batch_policylast_batch_padded 在这种情况下不起作用。它仅适用于迭代器内部的一个 pipeline。与 reader_name 参数互斥

  • reader_name (str, default = None) – 读取器的名称,将查询该读取器以获取分片大小、分片数和所有其他属性,这些属性对于正确计数迭代器需要处理的相关样本和填充样本的数量是必要的。它自动设置 last_batch_padded 以匹配读取器的配置。

  • auto_reset (string or bool, optional, default = False) –

    迭代器是否为下一个 epoch 自动重置自身,还是需要显式调用 reset()。

    它可以是以下值之一

    • "no", FalseNone - 在 epoch 结束时引发 StopIteration

    并且需要调用 reset() * "yes""True"- 在 epoch 结束时引发 StopIteration,但 reset() 在内部自动调用。

  • last_batch_policy (optional, default = LastBatchPolicy.FILL) – 当 epoch 中没有足够的样本来完全填充最后一个批次时,该如何处理。请参阅 nvidia.dali.plugin.base_iterator.LastBatchPolicy()。JAX 迭代器不支持 LastBatchPolicy.PARTIAL

  • last_batch_padded (bool, optional, default = False) – DALI 提供的最后一个批次是否用最后一个样本填充,或者只是包装起来。与 last_batch_policy 结合使用时,它会告知迭代器返回的最后一个批次是否仅部分填充了当前 epoch 的数据,是丢弃填充样本还是来自下一个 epoch 的样本。如果设置为 False,则下一个 epoch 将提前结束,因为来自它的数据已被消耗但被丢弃。如果设置为 True,则下一个 epoch 的长度将与第一个 epoch 相同。为此,还需要将读取器中的选项 pad_last_batch 设置为 True。当提供 reader_name 参数时,它将被覆盖

  • prepare_first_batch (bool, optional, default = True) – DALI 是否应在创建迭代器后立即缓冲第一个批次,以便在提示迭代器获取数据时,已经准备好一个批次

  • sharding (jax.sharding.Sharding) – jax.sharding.Sharding 兼容对象,如果存在,将用于为每个类别构建输出 jax.Array。迭代器将返回与 JAX 中的自动并行化兼容的输出。此参数与 devices 参数互斥。如果提供了 devices,则应将 sharding 设置为 None。

  • devices (list of jax.Device) – 用于并行运行 pipeline 的 JAX 设备列表。迭代器将返回与 pmapped JAX 函数兼容的输出。此参数与 sharding 参数互斥。如果提供了 sharding,则应将 devices 设置为 None。

  • checkpoints (list of str, optional, default = None) – 使用迭代器的 .checkpoints() 方法获得的检查点。如果提供,它们将用于恢复 pipeline 的状态。

示例

对于数据集 [1,2,3,4,5,6,7] 和批次大小 2

last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = True -> 最后一个批次 = [7, 7], 下一次迭代将返回 [1, 2]

last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = False -> 最后一个批次 = [7, 1], 下一次迭代将返回 [2, 3]

last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = True -> 最后一个批次 = [5, 6], 下一次迭代将返回 [1, 2]

last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = False -> 最后一个批次 = [5, 6], 下一次迭代将返回 [2, 3]

注意

JAX 迭代器不支持 LastBatchPolicy.PARTIAL。