Slice#
slice 操作提取张量的一部分
\( Y = X[start_0:end_0, start_1:end_1, ..., start_n:end_n] \)
其中 \(X\) 是输入张量,\(Y\) 是输出张量,并且 \(start_i\) 和 \(end_i\) 是第 \(i\) 个维度的起始和结束索引。
该操作允许在任意数量的维度上进行灵活切片,支持 Python 样式的切片语法,包括 start、stop 和 step 参数。
C++ API#
std::shared_ptr<Tensor_attributes>
Slice(std::shared_ptr<Tensor_attributes> input, Slice_attributes);
Slice 属性是一个带有 setter 的轻量级结构
Slice_attributes&
set_slices(std::vector<std::pair<int64_t, int64_t>> const value)
Slice_attributes&
set_name(std::string const&)
Slice_attributes&
set_compute_data_type(DataType_t value)
Python API:#
slice
input
要切片的输入张量
slices
Python slice 对象列表,每个维度一个
name
操作的可选名称
compute_data_type
操作的可选计算数据类型
使用示例
# Create an input tensor
input_tensor = graph.tensor(dims = [4, 8, 16])
# Perform slicing
sliced_tensor = graph.slice(input_tensor,
slices=[slice(1, 3), slice(2, 6), slice(0, 16)],
name="my_slice",
compute_data_type=cudnn.float32)