blob: 7db1bb9111a00360cd34f8ed9c0da435bbf6c948 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.linalg;
import com.github.fommil.netlib.F2jBLAS;
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
public class VectorizedBLAS extends F2jBLAS {
private static final VectorSpecies<Float> FMAX = FloatVector.SPECIES_MAX;
private static final VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX;
// y += alpha * x
@Override
public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy) {
if (n >= 0
&& x != null && x.length >= n && incx == 1
&& y != null && y.length >= n && incy == 1) {
if (alpha != 0.) {
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
int i = 0;
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i);
DoubleVector vy = DoubleVector.fromArray(DMAX, y, i);
vx.fma(valpha, vy).intoArray(y, i);
}
for (; i < n; i += 1) {
y[i] += alpha * x[i];
}
}
} else {
super.daxpy(n, alpha, x, incx, y, incy);
}
}
// sum(x * y)
@Override
public float sdot(int n, float[] x, int incx, float[] y, int incy) {
if (n >= 0
&& x != null && x.length >= n && incx == 1
&& y != null && y.length >= n && incy == 1) {
float sum = 0.0f;
int i = 0;
FloatVector vsum = FloatVector.zero(FMAX);
for (; i < FMAX.loopBound(n); i += FMAX.length()) {
FloatVector vx = FloatVector.fromArray(FMAX, x, i);
FloatVector vy = FloatVector.fromArray(FMAX, y, i);
vsum = vx.fma(vy, vsum);
}
sum += vsum.reduceLanes(VectorOperators.ADD);
for (; i < n; i += 1) {
sum += x[i] * y[i];
}
return sum;
} else {
return super.sdot(n, x, incx, y, incy);
}
}
// sum(x * y)
@Override
public double ddot(int n, double[] x, int incx, double[] y, int incy) {
if (n >= 0
&& x != null && x.length >= n && incx == 1
&& y != null && y.length >= n && incy == 1) {
double sum = 0.;
int i = 0;
DoubleVector vsum = DoubleVector.zero(DMAX);
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, i);
DoubleVector vy = DoubleVector.fromArray(DMAX, y, i);
vsum = vx.fma(vy, vsum);
}
sum += vsum.reduceLanes(VectorOperators.ADD);
for (; i < n; i += 1) {
sum += x[i] * y[i];
}
return sum;
} else {
return super.ddot(n, x, incx, y, incy);
}
}
@Override
public void dscal(int n, double alpha, double[] x, int incx) {
dscal(n, alpha, x, 0, incx);
}
// x = alpha * x
@Override
public void dscal(int n, double alpha, double[] x, int offsetx, int incx) {
if (n >= 0 && x != null && x.length >= offsetx + n && incx == 1) {
if (alpha != 1.) {
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
int i = 0;
for (; i < DMAX.loopBound(n); i += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, offsetx + i);
vx.mul(valpha).intoArray(x, offsetx + i);
}
for (; i < n; i += 1) {
x[offsetx + i] *= alpha;
}
}
} else {
super.dscal(n, alpha, x, offsetx, incx);
}
}
@Override
public void sscal(int n, float alpha, float[] x, int incx) {
sscal(n, alpha, x, 0, incx);
}
// x = alpha * x
@Override
public void sscal(int n, float alpha, float[] x, int offsetx, int incx) {
if (n >= 0 && x != null && x.length >= offsetx + n && incx == 1) {
if (alpha != 1.) {
FloatVector valpha = FloatVector.broadcast(FMAX, alpha);
int i = 0;
for (; i < FMAX.loopBound(n); i += FMAX.length()) {
FloatVector vx = FloatVector.fromArray(FMAX, x, offsetx + i);
vx.mul(valpha).intoArray(x, offsetx + i);
}
for (; i < n; i += 1) {
x[offsetx + i] *= alpha;
}
}
} else {
super.sscal(n, alpha, x, offsetx, incx);
}
}
// y = alpha * a * x + beta * y
@Override
public void dspmv(String uplo, int n, double alpha, double[] a,
double[] x, int incx, double beta, double[] y, int incy) {
if ("U".equals(uplo)
&& n >= 0
&& a != null && a.length >= n * (n + 1) / 2
&& x != null && x.length >= n && incx == 1
&& y != null && y.length >= n && incy == 1) {
// y = beta * y
dscal(n, beta, y, 1);
// y += alpha * A * x
if (alpha != 0.) {
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
for (int row = 0; row < n; row += 1) {
int col = 0;
DoubleVector vyrowsum = DoubleVector.zero(DMAX);
DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]);
for (; col < DMAX.loopBound(row); col += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, col);
DoubleVector vy = DoubleVector.fromArray(DMAX, y, col);
DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * (row + 1) / 2);
vyrowsum = valpha.mul(vx).fma(va, vyrowsum);
valphaxrow.fma(va, vy).intoArray(y, col);
}
y[row] += vyrowsum.reduceLanes(VectorOperators.ADD);
for (; col < row; col += 1) {
y[row] += alpha * x[col] * a[col + row * (row + 1) / 2];
y[col] += alpha * x[row] * a[col + row * (row + 1) / 2];
}
y[row] += alpha * x[col] * a[col + row * (row + 1) / 2];
}
}
} else {
super.dspmv(uplo, n, alpha, a, x, incx, beta, y, incy);
}
}
// a += alpha * x * x.t
@Override
public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a) {
if ("U".equals(uplo)
&& n >= 0
&& x != null && x.length >= n && incx == 1
&& a != null && a.length >= n * (n + 1) / 2) {
if (alpha != 0.) {
for (int row = 0; row < n; row += 1) {
int col = 0;
DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]);
for (; col < DMAX.loopBound(row + 1); col += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, col);
DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * (row + 1) / 2);
vx.fma(valphaxrow, va).intoArray(a, col + row * (row + 1) / 2);
}
for (; col < row + 1; col += 1) {
a[col + row * (row + 1) / 2] += alpha * x[row] * x[col];
}
}
}
} else {
super.dspr(uplo, n, alpha, x, incx, a);
}
}
// a += alpha * x * x.t
@Override
public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda) {
if ("U".equals(uplo)
&& n >= 0
&& x != null && x.length >= n && incx == 1
&& a != null && a.length >= n * n && lda == n) {
if (alpha != 0.) {
for (int row = 0; row < n; row += 1) {
int col = 0;
DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]);
for (; col < DMAX.loopBound(row + 1); col += DMAX.length()) {
DoubleVector vx = DoubleVector.fromArray(DMAX, x, col);
DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * n);
vx.fma(valphaxrow, va).intoArray(a, col + row * n);
}
for (; col < row + 1; col += 1) {
a[col + row * n] += alpha * x[row] * x[col];
}
}
}
} else {
super.dsyr(uplo, n, alpha, x, incx, a, lda);
}
}
@Override
public void dgemv(String trans, int m, int n,
double alpha, double[] a, int lda, double[] x, int incx,
double beta, double[] y, int incy) {
dgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
}
// y = alpha * A * x + beta * y
@Override
public void dgemv(String trans, int m, int n,
double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx,
double beta, double[] y, int offsety, int incy) {
if ("N".equals(trans)
&& m >= 0 && n >= 0
&& a != null && a.length >= offseta + m * n && lda == m
&& x != null && x.length >= offsetx + n && incx == 1
&& y != null && y.length >= offsety + m && incy == 1) {
// y = beta * y
dscal(m, beta, y, offsety, 1);
// y += alpha * A * x
if (alpha != 0.) {
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
for (int col = 0; col < n; col += 1) {
int row = 0;
for (; row < DMAX.loopBound(m); row += DMAX.length()) {
DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + row + col * m);
DoubleVector vy = DoubleVector.fromArray(DMAX, y, offsety + row);
valpha.mul(x[offsetx + col]).fma(va, vy)
.intoArray(y, offsety + row);
}
for (; row < m; row += 1) {
y[offsety + row] += alpha * x[offsetx + col] * a[offseta + row + col * m];
}
}
}
} else if ("T".equals(trans)
&& m >= 0 && n >= 0
&& a != null && a.length >= offseta + m * n && lda == m
&& x != null && x.length >= offsetx + m && incx == 1
&& y != null && y.length >= offsety + n && incy == 1) {
if (alpha != 0. || beta != 1.) {
for (int col = 0; col < n; col += 1) {
double sum = 0.;
int row = 0;
DoubleVector vsum = DoubleVector.zero(DMAX);
for (; row < DMAX.loopBound(m); row += DMAX.length()) {
DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + row + col * m);
DoubleVector vx = DoubleVector.fromArray(DMAX, x, offsetx + row);
vsum = va.fma(vx, vsum);
}
sum += vsum.reduceLanes(VectorOperators.ADD);
for (; row < m; row += 1) {
sum += x[offsetx + row] * a[offseta + row + col * m];
}
y[offsety + col] = alpha * sum + beta * y[offsety + col];
}
}
} else {
super.dgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
}
}
@Override
public void sgemv(String trans, int m, int n,
float alpha, float[] a, int lda, float[] x, int incx,
float beta, float[] y, int incy) {
sgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy);
}
// y = alpha * A * x + beta * y
@Override
public void sgemv(String trans, int m, int n,
float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx,
float beta, float[] y, int offsety, int incy) {
if ("N".equals(trans)
&& m >= 0 && n >= 0
&& a != null && a.length >= offseta + m * n && lda == m
&& x != null && x.length >= offsetx + n && incx == 1
&& y != null && y.length >= offsety + m && incy == 1) {
// y = beta * y
sscal(m, beta, y, offsety, 1);
// y += alpha * A * x
if (alpha != 0.f) {
FloatVector valpha = FloatVector.broadcast(FMAX, alpha);
for (int col = 0; col < n; col += 1) {
int row = 0;
for (; row < FMAX.loopBound(m); row += FMAX.length()) {
FloatVector va = FloatVector.fromArray(FMAX, a, offseta + row + col * m);
FloatVector vy = FloatVector.fromArray(FMAX, y, offsety + row);
valpha.mul(x[offsetx + col]).fma(va, vy)
.intoArray(y, offsety + row);
}
for (; row < m; row += 1) {
y[offsety + row] += alpha * x[offsetx + col] * a[offseta + row + col * m];
}
}
}
} else if ("T".equals(trans)
&& m >= 0 && n >= 0
&& a != null && a.length >= offseta + m * n && lda == m
&& x != null && x.length >= offsetx + m && incx == 1
&& y != null && y.length >= offsety + n && incy == 1) {
if (alpha != 0. || beta != 1.) {
for (int col = 0; col < n; col += 1) {
float sum = 0.f;
int row = 0;
FloatVector vsum = FloatVector.zero(FMAX);
for (; row < FMAX.loopBound(m); row += FMAX.length()) {
FloatVector va = FloatVector.fromArray(FMAX, a, offseta + row + col * m);
FloatVector vx = FloatVector.fromArray(FMAX, x, offsetx + row);
vsum = va.fma(vx, vsum);
}
sum += vsum.reduceLanes(VectorOperators.ADD);
for (; row < m; row += 1) {
sum += x[offsetx + row] * a[offseta + row + col * m];
}
y[offsety + col] = alpha * sum + beta * y[offsety + col];
}
}
} else {
super.sgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy);
}
}
@Override
public void dgemm(String transa, String transb, int m, int n, int k,
double alpha, double[] a, int lda, double[] b, int ldb,
double beta, double[] c, int ldc) {
dgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc);
}
// c = alpha * a * b + beta * c
@Override
public void dgemm(String transa, String transb, int m, int n, int k,
double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb,
double beta, double[] c, int offsetc, int ldc) {
if ("N".equals(transa) && "N".equals(transb)
&& m >= 0 && n >= 0 && k >= 0
&& a != null && a.length >= offseta + m * k && lda == m
&& b != null && b.length >= offsetb + k * n && ldb == k
&& c != null && c.length >= offsetc + m * n && ldc == m) {
// C = beta * C
dscal(m * n, beta, c, offsetc, 1);
// C += alpha * A * B
if (alpha != 0.) {
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
for (int col = 0; col < n; col += 1) {
for (int i = 0; i < k; i += 1) {
int row = 0;
for (; row < DMAX.loopBound(m); row += DMAX.length()) {
DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + i * m + row);
DoubleVector vc = DoubleVector.fromArray(DMAX, c, offsetc + col * m + row);
valpha.mul(b[offsetb + col * k + i]).fma(va, vc)
.intoArray(c, offsetc + col * m + row);
}
for (; row < m; row += 1) {
c[offsetc + col * m + row] += alpha * a[offseta + i * m + row] * b[offsetb + col * k + i];
}
}
}
}
} else if ("N".equals(transa) && "T".equals(transb)
&& m >= 0 && n >= 0 && k >= 0
&& a != null && a.length >= offseta + m * k && lda == m
&& b != null && b.length >= offsetb + k * n && ldb == n
&& c != null && c.length >= offsetc + m * n && ldc == m) {
// C = beta * C
dscal(m * n, beta, c, offsetc, 1);
// C += alpha * A * B
if (alpha != 0.) {
DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha);
for (int i = 0; i < k; i += 1) {
for (int col = 0; col < n; col += 1) {
int row = 0;
for (; row < DMAX.loopBound(m); row += DMAX.length()) {
DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + i * m + row);
DoubleVector vc = DoubleVector.fromArray(DMAX, c, offsetc + col * m + row);
valpha.mul(b[offsetb + col + i * n]).fma(va, vc)
.intoArray(c, offsetc + col * m + row);
}
for (; row < m; row += 1) {
c[offsetc + col * m + row] += alpha * a[offseta + i * m + row] * b[offsetb + col + i * n];
}
}
}
}
} else if ("T".equals(transa) && "N".equals(transb)
&& m >= 0 && n >= 0 && k >= 0
&& a != null && a.length >= offseta + m * k && lda == k
&& b != null && b.length >= offsetb + k * n && ldb == k
&& c != null && c.length >= offsetc + m * n && ldc == m) {
if (alpha != 0. || beta != 1.) {
for (int col = 0; col < n; col += 1) {
for (int row = 0; row < m; row += 1) {
double sum = 0.;
int i = 0;
DoubleVector vsum = DoubleVector.zero(DMAX);
for (; i < DMAX.loopBound(k); i += DMAX.length()) {
DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + i + row * k);
DoubleVector vb = DoubleVector.fromArray(DMAX, b, offsetb + col * k + i);
vsum = va.fma(vb, vsum);
}
sum += vsum.reduceLanes(VectorOperators.ADD);
for (; i < k; i += 1) {
sum += a[offseta + i + row * k] * b[offsetb + col * k + i];
}
if (beta != 0.) {
c[offsetc + col * m + row] = alpha * sum + beta * c[offsetc + col * m + row];
} else {
c[offsetc + col * m + row] = alpha * sum;
}
}
}
}
} else if ("T".equals(transa) && "T".equals(transb)
&& m >= 0 && n >= 0 && k >= 0
&& a != null && a.length >= offseta + m * k && lda == k
&& b != null && b.length >= offsetb + k * n && ldb == n
&& c != null && c.length >= offsetc + m * n && ldc == m) {
if (alpha != 0. || beta != 1.) {
// FIXME: do block by block
for (int col = 0; col < n; col += 1) {
for (int row = 0; row < m; row += 1) {
double sum = 0.;
for (int i = 0; i < k; i += 1) {
sum += a[offseta + i + row * k] * b[offsetb + col + i * n];
}
if (beta != 0.) {
c[offsetc + col * m + row] = alpha * sum + beta * c[offsetc + col * m + row];
} else {
c[offsetc + col * m + row] = alpha * sum;
}
}
}
}
} else {
super.dgemm(transa, transb, m, n, k,
alpha, a, offseta, lda, b, offsetb, ldb,
beta, c, offsetc, ldc);
}
}
}