blob: 137c64c1afcac4aa05d8e880b84787ca024ac4bb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.math;
import java.util.Set;
import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.F2jBLAS;
import org.apache.ignite.ml.math.exceptions.math.CardinalityException;
import org.apache.ignite.ml.math.exceptions.math.MathIllegalArgumentException;
import org.apache.ignite.ml.math.exceptions.math.NonSquareMatrixException;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.matrix.impl.SparseMatrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
import org.apache.ignite.ml.math.util.MatrixUtil;
/**
* Useful subset of BLAS operations.
* This class is based on 'BLAS' class from Apache Spark MLlib.
*/
public class Blas {
/** F2J implementation of BLAS. */
private static transient BLAS f2jBlas = new F2jBLAS();
/**
* Native implementation of BLAS. F2J implementation will be used as fallback if no native implementation is found.
*/
private static transient BLAS nativeBlas = BLAS.getInstance();
/**
* Performs y += a * x
*
* @param a Scalar a.
* @param x Vector x.
* @param y Vector y.
*/
public static void axpy(Double a, Vector x, Vector y) {
if (x.size() != y.size())
throw new CardinalityException(x.size(), y.size());
if (x.isArrayBased() && y.isArrayBased())
axpy(a, x.getStorage().data(), y.getStorage().data());
else if (x instanceof SparseVector && y.isArrayBased())
axpy(a, (SparseVector)x, y.getStorage().data());
else
throw new MathIllegalArgumentException("Operation 'axpy' doesn't support this combination of parameters [x="
+ x.getClass().getName() + ", y=" + y.getClass().getName() + "].");
}
/** */
private static void axpy(Double a, double[] x, double[] y) {
f2jBlas.daxpy(x.length, a, x, 1, y, 1);
}
/** */
private static void axpy(Double a, SparseVector x, double[] y) {
int xSize = x.size();
if (a == 1.0) {
int k = 0;
while (k < xSize) {
y[k] += x.getX(k);
k++;
}
}
else {
int k = 0;
while (k < xSize) {
y[k] += a * x.getX(k);
k++;
}
}
}
/**
* Returns dot product of vectors x and y.
*
* @param x Vector x.
* @param y Vector y.
* @return Dot product of x and y.
**/
public static Double dot(Vector x, Vector y) {
return x.dot(y);
}
/**
* Copies Vector x into Vector y. (y = x)
*
* @param x Vector x.
* @param y Vector y.
*/
public void copy(Vector x, Vector y) {
int n = y.size();
if (x.size() != n)
throw new CardinalityException(x.size(), n);
if (y.isArrayBased()) {
double[] yData = y.getStorage().data();
if (x.isArrayBased())
System.arraycopy(x.getStorage().data(), 0, y.getStorage().data(), 0, n);
else {
if (y instanceof SparseVector) {
for (int i = 0; i < n; i++)
yData[i] = x.getX(i);
}
}
}
else
throw new IllegalArgumentException("Vector y must be array based in copy.");
}
/**
* Performs in-place multiplication of vector x by a real scalar a. (x = a * x)
*
* @param a Scalar a.
* @param x Vector x.
**/
public static void scal(Double a, Vector x) {
if (x.isArrayBased())
f2jBlas.dscal(x.size(), a, x.getStorage().data(), 1);
else if (x instanceof SparseVector) {
Set<Integer> indexes = ((SparseVector)x).indexes();
for (Integer i : indexes)
x.compute(i, (ind, v) -> v * a);
}
else
throw new IllegalArgumentException();
}
/**
* Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
* @param u the upper triangular part of the matrix in a [[DenseVector]](column major)
*/
public static void spr(Double alpha, DenseVector v, DenseVector u) {
nativeBlas.dspr("U", v.size(), alpha, v.getStorage().data(), 1, u.getStorage().data());
}
/** */
public static void spr(Double alpha, SparseVector v, DenseVector u) {
int prevNonDfltInd = 0;
int startInd = 0;
double av;
double[] uData = u.getStorage().data();
for (Integer nonDefaultInd : v.indexes()) {
startInd += (nonDefaultInd - prevNonDfltInd) * (nonDefaultInd + prevNonDfltInd + 1) / 2;
av = alpha * v.get(nonDefaultInd);
for (Integer i : v.indexes())
if (i <= nonDefaultInd)
uData[startInd + i] += av * v.getX(i);
prevNonDfltInd = nonDefaultInd;
}
}
/**
* A := alpha * x * x^T + A.
*
* @param alpha a real scalar that will be multiplied to x * x^T^.
* @param x the vector x that contains the n elements.
* @param a the symmetric matrix A. Size of n x n.
*/
void syr(Double alpha, Vector x, DenseMatrix a) {
int mA = a.rowSize();
int nA = a.columnSize();
if (mA != nA)
throw new NonSquareMatrixException(mA, nA);
if (mA != x.size())
throw new CardinalityException(x.size(), mA);
// TODO: IGNITE-5535, Process DenseLocalOffHeapVector
if (x instanceof DenseVector)
syr(alpha, x, a);
else if (x instanceof SparseVector)
syr(alpha, x, a);
else
throw new IllegalArgumentException("Operation 'syr' does not support vector [class="
+ x.getClass().getName() + "].");
}
/** TODO: IGNTIE-5770, add description for a */
static void syr(Double alpha, DenseVector x, DenseMatrix a) {
int nA = a.rowSize();
int mA = a.columnSize();
nativeBlas.dsyr("U", x.size(), alpha, x.getStorage().data(), 1, a.getStorage().data(), nA);
// Fill lower triangular part of A
int i = 0;
while (i < mA) {
int j = i + 1;
while (j < nA) {
a.setX(j, i, a.getX(i, j));
j++;
}
i++;
}
}
/** */
public static void syr(Double alpha, SparseVector x, DenseMatrix a) {
int mA = a.columnSize();
for (Integer i : x.indexes()) {
double mult = alpha * x.getX(i);
for (Integer j : x.indexes())
a.getStorage().data()[mA * i + j] += mult * x.getX(j);
}
}
/**
* For the moment we have no flags indicating if matrix is transposed or not. Therefore all dgemm parameters for
* transposition are equal to 'N'.
*/
public static void gemm(double alpha, Matrix a, Matrix b, double beta, Matrix c) {
if (alpha == 0.0 && beta == 1.0)
return;
else if (alpha == 0.0)
scal(c, beta);
else {
double[] fA = a.getStorage().data();
double[] fB = b.getStorage().data();
double[] fC = c.getStorage().data();
assert fA != null;
nativeBlas.dgemm("N", "N", a.rowSize(), b.columnSize(), a.columnSize(), alpha, fA,
a.rowSize(), fB, b.rowSize(), beta, fC, c.rowSize());
if (c instanceof SparseMatrix)
MatrixUtil.unflatten(fC, c);
}
}
/**
* y := alpha * A * x + beta * y.
*
* @param alpha Alpha.
* @param a Matrix a.
* @param x Vector x.
* @param beta Beta.
* @param y Vector y.
*/
public static void gemv(double alpha, Matrix a, Vector x, double beta, Vector y) {
checkCardinality(a, x);
if (a.rowSize() != y.size())
throw new CardinalityException(a.columnSize(), y.size());
if (alpha == 0.0 && beta == 1.0)
return;
if (alpha == 0.0) {
scal(y, beta);
return;
}
double[] fA = a.getStorage().data();
double[] fX = x.getStorage().data();
double[] fY = y.getStorage().data();
nativeBlas.dgemv("N", a.rowSize(), a.columnSize(), alpha, fA, a.rowSize(), fX, 1, beta, fY, 1);
if (y instanceof SparseVector)
y.assign(fY);
}
/**
* M := alpha * M.
*
* @param m Matrix M.
* @param alpha Alpha.
*/
private static void scal(Matrix m, double alpha) {
if (alpha != 1.0)
for (int i = 0; i < m.rowSize(); i++)
for (int j = 0; j < m.columnSize(); j++)
m.setX(i, j, m.getX(i, j) * alpha);
}
/**
* v := alpha * v.
*
* @param v Vector v.
* @param alpha Aplha.
*/
private static void scal(Vector v, double alpha) {
if (alpha != 1.0)
for (int i = 0; i < v.size(); i++)
v.compute(i, (ind, val) -> val * alpha);
}
/**
* Checks if Matrix A can be multiplied by vector v, if not CardinalityException is thrown.
*
* @param a Matrix A.
* @param v Vector v.
*/
public static void checkCardinality(Matrix a, Vector v) throws CardinalityException {
if (a.columnSize() != v.size())
throw new CardinalityException(a.columnSize(), v.size());
}
}