fused_rope.h
函数
-
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
将旋转位置嵌入应用于输入张量。
- 参数:
input – [in] 融合 rope 的输入张量。
freqs – [in] freqs 张量。
output – [out] 输出张量。
s – [in] 输入的 s 维度长度。
b – [in] 输入的 b 维度长度。
h – [in] 输入的 h 维度长度。
d – [in] 输入的 d 维度长度。
d2 – [in] freqs 的 d 维度长度。
stride_s – [in] 输入的 s 维度步幅。
stride_b – [in] 输入的 b 维度步幅。
stride_h – [in] 输入的 h 维度步幅。
stride_d – [in] 输入的 d 维度步幅。
o_stride_s – [in] 输出的 s 维度步幅。
o_stride_b – [in] 输出的 b 维度步幅。
o_stride_h – [in] 输出的 h 维度步幅。
o_stride_d – [in] 输出的 d 维度步幅。
stream – [in] 用于操作的 CUDA 流。
-
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, NVTETensor input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
计算融合 rope 的反向传播。
- 参数:
output_grads – [in] 反向传播的传入梯度张量。
freqs – [in] freqs 张量。
input_grads – [out] 要计算的输入梯度张量。
s – [in] output_grads 的 s 维度长度。
b – [in] output_grads 的 b 维度长度。
h – [in] output_grads 的 h 维度长度。
d – [in] output_grads 的 d 维度长度。
d2 – [in] freqs 的 d 维度长度。
stride_s – [in] output_grads 的 s 维度步幅。
stride_b – [in] output_grads 的 b 维度步幅。
stride_h – [in] output_grads 的 h 维度步幅。
stride_d – [in] output_grads 的 d 维度步幅。
o_stride_s – [in] input_grads 的 s 维度步幅。
o_stride_b – [in] input_grads 的 b 维度步幅。
o_stride_h – [in] input_grads 的 h 维度步幅。
o_stride_d – [in] input_grads 的 d 维度步幅。
stream – [in] 用于操作的 CUDA 流。
-
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor output, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
以 thd 格式将旋转位置嵌入应用于输入张量。
- 参数:
input – [in] 融合 rope 的输入张量。
cu_seqlens – [in] 序列长度张量的累积和。
freqs – [in] freqs 张量。
output – [out] 输出张量。
cp_size – [in] 上下文并行世界大小。
cp_rank – [in] 上下文并行等级。
max_s – [in] 最大序列长度。
b – [in] 批次大小。
h – [in] 输入的 h 维度长度。
d – [in] 输入的 d 维度长度。
d2 – [in] freqs 的 d 维度长度。
stride_t – [in] 输入的 t 维度步幅。
stride_h – [in] 输入的 h 维度步幅。
stride_d – [in] 输入的 d 维度步幅。
o_stride_t – [in] 输出的 t 维度步幅。
o_stride_h – [in] 输出的 h 维度步幅。
o_stride_d – [in] 输出的 d 维度步幅。
stream – [in] 用于操作的 CUDA 流。
-
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, const NVTETensor freqs, NVTETensor input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream)
以 thd 格式计算融合 rope 的反向传播。
- 参数:
output_grads – [in] 反向传播的传入梯度张量。
cu_seqlens – [in] 序列长度张量的累积和。
freqs – [in] freqs 张量。
input_grads – [out] 要计算的输入梯度。
cp_size – [in] 上下文并行世界大小。
cp_rank – [in] 上下文并行等级。
max_s – [in] 最大序列长度。
b – [in] 批次大小。
h – [in] output_grads 的 h 维度长度。
d – [in] output_grads 的 d 维度长度。
d2 – [in] freqs 的 d 维度长度。
stride_t – [in] output_grads 的 t 维度步幅。
stride_h – [in] output_grads 的 h 维度步幅。
stride_d – [in] output_grads 的 d 维度步幅。
o_stride_t – [in] input_grads 的 t 维度步幅。
o_stride_h – [in] input_grads 的 h 维度步幅。
o_stride_d – [in] input_grads 的 d 维度步幅。
stream – [in] 用于操作的 CUDA 流。