使用 DALI 和 JAX 训练神经网络#

这个简单的示例展示了如何使用 DALI pipelines 训练在 JAX 中实现的神经网络。它基于 JAX 代码库中的 MNIST 训练示例,可以在这里找到。

我们将使用来自 DALI_extra 的 Caffe2 格式的 MNIST。

[1]:
import os

training_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)

第一步是创建一个定义函数,该函数稍后将用于创建 DALI 迭代器的实例。它定义了预处理的所有步骤。

在这个简单的示例中,我们有 fn.readers.caffe2 用于读取 Caffe2 格式的数据,fn.decoders.image 用于图像解码,fn.crop_mirror_normalize 用于归一化图像,以及 fn.reshape 用于调整输出张量的形状。我们还使用 labels.gpu() 将标签从 CPU 移动到 GPU 内存。我们的模型期望标签采用 one-hot 编码,因此我们使用 fn.one_hot 来转换它们。

此示例重点介绍如何使用 DALI 来训练在 JAX 中定义的模型。有关 DALI 和 JAX 集成的更多信息,请查看JAX 和 DALI 入门pipeline 文档

[2]:
from nvidia.dali.plugin.jax import data_iterator
import nvidia.dali.fn as fn
import nvidia.dali.types as types


batch_size = 200
image_size = 28
num_classes = 10


@data_iterator(output_map=["images", "labels"], reader_name="caffe2_reader")
def mnist_iterator(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=random_shuffle, name="caffe2_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

接下来,我们使用该函数为训练和验证创建 DALI 迭代器。

[3]:
print("Creating iterators")

training_iterator = mnist_iterator(
    data_path=training_data_path, random_shuffle=True, batch_size=batch_size
)

validation_iterator = mnist_iterator(
    data_path=validation_data_path, random_shuffle=False, batch_size=batch_size
)

print(training_iterator)
print(validation_iterator)
Creating iterators
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>

通过上述设置,DALI 迭代器已准备好进行训练。

最后,我们导入在 JAX 中实现的训练实用程序。init_model 将创建模型实例并初始化其参数。在这个简单的示例中,它是一个具有两个隐藏层的 MLP 模型。update 执行训练的一次迭代。accuracy 是一个辅助函数,用于在每个 epoch 后在测试集上运行验证并获取模型当前的准确率。

[4]:
from model import init_model, update, accuracy

此时,一切都已准备好运行训练。

[5]:
print("Starting training")

model = init_model()
num_epochs = 5

for epoch in range(num_epochs):
    for batch in training_iterator:
        model = update(model, batch)

    test_acc = accuracy(model, validation_iterator)
    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Starting training
Epoch 0 sec
Test set accuracy 0.67330002784729
Epoch 1 sec
Test set accuracy 0.7855000495910645
Epoch 2 sec
Test set accuracy 0.8251000642776489
Epoch 3 sec
Test set accuracy 0.8469000458717346
Epoch 4 sec
Test set accuracy 0.8616000413894653