blob: c0c909a6fe6641e6ae3548ab3e05bb3386e81e70 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "libmatrixmult.h"
#include "omp.h"
#include <cmath>
#include <cstdlib>
#ifdef USE_OPEN_BLAS
#include <cblas.h>
#else
#include <mkl_service.h>
#endif
static int SYSDS_CURRENT_NUM_THREADS = -1;
void setNumThreadsForBLAS(int numThreads) {
if (SYSDS_CURRENT_NUM_THREADS != numThreads) {
#ifdef USE_OPEN_BLAS
openblas_set_num_threads(numThreads);
#else
mkl_set_num_threads(numThreads);
#endif
SYSDS_CURRENT_NUM_THREADS = numThreads;
}
}
void dmatmult(double *m1Ptr, double *m2Ptr, double *retPtr, int m, int k, int n,
int numThreads) {
// BLAS routine dispatch according to input dimension sizes (we don't use
// cblas_dgemv with CblasColMajor for matrix-vector because it was generally
// slower than dgemm)
setNumThreadsForBLAS(numThreads);
if (m == 1 && n == 1) // VV
retPtr[0] = cblas_ddot(k, m1Ptr, 1, m2Ptr, 1);
else if (n == 1) // MV
cblas_dgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1, 0,
retPtr, 1);
else // MM
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k,
m2Ptr, n, 0, retPtr, n);
}
void smatmult(float *m1Ptr, float *m2Ptr, float *retPtr, int m, int k, int n,
int numThreads) {
// BLAS routine dispatch according to input dimension sizes (we don't use
// cblas_sgemv with CblasColMajor for matrix-vector because it was generally
// slower than sgemm)
setNumThreadsForBLAS(numThreads);
if (m == 1 && n == 1) // VV
retPtr[0] = cblas_sdot(k, m1Ptr, 1, m2Ptr, 1);
else if (n == 1) // MV
cblas_sgemv(CblasRowMajor, CblasNoTrans, m, k, 1, m1Ptr, k, m2Ptr, 1, 0,
retPtr, 1);
else // MM
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k,
m2Ptr, n, 0, retPtr, n);
}
void tsmm(double *m1Ptr, double *retPtr, int m1rlen, int m1clen, bool leftTrans,
int numThreads) {
setNumThreadsForBLAS(numThreads);
if ((leftTrans && m1clen == 1) || (!leftTrans && m1rlen == 1)) {
retPtr[0] = cblas_ddot(leftTrans ? m1rlen : m1clen, m1Ptr, 1, m1Ptr, 1);
} else { // general case
int n = leftTrans ? m1clen : m1rlen;
int k = leftTrans ? m1rlen : m1clen;
cblas_dsyrk(CblasRowMajor, CblasUpper,
leftTrans ? CblasTrans : CblasNoTrans, n, k, 1, m1Ptr, n, 0,
retPtr, n);
}
}