WebDataset 集成使用外部源#

本笔记本展示了如何将 webdataset 与 DALI pipeline 结合使用,通过外部源操作符。

简介#

数据表示#

Web Dataset 是一种数据集表示形式,它极大地优化了网络访问存储性能。最简单的情况下,它将整个数据集存储在一个 tarball 文件中,其中每个样本由一个或多个具有相同名称但不同扩展名的条目表示。这种方法改进了 RAM 中的驱动器访问缓存,因为数据是顺序表示的。

分片#

为了提高分布式存储访问和网络数据传输,webdataset 采用了一种称为分片 的策略。在这种方法中,保存数据的 tarball 被分成几个较小的 tarball,称为分片,这允许一次从多个存储驱动器获取,并减少了必须通过网络传输的数据包大小。

示例实现#

首先,让我们导入必要的模块并定义稍后将需要的数据集的位置。

DALI_EXTRA_PATH 环境变量应指向从 DALI extra 仓库 下载数据的位置。请确保检出了正确的发布标签。

tar_dataset_paths 保存将要加载的分片的路径,同时展示和测试 webdataset 加载器。

batch_size 是两个加载器的通用批次大小

[1]:
import nvidia.dali.fn as fn
import nvidia.dali as dali
import nvidia.dali.types as types
import webdataset as wds
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import random
import tempfile
import tarfile

root_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db", "webdataset",
                         "MNIST")
tar_dataset_paths = [os.path.join(root_path, data_file)
                        for data_file in ["devel-0.tar", "devel-1.tar",
                                          "devel-2.tar"]]
batch_size = 16

接下来,让我们提取稍后将用于将文件读取器与我们的自定义读取器进行比较的文件。

folder_dataset_files 保存文件的路径

[2]:
folder_dataset_root_dir = tempfile.TemporaryDirectory()
folder_dataset_dirs = [tempfile.TemporaryDirectory(dir=folder_dataset_root_dir.name)
                     for dataset in tar_dataset_paths]
folder_dataset_tars = [tarfile.open(dataset) for dataset in tar_dataset_paths]

for folder_dataset_tar, folder_dataset_subdir in zip(folder_dataset_tars,
                                                     folder_dataset_dirs):
    folder_dataset_tar.extractall(path=folder_dataset_subdir.name)

folder_dataset_files = [
    filepath
    for folder_dataset_subdir in folder_dataset_dirs
    for filepath in sorted(
        glob.glob(os.path.join(folder_dataset_subdir.name, "*.jpg")),
        key=lambda s: int(s[s.rfind('/') + 1:s.rfind(".jpg")])
    )
]

下面的函数用于稍后随机化数据集的输出。样本首先存储在预取缓冲区中,然后它们在生成器中随机产生,并被新样本替换。

[3]:
def buffered_shuffle(generator_factory, initial_fill, seed):
    def buffered_shuffle_generator():
        nonlocal generator_factory, initial_fill, seed
        generator = generator_factory()
        # The buffer size must be positive
        assert(initial_fill > 0)

        # The buffer that will hold the randomized samples
        buffer = []

        # The random context for preventing side effects
        random_context = random.Random(seed)

        try:
            while len(buffer) < initial_fill: # Fills in the random buffer
                buffer.append(next(generator))

            # Selects a random sample from the buffer and then fills it back
            # in with a new one
            while True:
                idx = random_context.randint(0, initial_fill-1)

                yield buffer[idx]
                buffer[idx] = None
                buffer[idx] = next(generator)

        # When the generator runs out of the samples flushes our the buffer
        except StopIteration:
            random_context.shuffle(buffer)

            while buffer:
                # Prevents the one sample that was not filled from being duplicated
                if buffer[-1] != None:
                    yield buffer[-1]
                buffer.pop()
    return buffered_shuffle_generator

下一个函数用于用最后一个样本填充最后一个批次,以使其与所有其他批次大小相同。

[4]:
def last_batch_padding(generator_factory, batch_size):
    def last_batch_padding_generator():
        nonlocal generator_factory, batch_size
        generator = generator_factory()
        in_batch_idx = 0
        last_item = None
        try:
            # Keeps track of the last sample and the sample number mod batch_size
            while True:
                if in_batch_idx >= batch_size:
                    in_batch_idx -= batch_size
                last_item = next(generator)
                in_batch_idx += 1
                yield last_item
        # Repeats the last sample the necessary number of times
        except StopIteration:
            while in_batch_idx < batch_size:
                yield last_item
                in_batch_idx += 1
    return last_batch_padding_generator

最后一个函数将所有数据收集到批次中,以便能够为最后一个样本提供可变长度的批次

[5]:
def collect_batches(generator_factory, batch_size):
    def collect_batches_generator():
        nonlocal generator_factory, batch_size
        generator = generator_factory()
        batch = []
        try:
            while True:
                batch.append(next(generator))
                if len(batch) == batch_size:
                    # Converts tuples of samples into tuples of batches of samples
                    yield tuple(map(list, zip(*batch)))
                    batch = []
        except StopIteration:
            if batch is not []:
                # Converts tuples of samples into tuples of batches of samples
                yield tuple(map(list, zip(*batch)))
    return collect_batches_generator

最后是数据加载器,它配置并返回一个 ExternalSource 节点。

关键字参数:#

paths:描述包含 webdataset 的文件/文件的路径,并且可以格式化为 WebDataset 接受的任何数据

extensions:描述包含要通过数据集输出的数据的扩展名。默认情况下,使用 WebDataset 支持的所有图像格式扩展名

random_shuffle:描述是否打乱 WebDataset 读取的数据

initial_fill:如果 random_shuffle 为 True,则描述数据混洗器的缓冲区大小。默认设置为 256。

