package org.apache.spark.mllib.feature
import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
* Entry in vocabulary
private case class VocabWord(
var word: String,
var cn: Int,
var point: Array[Int],
var code: Array[Int],
var codeLen:Int
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
* and then learns vector representation of words in the vocabulary.
* The vector representation can be used as features in
* natural language processing and machine learning algorithms.
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* matches the original C implementation.
* For original C implementation, see
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality.
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
* Sets vector size (default: 100).
def setVectorSize(vectorSize: Int): this.type = {
this.vectorSize = vectorSize
* Sets initial learning rate (default: 0.025).
def setLearningRate(learningRate: Double): this.type = {
this.learningRate = learningRate
* Sets number of partitions (default: 1). Use a small number for accuracy.
def setNumPartitions(numPartitions: Int): this.type = {
require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
this.numPartitions = numPartitions
* Sets number of iterations (default: 1), which should be smaller than or equal to number of
* partitions.
def setNumIterations(numIterations: Int): this.type = {
this.numIterations = numIterations
* Sets random seed (default: a random long integer).
def setSeed(seed: Long): this.type = {
this.seed = seed
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
/** context words from [-window, window] */
private val window = 5
/** minimum frequency to consider a vocabulary word */
private val minCount = 5
private var trainWordsCount = 0
private var vocabSize = 0
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]
private def learnVocab(words: RDD[String]): Unit = {
vocab = => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
.filter( >= minCount)
.sortWith((a, b) => >
vocabSize = vocab.length
var a = 0
while (a < vocabSize) {
vocabHash += vocab(a).word -> a
trainWordsCount += vocab(a).cn
a += 1
logInfo("trainWordsCount = " + trainWordsCount)
private def createExpTable(): Array[Float] = {
val expTable = new Array[Float](EXP_TABLE_SIZE)
var i = 0
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = (tmp / (tmp + 1.0)).toFloat
i += 1
private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * 2 + 1)
val binary = new Array[Int](vocabSize * 2 + 1)
val parentNode = new Array[Int](vocabSize * 2 + 1)
val code = new Array[Int](MAX_CODE_LENGTH)
val point = new Array[Int](MAX_CODE_LENGTH)
var a = 0
while (a < vocabSize) {
count(a) = vocab(a).cn
a += 1
while (a < 2 * vocabSize) {
count(a) = 1e9.toInt
a += 1
var pos1 = vocabSize - 1
var pos2 = vocabSize
var min1i = 0
var min2i = 0
a = 0
while (a < vocabSize - 1) {
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min1i = pos1
pos1 -= 1
} else {
min1i = pos2
pos2 += 1
} else {
min1i = pos2
pos2 += 1
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min2i = pos1
pos1 -= 1
} else {
min2i = pos2
pos2 += 1
} else {
min2i = pos2
pos2 += 1
count(vocabSize + a) = count(min1i) + count(min2i)
parentNode(min1i) = vocabSize + a
parentNode(min2i) = vocabSize + a
binary(min2i) = 1
a += 1
// Now assign binary code to each vocabulary word
var i = 0
a = 0
while (a < vocabSize) {
var b = a
i = 0
while (b != vocabSize * 2 - 2) {
code(i) = binary(b)
point(i) = b
i += 1
b = parentNode(b)
vocab(a).codeLen = i
vocab(a).point(0) = vocabSize - 2
b = 0
while (b < i) {
vocab(a).code(i - b - 1) = code(b)
vocab(a).point(i - b) = point(b) - vocabSize
b += 1
a += 1
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
* @return a Word2VecModel
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {
val words = dataset.flatMap(x => x)
val sc = dataset.context
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext
def next(): Array[Int] = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(
word match {
case Some(w) =>
sentence += w
sentenceLength += 1
case None =>
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
// TODO: discount by iteration?
alpha =
learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
if (a != window) {
val c = pos - window + a
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * vectorSize
val neu1e = new Array[Float](vectorSize)
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val inner = bcVocab.value(word).point(d)
val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
syn1Modify(inner) += 1
d += 1
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
syn0Modify(lastWord) += 1
a += 1
pos += 1
(syn0, syn1, lwc, wc)
val syn0Local = model._1
val syn1Local = model._2
// Only output modified vectors.
Iterator.tabulate(vocabSize) { index =>
if (syn0Modify(index) > 0) {
Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
} else {
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
if (syn1Modify(index) > 0) {
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
} else {
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
var i = 0
while (i < synAgg.length) {
val index = synAgg(i)._1
if (index < vocabSize) {
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
} else {
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
i += 1
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](vectorSize)
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
word2VecMap += word -> vector
i += 1
new Word2VecModel(word2VecMap.toMap)
* Computes the vector representation of each word in vocabulary (Java version).
* @param dataset a JavaRDD of words
* @return a Word2VecModel
def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
* :: Experimental ::
* Word2Vec model
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.snrm2(n, v1, 1)
val norm2 = blas.snrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
* Transforms a word to its vector representation
* @param word a word
* @return vector representation of word
def transform(word: String): Vector = {
model.get(word) match {
case Some(vec) =>
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector =
model.mapValues(vec => cosineSimilarity(fVector, vec))
.sortBy(- _._2)
.take(num + 1)
* Returns a map of words to their vector representations.
def getVectors: Map[String, Array[Float]] = {