/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.wayang.core.optimizer.channels;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.optimizer.DefaultOptimizationContext;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.optimizer.OptimizationUtils;
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimate;
import org.apache.wayang.core.optimizer.costs.TimeEstimate;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.core.plan.executionplan.ExecutionTask;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
import org.apache.wayang.core.plan.wayangplan.InputSlot;
import org.apache.wayang.core.plan.wayangplan.LoopHeadOperator;
import org.apache.wayang.core.plan.wayangplan.LoopSubplan;
import org.apache.wayang.core.plan.wayangplan.OutputSlot;
import org.apache.wayang.core.platform.ChannelDescriptor;
import org.apache.wayang.core.platform.Junction;
import org.apache.wayang.core.util.Bitmask;
import org.apache.wayang.core.util.OneTimeExecutable;
import org.apache.wayang.core.util.ReflectionUtils;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.core.util.WayangCollections;

/**
 * This graph contains a set of {@link ChannelConversion}s.
 */
public class ChannelConversionGraph {

    /**
     * Keeps track of the {@link ChannelConversion}s.
     */
    private final Map<ChannelDescriptor, List<ChannelConversion>> conversions = new HashMap<>();

    /**
     * Caches the {@link Comparator} for {@link ProbabilisticDoubleInterval}s.
     */
    private final ToDoubleFunction<ProbabilisticDoubleInterval> costSquasher;

    /**
     * On the face of competing {@link Tree}s, picks the one to use for converting {@link Channel}s.
     */
    private final TreeSelectionStrategy treeSelectionStrategy;

    private static final Logger logger = LogManager.getLogger(ChannelConversionGraph.class);

    /**
     * Creates a new instance.
     *
     * @param configuration describes how to configure the new instance
     */
    public ChannelConversionGraph(Configuration configuration) {
        this.costSquasher = configuration.getCostSquasherProvider().provide();
        configuration.getChannelConversionProvider().provideAll().forEach(this::add);
        String treeSelectionStrategyClassName = configuration.getStringProperty(
                "wayang.core.optimizer.channels.selection",
                this.getClass().getCanonicalName() + '$' + CostbasedTreeSelectionStrategy.class.getSimpleName()
        );
        this.treeSelectionStrategy = ReflectionUtils.instantiateDefault(treeSelectionStrategyClassName);
    }

    /**
     * Register a new {@code channelConversion} in this instance, which effectively adds an edge.
     */
    public void add(ChannelConversion channelConversion) {
        final List<ChannelConversion> edges = this.getOrCreateChannelConversions(channelConversion.getSourceChannelDescriptor());
        edges.add(channelConversion);
    }

    /**
     * Return all registered {@link ChannelConversion}s that convert the given {@code channelDescriptor}.
     *
     * @param channelDescriptor should be converted
     * @return the {@link ChannelConversion}s
     */
    private List<ChannelConversion> getOrCreateChannelConversions(ChannelDescriptor channelDescriptor) {
        return this.conversions.computeIfAbsent(channelDescriptor, key -> new ArrayList<>());
    }

    /**
     * Finds the minimum tree {@link Junction} (w.r.t. {@link TimeEstimate}s that connects the given {@link OutputSlot} to the
     * {@code destInputSlots}.
     *
     * @param output              {@link OutputSlot} of an {@link ExecutionOperator} that should be consumed
     * @param destInputSlots      {@link InputSlot}s of {@link ExecutionOperator}s that should receive data from the {@code output}
     * @param optimizationContext describes the above mentioned {@link ExecutionOperator} key figures
     * @param isRequestBreakpoint whether a breakpoint-capable {@link Channel} should be inserted if possible
     * @return a {@link Junction} or {@code null} if none could be found
     */
    public Junction findMinimumCostJunction(OutputSlot<?> output,
                                            List<InputSlot<?>> destInputSlots,
                                            OptimizationContext optimizationContext,
                                            boolean isRequestBreakpoint) {
        return new ShortestTreeSearcher(output, null, destInputSlots, optimizationContext, isRequestBreakpoint).getJunction();
    }

    /**
     * Finds the minimum tree {@link Junction} (w.r.t. {@link TimeEstimate}s that connects the given {@link OutputSlot} to the
     * {@code destInputSlots}.
     *
     * @param output              {@link OutputSlot} of an {@link ExecutionOperator} that should be consumed
     * @param openChannels        existing {@link Channel}s that must be part of the tree or {@code null}
     * @param destInputSlots      {@link InputSlot}s of {@link ExecutionOperator}s that should receive data from the {@code output}
     * @param optimizationContext describes the above mentioned {@link ExecutionOperator} key figures
     * @return a {@link Junction} or {@code null} if none could be found
     */
    public Junction findMinimumCostJunction(OutputSlot<?> output,
                                            Collection<Channel> openChannels,
                                            List<InputSlot<?>> destInputSlots,
                                            OptimizationContext optimizationContext) {
        return new ShortestTreeSearcher(output, openChannels, destInputSlots, optimizationContext, false).getJunction();
    }

    /**
     * Given two {@link Tree}s, select the preferred one.
     */
    private Tree selectPreferredTree(Tree t1, Tree t2) {
        if (t1 == null) return t2;
        if (t2 == null) return t1;
        return this.treeSelectionStrategy.select(t1, t2);
    }

