
import collections
from datetime import datetime
import re
import time
import random
from distutils.util import strtobool
# import numpy as np

from validate_args import _get_table_schema_names
from validate_args import cols_in_tbl_valid
from validate_args import does_exclude_reserved
from validate_args import explicit_bool_to_text
from validate_args import get_cols
from validate_args import get_first_schema
from validate_args import input_tbl_valid
from validate_args import is_var_valid
from validate_args import output_tbl_valid
from validate_args import quote_ident
from validate_args import unquote_ident
from validate_args import drop_tables
import plpy

m4_changequote(`<!', `!>')


def has_function_properties():
    """ __HAS_FUNCTION_PROPERTIES__ variable defined during configure """
    return m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)


def is_platform_pg():
    """ __POSTGRESQL__ variable defined during configure """
    return m4_ifdef(<!__POSTGRESQL__!>, <!True!>, <!False!>)

def is_platform_gp6_or_up():
    version_wrapper = __mad_version()
    return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0')

# ------------------------------------------------------------------------------


def get_seg_number():
    """ Find out how many primary segments(not include master segment) exist
        in the distribution. Might be useful for partitioning data.
    """
    if is_platform_pg():
        return 1
    else:
        count = plpy.execute("""
            SELECT count(*) from gp_segment_configuration
            WHERE role = 'p' and content != -1
            """)[0]['count']
        # in case some weird gpdb configuration happens, always returns
        # primary segment number >= 1
        return max(1, count)
# ------------------------------------------------------------------------------

def get_segments_per_host():
    """ Find out how many primary segments(not include master segment) exist
        per host. We assume every host has the same number of segments and
        we only return the first one.
    """
    if is_platform_pg():
        return 1
    else:
        count = plpy.execute("""
            SELECT count(*) from gp_segment_configuration
            WHERE role = 'p' and content != -1
            GROUP BY hostname
            LIMIT 1
            """)[0]['count']
        # in case some weird gpdb configuration happens, always returns
        # primary segment number >= 1
        return max(1, count)
# ------------------------------------------------------------------------------

def is_orca():
    if has_function_properties():
        optimizer = plpy.execute("show optimizer")[0]["optimizer"]
        if optimizer == 'on':
            return True
    return False
# ------------------------------------------------------------------------------


def _assert_equal(o1, o2, msg):
    """
    @brief if the given objects are not equal, then raise an error with the message
    @param o1           the first object
    @param o2           the second object
    @param msg          the error message to be reported
    """
    if not o1 == o2:
        plpy.error(msg)
# ------------------------------------------------------------------------------


def _assert(condition, msg):
    """
    @brief if the given condition is false, then raise an error with the message
    @param condition    the condition to be asserted
    @param msg          the error message to be reported
    """
    if not condition:
        plpy.error(msg)
# ------------------------------------------------------------------------------


def warn(condition, msg):
    """
    @brief if the given condition is false, then raise a warning with the message
    @param condition    the condition to be asserted
    @param msg          the error message to be reported
    """
    if not condition:
        plpy.warning(msg)
# ------------------------------------------------------------------------------


def get_distributed_by(source_table):
    """ Return a "distributed by (...)" clause that defines distribution policy of source_table
    Args:
        @param source_table

    Returns:
        List of str.
    """
    _, table_name = _get_table_schema_names(source_table)
    schema_name = get_first_schema(source_table)

    # GPDB 6 has pg_get_table_distributedby(<oid>) function to get the
    # DISTRIBUTED BY clause of a table. In older version, we have to
    # dig the column names based on gp_distribution_policy catalog.
    version_wrapper = __mad_version()
    if version_wrapper.is_gp_version_less_than("6.0"):
        dist_attr = plpy.execute("""
        SELECT array_agg(pga.attname) as dist_attr
        FROM (
            SELECT gdp.localoid,
                     CASE
                         WHEN ( ARRAY_UPPER(gdp.attrnums, 1) > 0 ) THEN
                            UNNEST(gdp.attrnums)
                         ELSE NULL
                     END AS attnum
                FROM gp_distribution_policy gdp
            ) AS distkey
            INNER JOIN pg_class AS pgc
            ON distkey.localoid = pgc.oid AND pgc.relname = '{table_name}'
            INNER JOIN pg_namespace pgn
            ON pgc.relnamespace = pgn.oid AND pgn.nspname = '{schema_name}'
            LEFT OUTER JOIN pg_attribute pga
            ON distkey.attnum = pga.attnum AND distkey.localoid = pga.attrelid
        """.format(table_name=table_name, schema_name=schema_name))[0]["dist_attr"]
        if len(dist_attr) > 0:
            dist_str = 'distributed by (' + ','.join(['"%s"' % i
                                                      for i in dist_attr
                                                      if i is not None]) + ')'
        else:
            dist_str = 'distributed randomly'
    else:
        dist_str = plpy.execute("""
        SELECT pg_catalog.pg_get_table_distributedby(pgc.oid) as distributedby
        FROM pg_class AS pgc
        INNER JOIN pg_namespace pgn ON pgc.relnamespace = pgn.oid
        WHERE pgc.relname = '{table_name}' AND pgn.nspname = '{schema_name}'
        """.format(table_name=table_name, schema_name=schema_name))[0]["distributedby"]

    return dist_str

