Shuffle¶
将输入张量重塑和转置为输出张量。
该层首先转置张量,然后重塑它,最后,它对张量应用另一个转置。
属性¶
first_transpose
第一个转置操作应用的排列。默认值:单位排列。
reshape_dims
重塑后的维度。维度的乘积必须等于输入维度的乘积。两个特殊值可以用作维度值
0
从输入复制相应的维度。如果重塑维度的数量小于输入,则通过对齐最重要的输入维度来解析 0。另请参见:zero_is_placeholder
。-1
通过查看输入和其余的重塑维度来推断该特定维度。只允许将一个维度指定为 -1。如果输入可以具有零体积,并且任何其他重塑维度可以为零(在解决 0 的特殊处理之后),则避免使用 -1,因为 -1 的解变得不确定,TensorRT 将报告错误。
second_transpose
第二个转置操作应用的排列。默认值:单位排列。
zero_is_placeholder
重塑维度中 0 的含义。如果 true
,则重塑维度中的 0 表示从第一个输入张量复制相应的维度。如果 false
,则重塑维度中的 0 表示零长度维度。默认值:true。
输入¶
input0:类型为 T
的张量。
input1:类型为 Int32
或 Int64
的可选张量,带有 reshape_dims
。
输出¶
output:类型为 T
的张量。
数据类型¶
T:bool
, int4
, int8
, uint8
, int32
, float8
, float16
, float32
, bfloat16
形状信息¶
output 是秩为 n
的张量。
input1 的形状为 \([n]\)。
示例¶
Shuffle (洗牌)
in1 = network.add_input("input1", dtype=trt.float32, shape=(3, 4))
layer = network.add_shuffle(in1)
layer.first_transpose = trt.Permutation([1, 0])
layer.reshape_dims = trt.Dims([2, 6])
network.mark_output(layer.get_output(0))
inputs[in1.name] = np.array(
[
[1.0, 2.0, 3.0, 4.0],
[10.0, 20.0, 30.0, 40.0],
[100.0, 200.0, 300.0, 400.0],
]
)
outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = np.array(
[[1.0, 10.0, 100.0, 2.0, 20.0, 200.0], [3.0, 30.0, 300.0, 4.0, 40.0, 400.0]]
)
Shuffle 重塑推断
in1 = network.add_input("input1", dtype=trt.float32, shape=(2, 3, 4))
layer = network.add_shuffle(in1)
layer.first_transpose = trt.Permutation([1, 0, 2])
layer.reshape_dims = trt.Dims([2, -1, 3])
network.mark_output(layer.get_output(0))
inputs[in1.name] = np.array(
[
[
[1.0, 2.0, 3.0, 4.0],
[10.0, 20.0, 30.0, 40.0],
[100.0, 200.0, 300.0, 400.0],
],
[
[5.0, 6.0, 7.0, 8.0],
[50.0, 60.0, 70.0, 80.0],
[500.0, 600.0, 700.0, 800.0],
],
]
)
outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = np.array(
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0], [20.0, 30.0, 40.0]],
[[50.0, 60.0, 70.0], [80.0, 100.0, 200.0], [300.0, 400.0, 500.0], [600.0, 700.0, 800.0]],
]
)
C++ API¶
有关 C++ IShuffleLayer 算子的更多信息,请参阅 C++ IShuffleLayer 文档。
Python API¶
有关 Python IShuffleLayer 算子的更多信息,请参阅 Python IShuffleLayer 文档。