    /**
     * Merge several {@link Tree}s, which should have the same root but should be otherwise disjoint.
     *
     * @return the merged {@link Tree} or {@code null} if the input {@code trees} could not be merged
     */
    private Tree mergeTrees(Collection<Tree> trees) {
        assert trees.size() >= 2;

        // For various trees to be combined, we require them to be "disjoint". Check this.
        final Iterator<Tree> iterator = trees.iterator();
        final Tree firstTree = iterator.next();
        Bitmask combinationSettledIndices = new Bitmask(firstTree.settledDestinationIndices);
        int maxSettledIndices = combinationSettledIndices.cardinality();
        final HashSet<ChannelDescriptor> employedChannelDescriptors = new HashSet<>(firstTree.employedChannelDescriptors);
        int maxVisitedChannelDescriptors = employedChannelDescriptors.size();
        double costs = firstTree.costs;
        TreeVertex newRoot = new TreeVertex(firstTree.root.channelDescriptor, firstTree.root.settledIndices);
        newRoot.copyEdgesFrom(firstTree.root);

        while (iterator.hasNext()) {
            final Tree ithTree = iterator.next();

            combinationSettledIndices.orInPlace(ithTree.settledDestinationIndices);
            maxSettledIndices += ithTree.settledDestinationIndices.cardinality();
            if (maxSettledIndices > combinationSettledIndices.cardinality()) {
                return null;
            }
            employedChannelDescriptors.addAll(ithTree.employedChannelDescriptors);
            maxVisitedChannelDescriptors += ithTree.employedChannelDescriptors.size() - 1; // NB: -1 for the root
            if (maxVisitedChannelDescriptors > employedChannelDescriptors.size()) {
                return null;
            }

            costs += ithTree.costs;
            newRoot.copyEdgesFrom(ithTree.root);
        }

        // If all tests are passed, create the combination.
        final Tree mergedTree = new Tree(newRoot, combinationSettledIndices);
        mergedTree.costs = costs;
        mergedTree.employedChannelDescriptors.addAll(employedChannelDescriptors);
        return mergedTree;
    }

    /**
     * Designates a strategy of which conversion tree to prefer when there are competing trees.
     */
    private interface TreeSelectionStrategy {

        /**
         * Select the preferred {@link Tree}.
         *
         * @param t1 the first {@link Tree} (not {@code null}
         * @param t2 the second {@link Tree} (not {@code null}
         * @return the preferred {@link Tree}
         */
        Tree select(Tree t1, Tree t2);

    }

    /**
     * Prefers {@link Tree}s with lower cost.
     */
    public static class CostbasedTreeSelectionStrategy implements TreeSelectionStrategy {

        @Override
        public Tree select(Tree t1, Tree t2) {
            return t1.costs <= t2.costs ? t1 : t2;
        }

    }

    /**
     * Prefers {@link Tree}s with lower cost.
     */
    public static class RandomTreeSelectionStrategy implements TreeSelectionStrategy {

        private final Random random = new Random();

        @Override
        public Tree select(Tree t1, Tree t2) {
            return this.random.nextBoolean() ? t1 : t2;
        }

    }

    /**
     * Finds the shortest tree between the {@link #sourceChannelDescriptor} and the {@link #destChannelDescriptorSets}.
     */
    private class ShortestTreeSearcher extends OneTimeExecutable {

        /**
         * The {@link OutputSlot} that should be converted.
         */
        private final OutputSlot<?> sourceOutput;

        /**
         * {@link ChannelDescriptor} for the {@link Channel} produced by the {@link #sourceOutput}.
         */
        private final ChannelDescriptor sourceChannelDescriptor;

        /**
         * Describes the number of data quanta that are presumably converted.
         */
        private final CardinalityEstimate cardinality;

        /**
         * How often the conversion is presumably performed.
         */
        private final int numExecutions;

        /**
         * {@link ChannelDescriptor}s of {@link #existingChannels} that are allowed to have new consumers.
         */
        private final Set<ChannelDescriptor> openChannelDescriptors;

        /**
         * Maps {@link ChannelDescriptor}s to already existing {@link Channel}s.
         */
        private final Map<ChannelDescriptor, Channel> existingChannels;

        /**
         * Maps destination {@link InputSlot}s to {@link #existingChannels} that are their destination.
         */
        private final Map<InputSlot<?>, Channel> existingDestinationChannels;

        private final Bitmask existingDestinationChannelIndices,
                absentDestinationChannelIndices,
                allDestinationChannelIndices = new Bitmask();

        private final Map<ChannelDescriptor, Bitmask> reachableExistingDestinationChannelIndices;

        /**
         * {@link InputSlot}s that should be served by the {@link #sourceOutput}.
         */
        private final List<InputSlot<?>> destInputs;

        /**
         * Supported {@link ChannelDescriptor}s for each of the {@link #destInputs}.
         */
        private final List<Set<ChannelDescriptor>> destChannelDescriptorSets;

        /**
         * The input {@link OptimizationContext} (and a write-copy, repectively).
         */
        private final OptimizationContext optimizationContext, optimizationContextCopy;

        /**
         * Whether the {@link #result} should contain a breakpoint-capable {@link Channel}.
         */
        private final boolean isRequestBreakpoint;

        /**
         * Maps kernelized {@link Set}s of possible input {@link ChannelDescriptor}s to destination {@link InputSlot}s via
         * their respective indices in {@link #destInputs}.
         */
        private Map<Set<ChannelDescriptor>, Bitmask> kernelDestChannelDescriptorSetsToIndices;

        /**
         * Maps specific input {@link ChannelDescriptor}s to applicable destination {@link InputSlot}s via
         * their respective indices in {@link #destInputs}.
         */
        private Map<ChannelDescriptor, Bitmask> kernelDestChannelDescriptorsToIndices;

        /**
         * Caches cost estimates for {@link ChannelConversion}s.
         */
        private Map<ChannelConversion, Double> conversionCostCache = new HashMap<>();

        /**
         * Caches whether {@link ChannelConversion}s should be filtered.
         */
        private Map<ChannelConversion, Boolean> conversionFilterCache = new HashMap<>();

        /**
         * Caches the result of {@link #getJunction()}.
         */
        private Junction result = null;