# ------------------------------------------------------------------------------


def num_features(source_table, independent_varname):
    return plpy.execute("SELECT array_upper({0}, 1) AS dim "
                        "FROM {1} LIMIT 1"
                        .format(independent_varname,
                                source_table))[0]['dim']
# ------------------------------------------------------------------------------


def num_samples(source_table):
    return plpy.execute("SELECT count(*) AS n FROM {0}"
                        .format(source_table))[0]['n']
# ------------------------------------------------------------------------------


def unique_string(desp='', **kwargs):
    """
    Generate random remporary names for temp table and other names.
    It has a SQL interface so both SQL and Python functions can call it.
    """
    r1 = random.randint(1, 100000000)
    r2 = int(time.time())
    r3 = int(time.time()) % random.randint(1, 100000000)
    u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
    return u_string
# ------------------------------------------------------------------------------


def add_postfix(quoted_string, postfix):
    """ Append a string to the end of the table name.
    If input table name is quoted by double quotes, make sure the postfix is
    inside of the double quotes.

    Arguments:
        @param quoted_string: str. A string representing a database quoted string
        @param postfix: str. A string to add as a suffix to quoted_string.
                            ** This is assumed to not contain any quotes **
    """
    quoted_string = quoted_string.strip()
    if quoted_string.startswith('"') and quoted_string.endswith('"'):
        output_str = quoted_string[:-1] + postfix + '"'
    else:
        output_str = quoted_string + postfix
    return output_str
# -------------------------------------------------------------------------


NUMERIC = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric',
               'real', 'double precision', 'float', 'serial', 'bigserial'])
INTEGER = set(['smallint', 'integer', 'bigint'])
TEXT = set(['text', 'varchar', 'character varying', 'char', 'character'])
BOOLEAN = set(['boolean'])
INCLUDE_ARRAY = set([unique_string('__include_array__')])
ONLY_ARRAY = set([unique_string('__only_array__')])
ANY_ARRAY = set([unique_string('__any_array__')])


def is_valid_psql_type(arg, valid_types):
    """ Verify if argument is a valid type

    Args:
        @param arg: str. Name of the Postgres type to validate
        @param valid_types: set. Set of valid type names to search.
                            This is typically created using the global names
                            in this module.
                            Three non-type flags are provided
                            (in descending order of precedence):
                                - ANY_ARRAY: check if arg is any array type
                                - ONLY_ARRAY: indicates that only array forms
                                    of the valid types should be checked
                                - INCLUDE_ARRAY: indicates that array and scalar
                                    forms of the valid types should be checked
                Examples: 1. valid_types = BOOLEAN | INTEGER | TEXT
                          2. valid_types = BOOLEAN | INTEGER | ONLY_ARRAY
                          3. valid_types = NUMERIC | INCLUDE_ARRAY
    """
    if not arg or not valid_types:
        return False
    arg = arg.lower()
    if ANY_ARRAY <= valid_types:
        return arg.rstrip().endswith('[]')
    if ONLY_ARRAY <= valid_types:
        return (arg.rstrip().endswith('[]') and arg.rstrip('[] ') in valid_types)
    if INCLUDE_ARRAY <= valid_types:
        # Remove the [] from end of the arg type
        # The single space is needed to ensure trailing white space is stripped
        arg = arg.rstrip('[] ')
    return (arg in valid_types)
# ------------------------------------------------------------------------------


def is_psql_numeric_type(arg, exclude=None):
    """
    Checks if argument is one of the various numeric types in PostgreSQL
    Args:
        @param arg: string, Type name to check
        @param exclude: iterable, List of types to exclude from checking

    Returns:
        Boolean. Returns if 'arg' is one of the numeric types
    """
    if exclude is None:
        exclude = []
    to_check_types = NUMERIC - set(exclude)
    return (arg in to_check_types)
# -------------------------------------------------------------------------


def is_psql_int_type(arg, exclude=None):
    """
    Checks if argument is one of the various numeric types in PostgreSQL
    Args:
        @param arg: string, Type name to check
        @param exclude: iterable, List of types to exclude from checking

    Returns:
        Boolean. Returns if 'arg' is one of the numeric types
    """
    if exclude is None:
        to_check_types = INTEGER
    else:
        to_check_types = INTEGER - set(exclude)
    return (arg in to_check_types)
# -------------------------------------------------------------------------


def is_psql_char_type(arg, exclude_list=[]):
    """
    This function checks if the given arg is one of the predefined postgres
    character types
    :param arg:
    :param exclude: Optionally exclude one or more types from the comparison
    :return: True if it is one of the character types, else False.

    """
    if not isinstance(exclude_list, list):
        exclude_list = [exclude_list]
    return arg in TEXT - set(exclude_list)


def is_psql_boolean_type(arg):
        """
        This function checks if the given arg is one of type postgres boolean
        :param arg:
        :return: True if it is boolean, else False.
        """
        return arg == 'boolean'


def is_string_formatted_as_array_expression(string_to_match):
    """
    Return true if the string is formatted as array[<something>], else false
    :param string_to_match:
    """
    matched = re.match(r"(?i)^array\[(.*)\]", string_to_match)
    return matched


def _string_to_array(s):
    """
    Split a string into an array of strings
    Any space around the substrings are removed

    Requirement: every individual element in the string
    must be a valid Postgres name, which means that if
    there are spaces or commas in the element then the
    whole element must be quoted by a pair of double
    quotes.

    Usually this is not a problem. Especially in older
    versions of GPDB, when an array is passed from
    SQL to Python, it is converted to a string, which
    automatically adds double quotes if there are spaces or
    commas in the element.

    So use this function, if you are sure that all elements
    are valid Postgres names.
    """
    elm = []
    for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
        elm.append(m.group(1))
    for i in range(len(elm)):
        elm[i] = elm[i].strip("\"")
    return elm
# ------------------------------------------------------------------------


def _string_to_array_with_quotes(s):
    """
    Same as _string_to_array except the double quotes will be kept.
    """
    elm = []
    for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
        elm.append(m.group(1))
    return elm
# ------------------------------------------------------------------------

def get_col_name_type_sql_string(colnames, coltypes):
    if colnames and coltypes and len(colnames)==len(coltypes):
        return ','.join(map(' '.join, zip(colnames, coltypes)))
    return None
# ------------------------------------------------------------------------

def py_list_to_sql_string(array, array_type=None, long_format=None):
    """Convert a list to SQL array string

    @note: The long format is recommended with the input values quoted
    appropriately by 'quote_literal'.

    The short format (e.g. '{1,2,3}') does not take care of escaping the
    string values in the array. Hence an input like '{ t$}t, test}' will fail due
    to the unescaped '{'.
    For eg. following is a valid input for short format:
        array = ['M'M', M\"M, "M$M,M'M", M, M, M@[\}(:*;M, MM ]
    Note the double quotes and curly brackets need to be escaped in the input if
    used with the short format.
    """
    if long_format is None:
        if (array_type is not None and
                (any(array_type.startswith(i)
                     for i in ["text", "varchar", "character varying"]))):
            long_format = False
        else:
            long_format = True
    if not array_type:
        array_type = "double precision[]"
    else:
        array_type = array_type.strip()
        if not array_type.endswith("[]"):
            array_type += "[]"

    if not array:
        return "'{{ }}'::{0}".format(array_type)
    else:
        quote_delimiter = "$__MADLIB_OUTER__$"
        # This is a quote delimiter that can be used in lieu of
        # single quotes and allows the use of single quotes in the
        # string without escaping
        array_str = "ARRAY[ {val} ]" if long_format else "{qd}{{ {val} }}{qd}"
        return (array_str + "::{array_type}").format(
            val=','.join(map(str, array)),
            array_type=array_type,
            qd=quote_delimiter)
# ------------------------------------------------------------------------

def create_cols_from_array_sql_string(py_list, sql_array_col, colname,
                                      coltype, has_one_ele,
                                      module_name='Input Error'):
    """
    Create SQL string to convert array of elements into multiple columns and corresponding
    SQL string of columns for CREATE TABLE.
    @args:
        @param: py_list, python list, if None, return sql_array_col as colname.
                            The py_list can at most have one 'None' element that
                            is converted to sql 'NULL'
        @param: sql_array_col, str pointing to a column in table containing
                               an array.
        @param: colname, name of output column (can be treated as prefix
                         to multiple cols if has_one_ele=False)
        @param: coltype, Type of columns to be created
        @param: has_one_ele, bool. if True, assumes sql_array_col has
                    an array of exactly one element, which is treated as in
                    index to get value from py_list. If False, then a new
                    column is created for every element in py_list,
                    whose corresponding values are obtained from sql_array_col.
    @examples:
        1) Input:
                py_list = ['cat', 'dog']
                sql_array_col = sqlcol
                colname = prob
                coltype = TEXT
                has_one_ele = FALSE
            Output:
                prob_cat TEXT, prob_dog TEXT
                CAST(sqlcol[1] AS TEXT) AS prob_cat, CAST(sqlcol[2] AS TEXT) AS prob_dog
        2) Input:
                py_list = ['cat', 'dog']
                sql_array_col = sqlcol
                colname = estimated_pred
                coltype = TEXT
                has_one_ele = TRUE
            Output:
                estimated_pred TEXT
                (ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred

    @NOTE:
        If py_list is [None, 'cat', 'dog', NULL']:
        then the SQL query string returned would create the following
        column names:
            prob_NULL, prob_cat, 'prob_dog', and 'prob_"NULL"'.
        1. Notice that for None, which represents Postgres' NULL value, the
        column name will be 'prob_NULL',
        2. and to differentiate the column name for a string 'NULL', the
        resulting column name will be 'prob_"NULL"'.

        The weird quoting in this column name is due to calling strip after
        quote_ident in the code below.

    @returns:
        @param, str, that can be used in a SQL query.
        @param, str, that can be used in a SQL query.

    """
    _assert(sql_array_col, "{0}: sql_array_col should be a valid string.".
        format(module_name))
    _assert(colname, "{0}: colname should be a valid string.".format(
        module_name))
    if py_list:
        _assert(is_valid_psql_type(coltype, BOOLEAN | NUMERIC | TEXT),
            "{0}: Invalid coltype parameter {1}.".format(
                module_name, coltype))
        _assert(py_list.count(None) <= 1,
                "{0}: Input list should contain at most 1 None element.".
                    format(module_name))
        def py_list_str(ele):
            """
                A python None is converted to a SQL NULL.
                String 'NULL' is converted to SQL 'NULL' string by quoting
                it to '"NULL"'. This quoting is necessary for Postgres to
                differentiate between NULL and 'NULL' in the SQL query
                string returned by create_cols_from_array_sql_string.
            """
            if ele is None:
                return 'NULL'
            elif isinstance(ele, str) and ele.lower()=='null':
                return '"{0}"'.format(ele)
            return ele

        py_list = list(map(py_list_str, py_list))
        if has_one_ele:
            # Query to choose the value in the first element of
            # sql_array_col which is the index to access in py_list.
            # The value from that corresponding index in py_list is
            # the value of colname.
            py_list_sql_str = py_list_to_sql_string(py_list, coltype+'[]')
            select_clause = "({0})[{1}[1]+1]::{2} AS {3}".format(
                py_list_sql_str, sql_array_col, coltype, colname)
            create_columns = "{0} {1}".format(
                colname, coltype)
        else:
            # Create as many columns as the length of py_list. The
            # colnames are created based on the elements in py_list,
            # while the value for these colnames is obtained from
            # sql_array_col.

            # we cannot call sql quote_ident on the py_list entries because
            # aliasing does not support quote_ident. Hence calling our
            # python implementation of quote_ident. We must call strip()
            # after quote_ident since the resulting SQL query fails otherwise.
            select_clause = ', '.join(
                ['CAST({sql_array_col}[{j}] AS {coltype}) AS "{final_colname}"'.
                    format(j=i + 1,
                           final_colname=quote_ident("{0}_{1}".
                               format(colname, str(suffix))).strip(' "'),
                           sql_array_col=sql_array_col,
                           coltype=coltype)
                for i, suffix in enumerate(py_list)
                ])
            create_columns = ', '.join(
                ['"{final_colname}" {coltype}'.
                     format(final_colname=quote_ident("{0}_{1}".
                                                      format(colname, str(suffix))).strip(' "'),
                            coltype=coltype)
                 for i, suffix in enumerate(py_list)
                 ])
    else:
        if has_one_ele:
            select_clause = '{0}[1]+1 AS {1}'.format(sql_array_col, colname)
            create_columns = '{0} {1}'.format(colname, coltype)
        else:
            select_clause = '{0} AS {1}'.format(sql_array_col, colname)
            create_columns = '{0} {1}'.format(colname, coltype+'[]')
    return select_clause, create_columns
