Squeeze¶
通过移除 axes 指定的维度来调整输入张量的形状。对应的维度长度必须为 1。
输入¶
input0:类型为 T
的张量。
input1:类型为 Int32
或 Int64
的张量
输出¶
output:类型为 T
的张量。
数据类型¶
T: bool
, int4
, int8
, int32
, int64
, float8
, float16
, float32
, bfloat16
形状信息¶
input1 的形状为 \([n]\)。
output 是一个秩为 \(rank(input) - n\) 的张量。
示例¶
Squeeze
in1 = network.add_input("input1", dtype=trt.float32, shape=(3, 1, 4, 1))
axes_weights = trt.Weights(np.array([1, -1], dtype=np.int64))
axes_layer = network.add_constant((2,), axes_weights)
axes_tensor = axes_layer.get_output(0)
layer = network.add_squeeze(in1, axes_tensor)
network.mark_output(layer.get_output(0))
test_data = np.array(
[
[1.0, 2.0, 3.0, 4.0],
[10.0, 20.0, 30.0, 40.0],
[100.0, 200.0, 300.0, 400.0],
]
)
inputs[in1.name] = test_data.reshape(3, 1, 4, 1)
outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = test_data
C++ API¶
有关 C++ ISqueezeLayer 运算符的更多信息,请参阅 C++ ISqueezeLayer。
Python API¶
有关 Python ISqueezeLayer 运算符的更多信息,请参阅 Python ISqueezeLayer 文档。