10. mcSPARSE 额外函数参考

本章介绍了用于操作稀疏矩阵的额外例程。

10.1. mcsparse<t>csrgeam2()

mcsparseStatus_t
mcsparseScsrgeam2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 const float*             alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const float*             csrSortedValA,
                                 const int*               csrSortedRowPtrA,
                                 const int*               csrSortedColIndA,
                                 const float*             beta,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const float*             csrSortedValB,
                                 const int*               csrSortedRowPtrB,
                                 const int*               csrSortedColIndB,
                                 const mcsparseMatDescr_t descrC,
                                 const float*             csrSortedValC,
                                 const int*               csrSortedRowPtrC,
                                 const int*               csrSortedColIndC,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseDcsrgeam2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 const double*            alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const double*            csrSortedValA,
                                 const int*               csrSortedRowPtrA,
                                 const int*               csrSortedColIndA,
                                 const double*            beta,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const double*            csrSortedValB,
                                 const int*               csrSortedRowPtrB,
                                 const int*               csrSortedColIndB,
                                 const mcsparseMatDescr_t descrC,
                                 const double*            csrSortedValC,
                                 const int*               csrSortedRowPtrC,
                                 const int*               csrSortedColIndC,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseCcsrgeam2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 const mcComplex*         alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const mcComplex*         csrSortedValA,
                                 const int*               csrSortedRowPtrA,
                                 const int*               csrSortedColIndA,
                                 const mcComplex*         beta,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const mcComplex*         csrSortedValB,
                                 const int*               csrSortedRowPtrB,
                                 const int*               csrSortedColIndB,
                                 const mcsparseMatDescr_t descrC,
                                 const mcComplex*         csrSortedValC,
                                 const int*               csrSortedRowPtrC,
                                 const int*               csrSortedColIndC,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseZcsrgeam2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 const mcDoubleComplex*   alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const mcDoubleComplex*   csrSortedValA,
                                 const int*               csrSortedRowPtrA,
                                 const int*               csrSortedColIndA,
                                 const mcDoubleComplex*   beta,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const mcDoubleComplex*   csrSortedValB,
                                 const int*               csrSortedRowPtrB,
                                 const int*               csrSortedColIndB,
                                 const mcsparseMatDescr_t descrC,
                                 const mcDoubleComplex*   csrSortedValC,
                                 const int*               csrSortedRowPtrC,
                                 const int*               csrSortedColIndC,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseXcsrgeam2Nnz(mcsparseHandle_t         handle,
                     int                      m,
                     int                      n,
                     const mcsparseMatDescr_t descrA,
                     int                      nnzA,
                     const int*               csrSortedRowPtrA,
                     const int*               csrSortedColIndA,
                     const mcsparseMatDescr_t descrB,
                     int                      nnzB,
                     const int*               csrSortedRowPtrB,
                     const int*               csrSortedColIndB,
                     const mcsparseMatDescr_t descrC,
                     int*                     csrSortedRowPtrC,
                     int*                     nnzTotalDevHostPtr,
                     void*                    workspace)
mcsparseStatus_t
mcsparseScsrgeam2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  const float*             alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const float*             csrSortedValA,
                  const int*               csrSortedRowPtrA,
                  const int*               csrSortedColIndA,
                  const float*             beta,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const float*             csrSortedValB,
                  const int*               csrSortedRowPtrB,
                  const int*               csrSortedColIndB,
                  const mcsparseMatDescr_t descrC,
                  float*                   csrSortedValC,
                  int*                     csrSortedRowPtrC,
                  int*                     csrSortedColIndC,
                  void*                    pBuffer)

mcsparseStatus_t
mcsparseDcsrgeam2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  const double*            alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const double*            csrSortedValA,
                  const int*               csrSortedRowPtrA,
                  const int*               csrSortedColIndA,
                  const double*            beta,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const double*            csrSortedValB,
                  const int*               csrSortedRowPtrB,
                  const int*               csrSortedColIndB,
                  const mcsparseMatDescr_t descrC,
                  double*                  csrSortedValC,
                  int*                     csrSortedRowPtrC,
                  int*                     csrSortedColIndC,
                  void*                    pBuffer)

