使用 Tensorflow DALI 插件与稀疏张量#

概述#

将我们的 DALI 数据加载和增强 pipeline 与 Tensorflow 结合使用非常简单。

但是,有时从 pipeline 中提取的数据批次不能表示为密集张量。在这种情况下,DALI op 利用 TensorFlow SparseTensor。请记住,SparseTensor 仅支持基于 CPU 的 pipeline。

定义数据加载 Pipeline#

首先,我们从定义一些简单的 pipeline 开始,它将以稀疏张量形式返回数据。为了实现这一点,我们将使用众所周知的 COCO 数据集。每张图像可能具有 0 个或多个边界框,其中包含描述其中存在的对象的标签。我们希望以标准化的方式返回图像,而标签和边界框将表示为稀疏张量。首先,让我们定义一些全局参数

DALI_EXTRA_PATH 环境变量应指向 DALI extra repository 中的数据下载位置。请确保检出正确的发布标签。

[1]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import os.path

test_data_root = os.environ["DALI_EXTRA_PATH"]

BATCH_SIZE = 32
test_data_root = os.environ["DALI_EXTRA_PATH"]
file_root = os.path.join(test_data_root, "db", "coco", "images")
annotations_file = os.path.join(test_data_root, "db", "coco", "instances.json")

创建了带有 COCO 读取器的 Pipeline。请注意,在处理图像时,来自 COCO ara 的其他数据会通过。

[2]:
@pipeline_def
def coco_pipeline():
    jpegs, bboxes, labels, im_ids = fn.readers.coco(
        file_root=file_root,
        annotations_file=annotations_file,
        ratio=False,
        image_ids=True,
    )
    images = fn.decoders.image(jpegs, device="cpu")
    images = fn.resize(
        images,
        resize_shorter=fn.random.uniform(range=(256.0, 480.0)),
        interp_type=types.INTERP_LINEAR,
    )
    images = fn.crop_mirror_normalize(
        images,
        crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),
        crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[128.0, 128.0, 128.0],
        std=[1.0, 1.0, 1.0],
    )
    images = fn.cast(images, dtype=types.INT32)

    return images, bboxes, labels, im_ids

接下来,我们使用正确的参数实例化 pipeline。我们将为每个 GPU 创建一个 pipeline,方法是为每个 pipeline 指定正确的 device_id

不同之处在于,我们将 pipeline 对象传递给 TensorFlow 运算符,而不是调用 pipeline.build 并使用它。

[3]:
pipe = coco_pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0)

使用 DALI TensorFlow 插件#

让我们首先导入 Tensorflow 和 DALI Tensorflow 插件作为 dali_tf

[4]:
import tensorflow as tf
import nvidia.dali.plugin.tf as dali_tf
import time
from tensorflow.compat.v1 import GPUOptions
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import Session
from tensorflow.compat.v1 import placeholder

tf.compat.v1.disable_eager_execution()

我们现在可以使用 nvidia.dali.plugin.tf.DALIIterator() 方法来获取 TensorFlow Op,它将生成我们将在 Tensorflow 图中使用的张量。

对于每个 DALI pipeline,我们使用 daliop,它返回一个 TensorFlow 张量元组,我们将存储在 image、 bouding boxes、 labels image ids 中。要启用稀疏张量生成,需要为将表示为稀疏张量的输出元素填充 True 参数。

[5]:
daliop = dali_tf.DALIIterator()

images = []
bboxes = []
labels = []
image_ids = []

with tf.device("/cpu"):
    image, bbox, label, id = daliop(
        pipeline=pipe,
        shapes=[(BATCH_SIZE, 3, 224, 224), (), (), ()],
        dtypes=[tf.int32, tf.float32, tf.int32, tf.int32],
        sparse=[False, True, True],
    )

    images.append(image)
    bboxes.append(bbox)
    labels.append(label)
    image_ids.append(id)

在简单 Tensorflow 图中使用张量#

我们将在 Tensorflow 图定义中使用 imagesbboxeslabelsimage_ids 张量列表。然后运行一个非常简单的单操作图会话,它将输出数据批次。然后我们将打印边界框、标签和图像 ID。

[6]:
with Session() as sess:
    all_img_per_sec = []
    total_batch_size = BATCH_SIZE

    start_time = time.time()

    # The actual run with our dali_tf tensors
    res_cpu = sess.run([images, bboxes, labels, image_ids])

