使用带有稀疏张量的 Tensorflow DALI 插件#
概述#
将我们的 DALI 数据加载和增强 pipeline 与 Tensorflow 结合使用非常简单。
然而,有时希望从 pipeline 中提取的一批数据无法表示为密集张量。在这种情况下,DALI op 会使用 TensorFlow SparseTensor。请注意,SparseTensor 仅在基于 CPU 的 pipeline 中受支持。
定义数据加载 Pipeline#
首先,我们从定义一些简单的 pipeline 开始,这些 pipeline 将以稀疏张量的形式返回数据。为了实现这一点,我们将使用著名的 COCO 数据集。每张图像可能包含 0 个或多个边界框,其中包含描述其中物体的标签。我们希望以标准化的方式返回图像,而标签和边界框将表示为稀疏张量。首先,让我们定义一些全局参数
DALI_EXTRA_PATH
环境变量应指向从 DALI extra repository 下载数据的位置。请确保已检出正确的发布标记。
[1]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import os.path
test_data_root = os.environ["DALI_EXTRA_PATH"]
BATCH_SIZE = 32
test_data_root = os.environ["DALI_EXTRA_PATH"]
file_root = os.path.join(test_data_root, "db", "coco", "images")
annotations_file = os.path.join(test_data_root, "db", "coco", "instances.json")
创建了带有 COCO 读取器的 Pipeline。请注意,在处理图像时,来自 COCO ara 的其他数据也会通过。
[2]:
@pipeline_def
def coco_pipeline():
jpegs, bboxes, labels, im_ids = fn.readers.coco(
file_root=file_root,
annotations_file=annotations_file,
ratio=False,
image_ids=True,
)
images = fn.decoders.image(jpegs, device="cpu")
images = fn.resize(
images,
resize_shorter=fn.random.uniform(range=(256.0, 480.0)),
interp_type=types.INTERP_LINEAR,
)
images = fn.crop_mirror_normalize(
images,
crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),
crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),
dtype=types.FLOAT,
crop=(224, 224),
mean=[128.0, 128.0, 128.0],
std=[1.0, 1.0, 1.0],
)
images = fn.cast(images, dtype=types.INT32)
return images, bboxes, labels, im_ids
接下来,我们使用正确的参数实例化 pipeline。我们将为每个 GPU 创建一个 pipeline,方法是为每个 pipeline 指定正确的 device_id
。
不同之处在于,我们将 pipeline 对象传递给 TensorFlow 运算符,而不是调用 pipeline.build
并使用它。
[3]:
pipe = coco_pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0)
使用 DALI TensorFlow 插件#
首先,让我们导入 Tensorflow 和 DALI Tensorflow 插件,并将其命名为 dali_tf
。
[4]:
import tensorflow as tf
import nvidia.dali.plugin.tf as dali_tf
import time
from tensorflow.compat.v1 import GPUOptions
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import Session
from tensorflow.compat.v1 import placeholder
tf.compat.v1.disable_eager_execution()
我们现在可以使用 nvidia.dali.plugin.tf.DALIIterator()
方法来获取 Tensorflow Op,它将生成我们将在 Tensorflow 图形中使用的张量。
对于每个 DALI pipeline,我们使用 daliop
,它返回一个 Tensorflow 张量元组,我们将将其存储在 image, bouding boxes, labels and image ids
中。要启用稀疏张量生成,需要为将表示为稀疏张量的输出元素填充 True
参数。
[5]:
daliop = dali_tf.DALIIterator()
images = []
bboxes = []
labels = []
image_ids = []
with tf.device("/cpu"):
image, bbox, label, id = daliop(
pipeline=pipe,
shapes=[(BATCH_SIZE, 3, 224, 224), (), (), ()],
dtypes=[tf.int32, tf.float32, tf.int32, tf.int32],
sparse=[False, True, True],
)
images.append(image)
bboxes.append(bbox)
labels.append(label)
image_ids.append(id)
在简单的 Tensorflow 图形中使用张量#
我们将在 Tensorflow 图形定义中使用 images
、bboxes
、labels
和 image_ids
张量列表。然后运行一个非常简单的*单操作图形*会话,它将输出批量数据。然后我们将打印边界框、标签和 image_ids。
[6]:
with Session() as sess:
all_img_per_sec = []
total_batch_size = BATCH_SIZE
start_time = time.time()
# The actual run with our dali_tf tensors
res_cpu = sess.run([images, bboxes, labels, image_ids])
print(res_cpu[1])
print(res_cpu[2])
print(res_cpu[3])
[SparseTensorValue(indices=array([[ 0, 0, 0],
[ 0, 0, 1],
[ 0, 0, 2],
[ 0, 0, 3],
[ 1, 0, 0],
[ 1, 0, 1],
[ 1, 0, 2],
[ 1, 0, 3],
[ 2, 0, 0],
[ 2, 0, 1],
[ 2, 0, 2],
[ 2, 0, 3],
[ 3, 0, 0],
[ 3, 0, 1],
[ 3, 0, 2],
[ 3, 0, 3],
[ 3, 1, 0],
[ 3, 1, 1],
[ 3, 1, 2],
[ 3, 1, 3],
[ 4, 0, 0],
[ 4, 0, 1],
[ 4, 0, 2],
[ 4, 0, 3],
[ 5, 0, 0],
[ 5, 0, 1],
[ 5, 0, 2],
[ 5, 0, 3],
[ 6, 0, 0],
[ 6, 0, 1],
[ 6, 0, 2],
[ 6, 0, 3],
[ 7, 0, 0],
[ 7, 0, 1],
[ 7, 0, 2],
[ 7, 0, 3],
[ 8, 0, 0],
[ 8, 0, 1],
[ 8, 0, 2],
[ 8, 0, 3],
[ 9, 0, 0],
[ 9, 0, 1],
[ 9, 0, 2],
[ 9, 0, 3],
[ 9, 1, 0],
[ 9, 1, 1],
[ 9, 1, 2],
[ 9, 1, 3],
[10, 0, 0],
[10, 0, 1],
[10, 0, 2],
[10, 0, 3],
[10, 1, 0],
[10, 1, 1],
[10, 1, 2],
[10, 1, 3],
[10, 2, 0],
[10, 2, 1],
[10, 2, 2],
[10, 2, 3],
[10, 3, 0],
[10, 3, 1],
[10, 3, 2],
[10, 3, 3],
[10, 4, 0],
[10, 4, 1],
[10, 4, 2],
[10, 4, 3],
[10, 5, 0],
[10, 5, 1],
[10, 5, 2],
[10, 5, 3],
[11, 0, 0],
[11, 0, 1],
[11, 0, 2],
[11, 0, 3],
[12, 0, 0],
[12, 0, 1],
[12, 0, 2],
[12, 0, 3],
[13, 0, 0],
[13, 0, 1],
[13, 0, 2],
[13, 0, 3],
[13, 1, 0],
[13, 1, 1],
[13, 1, 2],
[13, 1, 3],
[14, 0, 0],
[14, 0, 1],
[14, 0, 2],
[14, 0, 3],
[15, 0, 0],
[15, 0, 1],
[15, 0, 2],
[15, 0, 3],
[16, 0, 0],
[16, 0, 1],
[16, 0, 2],
[16, 0, 3],
[16, 1, 0],
[16, 1, 1],
[16, 1, 2],
[16, 1, 3],
[16, 2, 0],
[16, 2, 1],
[16, 2, 2],
[16, 2, 3],
[17, 0, 0],
[17, 0, 1],
[17, 0, 2],
[17, 0, 3],
[18, 0, 0],
[18, 0, 1],
[18, 0, 2],
[18, 0, 3],
[18, 1, 0],
[18, 1, 1],
[18, 1, 2],
[18, 1, 3],
[19, 0, 0],
[19, 0, 1],
[19, 0, 2],
[19, 0, 3],
[20, 0, 0],
[20, 0, 1],
[20, 0, 2],
[20, 0, 3],
[21, 0, 0],
[21, 0, 1],
[21, 0, 2],
[21, 0, 3],
[22, 0, 0],
[22, 0, 1],
[22, 0, 2],
[22, 0, 3],
[23, 0, 0],
[23, 0, 1],
[23, 0, 2],
[23, 0, 3],
[23, 1, 0],
[23, 1, 1],
[23, 1, 2],
[23, 1, 3],
[23, 2, 0],
[23, 2, 1],
[23, 2, 2],
[23, 2, 3],
[24, 0, 0],
[24, 0, 1],
[24, 0, 2],
[24, 0, 3],
[25, 0, 0],
[25, 0, 1],
[25, 0, 2],
[25, 0, 3],
[26, 0, 0],
[26, 0, 1],
[26, 0, 2],
[26, 0, 3],
[27, 0, 0],
[27, 0, 1],
[27, 0, 2],
[27, 0, 3],
[27, 1, 0],
[27, 1, 1],
[27, 1, 2],
[27, 1, 3],
[27, 2, 0],
[27, 2, 1],
[27, 2, 2],
[27, 2, 3],
[28, 0, 0],
[28, 0, 1],
[28, 0, 2],
[28, 0, 3],
[29, 0, 0],
[29, 0, 1],
[29, 0, 2],
[29, 0, 3],
[30, 0, 0],
[30, 0, 1],
[30, 0, 2],
[30, 0, 3],
[31, 0, 0],
[31, 0, 1],
[31, 0, 2],
[31, 0, 3]]), values=array([ 604., 120., 78., 563., 294., 411., 669., 345., 206.,
19., 887., 664., 70., 239., 580., 655., 604., 192.,
624., 726., 160., 152., 413., 397., 521., 36., 136.,
443., 732., 390., 181., 48., 69., 216., 1129., 437.,
377., 24., 512., 652., 316., 52., 476., 428., 572.,
442., 98., 403., 172., 181., 932., 466., 446., 191.,
728., 608., 347., 645., 187., 83., 143., 569., 204.,
88., 110., 145., 894., 363., 528., 120., 448., 273.,
253., 283., 816., 518., 85., 518., 639., 389., 221.,
188., 495., 220., 297., 486., 413., 211., 175., 44.,
1103., 916., 624., 241., 526., 474., 219., 222., 453.,
237., 553., 157., 366., 305., 727., 208., 465., 255.,
290., 269., 967., 467., 614., 30., 529., 787., 613.,
23., 527., 793., 331., 160., 600., 539., 55., 148.,
989., 512., 405., 74., 753., 496., 60., 497., 905.,
246., 432., 110., 252., 540., 528., 105., 643., 491.,
566., 79., 667., 439., 185., 28., 903., 785., 195.,
337., 820., 459., 10., 65., 978., 1214., 999., 312.,
138., 171., 853., 259., 167., 234., 897., 285., 182.,
299., 173., 55., 767., 1079., 539., 448., 556., 323.,
0., 77., 1036., 775., 72., 54., 1207., 797.],
dtype=float32), dense_shape=array([32, 6, 4]))]
[SparseTensorValue(indices=array([[ 0, 0],
[ 1, 0],
[ 2, 0],
[ 3, 0],
[ 3, 1],
[ 4, 0],
[ 5, 0],
[ 6, 0],
[ 7, 0],
[ 8, 0],
[ 9, 0],
[ 9, 1],
[10, 0],
[10, 1],
[10, 2],
[10, 3],
[10, 4],
[10, 5],
[11, 0],
[12, 0],
[13, 0],
[13, 1],
[14, 0],
[15, 0],
[16, 0],
[16, 1],
[16, 2],
[17, 0],
[18, 0],
[18, 1],
[19, 0],
[20, 0],
[21, 0],
[22, 0],
[23, 0],
[23, 1],
[23, 2],
[24, 0],
[25, 0],
[26, 0],
[27, 0],
[27, 1],
[27, 2],
[28, 0],
[29, 0],
[30, 0],
[31, 0]]), values=array([17, 2, 14, 12, 12, 1, 17, 8, 6, 8, 10, 17, 3, 3, 3, 3, 3,
3, 2, 4, 13, 14, 9, 1, 12, 12, 12, 6, 8, 10, 8, 14, 13, 16,
3, 3, 3, 15, 15, 9, 13, 13, 13, 7, 4, 12, 7], dtype=int32), dense_shape=array([32, 6]))]
[array([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15],
[16],
[17],
[18],
[19],
[20],
[21],
[22],
[23],
[24],
[25],
[26],
[27],
[28],
[29],
[30],
[31]], dtype=int32)]
让我们检查一下带有增强的输出图像!Tensorflow 输出 numpy 数组,因此我们可以使用 matplotlib
轻松地可视化它们。
我们定义一个 show_images
辅助函数,它将显示我们批次的样本。
批次布局是 NCHW,因此我们使用转置来获取 HWC 图像,matplotlib
可以显示这些图像。
[7]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
def show_images(image_batch, nb_images):
columns = 4
rows = (nb_images + 1) // (columns)
fig = plt.figure(figsize=(32, (32 // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(nb_images):
plt.subplot(gs[j])
plt.axis("off")
img = image_batch[0][j].transpose((1, 2, 0)) + 128
plt.imshow(img.astype("uint8"))
show_images(res_cpu[0], 8)
data:image/s3,"s3://crabby-images/0b7d0/0b7d0d420678e72043c9d157ca4c404eaef7ae7d" alt="../../../_images/examples_frameworks_tensorflow_tensorflow-plugin-sparse-tensor_14_0.png"
[ ]: