# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
@file pathing.py_in

@brief Pathing functions

@namespace utilities
"""
import plpy
import shlex
import string
import re

from control import MinWarning
from utilities import unique_string, _assert, add_postfix
from utilities import py_list_to_sql_string
from validate_args import get_cols
from validate_args import input_tbl_valid, output_tbl_valid, is_var_valid
from validate_args import quote_ident
# ------------------------------------------------------------------------

m4_changequote(`<!', `!>')


def path(schema_madlib, source_table, output_table, partition_expr,
         order_expr, pattern_expr, symbol_expr, agg_func,
         persist_rows, overlapping_patterns, **kwargs):
    """
        Perform regular pattern matching over a sequence of rows.

        Args:
        @param schema_madlib: str, Name of the MADlib schema
        @param source_table: str, Name of the input table/view
        @param output_table: str, Name of the table to store result
        @param partition_expr: str, Expression to partition (group) the input data
        @param order_expr: str, Expression to order the input data
        @param pattern_expr: str, Expression to define the pattern to search for
        @param symbol_expr: str, Definition for each symbol, comma-separated list
        @param agg_func: str, List of the result functions/aggregates to apply on matched patterns

    """
    with MinWarning("error"):
        # check for both false and None
        if not partition_expr:
            partition_expr = "1 = 1"
        if not persist_rows:
            # persist_rows = None implies no preference
            # persist_rows = False implies do not store the matched rows
            persist_rows = not bool(agg_func)
        if not overlapping_patterns:
            overlapping_patterns = False

        _validate(source_table, output_table, partition_expr, order_expr,
                  pattern_expr, symbol_expr, agg_func, persist_rows)

        new_pattern_expr, long_sym_str, short_sym_str = _parse_symbol_str(symbol_expr, pattern_expr)

        # build variables for intermediate objects
        input_with_id = unique_string('input_with_id')
        matched_view = unique_string('matched_view')
        id_col_name = unique_string('id_col')
        matched_partitions = unique_string('matched_partitions')
        seq_gen = unique_string('seq_gen')
        short_sym_name_str = unique_string('short_sym')
        long_sym_name_str = unique_string('long_sym')
        match_to_row_id = unique_string('match_to_row_id')
        match_id = unique_string('match_id')

        all_input_cols = [i.strip() for i in get_cols(source_table)]
        all_input_cols_str = ', '.join(all_input_cols)
        if persist_rows:
            matched_rows = add_postfix(output_table, "_tuples")
            table_or_view = 'TABLE'
        else:
            matched_rows = unique_string('matched_rows')
            table_or_view = 'VIEW'


        # build a new input temp table that contains a sequence and partition columns
        split_p_cols = [i.strip() for i in partition_expr.split(',')]
        p_col_names = [unique_string() for i in split_p_cols]
        p_col_as_str = ','.join(
            [i + " AS " + j for i, j in zip(split_p_cols, p_col_names)])
        p_col_name_str = ', '.join(p_col_names)
        distribution = m4_ifdef(<!__POSTGRESQL__!>, <!''!>,
                                <!"DISTRIBUTED BY ({0})".format(p_col_name_str)!>)
        plpy.execute("""
                     CREATE TEMP TABLE {input_with_id} AS
                         SELECT
                            {p_col_as_str},
                            *,
                            row_number() OVER() AS {id_col_name},
                            CASE
                                {short_sym_str}
                            END AS {short_sym_name_str},
                            CASE
                                {long_sym_str}
                            END AS {long_sym_name_str}
                         FROM {source_table}
                     {distribution}
                    """.format(**locals()))
        # Explanation for computing the path matches:
        #   Match is performed using regular expression pattern matching on a
        #   string produced by concatenating the symbols. The exact rows that
        #   produce the match are identified by correlating the matched string
        #   indices with another array containing row ids.

        match_id_name = "__madlib_path_match_id__" if "match_id" in all_input_cols else "match_id"
        symbol_name = "__madlib_path_symbol__" if "symbol" in all_input_cols else "symbol"

        plpy.execute("""
            CREATE {table_or_view} {matched_rows} AS
            SELECT {all_input_cols_str},
                   {long_sym_name_str} AS {symbol_name},
                   {match_id} AS {match_id_name}
            FROM
                {input_with_id} as source,
                (
                    SELECT
                        unnest((matched).id) as {match_id},
                        unnest((matched).row_id) as {match_to_row_id}
                    FROM
                    (
                        SELECT
                            {m}.path_pattern_match(
                                array_to_string(array_agg({short_sym_name_str} ORDER BY {order_expr}), '')::text,
                                '{new_pattern_expr}'::text,
                                array_agg({id_col_name} ORDER BY {order_expr})::float8[],
                                {overlapping_patterns}::boolean
                            ) as matched
                        FROM {input_with_id}
                        WHERE {short_sym_name_str} is NOT NULL
                        GROUP BY {p_col_name_str}
                    ) q
                ) as matched_rows
            WHERE source.{id_col_name} = matched_rows.{match_to_row_id}
            """.format(m=schema_madlib, **locals()))

        quoted_split_p_cols = [quote_ident(i) for i in split_p_cols]
        p_col_orig_name_str = ','.join(
            [i + " AS " + j for i, j in zip(split_p_cols, quoted_split_p_cols)])
        if agg_func:
            if partition_expr == '1 = 1':
                # no partition
                plpy.execute("""
                    CREATE TABLE {output_table} AS
                       SELECT
                            {match_id_name},
                            {agg_func}
                       FROM {matched_rows}
                       GROUP BY {match_id_name}
                    """.format(**locals()))
            else:
                plpy.execute("""
                    CREATE TABLE {output_table} AS
                       SELECT
                            {p_col_orig_name_str},
                            {match_id_name},
                            {agg_func}
                       FROM {matched_rows}
                       GROUP BY {partition_expr}, {match_id_name}
                    """.format(**locals()))
            result = "Aggregation result available in table " + output_table
        else:
            result = "No aggregation table created"
        if not persist_rows:
            plpy.execute("DROP VIEW IF EXISTS " + matched_rows)
        else:
            result += "\n Matched tuples can be found in table " + matched_rows
        plpy.execute("DROP TABLE IF EXISTS " + input_with_id)
    return result
# ------------------------------------------------------------------------------


def _validate(source_table, output_table, partition_expr, order_expr,
              pattern_expr, symbol_expr, agg_func, persist_rows):
    input_tbl_valid(source_table, 'Path')
    output_tbl_valid(output_table, 'Path')
    if persist_rows:
        output_tbl_valid(add_postfix(output_table, "_tuples"), 'Path')

    # ensure the expressions are not None or empty strings
    _assert(partition_expr, "Path error: Invalid partition expression")
    _assert(order_expr, "Path error: Invalid order expression")
    # ensure the partition/order expression can actually be used
    _assert(is_var_valid(source_table, partition_expr, order_expr),
            "Path error: invalid partition expression or order expression")

    _assert(pattern_expr, "Path error: Invalid pattern expression")
    _assert(symbol_expr, "Path error: Invalid symbol expression")
# ----------------------------------------------------------------------


def _parse_symbol_str(symbol_expr, pattern_expr):
    """ Parse symbol definition to build a CASE statement string
        and return a mapping of the definitions.

        Currently only single-character symbols are allowed.
        Postgresql regular expression match functions will be used on a string
        of symbols, where each symbol represents a tuple. Only single-length
        symbols can be used to maintain 1:1 correspondence between symbol and tuple.

        To allow input for multicharacter symbol for user convenience, inputed
        symbols are mapped to a single character in the pattern expression.
        This updated pattern expression is returned back to the caller.

        Further, two case statements are built to mark each tuple in the input table
        with the corresponding original (user-supplied) symbol and the
        new (single-character) symbol.

        Args:
            @param symbol_expr: str, A comma-separated string containing
                symbol definitions of the form: <symbol> := <symbol_definition>
            @param sym_mapping_tbl: str, Name of the table to output the
                correspondence table between symbol
            @param pattern_expr: str, The pattern expression where the original
                symbols are used. The original symbols are replaced by the corresponding
                new symbols.

        Example:
            symbol_expr = ('BEFORE:=start >= \'0:00:00\' and start < \'9:30:00\', '
                           'MARKET:=start >= \'9:30:00\' and start < \'16:00:00\'')
            pattern_expr = "(BEFORE)*(MARKET)*"

            returns  ("a*b*",
                      "CASE
                        WHEN start >= \'0:00:00\' and start < \'9:30:00\' THEN 'BEFORE'
                        WHEN start >= \'9:30:00\' and start < \'16:00:00\' THEN 'MARKET'
                       END",
                      "CASE
                        WHEN start >= \'0:00:00\' and start < \'9:30:00\' THEN 'a'
                        WHEN start >= \'9:30:00\' and start < \'16:00:00\' THEN 'b'
                       END"
        Returns:
            (str, str, str)
    """
    # all_symbols is all valid single-character symbols
    all_symbols = iter(string.ascii_lowercase + string.digits)
    symbol_expr_parser = shlex.shlex(symbol_expr)
    symbol_expr_parser.wordchars = [i for i in string.printable
                                    if i not in (symbol_expr_parser.quotes + ",")]
    symbol_expr_parser.whitespace = ','
    # parse symbol expr to get the strings between commas
    sym_def_parsed = list(symbol_expr_parser)

    orig_symbols_ordered = []
    orig_sym_definitions = {}
    new_sym_definitions = {}
    old_to_new = {}
    for each_sym_def in sym_def_parsed:
        # symbols are defined as a pair: <name> := <definition>
        sym_def_split = each_sym_def.split(":=")
        if len(sym_def_split) == 2:
            orig_sym, sym_def = (i.strip() for i in sym_def_split)
            orig_symbols_ordered.append(orig_sym)
            try:
                next_sym = all_symbols.next()
            except StopIteration:
                plpy.error("Path error: Total symbols in the symbol expression "
                           "exceed maximum number of symbols allowed.")
            # symbols are supposed to be case-insensitive. Use the lower-case
            # version to maintain a mapping from original to new symbol name
            orig_sym_lower = re.escape(orig_sym.lower())
            _assert(orig_sym_lower not in old_to_new,
                    "Path error: Multipe definitions of a symbol")
            old_to_new[orig_sym_lower] = next_sym
            orig_sym_definitions[orig_sym] = sym_def
            new_sym_definitions[next_sym] = sym_def

    # replace each occurence of the original symbol with the new
    # perform this operation in descending order of length to avoid substituting
    # subset of any symbol
    old_symbols_desc = list(sorted(old_to_new.keys(), key=len, reverse=True))
    replace_pattern = re.compile('|'.join(old_symbols_desc), re.IGNORECASE)
    new_pattern_expr = replace_pattern.sub(
        lambda m: old_to_new[re.escape(string.lower(m.group(0)))],
        pattern_expr)

    # build a case statement to search a tuple for each definition and pick the
    # appropriate symbol.
    orig_sym_case_stmt = []
    new_sym_case_stmt = []
    case_stmt = "WHEN {d} THEN '{s}'::text"
    for k in orig_symbols_ordered:
        orig_sym_case_stmt.append(case_stmt.format(s=k, d=orig_sym_definitions[k]))
        new_sym_case_stmt.append(case_stmt.format(s=old_to_new[re.escape(k.lower())],
                                                  d=orig_sym_definitions[k]))
    return (new_pattern_expr, '\n'.join(orig_sym_case_stmt), '\n'.join(new_sym_case_stmt))
# ----------------------------------------------------------------------


def path_help_message(schema_madlib, message, **kwargs):
    """ Help message for path function
    """
    summary_string = """
---------------------------------------------------------------------------
                                SUMMARY
---------------------------------------------------------------------------
Functionality: Path

The goal of the MADlib path function is to perform regular pattern matching
over a sequence of rows, and to extract useful information about the matches.
The useful information could be a simple count of matches or something more
involved like aggregation.

For more details on function usage:
    SELECT {schema_madlib}.path('usage');
    """.format(schema_madlib=schema_madlib)

    usage_string = """
---------------------------------------------------------------------------
                                USAGE
---------------------------------------------------------------------------
SELECT {schema_madlib}.path(
    'source_table',    -- Name of the table
    'output_table',    -- Table name to store the path results
    'partition_expr',  -- Partition expression to group the data table
    'order_expr',      -- Order expression to sort the tuples of the data table
    'symbol_def',      -- Definition of various symbols used in the pattern definition
    'pattern_def',     -- Definition of the path pattern to search for
    'agg_func',        -- Aggregate/window functions to be applied on the matched paths
    persist_rows,       -- Boolean indicating whether to output the matched
                        --  rows in an additional table (named <output_table>_tuples)
    overlapping_patterns    -- Boolean indicating whether to find every
                            -- overlapping occurrence of the pattern in the partition
);
    """.format(schema_madlib=schema_madlib)

    if not message:
        return summary_string
    elif message.lower() in ('usage', 'help', '?'):
        return usage_string
    else:
        return """
No such option. Use "SELECT {schema_madlib}.path()" for help.
        """.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------
