使用 cuBLASDx 的通用矩阵乘法#
在本简介中,我们将使用 cuBLASDx 库执行通用矩阵乘法。公开了此操作的三个变体
Shared memory API
: \(\mathbf{C}_{m\times n} = {\alpha} \times \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + {\beta} \times \mathbf{C}_{m\times n}\)Register API
With accumulator
: \(\mathbf{C}_{m\times n} = \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + \mathbf{C}_{m\times n}\)Without accumulator
: \(\mathbf{C}_{m\times n} = \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n}\)
本节基于 cuBLASDx 附带的 introduction_example.cu 示例。请参阅 示例 部分以查看其他 cuBLASDx 示例。
定义 GEMM 操作#
第一步是定义我们要执行的 GEMM。这通过将 cuBLASDx 运算符组合在一起来创建 GEMM 描述来完成。每次添加新运算符时,都会在编译时评估此类型的正确性。一个定义良好的 cuBLASDx GEMM 例程描述必须包括两部分
选定的线性代数例程。在本例中,它是矩阵乘法:
cublasdx::function::MM
。输入和输出的有效且充分的描述:矩阵的维度(
m
、n
、k
),精度(半精度、单精度、双精度等),数据类型(实数或复数)以及矩阵的数据排列(行优先或列优先)。
要获取由下式描述的任何操作的描述符
\(\mathbf{C}_{m\times n} = \left[ {\alpha} \ \times \ \right] \ \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} \ \left[\ + {\beta} \times \mathbf{C}_{m\times n} \right]\)
其中 m = n = k = 32
,我们只需要编写以下几行
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32 /* m */, 32 /* n */, 32 /* k */>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major /* A */, cublasdx::col_major /* B */>());
为了编码操作属性,cuBLASDx 提供了运算符 Size、Precision、Type、Function 和 Arrangement,它们可以使用标准加法运算符 (+
) 组合。
可选地,用户可以使用 Alignment 和 LeadingDimension 分别为每个矩阵设置对齐方式和前导维度。当使用与计算类型不同的自定义输入时,必须将 对齐运算符 设置为适当的值。
对于前导维度,也可以在执行期间动态设置它们,但是,值得注意的是,这可能会对性能产生影响。
提示
cuBLASDx 还支持无法简单地通过行优先或列优先和前导维度表示的矩阵。请参阅 simple_gemm_custom_layout.cu 示例。
要获得在 CUDA 块级别执行 GEMM 的完全可用的操作,我们需要提供至少两个附加信息
第一个是 SM 运算符,它指示我们要在其上运行 GEMM 的目标 CUDA 架构。每个 GPU 架构都不同,因此每个架构都可以使用不同的实现,并且可能需要不同的 CUDA 块大小才能获得最佳性能。在 introduction_example.cu 示例中,这作为模板参数传递,但在这里我们可以假设我们以 Volta GPU(
SM<700>()
)为目标。最后,我们使用 Block 运算符 来表明 BLAS 例程将由单个 CUDA 块中的多个线程执行。此时,cuBLASDx 执行额外的验证,以确保提供的描述有效,并且可以在请求的架构上执行它。
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
用户还可以指定将执行 GEMM 的布局和线程数。这可以通过 BlockDim 运算符 完成。添加 BlockDim<X, Y, Z>
意味着只有在使用块维度 dim3(X1, Y1, Z1)
启动内核时,GEMM 才能正常工作,其中 X1 >= X
,Y1 >= Y
和 Z1 >= Z
。详细要求可以在专用于 BlockDim 运算符的部分中找到。如果未使用 BlockDim
运算符,则 cuBLASDx 将选择首选的块大小,该大小可以使用 GEMM::block_dim
获得。
提示
如果没有必要设置自定义块维度,建议不要使用 BlockDim
运算符,而依赖于 GEMM::block_dim
。有关更多详细信息,请参阅 Block Execute Method 部分、BlockDim 运算符 和 Suggested Block Dim Trait。
对于本示例,我们假设我们要使用具有 256 个线程的 1D CUDA 线程块。
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block()
+ BlockDim<256>());
执行 GEMM#
描述矩阵乘法的类 GEMM
可以实例化为对象(或多个对象)。形成对象没有计算成本,应视为句柄。函数描述符对象提供计算方法 execute(...)
,用于执行请求的函数。
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
__global__ void gemm_kernel(double alpha, double *a, double *b, double beta, double *c) {
// Execute GEMM
GEMM().execute(/* What are the arguments? */);
}
从 cuBLASDx 0.2.0 开始,execute 方法采用张量 (cublasdx::tensor
) 作为输入和输出。cublasdx::tensor
是 CuTe 张量 (cute::Tensor) 的别名,它是多维数组的表示形式,其中包含
任何类型的内存中的数据,包括全局内存、共享内存和寄存器内存,以及
CuTe 布局 (cute::Layout),描述了元素的组织方式。
张量创建#
张量分区#
从 cuBLASDx 0.3.0
和寄存器片段 API 开始,库提供了新的接口,允许在参与 GEMM 的线程之间高效地划分全局和共享内存张量,以及随后修改这些寄存器张量。这些操作的入口点是与特定 GEMM 实例绑定的 Partitioner 对象
auto partitioner = GEMM::get_partitioner(); auto partitioner = GEMM::suggest_partitioner();
这样的对象允许
为此 GEMM 创建寄存器片段累加器
将片段索引映射到全局张量索引
分区其他张量,如
C
以获取其子张量将谓词应用于越界元素和线程
有关更多详细信息,请参阅 分区器和寄存器片段张量。
警告
寄存器片段只能与用于创建它的 GEMM 用作累加器
寄存器片段累加器#
寄存器片段累加器是存储在线程本地寄存器文件 (RF) 内存中的数组,包装在具有描述内部 GEMM 执行的不透明布局的 cublasdx::tensor
中。与全局内存和共享内存张量相反,此布局可能不是任意的,只能从分区器对象获得(请参阅 分区器和寄存器片段张量)。
注意
寄存器片段是一个不透明的分层张量,公开了一个 1D 张量接口
任何寄存器片段的特定布局的实现细节都与 GEMM 执行相关联,但是可以使用从 0 到 cublasdx::size(register_fragment)
范围内的 1D 索引访问每个片段
每个寄存器片段累加器表示全局或共享内存矩阵的片段,由保存它的线程的索引和从中创建它的 GEMM
实例确定。cuBLASDx 公开了两种将内存从线程本地索引空间映射到整个张量索引空间的方法
通过分区器对象的手动索引映射实用程序
通过具有收集/分散语义的自动复制功能。
要获取 GEMM 实例的寄存器片段,只需获取分区器并使用它来创建未初始化的累加器即可
auto partitioner = BLAS::get_partitioner();
auto c_fragment_accumulator = partitioner.make_accumulator_fragment();
// Now you can access it as a regular 1D tensor:
auto val_0 = c_fragment_accumulator(0);
复制张量#
复制寄存器片段#
为了复制具有 GEMM 结果的寄存器片段累加器,cuBLASDx 提供了一个辅助函数 cublasdx::copy_fragment(...)
,负责在本地张量片段和全局/共享张量中的适当位置之间执行加载和存储。
该函数考虑了给定的对齐方式,并尝试在可能的情况下向量化加载和存储。
- 此复制是每个线程的操作,全局/共享数据分区基于
包含适当 GEMM 执行细节的分区器对象
线程索引(包含在分区器对象中)
Partitioner 对象提供了许多辅助 API,从而在数据操作中具有很大的灵活性。有关更多详细信息,请参阅 分区器和寄存器片段张量。
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
auto partitioner = GEMM::get_partitioner();
auto c_fragment_accumulator = partitioner.make_accumulator_fragment();
// Load data from global to registers
cublasdx::copy_fragment<alignment::a>(c_global_tensor, c_fragment_accumulator, partitioner);
// Load data from shared to registers
cublasdx::copy_fragment<alignment::a>(c_shared_tensor, c_fragment_accumulator, partitioner);
// Store data from registers to global
cublasdx::copy_fragment<alignment::a>(c_fragment_accumulator, c_global_tensor, partitioner);
// Store data from registers to shared
cublasdx::copy_fragment<alignment::a>(c_fragment_accumulator, c_shared_tensor, partitioner);
累加器寄存器 GEMM API#
register Accumulation API GEMM
内核的典型结构如下
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
// Type <a/b/c>_value_type is defined based on the GEMM description. Precision operator defines its numerical
// precision, and via Type operator user specifies if it is complex or real.
//
// In this case, a/b/c_value_type are all double since set precision is double, and type is real.
using a_value_type = typename GEMM::a_value_type;
using b_value_type = typename GEMM::b_value_type;
using c_value_type = typename GEMM::c_value_type;
__global__ void gemm_kernel_registers_accumulation(a_value_type *a, b_value_type *b, c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Create global memory tensor
// a_global_tensor = (from a)
// b_global_tensor = (from b)
// c_global_tensor = (from c)
// Make shared memory tensor
// a_shared_tensor = (from smem)
// b_shared_tensor = (from smem + ...)
// Load data from global memory tensor to shared memory tensor
// a_shared_tensor <-- a_global_tensor
// b_shared_tensor <-- b_global_tensor
// Make C register Accumulator fragment
// c_register_accumulator = (from GEMM)
// Load appropriate data from global memory tensor to register fragment tensor
// c_register_accumulator <- c_global_tensor
// Execute GEMM
GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_accumulator);
__syncthreads();
// Store data from shared memory tensor to global memory tensor
// c_global_tensor <-- c_register_accumulator
}
此 API 更加复杂,为 C 累加器添加了额外的步骤
创建全局和共享内存张量(请参阅 张量创建)。
将数据从全局内存张量复制到共享内存张量(请参阅 复制张量)。
创建寄存器内存
C
累加器张量将全局输入张量
C
的适当部分复制到寄存器内存中(主要步骤)使用张量 API 执行
GEMM
。将数据从寄存器累加器张量复制到全局内存张量中的适当位置(请参阅 复制张量)。
在用张量创建和复制代码填充所有这些步骤后,我们得到
#include <cublasdx.hpp>
using namespace cublasdx;
template<class GEMM>
__global__ void gemm_kernel_registers_accumulation(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Get default data partitioner
auto partitioner = GEMM::get_partitioner();
// Create register fragment Accumulator
auto c_register_fragment = partitioner.make_accumulator_fragment();
// Partition Global C for GEMM and load appropriate elements into register fragment
cublasdx::copy_fragment<alignment::c>(c_global_tensor, c_register_fragment, partitioner);
// Execute GEMM with accumulation
GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_fragment);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
返回值寄存器 GEMM API#
Return Value register API GEMM
内核的典型结构如下
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
// Type <a/b/c>_value_type is defined based on the GEMM description. Precision operator defines its numerical
// precision, and via Type operator user specifies if it is complex or real.
//
// In this case, a/b/c_value_type are all double since set precision is double, and type is real.
using a_value_type = typename GEMM::a_value_type;
using b_value_type = typename GEMM::b_value_type;
using c_value_type = typename GEMM::c_value_type;
__global__ void gemm_kernel(c_value_type alpha, a_value_type *a, b_value_type *b, c_value_type beta, c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Create global memory tensor
// a_global_tensor = (from a)
// b_global_tensor = (from b)
// c_global_tensor = (from c)
// Make shared memory tensor
// a_shared_tensor = (from smem)
// b_shared_tensor = (from smem + ...)
// Load data from global memory tensor to shared memory tensor
// a_shared_tensor <-- a_global_tensor
// b_shared_tensor <-- b_global_tensor
// Execute GEMM
auto [c_register_fragment, ...] =
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
// Partition Global C for GEMM and store appropriate elements to global memory
// c_global_tensor <-- c_register_fragment
}
此 API 不会预先期望寄存器片段,而是将其作为结果返回
创建全局和共享内存张量(请参阅 张量创建)。
将数据从全局内存张量复制到共享内存张量(请参阅 复制张量)。
(主要步骤)使用张量 API 执行
GEMM
,以寄存器片段的形式获取结果。将数据从寄存器累加器张量复制到全局内存张量中的适当位置(请参阅 复制张量)。
在用张量创建和复制代码填充所有这些步骤后,我们得到
#include <cublasdx.hpp>
using namespace cublasdx;
template<class GEMM>
__global__ void gemm_kernel_registers(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM and get register fragment results and data partitioner in return
auto [c_register_fragment, partitioner] = GEMM().execute(a_shared_tensor, b_shared_tensor);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
启动 GEMM 内核#
要启动执行定义的 GEMM
的内核,我们需要知道所需的块维度和所有三个矩阵(A
、B
、C
)所需的共享内存量。矩阵 A
中的元素应采用行优先格式,而矩阵 B
和 C
应采用列优先格式,并考虑前导维度。
#include <cublasdx.hpp>
using namespace cublasdx;
// Kernels are unfolded in their appropriate sections above
template<class GEMM>
__global__ void gemm_kernel_shared(GEMM::c_value_type alpha, GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type beta, GEMM::c_value_type *c)
{
...
}
template<class GEMM>
__global__ void gemm_kernel_registers_accumulation(GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type *c);
{
...
}
template<class GEMM>
__global__ void gemm_kernel_registers(GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type *c);
{
...
}
// CUDA_CHECK_AND_EXIT - marco checks if function returns cudaSuccess; if not it prints the error code and exits the program
void introduction_example(value_type alpha, value_type *a, value_type *b, value_type beta, value_type *c) {
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ Function<function::MM>());
+ SM<700>()
+ Block());
// Shared memory API: C = alpha * A * B + beta * C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_shared<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size<GEMM>()>>>(1.0, a, b, 1.0, c);
// Register fragment Accumulation API: C = A * B + C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers_accumulation<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
// Register fragment API: C = A * B
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
CUDA_CHECK_AND_EXIT(cudaPeekAtLastError());
CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
}
所需的共享内存可以使用 cublasdx::get_shared_storage_size<GEMM>()
和 cublasdx::get_shared_storage_size_ab<GEMM>()
获得。它考虑了使用 LeadingDimension 运算符 声明的任何填充以及由 Alignment 运算符 产生的填充。
为简单起见,在示例中,我们为设备矩阵分配托管内存,假设使用 Volta 架构,并且不检查 CUDA API 函数返回的 CUDA 错误代码。请查看完整的 introduction_example.cu 示例以及 cuBLASDx 附带的其他示例,以获取更详细的代码。
#include <iostream>
#include <vector>
#include <cuda_runtime_api.h>
#include <cublasdx.hpp>
#include "common.hpp"
#include "reference.hpp"
template<class GEMM>
__global__ void gemm_kernel_shared(const typename GEMM::c_value_type alpha,
const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
const typename GEMM::c_value_type beta,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b, smem_c] = cublasdx::slice_shared_memory<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
__syncthreads();
// Store data from shared memory tensor to global memory tensor
cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
}
template<class GEMM>
__global__ void gemm_kernel_registers_accumulation(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Get default partitioner
auto partitioner = GEMM::get_partitioner();
// Create register fragment Accumulator
auto c_register_fragment = partitioner.make_accumulator_fragment();
// Partition Global C for GEMM and load appropriate elements into register fragment
cublasdx::copy_fragment<alignment::c>(c_global_tensor, c_register_fragment, partitioner);
// Execute GEMM with accumulation
GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_fragment);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
template<class GEMM>
__global__ void gemm_kernel_registers(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM and get register fragment results and data partitioner in return
auto [c_register_fragment, partitioner] = GEMM().execute(a_shared_tensor, b_shared_tensor);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
template<unsigned int Arch>
int introduction_example() {
using GEMM = decltype(cublasdx::Size<32, 32, 32>()
+ cublasdx::Precision<double>()
+ cublasdx::Type<cublasdx::type::real>()
+ cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ cublasdx::Function<cublasdx::function::MM>()
+ cublasdx::SM<700>()
+ cublasdx::Block()
+ cublasdx::BlockDim<256>());
using value_type = typename example::uniform_value_type_t<GEMM>;
constexpr auto global_a_size = example::global_memory_size_of<GEMM>::a_size;
constexpr auto global_b_size = example::global_memory_size_of<GEMM>::b_size;
constexpr auto global_c_size = example::global_memory_size_of<GEMM>::c_size;
// Allocate managed memory for A, B, C matrices in one go
value_type* abc;
auto size = global_a_size + global_b_size + global_c_size;
auto size_bytes = size * sizeof(value_type);
CUDA_CHECK_AND_EXIT(cudaMallocManaged(&abc, size_bytes));
// Generate data
for (size_t i = 0; i < size; i++) {
abc[i] = double(i / size);
}
value_type* a = abc;
value_type* b = abc + global_a_size;
value_type* c = abc + global_a_size + global_b_size;
// Shared memory API: C = alpha * A * B + beta * C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_shared<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size<GEMM>()>>>(1.0, a, b, 0.5, c);
// Register fragment Accumulation API: C = A * B + C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers_accumulation<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
// Register fragment API: C = A * B
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
CUDA_CHECK_AND_EXIT(cudaPeekAtLastError());
CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
CUDA_CHECK_AND_EXIT(cudaFree(abc));
std::cout << "Success" << std::endl;
return 0;
}
struct introduction_example_functor {
template<int Arch>
int operator()(std::integral_constant<int, Arch>) {
return introduction_example<Arch>();
}
};
int main(int, char**) {
return example::sm_runner(introduction_example_functor{});
}
重要的是要注意,与 cuBLAS 库不同,cuBLASDx 不需要在执行 BLAS 操作后将数据移回全局内存。它也不要求从全局内存加载输入数据。对于某些用例,这些属性可能是主要的性能优势。可能的优化列表包括但不限于
将 BLAS 例程与自定义预处理和后处理融合。
将多个 BLAS 操作融合在一起。
将 BLAS 和 FFT 操作(使用 cuFFTDx)融合在一起。
生成输入矩阵或其部分。
编译#
有关如何使用 cuBLASDx 编译程序的说明,请参阅 快速安装指南。