        /**
         * Create a new instance.
         *
         * @param sourceOutput        provides a {@link Channel} that should be converted
         * @param openChannels        existing {@link Channel} derived from {@code sourceOutput} and where we must take up
         *                            the search; or {@code null}
         * @param destInputs          that consume the converted {@link Channel}(s)
         * @param optimizationContext provides optimization info
         */
        private ShortestTreeSearcher(OutputSlot<?> sourceOutput,
                                     Collection<Channel> openChannels,
                                     List<InputSlot<?>> destInputs,
                                     OptimizationContext optimizationContext,
                                     boolean isRequestBreakpoint) {

            // Store relevant variables.
            this.isRequestBreakpoint = isRequestBreakpoint && openChannels == null; // No breakpoints requestable.
            this.optimizationContext = optimizationContext;
            this.optimizationContextCopy = new DefaultOptimizationContext(this.optimizationContext);
            this.sourceOutput = sourceOutput;
            this.destInputs = destInputs;
            final boolean isOpenChannelsPresent = openChannels != null && !openChannels.isEmpty();

            // Figure out the optimization info via the sourceOutput.
            final ExecutionOperator outputOperator = (ExecutionOperator) this.sourceOutput.getOwner();
            final OptimizationContext.OperatorContext operatorContext = optimizationContext.getOperatorContext(outputOperator);
            assert operatorContext != null : String.format("Optimization info for %s missing.", outputOperator);
            this.cardinality = operatorContext.getOutputCardinality(this.sourceOutput.getIndex());
            this.numExecutions = operatorContext.getNumExecutions();

            // Figure out, if a part of the conversion is already in place and initialize accordingly.

            if (isOpenChannelsPresent) {
                // Take any open channel and trace it back to the source channel.
                Channel existingChannel = WayangCollections.getAny(openChannels);
                while (existingChannel.getProducerSlot() != sourceOutput) {
                    existingChannel = OptimizationUtils.getPredecessorChannel(existingChannel);
                }
                final Channel sourceChannel = existingChannel;
                this.sourceChannelDescriptor = sourceChannel.getDescriptor();

                // Now traverse down-stream to find already reached destinations.
                this.existingChannels = new HashMap<>();
                this.existingDestinationChannels = new HashMap<>(4);
                this.existingDestinationChannelIndices = new Bitmask();

                this.collectExistingChannels(sourceChannel);
                this.openChannelDescriptors = new HashSet<>(openChannels.size());
                for (Channel openChannel : openChannels) {
                    this.openChannelDescriptors.add(openChannel.getDescriptor());
                }
            } else {
                this.sourceChannelDescriptor = outputOperator.getOutputChannelDescriptor(this.sourceOutput.getIndex());
                this.existingChannels = Collections.emptyMap();
                this.existingDestinationChannels = Collections.emptyMap();
                this.existingDestinationChannelIndices = new Bitmask();
                this.openChannelDescriptors = Collections.emptySet();
            }


            // Set up the destinations.
            this.destChannelDescriptorSets = WayangCollections.map(destInputs, this::resolveSupportedChannels);
            assert this.destChannelDescriptorSets.stream().noneMatch(Collection::isEmpty);
            this.kernelizeChannelRequests();

            if (isOpenChannelsPresent) {
                // Update the bitmask of all already reached destination channels...
                // ...and mark the paths of already reached destination channels via upstream traversal.
                this.reachableExistingDestinationChannelIndices = new HashMap<>();
                for (Channel existingDestinationChannel : this.existingDestinationChannels.values()) {
                    final Bitmask channelIndices = this.kernelDestChannelDescriptorsToIndices
                            .get(existingDestinationChannel.getDescriptor())
                            .and(this.existingDestinationChannelIndices);
                    while (true) {
                        this.reachableExistingDestinationChannelIndices.compute(
                                existingDestinationChannel.getDescriptor(),
                                (k, v) -> v == null ? new Bitmask(channelIndices) : v.orInPlace(channelIndices)
                        );
                        if (existingDestinationChannel.getDescriptor().equals(this.sourceChannelDescriptor)) break;
                        existingDestinationChannel = OptimizationUtils.getPredecessorChannel(existingDestinationChannel);
                    }
                }
            } else {
                this.reachableExistingDestinationChannelIndices = Collections.emptyMap();
            }

            this.absentDestinationChannelIndices = this.allDestinationChannelIndices
                    .andNot(this.existingDestinationChannelIndices);
        }

        /**
         * Traverse the given {@link Channel} downstream and collect any {@link Channel} that feeds some
         * {@link InputSlot} of a non-conversion {@link ExecutionOperator}.
         *
         * @param channel the {@link Channel} to traverse from
         * @see #existingChannels
         * @see #existingDestinationChannels
         */
        private void collectExistingChannels(Channel channel) {
            this.existingChannels.put(channel.getDescriptor(), channel);
            for (ExecutionTask consumer : channel.getConsumers()) {
                final ExecutionOperator operator = consumer.getOperator();
                if (!operator.isAuxiliary()) {
                    final InputSlot<?> input = consumer.getInputSlotFor(channel);
                    this.existingDestinationChannels.put(input, channel);
                    int destIndex = 0;
                    while (this.destInputs.get(destIndex) != input) destIndex++;
                    this.existingDestinationChannelIndices.set(destIndex);
                } else {
                    for (Channel outputChannel : consumer.getOutputChannels()) {
                        if (outputChannel != null) this.collectExistingChannels(outputChannel);
                    }
                }
            }
        }

        /**
         * Creates and caches a {@link Junction} according to the initialization parameters.
         *
         * @return the {@link Junction} or {@code null} if none could be found
         */
        public Junction getJunction() {
            this.tryExecute();
            return this.result;
        }