mcsparseStatus_t
mcsparseCcsrgeam2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  const mcComplex*         alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const mcComplex*         csrSortedValA,
                  const int*               csrSortedRowPtrA,
                  const int*               csrSortedColIndA,
                  const mcComplex*         beta,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const mcComplex*         csrSortedValB,
                  const int*               csrSortedRowPtrB,
                  const int*               csrSortedColIndB,
                  const mcsparseMatDescr_t descrC,
                  mcComplex*               csrSortedValC,
                  int*                     csrSortedRowPtrC,
                  int*                     csrSortedColIndC,
                  void*                    pBuffer)

mcsparseStatus_t
mcsparseZcsrgeam2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  const mcDoubleComplex*   alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const mcDoubleComplex*   csrSortedValA,
                  const int*               csrSortedRowPtrA,
                  const int*               csrSortedColIndA,
                  const mcDoubleComplex*   beta,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const mcDoubleComplex*   csrSortedValB,
                  const int*               csrSortedRowPtrB,
                  const int*               csrSortedColIndB,
                  const mcsparseMatDescr_t descrC,
                  mcDoubleComplex*         csrSortedValC,
                  int*                     csrSortedRowPtrC,
                  int*                     csrSortedColIndC,
                  void*                    pBuffer)

通用程序如下:

int baseC, nnzC;
/* alpha, nnzTotalDevHostPtr 指向主机内存 */
size_t BufferSizeInBytes;
char *buffer = NULL;
int *nnzTotalDevHostPtr = &nnzC;
mcsparseSetPointerMode(handle, MCSPARSE_POINTER_MODE_HOST);
mcMalloc((void**)&csrRowPtrC, sizeof(int)*(m+1));
/* 准备缓冲区 */
mcsparseScsrgeam2_bufferSizeExt(handle, m, n,
      alpha,
      descrA, nnzA,
      csrValA, csrRowPtrA, csrColIndA,
      beta,
      descrB, nnzB,
      csrValB, csrRowPtrB, csrColIndB,
      descrC,
      csrValC, csrRowPtrC, csrColIndC
      &bufferSizeInBytes
      );
mcMalloc((void**)&buffer, sizeof(char)*bufferSizeInBytes);
mcsparseXcsrgeam2Nnz(handle, m, n,
         descrA, nnzA, csrRowPtrA, csrColIndA,
         descrB, nnzB, csrRowPtrB, csrColIndB,
         descrC, csrRowPtrC, nnzTotalDevHostPtr,
         buffer);
if (NULL != nnzTotalDevHostPtr){
      nnzC = *nnzTotalDevHostPtr;
}else{
      mcMemcpy(&nnzC, csrRowPtrC+m, sizeof(int), mcMemcpyDeviceToHost);
      mcMemcpy(&baseC, csrRowPtrC, sizeof(int), mcMemcpyDeviceToHost);
      nnzC -= baseC;
}
mcMalloc((void**)&csrColIndC, sizeof(int)*nnzC);
mcMalloc((void**)&csrValC, sizeof(float)*nnzC);
mcsparseScsrgeam2(handle, m, n,
         alpha,
         descrA, nnzA,
         csrValA, csrRowPtrA, csrColIndA,
         beta,
         descrB, nnzB,
         csrValB, csrRowPtrB, csrColIndB,
         descrC,
         csrValC, csrRowPtrC, csrColIndC
         buffer);

