检查点#
检查点是 DALI 中的一项功能,允许您将 pipeline 的当前状态保存到文件中。然后,您可以从已保存的检查点恢复 pipeline,新的 pipeline 将产生与旧 pipeline 完全相同的输出。这对于可能被中断的长时间运行的训练作业特别有用。
DALI pipeline 的检查点包含有关 pipeline 中使用的所有随机数生成器的状态以及每个 reader 的进度的信息。
检查点 API#
启用检查点#
要启用检查点,在创建 pipeline 时设置 enable_checkpointing=True
。启用此选项后,DALI 将跟踪每个操作符的状态,允许您按需保存它。启用检查点不应影响性能。
@pipeline_def(..., enable_checkpointing=True)
def pipeline():
...
p = pipeline()
注意
如果启用了检查点,shuffle_after_epoch=True
的 Readers 可能会以不同的方式打乱样本。
保存检查点#
要保存检查点,您需要调用 Pipeline.checkpoint()
方法,该方法将返回序列化的检查点作为字符串。可选地,您可以将文件名作为参数传递,DALI 将检查点保存在那里。
for _ in range(iters):
output = p.run()
# Write the checkpoint to file:
checkpoint = p.checkpoint()
open('checkpoint_file.cpt', 'wb')
# Or simply:
checkpoint = p.checkpoint('checkpoint_file.cpt')
注意
调用 Pipeline.checkpoint()
方法可能会引入可观察到的开销。我们建议您不要过于频繁地调用它。
从检查点恢复#
您可以稍后从已保存的检查点恢复 pipeline 状态。为此,在构造时将 checkpoint 参数传递给 Pipeline
。这样的 pipeline 应该返回与原始 pipeline 完全相同的输出。
checkpoint = open('checkpoint_file.cpt', 'rb').read()
p_restored = pipeline(checkpoint=checkpoint)
警告
确保您要恢复的 pipeline 与原始 pipeline 相同,即包含相同的操作符和相同的参数。从使用不同 pipeline 创建的检查点恢复将导致未定义的行为。
External source 检查点#
fn.external_source()
操作符仅部分支持检查点。
仅当 source
是接受批次索引、BatchInfo
或 SampleInfo
的单参数可调用对象时,才支持检查点。对于此类 sources
,查询将从检查点中保存的点继续。
其他类型的 source
不支持检查点。它们的状态不会保存在检查点中,并且从检查点恢复后,它们将从头开始。如果您想使用检查点,我们建议您将 source 重写为受支持的可调用对象。
TensorFlow 插件中的检查点#
plugin.tf.DALIDataset
与 TensorFlow 的 tf.train.checkpoint
集成。有关更多详细信息,请参阅 TensorFlow 检查点文档页面。
警告
当前 plugin.tf.experimental.DALIDatasetWithInputs
不支持检查点。
警告
当前 GPU 数据集不支持检查点。