Slice#

slice 操作提取张量的一部分

\( Y = X[start_0:end_0, start_1:end_1, ..., start_n:end_n] \)

其中 \(X\) 是输入张量,\(Y\) 是输出张量,并且 \(start_i\)\(end_i\) 是第 \(i\) 个维度的起始和结束索引。

该操作允许在任意数量的维度上进行灵活切片,支持 Python 样式的切片语法,包括 start、stop 和 step 参数。

C++ API#

std::shared_ptr<Tensor_attributes>
Slice(std::shared_ptr<Tensor_attributes> input, Slice_attributes);

Slice 属性是一个带有 setter 的轻量级结构

Slice_attributes&
set_slices(std::vector<std::pair<int64_t, int64_t>> const value)

Slice_attributes&
set_name(std::string const&)

Slice_attributes&
set_compute_data_type(DataType_t value)

Python API:#

  • slice

    • input

      • 要切片的输入张量

    • slices

      • Python slice 对象列表,每个维度一个

    • name

      • 操作的可选名称

    • compute_data_type

      • 操作的可选计算数据类型

使用示例

# Create an input tensor

input_tensor = graph.tensor(dims = [4, 8, 16])

# Perform slicing
sliced_tensor = graph.slice(input_tensor, 
                            slices=[slice(1, 3), slice(2, 6), slice(0, 16)],
                            name="my_slice",
                            compute_data_type=cudnn.float32)