/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.twitter.distributedlog.service.balancer;

import static com.google.common.base.Preconditions.checkArgument;

import java.io.Serializable;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.tuple.Pair;

/**
 * A stream chooser based on number of streams.
 */
class CountBasedStreamChooser implements StreamChooser, Serializable,
        Comparator<Pair<SocketAddress, LinkedList<String>>> {

    private static final long serialVersionUID = 4664153397369979203L;

    final List<Pair<SocketAddress, LinkedList<String>>> streamsDistribution;

    // pivot index in the list of hosts. the chooser will remove streams from the hosts before
    // pivot, which will reduce their stream counts to make them equal to the stream count of the pivot.
    int pivot;
    int pivotCount;

    // next index in the list of hosts to choose stream from.
    int next;

    CountBasedStreamChooser(Map<SocketAddress, Set<String>> streams) {
        checkArgument(streams.size() > 0, "Only support no-empty streams distribution");
        streamsDistribution = new ArrayList<Pair<SocketAddress, LinkedList<String>>>(streams.size());
        for (Map.Entry<SocketAddress, Set<String>> entry : streams.entrySet()) {
            LinkedList<String> randomizedStreams = new LinkedList<String>(entry.getValue());
            Collections.shuffle(randomizedStreams);
            streamsDistribution.add(Pair.of(entry.getKey(), randomizedStreams));
        }
        // sort the hosts by the number of streams in descending order
        Collections.sort(streamsDistribution, this);
        pivot = 0;
        pivotCount = streamsDistribution.get(0).getValue().size();
        findNextPivot();
        next = 0;
    }

    private void findNextPivot() {
        int prevPivotCount = pivotCount;
        while (++pivot < streamsDistribution.size()) {
            pivotCount = streamsDistribution.get(pivot).getValue().size();
            if (pivotCount < prevPivotCount) {
                return;
            }
        }
        pivot = streamsDistribution.size();
        pivotCount = 0;
    }

    @Override
    public synchronized String choose() {
        // reach the pivot
        if (next == pivot) {
            if (streamsDistribution.get(next - 1).getRight().size() > pivotCount) {
                next = 0;
            } else if (pivotCount == 0) { // the streams are empty now
                return null;
            } else {
                findNextPivot();
                next = 0;
            }
        }

        // get stream count that next host to choose from
        LinkedList<String> nextStreams = streamsDistribution.get(next).getRight();
        if (nextStreams.size() == 0) {
            return null;
        }

        String chosenStream = nextStreams.remove();
        ++next;
        return chosenStream;
    }

    @Override
    public int compare(Pair<SocketAddress, LinkedList<String>> o1,
                       Pair<SocketAddress, LinkedList<String>> o2) {
        return o2.getValue().size() - o1.getValue().size();
    }
}
