使用 DALI 和 Flax 训练神经网络#

这个简单的例子展示了如何使用 DALI pipelines 训练在 Flax 中实现的神经网络。如果您想了解更多关于使用 Flax 训练神经网络的信息,请查看 Flax 入门 示例。

DALI 设置与使用纯 JAX 的训练示例非常相似。唯一的区别是在返回的图像中添加了尾部维度,以使其与 Flax 卷积兼容。如果您不熟悉如何在 JAX 中使用 DALI,您可以在DALI 和 JAX 入门示例中了解更多信息。

我们使用来自 DALI_extra 的 Caffe2 格式的 MNIST。

[1]:
import os

training_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)

第一步是创建一个迭代器定义函数,该函数稍后将用于创建 DALI 迭代器的实例。它定义了预处理的所有步骤。在这个简单的例子中,我们有 fn.readers.caffe2 用于读取 Caffe2 格式的数据,fn.decoders.image 用于图像解码,fn.crop_mirror_normalize 用于标准化图像,以及 fn.reshape 用于调整输出张量的形状。我们还使用 labels.gpu() 将标签从 CPU 移动到 GPU 内存,并应用 one-hot 编码以用于使用 fn.one_hot 进行训练。

本示例重点介绍如何在 JAX 中使用 DALI pipeline。有关 DALI 迭代器的更多信息,请查看DALI 和 JAX 入门pipeline 文档

[2]:
import nvidia.dali.fn as fn
import nvidia.dali.types as types

from nvidia.dali.plugin.jax import data_iterator


batch_size = 50
image_size = 28
num_classes = 10


@data_iterator(
    output_map=["images", "labels"], reader_name="mnist_caffe2_reader"
)
def mnist_iterator(data_path, is_training):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=is_training, name="mnist_caffe2_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0])
    images = fn.reshape(images, shape=[-1])  # Flatten the output image

    labels = labels.gpu()

    if is_training:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

使用迭代器定义函数,我们现在可以创建 DALI 迭代器。

[3]:
print("Creating iterators")
training_iterator = mnist_iterator(
    data_path=training_data_path, is_training=True, batch_size=batch_size
)
validation_iterator = mnist_iterator(
    data_path=validation_data_path, is_training=False, batch_size=batch_size
)

print(training_iterator)
print(validation_iterator)

print(f"Number of batches in training iterator = {len(training_iterator)}")
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Creating iterators
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc240f4e50>
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc1c78e020>
Number of batches in training iterator = 1200
Number of batches in validation iterator = 200

通过上述设置,DALI 迭代器已准备好进行训练。

现在我们需要设置模型和训练实用程序。本笔记本的目标不是解释 Flax 概念。我们想展示如何使用 DALI 作为数据加载和预处理库来训练在 Flax 中实现的模型。我们使用标准的 Flax 工具来定义简单的神经网络。我们有函数来创建此网络的实例,在其上运行一个训练步骤,并在每个 epoch 结束时计算模型的准确率。

如果您想了解更多关于 Flax 的信息,并更好地理解下面的代码,请查看 Flax 文档

[4]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import optax


class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=784)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1024)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1024)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x


def create_model_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([784]))["params"]
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx
    )


@jax.jit
def train_step(model_state, batch):
    def loss_fn(params):
        logits = model_state.apply_fn({"params": params}, batch["images"])
        loss = optax.softmax_cross_entropy(
            logits=logits, labels=batch["labels"]
        ).mean()
        return loss

    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state


def accuracy(model_state, iterator):
    correct_predictions = 0
    for batch in iterator:
        logits = model_state.apply_fn(
            {"params": model_state.params}, batch["images"]
        )
        correct_predictions = correct_predictions + jnp.sum(
            batch["labels"].ravel() == jnp.argmax(logits, axis=-1)
        )

    return correct_predictions / iterator.size

使用上面定义的实用程序,我们可以创建我们想要训练的模型的实例。

[5]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

model_state = create_model_state(init_rng, learning_rate, momentum)

此时,一切都已准备好运行训练。

[6]:
print("Starting training")

num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch in training_iterator:
        model_state = train_step(model_state, batch)

    acc = accuracy(model_state, validation_iterator)
    print(f"Accuracy = {acc}")
Starting training
Epoch 0
Accuracy = 0.9551000595092773
Epoch 1
Accuracy = 0.9691000580787659
Epoch 2
Accuracy = 0.9738000631332397
Epoch 3
Accuracy = 0.9622000455856323
Epoch 4
Accuracy = 0.9604000449180603

带有 DALI 和 FLAX 的多 GPU#

本节介绍如何扩展上面的示例以使用多个 GPU。

同样,我们从创建一个迭代器定义函数开始。它是我们之前看到的函数的略微修改版本。

注意传递给 fn.readers.caffe2 的新参数,num_shardsshard_id。它们用于控制分片

  • num_shards 设置分片总数

  • shard_id 告诉 pipeline 它负责训练中的哪个分片。

我们将 devices 参数添加到装饰器,以指定我们要使用的设备。这里我们使用机器上 JAX 可用的所有 GPU。

[7]:
batch_size = 200
image_size = 28
num_classes = 10


@data_iterator(
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    devices=jax.devices(),
)
def mnist_sharded_iterator(data_path, is_training, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=is_training,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id,
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )
    images = fn.reshape(images, shape=[-1])  # Flatten the output image

    labels = labels.gpu()

    if is_training:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

使用迭代器定义函数,我们现在可以创建用于在多个 GPU 上进行训练的 DALI 迭代器。此迭代器将返回与 pmapped JAX 函数兼容的输出。

[8]:
print("Creating training iterator")
training_iterator = mnist_sharded_iterator(
    data_path=training_data_path, is_training=True, batch_size=batch_size
)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

为了简单起见,我们将在一个 GPU 上运行验证。我们可以重用来自单 GPU 示例的验证迭代器。唯一的区别是我们需要将模型拉到同一个 GPU。在现实场景中,这可能会很昂贵,但对于这个玩具教育示例来说已经足够了。

为了使模型与 pmap 风格的多 GPU 训练兼容,我们需要复制它。如果您想了解更多关于使用 pmap 在多个 GPU 上进行训练的信息,您可以查看 JAX 文档中的 JAX 中的并行评估 和 Flax 文档中的 在多个设备上集成

[9]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

model_state = jax.pmap(create_model_state, static_broadcasted_argnums=(1, 2))(
    jax.random.split(init_rng, jax.device_count()), learning_rate, momentum
)

由于我们想在单个 GPU 上运行验证,我们只提取模型的一个副本并将其传递给 accuracy 函数。

现在,我们准备好使用 DALI 作为数据源在多个 GPU 上训练 Flax 模型。

[10]:
import flax

parallel_train_step = jax.pmap(train_step)

num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch in training_iterator:
        model_state = parallel_train_step(model_state, batch)

    acc = accuracy(flax.jax_utils.unreplicate(model_state), validation_iterator)
    print(f"Accuracy = {acc}")
Epoch 0
Accuracy = 0.9509000182151794
Epoch 1
Accuracy = 0.9643000364303589
Epoch 2
Accuracy = 0.9724000692367554
Epoch 3
Accuracy = 0.9701000452041626
Epoch 4
Accuracy = 0.9758000373840332