# ------------------------------------------------------------------------

def _array_to_string(origin):
    """
    Convert an array to string
    """
    def _escape(s):
        return re.sub(r'"', r'\"', str(s))
    return "{" + ",".join(map(_escape, origin)) + "}"
# ------------------------------------------------------------------------


def _cast_if_null(input, alias=''):
    if input:
        return str(input)
    else:
        null_str = "NULL::text"
        return null_str + " as " + alias if alias else null_str
# ------------------------------------------------------------------------


def set_client_min_messages(new_level):
    """
    Set the client_min_message setting in Postgres which controls the messages
    sent to the psql client.

    Args:
        @param new_level: string, New level to set the client_min_message

    Returns:
        old_msg_level: string, The old client_min_message level before changing it.

    Raise:
        ValueError: if the argument new_level is not a valid message level
    """
    if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2',
                                 'debug1', 'log', 'notice', 'warning', 'error',
                                 'fatal', 'panic'):
        raise ValueError("Not a valid message level sent to client")

    old_msg_level = plpy.execute(""" SELECT setting
                                  FROM pg_settings
                                  WHERE name='client_min_messages'
                                  """)[0]['setting']
    plpy.execute("SET client_min_messages TO {0}".format(new_level))
    return old_msg_level
# -------------------------------------------------------------------------

def is_pg_major_version_less_than(schema_madlib, compare_version, **kwargs):
    version = plpy.execute("select version()")[0]["version"]
    regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
    version = regex.findall(version)
    if len(version) > 0 and int(version[0][0]) < compare_version:
        return True
    else:
        return False

