blob: 63898062efdb5003aa1f51b6bc904e6be4548fa2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.scalabindings
import org.apache.mahout.math._
import org.apache.mahout.math.flavor.{BackEnum, TraversingStructureEnum}
import org.apache.mahout.math.function.Functions
import RLikeOps._
import org.apache.mahout.logging._
import scala.collection.JavaConversions._
object MMul extends MMBinaryFunc {
private final implicit val log = getLog(MMul.getClass)
override def apply(a: Matrix, b: Matrix, r: Option[Matrix]): Matrix = {
require(a.ncol == b.nrow, "Incompatible matrix sizes in matrix multiplication.")
val (af, bf) = (a.getFlavor, b.getFlavor)
val backs = (af.getBacking, bf.getBacking)
val sd = (af.getStructure, sparsityAnalysis(a), bf.getStructure, sparsityAnalysis(b))
val alg: MMulAlg = backs match {
// Both operands are jvm memory backs.
case (BackEnum.JVMMEM, BackEnum.JVMMEM)
sd match {
// Multiplication cases by a diagonal matrix.
case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.COLWISE, _)
if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagCW
case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.SPARSECOLWISE, _)
if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagCW
case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.ROWWISE, _)
if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagRW
case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.SPARSEROWWISE, _)
if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagRW
case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.VECTORBACKED, _)
if b.isInstanceOf[DiagonalMatrix] ⇒ jvmCWDiag
case (TraversingStructureEnum.SPARSECOLWISE, _, TraversingStructureEnum.VECTORBACKED, _)
if b.isInstanceOf[DiagonalMatrix] ⇒ jvmCWDiag
case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.VECTORBACKED, _)
if b.isInstanceOf[DiagonalMatrix] ⇒ jvmRWDiag
case (TraversingStructureEnum.SPARSEROWWISE, _, TraversingStructureEnum.VECTORBACKED, _)
if b.isInstanceOf[DiagonalMatrix] ⇒ jvmRWDiag
// Dense-dense cases
case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) if a eq b.t ⇒ jvmDRWAAt
case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) if a.t eq b ⇒ jvmDRWAAt
case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) ⇒ jvmRWCW
case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.ROWWISE, true) ⇒ jvmRWRW
case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.COLWISE, true) ⇒ jvmCWCW
case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) if a eq b.t ⇒ jvmDCWAAt
case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) if a.t eq b ⇒ jvmDCWAAt
case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) ⇒ jvmCWRW
// Sparse row matrix x sparse row matrix (array of vectors)
case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.ROWWISE, false) ⇒ jvmSparseRWRW
case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.COLWISE, false) ⇒ jvmSparseRWCW
case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.ROWWISE, false) ⇒ jvmSparseCWRW
case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.COLWISE, false) ⇒ jvmSparseCWCW
// Sparse matrix x sparse matrix (hashtable of vectors)
case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.SPARSEROWWISE, false)
jvmSparseRowRWRW
case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.SPARSECOLWISE, false)
jvmSparseRowRWCW
case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.SPARSEROWWISE, false)
jvmSparseRowCWRW
case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.SPARSECOLWISE, false)
jvmSparseRowCWCW
// Sparse matrix x non-like
case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseRowRWRW
case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseRowRWCW
case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseRowCWRW
case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseCWCW
case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ jvmSparseRWRW
case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ jvmSparseRWCW
case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ jvmSparseCWRW
case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ jvmSparseRowCWCW
// Everything else including at least one sparse LHS or RHS argument
case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseRWRW
case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseRWCW
case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseCWRW
case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseCWCW2flips
// Sparse methods are only effective if the first argument is sparse, so we need to do a swap.
case (_, _, _, false)(a, b, r) ⇒ apply(b.t, a.t, r.map {_.t}).t
// Default jvm-jvm case.
case _ ⇒ jvmRWCW
}
}
alg(a, b, r)
}
type MMulAlg = MMBinaryFunc
@inline
private def jvmRWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
require(r.forall(mxR ⇒ mxR.nrow == a.nrow && mxR.ncol == b.ncol))
val (m, n) = (a.nrow, b.ncol)
val mxR = r.getOrElse(if (sparsityAnalysis(a)) a.like(m, n) else b.like(m, n))
for (row ← 0 until mxR.nrow; col ← 0 until mxR.ncol) {
// this vector-vector should be sort of optimized, right?
mxR(row, col) = a(row, ::) dot b(::, col)
}
mxR
}
@inline
private def jvmRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
// A bit hackish: currently, this relies a bit on the fact that like produces RW(?)
val bclone = b.like(b.ncol, b.nrow).t
for (brow ← b) bclone(brow.index(), ::) := brow
require(bclone.getFlavor.getStructure == TraversingStructureEnum.COLWISE || bclone.getFlavor.getStructure ==
TraversingStructureEnum.SPARSECOLWISE, "COL wise conversion assumption of RHS is wrong, do over this code.")
jvmRWCW(a, bclone, r)
}
private def jvmCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
jvmRWRW(b.t, a.t, r.map(_.t)).t
}
private def jvmCWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
// This is a primary contender with Outer Prod sum algo.
// Here, we force-reorient both matrices and run RWCW.
// A bit hackish: currently, this relies a bit on the fact that clone always produces RW(?)
val aclone = a.cloned
require(aclone.getFlavor.getStructure == TraversingStructureEnum.ROWWISE || aclone.getFlavor.getStructure ==
TraversingStructureEnum.SPARSEROWWISE, "Row wise conversion assumption of RHS is wrong, do over this code.")
jvmRWRW(aclone, b, r)
}
private def jvmSparseRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
val mxR = r.getOrElse(b.like(a.nrow, b.ncol))
// This is basically almost the algorithm from SparseMatrix.times
for (arow ← a; ael ← arow.nonZeroes)
mxR(arow.index(), ::).assign(b(ael.index, ::), Functions.plusMult(ael))
mxR
}
private def jvmSparseRowRWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
val mxR = r.getOrElse(b.like(a.nrow, b.ncol))
for (arow ← a.iterateNonEmpty(); ael ← arow.vector.nonZeroes)
mxR(arow.index(), ::).assign(b(ael.index, ::), Functions.plusMult(ael))
mxR
}
private def jvmSparseRowCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None) =
jvmSparseRowRWRW(b.t, a.t, r.map(_.t)).t
private def jvmSparseRowCWCW2flips(a: Matrix, b: Matrix, r: Option[Matrix] = None) =
jvmSparseRowRWRW(a cloned, b cloned, r)
private def jvmSparseRowRWCW(a: Matrix, b: Matrix, r: Option[Matrix]) =
jvmSparseRowRWRW(a, b cloned, r)
private def jvmSparseRowCWRW(a: Matrix, b: Matrix, r: Option[Matrix]) =
jvmSparseRowRWRW(a cloned, b, r)
private def jvmSparseRWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None) =
jvmSparseRWRW(a, b.cloned, r)
private def jvmSparseCWRW(a: Matrix, b: Matrix, r: Option[Matrix] = None) =
jvmSparseRWRW(a cloned, b, r)
private def jvmSparseCWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None) =
jvmSparseRWRW(b.t, a.t, r.map(_.t)).t
private def jvmSparseCWCW2flips(a: Matrix, b: Matrix, r: Option[Matrix] = None) =
jvmSparseRWRW(a cloned, b cloned, r)
private def jvmDiagRW(diagm:Matrix, b:Matrix, r:Option[Matrix] = None):Matrix = {
val mxR = r.getOrElse(b.like(diagm.nrow, b.ncol))
for (del ← diagm.diagv.nonZeroes())
mxR(del.index, ::).assign(b(del.index, ::), Functions.plusMult(del))
mxR
}
private def jvmDiagCW(diagm: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
val mxR = r.getOrElse(b.like(diagm.nrow, b.ncol))
for (bcol ← b.t) mxR(::, bcol.index()) := bcol * diagm.diagv
mxR
}
private def jvmCWDiag(a: Matrix, diagm: Matrix, r: Option[Matrix] = None) =
jvmDiagRW(diagm, a.t, r.map {_.t}).t
private def jvmRWDiag(a: Matrix, diagm: Matrix, r: Option[Matrix] = None) =
jvmDiagCW(diagm, a.t, r.map {_.t}).t
/** Dense column-wise AA' */
private def jvmDCWAAt(a:Matrix, b:Matrix, r:Option[Matrix] = None) = {
// a.t must be equiv. to b. Cloning must rewrite to row-wise.
jvmDRWAAt(a.cloned,null,r)
}
/** Dense Row-wise AA' */
private def jvmDRWAAt(a:Matrix, b:Matrix, r:Option[Matrix] = None) = {
// a.t must be equiv to b.
debug("AAt computation detected.")
// Check dimensions if result is supplied.
require(r.forall(mxR ⇒ mxR.nrow == a.nrow && mxR.ncol == a.nrow))
val mxR = r.getOrElse(a.like(a.nrow, a.nrow))
// This is symmetric computation. Compile upper triangular first.
for (row ← 0 until mxR.nrow) {
// diagonal value
mxR(row, row) = a(row, ::).aggregate(Functions.PLUS, Functions.SQUARE)
for ( col ← row + 1 until mxR.ncol) {
// this vector-vector should be sort of optimized, right?
val v = a(row, ::) dot a(col, ::)
mxR(row, col) = v
mxR(col,row) = v
}
}
mxR
}
private def jvmOuterProdSum(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = {
// This may be already laid out for outer product computation, which may be faster than reorienting
// both matrices? need to check.
val (m, n) = (a.nrow, b.ncol)
// Prefer col-wise result iff a is dense and b is sparse. In all other cases default to row-wise.
val preferColWiseR = sparsityAnalysis(a) && !sparsityAnalysis(b)
val mxR = r.getOrElse {
(sparsityAnalysis(a), preferColWiseR) match {
case (false, false) ⇒ b.like(m, n)
case (false, true) ⇒ b.like(n, m).t
case (true, false) ⇒ a.like(m, n)
case (true, true) ⇒ a.like(n, m).t
}
}
// Loop outer products
if (preferColWiseR) {
// this means B is sparse and A is not, so we need to iterate over b values and update R columns with +=
// one at a time.
for ((acol, brow) ← a.t.zip(b); bel ← brow.nonZeroes) mxR(::, bel.index()) += bel * acol
} else {
for ((acol, brow) ← a.t.zip(b); ael ← acol.nonZeroes()) mxR(ael.index(), ::) += ael * brow
}
mxR
}
}