张量连接#
本节演示以下连接张量的方法
堆叠
串联
这些操作接受多个输入,并通过连接输入张量来生成输出。这些方法之间的区别在于,串联是沿着现有轴连接张量,而堆叠是插入一个新轴。
堆叠可以用于例如将单独的坐标组合成向量,或将颜色平面组合成彩色图像。串联可以用于例如将瓦片连接成更大的图像或附加列表。
串联#
在本节中,我们将向您展示如何沿不同轴串联。由于在以下示例中,我们将沿不同轴串联相同的张量,因此这些张量必须具有相同的形状。
[1]:
import nvidia.dali as dali
import nvidia.dali.fn as fn
import numpy as np
np.random.seed(1234)
arr = np.array(
[
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],
]
)
src1 = dali.types.Constant(arr)
src2 = dali.types.Constant(arr + 100)
src3 = dali.types.Constant(arr + 200)
pipe_cat = dali.pipeline.Pipeline(batch_size=1, num_threads=3, device_id=0)
with pipe_cat:
cat_outer = fn.cat(src1, src2, src3, axis=0)
cat_middle = fn.cat(src1, src2, src3, axis=1)
cat_inner = fn.cat(src1, src2, src3, axis=2)
pipe_cat.set_outputs(cat_outer, cat_middle, cat_inner)
pipe_cat.build()
o = pipe_cat.run()
[2]:
print("Concatenation along outer axis:")
print(o[0].at(0))
print("Shape: ", o[0].at(0).shape)
Concatenation along outer axis:
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
[[ 13 14 15 16]
[ 17 18 19 20]
[ 21 22 23 24]]
[[101 102 103 104]
[105 106 107 108]
[109 110 111 112]]
[[113 114 115 116]
[117 118 119 120]
[121 122 123 124]]
[[201 202 203 204]
[205 206 207 208]
[209 210 211 212]]
[[213 214 215 216]
[217 218 219 220]
[221 222 223 224]]]
Shape: (6, 3, 4)
[3]:
print("Concatenation along middle axis:")
print(o[1].at(0))
print("Shape: ", o[1].at(0).shape)
Concatenation along middle axis:
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[101 102 103 104]
[105 106 107 108]
[109 110 111 112]
[201 202 203 204]
[205 206 207 208]
[209 210 211 212]]
[[ 13 14 15 16]
[ 17 18 19 20]
[ 21 22 23 24]
[113 114 115 116]
[117 118 119 120]
[121 122 123 124]
[213 214 215 216]
[217 218 219 220]
[221 222 223 224]]]
Shape: (2, 9, 4)
[4]:
print("Concatenation along inner axis:")
print(o[2].at(0))
print("Shape: ", o[2].at(0).shape)
Concatenation along inner axis:
[[[ 1 2 3 4 101 102 103 104 201 202 203 204]
[ 5 6 7 8 105 106 107 108 205 206 207 208]
[ 9 10 11 12 109 110 111 112 209 210 211 212]]
[[ 13 14 15 16 113 114 115 116 213 214 215 216]
[ 17 18 19 20 117 118 119 120 217 218 219 220]
[ 21 22 23 24 121 122 123 124 221 222 223 224]]]
Shape: (2, 3, 12)
堆叠#
堆叠时,会插入一个新轴。它可以插入到最内层轴之后,在这种情况下,来自输入张量的值会交错排列。
将堆叠应用于与串联相同的输入。
[5]:
pipe_stack = dali.pipeline.Pipeline(batch_size=1, num_threads=3, device_id=0)
with pipe_stack:
st_outermost = fn.stack(src1, src2, src3, axis=0)
st_1 = fn.stack(src1, src2, src3, axis=1)
st_2 = fn.stack(src1, src2, src3, axis=2)
st_new_inner = fn.stack(src1, src2, src3, axis=3)
pipe_stack.set_outputs(st_outermost, st_1, st_2, st_new_inner)
pipe_stack.build()
o = pipe_stack.run()
[6]:
print("Stacking - insert outermost axis:")
print(o[0].at(0))
print("Shape: ", o[0].at(0).shape)
Stacking - insert outermost axis:
[[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
[[ 13 14 15 16]
[ 17 18 19 20]
[ 21 22 23 24]]]
[[[101 102 103 104]
[105 106 107 108]
[109 110 111 112]]
[[113 114 115 116]
[117 118 119 120]
[121 122 123 124]]]
[[[201 202 203 204]
[205 206 207 208]
[209 210 211 212]]
[[213 214 215 216]
[217 218 219 220]
[221 222 223 224]]]]
Shape: (3, 2, 3, 4)
[7]:
print("Stacking - new axis before 1:")
print(o[1].at(0))
print("Shape: ", o[1].at(0).shape)
Stacking - new axis before 1:
[[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
[[101 102 103 104]
[105 106 107 108]
[109 110 111 112]]
[[201 202 203 204]
[205 206 207 208]
[209 210 211 212]]]
[[[ 13 14 15 16]
[ 17 18 19 20]
[ 21 22 23 24]]
[[113 114 115 116]
[117 118 119 120]
[121 122 123 124]]
[[213 214 215 216]
[217 218 219 220]
[221 222 223 224]]]]
Shape: (2, 3, 3, 4)
[8]:
print("Stacking - new axis before 2:")
print(o[2].at(0))
print("Shape: ", o[2].at(0).shape)
Stacking - new axis before 2:
[[[[ 1 2 3 4]
[101 102 103 104]
[201 202 203 204]]
[[ 5 6 7 8]
[105 106 107 108]
[205 206 207 208]]
[[ 9 10 11 12]
[109 110 111 112]
[209 210 211 212]]]
[[[ 13 14 15 16]
[113 114 115 116]
[213 214 215 216]]
[[ 17 18 19 20]
[117 118 119 120]
[217 218 219 220]]
[[ 21 22 23 24]
[121 122 123 124]
[221 222 223 224]]]]
Shape: (2, 3, 3, 4)
[9]:
print("Stacking - new innermost axis:")
print(o[3].at(0))
print("Shape: ", o[3].at(0).shape)
Stacking - new innermost axis:
[[[[ 1 101 201]
[ 2 102 202]
[ 3 103 203]
[ 4 104 204]]
[[ 5 105 205]
[ 6 106 206]
[ 7 107 207]
[ 8 108 208]]
[[ 9 109 209]
[ 10 110 210]
[ 11 111 211]
[ 12 112 212]]]
[[[ 13 113 213]
[ 14 114 214]
[ 15 115 215]
[ 16 116 216]]
[[ 17 117 217]
[ 18 118 218]
[ 19 119 219]
[ 20 120 220]]
[[ 21 121 221]
[ 22 122 222]
[ 23 123 223]
[ 24 124 224]]]]
Shape: (2, 3, 4, 3)