# Deal with earlier versions of PG or GPDB
class __mad_version:
    def __init__(self):
        self.version = plpy.execute("select version()")[0]["version"]

    def select_vecfunc(self):
        """
        PG84 and GP40, GP41 do not have a good support for
        vectors. They convert any vector into a string, surrounded
        by { and }. Thus special care is needed for these older
        versions of GPDB and PG.
        """
        # GPDB 4.0 or 4.1
        if self.is_less_than_gp42() or self.is_less_than_pg90():
            return self.__extract
        else:
            return self.__identity

    def __extract(self, origin, text=True):
        """
        Extract vector elements from a string with {}
        as the brackets
        """
        if origin is None:
            return None
        elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1))
        if text is False:
            for i in range(len(elm)):
                elm[i] = float(elm[i])
        return elm

    def __identity(self, origin, text=True):
        return origin

    def select_vec_return(self):
        """
        Special care is needed if one needs to return
        vector from Python to SQL
        """
        if self.is_less_than_gp42() or self.is_less_than_pg90():
            return self.__condense
        else:
            return self.__identity

    def __condense(self, origin):
        """
        Convert the original vector into a string which some
        old versions of SQL system can recognize
        """
        return _array_to_string(origin)

    def select_array_agg(self, schema_madlib):
        """
        GPDB < 4.1 and PG < 9.0 do not have support for array_agg,
        so use the madlib array_agg for those versions
        """
        if self.is_less_than_gp41() or self.is_less_than_pg90():
            return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib)
        else:
            return "array_agg"

    def is_pg(self):
        if (re.search(r"PostgreSQL", self.version) and
                not re.search(r"Greenplum\s*Database", self.version)):
            return True
        return False

    def is_gp43(self):
        if re.search(r"Greenplum\s+Database\s+4\.3", self.version):
            return True
        return False

    def is_less_than_pg90(self):
        regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
        version = regex.findall(self.version)
        if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9:
            return True
        else:
            return False

    def is_less_than_gp41(self):
        regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
        version = regex.findall(self.version)
        if len(version) > 0 and float(version[0]) < 4.1:
            return True
        else:
            return False

    def is_less_than_gp42(self):
        regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
        version = regex.findall(self.version)
        if len(version) > 0 and float(version[0]) < 4.2:
            return True
        else:
            return False

    def is_pg_version_less_than(self, compare_version):
        """ Return True if self is a PostgreSQL database and
        self.version is less than compare_version

        @param compare_version: str, String form of the comparison version. Expected
                                    format is Semantic Versioning.
                                    examples of versions: 1.0, 2.3, 9.3.5

        """
        regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE)
        version = regex.findall(self.version)
        if len(version) > 0 and self.is_pg():
            db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
            cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
            return db_ver < cmp_ver
        else:
            return False

    def is_gp_version_less_than(self, compare_version):
        """ Return True if self is a Greenplum database and self.version
        is less than compare_version

        @param compare_version: str, String form of the comparison version. Expected
                                    format is Semantic Versioning.
                                    examples of versions: 1.0, 2.3, 9.3.5
        """
        regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE)
        version = regex.findall(self.version)
        if version:
            db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
            cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
            return db_ver < cmp_ver
        else:
            return False


def _string_to_sql_array(schema_madlib, s, **kwargs):
    """
    Split a string into an array of strings
    Any space around the substrings are removed

    Requirement: every individual element in the string
    must be a valid Postgres name, which means that if
    there are spaces or commas in the element then the
    whole element must be quoted by a pair of double
    quotes.

    Usually this is not a problem. Especially in older
    versions of GPDB, when an array is passed from
    SQL to Python, it is converted to a string, which
    automatically adds double quotes if there are spaces or
    commas in the element.

    So use this function, if you are sure that all elements
    are valid Postgres names.
    """
    # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
    version_wrapper = __mad_version()
    array_to_string = version_wrapper.select_vec_return()

    elm = []
    for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
        elm.append(m.group(1))
    for i in range(len(elm)):
        elm[i] = elm[i].strip("\"")
    return array_to_string(elm)
# ------------------------------------------------------------------------


def current_user():
    """Returns the user name of the current database user."""
    return plpy.execute("SELECT current_user")[0]['current_user']
# ------------------------------------------------------------------------

def is_superuser(user):

    return plpy.execute("SELECT rolsuper FROM pg_catalog.pg_roles "\
                        "WHERE rolname = '{0}'".format(user))[0]['rolsuper']

def madlib_version(schema_madlib):
    """Returns the MADlib version string."""
    raw = plpy.execute("""
            SELECT {schema_madlib}.version()
            """.format(**locals()))[0]['version']
    return raw.split(',')[0].split(' ')[-1]
# ------------------------------------------------------------------------


def preprocess_keyvalue_params(input_params, split_char='='):
    """
    Parse the input_params string and split it using the split_char

    @param input_params: str, Comma-separated list of parameters
        The parameter can be any key = value, where
            key is a string of alphanumeric character
            value is either and
    @param split_char: str, character that splits the key and value elements.
                        Default set to '='
    """
    re_str = (r"([-:\w]+\s*" +    # key is any alphanumeric character
                                  # (including - and :) string

              split_char +        # key and value are separated by split_char

              r"""
                \s*([\(\{\[]      # value can be string or array
                [^\[\]\(\)\{\}]*  #  if value is array then accept anything inside
                [\)\}\]]          #  and then match closing braces of array

                   |              # either array (above) or string (below)

                (?P<quote>\"?)[\w\s\-\%.]+(?P=quote)
                                 #  if value is string, it can be alphanumeric
                                 #    character string with a decimal dot,
                                 #    hyphen, or percent
                                 #    optionally quoted by `quote_char`
                  )
               )"""
              )
    pattern = re.compile(re_str, re.VERBOSE)
    return [m.group(1).strip() for m in pattern.finditer(input_params)]
# ------------------------------------------------------------------------