print(res_cpu[1])
print(res_cpu[2])
print(res_cpu[3])
[SparseTensorValue(indices=array([[ 0,  0,  0],
       [ 0,  0,  1],
       [ 0,  0,  2],
       [ 0,  0,  3],
       [ 1,  0,  0],
       [ 1,  0,  1],
       [ 1,  0,  2],
       [ 1,  0,  3],
       [ 2,  0,  0],
       [ 2,  0,  1],
       [ 2,  0,  2],
       [ 2,  0,  3],
       [ 3,  0,  0],
       [ 3,  0,  1],
       [ 3,  0,  2],
       [ 3,  0,  3],
       [ 3,  1,  0],
       [ 3,  1,  1],
       [ 3,  1,  2],
       [ 3,  1,  3],
       [ 4,  0,  0],
       [ 4,  0,  1],
       [ 4,  0,  2],
       [ 4,  0,  3],
       [ 5,  0,  0],
       [ 5,  0,  1],
       [ 5,  0,  2],
       [ 5,  0,  3],
       [ 6,  0,  0],
       [ 6,  0,  1],
       [ 6,  0,  2],
       [ 6,  0,  3],
       [ 7,  0,  0],
       [ 7,  0,  1],
       [ 7,  0,  2],
       [ 7,  0,  3],
       [ 8,  0,  0],
       [ 8,  0,  1],
       [ 8,  0,  2],
       [ 8,  0,  3],
       [ 9,  0,  0],
       [ 9,  0,  1],
       [ 9,  0,  2],
       [ 9,  0,  3],
       [ 9,  1,  0],
       [ 9,  1,  1],
       [ 9,  1,  2],
       [ 9,  1,  3],
       [10,  0,  0],
       [10,  0,  1],
       [10,  0,  2],
       [10,  0,  3],
       [10,  1,  0],
       [10,  1,  1],
       [10,  1,  2],
       [10,  1,  3],
       [10,  2,  0],
       [10,  2,  1],
       [10,  2,  2],
       [10,  2,  3],
       [10,  3,  0],
       [10,  3,  1],
       [10,  3,  2],
       [10,  3,  3],
       [10,  4,  0],
       [10,  4,  1],
       [10,  4,  2],
       [10,  4,  3],
       [10,  5,  0],
       [10,  5,  1],
       [10,  5,  2],
       [10,  5,  3],
       [11,  0,  0],
       [11,  0,  1],
       [11,  0,  2],
       [11,  0,  3],
       [12,  0,  0],
       [12,  0,  1],
       [12,  0,  2],
       [12,  0,  3],
       [13,  0,  0],
       [13,  0,  1],
       [13,  0,  2],
       [13,  0,  3],
       [13,  1,  0],
       [13,  1,  1],
       [13,  1,  2],
       [13,  1,  3],
       [14,  0,  0],
       [14,  0,  1],
       [14,  0,  2],
       [14,  0,  3],
       [15,  0,  0],
       [15,  0,  1],
       [15,  0,  2],
       [15,  0,  3],
       [16,  0,  0],
       [16,  0,  1],
       [16,  0,  2],
       [16,  0,  3],
       [16,  1,  0],
       [16,  1,  1],
       [16,  1,  2],
       [16,  1,  3],
       [16,  2,  0],
       [16,  2,  1],
       [16,  2,  2],
       [16,  2,  3],
       [17,  0,  0],
       [17,  0,  1],
       [17,  0,  2],
       [17,  0,  3],
       [18,  0,  0],
       [18,  0,  1],
       [18,  0,  2],
       [18,  0,  3],
       [18,  1,  0],
       [18,  1,  1],
       [18,  1,  2],
       [18,  1,  3],
       [19,  0,  0],
       [19,  0,  1],
       [19,  0,  2],
       [19,  0,  3],
       [20,  0,  0],
       [20,  0,  1],
       [20,  0,  2],
       [20,  0,  3],
       [21,  0,  0],
       [21,  0,  1],
       [21,  0,  2],
       [21,  0,  3],
       [22,  0,  0],
       [22,  0,  1],
       [22,  0,  2],
       [22,  0,  3],
       [23,  0,  0],
       [23,  0,  1],
       [23,  0,  2],
       [23,  0,  3],
       [23,  1,  0],
       [23,  1,  1],
       [23,  1,  2],
       [23,  1,  3],
       [23,  2,  0],
       [23,  2,  1],
       [23,  2,  2],
       [23,  2,  3],
       [24,  0,  0],
       [24,  0,  1],
       [24,  0,  2],
       [24,  0,  3],
       [25,  0,  0],
       [25,  0,  1],
       [25,  0,  2],
       [25,  0,  3],
       [26,  0,  0],
       [26,  0,  1],
       [26,  0,  2],
       [26,  0,  3],
       [27,  0,  0],
       [27,  0,  1],
       [27,  0,  2],
       [27,  0,  3],
       [27,  1,  0],
       [27,  1,  1],
       [27,  1,  2],
       [27,  1,  3],
       [27,  2,  0],
       [27,  2,  1],
       [27,  2,  2],
       [27,  2,  3],
       [28,  0,  0],
       [28,  0,  1],
       [28,  0,  2],
       [28,  0,  3],
       [29,  0,  0],
       [29,  0,  1],
       [29,  0,  2],
       [29,  0,  3],
       [30,  0,  0],
       [30,  0,  1],
       [30,  0,  2],
       [30,  0,  3],
       [31,  0,  0],
       [31,  0,  1],
       [31,  0,  2],
       [31,  0,  3]]), values=array([ 604.,  120.,   78.,  563.,  294.,  411.,  669.,  345.,  206.,
         19.,  887.,  664.,   70.,  239.,  580.,  655.,  604.,  192.,
        624.,  726.,  160.,  152.,  413.,  397.,  521.,   36.,  136.,
        443.,  732.,  390.,  181.,   48.,   69.,  216., 1129.,  437.,
        377.,   24.,  512.,  652.,  316.,   52.,  476.,  428.,  572.,
        442.,   98.,  403.,  172.,  181.,  932.,  466.,  446.,  191.,
        728.,  608.,  347.,  645.,  187.,   83.,  143.,  569.,  204.,
         88.,  110.,  145.,  894.,  363.,  528.,  120.,  448.,  273.,
        253.,  283.,  816.,  518.,   85.,  518.,  639.,  389.,  221.,
        188.,  495.,  220.,  297.,  486.,  413.,  211.,  175.,   44.,
       1103.,  916.,  624.,  241.,  526.,  474.,  219.,  222.,  453.,
        237.,  553.,  157.,  366.,  305.,  727.,  208.,  465.,  255.,
        290.,  269.,  967.,  467.,  614.,   30.,  529.,  787.,  613.,
         23.,  527.,  793.,  331.,  160.,  600.,  539.,   55.,  148.,
        989.,  512.,  405.,   74.,  753.,  496.,   60.,  497.,  905.,
        246.,  432.,  110.,  252.,  540.,  528.,  105.,  643.,  491.,
        566.,   79.,  667.,  439.,  185.,   28.,  903.,  785.,  195.,
        337.,  820.,  459.,   10.,   65.,  978., 1214.,  999.,  312.,
        138.,  171.,  853.,  259.,  167.,  234.,  897.,  285.,  182.,
        299.,  173.,   55.,  767., 1079.,  539.,  448.,  556.,  323.,
          0.,   77., 1036.,  775.,   72.,   54., 1207.,  797.],
      dtype=float32), dense_shape=array([32,  6,  4]))]
