# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# Graph Methods

# Please refer to the graph.sql_in file for the documentation

"""
@file graph.py_in

@namespace graph
"""
import plpy
from utilities.utilities import _assert, add_postfix
from utilities.utilities import get_filtered_cols_subquery_str
from utilities.utilities import extract_keyvalue_params
from utilities.validate_args import get_cols
from utilities.validate_args import unquote_ident
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_cols_and_types

def validate_output_and_summary_tables(model_out_table, module_name,
                                       out_table=None):
    """
    Validate a output table, and the associated summary table. The
    assumption here is that, given a model_out_table, there is also a summary
    table named model_out_table+"_summary" created. This function checks for
    the availability of both these tables.
    Optionally, the absence of an 'out_table' can also be checked for, which
    is the table that is to be created.
    Args:
        @param model_out_table
        @param module_name
        @param out_table (optional)

    Results:
        Throws an error if either model_out_table or model_out_table_"_summary"
        is not present. It also throws an error out_table (if specified)
        is already present.
    """
    _assert(model_out_table and model_out_table.strip().lower() not in ('null', ''),
            "Graph {0}: Invalid {0} table name.".format(module_name))
    _assert(table_exists(model_out_table),
            "Graph {0}: {0} table ({1}) is missing.".format(module_name, model_out_table))
    _assert(not table_is_empty(model_out_table),
            "Graph {0}: {0} table ({1}) is empty.".format(module_name, model_out_table))

    summary = add_postfix(model_out_table, "_summary")
    _assert(table_exists(summary),
            "Graph {0}: {0} summary table ({1}) is missing.".format(module_name, summary))
    _assert(not table_is_empty(summary),
            "Graph {0}: {0} summary table ({1}) is empty.".format(module_name, summary))

    if out_table:
        _assert(not table_exists(out_table),
                "Graph WCC: Output table {0} already exists.".format(out_table))


def validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
                          out_table, func_name, **kwargs):
    """
    Validates graph tables (vertex and edge) as well as the output table.
    """
    _assert(out_table and out_table.strip().lower() not in ('null', ''),
            "Graph {func_name}: Invalid output table name!".format(**locals()))
    _assert(not table_exists(out_table),
            "Graph {func_name}: Output table already exists!".format(**locals()))

    _assert(vertex_table and vertex_table.strip().lower() not in ('null', ''),
            "Graph {func_name}: Invalid vertex table name!".format(**locals()))
    _assert(table_exists(vertex_table),
            "Graph {func_name}: Vertex table ({vertex_table}) is missing!".format(
        **locals()))
    _assert(not table_is_empty(vertex_table),
            "Graph {func_name}: Vertex table ({vertex_table}) is empty!".format(
        **locals()))

    _assert(edge_table and edge_table.strip().lower() not in ('null', ''),
            "Graph {func_name}: Invalid edge table name!".format(**locals()))
    _assert(table_exists(edge_table),
            "Graph {func_name}: Edge table ({edge_table}) is missing!".format(
        **locals()))
    _assert(not table_is_empty(edge_table),
            "Graph {func_name}: Edge table ({edge_table}) is empty!".format(
        **locals()))

    existing_cols = set(unquote_ident(i) for i in get_cols(vertex_table))
    _assert(unquote_ident(vertex_id) in existing_cols,
            """Graph {func_name}: The vertex column {vertex_id} is not present in vertex table ({vertex_table}) """.
            format(**locals()))
    _assert(columns_exist_in_table(edge_table, edge_params.values()),
            """Graph {func_name}: Not all columns from {cols} are present in edge table ({edge_table})""".
            format(cols=edge_params.values(), **locals()))

    return None

def validate_params_for_link_analysis(schema_madlib, func_name,
                                            threshold, max_iter,
                                            edge_table=None,
                                            grouping_cols_list=None):
    _assert(not threshold or (threshold >= 0.0 and threshold <= 1.0),
            "{0}: Invalid threshold value ({1}), must be between 0 and 1.".
            format(func_name, threshold))
    _assert(max_iter > 0,
            """{0}: Invalid max_iter value ({1}), must be a positive integer.""".
            format(func_name, max_iter))
    if grouping_cols_list:
        # validate the grouping columns. We currently only support grouping_cols
        # to be column names in the edge_table, and not expressions!
        _assert(columns_exist_in_table(edge_table, grouping_cols_list, schema_madlib),
                "{0} error: One or more grouping columns specified do not exist!".
                format(func_name))

