张量连接#

本节演示以下连接张量的方法

  • 堆叠

  • 串联

这些操作接受多个输入,并通过连接输入张量来生成输出。这些方法之间的区别在于,串联是沿着现有轴连接张量,而堆叠是插入一个新轴。

堆叠可以用于例如将单独的坐标组合成向量,或将颜色平面组合成彩色图像。串联可以用于例如将瓦片连接成更大的图像或附加列表。

串联#

在本节中,我们将向您展示如何沿不同轴串联。由于在以下示例中,我们将沿不同轴串联相同的张量,因此这些张量必须具有相同的形状。

[1]:
import nvidia.dali as dali
import nvidia.dali.fn as fn
import numpy as np

np.random.seed(1234)

arr = np.array(
    [
        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
        [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],
    ]
)

src1 = dali.types.Constant(arr)
src2 = dali.types.Constant(arr + 100)
src3 = dali.types.Constant(arr + 200)

pipe_cat = dali.pipeline.Pipeline(batch_size=1, num_threads=3, device_id=0)
with pipe_cat:
    cat_outer = fn.cat(src1, src2, src3, axis=0)
    cat_middle = fn.cat(src1, src2, src3, axis=1)
    cat_inner = fn.cat(src1, src2, src3, axis=2)
    pipe_cat.set_outputs(cat_outer, cat_middle, cat_inner)

pipe_cat.build()
o = pipe_cat.run()
[2]:
print("Concatenation along outer axis:")
print(o[0].at(0))
print("Shape: ", o[0].at(0).shape)
Concatenation along outer axis:
[[[  1   2   3   4]
  [  5   6   7   8]
  [  9  10  11  12]]

 [[ 13  14  15  16]
  [ 17  18  19  20]
  [ 21  22  23  24]]

 [[101 102 103 104]
  [105 106 107 108]
  [109 110 111 112]]

 [[113 114 115 116]
  [117 118 119 120]
  [121 122 123 124]]

 [[201 202 203 204]
  [205 206 207 208]
  [209 210 211 212]]

 [[213 214 215 216]
  [217 218 219 220]
  [221 222 223 224]]]
Shape:  (6, 3, 4)
[3]:
print("Concatenation along middle axis:")
print(o[1].at(0))
print("Shape: ", o[1].at(0).shape)
Concatenation along middle axis:
[[[  1   2   3   4]
  [  5   6   7   8]
  [  9  10  11  12]
  [101 102 103 104]
  [105 106 107 108]
  [109 110 111 112]
  [201 202 203 204]
  [205 206 207 208]
  [209 210 211 212]]

 [[ 13  14  15  16]
  [ 17  18  19  20]
  [ 21  22  23  24]
  [113 114 115 116]
  [117 118 119 120]
  [121 122 123 124]
  [213 214 215 216]
  [217 218 219 220]
  [221 222 223 224]]]
Shape:  (2, 9, 4)
[4]:
print("Concatenation along inner axis:")
print(o[2].at(0))
print("Shape: ", o[2].at(0).shape)
Concatenation along inner axis:
[[[  1   2   3   4 101 102 103 104 201 202 203 204]
  [  5   6   7   8 105 106 107 108 205 206 207 208]
  [  9  10  11  12 109 110 111 112 209 210 211 212]]

 [[ 13  14  15  16 113 114 115 116 213 214 215 216]
  [ 17  18  19  20 117 118 119 120 217 218 219 220]
  [ 21  22  23  24 121 122 123 124 221 222 223 224]]]
Shape:  (2, 3, 12)

堆叠#

堆叠时,会插入一个新轴。它可以插入到最内层之后,在这种情况下,来自输入张量的值会交错排列。

将堆叠应用于与串联相同的输入。

[5]:
pipe_stack = dali.pipeline.Pipeline(batch_size=1, num_threads=3, device_id=0)
with pipe_stack:
    st_outermost = fn.stack(src1, src2, src3, axis=0)
    st_1 = fn.stack(src1, src2, src3, axis=1)
    st_2 = fn.stack(src1, src2, src3, axis=2)
    st_new_inner = fn.stack(src1, src2, src3, axis=3)
    pipe_stack.set_outputs(st_outermost, st_1, st_2, st_new_inner)

pipe_stack.build()
o = pipe_stack.run()
[6]:
print("Stacking - insert outermost axis:")
print(o[0].at(0))
print("Shape: ", o[0].at(0).shape)
Stacking - insert outermost axis:
[[[[  1   2   3   4]
   [  5   6   7   8]
   [  9  10  11  12]]

  [[ 13  14  15  16]
   [ 17  18  19  20]
   [ 21  22  23  24]]]


 [[[101 102 103 104]
   [105 106 107 108]
   [109 110 111 112]]

  [[113 114 115 116]
   [117 118 119 120]
   [121 122 123 124]]]


 [[[201 202 203 204]
   [205 206 207 208]
   [209 210 211 212]]

  [[213 214 215 216]
   [217 218 219 220]
   [221 222 223 224]]]]
Shape:  (3, 2, 3, 4)
[7]:
print("Stacking - new axis before 1:")
print(o[1].at(0))
print("Shape: ", o[1].at(0).shape)
Stacking - new axis before 1:
[[[[  1   2   3   4]
   [  5   6   7   8]
   [  9  10  11  12]]

  [[101 102 103 104]
   [105 106 107 108]
   [109 110 111 112]]

  [[201 202 203 204]
   [205 206 207 208]
   [209 210 211 212]]]


 [[[ 13  14  15  16]
   [ 17  18  19  20]
   [ 21  22  23  24]]

  [[113 114 115 116]
   [117 118 119 120]
   [121 122 123 124]]

  [[213 214 215 216]
   [217 218 219 220]
   [221 222 223 224]]]]
Shape:  (2, 3, 3, 4)
[8]:
print("Stacking - new axis before 2:")
print(o[2].at(0))
print("Shape: ", o[2].at(0).shape)
Stacking - new axis before 2:
[[[[  1   2   3   4]
   [101 102 103 104]
   [201 202 203 204]]

  [[  5   6   7   8]
   [105 106 107 108]
   [205 206 207 208]]

  [[  9  10  11  12]
   [109 110 111 112]
   [209 210 211 212]]]


 [[[ 13  14  15  16]
   [113 114 115 116]
   [213 214 215 216]]

  [[ 17  18  19  20]
   [117 118 119 120]
   [217 218 219 220]]

  [[ 21  22  23  24]
   [121 122 123 124]
   [221 222 223 224]]]]
Shape:  (2, 3, 3, 4)
[9]:
print("Stacking - new innermost axis:")
print(o[3].at(0))
print("Shape: ", o[3].at(0).shape)
Stacking - new innermost axis:
[[[[  1 101 201]
   [  2 102 202]
   [  3 103 203]
   [  4 104 204]]

  [[  5 105 205]
   [  6 106 206]
   [  7 107 207]
   [  8 108 208]]

  [[  9 109 209]
   [ 10 110 210]
   [ 11 111 211]
   [ 12 112 212]]]


 [[[ 13 113 213]
   [ 14 114 214]
   [ 15 115 215]
   [ 16 116 216]]

  [[ 17 117 217]
   [ 18 118 218]
   [ 19 119 219]
   [ 20 120 220]]

  [[ 21 121 221]
   [ 22 122 222]
   [ 23 123 223]
   [ 24 124 224]]]]
Shape:  (2, 3, 4, 3)