        /**
         * Find the supported {@link ChannelDescriptor}s for the given {@link InputSlot}. If the latter is a
         * "loop invariant" {@link InputSlot}, then require to only reusable {@link ChannelDescriptor}.
         *
         * @param input for which supported {@link ChannelDescriptor}s are requested
         * @return all eligible {@link ChannelDescriptor}s
         */
        private Set<ChannelDescriptor> resolveSupportedChannels(final InputSlot<?> input) {
            final Channel existingChannel = this.existingDestinationChannels.get(input);
            if (existingChannel != null) {
                return Collections.singleton(existingChannel.getDescriptor());
            }

            final ExecutionOperator owner = (ExecutionOperator) input.getOwner();
            final List<ChannelDescriptor> supportedInputChannels = owner.getSupportedInputChannels(input.getIndex());
            if (input.isLoopInvariant()) {
                // Loop input is needed in several iterations and must therefore be reusable.
                return supportedInputChannels.stream().filter(ChannelDescriptor::isReusable).collect(Collectors.toSet());
            } else {
                return WayangCollections.asSet(supportedInputChannels);
            }
        }

        @Override
        protected void doExecute() {
            // Start from the root vertex.
            final Tree tree = this.searchTree();
            if (tree != null) {
                this.createJunction(tree);
            } else {
                logger.debug("Could not connect {} with {}.", this.sourceOutput, this.destInputs);
            }
        }

        /**
         * Rule out any non-reusable {@link ChannelDescriptor}s in recurring {@link ChannelDescriptor} sets.
         *
         * @see #kernelDestChannelDescriptorSetsToIndices
         * @see #kernelDestChannelDescriptorsToIndices
         */
        private void kernelizeChannelRequests() {
            // Check if the Junction enters a loop "from the side", i.e., across multiple iterations.
            // CHECK: Since we rule out non-reusable Channels in #resolveSupportedChannels, do we really need this?
//            final LoopSubplan outputLoop = this.sourceOutput.getOwner().getInnermostLoop();
//            final int outputLoopDepth = this.sourceOutput.getOwner().getLoopStack().size();
//            boolean isSideEnterLoop = this.destInputs.stream().anyMatch(input ->
//                    !input.getOwner().isLoopHead() &&
//                            (input.getOwner().getLoopStack().size() > outputLoopDepth ||
//                                    (input.getOwner().getLoopStack().size() == outputLoopDepth && input.getOwner().getInnermostLoop() != outputLoop)
//                            )
//            );


            // Index the (unreached) Channel requests by their InputSlots, thereby merging equal ones.
            int index = 0;
            this.kernelDestChannelDescriptorSetsToIndices = new HashMap<>(this.destChannelDescriptorSets.size());
            for (Set<ChannelDescriptor> destChannelDescriptorSet : this.destChannelDescriptorSets) {
                final Bitmask indices = this.kernelDestChannelDescriptorSetsToIndices.computeIfAbsent(
                        destChannelDescriptorSet, key -> new Bitmask(this.destChannelDescriptorSets.size())
                );
                this.allDestinationChannelIndices.set(index);
                indices.set(index++);
            }

            // Strip off the non-reusable, superfluous ChannelDescriptors where applicable.
            Collection<Tuple<Set<ChannelDescriptor>, Bitmask>> kernelDestChannelDescriptorSetsToIndicesUpdates = new LinkedList<>();
            final Iterator<Map.Entry<Set<ChannelDescriptor>, Bitmask>> iterator =
                    this.kernelDestChannelDescriptorSetsToIndices.entrySet().iterator();
            while (iterator.hasNext()) {
                final Map.Entry<Set<ChannelDescriptor>, Bitmask> entry = iterator.next();
                final Bitmask indices = entry.getValue();
                // Don't touch destination channel sets that occur only once.
                if (indices.cardinality() < 2) continue;

                // If there is exactly one non-reusable and more than one reusable channel, we can remove the
                // non-reusable one.
                Set<ChannelDescriptor> channelDescriptors = entry.getKey();
                int numReusableChannels = (int) channelDescriptors.stream().filter(ChannelDescriptor::isReusable).count();
                if (numReusableChannels == 0 && channelDescriptors.size() == 1) {
                    logger.warn(
                            "More than two target operators request only the non-reusable channel {}.",
                            WayangCollections.getSingle(channelDescriptors)
                    );
                }
                if (channelDescriptors.size() - numReusableChannels == 1) {
                    iterator.remove();
                    channelDescriptors = new HashSet<>(channelDescriptors);
                    channelDescriptors.removeIf(channelDescriptor -> !channelDescriptor.isReusable());
                    kernelDestChannelDescriptorSetsToIndicesUpdates.add(new Tuple<>(channelDescriptors, indices));
                }
            }
            for (Tuple<Set<ChannelDescriptor>, Bitmask> channelsToIndicesChange : kernelDestChannelDescriptorSetsToIndicesUpdates) {
                this.kernelDestChannelDescriptorSetsToIndices.computeIfAbsent(
                        channelsToIndicesChange.getField0(),
                        key -> new Bitmask(this.destChannelDescriptorSets.size())
                ).orInPlace(channelsToIndicesChange.getField1());
            }

            // Index the single ChannelDescriptors.
            this.kernelDestChannelDescriptorsToIndices = new HashMap<>();
            for (Map.Entry<Set<ChannelDescriptor>, Bitmask> entry : this.kernelDestChannelDescriptorSetsToIndices.entrySet()) {
                final Set<ChannelDescriptor> channelDescriptorSet = entry.getKey();
                final Bitmask indices = entry.getValue();

                for (ChannelDescriptor channelDescriptor : channelDescriptorSet) {
                    this.kernelDestChannelDescriptorsToIndices.merge(channelDescriptor, new Bitmask(indices), Bitmask::or);
                }
            }
        }

        /**
         * Starts the actual search.
         */
        private Tree searchTree() {
            // Prepare the recursive traversal.
            final HashSet<ChannelDescriptor> visitedChannelDescriptors = new HashSet<>(16);
            visitedChannelDescriptors.add(this.sourceChannelDescriptor);

            // Perform the traversal.
            final Map<Bitmask, Tree> solutions = this.enumerate(
                    visitedChannelDescriptors,
                    this.sourceChannelDescriptor,
                    Bitmask.EMPTY_BITMASK,
                    this.sourceChannelDescriptor.isSuitableForBreakpoint()
            );

            // Get hold of a comprehensive solution (if it exists).
            Bitmask requestedIndices = new Bitmask(this.destChannelDescriptorSets.size());
            requestedIndices.flip(0, this.destChannelDescriptorSets.size());
            return solutions.get(requestedIndices);
        }

