断言

当输入张量包含 false 值时触发断言。如果构建器可以证明任何输入张量值在构建时为 false,则会报告构建时错误。否则,将在运行时评估输入张量值,如果任何元素为 false,则会报告运行时错误。

属性

message 断言失败时要打印的消息。

输入

condition: 类型为 T 的张量

数据类型

T: bool

形状信息

condition 是一个秩为 \(0 \leq n \leq 1\) 的张量。

体积限制

condition 最多可以有 64 个元素。

示例

断言未触发
in1 = network.add_input("input1", dtype=trt.float32, shape=(3, 4, 4))
shape = network.add_shape(in1)
identity = network.add_identity(in1)
cond = network.add_elementwise(shape.get_output(0), shape.get_output(0), op=trt.ElementWiseOperation.EQUAL)
assertion = network.add_assertion(cond.get_output(0), message="Shouldn't fail")
network.mark_output(identity.get_output(0))

inputs[in1.name] = np.zeros(shape=(2, 4))
outputs[identity.get_output(0).name] = identity.get_output(0).shape
构建时触发断言
# This test should fail during build stage
in1 = network.add_input("input1", dtype=trt.float32, shape=(3, 4, 4))
shape1 = network.add_shape(in1)
in2 = network.add_input("input2", dtype=trt.float32, shape=(3, 3, 4))
shape2 = network.add_shape(in2)
identity = network.add_identity(in1)
cond = network.add_elementwise(shape1.get_output(0), shape2.get_output(0), op=trt.ElementWiseOperation.EQUAL)
assertion = network.add_assertion(cond.get_output(0), message="Should fail")
network.mark_output(identity.get_output(0))

inputs[in1.name] = np.zeros(shape=(2, 4))
outputs[identity.get_output(0).name] = identity.get_output(0).shape

C++ API

有关 C++ IAssertionLayer 算子的更多信息,请参阅 C++ IAssertionLayer 文档

Python API

有关 Python IAssertionLayer 算子的更多信息,请参阅 Python IAssertionLayer 文档