def update_output_grouping_tables_for_link_analysis(temp_summary_table,
                                                    iter_num,
                                                    summary_table,
                                                    out_table,
                                                    res_table,
                                                    grouping_cols_list,
                                                    cur_unconv,
                                                    message_unconv=None):
    """
        This function updates the summary and output tables only for those
        groups that have converged. This is found out by looking at groups
        that appear in cur_unvonv but not in message_unconv: message_unconv
        consists of groups that have not converged in the current iteration,
        while cur_unconv contains groups that had not converged in the
        previous iterations. The entries in cur_unconv is a superset of the
        entries in message_unconv. So the difference in the groups across
        the two tables represents the groups that converged in the current
        iteration.
    """
    plpy.execute("TRUNCATE TABLE {0}".format(temp_summary_table))
    if message_unconv is None:
        # If this function is called after max_iter is completed, without
        # convergence, all the unconverged groups from cur_unconv is used
        # (note that message_unconv is renamed to cur_unconv before checking
        # for unconverged==0 in the pagerank function's for loop)
        plpy.execute("""
            INSERT INTO {temp_summary_table}
            SELECT * FROM {cur_unconv}
            """.format(**locals()))
    else:
        plpy.execute("""
            INSERT INTO {temp_summary_table}
            SELECT {cur_unconv}.*
            FROM {cur_unconv}
            WHERE {join_condition}
            """.format(join_condition=get_filtered_cols_subquery_str(
            cur_unconv, message_unconv, grouping_cols_list), **locals()))

    plpy.execute("""
        INSERT INTO {summary_table}
        SELECT *, {iter_num}+1 AS __iteration__
        FROM {temp_summary_table}
        """.format(**locals()))
    plpy.execute("""
        INSERT INTO {out_table}
        SELECT {res_table}.*
        FROM {res_table}
        INNER JOIN {temp_summary_table}
        ON {join_condition}
        """.format(join_condition=' AND '.join(
        ["{res_table}.{col}={temp_summary_table}.{col}".format(
            **locals())
         for col in grouping_cols_list]), **locals()))

def get_default_threshold_for_link_analysis(n_vertices):
    """
        A fixed threshold value, of say 1e-5, might not work well when the
        number of vertices is a billion, since the initial score
        (PageRank/Authority/Hub etc.) value of all nodes would then be 1/1e-9.
        So, assign default threshold value based on number of nodes in the graph.
        NOTE: The heuristic below is not based on any scientific evidence.
    """
    _assert(n_vertices > 0, """Number of vertices must be greater than 0""")
    return 1.0 / (n_vertices * 1000)

def get_graph_usage(schema_madlib, func_name, other_text):

    usage = """
    ----------------------------------------------------------------------------
                                USAGE
    ----------------------------------------------------------------------------
     SELECT {schema_madlib}.{func_name}(
        vertex_table  TEXT, -- Name of the table that contains the vertex data.
        vertex_id     TEXT, -- Name of the column containing the vertex ids.
        edge_table    TEXT, -- Name of the table that contains the edge data.
        edge_args     TEXT{comma} -- A comma-delimited string containing multiple
                            -- named arguments of the form "name=value".
        {other_text}
    );

    The following parameters are supported for edge table arguments
    ('edge_args' above):

    src (default = 'src'): Name of the column containing the source
                           vertex ids in the edge table.
    dest (default = 'dest'): Name of the column containing the destination
                            vertex ids in the edge table.
    weight (default = 'weight'): Name of the column containing the weight of
                                 edges in the edge table.
    """.format(schema_madlib=schema_madlib,
               func_name=func_name,
               other_text=other_text,
               comma=',' if other_text is not None else ' ')
    return usage


def get_edge_params(schema_madlib, table, edge_args):

    params_types = {'src': str, 'dest': str, 'weight': str}
    default_args = {'src': 'src', 'dest': 'dest', 'weight': 'weight'}
    edge_params = extract_keyvalue_params(edge_args,
                                          params_types,
                                          default_args)
    src = edge_params["src"]
    dest = edge_params["dest"]
    weight = edge_params["weight"]

    # Find the appropriate column type for correct concat operations
    names_and_types = get_cols_and_types(table)
    final_type = 'integer'
    for col_name,col_type in names_and_types:
        # We want to check the column types except the weight column
        if col_type == 'bigint'  and col_name != weight:
            final_type = 'bigint'

    return edge_params, final_type
