使用多 GPU 进行训练#

在这里,我们将展示如何在多个 GPU 上运行来自 使用 DALI 和 JAX 训练神经网络 的训练。我们将使用相同的网络和相同的数据 pipeline。唯一的区别是我们将在多个 GPU 上运行它。为了更好地理解以下内容,建议首先阅读 使用 DALI 和 JAX 训练神经网络

要了解如何在多个 GPU 上运行 DALI 迭代器,请参阅 JAX 和 DALI 入门部分关于多 GPU 支持。它解释了如何在多个 GPU 上运行 DALI 迭代器。以下示例建立在该知识之上。

使用自动并行化进行训练#

在本节中,我们希望使用 JAX 的自动并行化机制将训练分散到多个 GPU 上。为此,我们需要定义要应用于计算的 sharding

要了解有关分片的更多信息,请参阅 JAX 文档中关于分布式数组和自动并行化的部分

[1]:
import jax
from jax.sharding import PositionalSharding, Mesh
from jax.experimental import mesh_utils


mesh = mesh_utils.create_device_mesh((jax.device_count(), 1))
sharding = PositionalSharding(mesh)

print(sharding)
PositionalSharding([[{GPU 0}]
                    [{GPU 1}]])

接下来,我们创建 DALI 迭代器函数。我们基于 使用 DALI 和 JAX 训练神经网络 示例中的函数,并添加了对带有 sharding 和相关参数的多个 GPU 的支持。

[2]:
from nvidia.dali.plugin.jax import data_iterator
import nvidia.dali.fn as fn
import nvidia.dali.types as types


image_size = 28
num_classes = 10


@data_iterator(
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    sharding=sharding,
)
def mnist_training_iterator(data_path, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=True,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id,
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()
    labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

为了简单起见,在本教程中,我们在单个 GPU 上运行验证。我们为验证数据创建适当的 DALI 迭代器函数。

[3]:
@data_iterator(
    output_map=["images", "labels"], reader_name="mnist_caffe2_reader"
)
def mnist_validation_iterator(data_path):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=False, name="mnist_caffe2_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    return images, labels

我们定义一些训练参数并创建迭代器实例。

[4]:
import os

training_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)

batch_size = 200
num_epochs = 5


training_iterator = mnist_training_iterator(
    batch_size=batch_size, data_path=training_data_path
)
print(f"Number of batches in training iterator = {len(training_iterator)}")

validation_iterator = mnist_validation_iterator(
    batch_size=batch_size, data_path=validation_data_path
)
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Number of batches in training iterator = 300
Number of batches in validation iterator = 50

准备好所有这些设置后,我们可以开始实际训练。我们从 使用 DALI 和 JAX 训练神经网络 示例中导入模型相关实用程序,并使用它们来训练模型。

训练循环中的每个 batch 都包含根据 sharding 参数分片的 imageslabels

请注意,对于验证,我们如何将模型拉到一个 GPU。如前所述,这样做是为了简单起见。在实际场景中,您可以在所有 GPU 上运行验证并平均结果。

[5]:
from model import init_model, accuracy
from model import update

model = init_model()

for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update(model, batch)

    model_on_one_device = jax.tree_map(
        lambda x: jax.device_put(x, jax.devices()[0]), model
    )
    test_acc = accuracy(model_on_one_device, validation_iterator)

    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6739000082015991
Epoch 1 sec
Test set accuracy 0.7844000458717346
Epoch 2 sec
Test set accuracy 0.8244000673294067
Epoch 3 sec
Test set accuracy 0.8455000519752502
Epoch 4 sec
Test set accuracy 0.860200047492981

使用 pmapped 迭代器进行训练#

JAX 提供了另一种机制来跨多个设备分散计算:pmap 函数。DALI 也可以支持这种并行化方式。

要了解有关 pmap 的更多信息,请查看 JAX 文档

在 DALI 中,要以与 pmapped 函数兼容的方式配置迭代器,我们传递 devices 参数而不是 sharding。这里我们使用所有可用的 GPU。迭代器将返回跨所有 GPU 分片的 batch

sharding 一样,在底层,迭代器将创建 DALI pipeline 的多个实例,并且每个实例将被分配给一个 GPU。当请求输出时,DALI 将同步实例并将结果作为单个 batch 返回。

[6]:
@data_iterator(
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    devices=jax.devices(),
)
def mnist_training_iterator(data_path, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=True,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id,
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()
    labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

我们以与之前相同的方式创建迭代器实例

[7]:
print("Creating training iterator")
training_iterator = mnist_training_iterator(
    batch_size=batch_size, data_path=training_data_path
)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

对于验证,我们将使用与之前相同的迭代器。由于我们在单个 GPU 上运行它,因此我们无需更改任何内容。我们可以再次将模型拉到一个 GPU 并运行验证。

[8]:
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Number of batches in validation iterator = 50

为了使模型与 pmap 样式的多 GPU 训练兼容,我们需要复制它。如果您想了解有关使用 pmap 在多个 GPU 上进行训练的更多信息,可以查看 JAX 文档中的 JAX 中的并行评估

[9]:
import jax.numpy as jnp
from model import init_model, accuracy


model = init_model()
model = jax.tree_map(lambda x: jnp.array([x] * jax.device_count()), model)

对于多 GPU 训练,我们导入 update_parallel 函数。它与 update 函数相同,只是增加了跨设备的梯度同步。这将确保来自不同设备的模型副本保持相同。

由于我们希望在单个 GPU 上运行验证,因此我们仅提取一个模型副本并将其传递给 accuracy 函数。

[10]:
from model import update_parallel


for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update_parallel(model, batch)

    test_acc = accuracy(
        jax.tree_map(lambda x: x[0], model), validation_iterator
    )

    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6885000467300415
Epoch 1 sec
Test set accuracy 0.7829000353813171
Epoch 2 sec
Test set accuracy 0.8222000598907471
Epoch 3 sec
Test set accuracy 0.8438000679016113
Epoch 4 sec
Test set accuracy 0.8580000400543213