[SparseTensorValue(indices=array([[ 0,  0],
       [ 1,  0],
       [ 2,  0],
       [ 3,  0],
       [ 3,  1],
       [ 4,  0],
       [ 5,  0],
       [ 6,  0],
       [ 7,  0],
       [ 8,  0],
       [ 9,  0],
       [ 9,  1],
       [10,  0],
       [10,  1],
       [10,  2],
       [10,  3],
       [10,  4],
       [10,  5],
       [11,  0],
       [12,  0],
       [13,  0],
       [13,  1],
       [14,  0],
       [15,  0],
       [16,  0],
       [16,  1],
       [16,  2],
       [17,  0],
       [18,  0],
       [18,  1],
       [19,  0],
       [20,  0],
       [21,  0],
       [22,  0],
       [23,  0],
       [23,  1],
       [23,  2],
       [24,  0],
       [25,  0],
       [26,  0],
       [27,  0],
       [27,  1],
       [27,  2],
       [28,  0],
       [29,  0],
       [30,  0],
       [31,  0]]), values=array([17,  2, 14, 12, 12,  1, 17,  8,  6,  8, 10, 17,  3,  3,  3,  3,  3,
        3,  2,  4, 13, 14,  9,  1, 12, 12, 12,  6,  8, 10,  8, 14, 13, 16,
        3,  3,  3, 15, 15,  9, 13, 13, 13,  7,  4, 12,  7], dtype=int32), dense_shape=array([32,  6]))]
[array([[ 0],
       [ 1],
       [ 2],
       [ 3],
       [ 4],
       [ 5],
       [ 6],
       [ 7],
       [ 8],
       [ 9],
       [10],
       [11],
       [12],
       [13],
       [14],
       [15],
       [16],
       [17],
       [18],
       [19],
       [20],
       [21],
       [22],
       [23],
       [24],
       [25],
       [26],
       [27],
       [28],
       [29],
       [30],
       [31]], dtype=int32)]

让我们检查一下带有增强的输出图像!Tensorflow 输出 numpy 数组,因此我们可以使用 matplotlib 轻松地对其进行可视化。

我们定义一个 show_images 辅助函数,它将显示我们批次的样本。

批次布局是 NCHW,因此我们使用转置来获取 matplotlib 可以显示的 HWC 图像。

[7]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

%matplotlib inline


def show_images(image_batch, nb_images):
    columns = 4
    rows = (nb_images + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(nb_images):
        plt.subplot(gs[j])
        plt.axis("off")
        img = image_batch[0][j].transpose((1, 2, 0)) + 128
        plt.imshow(img.astype("uint8"))


show_images(res_cpu[0], 8)
../../../_images/examples_frameworks_tensorflow_tensorflow-plugin-sparse-tensor_14_0.png
[ ]: