/*
 * Copyright 2017, Yahoo! Inc.
 * Licensed under the terms of the Apache License 2.0. See LICENSE file at the project root
 * for terms.
 */

package com.yahoo.sketches.matrix;

import static com.yahoo.sketches.matrix.MatrixPreambleUtil.COMPACT_FLAG_MASK;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractFamilyID;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractFlags;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumColumns;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumColumnsUsed;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumRows;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractNumRowsUsed;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractPreLongs;
import static com.yahoo.sketches.matrix.MatrixPreambleUtil.extractSerVer;

import org.ojalgo.matrix.store.PrimitiveDenseStore;

import com.yahoo.memory.Memory;
import com.yahoo.memory.WritableMemory;
import com.yahoo.sketches.MatrixFamily;
import com.yahoo.sketches.SketchesArgumentException;

public final class MatrixImplOjAlgo extends Matrix {
  private PrimitiveDenseStore mtx_;

  private MatrixImplOjAlgo(final int numRows, final int numCols) {
    mtx_ = PrimitiveDenseStore.FACTORY.makeZero(numRows, numCols);
    numRows_ = numRows;
    numCols_ = numCols;
  }

  private MatrixImplOjAlgo(final PrimitiveDenseStore mtx) {
    mtx_ = mtx;
    numRows_ = (int) mtx.countRows();
    numCols_ = (int) mtx.countColumns();
  }

  static Matrix newInstance(final int numRows, final int numCols) {
    return new MatrixImplOjAlgo(numRows, numCols);
  }

  static Matrix heapifyInstance(final Memory srcMem) {
    final int minBytes = MatrixFamily.MATRIX.getMinPreLongs() * Long.BYTES;
    final long memCapBytes = srcMem.getCapacity();
    if (memCapBytes < minBytes) {
      throw new SketchesArgumentException("Source Memory too small: " + memCapBytes
              + " < " + minBytes);
    }

    final int preLongs = extractPreLongs(srcMem);
    final int serVer = extractSerVer(srcMem);
    final int familyID = extractFamilyID(srcMem);

    if (serVer != 1) {
      throw new SketchesArgumentException("Invalid SerVer reading srcMem. Expected 1, found: "
              + serVer);
    }
    if (familyID != MatrixFamily.MATRIX.getID()) {
      throw new SketchesArgumentException("srcMem does not point to a Matrix");
    }

    final int flags = extractFlags(srcMem);
    final boolean isCompact = (flags & COMPACT_FLAG_MASK) > 0;

    int nRows = extractNumRows(srcMem);
    int nCols = extractNumColumns(srcMem);

    final MatrixImplOjAlgo matrix = new MatrixImplOjAlgo(nRows, nCols);
    if (isCompact) {
      nRows = extractNumRowsUsed(srcMem);
      nCols = extractNumColumnsUsed(srcMem);
    }

    int memOffset = preLongs * Long.BYTES;
    for (int c = 0; c < nCols; ++c) {
      for (int r = 0; r < nRows; ++r) {
        matrix.mtx_.set(r, c, srcMem.getDouble(memOffset));
        memOffset += Double.BYTES;
      }
    }

    return matrix;
  }

  static Matrix wrap(final PrimitiveDenseStore mtx) {
    return new MatrixImplOjAlgo(mtx);
  }

  public Object getRawObject() {
    return mtx_;
  }

  @Override
  public byte[] toByteArray() {
    final int preLongs = 2;
    final long numElements = mtx_.count();
    assert numElements == mtx_.countColumns() * mtx_.countRows();

    final int outBytes = (int) ((preLongs * Long.BYTES) + (numElements * Double.BYTES));
    final byte[] outByteArr = new byte[outBytes];
    final WritableMemory memOut = WritableMemory.wrap(outByteArr);
    final Object memObj = memOut.getArray();
    final long memAddr = memOut.getCumulativeOffset(0L);

    MatrixPreambleUtil.insertPreLongs(memObj, memAddr, preLongs);
    MatrixPreambleUtil.insertSerVer(memObj, memAddr, MatrixPreambleUtil.SER_VER);
    MatrixPreambleUtil.insertFamilyID(memObj, memAddr, MatrixFamily.MATRIX.getID());
    MatrixPreambleUtil.insertFlags(memObj, memAddr, 0);
    MatrixPreambleUtil.insertNumRows(memObj, memAddr, (int) mtx_.countRows());
    MatrixPreambleUtil.insertNumColumns(memObj, memAddr, (int) mtx_.countColumns());
    memOut.putDoubleArray(preLongs << 3, mtx_.data, 0, (int) numElements);

    return outByteArr;
  }

