blob: 41e4fa0c97e7977a675dc972832232511766e3a2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.exception.*;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
import org.apache.nemo.common.ir.executionproperty.ExecutionProperty;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.compiler.optimizer.OptimizerUtils;
import org.apache.nemo.runtime.common.metric.MetricUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
/**
* Pass for applying XGBoost optimizations.
* <p>
* 1. The pass first triggers the client to run the XGBoost script, located under the `ml` python package.
* 2. The client runs the script, which trains the tree model using the metrics collected before, and constructs
* a tree model, which then predicts the 'knobs' that minimizes the JCT based on the weights of the leaves (JCT).
* 3. It receives the results, and in which direction each of the knobs should be optimized, and reconstructs the
* execution properties in the form that they are tuned.
* 4. The newly reconstructed execution properties are injected and the workload runs after the optimization.
*/
@Annotates()
public final class XGBoostPass extends AnnotatingPass {
private static final Logger LOG = LoggerFactory.getLogger(XGBoostPass.class.getName());
private static final BlockingQueue<String> MESSAGE_QUEUE = new LinkedBlockingQueue<>();
/**
* Default constructor.
*/
public XGBoostPass() {
super(XGBoostPass.class);
}
@Override
public IRDAG apply(final IRDAG dag) {
try {
final String message = XGBoostPass.takeMessage();
LOG.info("Received message from the client: {}", message);
if (message.isEmpty()) {
LOG.info("No optimization included in the message. Returning the original dag.");
return dag;
} else {
ObjectMapper mapper = new ObjectMapper();
List<Map<String, String>> listOfMap =
mapper.readValue(message, new TypeReference<List<Map<String, String>>>() {
});
// Formatted into 9 digits: 0:vertex/edge 1-5:ID 5-9:EP Index.
listOfMap.stream().filter(m -> m.get("feature").length() == 9).forEach(m -> {
final Pair<String, Integer> idAndEPKey = OptimizerUtils.stringToIdAndEPKeyIndex(m.get("feature"));
LOG.info("Tuning: {} of {} should be {} than {}",
idAndEPKey.right(), idAndEPKey.left(), m.get("val"), m.get("split"));
final ExecutionProperty<? extends Serializable> newEP = MetricUtils.keyAndValueToEP(idAndEPKey.right(),
Double.valueOf(m.get("split")), Double.valueOf(m.get("val")));
try {
if (idAndEPKey.left().startsWith("vertex")) {
final IRVertex v = dag.getVertexById(idAndEPKey.left());
final VertexExecutionProperty<?> originalEP = v.getExecutionProperties().stream()
.filter(ep -> ep.getClass().isAssignableFrom(newEP.getClass())).findFirst().orElse(null);
v.setProperty((VertexExecutionProperty) newEP);
if (!dag.checkIntegrity().isPassed()) {
v.setProperty(originalEP);
}
} else if (idAndEPKey.left().startsWith("edge")) {
final IREdge e = dag.getEdgeById(idAndEPKey.left());
final EdgeExecutionProperty<?> originalEP = e.getExecutionProperties().stream()
.filter(ep -> ep.getClass().isAssignableFrom(newEP.getClass())).findFirst().orElse(null);
e.setProperty((EdgeExecutionProperty) newEP);
if (!dag.checkIntegrity().isPassed()) {
e.setProperty(originalEP);
}
}
} catch (IllegalVertexOperationException | IllegalEdgeOperationException e) {
}
});
}
} catch (final InvalidParameterException e) {
LOG.warn(e.getMessage());
return dag;
} catch (final Exception e) {
throw new CompileTimeOptimizationException(e);
}
return dag;
}
/**
* @param message push the message to the message queue.
*/
public static void pushMessage(final String message) {
MESSAGE_QUEUE.add(message);
}
/**
* @return the message from the blocking queue.
*/
private static String takeMessage() {
try {
return MESSAGE_QUEUE.take();
} catch (InterruptedException e) {
// Restore interrupted state...
Thread.currentThread().interrupt();
throw new MetricException("Interrupted while waiting for message: " + e);
}
}
}