使用 DALI 和 Flax 训练神经网络#
这个简单的例子展示了如何使用 DALI pipelines 训练在 Flax 中实现的神经网络。如果您想了解更多关于使用 Flax 训练神经网络的信息,请查看 Flax 入门 示例。
DALI 设置与使用纯 JAX 的训练示例非常相似。唯一的区别是在返回的图像中添加了尾部维度,以使其与 Flax 卷积兼容。如果您不熟悉如何在 JAX 中使用 DALI,您可以在DALI 和 JAX 入门示例中了解更多信息。
我们使用来自 DALI_extra 的 Caffe2 格式的 MNIST。
[1]:
import os
training_data_path = os.path.join(
os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)
第一步是创建一个迭代器定义函数,该函数稍后将用于创建 DALI 迭代器的实例。它定义了预处理的所有步骤。在这个简单的例子中,我们有 fn.readers.caffe2
用于读取 Caffe2 格式的数据,fn.decoders.image
用于图像解码,fn.crop_mirror_normalize
用于标准化图像,以及 fn.reshape
用于调整输出张量的形状。我们还使用 labels.gpu()
将标签从 CPU 移动到 GPU 内存,并应用 one-hot 编码以用于使用 fn.one_hot
进行训练。
本示例重点介绍如何在 JAX 中使用 DALI pipeline。有关 DALI 迭代器的更多信息,请查看DALI 和 JAX 入门和pipeline 文档
[2]:
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.jax import data_iterator
batch_size = 50
image_size = 28
num_classes = 10
@data_iterator(
output_map=["images", "labels"], reader_name="mnist_caffe2_reader"
)
def mnist_iterator(data_path, is_training):
jpegs, labels = fn.readers.caffe2(
path=data_path, random_shuffle=is_training, name="mnist_caffe2_reader"
)
images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0])
images = fn.reshape(images, shape=[-1]) # Flatten the output image
labels = labels.gpu()
if is_training:
labels = fn.one_hot(labels, num_classes=num_classes)
return images, labels
使用迭代器定义函数,我们现在可以创建 DALI 迭代器。
[3]:
print("Creating iterators")
training_iterator = mnist_iterator(
data_path=training_data_path, is_training=True, batch_size=batch_size
)
validation_iterator = mnist_iterator(
data_path=validation_data_path, is_training=False, batch_size=batch_size
)
print(training_iterator)
print(validation_iterator)
print(f"Number of batches in training iterator = {len(training_iterator)}")
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Creating iterators
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc240f4e50>
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc1c78e020>
Number of batches in training iterator = 1200
Number of batches in validation iterator = 200
通过上述设置,DALI 迭代器已准备好进行训练。
现在我们需要设置模型和训练实用程序。本笔记本的目标不是解释 Flax 概念。我们想展示如何使用 DALI 作为数据加载和预处理库来训练在 Flax 中实现的模型。我们使用标准的 Flax 工具来定义简单的神经网络。我们有函数来创建此网络的实例,在其上运行一个训练步骤,并在每个 epoch 结束时计算模型的准确率。
如果您想了解更多关于 Flax 的信息,并更好地理解下面的代码,请查看 Flax 文档。
[4]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=784)(x)
x = nn.relu(x)
x = nn.Dense(features=1024)(x)
x = nn.relu(x)
x = nn.Dense(features=1024)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def create_model_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([784]))["params"]
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx
)
@jax.jit
def train_step(model_state, batch):
def loss_fn(params):
logits = model_state.apply_fn({"params": params}, batch["images"])
loss = optax.softmax_cross_entropy(
logits=logits, labels=batch["labels"]
).mean()
return loss
grad_fn = jax.grad(loss_fn)
grads = grad_fn(model_state.params)
model_state = model_state.apply_gradients(grads=grads)
return model_state
def accuracy(model_state, iterator):
correct_predictions = 0
for batch in iterator:
logits = model_state.apply_fn(
{"params": model_state.params}, batch["images"]
)
correct_predictions = correct_predictions + jnp.sum(
batch["labels"].ravel() == jnp.argmax(logits, axis=-1)
)
return correct_predictions / iterator.size
使用上面定义的实用程序,我们可以创建我们想要训练的模型的实例。
[5]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.1
momentum = 0.9
model_state = create_model_state(init_rng, learning_rate, momentum)
此时,一切都已准备好运行训练。
[6]:
print("Starting training")
num_epochs = 5
for epoch in range(num_epochs):
print(f"Epoch {epoch}")
for batch in training_iterator:
model_state = train_step(model_state, batch)
acc = accuracy(model_state, validation_iterator)
print(f"Accuracy = {acc}")
Starting training
Epoch 0
Accuracy = 0.9551000595092773
Epoch 1
Accuracy = 0.9691000580787659
Epoch 2
Accuracy = 0.9738000631332397
Epoch 3
Accuracy = 0.9622000455856323
Epoch 4
Accuracy = 0.9604000449180603
带有 DALI 和 FLAX 的多 GPU#
本节介绍如何扩展上面的示例以使用多个 GPU。
同样,我们从创建一个迭代器定义函数开始。它是我们之前看到的函数的略微修改版本。
注意传递给 fn.readers.caffe2
的新参数,num_shards
和 shard_id
。它们用于控制分片
num_shards
设置分片总数shard_id
告诉 pipeline 它负责训练中的哪个分片。
我们将 devices
参数添加到装饰器,以指定我们要使用的设备。这里我们使用机器上 JAX 可用的所有 GPU。
[7]:
batch_size = 200
image_size = 28
num_classes = 10
@data_iterator(
output_map=["images", "labels"],
reader_name="mnist_caffe2_reader",
devices=jax.devices(),
)
def mnist_sharded_iterator(data_path, is_training, num_shards, shard_id):
jpegs, labels = fn.readers.caffe2(
path=data_path,
random_shuffle=is_training,
name="mnist_caffe2_reader",
num_shards=num_shards,
shard_id=shard_id,
)
images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
images = fn.crop_mirror_normalize(
images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
)
images = fn.reshape(images, shape=[-1]) # Flatten the output image
labels = labels.gpu()
if is_training:
labels = fn.one_hot(labels, num_classes=num_classes)
return images, labels
使用迭代器定义函数,我们现在可以创建用于在多个 GPU 上进行训练的 DALI 迭代器。此迭代器将返回与 pmapped
JAX 函数兼容的输出。
[8]:
print("Creating training iterator")
training_iterator = mnist_sharded_iterator(
data_path=training_data_path, is_training=True, batch_size=batch_size
)
print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300
为了简单起见,我们将在一个 GPU 上运行验证。我们可以重用来自单 GPU 示例的验证迭代器。唯一的区别是我们需要将模型拉到同一个 GPU。在现实场景中,这可能会很昂贵,但对于这个玩具教育示例来说已经足够了。
为了使模型与 pmap 风格的多 GPU 训练兼容,我们需要复制它。如果您想了解更多关于使用 pmap
在多个 GPU 上进行训练的信息,您可以查看 JAX 文档中的 JAX 中的并行评估 和 Flax 文档中的 在多个设备上集成。
[9]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.1
momentum = 0.9
model_state = jax.pmap(create_model_state, static_broadcasted_argnums=(1, 2))(
jax.random.split(init_rng, jax.device_count()), learning_rate, momentum
)
由于我们想在单个 GPU 上运行验证,我们只提取模型的一个副本并将其传递给 accuracy
函数。
现在,我们准备好使用 DALI 作为数据源在多个 GPU 上训练 Flax 模型。
[10]:
import flax
parallel_train_step = jax.pmap(train_step)
num_epochs = 5
for epoch in range(num_epochs):
print(f"Epoch {epoch}")
for batch in training_iterator:
model_state = parallel_train_step(model_state, batch)
acc = accuracy(flax.jax_utils.unreplicate(model_state), validation_iterator)
print(f"Accuracy = {acc}")
Epoch 0
Accuracy = 0.9509000182151794
Epoch 1
Accuracy = 0.9643000364303589
Epoch 2
Accuracy = 0.9724000692367554
Epoch 3
Accuracy = 0.9701000452041626
Epoch 4
Accuracy = 0.9758000373840332