        /**
         * Recursive {@link Tree} enumeration strategy.
         *
         * @param visitedChannelDescriptors  previously visited {@link ChannelDescriptor}s (inclusive of {@code channelDescriptor};
         *                                   can be altered but must be in original state before leaving the method
         * @param channelDescriptor          the currently enumerated {@link ChannelDescriptor}
         * @param settledDestinationIndices  indices of destinations that have already been reached via the
         *                                   {@code visitedChannelDescriptors} (w/o {@code channelDescriptor};
         *                                   can be altered but must be in original state before leaving the method
         * @param isVisitedBreakpointChannel whether the {@code visitedChannelDescriptors} contain a breakpoint-capable {@link Channel}
         * @return solutions to the search problem reachable from this node; {@link Tree}s must still be rerooted
         */
        public Map<Bitmask, Tree> enumerate(
                Set<ChannelDescriptor> visitedChannelDescriptors,
                ChannelDescriptor channelDescriptor,
                Bitmask settledDestinationIndices,
                boolean isVisitedBreakpointChannel) {

            // Mapping from settled indices to the cheapest tree settling them. Will be the return value.
            Map<Bitmask, Tree> newSolutions = new HashMap<>(16);
            Tree newSolution;


            // Check if current path is a (new) solution.
            // Exclude existing destinations that are not reached via the current path.
            final Bitmask excludedExistingIndices = this.existingDestinationChannelIndices.andNot(
                    this.reachableExistingDestinationChannelIndices.getOrDefault(channelDescriptor, Bitmask.EMPTY_BITMASK)
            );
            // Exclude missing destinations if the current channel exists but is not open.
            final Bitmask excludedAbsentIndices =
                    (this.existingChannels.containsKey(channelDescriptor) && !openChannelDescriptors.contains(channelDescriptor)) ?
                            this.absentDestinationChannelIndices :
                            Bitmask.EMPTY_BITMASK;
            final Bitmask newSettledIndices = this.kernelDestChannelDescriptorsToIndices
                    .getOrDefault(channelDescriptor, Bitmask.EMPTY_BITMASK)
                    .andNot(settledDestinationIndices)
                    .andNotInPlace(excludedExistingIndices)
                    .andNotInPlace(excludedAbsentIndices);

            if (!newSettledIndices.isEmpty()) {
                if (channelDescriptor.isReusable() || newSettledIndices.cardinality() == 1) {
                    // If the channel is reusable, it is almost safe to say that we either use it for all possible
                    // destinations or for none, because no extra costs can incur.
                    // TODO: Create all possible combinations if required.
                    // The same is for when there is only a single destination reached.
                    newSolution = Tree.singleton(channelDescriptor, newSettledIndices);
                    newSolutions.put(newSolution.settledDestinationIndices, newSolution);

                } else {
                    // Otherwise, create an entry for each settled index.
                    for (int index = newSettledIndices.nextSetBit(0); index != -1; index = newSettledIndices.nextSetBit(index + 1)) {
                        Bitmask newSettledIndicesSubset = new Bitmask(index + 1);
                        newSettledIndicesSubset.set(index);
                        newSolution = Tree.singleton(channelDescriptor, newSettledIndicesSubset);
                        newSolutions.put(newSolution.settledDestinationIndices, newSolution);
                    }
                }

                // Check early stopping criteria:
                // There are no more channels that could be settled.
                if (newSettledIndices.cardinality() == this.destChannelDescriptorSets.size() - excludedExistingIndices.cardinality()
                        && (!this.isRequestBreakpoint || isVisitedBreakpointChannel)) {
                    return newSolutions;
                }
            }

            // For each outgoing edge, explore all combinations of reachable target indices.
            if (channelDescriptor.isReusable()) {
                // When descending, "pick" the newly settled destinations only for reusable ChannelDescriptors.
                settledDestinationIndices.orInPlace(newSettledIndices);
            }
            final List<ChannelConversion> channelConversions =
                    ChannelConversionGraph.this.conversions.getOrDefault(channelDescriptor, Collections.emptyList());
            final List<Collection<Tree>> childSolutionSets = new ArrayList<>(channelConversions.size());
            final Set<ChannelDescriptor> successorChannelDescriptors = this.getSuccessorChannelDescriptors(channelDescriptor);
            for (ChannelConversion channelConversion : channelConversions) {
                final ChannelDescriptor targetChannelDescriptor = channelConversion.getTargetChannelDescriptor();

                // Skip if there are successor channel descriptors that do not contain the target channel descriptor.
                if (successorChannelDescriptors != null && !successorChannelDescriptors.contains(targetChannelDescriptor)) {
                    continue;
                }

                // Check if the channelConversion can be filtered.
                if (successorChannelDescriptors == null && this.isFiltered(channelConversion)) {
                    logger.info("Filtering conversion {} between {} and {}.", channelConversion, this.sourceOutput, this.destInputs);
                    continue;
                }

                if (visitedChannelDescriptors.add(targetChannelDescriptor)) {
                    final Map<Bitmask, Tree> childSolutions = this.enumerate(
                            visitedChannelDescriptors,
                            targetChannelDescriptor,
                            settledDestinationIndices,
                            isVisitedBreakpointChannel || targetChannelDescriptor.isSuitableForBreakpoint()
                    );
                    childSolutions.values().forEach(
                            tree -> tree.reroot(
                                    channelDescriptor,
                                    channelDescriptor.isReusable() ? newSettledIndices : Bitmask.EMPTY_BITMASK,
                                    channelConversion,
                                    this.getCostEstimate(channelConversion)
                            )
                    );
                    if (!childSolutions.isEmpty()) childSolutionSets.add(childSolutions.values());

                    visitedChannelDescriptors.remove(targetChannelDescriptor);
                }
            }
            settledDestinationIndices.andNotInPlace(newSettledIndices);

            // Merge the childSolutionSets into the newSolutions.
            // Each childSolutionSet corresponds to a traversed outgoing ChannelConversion.

            // If a breakpoint is requested, we don't need to merge if there are already comprehensive solutions with
            // breakpoints.
            if (this.isRequestBreakpoint && !isVisitedBreakpointChannel) {
                Tree bestBreakpointSolution = null;
                for (Collection<Tree> trees : childSolutionSets) {
                    for (Tree tree : trees) {
                        if (this.allDestinationChannelIndices.isSubmaskOf(tree.settledDestinationIndices)) {
                            bestBreakpointSolution = selectPreferredTree(bestBreakpointSolution, tree);
                        }
                    }
                }
                if (bestBreakpointSolution != null) {
                    newSolutions.put(bestBreakpointSolution.settledDestinationIndices, bestBreakpointSolution);
                    return newSolutions;
                }
            }

            // At first, consider the childSolutionSet for each outgoing ChannelConversion individually.
            for (Collection<Tree> childSolutionSet : childSolutionSets) {
                // Each childSolutionSet its has a mapping from settled indices to trees.
                for (Tree tree : childSolutionSet) {
                    // Update newSolutions if the current tree is cheaper or settling new indices.
                    newSolutions.merge(tree.settledDestinationIndices, tree, ChannelConversionGraph.this::selectPreferredTree);
                }
            }

            // If the current Channel/vertex is reusable, also detect valid combinations.
            // Check if the combinations yield new solutions.
            if (channelDescriptor.isReusable()
                    && this.kernelDestChannelDescriptorSetsToIndices.size() > 1
                    && childSolutionSets.size() > 1
                    && this.destInputs.size() > newSettledIndices.cardinality() + settledDestinationIndices.cardinality() + 1) {

                // Determine the number of "unreached" destChannelDescriptorSets.
                int numUnreachedDestinationSets = 0;
                for (Bitmask settlableDestinationIndices : this.kernelDestChannelDescriptorSetsToIndices.values()) {
                    if (!settlableDestinationIndices.isSubmaskOf(settledDestinationIndices)) {
                        numUnreachedDestinationSets++;
                    }
                }

                if (numUnreachedDestinationSets >= 2) { // only combine when there is more than one destination left
                    final Collection<List<Collection<Tree>>> childSolutionSetCombinations =
                            WayangCollections.createPowerList(childSolutionSets, numUnreachedDestinationSets);
                    for (List<Collection<Tree>> childSolutionSetCombination : childSolutionSetCombinations) {
                        if (childSolutionSetCombination.size() < 2)
                            continue; // only combine when we have more than on child solution
                        for (List<Tree> solutionCombination : WayangCollections.streamedCrossProduct(childSolutionSetCombination)) {
                            final Tree tree = ChannelConversionGraph.this.mergeTrees(solutionCombination);
                            if (tree != null) {
                                newSolutions.merge(tree.settledDestinationIndices, tree, ChannelConversionGraph.this::selectPreferredTree);
                            }
                        }
                    }
                }
            }

            return newSolutions;
        }

