使用条件语句#

NVIDIA TensorRT 支持条件 if-then-else 控制流。TensorRT 条件语句用于实现网络子图的有条件执行。

定义条件语句#

条件边界层定义一个 if 条件语句

  • IConditionLayer 表示谓词,并指定条件语句应执行真分支(then-branch)还是假分支(else-branch)。

  • IIfConditionalInputLayer 指定条件分支之一的输入。

  • IIfConditionalOutputLayer 指定来自条件语句的输出。

每个边界层都继承自 IIfConditionalBoundaryLayer 类,该类具有一个方法 getConditional(),用于获取与其关联的 IIfConditionalIIfConditional 实例标识条件语句。具有相同 IIfConditional 的所有条件边界层都属于该条件语句。

一个条件语句必须正好有一个 IConditionLayer 实例,零个或多个 IIfConditionalInputLayer 实例,以及至少一个 IIfConditionalOutputLayer 实例。

IIfConditional 实现 if-then-else 控制流结构,该结构基于动态布尔输入提供网络子图的有条件执行。它由布尔标量谓词 condition 和两个分支子图定义:trueSubgraph,当 condition 的计算结果为 true 时执行;以及 falseSubgraph,当 condition 的计算结果为 false 时执行。

If condition is true then:
    output = trueSubgraph(trueInputs);
Else
    output = falseSubgraph(falseInputs);
Emit output

true 分支和 false 分支都必须以类似于许多编程语言中的三元运算符的方式定义。

要定义 if 条件语句,请使用 INetworkDefinition::addIfConditional 创建 IIfConditional 实例,然后添加边界层和分支层。

IIfConditional* simpleIf = network->addIfConditional();

IIfConditional::setCondition 方法接受一个参数:条件张量。这个 0D 布尔张量(标量)可以由网络中较早的层动态计算。它用于决定执行哪个分支。IConditionLayer 具有单个输入(条件)且没有输出,因为它在条件语句实现中内部使用。

// Create a condition predicate that is also a network input.
auto cond = network->addInput("cond", DataType::kBOOL, Dims{0});
IConditionLayer* condition = simpleIf->setCondition(*cond);

TensorRT 不支持用于实现条件分支的子图抽象,而是使用 IIfConditionalInputLayerIIfConditionalOutputLayer 来定义条件语句的边界。

  • IIfConditionalInputLayer 抽象化 IIfConditional 的一个或两个分支子图的单个输入。特定 IIfConditionalInputLayer 的输出可以同时馈送到两个分支。

    // Create an if-conditional input.
    // x is some arbitrary Network tensor.
    IIfConditionalInputLayer* inputX = simpleIf->addInput(*x);
    

then-branch 和 else-branch 的输入不必是相同的类型和形状。每个分支可以独立包含零个或多个输入。

IIfConditionalInputLayer 是可选的,用于控制哪些层将成为分支的一部分(请参阅 条件执行)。如果分支的所有输出都不依赖于 IIfConditionalInputLayer 实例,则该分支为空。当条件为 false 时没有要评估的层,并且网络评估应在条件语句之后继续时,空 else-branch 可能很有用(请参阅 条件示例)。

  • IIfConditionalOutputLayer 抽象化 if 条件语句的单个输出。它有两个输入:来自 trueSubgraph 的输出(输入索引 0)和来自 falseSubgraph 的输出(输入索引 1)。IIfConditionalOutputLayer 的输出可以被视为将在运行时确定的最终输出的占位符。

IIfConditionalOutputLayer 的作用类似于传统 SSA 控制流图中的 Φ (Phi) 函数节点。其语义是:选择 trueSubgraphfalseSubgraph 的输出。

// trueSubgraph and falseSubgraph represent network subgraphs
IIfConditionalOutputLayer* outputLayer = simpleIf->addOutput(
    *trueSubgraph->getOutput(0),
    *falseSubgraph->getOutput(0));

IIfConditional 的所有输出都必须源自 IIfConditionalOutputLayer 实例。

没有输出的 if 条件语句不会影响网络的其余部分。因此,它被认为是格式错误的。每个分支(子图)也必须至少有一个输出。if 条件语句的输出可以标记为网络的输出,除非该 if 条件语句嵌套在另一个 if 条件语句或循环内。

An if-conditional construct abstract model

条件执行#

网络层的条件执行是一种网络评估策略,其中仅当需要分支输出的值时才执行分支层(属于条件子图的层)。在条件执行中,真分支或假分支之一被执行并允许更改网络状态。

相反,在谓词执行中,真分支和假分支都执行,并且仅允许其中一个更改网络评估状态,具体取决于条件谓词的值(即,只有子图之一的输出被馈送到后续层)。

条件执行有时称为延迟评估,而谓词执行有时称为及早评估

