If

生成网络子图的有条件执行。truefalse 子图不是显式用于定义运算符,而是表示输入和输出张量的集合。

另请参阅

使用条件语句

输入

condition 类型为 T1 的张量

inputs 类型为 T2 的张量

输出

outputs 类型为 T2 的张量

数据类型

T1: bool

T2: bool, int32, float16, float32, bfloat16

形状信息

condition 是标量(零维张量)。

inputs 输入张量的数量及其形状对于每个子图可以不同。

outputs 输出张量的数量必须相同。对于每对对应的输出,它们的形状必须相等,除非条件是构建时常量。

示例

If (条件)
condition = network.add_input(name="condition", shape=(), dtype=trt.bool)
true_inp = network.add_input(name="true_input", shape=(1, 1), dtype=trt.float32)
false_inp = network.add_input(name="false_input", shape=(1, 1), dtype=trt.float32)
conditional = network.add_if_conditional()
conditional.set_condition(condition)

true_cond_inp = conditional.add_input(true_inp)
false_cond_inp = conditional.add_input(false_inp)
output = conditional.add_output(true_cond_inp.get_output(0), false_cond_inp.get_output(0))
network.mark_output(output.get_output(0))

inputs[condition.name] = np.array(True)
inputs[true_inp.name] = np.array([5.0])
inputs[false_inp.name] = np.array([0.0])

outputs[output.get_output(0).name] = output.get_output(0).shape
expected[output.get_output(0).name] = np.array([5.0])

带有 ElementWise 子图的 If
condition = network.add_input("condition", dtype=trt.bool, shape=())
in1 = network.add_input(name="input1", shape=(2, 2), dtype=trt.float32)
in2 = network.add_input(name="input2", shape=(1, 2), dtype=trt.float32)

conditional = network.add_if_conditional()
conditional.set_condition(condition)

cond_inp1 = conditional.add_input(in1)
cond_inp2 = conditional.add_input(in2)

true_elemwise = network.add_elementwise(cond_inp1.get_output(0), cond_inp2.get_output(0), op=trt.ElementWiseOperation.PROD)
false_elemwise = network.add_elementwise(cond_inp1.get_output(0), cond_inp2.get_output(0), op=trt.ElementWiseOperation.SUM)

output = conditional.add_output(true_elemwise.get_output(0), false_elemwise.get_output(0))
network.mark_output(output.get_output(0))

inputs[condition.name] = np.array(False)
inputs[in1.name] = np.array(
    [
        [5.0, 7.8],
        [-3.2, 4.6],
    ]
)
inputs[in2.name] = np.array(
    [
        [1.0, -1.0],
    ]
)

outputs[output.get_output(0).name] = output.get_output(0).shape
expected[output.get_output(0).name] = np.array([[6.0, 6.8], [-2.2, 3.6]])

C++ API

有关 C++ IConditionalLayer 运算符的更多信息,请参阅 C++ IConditionalLayer 文档

Python API

有关 Python IConditionalLayer 运算符的更多信息,请参阅 Python IConditionalLayer 文档