对于 csrgeam2() 函数,有几点需要注意:

  • 另外的三种组合:NT,TN和TT不受mcSPARSE支持。 如果要执行这三种组合之一,用户应使用例程 csr2csc() 进行转换。

  • 仅支持 MCSPARSE_MATRIX_TYPE_GENERAL 类型的矩阵。 如果 AB 是对称或共轭转置的,则用户必须将矩阵扩展为完整矩阵,并将描述符的 MatrixType 字段重新配置为 MCSPARSE_MATRIX_TYPE_GENERAL

  • 如果已知矩阵 C 的稀疏模式,用户可以跳过调用 mcsparseXcsrgeam2Nnz() 函数。 例如,假设用户有一个迭代算法,会迭代更新 AB,但保持稀疏模式不变。 用户可以调用 mcsparseXcsrgeam2Nnz() 函数一次来设置 C 的稀疏模式,然后仅在每次迭代中调用 mcsparse[S|D|C|Z]geam() 函数。

  • 指针 alphabeta 必须有效。

  • alphabeta 为零时,并不被 mcSPARSE 视为特殊情况。 C 的稀疏模式与 alphabeta 的值无关。

  • csrgeam2()csrgeam() 类似,只是 csrgeam2() 需要显式缓冲区,而 csrgeam() 在内部分配缓冲区。

  • 此函数需要内部分配的临时额外存储空间。

  • 如果流式有序内存分配器可用,则该例程支持异步执行

    输入

    handle

    处理mcSPARSE库上下文的句柄。

    m

    稀疏矩阵 ABC 的行数。

    n

    稀疏矩阵 ABC 的列数。

    alpha

    用于乘法的<type>标量。

    descrA

    矩阵 A 的描述符。 支持的矩阵类型仅为 MCSPARSE_MATRIX_TYPE_GENERAL

    nnzA

    矩阵 A 的非零元素数量

    csrValA

    矩阵 A 的非零元素值数组, 长度为 nnzA

    csrRowPtrA

    整型数组,包含m+1个元素, 其中每个元素表示每一行的起始位置 和最后一行的结束位置加一。

    csrColIndA

    整型数组,包含 nnzA 个元素, 表示矩阵A非零元素的列索引。 elements of matrix A

    beta

    用于乘法操作的标量值。 如果beta为零,则y不需要是一个 有效输入。

    descrB

    矩阵 B 的描述符。 支持的矩阵类型仅为 MCSPARSE_MATRIX_TYPE_GENERAL

    nnzB

    矩阵 B 的非零元素数量。

    csrValB

    矩阵 B 的非零元素值数组, 长度为 nnzB

    csrRowPtrB

    整型数组, 包含m+1个元素, 其中每个元素表示每一行的起始位置和 最后一行的结束位置加一。

    csrColIndB

    整型数组,包含 nnzB 个元素, 表示矩阵 B 非零元素的列索引。

    descrC

    矩阵 C 的描述符。 支持的矩阵类型仅为 MCSPARSE_MATRIX_TYPE_GENERAL

    输出

    csrValC

    矩阵C的非零元素值数组,长度为 nnzC

    csrRowPtrC

    整型数组,包含m+1个元素, 其中每个元素表示每一行的起始位置和 最后一行的结束位置加一。

    csrColIndC

    整型数组,包含 nnzC 个元素, 表示矩阵 C 非零元素的列索引。

    nnzTotalDevHostPtr

    在设备或主机内存中非零元素的总数。 它等于 (csrRowPtrC(m)-csrRowPtrC(0))

有关返回状态的描述,请参见 4.2 mcsparseStatus_t

10.2. mcsparse<t>csrgemm2()

mcsparseStatus_t
mcsparseScsrgemm2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 int                      k,
                                 const float*             alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const int*               csrRowPtrA,
                                 const int*               csrColIndA,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const int*               csrRowPtrB,
                                 const int*               csrColIndB,
                                 const float*             beta,
                                 const mcsparseMatDescr_t descrD,
                                 int                      nnzD,
                                 const int*               csrRowPtrD,
                                 const int*               csrColIndD,
                                 csrgemm2Info_t           info,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseDcsrgemm2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 int                      k,
                                 const double*            alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const int*               csrRowPtrA,
                                 const int*               csrColIndA,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const int*               csrRowPtrB,
                                 const int*               csrColIndB,
                                 const double*            beta,
                                 const mcsparseMatDescr_t descrD,
                                 int                      nnzD,
                                 const int*               csrRowPtrD,
                                 const int*               csrColIndD,
                                 csrgemm2Info_t           info,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseCcsrgemm2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 int                      k,
                                 const mcComplex*         alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const int*               csrRowPtrA,
                                 const int*               csrColIndA,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const int*               csrRowPtrB,
                                 const int*               csrColIndB,
                                 const mcComplex*         beta,
                                 const mcsparseMatDescr_t descrD,
                                 int                      nnzD,
                                 const int*               csrRowPtrD,
                                 const int*               csrColIndD,
                                 csrgemm2Info_t           info,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseZcsrgemm2_bufferSizeExt(mcsparseHandle_t         handle,
                                 int                      m,
                                 int                      n,
                                 int                      k,
                                 const mcDoubleComplex*   alpha,
                                 const mcsparseMatDescr_t descrA,
                                 int                      nnzA,
                                 const int*               csrRowPtrA,
                                 const int*               csrColIndA,
                                 const mcsparseMatDescr_t descrB,
                                 int                      nnzB,
                                 const int*               csrRowPtrB,
                                 const int*               csrColIndB,
                                 const mcDoubleComplex*   beta,
                                 const mcsparseMatDescr_t descrD,
                                 int                      nnzD,
                                 const int*               csrRowPtrD,
                                 const int*               csrColIndD,
                                 csrgemm2Info_t           info,
                                 size_t*                  pBufferSizeInBytes)

