数据加载: Webdataset#

概述#

此示例向您展示如何将以 Webdataset 格式存储的数据与 DALI 一起使用。

使用 readers.webdataset 操作器#

以 WebDataset 格式存储的数据可以使用 readers.webdataset 操作器读取。该操作器接受以下参数

  • paths tar 归档文件的路径(或路径列表),其中包含 webdataset

  • index_paths 相应索引文件的路径(或路径列表),其中包含有关 tar 文件具体信息的数据,使用 wds2idx 创建 - DALI 附带的实用程序。有关使用详情,请参阅 wds2idx -h。如果未提供,将从 tar 文件自动推断,尽管对于大型数据集来说,这将花费相当长的时间。

  • ext 扩展名集(或扩展名列表),以“;”分隔,指定操作器的输出以及哪些样本组件将馈送到操作器的特定输出

  • missing_component_behavior reader 在遇到无法为特定输出返回任何组件的样本时的行为。有 3 个选项

    • empty(默认)为该输出返回一个空张量

    • skip 跳过缺少组件的样本

    • error 引发错误

  • dtypes 操作器输出的数据类型。如果输出组件的大小无法被类型的大小整除,则会引发错误。如果未提供,则数据将以 UINT8 形式返回。除了这些参数外,该操作器还接受所有 reader 通用的参数,这些参数配置随机数生成器的种子、shuffle、分片以及在 epoch 结束时处理不完整批次。

创建索引#

索引文件(路径在参数 index_paths 中传递)可以使用 DALI 捆绑的工具 wds2idx 生成。

注意: DALI_EXTRA_PATH 环境变量应指向从 DALI extra 存储库 下载的数据的位置。

重要提示: 确保您检出与已安装 DALI 版本对应的正确发布标签。

[1]:
from subprocess import call
import os.path

test_data_root = os.environ["DALI_EXTRA_PATH"]
wds_data = os.path.join(test_data_root, "db", "webdataset", "train.tar")
batch_size = 16

定义和运行 Pipeline#

  1. 定义一个简单的 Pipeline,该 Pipeline 接受以 Webdataset 格式存储的图像并对其进行解码。

    在此示例中,我们通过裁剪、归一化和 HWC -> CHW 转换过程来处理图像。

[2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np


@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def wds_pipeline(wds_data=wds_data):
    img_raw, cls = fn.readers.webdataset(
        paths=wds_data, ext=["jpg", "cls"], missing_component_behavior="error"
    )
    img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB)
    resized = fn.resize(img, device="gpu", resize_shorter=256.0)
    output = fn.crop_mirror_normalize(
        resized,
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
    )
    return output, cls
  1. 构建并运行 Pipeline

[3]:
pipe = wds_pipeline()
pipe.build()
pipe_out = pipe.run()
  1. 要可视化结果,请使用 matplotlib 库,该库期望图像采用 HWC 格式,但 Pipeline 的输出采用 CHW 格式。

  2. 为了可视化目的,将图像转置回 HWC 布局。

[4]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

%matplotlib inline


def show_images(image_batch, labels):
    columns = 4
    rows = (batch_size + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        ascii = labels.at(j)
        plt.title(
            "".join([chr(item) for item in ascii]), fontdict={"fontsize": 25}
        )
        img_chw = image_batch.at(j)
        img_hwc = np.transpose(img_chw, (1, 2, 0)) / 255.0
        plt.imshow(img_hwc)
[5]:
images, labels = pipe_out
show_images(images.as_cpu(), labels)
../../../_images/examples_general_data_loading_dataloading_webdataset_8_0.png