Using Machine Learning for query optimization in Apache Wayang (incubating)

Apache Wayang (incubating) can be customized with concrete implementations of the EstimatableCost interface in order to optimize for a desired metric. The implementation can be enabled by providing it to a Configuration.

public class CustomEstimatableCost implements EstimatableCost {
    /* Provide concrete implementations to match desired cost function(s)
     * by implementing the interface in this class.
     */
}
public class WordCount {
    public static void main(String[] args) {
        /* Create a Wayang context and specify the platforms Wayang will consider */
        Configuration config = new Configuration();
        /* Provision of a EstimatableCost that implements the interface.*/
        config.setCostModel(new CustomEstimatableCost());
        WayangContext wayangContext = new WayangContext(config)
                .withPlugin(Java.basicPlugin())
                .withPlugin(Spark.basicPlugin());
        /*... omitted */
    }
}

In combination with an encoding scheme and a third party package to load ML models, the following example shows how to predict runtimes of query execution plans runtimes in Apache Wayang (incubating):

import org.apache.wayang.core.optimizer.costs.EstimatableCost;
import org.apache.wayang.core.optimizer.costs.EstimatableCostFactory;
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
import org.apache.wayang.core.optimizer.enumeration.LoopImplementation;
import org.apache.wayang.core.optimizer.enumeration.PlanImplementation;
import org.apache.wayang.core.platform.Junction;
import org.apache.wayang.core.plan.executionplan.ExecutionPlan;
import org.apache.wayang.core.plan.executionplan.ExecutionStage;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.ml.encoding.OneHotEncoder;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.exception.WayangException;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.ml.OrtMLModel;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import java.util.List;

public class MLCost implements EstimatableCost {
    public EstimatableCostFactory getFactory() {
        return new Factory();
    }

    public static class Factory implements EstimatableCostFactory {
        @Override public EstimatableCost makeCost() {
            return new MLCost();
        }
    }

    @Override public ProbabilisticDoubleInterval getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return ProbabilisticDoubleInterval.ofExactly(
                model.runModel(OneHotEncoder.encode(plan))
            );
        } catch(Exception e) {
            return ProbabilisticDoubleInterval.zero;
        }
    }

    @Override public ProbabilisticDoubleInterval getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return ProbabilisticDoubleInterval.ofExactly(
                model.runModel(OneHotEncoder.encode(plan))
            );
        } catch(Exception e) {
            return ProbabilisticDoubleInterval.zero;
        }
    }

    /** Returns a squashed cost estimate. */
    @Override public double getSquashedEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return model.runModel(OneHotEncoder.encode(plan));
        } catch(Exception e) {
            return 0;
        }
    }

    @Override public double getSquashedParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
        try {
            Configuration config = plan
                .getOptimizationContext()
                .getConfiguration();
            OrtMLModel model = OrtMLModel.getInstance(config);

            return model.runModel(OneHotEncoder.encode(plan));
        } catch(Exception e) {
            return 0;
        }
    }

    @Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator operator) {
        List<ProbabilisticDoubleInterval> intervalList = new ArrayList<ProbabilisticDoubleInterval>();
        List<Double> doubleList = new ArrayList<Double>();
        intervalList.add(this.getEstimate(plan, true));
        doubleList.add(this.getSquashedEstimate(plan, true));

        return new Tuple<>(intervalList, doubleList);
    }

    public PlanImplementation pickBestExecutionPlan(
            Collection<PlanImplementation> executionPlans,
            ExecutionPlan existingPlan,
            Set<Channel> openChannels,
            Set<ExecutionStage> executedStages) {
        final PlanImplementation bestPlanImplementation = executionPlans.stream()
                .reduce((p1, p2) -> {
                    final double t1 = p1.getSquashedCostEstimate();
                    final double t2 = p2.getSquashedCostEstimate();
                    return t1 < t2 ? p1 : p2;
                })
                .orElseThrow(() -> new WayangException("Could not find an execution plan."));
        return bestPlanImplementation;
    }
}

Third-party packages such as OnnxRuntime can be used to load pre-trained .onnx files that contain desired ML models.

import org.apache.wayang.core.api.Configuration;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;

import java.util.Vector;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.Map;
import java.util.function.BiFunction;

public class OrtMLModel {

    private static OrtMLModel INSTANCE;

    private OrtSession session;
    private OrtEnvironment env;

    private final Map<String, OnnxTensor> inputMap = new HashMap<>();
    private final Set<String> requestedOutputs = new HashSet<>();

    public static OrtMLModel getInstance(Configuration configuration) throws OrtException {
        if (INSTANCE == null) {
            INSTANCE = new OrtMLModel(configuration);
        }

        return INSTANCE;
    }

    private OrtMLModel(Configuration configuration) throws OrtException {
        this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
    }

    public void loadModel(String filePath) throws OrtException {
        if (this.env == null) {
            this.env = OrtEnvironment.getEnvironment();
        }

        if (this.session == null) {
            this.session = env.createSession(filePath, new OrtSession.SessionOptions());
        }
    }

    public void closeSession() throws OrtException {
        this.session.close();
        this.env.close();
    }

    /**
     * @param encodedVector
     * @return NaN on error, and a predicted cost on any other value.
     * @throws OrtException
     */
    public double runModel(Vector<Long> encodedVector) throws OrtException {
        double costPrediction;

        OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
        this.inputMap.put("input", tensor);
        this.requestedOutputs.add("output");

        BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
            try {
                return ((double[]) r.get(s).get().getValue())[0];
            } catch (OrtException e) {
                return Double.NaN;
            }
        };

        try (Result r = session.run(inputMap, requestedOutputs)) {
            costPrediction = unwrapFunc.apply(r, "output");
        }

        return costPrediction;
    }
}