blob: 4a75aac2799acae7fd730c8c42959bad3be4c136 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;
import org.apache.nemo.common.coder.DecoderFactory;
import org.apache.nemo.common.coder.EncoderFactory;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.transform.Transform;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import java.util.*;
import java.util.stream.Collectors;
/**
* Pass for Common Subexpression Elimination optimization. It eliminates vertices that are repetitively run without
* much meaning, and runs it a single time, instead of multiple times. We consider such vertices as 'common' when
* they include the same transform, and has incoming edges from an identical set of vertices.
* Refer to CommonSubexpressionEliminationPassTest for such cases.
*/
@Requires(CommunicationPatternProperty.class)
public final class CommonSubexpressionEliminationPass extends ReshapingPass {
/**
* Default constructor.
*/
public CommonSubexpressionEliminationPass() {
super(CommonSubexpressionEliminationPass.class);
}
@Override
public IRDAG apply(final IRDAG inputDAG) {
// find and collect vertices with equivalent transforms
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
final Map<Transform, List<OperatorVertex>> operatorVerticesToBeMerged = new HashMap<>();
final Map<OperatorVertex, Set<IREdge>> inEdges = new HashMap<>();
final Map<OperatorVertex, Set<IREdge>> outEdges = new HashMap<>();
inputDAG.reshapeUnsafely(dag -> {
dag.topologicalDo(irVertex -> {
if (irVertex instanceof OperatorVertex) {
final OperatorVertex operatorVertex = (OperatorVertex) irVertex;
operatorVerticesToBeMerged.putIfAbsent(operatorVertex.getTransform(), new ArrayList<>());
operatorVerticesToBeMerged.get(operatorVertex.getTransform()).add(operatorVertex);
dag.getIncomingEdgesOf(operatorVertex).forEach(irEdge -> {
inEdges.putIfAbsent(operatorVertex, new HashSet<>());
inEdges.get(operatorVertex).add(irEdge);
if (irEdge.getSrc() instanceof OperatorVertex) {
final OperatorVertex source = (OperatorVertex) irEdge.getSrc();
outEdges.putIfAbsent(source, new HashSet<>());
outEdges.get(source).add(irEdge);
}
});
} else {
builder.addVertex(irVertex, dag);
dag.getIncomingEdgesOf(irVertex).forEach(irEdge -> {
if (irEdge.getSrc() instanceof OperatorVertex) {
final OperatorVertex source = (OperatorVertex) irEdge.getSrc();
outEdges.putIfAbsent(source, new HashSet<>());
outEdges.get(source).add(irEdge);
} else {
builder.connectVertices(irEdge);
}
});
}
});
// merge them if they are not dependent on each other, and add IRVertices to the builder.
operatorVerticesToBeMerged.forEach(((transform, operatorVertices) -> {
final Map<Set<IRVertex>, List<OperatorVertex>> verticesToBeMergedWithIdenticalSources = new HashMap<>();
operatorVertices.forEach(operatorVertex -> {
// compare if incoming vertices are identical.
final Set<IRVertex> incomingVertices = dag.getIncomingEdgesOf(operatorVertex).stream().map(IREdge::getSrc)
.collect(Collectors.toSet());
if (verticesToBeMergedWithIdenticalSources.keySet().stream()
.anyMatch(lst -> lst.containsAll(incomingVertices) && incomingVertices.containsAll(lst))) {
final Set<IRVertex> foundKey = verticesToBeMergedWithIdenticalSources.keySet().stream()
.filter(vs -> vs.containsAll(incomingVertices) && incomingVertices.containsAll(vs))
.findFirst().get();
verticesToBeMergedWithIdenticalSources.get(foundKey).add(operatorVertex);
} else {
verticesToBeMergedWithIdenticalSources.putIfAbsent(incomingVertices, new ArrayList<>());
verticesToBeMergedWithIdenticalSources.get(incomingVertices).add(operatorVertex);
}
});
verticesToBeMergedWithIdenticalSources.values().forEach(ovs ->
mergeAndAddToBuilder(ovs, builder, dag, inEdges, outEdges));
}));
// process IREdges
operatorVerticesToBeMerged.values().forEach(operatorVertices ->
operatorVertices.forEach(operatorVertex -> {
inEdges.getOrDefault(operatorVertex, new HashSet<>()).forEach(e -> {
if (builder.contains(operatorVertex) && builder.contains(e.getSrc())) {
builder.connectVertices(e);
}
});
outEdges.getOrDefault(operatorVertex, new HashSet<>()).forEach(e -> {
if (builder.contains(operatorVertex) && builder.contains(e.getDst())) {
builder.connectVertices(e);
}
});
}));
return builder.build();
});
return inputDAG;
}
/**
* merge equivalent operator vertices and add them to the provided builder.
*
* @param ovs operator vertices that are to be merged (if there are no dependencies between them).
* @param builder builder to add the merged vertices to.
* @param dag dag to observe while adding them.
* @param inEdges incoming edges information.
* @param outEdges outgoing edges information.
*/
private static void mergeAndAddToBuilder(final List<OperatorVertex> ovs, final DAGBuilder<IRVertex, IREdge> builder,
final DAG<IRVertex, IREdge> dag,
final Map<OperatorVertex, Set<IREdge>> inEdges,
final Map<OperatorVertex, Set<IREdge>> outEdges) {
if (!ovs.isEmpty()) {
final OperatorVertex operatorVertexToUse = ovs.get(0);
final List<OperatorVertex> dependencyFailedOperatorVertices = new ArrayList<>();
builder.addVertex(operatorVertexToUse);
ovs.forEach(ov -> {
if (!ov.equals(operatorVertexToUse)) {
if (dag.pathExistsBetween(operatorVertexToUse, ov)) {
dependencyFailedOperatorVertices.add(ov);
} else {
// incoming edges do not need to be considered, as they come from identical incoming vertices.
// process outEdges
final Set<IREdge> outListToModify = outEdges.get(ov);
outEdges.getOrDefault(ov, new HashSet<>()).forEach(e -> {
outListToModify.remove(e);
final IREdge newIrEdge = new IREdge(e.getPropertyValue(CommunicationPatternProperty.class).get(),
operatorVertexToUse, e.getDst());
final Optional<EncoderFactory> encoderProperty = e.getPropertyValue(EncoderProperty.class);
if (encoderProperty.isPresent()) {
newIrEdge.setProperty(EncoderProperty.of(encoderProperty.get()));
}
final Optional<DecoderFactory> decoderProperty = e.getPropertyValue(DecoderProperty.class);
if (decoderProperty.isPresent()) {
newIrEdge.setProperty(DecoderProperty.of(decoderProperty.get()));
}
outListToModify.add(newIrEdge);
});
outEdges.remove(ov);
outEdges.putIfAbsent(operatorVertexToUse, new HashSet<>());
outEdges.get(operatorVertexToUse).addAll(outListToModify);
}
}
});
mergeAndAddToBuilder(dependencyFailedOperatorVertices, builder, dag, inEdges, outEdges);
}
}
}