Squeeze

通过移除 axes 指定的维度来调整输入张量的形状。对应的维度长度必须为 1。

输入

input0:类型为 T 的张量。

input1:类型为 Int32Int64 的张量

输出

output:类型为 T 的张量。

数据类型

T: bool, int4, int8, int32, int64, float8, float16, float32, bfloat16

形状信息

input1 的形状为 \([n]\)

output 是一个秩为 \(rank(input) - n\) 的张量。

示例

Squeeze
in1 = network.add_input("input1", dtype=trt.float32, shape=(3, 1, 4, 1))
axes_weights = trt.Weights(np.array([1, -1], dtype=np.int64))
axes_layer = network.add_constant((2,), axes_weights)
axes_tensor = axes_layer.get_output(0)
layer = network.add_squeeze(in1, axes_tensor)
network.mark_output(layer.get_output(0))

test_data = np.array(
    [
        [1.0, 2.0, 3.0, 4.0],
        [10.0, 20.0, 30.0, 40.0],
        [100.0, 200.0, 300.0, 400.0],
    ]
)

inputs[in1.name] = test_data.reshape(3, 1, 4, 1)

outputs[layer.get_output(0).name] = layer.get_output(0).shape

expected[layer.get_output(0).name] = test_data

C++ API

有关 C++ ISqueezeLayer 运算符的更多信息,请参阅 C++ ISqueezeLayer

Python API

有关 Python ISqueezeLayer 运算符的更多信息,请参阅 Python ISqueezeLayer 文档