Merge branch 'develop'
diff --git a/README.md b/README.md
index bc53895..cdade24 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,12 @@
 
 ## Versions
 
+### v0.3.1
+
+- Add CooccurrenceAlgorithm.
+  To use this algorithm, override engine.json by engine-cooccurrence.json,
+  or specify `--variant engine-cooccurrence.json` parameter for both pio train **and** deploy
+
 ### v0.3.0
 
 - update for PredictionIO 0.9.2, including:
diff --git a/build.sbt b/build.sbt
index e04ba88..3260f8e 100644
--- a/build.sbt
+++ b/build.sbt
@@ -6,7 +6,12 @@
 
 organization := "io.prediction"
 
+parallelExecution in Test := false
+
+test in assembly := {}
+
 libraryDependencies ++= Seq(
   "io.prediction"    %% "core"          % pioVersion.value % "provided",
   "org.apache.spark" %% "spark-core"    % "1.3.0" % "provided",
-  "org.apache.spark" %% "spark-mllib"   % "1.3.0" % "provided")
+  "org.apache.spark" %% "spark-mllib"   % "1.3.0" % "provided",
+  "org.scalatest"    %% "scalatest"     % "2.2.1" % "test")
diff --git a/engine-cooccurrence.json b/engine-cooccurrence.json
new file mode 100644
index 0000000..ba1a603
--- /dev/null
+++ b/engine-cooccurrence.json
@@ -0,0 +1,18 @@
+{
+  "id": "default",
+  "description": "Default settings",
+  "engineFactory": "org.template.similarproduct.SimilarProductEngine",
+  "datasource": {
+    "params" : {
+      "appName": "INVALID_APP_NAME"
+    }
+  },
+  "algorithms": [
+    {
+      "name": "cooccurrence",
+      "params": {
+        "n": 20
+      }
+    }
+  ]
+}
diff --git a/src/main/scala/CooccurrenceAlgorithm.scala b/src/main/scala/CooccurrenceAlgorithm.scala
new file mode 100644
index 0000000..f94dd7e
--- /dev/null
+++ b/src/main/scala/CooccurrenceAlgorithm.scala
@@ -0,0 +1,155 @@
+package org.template.similarproduct
+
+import io.prediction.controller.P2LAlgorithm
+import io.prediction.controller.Params
+import io.prediction.data.storage.BiMap
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+
+case class CooccurrenceAlgorithmParams(
+  n: Int // top co-occurrence
+) extends Params
+
+class CooccurrenceModel(
+  val topCooccurrences: Map[Int, Array[(Int, Int)]],
+  val itemStringIntMap: BiMap[String, Int],
+  val items: Map[Int, Item]
+) extends Serializable {
+  @transient lazy val itemIntStringMap = itemStringIntMap.inverse
+
+  override def toString(): String = {
+    val s = topCooccurrences.mapValues { v => v.mkString(",") }
+    s.toString
+  }
+}
+
+class CooccurrenceAlgorithm(val ap: CooccurrenceAlgorithmParams)
+  extends P2LAlgorithm[PreparedData, CooccurrenceModel, Query, PredictedResult] {
+
+  def train(sc: SparkContext, data: PreparedData): CooccurrenceModel = {
+
+    val itemStringIntMap = BiMap.stringInt(data.items.keys)
+
+    val topCooccurrences = trainCooccurrence(
+      events = data.viewEvents,
+      n = ap.n,
+      itemStringIntMap = itemStringIntMap
+    )
+
+    // collect Item as Map and convert ID to Int index
+    val items: Map[Int, Item] = data.items.map { case (id, item) =>
+      (itemStringIntMap(id), item)
+    }.collectAsMap.toMap
+
+    new CooccurrenceModel(
+      topCooccurrences = topCooccurrences,
+      itemStringIntMap = itemStringIntMap,
+      items = items
+    )
+
+  }
+
+  /* given the user-item events, find out top n co-occurrence pair for each item */
+  def trainCooccurrence(
+    events: RDD[ViewEvent],
+    n: Int,
+    itemStringIntMap: BiMap[String, Int]): Map[Int, Array[(Int, Int)]] = {
+
+    val userItem = events
+      // map item from string to integer index
+      .map ( v => (v.user, itemStringIntMap.getOrElse(v.item, 1)) )
+      .filter { case (user, item) => item != -1 }
+      // if user view same item multiple times, only count as once
+      .distinct()
+      .cache()
+
+    val cooccurrences: RDD[((Int, Int), Int)] = userItem.join(userItem)
+      // remove duplicate pair in reversed order for each user. eg. (a,b) vs. (b,a)
+      .filter { case (user, (item1, item2)) => item1 < item2 }
+      .map { case (user, (item1, item2)) => ((item1, item2), 1) }
+      .reduceByKey{ (a: Int, b: Int) => a + b }
+
+    val topCooccurrences = cooccurrences
+      .flatMap{ case (pair, count) =>
+        Seq((pair._1, (pair._2, count)), (pair._2, (pair._1, count)))
+      }
+      .groupByKey
+      .map { case (item, itemCounts) =>
+        (item, itemCounts.toArray.sortBy(_._2)(Ordering.Int.reverse).take(n))
+      }
+      .collectAsMap.toMap
+
+    topCooccurrences
+  }
+
+  def predict(model: CooccurrenceModel, query: Query): PredictedResult = {
+
+    // convert items to Int index
+    val queryList: Set[Int] = query.items
+      .flatMap(model.itemStringIntMap.get(_))
+      .toSet
+
+    val whiteList: Option[Set[Int]] = query.whiteList.map( set =>
+      set.map(model.itemStringIntMap.get(_)).flatten
+    )
+
+    val blackList: Option[Set[Int]] = query.blackList.map ( set =>
+      set.map(model.itemStringIntMap.get(_)).flatten
+    )
+
+    val counts: Array[(Int, Int)] = queryList.toVector
+      .flatMap { q =>
+        model.topCooccurrences.getOrElse(q, Array())
+      }
+      .groupBy { case (index, count) => index }
+      .map { case (index, indexCounts) => (index, indexCounts.map(_._2).sum) }
+      .toArray
+
+    val itemScores = counts
+      .filter { case (i, v) =>
+        isCandidateItem(
+          i = i,
+          items = model.items,
+          categories = query.categories,
+          queryList = queryList,
+          whiteList = whiteList,
+          blackList = blackList
+        )
+      }
+      .sortBy(_._2)(Ordering.Int.reverse)
+      .take(query.num)
+      .map { case (index, count) =>
+        ItemScore(
+          item = model.itemIntStringMap(index),
+          score = count
+        )
+      }
+
+    new PredictedResult(itemScores)
+
+  }
+
+  private
+  def isCandidateItem(
+    i: Int,
+    items: Map[Int, Item],
+    categories: Option[Set[String]],
+    queryList: Set[Int],
+    whiteList: Option[Set[Int]],
+    blackList: Option[Set[Int]]
+  ): Boolean = {
+    whiteList.map(_.contains(i)).getOrElse(true) &&
+    blackList.map(!_.contains(i)).getOrElse(true) &&
+    // discard items in query as well
+    (!queryList.contains(i)) &&
+    // filter categories
+    categories.map { cat =>
+      items(i).categories.map { itemCat =>
+        // keep this item if has ovelap categories with the query
+        !(itemCat.toSet.intersect(cat).isEmpty)
+      }.getOrElse(false) // discard this item if it has no categories
+    }.getOrElse(true)
+  }
+
+}
diff --git a/src/main/scala/Engine.scala b/src/main/scala/Engine.scala
index 9815ebf..766f7d8 100644
--- a/src/main/scala/Engine.scala
+++ b/src/main/scala/Engine.scala
@@ -13,7 +13,9 @@
 
 case class PredictedResult(
   itemScores: Array[ItemScore]
-) extends Serializable
+) extends Serializable {
+  override def toString: String = itemScores.mkString(",")
+}
 
 case class ItemScore(
   item: String,
@@ -25,7 +27,9 @@
     new Engine(
       classOf[DataSource],
       classOf[Preparator],
-      Map("als" -> classOf[ALSAlgorithm]),
+      Map(
+        "als" -> classOf[ALSAlgorithm],
+        "cooccurrence" -> classOf[CooccurrenceAlgorithm]),
       classOf[Serving])
   }
 }
