blob: af81132059955c8f6aa0cf31118e18a0a1774813 [file] [log] [blame]
/*
* Copyright 2017, Yahoo! Inc.
* Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root
* for terms.
*/
package com.yahoo.sketches.matrix;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.COMPACT_FLAG_MASK;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractFamilyID;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractFlags;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumColumns;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumColumnsUsed;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumRows;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumRowsUsed;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractPreLongs;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractSerVer;
import org.ojalgo.matrix.store.PrimitiveDenseStore;
import com.yahoo.memory.Memory;
import com.yahoo.memory.WritableMemory;
import com.yahoo.sketches.MatrixFamily;
import com.yahoo.sketches.SketchesArgumentException;
public final class MatrixImplOjAlgo extends Matrix {
private PrimitiveDenseStore mtx_;
private MatrixImplOjAlgo(final int numRows, final int numCols) {
mtx_ = PrimitiveDenseStore.FACTORY.makeZero(numRows, numCols);
numRows_ = numRows;
numCols_ = numCols;
}
private MatrixImplOjAlgo(final PrimitiveDenseStore mtx) {
mtx_ = mtx;
numRows_ = (int) mtx.countRows();
numCols_ = (int) mtx.countColumns();
}
static Matrix newInstance(final int numRows, final int numCols) {
return new MatrixImplOjAlgo(numRows, numCols);
}
static Matrix heapifyInstance(final Memory srcMem) {
final int minBytes = MatrixFamily.MATRIX.getMinPreLongs() * Long.BYTES;
final long memCapBytes = srcMem.getCapacity();
if (memCapBytes < minBytes) {
throw new SketchesArgumentException("Source Memory too small: " + memCapBytes
+ " < " + minBytes);
}
final int preLongs = extractPreLongs(srcMem);
final int serVer = extractSerVer(srcMem);
final int familyID = extractFamilyID(srcMem);
if (serVer != 1) {
throw new SketchesArgumentException("Invalid SerVer reading srcMem. Expected 1, found: "
+ serVer);
}
if (familyID != MatrixFamily.MATRIX.getID()) {
throw new SketchesArgumentException("srcMem does not point to a Matrix");
}
final int flags = extractFlags(srcMem);
final boolean isCompact = (flags & COMPACT_FLAG_MASK) > 0;
int nRows = extractNumRows(srcMem);
int nCols = extractNumColumns(srcMem);
final MatrixImplOjAlgo matrix = new MatrixImplOjAlgo(nRows, nCols);
if (isCompact) {
nRows = extractNumRowsUsed(srcMem);
nCols = extractNumColumnsUsed(srcMem);
}
int memOffset = preLongs * Long.BYTES;
for (int c = 0; c < nCols; ++c) {
for (int r = 0; r < nRows; ++r) {
matrix.mtx_.set(r, c, srcMem.getDouble(memOffset));
memOffset += Double.BYTES;
}
}
return matrix;
}
static Matrix wrap(final PrimitiveDenseStore mtx) {
return new MatrixImplOjAlgo(mtx);
}
public Object getRawObject() {
return mtx_;
}
@Override
public byte[] toByteArray() {
final int preLongs = 2;
final long numElements = mtx_.count();
assert numElements == mtx_.countColumns() * mtx_.countRows();
final int outBytes = (int) ((preLongs * Long.BYTES) + (numElements * Double.BYTES));
final byte[] outByteArr = new byte[outBytes];
final WritableMemory memOut = WritableMemory.wrap(outByteArr);
final Object memObj = memOut.getArray();
final long memAddr = memOut.getCumulativeOffset(0L);
MatrixPreambleUtil.insertPreLongs(memObj, memAddr, preLongs);
MatrixPreambleUtil.insertSerVer(memObj, memAddr, MatrixPreambleUtil.SER_VER);
MatrixPreambleUtil.insertFamilyID(memObj, memAddr, MatrixFamily.MATRIX.getID());
MatrixPreambleUtil.insertFlags(memObj, memAddr, 0);
MatrixPreambleUtil.insertNumRows(memObj, memAddr, (int) mtx_.countRows());
MatrixPreambleUtil.insertNumColumns(memObj, memAddr, (int) mtx_.countColumns());
memOut.putDoubleArray(preLongs << 3, mtx_.data, 0, (int) numElements);
return outByteArr;
}
@Override
public byte[] toCompactByteArray(final int numRows, final int numCols) {
// TODO: row/col limit checks
final int preLongs = 3;
// for non-compact we can do an array copy, so save as non-compact if using the entire matrix
final long numElements = (long) numRows * numCols;
final boolean isCompact = numElements < mtx_.count();
if (!isCompact) {
return toByteArray();
}
assert numElements < mtx_.count();
//final boolean isEmpty = (numRows == 0) || (numColumns == 0);
//final int flags = COMPACT_FLAG_MASK | (isEmpty ? EMPTY_FLAG_MASK : 0);
final int outBytes = (int) ((preLongs * Long.BYTES) + (numElements * Double.BYTES));
final byte[] outByteArr = new byte[outBytes];
final WritableMemory memOut = WritableMemory.wrap(outByteArr);
final Object memObj = memOut.getArray();
final long memAddr = memOut.getCumulativeOffset(0L);
MatrixPreambleUtil.insertPreLongs(memObj, memAddr, preLongs);
MatrixPreambleUtil.insertSerVer(memObj, memAddr, MatrixPreambleUtil.SER_VER);
MatrixPreambleUtil.insertFamilyID(memObj, memAddr, MatrixFamily.MATRIX.getID());
MatrixPreambleUtil.insertFlags(memObj, memAddr, COMPACT_FLAG_MASK);
MatrixPreambleUtil.insertNumRows(memObj, memAddr, (int) mtx_.countRows());
MatrixPreambleUtil.insertNumColumns(memObj, memAddr, (int) mtx_.countColumns());
MatrixPreambleUtil.insertNumRowsUsed(memObj, memAddr, numRows);
MatrixPreambleUtil.insertNumColumnsUsed(memObj, memAddr, numCols);
// write elements in column-major order
long offsetBytes = preLongs * Long.BYTES;
for (int c = 0; c < numCols; ++c) {
for (int r = 0; r < numRows; ++r) {
memOut.putDouble(offsetBytes, mtx_.get(r, c));
offsetBytes += Double.BYTES;
}
}
return outByteArr;
}
@Override
public double getElement(final int row, final int col) {
return mtx_.get(row, col);
}
@Override
public double[] getRow(final int row) {
final int cols = (int) mtx_.countColumns();
final double[] result = new double[cols];
for (int c = 0; c < cols; ++c) {
result[c] = mtx_.get(row, c);
}
return result;
}
@Override
public double[] getColumn(final int col) {
final int rows = (int) mtx_.countRows();
final double[] result = new double[rows];
for (int r = 0; r < rows; ++r) {
result[r] = mtx_.get(r, col);
}
return result;
}
@Override
public void setElement(final int row, final int col, final double value) {
mtx_.set(row, col, value);
}
@Override
public void setRow(final int row, final double[] values) {
if (values.length != mtx_.countColumns()) {
throw new SketchesArgumentException("Invalid number of elements for row. Expected "
+ mtx_.countColumns() + ", found " + values.length);
}
for (int i = 0; i < mtx_.countColumns(); ++i) {
mtx_.set(row, i, values[i]);
}
}
@Override
public void setColumn(final int column, final double[] values) {
if (values.length != mtx_.countRows()) {
throw new SketchesArgumentException("Invalid number of elements for column. Expected "
+ mtx_.countRows() + ", found " + values.length);
}
for (int i = 0; i < mtx_.countRows(); ++i) {
mtx_.set(i, column, values[i]);
}
}
}