使用多 GPU 进行训练#
在这里,我们将展示如何在多个 GPU 上运行来自 使用 DALI 和 JAX 训练神经网络 的训练。我们将使用相同的网络和相同的数据 pipeline。唯一的区别是我们将在多个 GPU 上运行它。为了更好地理解以下内容,建议首先阅读 使用 DALI 和 JAX 训练神经网络。
要了解如何在多个 GPU 上运行 DALI 迭代器,请参阅 JAX 和 DALI 入门部分关于多 GPU 支持。它解释了如何在多个 GPU 上运行 DALI 迭代器。以下示例建立在该知识之上。
使用自动并行化进行训练#
在本节中,我们希望使用 JAX 的自动并行化机制将训练分散到多个 GPU 上。为此,我们需要定义要应用于计算的 sharding
。
要了解有关分片的更多信息,请参阅 JAX 文档中关于分布式数组和自动并行化的部分。
[1]:
import jax
from jax.sharding import PositionalSharding, Mesh
from jax.experimental import mesh_utils
mesh = mesh_utils.create_device_mesh((jax.device_count(), 1))
sharding = PositionalSharding(mesh)
print(sharding)
PositionalSharding([[{GPU 0}]
[{GPU 1}]])
接下来,我们创建 DALI 迭代器函数。我们基于 使用 DALI 和 JAX 训练神经网络 示例中的函数,并添加了对带有 sharding
和相关参数的多个 GPU 的支持。
[2]:
from nvidia.dali.plugin.jax import data_iterator
import nvidia.dali.fn as fn
import nvidia.dali.types as types
image_size = 28
num_classes = 10
@data_iterator(
output_map=["images", "labels"],
reader_name="mnist_caffe2_reader",
sharding=sharding,
)
def mnist_training_iterator(data_path, num_shards, shard_id):
jpegs, labels = fn.readers.caffe2(
path=data_path,
random_shuffle=True,
name="mnist_caffe2_reader",
num_shards=num_shards,
shard_id=shard_id,
)
images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
images = fn.crop_mirror_normalize(
images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
)
images = fn.reshape(images, shape=[image_size * image_size])
labels = labels.gpu()
labels = fn.one_hot(labels, num_classes=num_classes)
return images, labels
为了简单起见,在本教程中,我们在单个 GPU 上运行验证。我们为验证数据创建适当的 DALI 迭代器函数。
[3]:
@data_iterator(
output_map=["images", "labels"], reader_name="mnist_caffe2_reader"
)
def mnist_validation_iterator(data_path):
jpegs, labels = fn.readers.caffe2(
path=data_path, random_shuffle=False, name="mnist_caffe2_reader"
)
images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
images = fn.crop_mirror_normalize(
images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
)
images = fn.reshape(images, shape=[image_size * image_size])
labels = labels.gpu()
return images, labels
我们定义一些训练参数并创建迭代器实例。
[4]:
import os
training_data_path = os.path.join(
os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)
batch_size = 200
num_epochs = 5
training_iterator = mnist_training_iterator(
batch_size=batch_size, data_path=training_data_path
)
print(f"Number of batches in training iterator = {len(training_iterator)}")
validation_iterator = mnist_validation_iterator(
batch_size=batch_size, data_path=validation_data_path
)
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Number of batches in training iterator = 300
Number of batches in validation iterator = 50
准备好所有这些设置后,我们可以开始实际训练。我们从 使用 DALI 和 JAX 训练神经网络 示例中导入模型相关实用程序,并使用它们来训练模型。
训练循环中的每个 batch
都包含根据 sharding
参数分片的 images
和 labels
。
请注意,对于验证,我们如何将模型拉到一个 GPU。如前所述,这样做是为了简单起见。在实际场景中,您可以在所有 GPU 上运行验证并平均结果。
[5]:
from model import init_model, accuracy
from model import update
model = init_model()
for epoch in range(num_epochs):
for it, batch in enumerate(training_iterator):
model = update(model, batch)
model_on_one_device = jax.tree_map(
lambda x: jax.device_put(x, jax.devices()[0]), model
)
test_acc = accuracy(model_on_one_device, validation_iterator)
print(f"Epoch {epoch} sec")
print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6739000082015991
Epoch 1 sec
Test set accuracy 0.7844000458717346
Epoch 2 sec
Test set accuracy 0.8244000673294067
Epoch 3 sec
Test set accuracy 0.8455000519752502
Epoch 4 sec
Test set accuracy 0.860200047492981
使用 pmapped
迭代器进行训练#
JAX 提供了另一种机制来跨多个设备分散计算:pmap
函数。DALI 也可以支持这种并行化方式。
要了解有关 pmap
的更多信息,请查看 JAX 文档。
在 DALI 中,要以与 pmapped
函数兼容的方式配置迭代器,我们传递 devices
参数而不是 sharding
。这里我们使用所有可用的 GPU。迭代器将返回跨所有 GPU 分片的 batch
。
与 sharding
一样,在底层,迭代器将创建 DALI pipeline 的多个实例,并且每个实例将被分配给一个 GPU。当请求输出时,DALI 将同步实例并将结果作为单个 batch
返回。
[6]:
@data_iterator(
output_map=["images", "labels"],
reader_name="mnist_caffe2_reader",
devices=jax.devices(),
)
def mnist_training_iterator(data_path, num_shards, shard_id):
jpegs, labels = fn.readers.caffe2(
path=data_path,
random_shuffle=True,
name="mnist_caffe2_reader",
num_shards=num_shards,
shard_id=shard_id,
)
images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
images = fn.crop_mirror_normalize(
images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
)
images = fn.reshape(images, shape=[image_size * image_size])
labels = labels.gpu()
labels = fn.one_hot(labels, num_classes=num_classes)
return images, labels
我们以与之前相同的方式创建迭代器实例
[7]:
print("Creating training iterator")
training_iterator = mnist_training_iterator(
batch_size=batch_size, data_path=training_data_path
)
print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300
对于验证,我们将使用与之前相同的迭代器。由于我们在单个 GPU 上运行它,因此我们无需更改任何内容。我们可以再次将模型拉到一个 GPU 并运行验证。
[8]:
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Number of batches in validation iterator = 50
为了使模型与 pmap 样式的多 GPU 训练兼容,我们需要复制它。如果您想了解有关使用 pmap
在多个 GPU 上进行训练的更多信息,可以查看 JAX 文档中的 JAX 中的并行评估。
[9]:
import jax.numpy as jnp
from model import init_model, accuracy
model = init_model()
model = jax.tree_map(lambda x: jnp.array([x] * jax.device_count()), model)
对于多 GPU 训练,我们导入 update_parallel
函数。它与 update
函数相同,只是增加了跨设备的梯度同步。这将确保来自不同设备的模型副本保持相同。
由于我们希望在单个 GPU 上运行验证,因此我们仅提取一个模型副本并将其传递给 accuracy
函数。
[10]:
from model import update_parallel
for epoch in range(num_epochs):
for it, batch in enumerate(training_iterator):
model = update_parallel(model, batch)
test_acc = accuracy(
jax.tree_map(lambda x: x[0], model), validation_iterator
)
print(f"Epoch {epoch} sec")
print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6885000467300415
Epoch 1 sec
Test set accuracy 0.7829000353813171
Epoch 2 sec
Test set accuracy 0.8222000598907471
Epoch 3 sec
Test set accuracy 0.8438000679016113
Epoch 4 sec
Test set accuracy 0.8580000400543213