        /**
         * Find {@link ChannelDescriptor}s of {@link Channel}s that can be reached from the current
         * {@link ChannelDescriptor}. This is relevant only when there are {@link #existingChannels}.
         *
         * @param descriptor from which successor {@link ChannelDescriptor}s are requested
         * @return the successor {@link ChannelDescriptor}s or {@code null} if no restrictions apply
         */
        private Set<ChannelDescriptor> getSuccessorChannelDescriptors(ChannelDescriptor descriptor) {
            final Channel channel = this.existingChannels.get(descriptor);
            if (channel == null || this.openChannelDescriptors.contains(descriptor)) return null;

            Set<ChannelDescriptor> result = new HashSet<>();
            for (ExecutionTask consumer : channel.getConsumers()) {
                if (!consumer.getOperator().isAuxiliary()) continue;
                for (Channel successorChannel : consumer.getOutputChannels()) {
                    result.add(successorChannel.getDescriptor());
                }
            }

            return result;
        }

        /**
         * Retrieve a cached or calculate and cache the cost estimate for a given {@link ChannelConversion}
         * w.r.t. the {@link #cardinality}.
         *
         * @param channelConversion whose cost estimate is requested
         * @return the cost estimate
         */
        private double getCostEstimate(ChannelConversion channelConversion) {
            return this.conversionCostCache.computeIfAbsent(
                    channelConversion,
                    key -> {
                        final ProbabilisticDoubleInterval costEstimate = key.estimateConversionCost(
                                this.cardinality, this.numExecutions, this.optimizationContextCopy
                        );
                        return costSquasher.applyAsDouble(costEstimate);
                    }
            );
        }

        /**
         * Determine whether the given {@link ChannelConversion} should be filtered.
         *
         * @param channelConversion which should be checked for filtering
         * @return whether to filter the {@link ChannelConversion}
         */
        private boolean isFiltered(ChannelConversion channelConversion) {
            return this.conversionFilterCache.computeIfAbsent(
                    channelConversion,
                    key -> key.isFiltered(this.cardinality, this.numExecutions, this.optimizationContextCopy)
            );
        }

