blob: d983dd379115c4f55b578cd22172dd6358a5d317 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.feature
import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
* Entry in vocabulary
*/
private case class VocabWord(
var word: String,
var cn: Int,
var point: Array[Int],
var code: Array[Int],
var codeLen:Int
)
/**
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
* and then learns vector representation of words in the vocabulary.
* The vector representation can be used as features in
* natural language processing and machine learning algorithms.
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* matches the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality.
*/
@Experimental
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
/**
* Sets vector size (default: 100).
*/
def setVectorSize(vectorSize: Int): this.type = {
this.vectorSize = vectorSize
this
}
/**
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
this.learningRate = learningRate
this
}
/**
* Sets number of partitions (default: 1). Use a small number for accuracy.
*/
def setNumPartitions(numPartitions: Int): this.type = {
require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
this.numPartitions = numPartitions
this
}
/**
* Sets number of iterations (default: 1), which should be smaller than or equal to number of
* partitions.
*/
def setNumIterations(numIterations: Int): this.type = {
this.numIterations = numIterations
this
}
/**
* Sets random seed (default: a random long integer).
*/
def setSeed(seed: Long): this.type = {
this.seed = seed
this
}
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
/** context words from [-window, window] */
private val window = 5
/** minimum frequency to consider a vocabulary word */
private val minCount = 5
private var trainWordsCount = 0
private var vocabSize = 0
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]
private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)
vocabSize = vocab.length
var a = 0
while (a < vocabSize) {
vocabHash += vocab(a).word -> a
trainWordsCount += vocab(a).cn
a += 1
}
logInfo("trainWordsCount = " + trainWordsCount)
}
private def createExpTable(): Array[Float] = {
val expTable = new Array[Float](EXP_TABLE_SIZE)
var i = 0
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = (tmp / (tmp + 1.0)).toFloat
i += 1
}
expTable
}
private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * 2 + 1)
val binary = new Array[Int](vocabSize * 2 + 1)
val parentNode = new Array[Int](vocabSize * 2 + 1)
val code = new Array[Int](MAX_CODE_LENGTH)
val point = new Array[Int](MAX_CODE_LENGTH)
var a = 0
while (a < vocabSize) {
count(a) = vocab(a).cn
a += 1
}
while (a < 2 * vocabSize) {
count(a) = 1e9.toInt
a += 1
}
var pos1 = vocabSize - 1
var pos2 = vocabSize
var min1i = 0
var min2i = 0
a = 0
while (a < vocabSize - 1) {
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min1i = pos1
pos1 -= 1
} else {
min1i = pos2
pos2 += 1
}
} else {
min1i = pos2
pos2 += 1
}
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min2i = pos1
pos1 -= 1
} else {
min2i = pos2
pos2 += 1
}
} else {
min2i = pos2
pos2 += 1
}
count(vocabSize + a) = count(min1i) + count(min2i)
parentNode(min1i) = vocabSize + a
parentNode(min2i) = vocabSize + a
binary(min2i) = 1
a += 1
}
// Now assign binary code to each vocabulary word
var i = 0
a = 0
while (a < vocabSize) {
var b = a
i = 0
while (b != vocabSize * 2 - 2) {
code(i) = binary(b)
point(i) = b
i += 1
b = parentNode(b)
}
vocab(a).codeLen = i
vocab(a).point(0) = vocabSize - 2
b = 0
while (b < i) {
vocab(a).code(i - b - 1) = code(b)
vocab(a).point(i - b) = point(b) - vocabSize
b += 1
}
a += 1
}
}
/**
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
* @return a Word2VecModel
*/
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
val words = dataset.flatMap(x => x)
learnVocab(words)
createBinaryTree()
val sc = dataset.context
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext
def next(): Array[Int] = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next())
word match {
case Some(w) =>
sentence += w
sentenceLength += 1
case None =>
}
}
sentence.toArray
}
}
}
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
// TODO: discount by iteration?
alpha =
learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
if (a != window) {
val c = pos - window + a
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * vectorSize
val neu1e = new Array[Float](vectorSize)
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val inner = bcVocab.value(word).point(d)
val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
syn1Modify(inner) += 1
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
syn0Modify(lastWord) += 1
}
}
a += 1
}
pos += 1
}
(syn0, syn1, lwc, wc)
}
val syn0Local = model._1
val syn1Local = model._2
// Only output modified vectors.
Iterator.tabulate(vocabSize) { index =>
if (syn0Modify(index) > 0) {
Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
} else {
None
}
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
if (syn1Modify(index) > 0) {
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
} else {
None
}
}.flatten
}
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect()
var i = 0
while (i < synAgg.length) {
val index = synAgg(i)._1
if (index < vocabSize) {
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
} else {
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
i += 1
}
}
newSentences.unpersist()
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](vectorSize)
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
word2VecMap += word -> vector
i += 1
}
new Word2VecModel(word2VecMap.toMap)
}
/**
* Computes the vector representation of each word in vocabulary (Java version).
* @param dataset a JavaRDD of words
* @return a Word2VecModel
*/
def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
fit(dataset.rdd.map(_.asScala))
}
}
/**
* :: Experimental ::
* Word2Vec model
*/
@Experimental
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.snrm2(n, v1, 1)
val norm2 = blas.snrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
}
/**
* Transforms a word to its vector representation
* @param word a word
* @return vector representation of word
*/
def transform(word: String): Vector = {
model.get(word) match {
case Some(vec) =>
Vectors.dense(vec.map(_.toDouble))
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
}
/**
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
findSynonyms(vector,num)
}
/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
model.mapValues(vec => cosineSimilarity(fVector, vec))
.toSeq
.sortBy(- _._2)
.take(num + 1)
.tail
.toArray
}
/**
* Returns a map of words to their vector representations.
*/
def getVectors: Map[String, Array[Float]] = {
model
}
}