数据加载: Webdataset#
概述#
此示例向您展示如何将以 Webdataset 格式存储的数据与 DALI 一起使用。
使用 readers.webdataset 操作器#
以 WebDataset 格式存储的数据可以使用 readers.webdataset
操作器读取。该操作器接受以下参数
paths
tar 归档文件的路径(或路径列表),其中包含 webdatasetindex_paths
相应索引文件的路径(或路径列表),其中包含有关 tar 文件具体信息的数据,使用wds2idx
创建 - DALI 附带的实用程序。有关使用详情,请参阅wds2idx -h
。如果未提供,将从 tar 文件自动推断,尽管对于大型数据集来说,这将花费相当长的时间。ext
扩展名集(或扩展名列表),以“;”分隔,指定操作器的输出以及哪些样本组件将馈送到操作器的特定输出missing_component_behavior
reader 在遇到无法为特定输出返回任何组件的样本时的行为。有 3 个选项empty
(默认)为该输出返回一个空张量skip
跳过缺少组件的样本error
引发错误
dtypes
操作器输出的数据类型。如果输出组件的大小无法被类型的大小整除,则会引发错误。如果未提供,则数据将以 UINT8 形式返回。除了这些参数外,该操作器还接受所有 reader 通用的参数,这些参数配置随机数生成器的种子、shuffle、分片以及在 epoch 结束时处理不完整批次。
创建索引#
索引文件(路径在参数 index_paths 中传递)可以使用 DALI 捆绑的工具 wds2idx 生成。
注意: DALI_EXTRA_PATH
环境变量应指向从 DALI extra 存储库 下载的数据的位置。
重要提示: 确保您检出与已安装 DALI 版本对应的正确发布标签。
[1]:
from subprocess import call
import os.path
test_data_root = os.environ["DALI_EXTRA_PATH"]
wds_data = os.path.join(test_data_root, "db", "webdataset", "train.tar")
batch_size = 16
定义和运行 Pipeline#
定义一个简单的 Pipeline,该 Pipeline 接受以 Webdataset 格式存储的图像并对其进行解码。
在此示例中,我们通过裁剪、归一化和
HWC
->CHW
转换过程来处理图像。
[2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np
@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def wds_pipeline(wds_data=wds_data):
img_raw, cls = fn.readers.webdataset(
paths=wds_data, ext=["jpg", "cls"], missing_component_behavior="error"
)
img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB)
resized = fn.resize(img, device="gpu", resize_shorter=256.0)
output = fn.crop_mirror_normalize(
resized,
dtype=types.FLOAT,
crop=(224, 224),
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
)
return output, cls
构建并运行 Pipeline
[3]:
pipe = wds_pipeline()
pipe.build()
pipe_out = pipe.run()
要可视化结果,请使用
matplotlib
库,该库期望图像采用HWC
格式,但 Pipeline 的输出采用CHW
格式。为了可视化目的,将图像转置回
HWC
布局。
[4]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
def show_images(image_batch, labels):
columns = 4
rows = (batch_size + 1) // (columns)
fig = plt.figure(figsize=(32, (32 // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
ascii = labels.at(j)
plt.title(
"".join([chr(item) for item in ascii]), fontdict={"fontsize": 25}
)
img_chw = image_batch.at(j)
img_hwc = np.transpose(img_chw, (1, 2, 0)) / 255.0
plt.imshow(img_hwc)
[5]:
images, labels = pipe_out
show_images(images.as_cpu(), labels)
