检查点#

检查点是 DALI 中的一项功能,允许您将 pipeline 的当前状态保存到文件中。然后,您可以从保存的检查点恢复 pipeline,新的 pipeline 将产生与旧 pipeline 完全相同的输出。对于可能被中断的长时间运行的训练作业,它特别有用。

DALI pipeline 的检查点包含有关 pipeline 中使用的所有随机数生成器的状态以及每个读取器的进度的信息。

检查点 API#

启用检查点#

要启用检查点,请在创建 pipeline 时设置 enable_checkpointing=True。启用此选项后,DALI 将跟踪每个操作符的状态,允许您按需保存它。启用检查点不应影响性能。

@pipeline_def(..., enable_checkpointing=True)
def pipeline():
    ...

p = pipeline()

注意

如果启用了检查点,shuffle_after_epoch=True 的读取器可能会以不同的方式打乱样本。

保存检查点#

要保存检查点,您需要调用 Pipeline.checkpoint() 方法,该方法将返回序列化的检查点作为字符串。可选地,您可以将文件名作为参数传递,DALI 会将检查点保存在那里。

for _ in range(iters):
    output = p.run()

# Write the checkpoint to file:
checkpoint = p.checkpoint()
open('checkpoint_file.cpt', 'wb')

# Or simply:
checkpoint = p.checkpoint('checkpoint_file.cpt')

注意

调用 Pipeline.checkpoint() 方法可能会引入可观察到的开销。我们建议您不要过于频繁地调用它。

从检查点恢复#

您可以稍后从保存的检查点恢复 pipeline 状态。为此,请在构造 Pipeline 时将 checkpoint 参数传递给它。这样的 pipeline 应该返回与原始 pipeline 完全相同的输出。

checkpoint = open('checkpoint_file.cpt', 'rb').read()
p_restored = pipeline(checkpoint=checkpoint)

警告

请确保您要恢复的 pipeline 与原始 pipeline 相同,即包含相同的操作符和相同的参数。从使用不同 pipeline 创建的检查点恢复将导致未定义的行为。

外部源检查点#

fn.external_source() 操作符仅部分支持检查点。

仅当 source 是接受批次索引、BatchInfoSampleInfo 的单参数可调用对象时,才支持检查点。对于此类 sources,查询将从检查点中保存的点继续。

其他类型的 source 不支持检查点。它们的状态不会保存在检查点中,并且从检查点恢复后,它们将从头开始。如果您想使用检查点,我们建议您将您的源重写为受支持的可调用对象。

TensorFlow 插件中的检查点#

plugin.tf.DALIDataset 与 TensorFlow 的 tf.train.checkpoint 集成。有关更多详细信息,请参阅 TensorFlow 检查点文档页面

警告

目前不支持 plugin.tf.experimental.DALIDatasetWithInputs 的检查点。

警告

目前不支持 GPU 数据集的检查点。