| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.tinkerpop.gremlin.process.computer.ranking.pagerank; |
| |
| import org.apache.commons.configuration.Configuration; |
| import org.apache.tinkerpop.gremlin.process.computer.GraphComputer; |
| import org.apache.tinkerpop.gremlin.process.computer.Memory; |
| import org.apache.tinkerpop.gremlin.process.computer.MessageCombiner; |
| import org.apache.tinkerpop.gremlin.process.computer.MessageScope; |
| import org.apache.tinkerpop.gremlin.process.computer.Messenger; |
| import org.apache.tinkerpop.gremlin.process.traversal.util.TraversalClassFunction; |
| import org.apache.tinkerpop.gremlin.process.traversal.util.TraversalObjectFunction; |
| import org.apache.tinkerpop.gremlin.process.traversal.util.TraversalScriptFunction; |
| import org.apache.tinkerpop.gremlin.process.computer.util.AbstractVertexProgramBuilder; |
| import org.apache.tinkerpop.gremlin.process.computer.util.ConfigurationTraversal; |
| import org.apache.tinkerpop.gremlin.process.computer.util.StaticVertexProgram; |
| import org.apache.tinkerpop.gremlin.process.traversal.Traversal; |
| import org.apache.tinkerpop.gremlin.process.traversal.TraversalSource; |
| import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__; |
| import org.apache.tinkerpop.gremlin.structure.Edge; |
| import org.apache.tinkerpop.gremlin.structure.Graph; |
| import org.apache.tinkerpop.gremlin.structure.Vertex; |
| import org.apache.tinkerpop.gremlin.structure.VertexProperty; |
| import org.apache.tinkerpop.gremlin.structure.util.StringFactory; |
| import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; |
| |
| import java.util.Arrays; |
| import java.util.HashSet; |
| import java.util.Optional; |
| import java.util.Set; |
| import java.util.function.Function; |
| import java.util.function.Supplier; |
| |
| /** |
| * @author Marko A. Rodriguez (http://markorodriguez.com) |
| */ |
| public class PageRankVertexProgram extends StaticVertexProgram<Double> { |
| |
| private MessageScope.Local<Double> incidentMessageScope = MessageScope.Local.of(__::outE); |
| private MessageScope.Local<Double> countMessageScope = MessageScope.Local.of(new MessageScope.Local.ReverseTraversalSupplier(this.incidentMessageScope)); |
| |
| public static final String PAGE_RANK = "gremlin.pageRankVertexProgram.pageRank"; |
| public static final String EDGE_COUNT = "gremlin.pageRankVertexProgram.edgeCount"; |
| |
| private static final String VERTEX_COUNT = "gremlin.pageRankVertexProgram.vertexCount"; |
| private static final String ALPHA = "gremlin.pageRankVertexProgram.alpha"; |
| private static final String TOTAL_ITERATIONS = "gremlin.pageRankVertexProgram.totalIterations"; |
| private static final String TRAVERSAL_SUPPLIER = "gremlin.pageRankVertexProgram.traversalSupplier"; |
| |
| private ConfigurationTraversal<Vertex, Edge> configurationTraversal; |
| private double vertexCountAsDouble = 1.0d; |
| private double alpha = 0.85d; |
| private int totalIterations = 30; |
| |
| private static final Set<String> COMPUTE_KEYS = new HashSet<>(Arrays.asList(PAGE_RANK, EDGE_COUNT)); |
| |
| private PageRankVertexProgram() { |
| |
| } |
| |
| @Override |
| public void loadState(final Graph graph, final Configuration configuration) { |
| if (configuration.containsKey(TRAVERSAL_SUPPLIER)) { |
| this.configurationTraversal = ConfigurationTraversal.loadState(graph, configuration, TRAVERSAL_SUPPLIER); |
| this.incidentMessageScope = MessageScope.Local.of(this.configurationTraversal); |
| this.countMessageScope = MessageScope.Local.of(new MessageScope.Local.ReverseTraversalSupplier(this.incidentMessageScope)); |
| } |
| this.vertexCountAsDouble = configuration.getDouble(VERTEX_COUNT, 1.0d); |
| this.alpha = configuration.getDouble(ALPHA, 0.85d); |
| this.totalIterations = configuration.getInt(TOTAL_ITERATIONS, 30); |
| } |
| |
| @Override |
| public void storeState(final Configuration configuration) { |
| configuration.setProperty(VERTEX_PROGRAM, PageRankVertexProgram.class.getName()); |
| configuration.setProperty(VERTEX_COUNT, this.vertexCountAsDouble); |
| configuration.setProperty(ALPHA, this.alpha); |
| configuration.setProperty(TOTAL_ITERATIONS, this.totalIterations); |
| if (null != this.configurationTraversal) { |
| this.configurationTraversal.storeState(configuration); |
| } |
| } |
| |
| @Override |
| public GraphComputer.ResultGraph getPreferredResultGraph() { |
| return GraphComputer.ResultGraph.NEW; |
| } |
| |
| @Override |
| public GraphComputer.Persist getPreferredPersist() { |
| return GraphComputer.Persist.VERTEX_PROPERTIES; |
| } |
| |
| @Override |
| public Set<String> getElementComputeKeys() { |
| return COMPUTE_KEYS; |
| } |
| |
| @Override |
| public Optional<MessageCombiner<Double>> getMessageCombiner() { |
| return (Optional) PageRankMessageCombiner.instance(); |
| } |
| |
| @Override |
| public Set<MessageScope> getMessageScopes(final Memory memory) { |
| final Set<MessageScope> set = new HashSet<>(); |
| set.add(memory.isInitialIteration() ? this.countMessageScope : this.incidentMessageScope); |
| return set; |
| } |
| |
| @Override |
| public void setup(final Memory memory) { |
| |
| } |
| |
| @Override |
| public void execute(final Vertex vertex, Messenger<Double> messenger, final Memory memory) { |
| if (memory.isInitialIteration()) { |
| messenger.sendMessage(this.countMessageScope, 1.0d); |
| } else if (1 == memory.getIteration()) { |
| double initialPageRank = 1.0d / this.vertexCountAsDouble; |
| double edgeCount = IteratorUtils.reduce(messenger.receiveMessages(), 0.0d, (a, b) -> a + b); |
| vertex.property(VertexProperty.Cardinality.single, PAGE_RANK, initialPageRank); |
| vertex.property(VertexProperty.Cardinality.single, EDGE_COUNT, edgeCount); |
| messenger.sendMessage(this.incidentMessageScope, initialPageRank / edgeCount); |
| } else { |
| double newPageRank = IteratorUtils.reduce(messenger.receiveMessages(), 0.0d, (a, b) -> a + b); |
| newPageRank = (this.alpha * newPageRank) + ((1.0d - this.alpha) / this.vertexCountAsDouble); |
| vertex.property(VertexProperty.Cardinality.single, PAGE_RANK, newPageRank); |
| messenger.sendMessage(this.incidentMessageScope, newPageRank / vertex.<Double>value(EDGE_COUNT)); |
| } |
| } |
| |
| @Override |
| public boolean terminate(final Memory memory) { |
| return memory.getIteration() >= this.totalIterations; |
| } |
| |
| @Override |
| public String toString() { |
| return StringFactory.vertexProgramString(this, "alpha=" + this.alpha + ",iterations=" + this.totalIterations); |
| } |
| |
| ////////////////////////////// |
| |
| public static Builder build() { |
| return new Builder(); |
| } |
| |
| public final static class Builder extends AbstractVertexProgramBuilder<Builder> { |
| |
| private Builder() { |
| super(PageRankVertexProgram.class); |
| } |
| |
| public Builder iterations(final int iterations) { |
| this.configuration.setProperty(TOTAL_ITERATIONS, iterations); |
| return this; |
| } |
| |
| public Builder alpha(final double alpha) { |
| this.configuration.setProperty(ALPHA, alpha); |
| return this; |
| } |
| |
| public Builder traversal(final TraversalSource.Builder builder, final String scriptEngine, final String traversalScript, final Object... bindings) { |
| ConfigurationTraversal.storeState(new TraversalScriptFunction<>(builder, scriptEngine, traversalScript, bindings), this.configuration, TRAVERSAL_SUPPLIER); |
| return this; |
| } |
| |
| public Builder traversal(final Traversal.Admin<Vertex, Edge> traversal) { |
| ConfigurationTraversal.storeState(new TraversalObjectFunction<>(traversal), this.configuration, TRAVERSAL_SUPPLIER); |
| return this; |
| } |
| |
| |
| public Builder traversal(final Class<? extends Supplier<Traversal.Admin<?, ?>>> traversalClass) { |
| ConfigurationTraversal.storeState(new TraversalClassFunction(traversalClass), this.configuration, TRAVERSAL_SUPPLIER); |
| return this; |
| } |
| |
| public Builder vertexCount(final long vertexCount) { |
| this.configuration.setProperty(VERTEX_COUNT, (double) vertexCount); |
| return this; |
| } |
| } |
| |
| //////////////////////////// |
| |
| @Override |
| public Features getFeatures() { |
| return new Features() { |
| @Override |
| public boolean requiresLocalMessageScopes() { |
| return true; |
| } |
| |
| @Override |
| public boolean requiresVertexPropertyAddition() { |
| return true; |
| } |
| }; |
| } |
| } |