IIfConditionalInputLayer 的实例可用于指定哪些层被及早调用,哪些层被延迟调用。这是通过向后跟踪网络层来完成的,从每个条件输出开始。数据依赖于至少一个 IIfConditionalInputLayer 输出的层被认为是条件语句的内部层,因此会被延迟评估。在极端情况下,如果没有向条件语句添加 IIfConditionalInputLayer 实例,则所有层都将被及早执行,类似于 ISelectLayer

以下三个图表描述了 IIfConditionalInputLayer 放置的选择如何控制执行调度。

Controlling conditional-execution using IIfConditionalInputLayer placement

在图 A 中,真分支包含三个层(T1、T2、T3)。当条件计算结果为 true 时,这些层会被延迟执行。

在图 B 中,输入层 I1 放置在层 T1 之后,这会将 T1 移出真分支。层 T1 在评估 if 构造之前及早执行。

在图 C 中,输入层 I1 被移除,这会将 T3 移出条件语句。T2 的输入被重新配置以创建合法的网络,并且 T2 也移出真分支。当条件计算结果为 true 时,条件语句不计算任何内容,因为输出已经及早计算出来(但它确实将其条件相关输入复制到其输出)。

嵌套和循环#

条件分支可以嵌套其他条件语句,也可以嵌套循环。循环可以嵌套条件语句。与循环嵌套一样,TensorRT 从数据流中推断条件语句和循环的嵌套。例如,如果条件语句 B 使用在循环 A 内部定义的值,则 B 被认为嵌套在 A 内部。

真分支到假分支层之间,反之亦然,不能有交叉边。换句话说,一个分支的输出不能依赖于另一个分支中的层。

有关如何指定嵌套的示例,请参阅 条件示例 部分。

局限性#

真/假子图分支中的输出张量数量必须相同。每个分支输出张量的类型和形状必须相同。

请注意,这比 ONNX 规范更受约束,ONNX 规范要求真/假子图具有相同数量的输出并使用相同的输出类型,但允许不同的输出形状。

条件示例#

简单 If 条件语句#

以下示例展示了如何实现一个简单的条件语句,该条件语句有条件地对两个张量执行算术运算。

1condition = true
2If condition is true:
3        output = x + y
4Else:
5        output = x - y
 1ITensor* addCondition(INetworkDefinition& n, bool predicate)
 2{
 3    // The condition value is a constant int32 input that is cast to boolean because TensorRT doesn't support boolean constant layers.
 4
 5    static const Dims scalarDims = Dims{0, {}};
 6    static float constexpr zero{0};
 7    static float constexpr one{1};
 8
 9    float* const val = predicate ? &one : &zero;
10
11    ITensor* cond =
12        n.addConstant(scalarDims, DataType::kINT32, val, 1})->getOutput(0);
13
14    auto* cast = n.addIdentity(cond);
15    cast->setOutputType(0, DataType::kBOOL);
16    cast->getOutput(0)->setType(DataType::kBOOL);
17
18    return cast->getOutput(0);
19}
20
21IBuilder* builder = createInferBuilder(gLogger);
22INetworkDefinition& n = *builder->createNetworkV2(0U);
23auto x = n.addInput("x", DataType::kFLOAT, Dims{1, {5}});
24auto y = n.addInput("y", DataType::kFLOAT, Dims{1, {5}});
25ITensor* cond = addCondition(n, true);
26
27auto* simpleIf = n.addIfConditional();
28simpleIf->setCondition(*cond);
29
30// Add input layers to demarcate entry into true/false branches.
31x = simpleIf->addInput(*x)->getOutput(0);
32y = simpleIf->addInput(*y)->getOutput(0);
33
34auto* trueSubgraph = n.addElementWise(*x, *y, ElementWiseOperation::kSUM)->getOutput(0);
35auto* falseSubgraph = n.addElementWise(*x, *y, ElementWiseOperation::kSUB)->getOutput(0);
36
37auto* output = simpleIf->addOutput(*trueSubgraph, *falseSubgraph)->getOutput(0);
38n.markOutput(*output);

从 PyTorch 导出#

以下示例展示了如何将脚本化的 PyTorch 代码导出到 ONNX。函数 sum_even 中的代码执行嵌套在循环中的 if 条件语句。

import torch.onnx
import torch
import tensorrt as trt
import numpy as np

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

@torch.jit.script
def sum_even(items):
    s = torch.zeros(1, dtype=torch.float)
    for c in items:
        if c % 2 == 0:
            s += c
    return s

class ExampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, items):
        return sum_even(items)

def build_engine(model_file):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network()
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(model_file, 'rb') as model:
        assert parser.parse(model.read())
        return builder.build_engine(network, config)

def export_to_onnx():
    items = torch.zeros(4, dtype=torch.float)
    example = ExampleModel()
    torch.onnx.export(example, (items), "example.onnx", verbose=False, opset_version=13, enable_onnx_checker=False, do_constant_folding=True)

export_to_onnx()
build_engine("example.onnx")