blob: 7a8b67e4f97d00c03098a467574d302a7c1e7dab [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spot.lda
import org.apache.log4j.Logger
import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.{LDAModel, _}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
/**
* Spark LDA implementation
* Contains routines for LDA using Scala Spark implementation from org.apache.spark.mllib.clustering
* 1. Creates list of unique documents, words and model based on those two
* 2. Processes the model using Spark LDA
* 3. Reads Spark LDA results: Topic distributions per document (docTopicDist) and word distributions per topic (wordTopicMat)
* 4. Convert wordTopicMat (Matrix) and docTopicDist (RDD) to appropriate formats (maps) originally specified in LDA-C
*/
object SpotLDAWrapper {
/**
* Runs Spark LDA and returns a new model.
*
* @param topicCount number of topics to find
* @param logger application logger
* @param ldaSeed LDA seed
* @param ldaAlpha document concentration
* @param ldaBeta topic concentration
* @param ldaOptimizerOption LDA optimizer, em or online
* @param maxIterations maximum number of iterations for the optimizer
* @return
*/
def run(topicCount: Int,
logger: Logger,
ldaSeed: Option[Long],
ldaAlpha: Double,
ldaBeta: Double,
ldaOptimizerOption: String,
maxIterations: Int,
helper: SpotLDAHelper): SpotLDAModel = {
// Structure corpus so that the index is the docID, values are the vectors of word occurrences in that doc
val ldaCorpus: RDD[(Long, Vector)] = helper.formattedCorpus
// Instantiate optimizer based on input
val ldaOptimizer = ldaOptimizerOption match {
case "em" => new EMLDAOptimizer
case "online" => new OnlineLDAOptimizer().setOptimizeDocConcentration(true).setMiniBatchFraction({
val corpusSize = ldaCorpus.count()
if (corpusSize < 2) 0.75
else (0.05 + 1) / corpusSize
})
case _ => throw new IllegalArgumentException(
s"Invalid LDA optimizer $ldaOptimizerOption")
}
logger.info(s"Running Spark LDA with params alpha = $ldaAlpha beta = $ldaBeta " +
s"Max iterations = $maxIterations Optimizer = $ldaOptimizerOption")
// Set LDA params from input args
val lda =
new LDA()
.setK(topicCount)
.setMaxIterations(maxIterations)
.setAlpha(ldaAlpha)
.setBeta(ldaBeta)
.setOptimizer(ldaOptimizer)
// If caller does not provide seed to lda, ie. ldaSeed is empty,
// lda is seeded automatically set to hash value of class name
if (ldaSeed.nonEmpty) {
lda.setSeed(ldaSeed.get)
}
val model: LDAModel = lda.run(ldaCorpus)
SpotLDAModel(model)
}
/**
* Load an existing model from HDFS location.
*
* @param sparkSession the Spark session.
* @param location the HDFS location for the model.
* @param ldaOptimizerOption LDA optimizer, em or online.
* @return SpotLDAModel
*/
def load(sparkSession: SparkSession, location: String, ldaOptimizerOption: String): SpotLDAModel = {
val sparkContext: SparkContext = sparkSession.sparkContext
val model = ldaOptimizerOption match {
case "em" => DistributedLDAModel.load(sparkContext, location)
case "online" => LocalLDAModel.load(sparkContext, location)
case _ => throw new IllegalArgumentException(
s"Invalid LDA optimizer $ldaOptimizerOption")
}
SpotLDAModel(model)
}
}