seed:描述用于混洗数据的种子。对于获得一致的结果很有用。默认设置为 0

pad_last_batch:描述是否用最后一个样本填充最后一个批次以匹配常规批次大小

read_ahead:描述是否将数据预取到内存中

cycle:可以是 "raise",在这种情况下,数据加载器一旦到达数据末尾将抛出 StopIteration,在这种情况下,用户必须在下一个 epoch 之前调用 pipeline.reset(),或者 "quiet"(默认),在这种情况下,它将保持一遍又一遍地循环数据

[6]:
def read_webdataset(
    paths,
    extensions=None,
    random_shuffle=False,
    initial_fill=256,
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet"
):
    # Parsing the input data
    assert(cycle in {"quiet", "raise", "no"})
    if extensions == None:
        # All supported image formats
        extensions = ';'.join(["jpg", "jpeg", "img", "image", "pbm", "pgm", "png"])
    if type(extensions) == str:
        extensions = (extensions,)

    # For later information for batch collection and padding
    max_batch_size = dali.pipeline.Pipeline.current().max_batch_size

    def webdataset_generator():
        bytes_np_mapper = (lambda data: np.frombuffer(data, dtype=np.uint8),
                           )*len(extensions)
        dataset_instance = (wds.WebDataset(paths)
                            .to_tuple(*extensions)
                            .map_tuple(*bytes_np_mapper))

        for sample in dataset_instance:
            yield sample

    dataset = webdataset_generator

    # Adding the buffered shuffling
    if random_shuffle:
        dataset = buffered_shuffle(dataset, initial_fill, seed)

    # Adding the batch padding
    if pad_last_batch:
        dataset = last_batch_padding(dataset, max_batch_size)

    # Collecting the data into batches (possibly undefull)
    # Handled by a custom function only when `silent_cycle` is False
    if cycle != "quiet":
        dataset = collect_batches(dataset, max_batch_size)

    # Prefetching the data
    if read_ahead:
        dataset=list(dataset())

    return fn.external_source(
        source=dataset,
        num_outputs=len(extensions),
        # If `cycle` is "quiet" then batching is handled by the external source
        batch=(cycle != "quiet"),
        cycle=cycle,
        dtype=types.UINT8
    )

我们还定义了一个示例数据增强函数,该函数解码图像,对其应用抖动并将其大小调整为 244x244。

[7]:
def decode_augment(img, seed=0):
    img = fn.decoders.image(img)
    img = fn.jitter(img.gpu(), seed=seed)
    img = fn.resize(img, size=(224, 224))
    return img

用法演示#

下面我们定义了带有我们基于 external_source 的加载器的示例 webdataset pipeline,它只是将先前定义的读取器和增强函数链接在一起。

[8]:
@dali.pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def webdataset_pipeline(
    paths,
    random_shuffle=False,
    initial_fill=256,
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet"
):
    img, label = read_webdataset(paths=paths,
                                 extensions=("jpg", "cls"),
                                 random_shuffle=random_shuffle,
                                 initial_fill=initial_fill,
                                 seed=seed,
                                 pad_last_batch=pad_last_batch,
                                 read_ahead=read_ahead,
                                 cycle=cycle)
    return decode_augment(img, seed=seed), label

然后可以使用传递给数据加载器的所需参数构建 pipeline

[9]:
pipeline = webdataset_pipeline(
    tar_dataset_paths,   # Paths for the sharded dataset
    random_shuffle=True, # Random buffered shuffling on
    pad_last_batch=False, # Last batch is filled to the full size
    read_ahead=False,
    cycle="raise")     # All the data is preloaded into the memory
pipeline.build()

并执行,使用 matplotlib 打印示例图像

[10]:
# If StopIteration is raised, use pipeline.reset() to start a new epoch
img, c = pipeline.run()
img = img.as_cpu()
# Conversion from an array of bytes back to bytes and then to int
print(int(bytes(c.as_array()[0])))
plt.imshow(img.as_array()[0])
plt.show()
1
../../_images/examples_use_cases_webdataset-externalsource_22_1.png

检查一致性#

在这里,我们将检查 webdataset 的自定义 pipeline 是否与从解压目录读取文件的等效 pipeline 匹配,使用 fn.readers.file 读取器。

首先,让我们定义要比较的 pipeline。这与 webdataset 的 pipeline 相同,但而是使用 fn.readers.file 读取器。

[11]:
@dali.pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def file_pipeline(files):
    img, _ = fn.readers.file(files=files)
    return decode_augment(img)

然后,让我们实例化并构建两个 pipeline

[12]:
webdataset_pipeline_instance = webdataset_pipeline(tar_dataset_paths)
webdataset_pipeline_instance.build()
file_pipeline_instance = file_pipeline(folder_dataset_files)
file_pipeline_instance.build()

并运行比较循环。

[13]:
# The number of batches to sample between the two pipelines
num_batches = 10

for _ in range(num_batches):
    webdataset_pipeline_threw_exception = False
    file_pipeline_threw_exception = False

    # Try running the webdataset pipeline and check if it has run out of
    # the samples
    try:
        web_img, _ = webdataset_pipeline_instance.run()
    except StopIteration:
        webdataset_pipeline_threw_exception = True

    # Try running the file pipeline and check if it has run out of the samples
    try:
        (file_img,) = file_pipeline_instance.run()
    except StopIteration:
        file_pipeline_threw_exception = True

    # In case of different number of batches
    assert(webdataset_pipeline_threw_exception==file_pipeline_threw_exception)

    web_img = web_img.as_cpu().as_array()
    file_img = file_img.as_cpu().as_array()

    # In case the pipelines give different outputs
    np.testing.assert_equal(web_img, file_img)
else:
    print("No difference found!")
No difference found!