        private void createJunction(Tree tree) {
            List<OptimizationContext> localOptimizationContexts = this.forkLocalOptimizationContext();

            // Create the a new Junction.
            final Junction junction = new Junction(this.sourceOutput, this.destInputs, localOptimizationContexts);
            Channel sourceChannel = this.existingChannels.get(this.sourceChannelDescriptor);
            if (sourceChannel == null) {
                sourceChannel = this.sourceChannelDescriptor.createChannel(this.sourceOutput, this.optimizationContext.getConfiguration());
            }
            junction.setSourceChannel(sourceChannel);
            this.createJunctionAux(tree.root, sourceChannel, junction);

            // Assign appropriate LoopSubplans to the newly created ExecutionTasks.
            // Determine the LoopSubplan from the "source side" of the Junction.
            final OutputSlot<?> sourceOutput = sourceChannel.getProducerSlot();
            final ExecutionOperator sourceOperator = (ExecutionOperator) sourceOutput.getOwner();
            final LoopSubplan sourceLoop =
                    (!sourceOperator.isLoopHead() || sourceOperator.isFeedforwardOutput(sourceOutput)) ?
                            sourceOperator.getInnermostLoop() : null;

            if (sourceLoop != null) {
                // If the source side is determining a LoopSubplan, it should be what the "target sides" request.
                for (int destIndex = 0; destIndex < this.destInputs.size(); destIndex++) {
                    assert this.destInputs.get(destIndex).getOwner().getInnermostLoop() == sourceLoop :
                            String.format(
                                    "Expected that %s would belong to %s, just as %s does.",
                                    this.destInputs.get(destIndex), sourceLoop, sourceOutput
                            );
                    Channel targetChannel = junction.getTargetChannel(destIndex);
                    while (targetChannel != sourceChannel) {
                        final ExecutionTask producer = targetChannel.getProducer();
                        producer.getOperator().setContainer(sourceLoop);
                        assert producer.getNumInputChannels() == 1 : String.format(
                                "Glue operator %s was expected to have exactly one input channel.",
                                producer
                        );
                        targetChannel = producer.getInputChannel(0);
                    }
                }
            }

            assert tree.settledDestinationIndices.stream().allMatch(i -> junction.getTargetChannel(i) != null) :
                    String.format("Junction from %s to %s has no target channels.",
                            junction.getSourceOutput(),
                            tree.settledDestinationIndices.stream()
                                    .filter(idx -> junction.getTargetChannel(idx) == null)
                                    .mapToObj(idx -> String.format("%s (index=%d)", this.destInputs.get(idx), idx))
                                    .collect(Collectors.joining(" and "))
                    );

            // CHECK: We don't need to worry about entering loops, because in this case #resolveSupportedChannels(...)
            // does all the magic!?

            this.result = junction;
        }

        /**
         * Helper function to create a {@link Junction} from a {@link Tree}.
         *
         * @param vertex      the currently iterated {@link TreeVertex} (start at the root)
         * @param baseChannel the corresponding {@link Channel}
         * @param junction    that is being initialized
         */
        private void createJunctionAux(TreeVertex vertex, Channel baseChannel, Junction junction) {
            Channel baseChannelCopy = null;
            for (int index = vertex.settledIndices.nextSetBit(0);
                 index >= 0;
                 index = vertex.settledIndices.nextSetBit(index + 1)) {
                // Beware that, if the base channel is existent (and open), we need to create a copy for any new
                // destinations.
                if (!this.existingDestinationChannelIndices.get(index) && this.openChannelDescriptors.contains(baseChannel.getDescriptor())) {
                    if (baseChannelCopy == null) baseChannelCopy = baseChannel.copy();
                    junction.setTargetChannel(index, baseChannelCopy);
                } else {
                    junction.setTargetChannel(index, baseChannel);
                }
            }

            for (TreeEdge edge : vertex.outEdges) {
                // See if there is an already existing channel in place.
                Channel newChannel = this.existingChannels.get(edge.channelConversion.getTargetChannelDescriptor());

                // Otherwise, create a new channel conversion.
                if (newChannel == null) {
                    // Beware that, if the base channel is existent (and open), we need to create a copy of it for the
                    // new channel conversions.
                    if (baseChannelCopy == null) {
                        baseChannelCopy = this.openChannelDescriptors.contains(baseChannel.getDescriptor()) ?
                                baseChannel.copy() :
                                baseChannel;
                    }
                    newChannel = edge.channelConversion.convert(
                            baseChannelCopy,
                            this.optimizationContext.getConfiguration(),
                            junction.getOptimizationContexts(),
                            // Hacky: Inject cardinality for cases where we convert a LoopHeadOperator output.
                            junction.getOptimizationContexts().size() == 1 ? this.cardinality : null
                    );
                } else {
                    edge.channelConversion.update(
                            baseChannel,
                            newChannel,
                            junction.getOptimizationContexts(),
                            // Hacky: Inject cardinality for cases where we convert a LoopHeadOperator output.
                            junction.getOptimizationContexts().size() == 1 ? this.cardinality : null
                    );
                }
                if (baseChannel != newChannel) {
                    final ExecutionTask producer = newChannel.getProducer();
                    final ExecutionOperator conversionOperator = producer.getOperator();
                    conversionOperator.setName(String.format(
                            "convert %s", junction.getSourceOutput()
                    ));
                    junction.register(producer);
                }

                this.createJunctionAux(edge.destination, newChannel, junction);
            }
        }

        /**
         * Creates a new {@link OptimizationContext} that forks
         * <ul>
         * <li>the given {@code optimizationContext}'s parent if the {@link #sourceOutput} is the final
         * {@link OutputSlot} of a {@link LoopHeadOperator}</li>
         * <li>or else the given {@code optimizationContext}.</li>
         * </ul>
         * We have to do this because in the former case the {@link Junction} {@link ExecutionOperator}s should not
         * reside in a loop {@link OptimizationContext}.
         *
         * @return the forked {@link OptimizationContext}
         */
        // TODO: Refactor this.
        private List<OptimizationContext> forkLocalOptimizationContext() {
            OptimizationContext baseOptimizationContext =
                    this.sourceOutput.getOwner().isLoopHead() && !this.sourceOutput.isFeedforward() ?
                            this.optimizationContext.getParent() :
                            this.optimizationContext;
            return baseOptimizationContext.getDefaultOptimizationContexts().stream()
                    .map(DefaultOptimizationContext::new)
                    .collect(Collectors.toList());
        }


    }

