在多 GPU 上运行 JAX 增强#

之前的教程 介绍了在 DALI 中使用 JAX 函数作为增强的简单用法。

在本教程中,我们将展示在编写数据处理增强时如何从 JAX 的多 GPU 分片中获益。我们将创建一个分布式数据加载器,该加载器可以访问来自不同分片中的样本。使用相同的设置,您可以扩展此示例以使用其他 JAX 分布式操作,例如跨全局批次的归约。

设置分布式运行#

在这个简化的示例中,我们将假设数据加载分布在两个进程中。每个进程中的数据加载器将在其自己的 GPU 上处理其自身的数据部分:一个将处理狗的图像,另一个将处理猫的图像。

请注意,在实际场景中,此示例的代码将由一些设置脚本启动,该脚本将确保将其启动为两个独立的进程,每个进程都将其唯一的进程 ID 作为启动参数接收。在此,为了演示目的,我们在本笔记本中逐步运行 process_id=0,而另一个进程 process_id=1 在后台运行。

[1]:
import subprocess

subprocess.Popen(["python3.10", "jax_operator_multi_gpu_process_1.py"])
[1]:
<Popen: returncode: None args: ['python3.10', 'jax_operator_multi_gpu_proces...>

我们首先初始化 JAX 分布式工作流程。

[2]:
import os
import jax

process_id = 0  # the other process is launched with process_id = 1

os.environ["CUDA_VISIBLE_DEVICES"] = str(process_id)

jax.distributed.initialize(
    coordinator_address="localhost:12321",
    num_processes=2,
    process_id=process_id,
)

现在我们创建网格描述 - 全局批次在两个进程之间拆分。

[3]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding

assert len(jax.devices()) == 2
assert len(jax.local_devices()) == 1

mesh = Mesh(jax.devices(), axis_names=("batch"))
sharding = NamedSharding(mesh, PartitionSpec("batch"))

带有分片的迭代器#

接下来,我们定义迭代器并传递上面定义的分片对象。基于此,iterator_function 将接收 shard_id,我们使用它来决定仅读取狗或猫的图像。我们读取、解码和调整图像大小。最后,我们调用在下一个单元格中定义的 global_mixup JAX 操作。

[4]:
import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator

dogs = [f"../data/images/dog/dog_{i}.jpg" for i in range(1, 9)]
kittens = [f"../data/images/kitten/cat_{i}.jpg" for i in range(1, 9)]


@data_iterator(
    output_map=["images"],
    sharding=sharding,
)
def iterator_function(shard_id, num_shards):
    assert num_shards == 2
    jpegs, _ = fn.readers.file(
        files=dogs if shard_id == 0 else kittens, name="image_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(244, 244))

    # mixup images between shards
    images = global_mixup(images)
    return images

带有分片的 JAX 增强#

global_mixup 函数以样本方式混合来自两个并发运行的进程的两个批次。

我们在 data_iterator 调用中传递的 sharding 对象也传递给 jax_function。这提示 DALI 如何构建 JAX 全局数组。请注意,目前,jax_function 中对 sharding 的支持仅限于全局分片,即当每个数据加载进程看到单个(每个进程中不同的)GPU 设备时。

jax.experimental.shard_map 规范允许我们使用 jax.lax.pshuffle

[5]:
from functools import partial
from jax.experimental.shard_map import shard_map

from nvidia.dali.plugin.jax.fn import jax_function


@jax_function(sharding=sharding)
@jax.jit
@partial(
    shard_map,
    mesh=sharding.mesh,
    in_specs=PartitionSpec("batch"),
    out_specs=PartitionSpec("batch"),
)
@jax.vmap
def global_mixup(sample):
    mixed_up = 0.5 * sample + 0.5 * jax.lax.pshuffle(sample, "batch", [1, 0])
    mixed_up = jax.numpy.clip(mixed_up, 0, 255)
    return jax.numpy.array(mixed_up, dtype=jax.numpy.uint8)

现在,我们准备测试数据加载器。

[6]:
local_batch_size = 8
num_shards = 2

iterator = iterator_function(
    batch_size=num_shards * local_batch_size, num_threads=4
)
batch = next(iterator)

让我们定义一个简单的辅助函数来呈现生成的批次。

[7]:
import matplotlib.pyplot as plt
from matplotlib import gridspec


def show_image(images, columns=4, fig_size=24):
    rows = (len(images) + columns - 1) // columns
    plt.figure(figsize=(fig_size, (fig_size // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(images[j])
[8]:
local_batches = [x.data for x in batch["images"].addressable_shards]
assert len(local_batches) == 1

show_image(local_batches[0])
../../_images/examples_custom_operations_jax_operator_multi_gpu_16_0.png