# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# HITS

# Please refer to the hits.sql_in file for the documentation

"""
@file hits.py_in

@namespace graph
"""

import math
import plpy
import sys
from graph_utils import get_graph_usage
from graph_utils import get_default_threshold_for_link_analysis
from graph_utils import update_output_grouping_tables_for_link_analysis
from graph_utils import validate_graph_coding
from graph_utils import validate_params_for_link_analysis

from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import _check_groups
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import add_postfix
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string, split_quoted_delimited_str
from utilities.utilities import is_platform_pg

from utilities.validate_args import columns_exist_in_table, drop_tables
from utilities.validate_args import get_cols_and_types, table_exists
from utilities.utilities import rename_table

def validate_hits_args(schema_madlib, vertex_table, vertex_id, edge_table,
                       edge_params, out_table, max_iter, threshold,
                       grouping_cols_list=None):
    """
    Function to validate input parameters for HITS
    """
    validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
                          out_table, 'HITS')
    # Validate args such as threshold and max_iter
    validate_params_for_link_analysis(schema_madlib, "HITS",
                                            threshold, max_iter,
                                            edge_table, grouping_cols_list)


def hits(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
         out_table, max_iter, threshold, grouping_cols, **kwargs):
    """
    Function that computes the HITS scores
    Args:
        @param schema_madlib:
        @param vertex_table:
        @param vertex_id:
        @param edge_table:
        @param edge_args:
        @param out_table:
        @param max_iter:
        @param threshold:
        @param grouping_cols:
        @param kwargs:
    """
    with MinWarning('warning'):
        params_types = {'src': str, 'dest': str}
        default_args = {'src': 'src', 'dest': 'dest'}
        edge_params = extract_keyvalue_params(
            edge_args, params_types, default_args)

        # populate default values for optional params if null
        if max_iter is None:
            max_iter = 100
        if not vertex_id:
            vertex_id = "id"
        if not grouping_cols:
            grouping_cols = ''

        grouping_cols_list = split_quoted_delimited_str(grouping_cols)
        validate_hits_args(schema_madlib, vertex_table, vertex_id, edge_table,
                           edge_params, out_table, max_iter, threshold,
                           grouping_cols_list)

        summary_table = add_postfix(out_table, "_summary")
        _assert(not table_exists(summary_table),
                """Graph HITS: Output summary table ({summary_table}) already
                exists.""".format(**locals()))

        src = edge_params["src"]
        dest = edge_params["dest"]
        n_vertices = plpy.execute("""
                            SELECT COUNT({0}) AS cnt
                            FROM {1}
                        """.format(vertex_id, vertex_table))[0]["cnt"]

        if threshold is None:
            threshold = get_default_threshold_for_link_analysis(n_vertices)

        # table/column names used when grouping_cols is set.
        edge_temp_table = unique_string(desp='temp_edge')
        grouping_cols_comma = grouping_cols + ',' if grouping_cols else ''
        distribution = ('' if is_platform_pg() else
                        "DISTRIBUTED BY ({0}{1})".format(
                            grouping_cols_comma, dest))
        drop_tables([edge_temp_table])
        plpy.execute("""
                CREATE TEMP TABLE {edge_temp_table} AS
                SELECT * FROM {edge_table}
                {distribution}
            """.format(**locals()))

        ######################################################################
        # Set initial authority_norm and hub_norm as 1, so that later the final
        # norm should be positive number
        authority_init_value = 1.0
        hub_init_value = 1.0

        subquery1 = unique_string(desp='subquery1')

        distinct_grp_table = ''
        select_grouping_cols_comma = ''
        select_subquery1_grouping_cols_comma = ''
        group_by_clause = ''
        grouping_cols_for_create_table = ''
        grouping_cols_for_create_table_comma = ''
        # This table is created only when grouping is used.
        temp_summary_table = None
        if grouping_cols:
            distinct_grp_table = unique_string(desp='grp')
            drop_tables([distinct_grp_table])
            plpy.execute("""CREATE TEMP TABLE {distinct_grp_table} AS
                    SELECT DISTINCT {grouping_cols} FROM {edge_temp_table}
                """.format(**locals()))
            group_by_clause = get_table_qualified_col_str(subquery1, grouping_cols_list)
            select_grouping_cols_comma = group_by_clause + ','
            select_subquery1_grouping_cols_comma = grouping_cols + ','
            group_by_clause = 'GROUP BY ' + grouping_cols + ',' + vertex_id
            cols_names_types = get_cols_and_types(edge_table)
            grouping_cols_for_create_table = ', '.join([c_name + " " + c_type
                                                        for (c_name, c_type)
                                                        in cols_names_types
                                                        if c_name in
                                                        grouping_cols_list])
            grouping_cols_for_create_table_comma = \
                grouping_cols_for_create_table + ','
            temp_summary_table = unique_string(desp='temp_summary')
            drop_tables([temp_summary_table])
            plpy.execute("""
                CREATE TEMP TABLE {temp_summary_table} (
                    {grouping_cols_for_create_table}
                )
            """.format(**locals()))
            # Create output table. This will be updated whenever a group converges
            # Note that vertex_id is assumed to be an integer (as described in
            # documentation)
            plpy.execute("""
                    CREATE TABLE {out_table} (
                        {grouping_cols_for_create_table_comma}
                        {vertex_id} BIGINT,
                        authority DOUBLE PRECISION,
                        hub DOUBLE PRECISION
                    )
                """.format(**locals()))

        # Intermediate tables required.
        cur = unique_string(desp='cur')
        message = unique_string(desp='message')
        # curalias and msgalias are used as aliases for current and
        # message tables respectively during self joins
        curalias = unique_string(desp='curalias')
        msgalias = unique_string(desp='msgalias')
        message_unconv = unique_string(desp='message_unconv')
        subquery2 = unique_string(desp='subquery2')

        # GPDB has distributed by clauses to help them with indexing.
        # For Postgres we add the index explicitly.
        if is_platform_pg():
            plpy.execute("CREATE INDEX ON {0}({1})".format(
                edge_temp_table, dest))

        if is_platform_pg():
            cur_distribution = cnts_distribution = ''
        else:
            cur_distribution = cnts_distribution = \
                "DISTRIBUTED BY ({0}{1})".format(
                    grouping_cols_comma, vertex_id)
        cur_join_clause = " {cur}.{vertex_id} = {edge_temp_table}.{dest}"\
            .format(**locals())
        curalias_join_clause = "{curalias}.{vertex_id} = {edge_temp_table}.{src}"\
            .format(**locals())
        drop_tables([cur])
        plpy.execute("""
            CREATE TEMP TABLE {cur} AS
            SELECT {select_grouping_cols_comma} {subquery1}.{vertex_id},
                    {authority_init_value}::DOUBLE PRECISION AS authority,
                    {hub_init_value}::DOUBLE PRECISION AS hub
            FROM (
                SELECT {select_subquery1_grouping_cols_comma} {vertex_table}.{vertex_id}
                FROM {edge_temp_table} JOIN {vertex_table}
                ON {edge_temp_table}.{src}={vertex_table}.{vertex_id}
                UNION
                SELECT {select_subquery1_grouping_cols_comma} {vertex_table}.{vertex_id}
                FROM {edge_temp_table} JOIN {vertex_table}
                ON {edge_temp_table}.{dest}={vertex_table}.{vertex_id}
            ) {subquery1}
            {group_by_clause}
            {cur_distribution}
        """.format(**locals()))

        # The summary table contains the total number of iterations for each
        # group
        plpy.execute("""
                CREATE TABLE {summary_table} (
                    {grouping_cols_for_create_table_comma}
                    __iterations__ INTEGER
                )
            """.format(**locals()))

        # create message table
        cur_grouping_cols_comma = ''
        grouping_cols_join_condition_cur_and_edge = ''
        grouping_cols_join_condition_cur_and_curalias = ''

        # update message table
        msg_grouping_cols_comma = ''
        grouping_cols_where_clause_msg_subquery2 = ''
        grouping_cols_join_condition_msg_and_edge = ''
        grouping_cols_join_condition_msg_and_msgalias = ''
        grouping_cols_where_condition_msg_subquery1 = ''
        group_by_msg_grouping_cols = ''

        # check for convergence
        select_distinct_unconverged_rows = '{0}.{1}'.format(cur, vertex_id)
        grouping_cols_join_condition_cur_and_msg = ''

        if grouping_cols:
            cur_grouping_cols = get_table_qualified_col_str(cur, grouping_cols_list)
            cur_grouping_cols_comma = cur_grouping_cols + ','
            grouping_cols_join_condition_cur_and_edge = ' AND ' + \
                _check_groups(cur, edge_temp_table, grouping_cols_list)
            grouping_cols_join_condition_cur_and_curalias = ' AND ' + \
                _check_groups(cur, curalias, grouping_cols_list)

            msg_grouping_cols = get_table_qualified_col_str(message, grouping_cols_list)
            msg_grouping_cols_comma = msg_grouping_cols + ','
            grouping_cols_where_clause_msg_subquery2 = ' AND ' + \
                _check_groups(message, subquery2, grouping_cols_list)
            grouping_cols_join_condition_msg_and_edge = ' AND ' + \
                _check_groups(message, edge_temp_table, grouping_cols_list)
            grouping_cols_join_condition_msg_and_msgalias = ' AND ' + \
                _check_groups(message, msgalias, grouping_cols_list)

            group_by_msg_grouping_cols = ' GROUP BY ' + msg_grouping_cols

            grouping_cols_where_condition_msg_subquery1 = ' WHERE ' + \
                _check_groups(message, subquery1, grouping_cols_list)

            # this is used for the message_unconv table to find out how many groups
            # have converged
            select_distinct_unconverged_rows = cur_grouping_cols
            grouping_cols_join_condition_cur_and_msg = ' AND ' + \
                _check_groups(cur, message, grouping_cols_list)

        # Variables common to both grouping and non-grouping cases
        message_join_clause = "{message}.{vertex_id} = \
        {edge_temp_table}.{src}".format(**locals())
        msgalias_join_clause = "{msgalias}.{vertex_id} = {edge_temp_table}.{dest}".format(
            **locals())
        sum_norm_square_root = unique_string(desp='sum_norm_sqr_root')
        cur_unconv = unique_string(desp='cur_unconv')

        converged = False
        iteration_num = 0
        """
        We need to calculate the hub and authority scores for each iteration
        and pass the current iteration values to the next iteration.
        To achieve this, we use the following tables
        1. cur -> This gets initialized with default values for both authority
        and hub.
        2. message -> This gets created for each iteration with newly
        computed scores based on cur. At the end of each iteration, we rename
        message to cur as a way to pass authority and hub values to the next
        iteration.
        This convention is similar to message passing paradigm in
        distributed systems such as spark.
        """
        for iteration_num in range(max_iter):

            calculate_authority_and_hub_scores(**locals())

            # Check for convergence only if threshold != 0.
            if threshold != 0:

                converged = check_for_convergence(**locals())

                if iteration_num > 0 and grouping_cols:
                    # Update result and summary tables for groups that have
                    # converged
                    # since the last iteration.
                    update_output_grouping_tables_for_link_analysis(temp_summary_table,
                                                                          iteration_num,
                                                                          summary_table,
                                                                          out_table,
                                                                          message,
                                                                          grouping_cols_list,
                                                                          cur_unconv,
                                                                          message_unconv)

                drop_tables([cur_unconv])
                plpy.execute("""ALTER TABLE {message_unconv} RENAME TO
                    {cur_unconv} """.format(**locals()))

            drop_tables([cur])
            plpy.execute("""ALTER TABLE {message} RENAME TO {cur}
                    """.format(**locals()))
            drop_tables([sum_norm_square_root])

            if converged:
                break

        update_final_results(schema_madlib, converged, threshold, cur, temp_summary_table,
                             iteration_num, summary_table, out_table,
                             grouping_cols_list, cur_unconv, distinct_grp_table)

        # Cleanup All the intermediate tables
        drop_tables([cur, message, cur_unconv, message_unconv, edge_temp_table])

        if grouping_cols:
            drop_tables([distinct_grp_table, temp_summary_table])


