Shuffle

将输入张量重塑和转置为输出张量。

该层首先转置张量,然后重塑它,最后,它对张量应用另一个转置。

属性

first_transpose 第一个转置操作应用的排列。默认值:单位排列。

reshape_dims 重塑后的维度。维度的乘积必须等于输入维度的乘积。两个特殊值可以用作维度值

  • 0 从输入复制相应的维度。如果重塑维度的数量小于输入,则通过对齐最重要的输入维度来解析 0。另请参见:zero_is_placeholder

  • -1 通过查看输入和其余的重塑维度来推断该特定维度。只允许将一个维度指定为 -1。如果输入可以具有零体积,并且任何其他重塑维度可以为零(在解决 0 的特殊处理之后),则避免使用 -1,因为 -1 的解变得不确定,TensorRT 将报告错误。

second_transpose 第二个转置操作应用的排列。默认值:单位排列。

zero_is_placeholder 重塑维度中 0 的含义。如果 true,则重塑维度中的 0 表示从第一个输入张量复制相应的维度。如果 false,则重塑维度中的 0 表示零长度维度。默认值:true。

输入

input0:类型为 T 的张量。

input1:类型为 Int32Int64 的可选张量,带有 reshape_dims

输出

output:类型为 T 的张量。

数据类型

Tbool, int4, int8, uint8, int32, float8, float16, float32, bfloat16

形状信息

output 是秩为 n 的张量。

input1 的形状为 \([n]\)

示例

Shuffle (洗牌)
in1 = network.add_input("input1", dtype=trt.float32, shape=(3, 4))
layer = network.add_shuffle(in1)
layer.first_transpose = trt.Permutation([1, 0])
layer.reshape_dims = trt.Dims([2, 6])
network.mark_output(layer.get_output(0))

inputs[in1.name] = np.array(
    [
        [1.0, 2.0, 3.0, 4.0],
        [10.0, 20.0, 30.0, 40.0],
        [100.0, 200.0, 300.0, 400.0],
    ]
)

outputs[layer.get_output(0).name] = layer.get_output(0).shape

expected[layer.get_output(0).name] = np.array(
    [[1.0, 10.0, 100.0, 2.0, 20.0, 200.0], [3.0, 30.0, 300.0, 4.0, 40.0, 400.0]]
)
Shuffle 重塑推断
in1 = network.add_input("input1", dtype=trt.float32, shape=(2, 3, 4))
layer = network.add_shuffle(in1)
layer.first_transpose = trt.Permutation([1, 0, 2])
layer.reshape_dims = trt.Dims([2, -1, 3])
network.mark_output(layer.get_output(0))

inputs[in1.name] = np.array(
    [
        [
            [1.0, 2.0, 3.0, 4.0],
            [10.0, 20.0, 30.0, 40.0],
            [100.0, 200.0, 300.0, 400.0],
        ],
        [
            [5.0, 6.0, 7.0, 8.0],
            [50.0, 60.0, 70.0, 80.0],
            [500.0, 600.0, 700.0, 800.0],
        ],
    ]
)

outputs[layer.get_output(0).name] = layer.get_output(0).shape

expected[layer.get_output(0).name] = np.array(
    [
        [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0], [20.0, 30.0, 40.0]],
        [[50.0, 60.0, 70.0], [80.0, 100.0, 200.0], [300.0, 400.0, 500.0], [600.0, 700.0, 800.0]],
    ]
)

C++ API

有关 C++ IShuffleLayer 算子的更多信息,请参阅 C++ IShuffleLayer 文档

Python API

有关 Python IShuffleLayer 算子的更多信息,请参阅 Python IShuffleLayer 文档