cuBLAS矩阵乘法性能分析(附代码示例)

使用教程

矩阵乘法是神经网络中最基础、最重要的一个运算。在用CUDA实现矩阵乘法时,不需要我们手动写,cuBLAS库提供了现成的矩阵乘法算子,例如cublasGemmExcublasLtMatmul。其中后者是轻量级版本,API调用更灵活。例如对于整数乘法,cublasLtMatmul支持int8的输入输出,而cublasGemmEx只支持int8输入,int32输出。

今天我只给大家讲解cublasGemmEx,主要使用起来相对更简洁一点。

官方文档地址:
https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx

经过翻阅网上各种教程,我找到了一篇我认为写的最好的博客。例子举得非常好,写的很详细。地址如下:
https://www.cnblogs.com/cuancuancuanhao/p/7763256.html

具体的使用方法可以参见上面这篇博客,我这里就不再赘述了。

今天我主要给大家演示一下,不同数据类型的矩阵乘法,速度和结果上到底有多大的差异?

测试代码

我写了一个简单的测试代码:

#include <sys/time.h>
#include <cuda_profiler_api.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdio.h>

int8_t float2int8(float f, float scale) {
    int8_t i = int8_t(f * scale);
    if (i < -127) i = -127;
    if (i > 127) i = 127;
    return i;
}

template <typename T, typename S>
void allocate_memory(int m, int n, int k, T **A, T **B, S **C) {
    cudaMallocManaged(A, m * k * sizeof(T));
    cudaMallocManaged(B, k * n * sizeof(T));
    cudaMallocManaged(C, m * n * sizeof(S));
}

template <typename T, typename S>
void free_memory(T *A, T *B, S *C) {
    cudaFree(A);
    cudaFree(B);
    cudaFree(C);
}

template <typename T, typename S>
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transA, cublasOperation_t transB,
                   int m, int n, int k, T *A, T *B, S *C, int lda, int ldb, int ldc,
                   S *alpha, S *beta, int algo) {
    cudaDataType_t AType, BType, CType, ComputeType;
    if (std::is_same<T, float>::value) {
        AType = BType = CType = ComputeType = CUDA_R_32F;
    } else if (std::is_same<T, __half>::value) {
        AType = BType = CType = ComputeType = CUDA_R_16F;
    } else if (std::is_same<T, int8_t>::value) {
        AType = BType = CUDA_R_8I;
        CType = ComputeType = CUDA_R_32I;
    } else {
        printf("Not supported data type.");
        return -1;
    }
    cublasStatus_t status;
    status = cublasGemmEx(handle,
                          transA,
                          transB,
                          m,
                          n,
                          k,
                          alpha,
                          A,
                          AType,
                          lda,
                          B,
                          BType,
                          ldb,
                          beta,
                          C,
                          CType,
                          ldc,
                          ComputeType,
                          static_cast<cublasGemmAlgo_t>(algo));

    if (status == CUBLAS_STATUS_SUCCESS)
        return 1;
    else
        return -1;
}