mcsparseStatus_t
mcsparseXcsrgemm2Nnz(mcsparseHandle_t        handle,
                     int                      m,
                     int                      n,
                     int                      k,
                     const mcsparseMatDescr_t descrA,
                     int                      nnzA,
                     const int*               csrRowPtrA,
                     const int*               csrColIndA,
                     const mcsparseMatDescr_t descrB,
                     int                      nnzB,
                     const int*               csrRowPtrB,
                     const int*               csrColIndB,
                     const mcsparseMatDescr_t descrD,
                     int                      nnzD,
                     const int*               csrRowPtrD,
                     const int*               csrColIndD,
                     const mcsparseMatDescr_t descrC,
                     int*                     csrRowPtrC,
                     int*                     nnzTotalDevHostPtr,
                     const csrgemm2Info_t     info,
                     void*                    pBuffer)
mcsparseStatus_t
mcsparseScsrgemm2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  int                      k,
                  const float*             alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const float*             csrValA,
                  const int*               csrRowPtrA,
                  const int*               csrColIndA,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const float*             csrValB,
                  const int*               csrRowPtrB,
                  const int*               csrColIndB,
                  const float*             beta,
                  const mcsparseMatDescr_t descrD,
                  int                      nnzD,
                  const float*             csrValD,
                  const int*               csrRowPtrD,
                  const int*               csrColIndD,
                  const mcsparseMatDescr_t descrC,
                  float*                   csrValC,
                  const int*               csrRowPtrC,
                  int*                     csrColIndC,
                  const csrgemm2Info_t     info,
                  void*                    pBuffer)

mcsparseStatus_t
mcsparseDcsrgemm2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  int                      k,
                  const double*            alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const double*            csrValA,
                  const int*               csrRowPtrA,
                  const int*               csrColIndA,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const double*            csrValB,
                  const int*               csrRowPtrB,
                  const int*               csrColIndB,
                  const double*            beta,
                  const mcsparseMatDescr_t descrD,
                  int                      nnzD,
                  const double*            csrValD,
                  const int*               csrRowPtrD,
                  const int*               csrColIndD,
                  const mcsparseMatDescr_t descrC,
                  double*                  csrValC,
                  const int*               csrRowPtrC,
                  int*                     csrColIndC,
                  const csrgemm2Info_t     info,
                  void*                    pBuffer)

mcsparseStatus_t
mcsparseCcsrgemm2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  int                      k,
                  const mcComplex*         alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const mcComplex*         csrValA,
                  const int*               csrRowPtrA,
                  const int*               csrColIndA,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const mcComplex*         csrValB,
                  const int*               csrRowPtrB,
                  const int*               csrColIndB,
                  const mcComplex*         beta,
                  const mcsparseMatDescr_t descrD,
                  int                      nnzD,
                  const mcComplex*         csrValD,
                  const int*               csrRowPtrD,
                  const int*               csrColIndD,
                  const mcsparseMatDescr_t descrC,
                  mcComplex*               csrValC,
                  const int*               csrRowPtrC,
                  int*                     csrColIndC,
                  const csrgemm2Info_t     info,
                  void*                    pBuffer)

mcsparseStatus_t
mcsparseZcsrgemm2(mcsparseHandle_t         handle,
                  int                      m,
                  int                      n,
                  int                      k,
                  const mcDoubleComplex*   alpha,
                  const mcsparseMatDescr_t descrA,
                  int                      nnzA,
                  const mcDoubleComplex*   csrValA,
                  const int*               csrRowPtrA,
                  const int*               csrColIndA,
                  const mcsparseMatDescr_t descrB,
                  int                      nnzB,
                  const mcDoubleComplex*   csrValB,
                  const int*               csrRowPtrB,
                  const int*               csrColIndB,
                  const mcDoubleComplex*   beta,
                  const mcsparseMatDescr_t descrD,
                  int                      nnzD,
                  const mcDoubleComplex*   csrValD,
                  const int*               csrRowPtrD,
                  const int*               csrColIndD,
                  const mcsparseMatDescr_t descrC,
                  mcDoubleComplex*         csrValC,
                  const int*               csrRowPtrC,
                  int*                     csrColIndC,
                  const csrgemm2Info_t     info,
                  void*                    pBuffer)

