/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <vector>
#include <stdexcept>

#include <nanobind/nanobind.h>
#include <nanobind/make_iterator.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>

#include "py_object_lt.hpp"
#include "py_object_ostream.hpp"
#include "quantile_conditional.hpp"
#include "req_sketch.hpp"

namespace nb = nanobind;

template<typename T, typename C>
void bind_req_sketch(nb::module_ &m, const char* name) {
  using namespace datasketches;

  auto req_class = nb::class_<req_sketch<T, C>>(m, name)
    .def(nb::init<uint16_t, bool>(), nb::arg("k")=12, nb::arg("is_hra")=true,
         "Creates an REQ sketch instance with the given value of k.\n\n"
         ":param k: Controls the size/accuracy trade-off of the sketch. Default is 12.\n"
         ":type k: int, optional\n"
         ":param is_hra: Specifies whether the skech has High Rank Accuracy (True) or Low Rank Accuracy. Default True.\n"
         ":type is_hra: bool, optional"
    )
    .def("__copy__", [](const req_sketch<T, C>& sk){ return req_sketch<T, C>(sk); })
    .def("update", (void (req_sketch<T, C>::*)(const T&)) &req_sketch<T, C>::update, nb::arg("item"),
        "Updates the sketch with the given value")
    .def("merge", (void (req_sketch<T, C>::*)(const req_sketch<T, C>&)) &req_sketch<T, C>::merge, nb::arg("sketch"),
        "Merges the provided sketch into this one")
    .def("__str__", &req_sketch<T, C>::to_string,
        "Produces a string summary of the sketch")
    .def("to_string", &req_sketch<T, C>::to_string, nb::arg("print_levels")=false, nb::arg("print_items")=false,
        "Produces a string summary of the sketch")
    .def("is_hra", &req_sketch<T, C>::is_HRA,
        "Returns True if the sketch is in High Rank Accuracy mode, otherwise False")
    .def("is_empty", &req_sketch<T, C>::is_empty,
        "Returns True if the sketch is empty, otherwise False")
    .def_prop_ro("k", &req_sketch<T, C>::get_k,
        "The configured parameter k")
    .def_prop_ro("n", &req_sketch<T, C>::get_n,
        "The length of the input stream")
    .def_prop_ro("num_retained", &req_sketch<T, C>::get_num_retained,
        "The number of retained items (samples) in the sketch")
    .def("is_estimation_mode", &req_sketch<T, C>::is_estimation_mode,
        "Returns True if the sketch is in estimation mode, otherwise False")
    .def("get_min_value", &req_sketch<T, C>::get_min_item,
        "Returns the minimum value from the stream. If empty, req_floats_sketch returns nan; req_ints_sketch throws a RuntimeError")
    .def("get_max_value", &req_sketch<T, C>::get_max_item,
        "Returns the maximum value from the stream. If empty, req_floats_sketch returns nan; req_ints_sketch throws a RuntimeError")
    .def("get_quantile", &req_sketch<T, C>::get_quantile, nb::arg("rank"), nb::arg("inclusive")=false,
        "Returns an approximation to the data value "
        "associated with the given normalized rank in a hypothetical sorted "
        "version of the input stream so far.\n"
        "For req_floats_sketch: if the sketch is empty this returns nan. "
        "For req_ints_sketch: if the sketch is empty this throws a RuntimeError.")
    .def(
        "get_quantiles",
        [](const req_sketch<T, C>& sk, const std::vector<double>& ranks, bool inclusive) {
          std::vector<T> quantiles;
          if (!sk.is_empty()) {
            quantiles.reserve(ranks.size());
            for (size_t i = 0; i < ranks.size(); ++i) quantiles.push_back(sk.get_quantile(ranks[i], inclusive));
          }
          return quantiles;
        },
        nb::arg("ranks"), nb::arg("inclusive")=false,
        "This returns an array that could have been generated by using get_quantile() for each "
        "normalized rank separately.\n"
        "If the sketch is empty this returns an empty vector."
    )
    .def("get_rank", &req_sketch<T, C>::get_rank, nb::arg("value"), nb::arg("inclusive")=false,
        "Returns an approximation to the normalized rank of the given value from 0 to 1, inclusive.\n"
        "The resulting approximation has a probabilistic guarantee that can be obtained from the "
        "get_normalized_rank_error(False) function.\n"
        "With the parameter inclusive=true the weight of the given value is included into the rank."
        "Otherwise the rank equals the sum of the weights of values less than the given value.\n"
        "If the sketch is empty this returns nan.")
    .def(
        "get_pmf",
        [](const req_sketch<T, C>& sk, const std::vector<T>& split_points, bool inclusive) {
          return sk.get_PMF(split_points.data(), split_points.size(), inclusive);
        },
        nb::arg("split_points"), nb::arg("inclusive")=false,
        "Returns an approximation to the Probability Mass Function (PMF) of the input stream "
        "given a set of split points (values).\n"
        "The resulting approximations have a probabilistic guarantee that can be obtained from the "
        "get_normalized_rank_error(True) function.\n"
        "If the sketch is empty this returns an empty vector.\n"
        "split_points is an array of m unique, monotonically increasing float values "
        "that divide the real number line into m+1 consecutive disjoint intervals.\n"
        "If the parameter inclusive=false, the definition of an 'interval' is inclusive of the left split point (or minimum value) and "
        "exclusive of the right split point, with the exception that the last interval will include "
        "the maximum value.\n"
        "If the parameter inclusive=true, the definition of an 'interval' is exclusive of the left split point (or minimum value) and "
        "inclusive of the right split point.\n"
        "It is not necessary to include either the min or max values in these split points."
    )
    .def(
        "get_cdf",
        [](const req_sketch<T, C>& sk, const std::vector<T>& split_points, bool inclusive) {
          return sk.get_CDF(split_points.data(), split_points.size(), inclusive);
        },
        nb::arg("split_points"), nb::arg("inclusive")=false,
        "Returns an approximation to the Cumulative Distribution Function (CDF), which is the "
        "cumulative analog of the PMF, of the input stream given a set of split points (values).\n"
        "The resulting approximations have a probabilistic guarantee that can be obtained from the "
        "get_normalized_rank_error(True) function.\n"
        "If the sketch is empty this returns an empty vector.\n"
        "split_points is an array of m unique, monotonically increasing float values "
        "that divide the real number line into m+1 consecutive disjoint intervals.\n"
        "If the parameter inclusive=false, the definition of an 'interval' is inclusive of the left split point (or minimum value) and "
        "exclusive of the right split point, with the exception that the last interval will include "
        "the maximum value.\n"
        "If the parameter inclusive=true, the definition of an 'interval' is exclusive of the left split point (or minimum value) and "
        "inclusive of the right split point.\n"
        "It is not necessary to include either the min or max values in these split points."
    )
    .def("get_rank_lower_bound", &req_sketch<T, C>::get_rank_lower_bound, nb::arg("rank"), nb::arg("num_std_dev"),
        "Returns an approximate lower bound on the given normalized rank.\n"
        "Normalized rank must be a value between 0.0 and 1.0 (inclusive); "
        "the number of standard deviations must be 1, 2, or 3.")
    .def("get_rank_upper_bound", &req_sketch<T, C>::get_rank_upper_bound, nb::arg("rank"), nb::arg("num_std_dev"),
        "Returns an approximate upper bound on the given normalized rank.\n"
        "Normalized rank must be a value between 0.0 and 1.0 (inclusive); "
        "the number of standard deviations must be 1, 2, or 3.")
    .def_static("get_RSE", &req_sketch<T, C>::get_RSE,
        nb::arg("k"), nb::arg("rank"), nb::arg("is_hra"), nb::arg("n"),
        "Returns an a priori estimate of relative standard error (RSE, expressed as a number in [0,1]). "
        "Derived from Lemma 12 in http://arxiv.org/abs/2004.01668v2, but the constant factors have been "
        "modified based on empirical measurements, for a given value of parameter k.\n"
        "Normalized rank must be a value between 0.0 and 1.0 (inclusive). If is_hra is True, uses high "
        "rank accuracy mode, else low rank accuracy. N is an estimate of the total number of points "
        "provided to the sketch.")
    .def("__iter__",
        [](const req_sketch<T, C>& s) {
            return nb::make_iterator(nb::type<req_sketch<T, C>>(),
                                     "req_iterator",
                                     s.begin(),
                                     s.end());
            }, nb::keep_alive<0, 1>()
    )
    ;

    add_serialization<T>(req_class);
    add_vector_update<T>(req_class);
}

void init_req(nb::module_ &m) {
  bind_req_sketch<int, std::less<int>>(m, "req_ints_sketch");
  bind_req_sketch<float, std::less<float>>(m, "req_floats_sketch");
  bind_req_sketch<nb::object, py_object_lt>(m, "req_items_sketch");
}
