blob: 0dc9fe9b30d40cf4ae48ad2336e73b336d7d8b80 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.optimizer;
import net.jcip.annotations.NotThreadSafe;
import org.apache.nemo.common.Util;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.exception.CompileTimeOptimizationException;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CacheIDProperty;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.vertex.CachedSourceVertex;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDataReceiverProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.compiler.optimizer.pass.runtime.Message;
import org.apache.nemo.compiler.optimizer.policy.Policy;
import org.apache.nemo.compiler.optimizer.policy.XGBoostPolicy;
import org.apache.nemo.conf.JobConf;
import org.apache.nemo.runtime.common.comm.ControlMessage;
import org.apache.nemo.runtime.common.message.ClientRPC;
import org.apache.reef.tang.annotations.Parameter;
import javax.inject.Inject;
import java.util.*;
import java.util.stream.Collectors;
/**
* An interface for optimizer, which manages the optimization over submitted IR DAGs through {@link Policy}s.
* The instance of this class will reside in driver.
*/
@NotThreadSafe
public final class NemoOptimizer implements Optimizer {
private final String dagDirectory;
private final Policy optimizationPolicy;
private final String environmentTypeStr;
private final String executorInfoContents;
private final ClientRPC clientRPC;
private final Map<UUID, Integer> cacheIdToParallelism = new HashMap<>();
private int irDagCount = 0;
/**
* @param dagDirectory to store JSON representation of intermediate DAGs.
* @param policyName the name of the optimization policy.
* @param environmentTypeStr the environment type of the workload to optimize the DAG for.
* @param executorInfoContents the string of the information of the executors provided.
* @param clientRPC the RPC channel to communicate with the client.
*/
@Inject
private NemoOptimizer(@Parameter(JobConf.DAGDirectory.class) final String dagDirectory,
@Parameter(JobConf.OptimizationPolicy.class) final String policyName,
@Parameter(JobConf.EnvironmentType.class) final String environmentTypeStr,
@Parameter(JobConf.ExecutorJSONContents.class) final String executorInfoContents,
final ClientRPC clientRPC) {
this.dagDirectory = dagDirectory;
this.environmentTypeStr = OptimizerUtils.filterEnvironmentTypeString(environmentTypeStr);
this.executorInfoContents = executorInfoContents;
this.clientRPC = clientRPC;
try {
optimizationPolicy = (Policy) Class.forName(policyName).newInstance();
if (policyName == null) {
throw new CompileTimeOptimizationException("A policy name should be specified.");
}
} catch (final Exception e) {
throw new CompileTimeOptimizationException(e);
}
}
@Override
public IRDAG optimizeAtCompileTime(final IRDAG dag) {
final String irDagId = "ir-" + irDagCount++ + "-";
dag.storeJSON(dagDirectory, irDagId, "IR before optimization");
final IRDAG optimizedDAG;
final Map<UUID, IREdge> cacheIdToEdge = new HashMap<>();
// Handle caching first.
final IRDAG cacheFilteredDag = handleCaching(dag, cacheIdToEdge);
if (!cacheIdToEdge.isEmpty()) {
cacheFilteredDag.storeJSON(dagDirectory, irDagId + "FilterCache",
"IR after cache filtering");
}
// Conduct compile-time optimization.
beforeCompileTimeOptimization(dag, optimizationPolicy);
optimizedDAG = optimizationPolicy.runCompileTimeOptimization(cacheFilteredDag, dagDirectory);
optimizedDAG
.storeJSON(dagDirectory, irDagId + optimizationPolicy.getClass().getSimpleName(),
"IR optimized for " + optimizationPolicy.getClass().getSimpleName());
// Update cached list.
// TODO #191: Report the actual state of cached data to optimizer.
// Now we assume that the optimized dag always run properly.
cacheIdToEdge.forEach((cacheId, edge) -> {
if (!cacheIdToParallelism.containsKey(cacheId)) {
cacheIdToParallelism.put(
cacheId, optimizedDAG
.getVertexById(edge.getDst().getId()).getPropertyValue(ParallelismProperty.class)
.orElseThrow(() -> new RuntimeException("No parallelism on an IR vertex.")));
}
});
// Return optimized dag
return optimizedDAG;
}
@Override
public IRDAG optimizeAtRunTime(final IRDAG dag, final Message message) {
return optimizationPolicy.runRunTimeOptimizations(dag, message);
}
/**
* Operations to be done prior to the Compile-Time Optimizations.
* TODO #371: This part can be reduced by not using the client RPC and sending the python script to the driver
* itself later on.
*
* @param dag the DAG to process.
* @param policy the optimization policy to optimize the DAG with.
*/
private void beforeCompileTimeOptimization(final IRDAG dag, final Policy policy) {
dag.recordExecutorInfo(Util.parseResourceSpecificationString(this.executorInfoContents));
if (policy instanceof XGBoostPolicy) {
clientRPC.send(ControlMessage.DriverToClientMessage.newBuilder()
.setType(ControlMessage.DriverToClientMessageType.LaunchOptimization)
.setOptimizationType(ControlMessage.OptimizationType.XGBoost)
.setDataCollected(ControlMessage.DataCollectMessage.newBuilder()
.setData(dag.irDAGSummary() + this.environmentTypeStr)
.build())
.build());
}
}
/**
* Handle data caching.
* At first, it search the edges having cache ID from the given dag and update them to the given map.
* Then, if some edge of a submitted dag is annotated as "cached" and the data was produced already,
* the part of the submitted dag which produces the cached data will be cropped and the last vertex
* before the cached edge will be replaced with a cached data source vertex.
* This cached edge will be detected and appended to the original dag in scheduler.
*
* @param dag the dag to handle.
* @param cacheIdToEdge the map from cache ID to edge to update.
* @return the cropped dag regarding to caching.
*/
private IRDAG handleCaching(final IRDAG dag, final Map<UUID, IREdge> cacheIdToEdge) {
dag.topologicalDo(irVertex ->
dag.getIncomingEdgesOf(irVertex).forEach(
edge -> edge.getPropertyValue(CacheIDProperty.class).
ifPresent(cacheId -> cacheIdToEdge.put(cacheId, edge))
));
if (cacheIdToEdge.isEmpty()) {
return dag;
} else {
final DAGBuilder<IRVertex, IREdge> filteredDagBuilder = new DAGBuilder<>();
final List<IRVertex> sinkVertices = dag.getVertices().stream()
.filter(irVertex -> dag.getOutgoingEdgesOf(irVertex).isEmpty())
.collect(Collectors.toList());
sinkVertices.forEach(filteredDagBuilder::addVertex); // Sink vertex cannot be cached already.
sinkVertices.forEach(sinkVtx -> addNonCachedVerticesAndEdges(dag, sinkVtx, filteredDagBuilder));
return new IRDAG(filteredDagBuilder.buildWithoutSourceCheck());
}
}
/**
* Recursively add vertices and edges after cached edges to the dag builder in reversed order.
*
* @param dag the original dag to filter.
* @param irVertex the ir vertex to consider to add.
* @param builder the filtered dag builder.
*/
private void addNonCachedVerticesAndEdges(final IRDAG dag,
final IRVertex irVertex,
final DAGBuilder<IRVertex, IREdge> builder) {
if (irVertex.getPropertyValue(IgnoreSchedulingTempDataReceiverProperty.class).orElse(false)
&& dag.getIncomingEdgesOf(irVertex).stream()
.filter(irEdge -> irEdge.getPropertyValue(CacheIDProperty.class).isPresent())
.anyMatch(irEdge -> cacheIdToParallelism
.containsKey(irEdge.getPropertyValue(CacheIDProperty.class).get()))) {
builder.removeVertex(irVertex); // Ignore ghost vertex which was cached once.
return;
}
dag.getIncomingEdgesOf(irVertex).stream()
.forEach(edge -> {
final Optional<UUID> cacheId = dag.getOutgoingEdgesOf(edge.getSrc()).stream()
.filter(edgeToFilter -> edgeToFilter.getPropertyValue(CacheIDProperty.class).isPresent())
.map(edgeToMap -> edgeToMap.getPropertyValue(CacheIDProperty.class).get())
.findFirst();
if (cacheId.isPresent() && cacheIdToParallelism.get(cacheId.get()) != null) { // Cached already.
// Replace the vertex emitting cached edge with a cached source vertex.
final IRVertex cachedDataRelayVertex = new CachedSourceVertex(cacheIdToParallelism.get(cacheId.get()));
cachedDataRelayVertex.setPropertyPermanently(
ParallelismProperty.of(cacheIdToParallelism.get(cacheId.get())));
builder.addVertex(cachedDataRelayVertex);
final IREdge newEdge = new IREdge(
edge.getPropertyValue(CommunicationPatternProperty.class)
.orElseThrow(() -> new RuntimeException("No communication pattern on an ir edge")),
cachedDataRelayVertex,
irVertex);
edge.copyExecutionPropertiesTo(newEdge);
newEdge.setProperty(CacheIDProperty.of(cacheId.get()));
builder.connectVertices(newEdge);
// Stop the recursion for this vertex.
} else {
final IRVertex srcVtx = edge.getSrc();
builder.addVertex(srcVtx);
builder.connectVertices(edge);
addNonCachedVerticesAndEdges(dag, srcVtx, builder);
}
});
}
}