blob: 669bb692bd8463fc40ecfe12536e54d68985978f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spot.lda
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDAModel, LocalLDAModel}
import org.apache.spark.sql.SparkSession
/**
* Spot LDAModel.
*/
sealed trait SpotLDAModel {
/**
* Save the model to HDFS
*
* @param sparkSession
* @param location
*/
def save(sparkSession: SparkSession, location: String): Unit
/**
* Predict topicDistributions and get topicsMatrix along with results formatted for Apache Spot scoring
*
* @param helper
* @return
*/
def predict(helper: SpotLDAHelper): SpotLDAResult
}
/**
* Spark LocalLDAModel wrapper.
*
* @param ldaModel Spark LDA Model
*/
class SpotLocalLDAModel(final val ldaModel: LDAModel) extends SpotLDAModel {
/**
* Save LocalLDAModel on HDFS location
*
* @param sparkSession the Spark session
* @param location the HDFS location
*/
override def save(sparkSession: SparkSession, location: String): Unit = {
val sparkContext = sparkSession.sparkContext
ldaModel.save(sparkContext, location)
}
/**
* Predict topicDistributions and get topicsMatrix along with results formatted for Apache Spot scoring.
* SpotLocalLDAModel.predict will use corpus from spotLDAHelper which can be a new set of documents or the same
* documents used for training.
*
* @param spotLDAHelper Spot LDA Helper object, can be the same used for training or a new instance with new
* documents.
* @return SpotLDAResult
*/
override def predict(spotLDAHelper: SpotLDAHelper): SpotLDAResult = {
val localLDAModel: LocalLDAModel = ldaModel.asInstanceOf[LocalLDAModel]
val topicDistributions = localLDAModel.topicDistributions(spotLDAHelper.formattedCorpus)
val topicMix = localLDAModel.topicsMatrix
SpotLDAResult(spotLDAHelper, topicDistributions, topicMix)
}
}
/** Spark DistributedLDAModel wrapper.
* Ideally, this model should be used only for batch processing.
*
* @param ldaModel Spark LDA Model
*/
class SpotDistributedLDAModel(final val ldaModel: LDAModel) extends
SpotLDAModel {
/**
* Save DistributedLDAModel on HDFS location
*
* @param sparkSession the Spark session
* @param location the HDFS location
*/
override def save(sparkSession: SparkSession, location: String): Unit = {
val sparkContext = sparkSession.sparkContext
ldaModel.save(sparkContext, location)
}
/**
* Predict topicDistributions and get topicsMatrix along with results formatted for Apache Spot scoring.
* SpotDistributeLDAModel.predict will use same documents that were used for training, can't predict on new
* documents. When passing spotLDAHelper we recommend to make sure it's the same object it was passed for training.
*
* @param spotLDAHelper Spot LDA Helper object used for training
* @return SpotLDAResult
*/
override def predict(spotLDAHelper: SpotLDAHelper): SpotLDAResult = {
val distributedLDAModel: DistributedLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
val topicDistributions = distributedLDAModel.topicDistributions
val topicsMatrix = distributedLDAModel.topicsMatrix
SpotLDAResult(spotLDAHelper, topicDistributions, topicsMatrix)
}
}
object SpotLDAModel {
/**
* Factory method, based on instance of ldaModel will generate an object based on DistributedLDAModel
* implementation or LocalLDAModel.
*
* @param ldaModel
* @param spotLDAHelper
* @return
*/
def apply(ldaModel: LDAModel, spotLDAHelper: SpotLDAHelper = null): SpotLDAModel = {
ldaModel match {
case model: DistributedLDAModel => new SpotDistributedLDAModel(model)
case model: LocalLDAModel => new SpotLocalLDAModel(model)
}
}
}