def update_final_results(schema_madlib, converged, threshold, cur, temp_summary_table,
                         iteration_num, summary_table, out_table,
                         grouping_cols_list, cur_unconv, distinct_grp_table):
    """
        If there still are some converged/unconverged nodes (within groups/entire
        table), update results table for those nodes.
    """
    if grouping_cols_list:
        if not converged:
            if threshold != 0:
                # We completed max_iters, but there are still some unconverged
                # groups # Update the result and summary tables for unconverged
                # groups.
                update_output_grouping_tables_for_link_analysis(temp_summary_table,
                                                                      iteration_num,
                                                                      summary_table,
                                                                      out_table, cur,
                                                                      grouping_cols_list,
                                                                      cur_unconv)
            else:
                # No group has converged. List of all group values are in
                # distinct_grp_table.
                update_output_grouping_tables_for_link_analysis(temp_summary_table,
                                                                      iteration_num,
                                                                      summary_table,
                                                                      out_table, cur,
                                                                      grouping_cols_list,
                                                                      distinct_grp_table)
    else:

        rename_table(schema_madlib, cur, out_table)
        plpy.execute("""
                INSERT INTO {summary_table} VALUES
                ({iteration_num}+1)
                """.format(**locals()))


def check_for_convergence(**kwargs):
    # message_unconv and cur_unconv will contain the unconverged
    # groups after current and previous iterations respectively.
    # we check if there is at least one unconverged node (limit 1
    # is used in the query).
    plpy.execute("""
                    CREATE TABLE {message_unconv} AS
                    SELECT DISTINCT {select_distinct_unconverged_rows}
                    FROM {message}
                    INNER JOIN {cur}
                    ON {cur}.{vertex_id}={message}.{vertex_id}
                    {grouping_cols_join_condition_cur_and_msg}
                    WHERE ABS({cur}.authority-{message}.authority) > {threshold}
                    OR ABS({cur}.hub-{message}.hub) > {threshold}
                """.format(**kwargs))
    unconverged_node_num = plpy.execute("""
                        SELECT COUNT(*) AS cnt FROM {message_unconv}
                    """.format(**kwargs))[0]["cnt"]
    return unconverged_node_num == 0


