在多 GPU 上运行 JAX 增强#
之前的教程 介绍了在 DALI 中使用 JAX 函数作为增强的简单用法。
在本教程中,我们将展示在编写数据处理增强时如何从 JAX 的多 GPU 分片中获益。我们将创建一个分布式数据加载器,该加载器可以访问来自不同分片中的样本。使用相同的设置,您可以扩展此示例以使用其他 JAX 分布式操作,例如跨全局批次的归约。
设置分布式运行#
在这个简化的示例中,我们将假设数据加载分布在两个进程中。每个进程中的数据加载器将在其自己的 GPU 上处理其自身的数据部分:一个将处理狗的图像,另一个将处理猫的图像。
请注意,在实际场景中,此示例的代码将由一些设置脚本启动,该脚本将确保将其启动为两个独立的进程,每个进程都将其唯一的进程 ID 作为启动参数接收。在此,为了演示目的,我们在本笔记本中逐步运行 process_id=0
,而另一个进程 process_id=1
在后台运行。
[1]:
import subprocess
subprocess.Popen(["python3.10", "jax_operator_multi_gpu_process_1.py"])
[1]:
<Popen: returncode: None args: ['python3.10', 'jax_operator_multi_gpu_proces...>
我们首先初始化 JAX 分布式工作流程。
[2]:
import os
import jax
process_id = 0 # the other process is launched with process_id = 1
os.environ["CUDA_VISIBLE_DEVICES"] = str(process_id)
jax.distributed.initialize(
coordinator_address="localhost:12321",
num_processes=2,
process_id=process_id,
)
现在我们创建网格描述 - 全局批次在两个进程之间拆分。
[3]:
from jax.sharding import Mesh, PartitionSpec, NamedSharding
assert len(jax.devices()) == 2
assert len(jax.local_devices()) == 1
mesh = Mesh(jax.devices(), axis_names=("batch"))
sharding = NamedSharding(mesh, PartitionSpec("batch"))
带有分片的迭代器#
接下来,我们定义迭代器并传递上面定义的分片对象。基于此,iterator_function
将接收 shard_id
,我们使用它来决定仅读取狗或猫的图像。我们读取、解码和调整图像大小。最后,我们调用在下一个单元格中定义的 global_mixup
JAX 操作。
[4]:
import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator
dogs = [f"../data/images/dog/dog_{i}.jpg" for i in range(1, 9)]
kittens = [f"../data/images/kitten/cat_{i}.jpg" for i in range(1, 9)]
@data_iterator(
output_map=["images"],
sharding=sharding,
)
def iterator_function(shard_id, num_shards):
assert num_shards == 2
jpegs, _ = fn.readers.file(
files=dogs if shard_id == 0 else kittens, name="image_reader"
)
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, size=(244, 244))
# mixup images between shards
images = global_mixup(images)
return images
带有分片的 JAX 增强#
global_mixup
函数以样本方式混合来自两个并发运行的进程的两个批次。
我们在 data_iterator
调用中传递的 sharding
对象也传递给 jax_function
。这提示 DALI 如何构建 JAX 全局数组。请注意,目前,jax_function
中对 sharding
的支持仅限于全局分片,即当每个数据加载进程看到单个(每个进程中不同的)GPU 设备时。
jax.experimental.shard_map
规范允许我们使用 jax.lax.pshuffle
。
[5]:
from functools import partial
from jax.experimental.shard_map import shard_map
from nvidia.dali.plugin.jax.fn import jax_function
@jax_function(sharding=sharding)
@jax.jit
@partial(
shard_map,
mesh=sharding.mesh,
in_specs=PartitionSpec("batch"),
out_specs=PartitionSpec("batch"),
)
@jax.vmap
def global_mixup(sample):
mixed_up = 0.5 * sample + 0.5 * jax.lax.pshuffle(sample, "batch", [1, 0])
mixed_up = jax.numpy.clip(mixed_up, 0, 255)
return jax.numpy.array(mixed_up, dtype=jax.numpy.uint8)
现在,我们准备测试数据加载器。
[6]:
local_batch_size = 8
num_shards = 2
iterator = iterator_function(
batch_size=num_shards * local_batch_size, num_threads=4
)
batch = next(iterator)
让我们定义一个简单的辅助函数来呈现生成的批次。
[7]:
import matplotlib.pyplot as plt
from matplotlib import gridspec
def show_image(images, columns=4, fig_size=24):
rows = (len(images) + columns - 1) // columns
plt.figure(figsize=(fig_size, (fig_size // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
plt.imshow(images[j])
[8]:
local_batches = [x.data for x in batch["images"].addressable_shards]
assert len(local_batches) == 1
show_image(local_batches[0])
