数据加载:Webdataset#

概述#

此示例演示了如何在 DALI 中使用以 Webdataset 格式存储的数据。

使用 readers.webdataset 运算符#

以 WebDataset 格式存储的数据可以使用 readers.webdataset 运算符读取。该运算符接受以下参数

  • paths 指向包含 webdataset 的 tar 存档的路径(或路径列表)

  • index_paths 指向各自索引文件的路径(或路径列表),这些索引文件包含有关 tar 文件具体信息的数据,使用 wds2idx 创建 - DALI 附带的实用程序。有关使用详情,请参阅 wds2idx -h。如果未提供,则将自动从 tar 文件中推断,但这对于大型数据集将花费相当长的时间。

  • ext 扩展名集(或列表),以“;”分隔,用于指定运算符的输出以及哪些样本组件将馈入运算符的特定输出

  • missing_component_behavior 运算符在遇到无法为特定输出返回任何组件的样本时的行为。有 3 个选项

    • empty (默认)为该输出返回一个空张量

    • skip 跳过缺少组件的样本

    • error 引发错误

  • dtypes 运算符输出的数据类型。如果输出组件的大小不能被类型的大小整除,则会引发错误。如果未提供,则数据将作为 UINT8 返回。除了这些参数外,该运算符还接受所有读取器通用的参数,这些参数配置随机数生成器的种子、洗牌、分片以及在 epoch 结束时处理不完整批次。

创建索引#

索引文件(路径在参数 index_paths 中传递)可以使用 DALI 捆绑的工具 wds2idx 生成。

注意DALI_EXTRA_PATH 环境变量应指向从 DALI extra repository 下载数据的位置。

重要提示:确保您检出与已安装的 DALI 版本对应的正确发布标签。

[1]:
from subprocess import call
import os.path

test_data_root = os.environ["DALI_EXTRA_PATH"]
wds_data = os.path.join(test_data_root, "db", "webdataset", "train.tar")
batch_size = 16

定义和运行 pipeline#

  1. 定义一个简单的 pipeline,它接受以 Webdataset 格式存储的图像并对其进行解码。

    在此示例中,我们通过裁剪、归一化和 HWC -> CHW 转换过程来处理图像。

[2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np


@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def wds_pipeline(wds_data=wds_data):
    img_raw, cls = fn.readers.webdataset(
        paths=wds_data, ext=["jpg", "cls"], missing_component_behavior="error"
    )
    img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB)
    resized = fn.resize(img, device="gpu", resize_shorter=256.0)
    output = fn.crop_mirror_normalize(
        resized,
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
    )
    return output, cls
  1. 构建并运行 pipeline

[3]:
pipe = wds_pipeline()
pipe.build()
pipe_out = pipe.run()
  1. 要可视化结果,请使用 matplotlib 库,该库期望图像采用 HWC 格式,但 pipeline 的输出采用 CHW 格式。

  2. 为了可视化目的,将图像转置回 HWC 布局。

[4]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

%matplotlib inline


def show_images(image_batch, labels):
    columns = 4
    rows = (batch_size + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        ascii = labels.at(j)
        plt.title(
            "".join([chr(item) for item in ascii]), fontdict={"fontsize": 25}
        )
        img_chw = image_batch.at(j)
        img_hwc = np.transpose(img_chw, (1, 2, 0)) / 255.0
        plt.imshow(img_hwc)
[5]:
images, labels = pipe_out
show_images(images.as_cpu(), labels)
../../../_images/examples_general_data_loading_dataloading_webdataset_8_0.png