def calculate_authority_and_hub_scores(**kwargs):
    """
    This function is responsible for calculating the authority and hub scores.
    This is done in a two-step process:
    1. create message table and compute authority score.
    2. Use authority scores computed to update the hub score.

    :param kwargs: dict of locals() of the calling function.
    :return:
    """
    ###################################################################
    # HITS scores for nodes in a graph at any given iteration 'i' is
    # calculated as following:
    # authority_i(A) = hub_i(B) + hub_i(C) + ..., where B, C are nodes
    # that have edges that point to node A
    # After calculating authority scores for all nodes, hub scores are
    # calculated as following:
    # hub_i(A) = authority_i(D) + authority_i(E) + ..., where D, E are
    # nodes that A points to
    # At the end of each iteration, a normalization will
    # be done for all authority scores and hub scores using L2 distance
    ###################################################################

    ###################################################################
    # calculate authority
    # if there is no node that point to A, authority_i(A) = 0
    ###################################################################
    plpy.execute("""
                    CREATE TABLE {message} AS
                    SELECT {cur_grouping_cols_comma} {cur}.{vertex_id} AS {vertex_id},
                            COALESCE(SUM({curalias}.hub), 0.0) AS authority,
                            {cur}.hub AS hub
                    FROM {cur}
                        LEFT JOIN {edge_temp_table} ON
                        {cur_join_clause} {grouping_cols_join_condition_cur_and_edge}
                        LEFT JOIN {cur} AS {curalias} ON
                        {curalias_join_clause} {grouping_cols_join_condition_cur_and_curalias}
                    GROUP BY {cur_grouping_cols_comma} {cur}.{vertex_id}, {cur}.hub
                    {cur_distribution}
                """.format(**kwargs))
    ###################################################################
    # calculate hub
    # if node A doesn't point to any node, hub_i(A) = 0
    ###################################################################
    plpy.execute("""
                    UPDATE {message}
                    SET hub = {subquery2}.hub
                    FROM
                    (SELECT {msg_grouping_cols_comma} {message}.{vertex_id} AS {vertex_id},
                            COALESCE(SUM({msgalias}.authority), 0) AS hub
                    FROM {message}
                        LEFT JOIN {edge_temp_table} ON
                        {message_join_clause} {grouping_cols_join_condition_msg_and_edge}
                        LEFT JOIN {message} AS {msgalias} ON
                        {msgalias_join_clause} {grouping_cols_join_condition_msg_and_msgalias}
                    GROUP BY {msg_grouping_cols_comma} {message}.{vertex_id}) AS {subquery2}
                    WHERE {subquery2}.{vertex_id} = {message}.{vertex_id}
                    {grouping_cols_where_clause_msg_subquery2}
                """.format(**kwargs))
    # normalize authority and hub score with L2 distance
    plpy.execute("""
                    CREATE TEMP TABLE {sum_norm_square_root} AS
                        SELECT {msg_grouping_cols_comma}
                            SQRT(SUM(POWER(authority, 2))) AS auth_sum_norm_square_root,
                            SQRT(SUM(POWER(hub, 2))) AS hub_sum_norm_square_root
                        FROM {message}
                    {group_by_msg_grouping_cols}
                """.format(**kwargs))

    num_zero_sum_norm_square_root = plpy.execute("""
                                                    SELECT COUNT(*) AS cnt
                                                    FROM {sum_norm_square_root}
                                                    WHERE auth_sum_norm_square_root = 0
                                                    OR hub_sum_norm_square_root = 0
                                                """.format(**kwargs))[0]["cnt"]

    if num_zero_sum_norm_square_root > 0:
        plpy.error("Error while normalizing authority score, please \
                            make sure your graph is a directed graph")
    plpy.execute("""
                    UPDATE {message}
                        SET authority = {message}.authority/{subquery1}.auth_sum_norm_square_root,
                            hub = {message}.hub/{subquery1}.hub_sum_norm_square_root
                    from (SELECT {grouping_cols_comma}
                            auth_sum_norm_square_root,
                            hub_sum_norm_square_root
                        FROM {sum_norm_square_root}) {subquery1}
                    {grouping_cols_where_condition_msg_subquery1}
                """.format(**kwargs))