def extract_keyvalue_params(input_params,
                            input_param_types=None,
                            default_values=None,
                            split_char='=',
                            usage_str='',
                            ignore_invalid=False,
                            allow_duplicates=True,
                            lower_case_names=True):
    """ Extract key value pairs from input parameters or set the default values

    Args:
        @param input_params: string, Format of
                    'key1=value1, key2=value2,...', assuming default split_char.
                    The order does not matter. If a parameter is missing, then
                    the default value is used. If input_params is None or '',
                    then all default values are returned. This function also
                    validates the values of these parameters.

        @param input_param_types: dict, The type of each allowed parameter
                                            name. Currently supports one of
                                            (int, float, str, list)
        @param default_values: dict, Default values for each allowed parameter.
        @param split_char: str, The character used to split key and value.
                            Default set to '='
        @param usage_str: str, An optional usage string to print with error message.
        @param ignore_invalid: bool, If True an invalid param input is ignore silently
        @param allow_duplicates: bool, Allow repeat of a 'key' in input_params,
                                where the last occurence will be reflected
                                in the output. If False, then a ValueError is
                                raised.
        @param lower_case_names: bool, Convert parameter names to lower case.

    Returns:
        Dict. Dictionary of input parameter values with key as parameter name
        and value as the parameter value

    Throws:
        plpy.error - If the parameter is unsupported or the value is
        not valid.
    """
    if not input_params:
        return default_values if default_values is not None else {}

    if default_values:
        parameter_dict = default_values
    else:
        parameter_dict = {}
    seen_params = set()

    for s in preprocess_keyvalue_params(input_params, split_char=split_char):
        items = split_quoted_delimited_str(s, delimiter=split_char)
        if (len(items) != 2):
            raise KeyError("Input parameter list has incorrect format "
                           "{0}".format(usage_str))

        param_name = items[0].strip(" \"")
        if lower_case_names:
            param_name = param_name.lower()
        param_value = items[1].strip()

        if not allow_duplicates and param_name in seen_params:
            raise ValueError("Invalid input: {0} duplicated in the param list".
                             format(param_name))

        if not param_name or param_name in ('none', 'null'):
            plpy.error("Invalid input param name: {0} \n"
                       "{1}".format(param_name, usage_str))
        if input_param_types:
            try:
                param_type = input_param_types[param_name]
            except KeyError:
                if not ignore_invalid:
                    raise KeyError("Invalid input: {0} is not a valid parameter "
                                   "{1}".format(param_name, usage_str))
                else:
                    continue
            try:
                if param_type == bool:  # bool is not subclassable
                    #  True values are y, yes, t, true, on and 1;
                    #  False values are n, no, f, false, off and 0.
                    #  Raises ValueError if anything else.
                    parameter_dict[param_name] = bool(strtobool(param_value))
                elif param_type in (int, str, float):
                    parameter_dict[param_name] = param_type(param_value)
                elif issubclass(param_type, collections.Iterable):
                    parameter_dict[param_name] = split_quoted_delimited_str(
                        param_value.strip('[](){} '))
                else:
                    raise TypeError("Invalid input: {0} has unsupported type "
                                    "{1}".format(param_name, usage_str))
            except ValueError:
                raise ValueError("Invalid input: {0} must be {1} \n"
                                 "{2}".format(param_name, param_type, usage_str))
        else:
            # if input parameter types not provided then just return all as string
            parameter_dict[param_name] = str(param_value)
        seen_params.add(param_name)
    return parameter_dict
# -------------------------------------------------------------------------


def split_quoted_delimited_str(input_str, delimiter=',', quote='"'):
    """ Parse a delimited-string to return a list of individual tokens taking
        quotes into account.

    Args:
        @param input_str: str, Delimited input string
        @param delimiter: str, The field delimiter character that separates
                                the tokens in input_str. Default = ','
        @param quote: str, One-character string used to quote fields containing
                                special characters, such as the field delimiter
                                Default = '"'

    Returns:
        List. List of delimited strings.
    """
    if not input_str or not delimiter or not quote:
        return []
    try:
        delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'.
                                   format(d=delimiter, q=quote))
        return [i.strip() for i in
                delimiter_reg.split(input_str.strip())[1::2]]
    except Exception as e:
        plpy.warning(str(e))
        raise ValueError("Invalid string input for splitting")
# ------------------------------------------------------------------------------


def strip_end_quotes(input_str, quote='"'):
    """ Remove the quote character from the start and end if they are present
    (at both ends). Whitespace at start and end are not ignored.

    Args:
        @param input_str
        @param quote

    Returns:
        str. Original string without the quotes at start and end
    """
    if not input_str or not quote:
        return input_str
    if not isinstance(input_str, str):
        return input_str
    if input_str.startswith(quote) and input_str.endswith(quote):
        return input_str[1:-1]
    else:
        return input_str
# ------------------------------------------------------------------------------


def _grp_null_checks(grp_list):
    """
    Helper function for generating NULL checks for grouping columns
    to be used within a WHERE clause
    Args:
        @param grp_list   The list of grouping columns
    """
    return ' AND '.join([" {i} IS NOT NULL ".format(**locals())
                         for i in grp_list])
# ------------------------------------------------------------------------------


def _check_groups(tbl1, tbl2, grp_list):
    """
    Helper function for joining tables with groups.
    Args:
            @param tbl1       Name of the first table
            @param tbl2       Name of the second table
            @param grp_list   The list of grouping columns
    """

    return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals())
                         for i in grp_list])
# ------------------------------------------------------------------------------