  @Override
  public byte[] toCompactByteArray(final int numRows, final int numCols) {
    // TODO: row/col limit checks

    final int preLongs = 3;

    // for non-compact we can do an array copy, so save as non-compact if using the entire matrix
    final long numElements = (long) numRows * numCols;
    final boolean isCompact = numElements < mtx_.count();
    if (!isCompact) {
      return toByteArray();
    }

    assert numElements < mtx_.count();

    //final boolean isEmpty = (numRows == 0) || (numColumns == 0);
    //final int flags = COMPACT_FLAG_MASK | (isEmpty ? EMPTY_FLAG_MASK : 0);

    final int outBytes = (int) ((preLongs * Long.BYTES) + (numElements * Double.BYTES));
    final byte[] outByteArr = new byte[outBytes];
    final WritableMemory memOut = WritableMemory.wrap(outByteArr);
    final Object memObj = memOut.getArray();
    final long memAddr = memOut.getCumulativeOffset(0L);

    MatrixPreambleUtil.insertPreLongs(memObj, memAddr, preLongs);
    MatrixPreambleUtil.insertSerVer(memObj, memAddr, MatrixPreambleUtil.SER_VER);
    MatrixPreambleUtil.insertFamilyID(memObj, memAddr, MatrixFamily.MATRIX.getID());
    MatrixPreambleUtil.insertFlags(memObj, memAddr, COMPACT_FLAG_MASK);
    MatrixPreambleUtil.insertNumRows(memObj, memAddr, (int) mtx_.countRows());
    MatrixPreambleUtil.insertNumColumns(memObj, memAddr, (int) mtx_.countColumns());
    MatrixPreambleUtil.insertNumRowsUsed(memObj, memAddr, numRows);
    MatrixPreambleUtil.insertNumColumnsUsed(memObj, memAddr, numCols);

    // write elements in column-major order
    long offsetBytes = preLongs * Long.BYTES;
    for (int c = 0; c < numCols; ++c) {
      for (int r = 0; r < numRows; ++r) {
        memOut.putDouble(offsetBytes, mtx_.get(r, c));
        offsetBytes += Double.BYTES;
      }
    }

    return outByteArr;
  }

  @Override
  public double getElement(final int row, final int col) {
    return mtx_.get(row, col);
  }

  @Override
  public double[] getRow(final int row) {
    final int cols = (int) mtx_.countColumns();
    final double[] result = new double[cols];
    for (int c = 0; c < cols; ++c) {
      result[c] = mtx_.get(row, c);
    }
    return result;
  }

  @Override
  public double[] getColumn(final int col) {
    final int rows = (int) mtx_.countRows();
    final double[] result = new double[rows];
    for (int r = 0; r < rows; ++r) {
      result[r] = mtx_.get(r, col);
    }
    return result;
  }

  @Override
  public void setElement(final int row, final int col, final double value) {
    mtx_.set(row, col, value);
  }

  @Override
  public void setRow(final int row, final double[] values) {
    if (values.length != mtx_.countColumns()) {
      throw new SketchesArgumentException("Invalid number of elements for row. Expected "
              + mtx_.countColumns() + ", found " + values.length);
    }

    for (int i = 0; i < mtx_.countColumns(); ++i) {
      mtx_.set(row, i, values[i]);
    }
  }

  @Override
  public void setColumn(final int column, final double[] values) {
    if (values.length != mtx_.countRows()) {
      throw new SketchesArgumentException("Invalid number of elements for column. Expected "
              + mtx_.countRows() + ", found " + values.length);
    }

    for (int i = 0; i < mtx_.countRows(); ++i) {
      mtx_.set(i, column, values[i]);
    }
  }
}