diff --git a/src/test/scala/CooccurenceAlgorithmTest.scala b/src/test/scala/CooccurenceAlgorithmTest.scala
new file mode 100644
index 0000000..572844e
--- /dev/null
+++ b/src/test/scala/CooccurenceAlgorithmTest.scala
@@ -0,0 +1,149 @@
+package org.template.similarproduct
+
+import io.prediction.data.storage.BiMap
+
+import org.scalatest.FlatSpec
+import org.scalatest.Matchers
+
+class CooccurrenceAlgorithmTest
+  extends FlatSpec with EngineTestSparkContext with Matchers {
+
+  val params = CooccurrenceAlgorithmParams(n = 10)
+  val algorithm = new CooccurrenceAlgorithm(params)
+
+  val itemStringIntMap = BiMap(Map(
+    "i0" -> 0,
+    "i1" -> 1,
+    "i2" -> 2,
+    "i3" -> 3
+  ))
+
+  val viewSeq = Seq(
+    ViewEvent("u0", "i0", 1000010),
+    ViewEvent("u0", "i1", 1000020),
+    ViewEvent("u0", "i1", 1000020),
+    ViewEvent("u1", "i1", 1000030),
+    ViewEvent("u1", "i2", 1000040),
+    ViewEvent("u1", "i3", 1000040),
+    ViewEvent("u2", "i2", 1000040),
+    ViewEvent("u2", "i1", 1000040),
+    ViewEvent("u3", "i1", 1000040),
+    ViewEvent("u3", "i2", 1000040),
+    ViewEvent("u3", "i0", 1000040),
+    ViewEvent("u4", "i2", 1000040),
+    ViewEvent("u4", "i3", 1000040),
+    ViewEvent("u5", "i0", 1000040),
+    ViewEvent("u5", "i1", 1000040),
+    ViewEvent("u6", "i0", 1000040),
+    ViewEvent("u6", "i1", 1000040)
+  )
+
+  "trainCooccurrence" should "return top 10 correctly" in {
+
+    val viewEvents = sc.parallelize(viewSeq)
+
+    val topCooccurrences = algorithm.trainCooccurrence(viewEvents, 10, itemStringIntMap)
+
+    val expected = Map(
+      0 -> Array((1, 4), (2, 1)),
+      1 -> Array((0, 4), (2, 3), (3, 1)),
+      2 -> Array((1, 3), (3, 2), (0, 1)),
+      3 -> Array((2, 2), (1, 1))
+    )
+
+    topCooccurrences(0) should be (expected(0))
+    topCooccurrences(1) should be (expected(1))
+    topCooccurrences(2) should be (expected(2))
+    topCooccurrences(3) should be (expected(3))
+
+  }
+
+  "trainCooccurrence" should "return top 1 correctly" in {
+
+    val viewEvents = sc.parallelize(viewSeq)
+
+    val topCooccurrences = algorithm.trainCooccurrence(viewEvents, 1, itemStringIntMap)
+
+    val expected = Map(
+      0 -> Array((1, 4)),
+      1 -> Array((0, 4)),
+      2 -> Array((1, 3)),
+      3 -> Array((2, 2))
+    )
+
+    topCooccurrences(0) should be (expected(0))
+    topCooccurrences(1) should be (expected(1))
+    topCooccurrences(2) should be (expected(2))
+    topCooccurrences(3) should be (expected(3))
+
+  }
+
+  val model = new CooccurrenceModel(
+    //
+    topCooccurrences = Map(
+      0 -> Array((1, 4), (2, 1)),
+      1 -> Array((0, 4), (2, 3), (3, 1)),
+      2 -> Array((1, 3), (3, 2), (0, 1)),
+      3 -> Array((2, 2), (1, 1))
+    ),
+    itemStringIntMap = BiMap(Map(
+      "i0" -> 0,
+      "i1" -> 1,
+      "i2" -> 2,
+      "i3" -> 3
+    )),
+    items = Map(
+      0 -> Item(categories = Some(List("c0", "c1"))),
+      1 -> Item(categories = None),
+      2 -> Item(categories = Some(List("c0", "c2"))),
+      3 -> Item(categories = Some(List("c0,", "c2", "c3")))
+    )
+  )
+
+  // very basic test only
+  "predict top 10 items" should "return PredictedResult correctly" in {
+
+    val query = Query(
+      items = List("i1"),
+      num = 10,
+      categories = None,
+      whiteList = None,
+      blackList = None
+    )
+
+    val predictedResult = algorithm.predict(model, query)
+
+    val expected = PredictedResult(
+      Array(ItemScore("i0", 4.0), ItemScore("i2", 3.0), ItemScore("i3", 1.0))
+    )
+
+    // scalatest can't match array with equal if wrapped insider case class.
+    // directly compare itemScores array instead to work around.
+    predictedResult.itemScores should equal (expected.itemScores)
+
+  }
+
+  "predict top 2 items" should "return PredictedResult correctly" in {
+
+    val query = Query(
+      items = List("i1", "i2"),
+      num = 10,
+      categories = None,
+      whiteList = None,
+      blackList = None
+    )
+
+    val predictedResult = algorithm.predict(model, query)
+
+    val expected = PredictedResult(
+      Array(ItemScore("i0", 5.0), ItemScore("i3", 3.0))
+    )
+
+    // scalatest can't match array with equal if wrapped insider case class.
+    // directly compare itemScores array instead to work around.
+    predictedResult.itemScores should equal (expected.itemScores)
+
+  }
+
+
+}
diff --git a/src/test/scala/EngineTestSparkContext.scala b/src/test/scala/EngineTestSparkContext.scala
new file mode 100644
index 0000000..f509c2f
--- /dev/null
+++ b/src/test/scala/EngineTestSparkContext.scala
@@ -0,0 +1,36 @@
+package org.template.similarproduct
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, Suite}
+
+trait EngineTestSparkContext extends BeforeAndAfterAll {
+  self: Suite =>
+  @transient private var _sc: SparkContext = _
+
+  def sc: SparkContext = _sc
+
+  var conf = new SparkConf(false)
+
+  override def beforeAll() {
+    _sc = new SparkContext("local", "test", conf)
+    super.beforeAll()
+  }
+
+  override def afterAll() {
+    LocalSparkContext.stop(_sc)
+
+    _sc = null
+    super.afterAll()
+  }
+}
+
+object LocalSparkContext {
+  def stop(sc: SparkContext) {
+    if (sc != null) {
+      sc.stop()
+    }
+    // To avoid Akka rebinding to the same port, since it doesn't unbind
+    // immediately on shutdown
+    System.clearProperty("spark.driver.port")
+  }
+}