def get_filtered_cols_subquery_str(include_from_table, exclude_from_table,
                                   filter_cols_list):
    """
    This function returns a subquery string with columns in the filter_cols_list
    that appear in include_from_table but NOT IN exclude_from_table.
    :param include_from_table: table from which cols in filter_cols_list should
                               be included
    :param exclude_from_table: table from which cols in filter_cols_list should
                               be excluded
    :param filter_cols_list: list of column names
    :return: query string with relevant column names
    """
    included_cols = get_table_qualified_col_str(include_from_table, filter_cols_list)
    cols = ', '.join([col for col in filter_cols_list])
    return """({included_cols}) NOT IN
                (SELECT {cols}
                 FROM {exclude_from_table})
           """.format(**locals())
# ------------------------------------------------------------------------------


def get_table_qualified_col_str(tbl_name, col_list):
    """
    Helper function for selecting columns of a table
    Args:
            @param tbl        Name of the table
            @param grp_list   The list of grouping columns
    """
    return ' , '.join([" {tbl_name}.{col} ".format(**locals())
                       for col in col_list])
# ------------------------------------------------------------------------------


def get_grouping_col_str(schema_madlib, module_name, reserved_cols,
                         source_table, grouping_col):
    if grouping_col and grouping_col.lower() != 'null':
        grouping_col_array = _string_to_array_with_quotes(grouping_col)
        cols_in_tbl_valid(source_table, grouping_col_array, module_name)
        does_exclude_reserved(grouping_col_array, reserved_cols)
        grp_array_w_cast = explicit_bool_to_text(source_table,
                                                 grouping_col_array,
                                                 schema_madlib)
        grouping_str = ', '.join(i + "::text" for i in grp_array_w_cast)
    else:
        grouping_str = "Null"
        grouping_col = None

    return grouping_str, grouping_col
# ------------------------------------------------------------------------------


def collate_plpy_result(plpy_result_rows):
    if not plpy_result_rows:
        return {}
    else:
        all_keys = plpy_result_rows[0].keys()
        result = collections.defaultdict(list)
        for each_row in plpy_result_rows:
            for each_key in all_keys:
                result[each_key].append(each_row[each_key])
    return result
# ------------------------------------------------------------------------------


def validate_module_input_params(source_table, output_table, independent_varname,
                                 dependent_varname, module_name,
                                 grouping_cols=None, other_output_tables=None):
    """
    This function is supposed to be used for validating params for
    supervised learning like algos, e.g. linear regression, mlp, etc. since all
    of them need to validate the following 4 parameters.
    :param source_table: This table should exist and not be empty
    :param output_table: This table should not exist
    :param dependent_varname: This should be a valid expression in the source
                               table
    :param independent_varname: This should be a valid expression in the source
                               table
    :param module_name: Name of the module to be printed with the error messages
    :param other_output_tables: List of additional output tables to validate.
                                These tables should not exist
    """

    input_tbl_valid(source_table, module_name)

    output_tbl_valid(output_table, module_name)

    if other_output_tables:
        for tbl in other_output_tables:
            output_tbl_valid(tbl, module_name)

    _assert(is_var_valid(source_table, independent_varname),
            "{module_name} error: invalid independent_varname "
            "('{independent_varname}') for source_table "
            "({source_table})!".format(module_name=module_name,
                                       independent_varname=independent_varname,
                                       source_table=source_table))

    _assert(is_var_valid(source_table, dependent_varname),
            "{module_name} error: invalid dependent_varname "
            "('{dependent_varname}') for source_table "
            "({source_table})!".format(module_name=module_name,
                                       dependent_varname=dependent_varname,
                                       source_table=source_table))
    if grouping_cols:
        _assert(is_var_valid(source_table, grouping_cols),
                "{module_name} error: invalid grouping_cols "
                "('{grouping_cols}') for source_table "
                "({source_table})!".format(module_name=module_name,
                                           grouping_cols=grouping_cols,
                                           source_table=source_table))
# ------------------------------------------------------------------------


def create_table_drop_cols(source_table, out_table, cols_to_drop, **kwargs):
    """ Create copy of table while dropping some of the columns
    Args:
        @param source_table str. Name of the source table
        @param out_table str. Name of the output table
        @param cols_to_drop str. Comma-separated list of columns to drop
    """
    input_tbl_valid(source_table, 'Utilities')
    output_tbl_valid(out_table, 'Utilities')
    _assert(cols_to_drop and cols_to_drop.strip(),
            "Utilities error: cols_to_drop cannot be empty or NULL")

    source_table_cols = get_cols(source_table)
    cols_to_drop_list = split_quoted_delimited_str(cols_to_drop)
    cols_not_in_source = set(cols_to_drop_list) - set(source_table_cols)
    _assert(not cols_not_in_source,
            "Utilities error: Some column(s) in cols_to_drop are not present "
            "in source table")

    cols_to_retain = [c for c in source_table_cols if c not in cols_to_drop_list]
    _assert(cols_to_retain,
            "Utilities error: No valid columns for the output table")
    plpy.execute("""
        CREATE TABLE {out_table} AS
        SELECT {cols}
        FROM {source_table}
        """.format(cols=', '.join(cols_to_retain),
                   out_table=out_table,
                   source_table=source_table))
# ------------------------------------------------------------------------------


