数据加载:MXNet recordIO#

概述#

这个例子展示了如何使用存储在 MXNet recordIO 格式中的数据与 DALI。

创建索引#

要使用存储在 recordIO 格式中的数据,我们需要使用 readers.mxnet 操作符。除了所有读取器通用的参数(例如 random_shuffle)之外,此操作符还接受 pathindex_path 参数

  • path 是 recordIO 文件路径的列表

  • index_path 是一个列表(大小为 1),其中包含索引文件的路径。当您使用 MXNet 的 im2rec.py 工具时,会自动创建此文件(带有 .idx 扩展名);也可以使用 DALI 附带的 rec2idx 工具从 recordIO 文件中获取。

DALI_EXTRA_PATH 环境变量应指向从 DALI extra repository 下载数据的位置。

重要提示:请确保您检出与已安装 DALI 版本对应的正确发布标签。

[1]:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np
import os.path

test_data_root = os.environ["DALI_EXTRA_PATH"]
base = os.path.join(test_data_root, "db", "recordio")
batch_size = 16

idx_files = [base + "/train.idx"]
rec_files = [base + "/train.rec"]

定义和运行 Pipeline#

  1. 定义一个简单的 pipeline,它接受存储在 recordIO 格式中的图像,解码它们,并为在深度学习框架中摄取做准备。

    图像处理涉及裁剪、归一化和 HWC -> CHW 转换过程。

[2]:
pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=0)
with pipe:
    jpegs, labels = fn.readers.mxnet(path=rec_files, index_path=idx_files)
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
    output = fn.crop_mirror_normalize(
        images,
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
    )
    pipe.set_outputs(output, labels)
  1. 现在让我们构建并运行 pipeline

[3]:
pipe.build()
pipe_out = pipe.run()
  1. 为了可视化结果,请使用 matplotlib 库,它期望图像为 HWC 格式,但 pipeline 的输出为 CHW 格式。

    注意CHW 是大多数深度学习框架的首选格式。

  2. 为了可视化目的,将图像转置回 HWC 布局。

[4]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

%matplotlib inline


def show_images(image_batch):
    columns = 4
    rows = (batch_size + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        img_chw = image_batch.at(j)
        img_hwc = np.transpose(img_chw, (1, 2, 0)) / 255.0
        plt.imshow(img_hwc)
[5]:
images, labels = pipe_out
show_images(images.as_cpu())
../../../_images/examples_general_data_loading_dataloading_recordio_8_0.svg