If¶
生成网络子图的有条件执行。true
和 false
子图不是显式用于定义运算符,而是表示输入和输出张量的集合。
另请参阅
输入¶
condition 类型为 T1
的张量
inputs 类型为 T2
的张量
输出¶
outputs 类型为 T2
的张量
数据类型¶
T1: bool
T2: bool
, int32
, float16
, float32
, bfloat16
形状信息¶
condition 是标量(零维张量)。
inputs 输入张量的数量及其形状对于每个子图可以不同。
outputs 输出张量的数量必须相同。对于每对对应的输出,它们的形状必须相等,除非条件是构建时常量。
示例¶
If (条件)
condition = network.add_input(name="condition", shape=(), dtype=trt.bool)
true_inp = network.add_input(name="true_input", shape=(1, 1), dtype=trt.float32)
false_inp = network.add_input(name="false_input", shape=(1, 1), dtype=trt.float32)
conditional = network.add_if_conditional()
conditional.set_condition(condition)
true_cond_inp = conditional.add_input(true_inp)
false_cond_inp = conditional.add_input(false_inp)
output = conditional.add_output(true_cond_inp.get_output(0), false_cond_inp.get_output(0))
network.mark_output(output.get_output(0))
inputs[condition.name] = np.array(True)
inputs[true_inp.name] = np.array([5.0])
inputs[false_inp.name] = np.array([0.0])
outputs[output.get_output(0).name] = output.get_output(0).shape
expected[output.get_output(0).name] = np.array([5.0])
带有 ElementWise 子图的 If
condition = network.add_input("condition", dtype=trt.bool, shape=())
in1 = network.add_input(name="input1", shape=(2, 2), dtype=trt.float32)
in2 = network.add_input(name="input2", shape=(1, 2), dtype=trt.float32)
conditional = network.add_if_conditional()
conditional.set_condition(condition)
cond_inp1 = conditional.add_input(in1)
cond_inp2 = conditional.add_input(in2)
true_elemwise = network.add_elementwise(cond_inp1.get_output(0), cond_inp2.get_output(0), op=trt.ElementWiseOperation.PROD)
false_elemwise = network.add_elementwise(cond_inp1.get_output(0), cond_inp2.get_output(0), op=trt.ElementWiseOperation.SUM)
output = conditional.add_output(true_elemwise.get_output(0), false_elemwise.get_output(0))
network.mark_output(output.get_output(0))
inputs[condition.name] = np.array(False)
inputs[in1.name] = np.array(
[
[5.0, 7.8],
[-3.2, 4.6],
]
)
inputs[in2.name] = np.array(
[
[1.0, -1.0],
]
)
outputs[output.get_output(0).name] = output.get_output(0).shape
expected[output.get_output(0).name] = np.array([[6.0, 6.8], [-2.2, 3.6]])
C++ API¶
有关 C++ IConditionalLayer 运算符的更多信息,请参阅 C++ IConditionalLayer 文档。
Python API¶
有关 Python IConditionalLayer 运算符的更多信息,请参阅 Python IConditionalLayer 文档。