def hits_help(schema_madlib, message, **kwargs):
    """
    Help function for hits

    Args:
        @param schema_madlib
        @param message: string, Help message string
        @param kwargs

    Returns:
        String. Help/usage information
    """
    if message is not None and \
            message.lower() in ("usage", "help", "?"):
        help_string = "Get from method below"
        help_string = get_graph_usage(schema_madlib, 'HITS',
     """out_table     TEXT, -- Name of the output table for HITS
        max_iter      INTEGER, -- Maximum iteration number (DEFAULT = 100)
        threshold     DOUBLE PRECISION, -- Stopping criteria (DEFAULT = 1/(N*1000),
                                        -- N is number of vertices in the graph)
        grouping_cols TEXT -- Comma separated column names to group on
                           -- (DEFAULT = NULL, no grouping)
""") + """

A summary table is also created that contains information regarding the
number of iterations required for convergence. It is named by adding the
suffix '_summary' to the 'out_table' parameter.
"""
    else:
        help_string = """
----------------------------------------------------------------------------
                                SUMMARY
----------------------------------------------------------------------------
Given a directed graph, hits algorithm finds the authority and hub scores of
all the vertices in the graph.
--
For an overview on usage, run:
SELECT {schema_madlib}.hits('usage');
"""

    return help_string.format(schema_madlib=schema_madlib)
# ---------------------------------------------------------------------
