| |
| import collections |
| from datetime import datetime |
| import re |
| import time |
| import random |
| from distutils.util import strtobool |
| # import numpy as np |
| |
| from validate_args import _get_table_schema_names |
| from validate_args import cols_in_tbl_valid |
| from validate_args import does_exclude_reserved |
| from validate_args import explicit_bool_to_text |
| from validate_args import get_cols |
| from validate_args import get_first_schema |
| from validate_args import input_tbl_valid |
| from validate_args import is_var_valid |
| from validate_args import output_tbl_valid |
| from validate_args import quote_ident |
| from validate_args import unquote_ident |
| from validate_args import drop_tables |
| import plpy |
| |
| m4_changequote(`<!', `!>') |
| |
| |
| def has_function_properties(): |
| """ __HAS_FUNCTION_PROPERTIES__ variable defined during configure """ |
| return m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>) |
| |
| |
| def is_platform_pg(): |
| """ __POSTGRESQL__ variable defined during configure """ |
| return m4_ifdef(<!__POSTGRESQL__!>, <!True!>, <!False!>) |
| |
| def is_platform_gp6_or_up(): |
| version_wrapper = __mad_version() |
| return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0') |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_seg_number(): |
| """ Find out how many primary segments(not include master segment) exist |
| in the distribution. Might be useful for partitioning data. |
| """ |
| if is_platform_pg(): |
| return 1 |
| else: |
| count = plpy.execute(""" |
| SELECT count(*) from gp_segment_configuration |
| WHERE role = 'p' and content != -1 |
| """)[0]['count'] |
| # in case some weird gpdb configuration happens, always returns |
| # primary segment number >= 1 |
| return max(1, count) |
| # ------------------------------------------------------------------------------ |
| |
| def get_segments_per_host(): |
| """ Find out how many primary segments(not include master segment) exist |
| per host. We assume every host has the same number of segments and |
| we only return the first one. |
| """ |
| if is_platform_pg(): |
| return 1 |
| else: |
| count = plpy.execute(""" |
| SELECT count(*) from gp_segment_configuration |
| WHERE role = 'p' and content != -1 |
| GROUP BY hostname |
| LIMIT 1 |
| """)[0]['count'] |
| # in case some weird gpdb configuration happens, always returns |
| # primary segment number >= 1 |
| return max(1, count) |
| # ------------------------------------------------------------------------------ |
| |
| def is_orca(): |
| if has_function_properties(): |
| optimizer = plpy.execute("show optimizer")[0]["optimizer"] |
| if optimizer == 'on': |
| return True |
| return False |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _assert_equal(o1, o2, msg): |
| """ |
| @brief if the given objects are not equal, then raise an error with the message |
| @param o1 the first object |
| @param o2 the second object |
| @param msg the error message to be reported |
| """ |
| if not o1 == o2: |
| plpy.error(msg) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _assert(condition, msg): |
| """ |
| @brief if the given condition is false, then raise an error with the message |
| @param condition the condition to be asserted |
| @param msg the error message to be reported |
| """ |
| if not condition: |
| plpy.error(msg) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def warn(condition, msg): |
| """ |
| @brief if the given condition is false, then raise a warning with the message |
| @param condition the condition to be asserted |
| @param msg the error message to be reported |
| """ |
| if not condition: |
| plpy.warning(msg) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_distributed_by(source_table): |
| """ Return a "distributed by (...)" clause that defines distribution policy of source_table |
| Args: |
| @param source_table |
| |
| Returns: |
| List of str. |
| """ |
| _, table_name = _get_table_schema_names(source_table) |
| schema_name = get_first_schema(source_table) |
| |
| # GPDB 6 has pg_get_table_distributedby(<oid>) function to get the |
| # DISTRIBUTED BY clause of a table. In older version, we have to |
| # dig the column names based on gp_distribution_policy catalog. |
| version_wrapper = __mad_version() |
| if version_wrapper.is_gp_version_less_than("6.0"): |
| dist_attr = plpy.execute(""" |
| SELECT array_agg(pga.attname) as dist_attr |
| FROM ( |
| SELECT gdp.localoid, |
| CASE |
| WHEN ( ARRAY_UPPER(gdp.attrnums, 1) > 0 ) THEN |
| UNNEST(gdp.attrnums) |
| ELSE NULL |
| END AS attnum |
| FROM gp_distribution_policy gdp |
| ) AS distkey |
| INNER JOIN pg_class AS pgc |
| ON distkey.localoid = pgc.oid AND pgc.relname = '{table_name}' |
| INNER JOIN pg_namespace pgn |
| ON pgc.relnamespace = pgn.oid AND pgn.nspname = '{schema_name}' |
| LEFT OUTER JOIN pg_attribute pga |
| ON distkey.attnum = pga.attnum AND distkey.localoid = pga.attrelid |
| """.format(table_name=table_name, schema_name=schema_name))[0]["dist_attr"] |
| if len(dist_attr) > 0: |
| dist_str = 'distributed by (' + ','.join(['"%s"' % i |
| for i in dist_attr |
| if i is not None]) + ')' |
| else: |
| dist_str = 'distributed randomly' |
| else: |
| dist_str = plpy.execute(""" |
| SELECT pg_catalog.pg_get_table_distributedby(pgc.oid) as distributedby |
| FROM pg_class AS pgc |
| INNER JOIN pg_namespace pgn ON pgc.relnamespace = pgn.oid |
| WHERE pgc.relname = '{table_name}' AND pgn.nspname = '{schema_name}' |
| """.format(table_name=table_name, schema_name=schema_name))[0]["distributedby"] |
| |
| return dist_str |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| def num_features(source_table, independent_varname): |
| return plpy.execute("SELECT array_upper({0}, 1) AS dim " |
| "FROM {1} LIMIT 1" |
| .format(independent_varname, |
| source_table))[0]['dim'] |
| # ------------------------------------------------------------------------------ |
| |
| |
| def num_samples(source_table): |
| return plpy.execute("SELECT count(*) AS n FROM {0}" |
| .format(source_table))[0]['n'] |
| # ------------------------------------------------------------------------------ |
| |
| |
| def unique_string(desp='', **kwargs): |
| """ |
| Generate random remporary names for temp table and other names. |
| It has a SQL interface so both SQL and Python functions can call it. |
| """ |
| r1 = random.randint(1, 100000000) |
| r2 = int(time.time()) |
| r3 = int(time.time()) % random.randint(1, 100000000) |
| u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__" |
| return u_string |
| # ------------------------------------------------------------------------------ |
| |
| |
| def add_postfix(quoted_string, postfix): |
| """ Append a string to the end of the table name. |
| If input table name is quoted by double quotes, make sure the postfix is |
| inside of the double quotes. |
| |
| Arguments: |
| @param quoted_string: str. A string representing a database quoted string |
| @param postfix: str. A string to add as a suffix to quoted_string. |
| ** This is assumed to not contain any quotes ** |
| """ |
| quoted_string = quoted_string.strip() |
| if quoted_string.startswith('"') and quoted_string.endswith('"'): |
| output_str = quoted_string[:-1] + postfix + '"' |
| else: |
| output_str = quoted_string + postfix |
| return output_str |
| # ------------------------------------------------------------------------- |
| |
| |
| NUMERIC = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric', |
| 'real', 'double precision', 'float', 'serial', 'bigserial']) |
| INTEGER = set(['smallint', 'integer', 'bigint']) |
| TEXT = set(['text', 'varchar', 'character varying', 'char', 'character']) |
| BOOLEAN = set(['boolean']) |
| INCLUDE_ARRAY = set([unique_string('__include_array__')]) |
| ONLY_ARRAY = set([unique_string('__only_array__')]) |
| ANY_ARRAY = set([unique_string('__any_array__')]) |
| |
| |
| def is_valid_psql_type(arg, valid_types): |
| """ Verify if argument is a valid type |
| |
| Args: |
| @param arg: str. Name of the Postgres type to validate |
| @param valid_types: set. Set of valid type names to search. |
| This is typically created using the global names |
| in this module. |
| Three non-type flags are provided |
| (in descending order of precedence): |
| - ANY_ARRAY: check if arg is any array type |
| - ONLY_ARRAY: indicates that only array forms |
| of the valid types should be checked |
| - INCLUDE_ARRAY: indicates that array and scalar |
| forms of the valid types should be checked |
| Examples: 1. valid_types = BOOLEAN | INTEGER | TEXT |
| 2. valid_types = BOOLEAN | INTEGER | ONLY_ARRAY |
| 3. valid_types = NUMERIC | INCLUDE_ARRAY |
| """ |
| if not arg or not valid_types: |
| return False |
| arg = arg.lower() |
| if ANY_ARRAY <= valid_types: |
| return arg.rstrip().endswith('[]') |
| if ONLY_ARRAY <= valid_types: |
| return (arg.rstrip().endswith('[]') and arg.rstrip('[] ') in valid_types) |
| if INCLUDE_ARRAY <= valid_types: |
| # Remove the [] from end of the arg type |
| # The single space is needed to ensure trailing white space is stripped |
| arg = arg.rstrip('[] ') |
| return (arg in valid_types) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def is_psql_numeric_type(arg, exclude=None): |
| """ |
| Checks if argument is one of the various numeric types in PostgreSQL |
| Args: |
| @param arg: string, Type name to check |
| @param exclude: iterable, List of types to exclude from checking |
| |
| Returns: |
| Boolean. Returns if 'arg' is one of the numeric types |
| """ |
| if exclude is None: |
| exclude = [] |
| to_check_types = NUMERIC - set(exclude) |
| return (arg in to_check_types) |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_psql_int_type(arg, exclude=None): |
| """ |
| Checks if argument is one of the various numeric types in PostgreSQL |
| Args: |
| @param arg: string, Type name to check |
| @param exclude: iterable, List of types to exclude from checking |
| |
| Returns: |
| Boolean. Returns if 'arg' is one of the numeric types |
| """ |
| if exclude is None: |
| to_check_types = INTEGER |
| else: |
| to_check_types = INTEGER - set(exclude) |
| return (arg in to_check_types) |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_psql_char_type(arg, exclude_list=[]): |
| """ |
| This function checks if the given arg is one of the predefined postgres |
| character types |
| :param arg: |
| :param exclude: Optionally exclude one or more types from the comparison |
| :return: True if it is one of the character types, else False. |
| |
| """ |
| if not isinstance(exclude_list, list): |
| exclude_list = [exclude_list] |
| return arg in TEXT - set(exclude_list) |
| |
| |
| def is_psql_boolean_type(arg): |
| """ |
| This function checks if the given arg is one of type postgres boolean |
| :param arg: |
| :return: True if it is boolean, else False. |
| """ |
| return arg == 'boolean' |
| |
| |
| def is_string_formatted_as_array_expression(string_to_match): |
| """ |
| Return true if the string is formatted as array[<something>], else false |
| :param string_to_match: |
| """ |
| matched = re.match(r"(?i)^array\[(.*)\]", string_to_match) |
| return matched |
| |
| |
| def _string_to_array(s): |
| """ |
| Split a string into an array of strings |
| Any space around the substrings are removed |
| |
| Requirement: every individual element in the string |
| must be a valid Postgres name, which means that if |
| there are spaces or commas in the element then the |
| whole element must be quoted by a pair of double |
| quotes. |
| |
| Usually this is not a problem. Especially in older |
| versions of GPDB, when an array is passed from |
| SQL to Python, it is converted to a string, which |
| automatically adds double quotes if there are spaces or |
| commas in the element. |
| |
| So use this function, if you are sure that all elements |
| are valid Postgres names. |
| """ |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| for i in range(len(elm)): |
| elm[i] = elm[i].strip("\"") |
| return elm |
| # ------------------------------------------------------------------------ |
| |
| |
| def _string_to_array_with_quotes(s): |
| """ |
| Same as _string_to_array except the double quotes will be kept. |
| """ |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| return elm |
| # ------------------------------------------------------------------------ |
| |
| def get_col_name_type_sql_string(colnames, coltypes): |
| if colnames and coltypes and len(colnames)==len(coltypes): |
| return ','.join(map(' '.join, zip(colnames, coltypes))) |
| return None |
| # ------------------------------------------------------------------------ |
| |
| def py_list_to_sql_string(array, array_type=None, long_format=None): |
| """Convert a list to SQL array string |
| |
| @note: The long format is recommended with the input values quoted |
| appropriately by 'quote_literal'. |
| |
| The short format (e.g. '{1,2,3}') does not take care of escaping the |
| string values in the array. Hence an input like '{ t$}t, test}' will fail due |
| to the unescaped '{'. |
| For eg. following is a valid input for short format: |
| array = ['M'M', M\"M, "M$M,M'M", M, M, M@[\}(:*;M, MM ] |
| Note the double quotes and curly brackets need to be escaped in the input if |
| used with the short format. |
| """ |
| if long_format is None: |
| if (array_type is not None and |
| (any(array_type.startswith(i) |
| for i in ["text", "varchar", "character varying"]))): |
| long_format = False |
| else: |
| long_format = True |
| if not array_type: |
| array_type = "double precision[]" |
| else: |
| array_type = array_type.strip() |
| if not array_type.endswith("[]"): |
| array_type += "[]" |
| |
| if not array: |
| return "'{{ }}'::{0}".format(array_type) |
| else: |
| quote_delimiter = "$__MADLIB_OUTER__$" |
| # This is a quote delimiter that can be used in lieu of |
| # single quotes and allows the use of single quotes in the |
| # string without escaping |
| array_str = "ARRAY[ {val} ]" if long_format else "{qd}{{ {val} }}{qd}" |
| return (array_str + "::{array_type}").format( |
| val=','.join(map(str, array)), |
| array_type=array_type, |
| qd=quote_delimiter) |
| # ------------------------------------------------------------------------ |
| |
| def create_cols_from_array_sql_string(py_list, sql_array_col, colname, |
| coltype, has_one_ele, |
| module_name='Input Error'): |
| """ |
| Create SQL string to convert array of elements into multiple columns and corresponding |
| SQL string of columns for CREATE TABLE. |
| @args: |
| @param: py_list, python list, if None, return sql_array_col as colname. |
| The py_list can at most have one 'None' element that |
| is converted to sql 'NULL' |
| @param: sql_array_col, str pointing to a column in table containing |
| an array. |
| @param: colname, name of output column (can be treated as prefix |
| to multiple cols if has_one_ele=False) |
| @param: coltype, Type of columns to be created |
| @param: has_one_ele, bool. if True, assumes sql_array_col has |
| an array of exactly one element, which is treated as in |
| index to get value from py_list. If False, then a new |
| column is created for every element in py_list, |
| whose corresponding values are obtained from sql_array_col. |
| @examples: |
| 1) Input: |
| py_list = ['cat', 'dog'] |
| sql_array_col = sqlcol |
| colname = prob |
| coltype = TEXT |
| has_one_ele = FALSE |
| Output: |
| prob_cat TEXT, prob_dog TEXT |
| CAST(sqlcol[1] AS TEXT) AS prob_cat, CAST(sqlcol[2] AS TEXT) AS prob_dog |
| 2) Input: |
| py_list = ['cat', 'dog'] |
| sql_array_col = sqlcol |
| colname = estimated_pred |
| coltype = TEXT |
| has_one_ele = TRUE |
| Output: |
| estimated_pred TEXT |
| (ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred |
| |
| @NOTE: |
| If py_list is [None, 'cat', 'dog', NULL']: |
| then the SQL query string returned would create the following |
| column names: |
| prob_NULL, prob_cat, 'prob_dog', and 'prob_"NULL"'. |
| 1. Notice that for None, which represents Postgres' NULL value, the |
| column name will be 'prob_NULL', |
| 2. and to differentiate the column name for a string 'NULL', the |
| resulting column name will be 'prob_"NULL"'. |
| |
| The weird quoting in this column name is due to calling strip after |
| quote_ident in the code below. |
| |
| @returns: |
| @param, str, that can be used in a SQL query. |
| @param, str, that can be used in a SQL query. |
| |
| """ |
| _assert(sql_array_col, "{0}: sql_array_col should be a valid string.". |
| format(module_name)) |
| _assert(colname, "{0}: colname should be a valid string.".format( |
| module_name)) |
| if py_list: |
| _assert(is_valid_psql_type(coltype, BOOLEAN | NUMERIC | TEXT), |
| "{0}: Invalid coltype parameter {1}.".format( |
| module_name, coltype)) |
| _assert(py_list.count(None) <= 1, |
| "{0}: Input list should contain at most 1 None element.". |
| format(module_name)) |
| def py_list_str(ele): |
| """ |
| A python None is converted to a SQL NULL. |
| String 'NULL' is converted to SQL 'NULL' string by quoting |
| it to '"NULL"'. This quoting is necessary for Postgres to |
| differentiate between NULL and 'NULL' in the SQL query |
| string returned by create_cols_from_array_sql_string. |
| """ |
| if ele is None: |
| return 'NULL' |
| elif isinstance(ele, str) and ele.lower()=='null': |
| return '"{0}"'.format(ele) |
| return ele |
| |
| py_list = list(map(py_list_str, py_list)) |
| if has_one_ele: |
| # Query to choose the value in the first element of |
| # sql_array_col which is the index to access in py_list. |
| # The value from that corresponding index in py_list is |
| # the value of colname. |
| py_list_sql_str = py_list_to_sql_string(py_list, coltype+'[]') |
| select_clause = "({0})[{1}[1]+1]::{2} AS {3}".format( |
| py_list_sql_str, sql_array_col, coltype, colname) |
| create_columns = "{0} {1}".format( |
| colname, coltype) |
| else: |
| # Create as many columns as the length of py_list. The |
| # colnames are created based on the elements in py_list, |
| # while the value for these colnames is obtained from |
| # sql_array_col. |
| |
| # we cannot call sql quote_ident on the py_list entries because |
| # aliasing does not support quote_ident. Hence calling our |
| # python implementation of quote_ident. We must call strip() |
| # after quote_ident since the resulting SQL query fails otherwise. |
| select_clause = ', '.join( |
| ['CAST({sql_array_col}[{j}] AS {coltype}) AS "{final_colname}"'. |
| format(j=i + 1, |
| final_colname=quote_ident("{0}_{1}". |
| format(colname, str(suffix))).strip(' "'), |
| sql_array_col=sql_array_col, |
| coltype=coltype) |
| for i, suffix in enumerate(py_list) |
| ]) |
| create_columns = ', '.join( |
| ['"{final_colname}" {coltype}'. |
| format(final_colname=quote_ident("{0}_{1}". |
| format(colname, str(suffix))).strip(' "'), |
| coltype=coltype) |
| for i, suffix in enumerate(py_list) |
| ]) |
| else: |
| if has_one_ele: |
| select_clause = '{0}[1]+1 AS {1}'.format(sql_array_col, colname) |
| create_columns = '{0} {1}'.format(colname, coltype) |
| else: |
| select_clause = '{0} AS {1}'.format(sql_array_col, colname) |
| create_columns = '{0} {1}'.format(colname, coltype+'[]') |
| return select_clause, create_columns |
| # ------------------------------------------------------------------------ |
| |
| def _array_to_string(origin): |
| """ |
| Convert an array to string |
| """ |
| def _escape(s): |
| return re.sub(r'"', r'\"', str(s)) |
| return "{" + ",".join(map(_escape, origin)) + "}" |
| # ------------------------------------------------------------------------ |
| |
| |
| def _cast_if_null(input, alias=''): |
| if input: |
| return str(input) |
| else: |
| null_str = "NULL::text" |
| return null_str + " as " + alias if alias else null_str |
| # ------------------------------------------------------------------------ |
| |
| |
| def set_client_min_messages(new_level): |
| """ |
| Set the client_min_message setting in Postgres which controls the messages |
| sent to the psql client. |
| |
| Args: |
| @param new_level: string, New level to set the client_min_message |
| |
| Returns: |
| old_msg_level: string, The old client_min_message level before changing it. |
| |
| Raise: |
| ValueError: if the argument new_level is not a valid message level |
| """ |
| if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2', |
| 'debug1', 'log', 'notice', 'warning', 'error', |
| 'fatal', 'panic'): |
| raise ValueError("Not a valid message level sent to client") |
| |
| old_msg_level = plpy.execute(""" SELECT setting |
| FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| plpy.execute("SET client_min_messages TO {0}".format(new_level)) |
| return old_msg_level |
| # ------------------------------------------------------------------------- |
| |
| def is_pg_major_version_less_than(schema_madlib, compare_version, **kwargs): |
| version = plpy.execute("select version()")[0]["version"] |
| regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE) |
| version = regex.findall(version) |
| if len(version) > 0 and int(version[0][0]) < compare_version: |
| return True |
| else: |
| return False |
| |
| # Deal with earlier versions of PG or GPDB |
| class __mad_version: |
| def __init__(self): |
| self.version = plpy.execute("select version()")[0]["version"] |
| |
| def select_vecfunc(self): |
| """ |
| PG84 and GP40, GP41 do not have a good support for |
| vectors. They convert any vector into a string, surrounded |
| by { and }. Thus special care is needed for these older |
| versions of GPDB and PG. |
| """ |
| # GPDB 4.0 or 4.1 |
| if self.is_less_than_gp42() or self.is_less_than_pg90(): |
| return self.__extract |
| else: |
| return self.__identity |
| |
| def __extract(self, origin, text=True): |
| """ |
| Extract vector elements from a string with {} |
| as the brackets |
| """ |
| if origin is None: |
| return None |
| elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1)) |
| if text is False: |
| for i in range(len(elm)): |
| elm[i] = float(elm[i]) |
| return elm |
| |
| def __identity(self, origin, text=True): |
| return origin |
| |
| def select_vec_return(self): |
| """ |
| Special care is needed if one needs to return |
| vector from Python to SQL |
| """ |
| if self.is_less_than_gp42() or self.is_less_than_pg90(): |
| return self.__condense |
| else: |
| return self.__identity |
| |
| def __condense(self, origin): |
| """ |
| Convert the original vector into a string which some |
| old versions of SQL system can recognize |
| """ |
| return _array_to_string(origin) |
| |
| def select_array_agg(self, schema_madlib): |
| """ |
| GPDB < 4.1 and PG < 9.0 do not have support for array_agg, |
| so use the madlib array_agg for those versions |
| """ |
| if self.is_less_than_gp41() or self.is_less_than_pg90(): |
| return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib) |
| else: |
| return "array_agg" |
| |
| def is_pg(self): |
| if (re.search(r"PostgreSQL", self.version) and |
| not re.search(r"Greenplum\s*Database", self.version)): |
| return True |
| return False |
| |
| def is_gp43(self): |
| if re.search(r"Greenplum\s+Database\s+4\.3", self.version): |
| return True |
| return False |
| |
| def is_less_than_pg90(self): |
| regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9: |
| return True |
| else: |
| return False |
| |
| def is_less_than_gp41(self): |
| regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and float(version[0]) < 4.1: |
| return True |
| else: |
| return False |
| |
| def is_less_than_gp42(self): |
| regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and float(version[0]) < 4.2: |
| return True |
| else: |
| return False |
| |
| def is_pg_version_less_than(self, compare_version): |
| """ Return True if self is a PostgreSQL database and |
| self.version is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| |
| """ |
| regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and self.is_pg(): |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| def is_gp_version_less_than(self, compare_version): |
| """ Return True if self is a Greenplum database and self.version |
| is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| """ |
| regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if version: |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| |
| def _string_to_sql_array(schema_madlib, s, **kwargs): |
| """ |
| Split a string into an array of strings |
| Any space around the substrings are removed |
| |
| Requirement: every individual element in the string |
| must be a valid Postgres name, which means that if |
| there are spaces or commas in the element then the |
| whole element must be quoted by a pair of double |
| quotes. |
| |
| Usually this is not a problem. Especially in older |
| versions of GPDB, when an array is passed from |
| SQL to Python, it is converted to a string, which |
| automatically adds double quotes if there are spaces or |
| commas in the element. |
| |
| So use this function, if you are sure that all elements |
| are valid Postgres names. |
| """ |
| # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0 |
| version_wrapper = __mad_version() |
| array_to_string = version_wrapper.select_vec_return() |
| |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| for i in range(len(elm)): |
| elm[i] = elm[i].strip("\"") |
| return array_to_string(elm) |
| # ------------------------------------------------------------------------ |
| |
| |
| def current_user(): |
| """Returns the user name of the current database user.""" |
| return plpy.execute("SELECT current_user")[0]['current_user'] |
| # ------------------------------------------------------------------------ |
| |
| def is_superuser(user): |
| |
| return plpy.execute("SELECT rolsuper FROM pg_catalog.pg_roles "\ |
| "WHERE rolname = '{0}'".format(user))[0]['rolsuper'] |
| |
| def get_table_owner(schema_table): |
| |
| split_table = schema_table.split(".",1) |
| schema = split_table[0] |
| non_schema_table = split_table[1] |
| |
| q = """SELECT tableowner FROM pg_catalog.pg_tables |
| WHERE schemaname='{0}' AND tablename='{1}' |
| """.format(schema, non_schema_table) |
| return plpy.execute(q)[0]['tableowner'] |
| |
| def madlib_version(schema_madlib): |
| """Returns the MADlib version string.""" |
| raw = plpy.execute(""" |
| SELECT {schema_madlib}.version() |
| """.format(**locals()))[0]['version'] |
| return raw.split(',')[0].split(' ')[-1] |
| # ------------------------------------------------------------------------ |
| |
| |
| def preprocess_keyvalue_params(input_params, split_char='='): |
| """ |
| Parse the input_params string and split it using the split_char |
| |
| @param input_params: str, Comma-separated list of parameters |
| The parameter can be any key = value, where |
| key is a string of alphanumeric character |
| value is either and |
| @param split_char: str, character that splits the key and value elements. |
| Default set to '=' |
| """ |
| re_str = (r"([-:\w]+\s*" + # key is any alphanumeric character |
| # (including - and :) string |
| |
| split_char + # key and value are separated by split_char |
| |
| r""" |
| \s*([\(\{\[] # value can be string or array |
| [^\[\]\(\)\{\}]* # if value is array then accept anything inside |
| [\)\}\]] # and then match closing braces of array |
| |
| | # either array (above) or string (below) |
| |
| (?P<quote>\"?)[\w\s\-\%.]+(?P=quote) |
| # if value is string, it can be alphanumeric |
| # character string with a decimal dot, |
| # hyphen, or percent |
| # optionally quoted by `quote_char` |
| ) |
| )""" |
| ) |
| pattern = re.compile(re_str, re.VERBOSE) |
| return [m.group(1).strip() for m in pattern.finditer(input_params)] |
| # ------------------------------------------------------------------------ |
| |
| |
| def extract_keyvalue_params(input_params, |
| input_param_types=None, |
| default_values=None, |
| split_char='=', |
| usage_str='', |
| ignore_invalid=False, |
| allow_duplicates=True, |
| lower_case_names=True): |
| """ Extract key value pairs from input parameters or set the default values |
| |
| Args: |
| @param input_params: string, Format of |
| 'key1=value1, key2=value2,...', assuming default split_char. |
| The order does not matter. If a parameter is missing, then |
| the default value is used. If input_params is None or '', |
| then all default values are returned. This function also |
| validates the values of these parameters. |
| |
| @param input_param_types: dict, The type of each allowed parameter |
| name. Currently supports one of |
| (int, float, str, list) |
| @param default_values: dict, Default values for each allowed parameter. |
| @param split_char: str, The character used to split key and value. |
| Default set to '=' |
| @param usage_str: str, An optional usage string to print with error message. |
| @param ignore_invalid: bool, If True an invalid param input is ignore silently |
| @param allow_duplicates: bool, Allow repeat of a 'key' in input_params, |
| where the last occurence will be reflected |
| in the output. If False, then a ValueError is |
| raised. |
| @param lower_case_names: bool, Convert parameter names to lower case. |
| |
| Returns: |
| Dict. Dictionary of input parameter values with key as parameter name |
| and value as the parameter value |
| |
| Throws: |
| plpy.error - If the parameter is unsupported or the value is |
| not valid. |
| """ |
| if not input_params: |
| return default_values if default_values is not None else {} |
| |
| if default_values: |
| parameter_dict = default_values |
| else: |
| parameter_dict = {} |
| seen_params = set() |
| |
| for s in preprocess_keyvalue_params(input_params, split_char=split_char): |
| items = split_quoted_delimited_str(s, delimiter=split_char) |
| if (len(items) != 2): |
| raise KeyError("Input parameter list has incorrect format " |
| "{0}".format(usage_str)) |
| |
| param_name = items[0].strip(" \"") |
| if lower_case_names: |
| param_name = param_name.lower() |
| param_value = items[1].strip() |
| |
| if not allow_duplicates and param_name in seen_params: |
| raise ValueError("Invalid input: {0} duplicated in the param list". |
| format(param_name)) |
| |
| if not param_name or param_name in ('none', 'null'): |
| plpy.error("Invalid input param name: {0} \n" |
| "{1}".format(param_name, usage_str)) |
| if input_param_types: |
| try: |
| param_type = input_param_types[param_name] |
| except KeyError: |
| if not ignore_invalid: |
| raise KeyError("Invalid input: {0} is not a valid parameter " |
| "{1}".format(param_name, usage_str)) |
| else: |
| continue |
| try: |
| if param_type == bool: # bool is not subclassable |
| # True values are y, yes, t, true, on and 1; |
| # False values are n, no, f, false, off and 0. |
| # Raises ValueError if anything else. |
| parameter_dict[param_name] = bool(strtobool(param_value)) |
| elif param_type in (int, str, float): |
| parameter_dict[param_name] = param_type(param_value) |
| elif issubclass(param_type, collections.Iterable): |
| parameter_dict[param_name] = split_quoted_delimited_str( |
| param_value.strip('[](){} ')) |
| else: |
| raise TypeError("Invalid input: {0} has unsupported type " |
| "{1}".format(param_name, usage_str)) |
| except ValueError: |
| raise ValueError("Invalid input: {0} must be {1} \n" |
| "{2}".format(param_name, param_type, usage_str)) |
| else: |
| # if input parameter types not provided then just return all as string |
| parameter_dict[param_name] = str(param_value) |
| seen_params.add(param_name) |
| return parameter_dict |
| # ------------------------------------------------------------------------- |
| |
| |
| def split_quoted_delimited_str(input_str, delimiter=',', quote='"'): |
| """ Parse a delimited-string to return a list of individual tokens taking |
| quotes into account. |
| |
| Args: |
| @param input_str: str, Delimited input string |
| @param delimiter: str, The field delimiter character that separates |
| the tokens in input_str. Default = ',' |
| @param quote: str, One-character string used to quote fields containing |
| special characters, such as the field delimiter |
| Default = '"' |
| |
| Returns: |
| List. List of delimited strings. |
| """ |
| if not input_str or not delimiter or not quote: |
| return [] |
| try: |
| delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'. |
| format(d=delimiter, q=quote)) |
| return [i.strip() for i in |
| delimiter_reg.split(input_str.strip())[1::2]] |
| except Exception as e: |
| plpy.warning(str(e)) |
| raise ValueError("Invalid string input for splitting") |
| # ------------------------------------------------------------------------------ |
| |
| |
| def strip_end_quotes(input_str, quote='"'): |
| """ Remove the quote character from the start and end if they are present |
| (at both ends). Whitespace at start and end are not ignored. |
| |
| Args: |
| @param input_str |
| @param quote |
| |
| Returns: |
| str. Original string without the quotes at start and end |
| """ |
| if not input_str or not quote: |
| return input_str |
| if not isinstance(input_str, str): |
| return input_str |
| if input_str.startswith(quote) and input_str.endswith(quote): |
| return input_str[1:-1] |
| else: |
| return input_str |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _grp_null_checks(grp_list): |
| """ |
| Helper function for generating NULL checks for grouping columns |
| to be used within a WHERE clause |
| Args: |
| @param grp_list The list of grouping columns |
| """ |
| return ' AND '.join([" {i} IS NOT NULL ".format(**locals()) |
| for i in grp_list]) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _check_groups(tbl1, tbl2, grp_list): |
| """ |
| Helper function for joining tables with groups. |
| Args: |
| @param tbl1 Name of the first table |
| @param tbl2 Name of the second table |
| @param grp_list The list of grouping columns |
| """ |
| |
| return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals()) |
| for i in grp_list]) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_filtered_cols_subquery_str(include_from_table, exclude_from_table, |
| filter_cols_list): |
| """ |
| This function returns a subquery string with columns in the filter_cols_list |
| that appear in include_from_table but NOT IN exclude_from_table. |
| :param include_from_table: table from which cols in filter_cols_list should |
| be included |
| :param exclude_from_table: table from which cols in filter_cols_list should |
| be excluded |
| :param filter_cols_list: list of column names |
| :return: query string with relevant column names |
| """ |
| included_cols = get_table_qualified_col_str(include_from_table, filter_cols_list) |
| cols = ', '.join([col for col in filter_cols_list]) |
| return """({included_cols}) NOT IN |
| (SELECT {cols} |
| FROM {exclude_from_table}) |
| """.format(**locals()) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_table_qualified_col_str(tbl_name, col_list): |
| """ |
| Helper function for selecting columns of a table |
| Args: |
| @param tbl Name of the table |
| @param grp_list The list of grouping columns |
| """ |
| return ' , '.join([" {tbl_name}.{col} ".format(**locals()) |
| for col in col_list]) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_grouping_col_str(schema_madlib, module_name, reserved_cols, |
| source_table, grouping_col): |
| if grouping_col and grouping_col.lower() != 'null': |
| grouping_col_array = _string_to_array_with_quotes(grouping_col) |
| cols_in_tbl_valid(source_table, grouping_col_array, module_name) |
| does_exclude_reserved(grouping_col_array, reserved_cols) |
| grp_array_w_cast = explicit_bool_to_text(source_table, |
| grouping_col_array, |
| schema_madlib) |
| grouping_str = ', '.join(i + "::text" for i in grp_array_w_cast) |
| else: |
| grouping_str = "Null" |
| grouping_col = None |
| |
| return grouping_str, grouping_col |
| # ------------------------------------------------------------------------------ |
| |
| |
| def collate_plpy_result(plpy_result_rows): |
| if not plpy_result_rows: |
| return {} |
| else: |
| all_keys = plpy_result_rows[0].keys() |
| result = collections.defaultdict(list) |
| for each_row in plpy_result_rows: |
| for each_key in all_keys: |
| result[each_key].append(each_row[each_key]) |
| return result |
| # ------------------------------------------------------------------------------ |
| |
| |
| def validate_module_input_params(source_table, output_table, independent_varname, |
| dependent_varname, module_name, |
| grouping_cols=None, other_output_tables=None): |
| """ |
| This function is supposed to be used for validating params for |
| supervised learning like algos, e.g. linear regression, mlp, etc. since all |
| of them need to validate the following 4 parameters. |
| :param source_table: This table should exist and not be empty |
| :param output_table: This table should not exist |
| :param dependent_varname: This should be a valid expression in the source |
| table |
| :param independent_varname: This should be a valid expression in the source |
| table |
| :param module_name: Name of the module to be printed with the error messages |
| :param other_output_tables: List of additional output tables to validate. |
| These tables should not exist |
| """ |
| |
| input_tbl_valid(source_table, module_name) |
| |
| output_tbl_valid(output_table, module_name) |
| |
| if other_output_tables: |
| for tbl in other_output_tables: |
| output_tbl_valid(tbl, module_name) |
| if type(independent_varname) is not list: |
| independent_varname = [independent_varname] |
| if type(dependent_varname) is not list: |
| dependent_varname = [dependent_varname] |
| for i in independent_varname: |
| _assert(is_var_valid(source_table, i), |
| "{module_name} error: invalid independent_varname " |
| "('{independent_varname}') for source_table " |
| "({source_table})!".format(module_name=module_name, |
| independent_varname=i, |
| source_table=source_table)) |
| |
| |
| |
| for i in dependent_varname: |
| _assert(is_var_valid(source_table, i), |
| "{module_name} error: invalid dependent_varname " |
| "('{dependent_varname}') for source_table " |
| "({source_table})!".format(module_name=module_name, |
| dependent_varname=i, |
| source_table=source_table)) |
| if grouping_cols: |
| _assert(is_var_valid(source_table, grouping_cols), |
| "{module_name} error: invalid grouping_cols " |
| "('{grouping_cols}') for source_table " |
| "({source_table})!".format(module_name=module_name, |
| grouping_cols=grouping_cols, |
| source_table=source_table)) |
| # ------------------------------------------------------------------------ |
| |
| |
| def create_table_drop_cols(source_table, out_table, cols_to_drop, **kwargs): |
| """ Create copy of table while dropping some of the columns |
| Args: |
| @param source_table str. Name of the source table |
| @param out_table str. Name of the output table |
| @param cols_to_drop str. Comma-separated list of columns to drop |
| """ |
| input_tbl_valid(source_table, 'Utilities') |
| output_tbl_valid(out_table, 'Utilities') |
| _assert(cols_to_drop and cols_to_drop.strip(), |
| "Utilities error: cols_to_drop cannot be empty or NULL") |
| |
| source_table_cols = get_cols(source_table) |
| cols_to_drop_list = split_quoted_delimited_str(cols_to_drop) |
| cols_not_in_source = set(cols_to_drop_list) - set(source_table_cols) |
| _assert(not cols_not_in_source, |
| "Utilities error: Some column(s) in cols_to_drop are not present " |
| "in source table") |
| |
| cols_to_retain = [c for c in source_table_cols if c not in cols_to_drop_list] |
| _assert(cols_to_retain, |
| "Utilities error: No valid columns for the output table") |
| plpy.execute(""" |
| CREATE TABLE {out_table} AS |
| SELECT {cols} |
| FROM {source_table} |
| """.format(cols=', '.join(cols_to_retain), |
| out_table=out_table, |
| source_table=source_table)) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def rotate(l, n): |
| """Summary |
| Rotate the list l to right(the index increasing direction) for n elements. |
| Args: |
| l (list): The input list to rotate |
| n (integer): The number of elements to rotate |
| |
| Returns: |
| list: The rotated list |
| """ |
| return l[-n:] + l[:-n] |
| # ------------------------------------------------------------------------------ |
| |
| def rename_table(schema_madlib, orig_name, new_name): |
| """ |
| Renames possibly schema qualified table name to a new schema qualified name |
| ensuring the schema qualification are changed appropriately |
| |
| Args: |
| @param orig_name: string, Original name of the table |
| (must be schema qualified if table schema is not in search path) |
| @param new_name: string, New name of the table |
| (can be schema qualified. If it is not then the original |
| schema is maintained) |
| Returns: |
| String. The new table name qualified with the schema name |
| """ |
| new_names_split = new_name.split(".") |
| if len(new_names_split) > 2: |
| raise AssertionError("Invalid table name") |
| new_table_name = new_names_split[-1] |
| new_table_schema = new_names_split[0] if len(new_names_split) > 1 else None |
| |
| orig_names_split = orig_name.split(".") |
| if len(orig_names_split) > 2: |
| raise AssertionError("Invalid table name") |
| |
| if len(orig_names_split) > 1: |
| orig_table_schema = orig_names_split[0] |
| else: |
| # we need to get the schema name of the original table if we are |
| # to change the schema of the new table. This is to ensure that we |
| # change the schema of the correct table in case there are multiple |
| # tables with the same new name. |
| orig_table_schema = get_first_schema(orig_name) |
| |
| if orig_table_schema is None: |
| raise AssertionError("Relation {0} not found during rename". |
| format(orig_name)) |
| return __do_rename_and_get_new_name(orig_name, new_name, orig_table_schema, |
| new_table_schema, new_table_name) |
| # ------------------------------------------------------------------------------ |
| def __do_rename_and_get_new_name(orig_name, new_name, orig_table_schema, |
| new_table_schema, new_table_name): |
| """ |
| Internal private function to perform the rename operation after all the |
| validation checks |
| """ |
| |
| """ |
| CASE 1 |
| If the output table is schema is pg_temp, we cannot alter table schemas from/to |
| temp schemas. If it looks like a temp schema, we stay safe and just use |
| create/drop |
| Test cases |
| foo.bar to pg_temp.bar |
| foo.bar to pg_temp.bar2 |
| foo to pg_temp.bar |
| pg_temp.foo to pg_temp.bar |
| """ |
| if new_table_schema and 'pg_temp' in new_table_schema: |
| """ |
| If both new_table_schema and orig_table_schema have pg_temp in it, |
| just run an alter statement instead of CTAS. Without this, pca dev-check |
| fails on gpdb5/6 (but not on pg) |
| """ |
| if new_table_schema != orig_table_schema: |
| plpy.info("""CREATE TABLE {new_name} AS SELECT * FROM {orig_name}""" |
| .format(**locals())) |
| plpy.execute("""CREATE TABLE {new_name} AS SELECT * FROM {orig_name}""" |
| .format(**locals())) |
| drop_tables([orig_name]) |
| return new_name |
| else: |
| plpy.execute("ALTER TABLE {orig_name} RENAME TO {new_table_name}". |
| format(**locals())) |
| return new_name |
| |
| """ |
| CASE 2 |
| Do direct rename if the new table does not have an output schema or |
| if the new table schema is the same as the original table schema |
| Test Cases |
| rename foo to bar |
| rename foo.bar to foo.bar2 |
| rename foo.bar to bar2 |
| """ |
| if not new_table_schema or new_table_schema == orig_table_schema: |
| plpy.execute("ALTER TABLE {orig_name} RENAME TO {new_table_name}". |
| format(**locals())) |
| return orig_table_schema + "." + new_table_name |
| |
| """ |
| CASE 3 |
| output table is schema qualified |
| 1. rename the original table to an interim temp name |
| 2. set the new schema on that interim table |
| 3. rename interim table to the new table name |
| Test cases |
| foo.bar to foo2.bar2 |
| foo.bar to foo2.bar |
| """ |
| interim_temp_name = unique_string("rename_table_helper") |
| plpy.execute( |
| "ALTER TABLE {orig_name} RENAME to {interim_temp_name}".format( |
| **locals())) |
| |
| plpy.execute( |
| """ALTER TABLE {interim_temp_name} SET SCHEMA {new_table_schema}""".format( |
| **locals())) |
| |
| plpy.execute( |
| """ALTER TABLE {new_table_schema}.{interim_temp_name} RENAME to {new_table_name}""" |
| .format(**locals())) |
| return new_name |
| # ------------------------------------------------------------------------------ |
| |
| def is_platform_gp6_or_up(): |
| version_wrapper = __mad_version() |
| return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0') |
| |
| def get_psql_type(py_type): |
| if type(py_type) == int: |
| return 'integer' |
| elif type(py_type) == float: |
| return 'double precision' |
| elif type(py_type) == bool: |
| return 'boolean' |
| elif type(py_type) == str: |
| return 'varchar' |
| else: |
| plpy.error("Cannot determine the type of {0}".format(py_type)) |
| |
| |
| def get_schema(tbl_str): |
| |
| names = tbl_str.split('.') |
| |
| if not names or len(names) > 2: |
| raise TypeError("Incorrect table name ({0}) provided! Table name should be " |
| "of the form: <schema name>.<table name>".format(table_name)) |
| elif len(names) == 2: |
| return unquote_ident(names[0]) |
| |
| else: |
| return None |
| # ------------------------------------------------------------------------------- |
| |
| def get_current_timestamp(format): |
| """Gets current time stamp in the specified format string""" |
| return datetime.fromtimestamp(time.time()).strftime(format) |