blob: 2f0bbf214bfcf85df1e4bb1f787a0d8a911abb89 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gearpump.experiments.pagerank
import akka.actor.Actor.Receive
import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.experiments.pagerank.PageRankApplication.NodeWithTaskId
import org.apache.gearpump.experiments.pagerank.PageRankController.Tick
import org.apache.gearpump.experiments.pagerank.PageRankWorker.{LatestWeight, UpdateWeight}
import org.apache.gearpump.streaming.task.{Task, TaskContext, TaskId, TaskWrapper}
import org.apache.gearpump.util.Graph
class PageRankWorker(taskContext: TaskContext, conf: UserConfig) extends Task(taskContext, conf) {
import taskContext.taskId
private var weight: Double = 1.0
private var upstreamWeights = Map.empty[TaskId, Double]
val taskCount = conf.getInt(PageRankApplication.COUNT).get
lazy val allTasks = (0 until taskCount).toList.map(TaskId(processorId = 1, _))
private val graph = conf.getValue[Graph[NodeWithTaskId[_], AnyRef]](PageRankApplication.DAG).get
private val node = graph.getVertices.find { node =>
node.taskId == taskContext.taskId.index
}.get
private val downstream = graph.outgoingEdgesOf(node).map(_._3)
.map(id => taskId.copy(index = id.taskId)).toSeq
private val upstreamCount = graph.incomingEdgesOf(node).map(_._1).size
LOG.info(s"downstream nodes: $downstream")
private var tick = 0
private def output(msg: AnyRef, tasks: TaskId*): Unit = {
taskContext.asInstanceOf[TaskWrapper].outputUnManaged(msg, tasks: _*)
}
override def receiveUnManagedMessage: Receive = {
case Tick(tick) =>
this.tick = tick
if (downstream.length == 0) {
// If there is no downstream, we will evenly distribute our page rank to
// every node in the graph
val update = UpdateWeight(taskId, weight / taskCount)
output(update, allTasks: _*)
} else {
val update = UpdateWeight(taskId, weight / downstream.length)
output(update, downstream: _*)
}
case update@UpdateWeight(upstreamTaskId, weight) =>
upstreamWeights += upstreamTaskId -> weight
if (upstreamWeights.size == upstreamCount) {
val nextWeight = upstreamWeights.foldLeft(0.0) { (sum, item) => sum + item._2 }
this.upstreamWeights = Map.empty[TaskId, Double]
this.weight = nextWeight
output(LatestWeight(taskId, weight, tick), TaskId(0, 0))
}
}
}
object PageRankWorker {
case class UpdateWeight(taskId: TaskId, weight: Double)
case class LatestWeight(taskId: TaskId, weight: Double, tick: Int)
}