TensorRT 10.8.0
nvinfer1::IScatterLayer 类参考

网络定义中的一个 scatter 层。支持多种 scatter 类型。 更多...

#include <NvInfer.h>

nvinfer1::IScatterLayer 的继承关系图
nvinfer1::ILayer nvinfer1::INoCopy

公共成员函数

void setMode (ScatterMode mode) noexcept
 设置 scatter 模式。 更多...
 
ScatterMode getMode () const noexcept
 获取 scatter 模式。 更多...
 
void setAxis (int32_t axis) noexcept
 设置 ScatterMode::kELEMENTS 使用的轴。 更多...
 
int32_t getAxis () const noexcept
 获取轴。 更多...
 
- 继承自 nvinfer1::ILayer 的公共成员函数
LayerType getType () const noexcept
 返回层的类型。 更多...
 
void setName (char const *name) noexcept
 设置层的名称。 更多...
 
char const * getName () const noexcept
 返回层的名称。 更多...
 
int32_t getNbInputs () const noexcept
 获取层的输入数量。 更多...
 
ITensorgetInput (int32_t index) const noexcept
 获取与给定索引对应的层输入。 更多...
 
int32_t getNbOutputs () const noexcept
 获取层的输出数量。 更多...
 
ITensorgetOutput (int32_t index) const noexcept
 获取与给定索引对应的层输出。 更多...
 
void setInput (int32_t index, ITensor &tensor) noexcept
 使用特定的张量替换此层的输入。 更多...
 
void setPrecision (DataType dataType) noexcept
 在弱类型网络中,设置此层的首选或必需的计算精度。 更多...
 
DataType getPrecision () const noexcept
 获取此层的计算精度 更多...
 
bool precisionIsSet () const noexcept
 是否已为此层设置计算精度 更多...
 
void resetPrecision () noexcept
 重置此层的计算精度 更多...
 
void setOutputType (int32_t index, DataType dataType) noexcept
 在弱类型网络中,设置此层的输出类型。 更多...
 
DataType getOutputType (int32_t index) const noexcept
 获取此层的输出类型 更多...
 
bool outputTypeIsSet (int32_t index) const noexcept
 是否已为此层设置输出类型 更多...
 
void resetOutputType (int32_t index) noexcept
 重置此层的输出类型 更多...
 
void setMetadata (char const *metadata) noexcept
 设置此层的元数据。 更多...
 
char const * getMetadata () const noexcept
 获取层的元数据。 更多...
 

保护成员函数

virtual ~IScatterLayer () noexcept=default
 
- 继承自 nvinfer1::ILayer 的保护成员函数
virtual ~ILayer () noexcept=default
 
- 继承自 nvinfer1::INoCopy 的保护成员函数
 INoCopy ()=default
 
virtual ~INoCopy ()=default
 
 INoCopy (INoCopy const &other)=delete
 
INoCopyoperator= (INoCopy const &other)=delete
 
 INoCopy (INoCopy &&other)=delete
 
INoCopyoperator= (INoCopy &&other)=delete
 

保护属性

apiv::VScatterLayer * mImpl
 
- 继承自 nvinfer1::ILayer 的保护属性
apiv::VLayer * mLayer
 

详细描述

网络定义中的一个 scatter 层。支持多种 scatter 类型。

Scatter 层有三个输入张量:Data、Indices 和 Updates,一个输出张量 Output,以及一个 scatter 模式。当使用 kELEMENT 模式时,可以使用可选的轴参数。

  • Data 是一个秩 r >= 1 的张量,存储要在 Output 中复制的值。
  • Indices 是一个秩 q 的张量,确定在 Output 中写入新值的位置。秩 q 的约束取决于模式: ScatterMode::kND: q >= 1 ScatterMode::kELEMENT: q 必须与 r 相同
  • Updates 是一个秩 s >= 1 的张量,提供要写入 Output 的数据,由 Indices 中对应的位置指定。Updates 秩的约束取决于模式: ScatterMode::kND: s = r + q - shape(Indices)[-1] - 1 Scattermode::kELEMENT: s = q = r
  • Output 是一个与 Data 具有相同维度的张量,存储转换的结果值。它不能是形状张量。Data、Update 和 Output 的类型应相同,Indices 应为 DataType::kINT32DataType::kINT64 类型。

输出通过复制数据计算得出,然后根据索引更新其元素。Indices 如何解释取决于 ScatterMode。

ScatterMode::kND

The indices are interpreted as a tensor of rank q-1 of indexing tuples.
The axis parameter is ignored.

Given that data dims are {d_0,...,d_{r-1}} and indices dims are {i_0,...,i_{q-1}},
define k = indices[q-1], it follows that updates dims are {i_0,...,i_{q-2},d_k,...,d_{r-1}}
The updating can be computed by:
    foreach slice in indices[i_0,...,i_{q-2}]
        output[indices[slice]] = updates[slice]

ScatterMode::kELEMENT

Here "axis" denotes the result of getAxis().

For each element X of indices:
    Let J denote a sequence for the subscripts of X
    Let K = sequence J with element [axis] replaced by X
    output[K] = updates[J]

For example, if indices has dimensions [N,C,H,W] and axis is 2, then the updates happen as:

    for n in [0,n)
        for c in [0,n)
            for h in [0,n)
                for w in [0,n)
                    output[n,c,indices[n,c,h,w],w] = updates[n,c,h,w]

写入到相同的输出元素会导致未定义的行为。

警告
请勿从此类继承,因为这样做会破坏 API 和 ABI 的前向兼容性。

构造函数 & 析构函数文档

◆ ~IScatterLayer()

virtual nvinfer1::IScatterLayer::~IScatterLayer ( )
protectedvirtualdefaultnoexcept

成员函数文档

◆ getAxis()

int32_t nvinfer1::IScatterLayer::getAxis ( ) const
inlinenoexcept

获取轴。

◆ getMode()

ScatterMode nvinfer1::IScatterLayer::getMode ( ) const
inlinenoexcept

获取 scatter 模式。

另请参阅
setMode()

◆ setAxis()

void nvinfer1::IScatterLayer::setAxis ( int32_t  axis)
inlinenoexcept

设置 ScatterMode::kELEMENTS 使用的轴。

轴默认为 0。

◆ setMode()

void nvinfer1::IScatterLayer::setMode ( ScatterMode  mode)
inlinenoexcept

设置 scatter 模式。

另请参阅
getMode()

成员数据文档

◆ mImpl

apiv::VScatterLayer* nvinfer1::IScatterLayer::mImpl
protected

此类的文档由以下文件生成

  版权所有 © 2024 英伟达公司
  隐私政策 | 管理我的隐私 | 请勿出售或分享我的数据 | 服务条款 | 无障碍访问 | 公司政策 | 产品安全 | 联系我们