此函数执行以下矩阵乘法操作:

C=alpha*A*B+beta*D

其中, ABDC 分别为 m×kk×nm×nm×n 的稀疏矩阵(由数组 csrValA|csrValB|csrValD|csrValC 在CSR存储格式中定义)。

请注意,新的API mcsparseSpGEMM 要求 D 必须具有与 C 相同的稀疏模式。

csrgemm2 使用 alphabeta 来支持以下操作:

alpha

beta

operation

NULL

NULL

无效

NULL

!NULL

C = beta*D ,不使用 AB

!NULL

NULL

C = alpha*A*B ,不使用 D

!NULL

!NULL

C = alpha*A*B + beta*D

alphabeta 的数值只影响 C 的数值,而不影响其稀疏模式。 例如,如果 alphabeta 非零,则 C 的稀疏模式是 A*BD 的并集,与 alphabeta 的数值无关。

下表根据 mnk 的值显示了不同的操作:

m,n,k

操作

m<0 or n <0 or k<0

无效

m is 0 or n is 0

不执行任何操作

m >0 and n >0 and k is 0

beta 为零,则无效; 若 beta 不为零, 则 C = beta*D

m >0 and n >0 and k >0

alpha 为零,则 C = beta*D ; 若 beta 为零,则 C = alpha*A*B ; 若 alphabeta 都不为零, 则 C = alpha*A*B + beta*D

此函数需要由 csrgemm2_bufferSizeExt() 返回缓冲区大小。 pBuffer 的地址必须是128字节的倍数。如果不是,将返回 MCSPARSE_STATUS_INVALID_VALUE

mcSPARSE库采用两步方法来完成稀疏矩阵计算。

在第一步中,用户分配具有m+1个元素的 csrRowPtrC,并使用函数 mcsparseXcsrgemm2Nnz() 来确定 csrRowPtrC 和非零元素的总数。 在第二步中,用户从 (nnzC=*nnzTotalDevHostPtr)(nnzC=csrRowPtrC(m)-csrRowPtrC(0)) 中获取 nnzC (矩阵C的非零元素数),然后分别分配 nnzC 个元素的 csrValCcsrColIndC,最后调用函数 mcsparse[S|D|C|Z]csrgemm2() 对矩阵 C 进行求值。

C=-A*B+D 的通用程序如下:

// 假设矩阵A,B和D已经准备好。
int baseC, nnzC;
csrgemm2Info_t info = NULL;
size_t bufferSize;
void *buffer = NULL;
//nnzTotalDevHostPtr指向主机内存
int *nnzTotalDevHostPtr = &nnzC;
double alpha = -1.0;
double beta  =  1.0;
mcsparseSetPointerMode(handle, MCSPARSE_POINTER_MODE_HOST);

// 步骤 1: 创建一个不透明的结构
mcsparseCreateCsrgemm2Info(&info);

// 步骤 2:为csrgemm2Nnz和csrgemm2分配缓冲区
mcsparseDcsrgemm2_bufferSizeExt(handle, m, n, k, &alpha,
      descrA, nnzA, csrRowPtrA, csrColIndA,
      descrB, nnzB, csrRowPtrB, csrColIndB,
      &beta,
      descrD, nnzD, csrRowPtrD, csrColIndD,
      info,
      &bufferSize);
mcMalloc(&buffer, bufferSize);

// 步骤 3: 计算csrRowPtrC
mcMalloc((void**)&csrRowPtrC, sizeof(int)*(m+1));
mcsparseXcsrgemm2Nnz(handle, m, n, k,
         descrA, nnzA, csrRowPtrA, csrColIndA,
         descrB, nnzB, csrRowPtrB, csrColIndB,
         descrD, nnzD, csrRowPtrD, csrColIndD,
         descrC, csrRowPtrC, nnzTotalDevHostPtr,
         info, buffer );
if (NULL != nnzTotalDevHostPtr){
      nnzC = *nnzTotalDevHostPtr;
}else{
      mcMemcpy(&nnzC, csrRowPtrC+m, sizeof(int), mcMemcpyDeviceToHost);
      mcMemcpy(&baseC, csrRowPtrC, sizeof(int), mcMemcpyDeviceToHost);
      nnzC -= baseC;
}