    /**
     * A tree consisting of {@link TreeVertex}es connected by {@link TreeEdge}s.
     */
    private static class Tree {

        /**
         * The root node of this instance.
         */
        private TreeVertex root;

        /**
         * The union of all settled indices of the contained {@link TreeVertex}es in this instance.
         *
         * @see TreeVertex#settledIndices
         */
        private final Bitmask settledDestinationIndices;

        /**
         * The {@link Set} of {@link ChannelDescriptor}s in all {@link TreeVertex}es of this instance.
         *
         * @see TreeVertex#channelDescriptor
         */
        private final Set<ChannelDescriptor> employedChannelDescriptors = new HashSet<>();

        /**
         * The sum of the costs of all {@link TreeEdge}s of this instance.
         */
        private double costs = 0d;

        /**
         * Creates a new instance with a single {@link TreeVertex}.
         *
         * @param channelDescriptor represented by the {@link TreeVertex}
         * @param settledIndices    indices to destinations settled by the {@code channelDescriptor}
         * @return the new instance
         */
        static Tree singleton(ChannelDescriptor channelDescriptor, Bitmask settledIndices) {
            return new Tree(new TreeVertex(channelDescriptor, settledIndices), new Bitmask(settledIndices));
        }

        Tree(TreeVertex root, Bitmask settledDestinationIndices) {
            this.root = root;
            this.settledDestinationIndices = settledDestinationIndices;
            this.employedChannelDescriptors.add(root.channelDescriptor);
        }

        /**
         * Push down the {@link #root} of this instance by adding a new {@link TreeVertex} as root and put the old
         * root as its child node.
         *
         * @param newRootChannelDescriptor    will be wrapped in the new {@link #root}
         * @param newRootSettledIndices       destination indices settled by the {@code newRootChannelDescriptor}
         * @param newToObsoleteRootConversion used to establish the {@link TreeEdge} between the old and new {@link #root}
         * @param costEstimate                of the {@code newToObsoleteRootConversion}
         */
        void reroot(ChannelDescriptor newRootChannelDescriptor,
                    Bitmask newRootSettledIndices,
                    ChannelConversion newToObsoleteRootConversion,
                    double costEstimate) {
            // Exchange the root.
            final TreeVertex newRoot = new TreeVertex(newRootChannelDescriptor, newRootSettledIndices);
            final TreeEdge edge = newRoot.linkTo(newToObsoleteRootConversion, this.root, costEstimate);
            this.root = newRoot;
            // Update metadata.
            this.employedChannelDescriptors.add(newRootChannelDescriptor);
            this.settledDestinationIndices.orInPlace(newRootSettledIndices);
            this.costs += edge.costEstimate;
        }

        @Override
        public String toString() {
            return String.format("%s[%s, %s]", this.getClass().getSimpleName(), this.costs, this.root.getChildChannelConversions());
        }
    }

    /**
     * Vertex in a {@link Tree}. Corresponds to a {@link ChannelDescriptor}.
     */
    private static class TreeVertex {

        /**
         * The {@link ChannelDescriptor} represented by this instance.
         */
        private final ChannelDescriptor channelDescriptor;

        /**
         * {@link TreeEdge}s to child {@link TreeVertex}es.
         */
        private final List<TreeEdge> outEdges;

        /**
         * Indices to settled {@link ShortestTreeSearcher#destInputs}.
         */
        private final Bitmask settledIndices;

        /**
         * Creates a new instance.
         *
         * @param channelDescriptor to be represented by this instance
         * @param settledIndices    indices to settled destinations
         */
        private TreeVertex(ChannelDescriptor channelDescriptor, Bitmask settledIndices) {
            this.channelDescriptor = channelDescriptor;
            this.settledIndices = settledIndices;
            this.outEdges = new ArrayList<>(4);
        }

        private TreeEdge linkTo(ChannelConversion channelConversion, TreeVertex destination, double costEstimate) {
            final TreeEdge edge = new TreeEdge(channelConversion, destination, costEstimate);
            this.outEdges.add(edge);
            return edge;
        }

        private void copyEdgesFrom(TreeVertex that) {
            assert this.channelDescriptor.equals(that.channelDescriptor);
            this.outEdges.addAll(that.outEdges);
        }

        /**
         * Collects all {@link ChannelConversion}s employed by (indirectly) outgoing {@link TreeEdge}s.
         *
         * @return a {@link Set} of said {@link ChannelConversion}s
         */
        private Set<ChannelConversion> getChildChannelConversions() {
            Set<ChannelConversion> channelConversions = new HashSet<>();
            for (TreeEdge edge : this.outEdges) {
                channelConversions.add(edge.channelConversion);
                channelConversions.addAll(edge.destination.getChildChannelConversions());
            }
            return channelConversions;
        }

        @Override
        public String toString() {
            return String.format("%s[%s]", this.getClass().getSimpleName(), this.channelDescriptor);
        }
    }

    /**
     * Edge in a {@link Tree}.
     */
    private static class TreeEdge {

        /**
         * The target {@link TreeVertex}.
         */
        private final TreeVertex destination;

        /**
         * The {@link ChannelConversion} represented by this instance.
         */
        private final ChannelConversion channelConversion;

        /**
         * The cost estimate of the {@link #channelConversion}.
         */
        private final double costEstimate;

        private TreeEdge(ChannelConversion channelConversion, TreeVertex destination, double costEstimate) {
            this.channelConversion = channelConversion;
            this.destination = destination;
            this.costEstimate = costEstimate;
        }

        @Override
        public String toString() {
            return String.format("%s[%s, %s]", this.getClass().getSimpleName(), this.channelConversion, this.costEstimate);
        }
    }
}
