blob: a33a30a9d18ba7feb398a5cd4028d91f07f9e06f [file] [log] [blame]
package io.prediction.e2.engine
import io.prediction.e2.fixture.{MarkovChainFixture, SharedSparkContext}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix
import org.scalatest.{FlatSpec, Matchers}
import scala.language.reflectiveCalls
class MarkovChainTest extends FlatSpec with Matchers with SharedSparkContext
with MarkovChainFixture {
"Markov chain training" should "produce a model" in {
val matrix =
new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries))
val model = MarkovChain.train(matrix, 2)
model.n should be(2)
model.transitionVectors.collect() should contain theSameElementsAs Seq(
(0, Vectors.sparse(2, Array(0, 1), Array(0.3, 0.7))),
(1, Vectors.sparse(2, Array(0, 1), Array(0.5, 0.5)))
)
}
it should "contains probabilities of the top N only" in {
val matrix =
new CoordinateMatrix(sc.parallelize(fiveByFiveMatrix.matrixEntries))
val model = MarkovChain.train(matrix, 2)
model.n should be(2)
(0, Vectors.sparse(5, Array(1, 2), Array(.6, .4)))
model.transitionVectors.collect() should contain theSameElementsAs Seq(
(0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))),
(1, Vectors.sparse(5, Array(2, 4), Array(9.0 / 25, 8.0 / 25))),
(2, Vectors.sparse(5, Array(1, 4), Array(10.0 / 28, 10.0 / 28))),
(3, Vectors.sparse(5, Array(3, 4), Array(3.0 / 9, 4.0 / 9))),
(4, Vectors.sparse(5, Array(3, 4), Array(8.0 / 25, 0.4)))
)
}
"Model predict" should "calculate the probablities of new states" in {
val matrix =
new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries))
val model = MarkovChain.train(matrix, 2)
val nextState = model.predict(Seq(0.4, 0.6))
nextState should contain theSameElementsInOrderAs Seq(0.42, 0.58)
}
}