// 步骤 4:完成C的稀疏模式和值。
mcMalloc((void**)&csrColIndC, sizeof(int)*nnzC);
mcMalloc((void**)&csrValC, sizeof(double)*nnzC);
// 注意: 如果只需要稀疏模式,将 csrValC 设置为 null。
mcsparseDcsrgemm2(handle, m, n, k, &alpha,
         descrA, nnzA, csrValA, csrRowPtrA, csrColIndA,
         descrB, nnzB, csrValB, csrRowPtrB, csrColIndB,
         &beta,
         descrD, nnzD, csrValD, csrRowPtrD, csrColIndD,
         descrC, csrValC, csrRowPtrC, csrColIndC,
         info, buffer);

// 步骤 5: 销毁不透明的结构
mcsparseDestroyCsrgemm2Info(info);

对于 csrgemm2() 函数,有几点注意事项:

  • 仅支持NN版本。对于其他模式,用户必须明确地转置矩阵 AB

  • 仅支持 MCSPARSE_MATRIX_TYPE_GENERAL。 如果 AB 是对称的或共轭的,用户必须将矩阵扩展为完整矩阵,并重新配置 MatrixType 字段描述符为 MCSPARSE_MATRIX_TYPE_GENERAL

  • 如果 csrValC 为零,则仅计算 C 的稀疏模式。

如果 pBuffer != NULL,则函数 mcsparseXcsrgeam2Nnz()mcsparse<t>csrgeam2() 具有以下特性:

  • 该例程不需要额外的存储空间。

  • 该例程支持异步执行。

    输入

    handle

    处理mcSPARSE库上下文的句柄。

    m

    稀疏矩阵 ADC 的行数。

    n

    稀疏矩阵 BDC 的列数。

    k

    稀疏矩阵 AB 的列数或行数。

    alpha

    用于乘法的<type>标量。

    descrA

    矩阵 A 的描述符。 仅支持 MCSPARSE_MATRIX_TYPE_GENERAL 类型。

    nnzA

    矩阵 A 的非零元素个数。

    csrValA

    矩阵 A 的非零元素值数组。

    csrRowPtrA

    m+1 个元素组成的整型数组,包含每行的起始位置 和最后一行的结束位置加一。

    csrColIndA

    整型数组,包含 nnzA 个矩阵 A 的非零元素的列 索引。

    descrB

    矩阵B的描述符。 仅支持 MCSPARSE_MATRIX_TYPE_GENERAL 类型。

    nnzB

    矩阵 B 的非零元素个数。

    csrValB

    矩阵 B 的非零元素值数组。

    csrRowPtrB

    k+1 个元素组成的整型数组, 包含每行的起始位置和最后一行的结束位置加一。

    csrColIndB

    整型数组包含矩阵 BnnzB 个非零元素的 列索引。

    descrD

    矩阵 D 的描述符。 仅支持 MCSPARSE_MATRIX_TYPE_GENERAL 类型。

    nnzD

    矩阵 D 的非零元素个数。

    csrValD

    矩阵 D 的非零元素值数组。

    csrRowPtrD

    m+1 个元素组成的整型数组,包含每行的起始位置 和最后一行的结束位置加一。

    csrColIndD

    整型数组包含矩阵 DnnzD 个非零元素的 列索引。

    beta

    用于乘法的<type>标量。

    descrC

    矩阵C的描述符。 仅支持 MCSPARSE_MATRIX_TYPE_GENERAL

    info

    csrgemm2Nnzcsrgemm2 中使用的存储有关 信息的结构体。

    pBuffer

    用户分配的缓冲区,其大小由 csrgemm2_bufferSizeExt 返回。

    输出

    csrValC

    稀疏矩阵 C 的非零元素值数组, 长度为 nnzC

    csrRowPtrC

    m+1 个元素组成的整型数组,用于记录 每行起始位置和最后一行结束位置加一的索引。

    csrColIndC

    整型数组,包含矩阵 CnnzC 个 非零元素的列索引。

    pBufferSizeInBytes

    csrgemm2Nnnzcsrgemm2 中 使用的缓冲区大小,以字节为单位。

    nnzTotalDevHostPtr

    稀疏矩阵 C 在设备或主机内存中的 总非零元素个数,等于 (csrRowPtrC(m)-csrRowPtrC(0))

有关返回状态的描述,请参见 4.2 mcsparseStatus_t