def rotate(l, n):
    """Summary
    Rotate the list l to right(the index increasing direction) for n elements.
    Args:
        l (list): The input list to rotate
        n (integer): The number of elements to rotate

    Returns:
        list: The rotated list
    """
    return l[-n:] + l[:-n]
# ------------------------------------------------------------------------------

def rename_table(schema_madlib, orig_name, new_name):
    """
    Renames possibly schema qualified table name to a new schema qualified name
    ensuring the schema qualification are changed appropriately

    Args:
        @param orig_name: string, Original name of the table
                          (must be schema qualified if table schema is not in search path)
        @param new_name: string, New name of the table
                          (can be schema qualified. If it is not then the original
                           schema is maintained)
    Returns:
        String. The new table name qualified with the schema name
    """
    new_names_split = new_name.split(".")
    if len(new_names_split) > 2:
        raise AssertionError("Invalid table name")
    new_table_name = new_names_split[-1]
    new_table_schema = new_names_split[0] if len(new_names_split) > 1 else None

    orig_names_split = orig_name.split(".")
    if len(orig_names_split) > 2:
        raise AssertionError("Invalid table name")

    if len(orig_names_split) > 1:
        orig_table_schema = orig_names_split[0]
    else:
        # we need to get the schema name of the original table if we are
        # to change the schema of the new table. This is to ensure that we
        # change the schema of the correct table in case there are multiple
        # tables with the same new name.
        orig_table_schema = get_first_schema(orig_name)

    if orig_table_schema is None:
        raise AssertionError("Relation {0} not found during rename".
                             format(orig_name))
    return __do_rename_and_get_new_name(orig_name, new_name, orig_table_schema,
                                        new_table_schema, new_table_name)
# ------------------------------------------------------------------------------
def __do_rename_and_get_new_name(orig_name, new_name, orig_table_schema,
                                 new_table_schema, new_table_name):
    """
    Internal private function to perform the rename operation after all the
    validation checks
    """

    """
    CASE 1
    If the output table is schema is pg_temp, we cannot alter table schemas from/to
    temp schemas. If it looks like a temp schema, we stay safe and just use
    create/drop
        Test cases
        foo.bar to pg_temp.bar
        foo.bar to pg_temp.bar2
        foo to pg_temp.bar
        pg_temp.foo to pg_temp.bar
    """
    if new_table_schema and 'pg_temp' in new_table_schema:
        """
        If both new_table_schema and orig_table_schema have pg_temp in it,
        just run an alter statement instead of CTAS. Without this, pca dev-check
        fails on gpdb5/6 (but not on pg)
        """
        if new_table_schema != orig_table_schema:
            plpy.info("""CREATE TABLE {new_name} AS SELECT * FROM {orig_name}"""
                      .format(**locals()))
            plpy.execute("""CREATE TABLE {new_name} AS SELECT * FROM {orig_name}"""
                         .format(**locals()))
            drop_tables([orig_name])
            return new_name
        else:
            plpy.execute("ALTER TABLE {orig_name} RENAME TO {new_table_name}".
                         format(**locals()))
            return new_name

    """
    CASE 2
    Do direct rename if the new table does not have an output schema or
    if the new table schema is the same as the original table schema
    Test Cases
    rename foo to bar
    rename foo.bar to foo.bar2
    rename foo.bar to bar2
    """
    if not new_table_schema or new_table_schema == orig_table_schema:
        plpy.execute("ALTER TABLE {orig_name} RENAME TO {new_table_name}".
                     format(**locals()))
        return orig_table_schema + "." + new_table_name

    """
    CASE 3
    output table is schema qualified
    1. rename the original table to an interim temp name
    2. set the new schema on that interim table
    3. rename interim table to the new table name
    Test cases
    foo.bar to foo2.bar2
    foo.bar to foo2.bar
    """
    interim_temp_name = unique_string("rename_table_helper")
    plpy.execute(
        "ALTER TABLE {orig_name} RENAME to {interim_temp_name}".format(
            **locals()))

    plpy.execute(
        """ALTER TABLE {interim_temp_name} SET SCHEMA {new_table_schema}""".format(
            **locals()))

    plpy.execute(
        """ALTER TABLE {new_table_schema}.{interim_temp_name} RENAME to {new_table_name}"""
        .format(**locals()))
    return new_name
# ------------------------------------------------------------------------------

def is_platform_gp6_or_up():
    version_wrapper = __mad_version()
    return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0')

def get_psql_type(py_type):
    if type(py_type) == int:
        return 'integer'
    elif type(py_type) == float:
        return 'double precision'
    elif type(py_type) == bool:
        return 'boolean'
    elif type(py_type) == str:
        return 'varchar'
    else:
        plpy.error("Cannot determine the type of {0}".format(py_type))


def get_schema(tbl_str):

    names = tbl_str.split('.')

    if not names or len(names) > 2:
        raise TypeError("Incorrect table name ({0}) provided! Table name should be "
                        "of the form: <schema name>.<table name>".format(table_name))
    elif len(names) == 2:
        return unquote_ident(names[0])

    else:
        return None
# -------------------------------------------------------------------------------

def get_current_timestamp(format):
    """Gets current time stamp in the specified format string"""
    return datetime.fromtimestamp(time.time()).strftime(format)