template <typename T, typename S>
void test_gemm(cublasHandle_t handle, int m, int n, int k, T *A, T *B, S *C,
               S *alpha, S *beta, int algo, int iteration) {
    float total_time = 0;
    for (int i = 0; i < iteration; ++i) {
        struct timeval start, end;
        cudaDeviceSynchronize();
        cudaProfilerStart();
        gettimeofday(&start, NULL);
        int success = cublas_gemm_ex(handle,
                                     CUBLAS_OP_N,
                                     CUBLAS_OP_N,
                                     n,
                                     m,
                                     k,
                                     B,
                                     A,
                                     C,
                                     n,
                                     k,
                                     n,
                                     alpha,
                                     beta,
                                     static_cast<cublasGemmAlgo_t>(algo));
        cudaDeviceSynchronize();
        gettimeofday(&end, NULL);
        cudaProfilerStop();
        if (success > 0 && i > 0)
            total_time += (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001;
    }
    if (total_time > 0)
        printf("algo %d: %.3f ms\n", algo, total_time / (iteration - 1));
}

int main() {
    int m = 4096, n = 8192, k = 1024;
    printf("shape: (%d, %d) x (%d, %d)\n", m, k, k, n);
    int start_algo = CUBLAS_GEMM_DEFAULT;
    int end_algo = CUBLAS_GEMM_ALGO23;
    int start_algo_t_op = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
    int end_algo_t_op = CUBLAS_GEMM_ALGO15_TENSOR_OP;
    int iteration = 10;

    float *fA, *fB, *fC;
    __half *hA, *hB, *hC;
    int8_t *iA, *iB; int32_t *iC;
    float f_alpha = 1, f_beta = 0;
    __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
    int32_t i_alpha = 1, i_beta = 0;
    allocate_memory(m, n, k, &fA, &fB, &fC);
    allocate_memory(m, n, k, &hA, &hB, &hC);
    allocate_memory(m, n, k, &iA, &iB, &iC);
    for (int i = 0; i < m * k; ++i) {
        fA[i] = float(i % 255 - 127) / 127;
        hA[i] = __float2half_rn(fA[i]);
        iA[i] = float2int8(fA[i], 127);
    } 
    for (int i = 0; i < k * n; ++i) {
        fB[i] = float(i % 255 - 127) / 127;
        hB[i] = __float2half_rn(fB[i]);
        iB[i] = float2int8(fB[i], 127);
    }
    cublasHandle_t handle;
    cublasCreate(&handle);

    printf(">>>>>>>>>>>>>>>>> test fp32 >>>>>>>>>>>>>>>>>\n");
    for (int algo = start_algo; algo <= end_algo; ++algo)
        test_gemm(handle, m, n, k, fA, fB, fC, &f_alpha, &f_beta, algo, iteration);
    for (int algo = start_algo_t_op; algo <= end_algo_t_op; ++algo)
        test_gemm(handle, m, n, k, fA, fB, fC, &f_alpha, &f_beta, algo, iteration);


    printf(">>>>>>>>>>>>>>>>> test fp16 >>>>>>>>>>>>>>>>>\n");
    for (int algo = start_algo; algo <= end_algo; ++algo)
        test_gemm(handle, m, n, k, hA, hB, hC, &h_alpha, &h_beta, algo, iteration);
    for (int algo = start_algo_t_op; algo <= end_algo_t_op; ++algo)
        test_gemm(handle, m, n, k, hA, hB, hC, &h_alpha, &h_beta, algo, iteration);

    printf(">>>>>>>>>>>>>>>>> test int8 >>>>>>>>>>>>>>>>>\n");
    for (int algo = start_algo; algo <= end_algo; ++algo)
        test_gemm(handle, m, n, k, iA, iB, iC, &i_alpha, &i_beta, algo, iteration);
    for (int algo = start_algo_t_op; algo <= end_algo_t_op; ++algo)
        test_gemm(handle, m, n, k, iA, iB, iC, &i_alpha, &i_beta, algo, iteration);

    printf(">>>>>>>>>>>>>>>>> compare result >>>>>>>>>>>>>>>>>\n");
    printf("fp32: ");
    for (int i = 0; i < 10; ++i)
        printf("%.5f%c", fC[i], " \n"[i==9]);
    printf("fp16: ");
    for (int i = 0; i < 10; ++i)
        printf("%.5f%c", float(hC[i]), " \n"[i==9]);
    printf("int8: ");
    for (int i = 0; i < 10; ++i)
        printf("%.5f%c", float(iC[i])/127/127, " \n"[i==9]);

    free_memory(iA, iB, iC);
    free_memory(fA, fB, fC);
    free_memory(hA, hB, hC);
    return 0;
}

代码保存为test_gemm.cpp,然后执行下面命令进行编译:

nvcc test_gemm.cpp -o test_gemm -L/usr/local/cuda/lib64 -lcudart -lcuda -lcublas

最后执行./test_gemm运行就行了。

这里计算的是$C = A \cdot B$,其中$A$的维度是$(m, k)$,$B$的维度是$(k, n)$,$C$的维度是$(m, n)$。由于在C++和Python中新建的数组默认都是行优先存储,而cuBLAS计算矩阵乘法是默认是列优先存储。所以你新建的矩阵送到cuBLAS矩阵乘法算子后,它默认识别成了列优先存储。因此需要调整一下运算顺序,或者对矩阵进行转置。

你需要记住一点,行优先存储的矩阵送到cuBLAS后,相当于做了一次转置,同样计算得到的矩阵$C$也是列优先存储的,你需要转置后再用行优先存储来正常读取。而根据矩阵的运算法则,我们有:
$$
C^{\top} = (A \cdot B)^{\top} = B^{\top} \cdot A^{\top}
$$
所以三个转置后的矩阵就不需要经过任何处理了,直接送到cuBLAS里计算就行了。

运行结果

我对比了三种数据类型:fp32fp16int8,测试环境是V100显卡、CUDA 10.1。由于V100显卡没有int8的tensor core,所以速度并不能达到最快。要想全速进行int8的矩阵乘法,推荐使用sm75及以上的显卡,例如T4、A100等等。此外我还对比了不同的GEMM算法的效果。

执行上面的运行命令后,会输出如下的结果:

shape: (4096, 1024) x (1024, 8192)
>>>>>>>>>>>>>>>>> test fp32 >>>>>>>>>>>>>>>>>
algo -1: 4.831 ms
algo 2: 5.293 ms
algo 3: 5.406 ms
algo 4: 5.297 ms
algo 5: 5.098 ms
algo 6: 4.874 ms
algo 11: 4.870 ms
algo 18: 7.219 ms
algo 19: 6.061 ms
algo 20: 5.631 ms
algo 99: 1.110 ms
algo 100: 1.159 ms
algo 101: 1.688 ms
algo 102: 4.944 ms
algo 103: 4.744 ms
algo 104: 4.700 ms
algo 105: 4.679 ms
algo 106: 4.679 ms
algo 107: 4.675 ms
algo 108: 4.676 ms
algo 109: 4.677 ms
algo 110: 4.676 ms
algo 111: 4.676 ms
algo 112: 4.678 ms
algo 113: 4.675 ms
algo 114: 4.676 ms
algo 115: 4.689 ms
>>>>>>>>>>>>>>>>> test fp16 >>>>>>>>>>>>>>>>>
algo -1: 2.423 ms
algo 1: 2.460 ms
algo 2: 2.565 ms
algo 3: 2.518 ms
algo 5: 2.398 ms
algo 6: 2.416 ms
algo 99: 0.737 ms
algo 100: 1.581 ms
algo 101: 1.032 ms
algo 102: 0.978 ms
algo 103: 0.767 ms
algo 104: 0.790 ms
algo 105: 0.803 ms
algo 106: 0.774 ms
algo 107: 2.656 ms
algo 108: 2.577 ms
algo 109: 2.518 ms
algo 110: 0.925 ms
algo 111: 0.951 ms
algo 112: 0.935 ms
algo 113: 0.909 ms
algo 114: 2.549 ms
algo 115: 2.532 ms
>>>>>>>>>>>>>>>>> test int8 >>>>>>>>>>>>>>>>>
algo -1: 1.232 ms
algo 0: 7.544 ms
algo 1: 1.217 ms
algo 2: 1.294 ms
algo 3: 2.362 ms
algo 99: 1.243 ms
algo 100: 1.244 ms
algo 101: 1.237 ms
algo 102: 1.232 ms
algo 103: 1.230 ms
algo 104: 1.224 ms
algo 105: 1.222 ms
algo 106: 1.224 ms
algo 107: 1.225 ms
algo 108: 1.224 ms
algo 109: 1.218 ms
algo 110: 1.217 ms
algo 111: 1.217 ms
algo 112: 1.218 ms
algo 113: 1.218 ms
algo 114: 1.216 ms
algo 115: 1.217 ms
>>>>>>>>>>>>>>>>> compare result >>>>>>>>>>>>>>>>>
fp32: 52.38629 44.76633 37.65229 31.04420 24.94203 19.34578 14.25543 9.67102 5.59253 2.01996
fp16: 52.46875 44.84375 37.40625 31.21875 24.95312 19.39062 14.28125 9.69531 5.61328 2.05078
int8: 52.38626 44.76632 37.65230 31.04421 24.94203 19.34577 14.25544 9.67103 5.59254 2.01996

这里简单解释一下,algo -1到23表示不使用tensor core算法的结果,algo 99到115表示使用tensor core算法的结果。

可以看到图中缺失了一部分算法的结果,因为那些算法可能不适用于当前的矩阵乘法,因此报错了。

汇总一下各自最快的结果(不使用vs使用tensor core):

  • fp32: 4.83 1.11
  • fp16: 2.41 0.73
  • int8: 1.21 1.21

由于V100显卡没有int8的tensor core,所以int8的两个结果是相同的。结果也符合我们的预期,速度上fp32慢于fp16慢于int8。所以在实际的深度学习应用中,流行使用混合精度,也就是用fp16来进行训练和推理。

而int8是速度最快的,所以如果训练和推理也都能使用int8的话,速度上将会迈上一个新的台阶。

那么一个浮点数的矩阵乘法怎么转变为整数的矩阵乘法呢?这里我不会详细讲,后续会出一个详细的量化教程。

简单来说,对于一个浮点数$f$,假设范围在$[-1, 1]$之间,那我们可以将它表示成一个$[-127, 127]$之间的8位整数$i$,转换关系为:
$$
f = i / 127
$$
那么浮点数矩阵乘法$f_3 = f_1 \cdot f_2$就可以表示为:
$$
f_3 = f_1 \cdot f_2 = i_1 \cdot i_2 / 127^2
$$
所以只需要计算int8矩阵乘法$i_1 \cdot i_2$,然后得到int32类型的输出结果之后,除以$127^2$就可以得到原始的浮点数结果了。

那么由于这里有个类型转换的操作,所以会产生误差。但是在我们的样例中,int8的误差竟然比fp16还要小很多,结果和fp32几乎一模一样。这主要由于是我构造的矩阵数据分布非常均匀有规律,因此计算误差会很小,实际深度网络中int8的误差会较大。

结语

int8甚至更低比特的量化的实际收益非常大,提速可以达到将近2倍。虽然现在有很多现成的自动量化工具,但是效果上或多或少都有一定的损失,速度上也没有达到极致。因此今后量化是一个不错的方向,值得一试。


   转载规则


《cuBLAS矩阵乘法性能分析(附代码示例)》 韦阳 采用 知识共享署名 4.0 国际许可协议 进行许可。
 上一篇
昨晚学妹参加了B站秋招笔试,还想考考我? 昨晚学妹参加了B站秋招笔试,还想考考我?
学妹昨晚参加了B站的2022届秋招算法笔试,做完给我发来了一道题,想考考我,说挺难的。 我看了两分钟,给她发去了我的思路。然后学妹一眼就看懂了,立马秒过。 那么这道题到底是怎么做的呢? 题目要求将$n$个数切分成$k$块,求每块的序号乘上该
2021-08-26
下一篇 
最全攻略:利用LightSeq加速你的深度学习模型 最全攻略:利用LightSeq加速你的深度学习模型
前言LightSeq是字节跳动火山翻译团队开源的一款Transformer系列模型加速引擎,分为训练和推理两个部分。其中推理加速引擎早在2019年12月就已经开源,而训练加速引擎也在2021年6月开源。项目地址:https://github
2021-08-24
  目录