使用 DALI 和 Paxml 训练神经网络#

这个简单的例子展示了如何使用 DALI 数据预处理来训练在 Paxml 中实现的神经网络。它基于 Paxml 代码库中的 MNIST 训练示例,可以在这里找到。

我们使用来自 DALI_extra 的 Caffe2 格式的 MNIST 作为数据源。

[1]:
import os

training_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)

第一步是创建迭代器定义函数,该函数稍后将用于创建 DALI 迭代器的实例。它定义了预处理的所有步骤。在这个简单的例子中,我们有 fn.readers.caffe2 用于读取 Caffe2 格式的数据,fn.decoders.image 用于图像解码,fn.crop_mirror_normalize 用于归一化图像,以及 fn.reshape 用于调整输出张量的形状。

此示例重点介绍如何将 DALI pipeline 与 Paxml 一起使用。有关编写 DALI 迭代器的更多信息,请查看 DALI 和 JAX 入门pipeline 文档。要了解有关 Paxml 以及如何使用它编写神经网络的更多信息,请查看 Paxml Github 页面

[2]:
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.jax import data_iterator


@data_iterator(
    output_map=["inputs", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True,
)
def mnist_iterator(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader",
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="HWC"
    )

    labels = labels.gpu()
    labels = fn.reshape(labels, shape=[])

    return images, labels

此示例使用 Praxis 中定义的 Pax 数据输入。我们将创建一个简单的包装器,该包装器使用 DALI 迭代器 for JAX 作为数据源。

[3]:
from praxis import base_input
from nvidia.dali.plugin import jax as dax


class MnistDaliInput(base_input.BaseInput):
    def __post_init__(self):
        super().__post_init__()

        data_path = (
            training_data_path if self.is_training else validation_data_path
        )

        training_pipeline = mnist_iterator(
            data_path=data_path,
            random_shuffle=self.is_training,
            batch_size=self.batch_size,
        )
        self._iterator = dax.DALIGenericIterator(
            training_pipeline,
            output_map=["inputs", "labels"],
            reader_name="mnist_caffe2_reader",
            auto_reset=True,
        )

    def get_next(self):
        try:
            return next(self._iterator)
        except StopIteration:
            self._iterator.reset()
            return next(self._iterator)

    def reset(self) -> None:
        super().reset()
        self._iterator = self._iterator.reset()

MnistDaliInput 可以在 Pax Experiment 中用作数据源。下面的代码示例展示了如何通过定义 Experiment 类的 datasets 方法来连接这两个类。

def datasets(self) -> list[pax_fiddle.Config[base_input.BaseInput]]:
  return [
      pax_fiddle.Config(
          MnistDaliInput, batch_size=self.BATCH_SIZE, is_training=True
      )
  ]

有关完整的可运行示例,您可以查看 docs/examples/frameworks/jax/pax_examples。可以通过运行以下命令来测试此文件夹中的代码。

[4]:
!python -m paxml.main --job_log_dir=/tmp/dali_pax_logs --exp pax_examples.dali_pax_example.MnistExperiment 2>/dev/null

它在 /tmp/dali_pax_logs 中生成与 tensorboard 兼容的日志。我们使用一个辅助函数从日志中读取训练准确率并在终端中打印出来。

[5]:
from tensorflow.core.util import event_pb2
from tensorflow.python.lib.io import tf_record
from tensorflow.python.framework import tensor_util


def print_logs(path):
    "Helper function to print logs from logs directory created by paxml example"

    def summary_iterator():
        for r in tf_record.tf_record_iterator(path):
            yield event_pb2.Event.FromString(r)

    for summary in summary_iterator():
        for value in summary.summary.value:
            if value.tag == "Metrics/accuracy":
                t = tensor_util.MakeNdarray(value.tensor)
                print(f"Iteration: {summary.step}, accuracy: {t}")

使用此辅助函数,我们可以在 Python 代码内部打印训练的准确率。

[6]:
for file in os.listdir("/tmp/dali_pax_logs/summaries/train/"):
    print_logs(os.path.join("/tmp/dali_pax_logs/summaries/train/", file))
Iteration: 100, accuracy: 0.3935546875
Iteration: 200, accuracy: 0.5634765625
Iteration: 300, accuracy: 0.728515625
Iteration: 400, accuracy: 0.8369140625
Iteration: 500, accuracy: 0.87109375
Iteration: 600, accuracy: 0.87890625
Iteration: 700, accuracy: 0.884765625
Iteration: 800, accuracy: 0.8994140625
Iteration: 900, accuracy: 0.8994140625
Iteration: 1000, accuracy: 0.90625