| m4_changequote(`<!', `!>') |
| |
| import re |
| import time |
| import random |
| from distutils.util import strtobool |
| |
| if __name__ != "__main__": |
| import plpy |
| |
| |
| def _assert(condition, msg): |
| """ |
| @brief if the given condition is false, then raise an error with the message |
| @param condition the condition to be asserted |
| @param msg the error message to be reported |
| """ |
| if not condition: |
| plpy.error(msg) |
| # ======================================================================== |
| |
| |
| def warn(condition, msg): |
| """ |
| @brief if the given condition is false, then raise a warning with the message |
| @param condition the condition to be asserted |
| @param msg the error message to be reported |
| """ |
| if not condition: |
| plpy.warning(msg) |
| |
| |
| def get_seg_number(): |
| """ Find out how many primary segments exist in the distribution |
| Might be useful for partitioning data. |
| """ |
| m4_ifdef(<!__POSTGRESQL__!>, <!return 1!>, <! |
| return plpy.execute( |
| """ |
| SELECT count(*) from gp_segment_configuration |
| WHERE content >= 0 and role = 'p' |
| """)[0]['count'] |
| !>) |
| # ------------------------------------------------------------ |
| |
| |
| def unique_string(**kwargs): |
| """ |
| Generate random remporary names for temp table and other names. |
| It has a SQL interface so both SQL and Python functions can call it. |
| """ |
| r1 = random.randint(1, 100000000) |
| r2 = int(time.time()) |
| r3 = int(time.time()) % random.randint(1, 100000000) |
| u_string = "__madlib_temp_" + str(r1) + "_" + str(r2) + "_" + str(r3) + "__" |
| return u_string |
| # ======================================================================== |
| |
| |
| def add_postfix(quoted_string, postfix): |
| """ Append a string to the end of the table name. |
| If input table name is quoted by double quotes, make sure the postfix is |
| inside of the double quotes. |
| |
| Arguments: |
| @param quoted_string: str. A string representing a database quoted string |
| @param postfix: str. A string to add as a suffix to quoted_string. |
| ** This is assumed to not contain any quotes ** |
| """ |
| quoted_string = quoted_string.strip() |
| if quoted_string.startswith('"') and quoted_string.endswith('"'): |
| output_str = quoted_string[:-1] + postfix + '"' |
| else: |
| output_str = quoted_string + postfix |
| return output_str |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_orca(): |
| m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <! |
| optimizer = plpy.execute("show optimizer")[0]["optimizer"] |
| return True if optimizer == 'on' else False |
| !>, <! |
| return False |
| !>) |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_psql_numeric_type(arg, exclude=None): |
| """ |
| Checks if argument is one of the various numeric types in PostgreSQL |
| Args: |
| @param arg: string, Type name to check |
| @param exclude: iterable, List of types to exclude from checking |
| |
| Returns: |
| Boolean. Returns if 'arg' is one of the numeric types |
| """ |
| numeric_types = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric', |
| 'real', 'double precision', 'serial', 'bigserial']) |
| if exclude is None: |
| exclude = [] |
| to_check_types = numeric_types - set(exclude) |
| return (arg in to_check_types) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _string_to_array(s): |
| """ |
| Split a string into an array of strings |
| Any space around the substrings are removed |
| |
| Requirement: every individual element in the string |
| must be a valid Postgres name, which means that if |
| there are spaces or commas in the element then the |
| whole element must be quoted by a pair of double |
| quotes. |
| |
| Usually this is not a problem. Especially in older |
| versions of GPDB, when an array is passed from |
| SQL to Python, it is converted to a string, which |
| automatically adds double quotes if there are spaces or |
| commas in the element. |
| |
| So use this function, if you are sure that all elements |
| are valid Postgres names. |
| """ |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| for i in range(len(elm)): |
| elm[i] = elm[i].strip("\"") |
| return elm |
| |
| |
| # ======================================================================== |
| |
| def _string_to_array_with_quotes(s): |
| """ |
| Same as _string_to_array except the double quotes will be kept. |
| """ |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| return elm |
| |
| # ======================================================================== |
| |
| |
| def py_list_to_sql_string(array, array_type=None): |
| """Convert a list to SQL array string """ |
| long_format = True |
| if (array_type and |
| (any(i in array_type |
| for i in ["text", "varchar", "character varying"]))): |
| long_format = False |
| if not array_type: |
| array_type = "double precision[]" |
| else: |
| array_type = array_type.strip() |
| if not array_type.endswith("[]"): |
| array_type += "[]" |
| |
| if not array: |
| return "'{{ }}'::{0}".format(array_type) |
| else: |
| array_str = "ARRAY[ {0} ]" if long_format else "'{{ {0} }}'" |
| return (array_str + "::{1}").format(','.join(map(str, array)), array_type) |
| # ======================================================================== |
| |
| |
| def _array_to_string(origin): |
| """ |
| Convert an array to string |
| """ |
| return "{" + ",".join(map(str, origin)) + "}" |
| # ======================================================================== |
| |
| |
| def set_client_min_messages(new_level): |
| """ |
| Set the client_min_message setting in Postgres which controls the messages |
| sent to the psql client. |
| |
| Args: |
| @param new_level: string, New level to set the client_min_message |
| |
| Returns: |
| old_msg_level: string, The old client_min_message level before changing it. |
| |
| Raise: |
| ValueError: if the argument new_level is not a valid message level |
| """ |
| if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2', |
| 'debug1', 'log', 'notice', 'warning', 'error', |
| 'fatal', 'panic'): |
| raise ValueError("Not a valid message level sent to client") |
| |
| old_msg_level = plpy.execute(""" SELECT setting |
| FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| plpy.execute("SET client_min_messages TO {0}".format(new_level)) |
| return old_msg_level |
| # ------------------------------------------------------------------------- |
| |
| |
| # Deal with earlier versions of PG or GPDB |
| class __mad_version: |
| def __init__(self): |
| self.version = plpy.execute("select version()")[0]["version"] |
| |
| def select_vecfunc(self): |
| """ |
| PG84 and GP40, GP41 do not have a good support for |
| vectors. They convert any vector into a string, surrounded |
| by { and }. Thus special care is needed for these older |
| versions of GPDB and PG. |
| """ |
| # GPDB 4.0 or 4.1 |
| if self.is_less_than_gp42() or self.is_less_than_pg90(): |
| return self.__extract |
| else: |
| return self.__identity |
| |
| def __extract(self, origin, text=True): |
| """ |
| Extract vector elements from a string with {} |
| as the brackets |
| """ |
| if origin is None: |
| return None |
| elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1)) |
| if text is False: |
| for i in range(len(elm)): |
| elm[i] = float(elm[i]) |
| return elm |
| |
| def __identity(self, origin, text=True): |
| return origin |
| |
| def select_vec_return(self): |
| """ |
| Special care is needed if one needs to return |
| vector from Python to SQL |
| """ |
| if self.is_less_than_gp42() or self.is_less_than_pg90(): |
| return self.__condense |
| else: |
| return self.__identity |
| |
| def __condense(self, origin): |
| """ |
| Convert the original vector into a string which some |
| old versions of SQL system can recognize |
| """ |
| return _array_to_string(origin) |
| |
| def select_array_agg(self, schema_madlib): |
| """ |
| GPDB < 4.1 and PG < 9.0 do not have support for array_agg, |
| so use the madlib array_agg for those versions |
| """ |
| if self.is_less_than_gp41() or self.is_less_than_pg90(): |
| return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib) |
| else: |
| return "array_agg" |
| |
| def is_pg(self): |
| if (re.search(r"PostgreSQL", self.version) and |
| not re.search(r"Greenplum\s*Database", self.version)): |
| return True |
| return False |
| |
| def is_gp43(self): |
| if re.search(r"Greenplum\s+Database\s+4\.3", self.version): |
| return True |
| return False |
| |
| def is_less_than_pg90(self): |
| regex = re.compile('PostgreSQL\s*([0-9])([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9: |
| return True |
| else: |
| return False |
| |
| def is_less_than_gp41(self): |
| regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and float(version[0]) < 4.1: |
| return True |
| else: |
| return False |
| |
| def is_less_than_gp42(self): |
| regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and float(version[0]) < 4.2: |
| return True |
| else: |
| return False |
| |
| def is_hawq(self): |
| if re.search(r"HAWQ\s+[0-9.]+", self.version): |
| return True |
| return False |
| |
| def is_pg_version_less_than(self, compare_version): |
| """ Return True if self is a PostgreSQL database and |
| self.version is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| |
| """ |
| regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and self.is_pg(): |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| def is_gp_version_less_than(self, compare_version): |
| """ Return True if self is a Greenplum database and self.version |
| is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| """ |
| regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and not self.is_hawq(): |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| def is_hq_version_less_than(self, compare_version): |
| """ Return True if self is a HAWQ database and self.version |
| is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| """ |
| regex = re.compile('HAWQ\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0: |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| |
| def _string_to_sql_array(schema_madlib, s, **kwargs): |
| """ |
| Split a string into an array of strings |
| Any space around the substrings are removed |
| |
| Requirement: every individual element in the string |
| must be a valid Postgres name, which means that if |
| there are spaces or commas in the element then the |
| whole element must be quoted by a pair of double |
| quotes. |
| |
| Usually this is not a problem. Especially in older |
| versions of GPDB, when an array is passed from |
| SQL to Python, it is converted to a string, which |
| automatically adds double quotes if there are spaces or |
| commas in the element. |
| |
| So use this function, if you are sure that all elements |
| are valid Postgres names. |
| """ |
| # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0 |
| version_wrapper = __mad_version() |
| array_to_string = version_wrapper.select_vec_return() |
| |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| for i in range(len(elm)): |
| elm[i] = elm[i].strip("\"") |
| return array_to_string(elm) |
| # ======================================================================== |
| |
| |
| def current_user(): |
| """Returns the user name of the current database user.""" |
| return plpy.execute("SELECT current_user")[0]['current_user'] |
| # ======================================================================== |
| |
| |
| def madlib_version(schema_madlib): |
| """Returns the MADlib version string.""" |
| raw = plpy.execute(""" |
| SELECT {schema_madlib}.version() |
| """.format(**locals()))[0]['version'] |
| return raw.split(',')[0].split(' ')[-1] |
| # ======================================================================== |
| |
| |
| def preprocess_keyvalue_params(input_params): |
| """ |
| Parse the input_params string and split it |
| |
| @param input_params: string, Comma-separated list of parameters |
| The parameter can be any key = value, where |
| key is a string of alphanumeric character |
| value is either and |
| |
| |
| """ |
| pattern = re.compile(r""" |
| (\w+\s* # key is any alphanumeric character string |
| = # key and value are separated by = |
| \s*([\(\{\[] # value can be string or array |
| [^\[\]\(\)\{\}]* # if value is array then accept anything inside array |
| [\)\}\]] # and then match closing braces of array |
| |
| | # either array (above) or string (below) |
| |
| (?P<quote>\"?)[\w\s\-\%]+\.?[\w\s\-\%]*(?P=quote) |
| # if value is string, it can be alphanumeric |
| # character string with a decimal dot, |
| # optionally quoted by double quotes |
| ) |
| )""", re.VERBOSE) |
| return [m.group(1).strip() for m in pattern.finditer(input_params)] |
| # ======================================================================== |
| |
| |
| def extract_keyvalue_params(input_params, |
| input_param_types, |
| default_values=None): |
| """ Extract key value pairs from input parameters or set the default values |
| |
| Args: |
| @param input_params: string, Format of |
| 'key1=value1, key2=value2,...'. The order |
| does not matter. If a parameter is missing, then the default |
| value is used. If input_params is None or '', |
| then all default values are returned. |
| This function also validates the values of these parameters. |
| @param input_param_types: dict, The type of each allowed parameter |
| name. Currently supports one of |
| (int, float, str, list) |
| @param default_values: dict, Default values for each allowed parameter. |
| |
| Returns: |
| Dict. Dictionary of input parameter values with key as parameter name |
| and value as the parameter value |
| |
| Throws: |
| plpy.error - If the parameter is unsupported or the value is |
| not valid. |
| """ |
| if not input_params: |
| return default_values if default_values is not None else {} |
| |
| if default_values: |
| parameter_dict = default_values |
| else: |
| parameter_dict = {} |
| for s in preprocess_keyvalue_params(input_params): |
| items = s.split("=") |
| if (len(items) != 2): |
| raise KeyError("Input parameter list has incorrect format") |
| |
| param_name = items[0].strip(" \"").lower() |
| param_value = items[1].strip() # .strip(" \"").lower() |
| if not param_name or param_name in ('none', 'null'): |
| plpy.error("Invalid input param name: {0} ".format(str(param_name))) |
| # if not param_value or param_value in ('none', 'null'): |
| # plpy.error("Invalid input param value: {0} ".format(str(param_value))) |
| |
| try: |
| param_type = input_param_types[param_name] |
| except KeyError: |
| raise KeyError("Invalid input: {0} is not a valid parameter". |
| format(param_name)) |
| try: |
| if param_type in (int, str, float): |
| parameter_dict[param_name] = param_type(param_value) |
| elif param_type == list: |
| parameter_dict[param_name] = param_value.strip('[](){} ').split(',') |
| elif param_type == bool: |
| # # True values are y, yes, t, true, on and 1; |
| # # False values are n, no, f, false, off and 0. |
| # # Raises ValueError if anything else. |
| parameter_dict[param_name] = bool(strtobool(param_value)) |
| # if param_value in ['y', 'yes', 't', 'true', 'on', '1']: |
| # parameter_dict[param_name] = True |
| # elif param_value in ['n', 'no', 'f', 'false', 'off', '0']: |
| # parameter_dict[param_name] = False |
| # else: |
| # raise ValueError( |
| # """Invalid input: {0} should be of boolean type. |
| # True values are y, yes, t, true, on and 1. |
| # False values are n, no, f, false, off and 0. |
| # """.format(param_name)) |
| else: |
| raise TypeError("Invalid input: {0} has unsupported type". |
| format(param_name)) |
| except ValueError: |
| raise ValueError("Invalid input: {0} must be {1}". |
| format(param_name, str(param_type))) |
| return parameter_dict |
| # ------------------------------------------------------------------------- |
| |
| |
| def split_quoted_delimited_str(input_str, delimiter=',', quote=None): |
| """ Parse a delimited-string to return a list of individual tokens taking |
| quotes into account. |
| |
| Args: |
| @param input_str: str, Delimited input string |
| @param delimiter: str, The field delimiter character that separates |
| the tokens in input_str. Default = ',' |
| @param quote: str, One-character string used to quote fields containing |
| special characters, such as the field delimiter |
| Default = '"' |
| |
| Returns: |
| List. List of delimited strings. |
| """ |
| if not quote: |
| quote = '"' |
| |
| if not input_str or not delimiter: |
| return [] |
| try: |
| delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'. |
| format(d=delimiter, q=quote)) |
| return [i.strip() for i in |
| delimiter_reg.split(input_str.strip())[1::2]] |
| except Exception as e: |
| plpy.warning(str(e)) |
| raise ValueError("Invalid string input for splitting") |
| # ------------------------------------------------------------------------------ |
| |
| |
| def strip_end_quotes(input_str, quote='"'): |
| """ Remove the quote character from the start and end if they are present |
| (at both ends). Whitespace at start and end are not ignored. |
| |
| Args: |
| @param input_str |
| @param quote |
| |
| Returns: |
| str. Original string without the quotes at start and end |
| """ |
| if not input_str or not quote: |
| return input_str |
| if not isinstance(input_str, str): |
| return input_str |
| if input_str.startswith(quote) and input_str.endswith(quote): |
| return input_str[1:-1] |
| else: |
| return input_str |
| # ------------------------------------------------------------------------------ |
| |
| |
| import unittest |
| |
| |
| class UtilitiesTestCase(unittest.TestCase): |
| """ |
| Comment "import plpy" and replace plpy.error calls with appropriate |
| Python Exceptions to successfully run the test cases |
| """ |
| def setUp(self): |
| self.optimizer_params1 = 'max_iter=10, optimizer="irls", precision=1e-4' |
| self.optimizer_params2 = 'max_iter=2.01, optimizer=newton-irls, precision=1e-5' |
| self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,2,3,4}' |
| self.optimizer_params4 = 'max_iter=10, optimizer="irls", \ |
| precision=0.02.01, lambda={1,2,3,4}' |
| self.optimizer_types = {'max_iter': int, 'optimizer': str, |
| 'lambda': list, 'precision': float} |
| |
| def tearDown(self): |
| pass |
| |
| def test_preprocess_optimizer(self): |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params1), |
| ['max_iter=10', 'optimizer="irls"', 'precision=1e-4']) |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params2), |
| ['max_iter=2.01', 'optimizer=newton-irls', 'precision=1e-5']) |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params3), |
| ['max_iter=10', 'lambda={1,2,3,4}']) |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params4), |
| ['max_iter=10', 'optimizer="irls"', 'precision=0.02', 'lambda={1,2,3,4}']) |
| |
| def test_extract_optimizers(self): |
| self.assertEqual({'max_iter': 10, 'optimizer': "irls", 'precision': 0.0001}, |
| extract_keyvalue_params(self.optimizer_params1, self.optimizer_types)) |
| self.assertEqual({'max_iter': 10, 'lambda': ['1', '2', '3', '4']}, |
| extract_keyvalue_params(self.optimizer_params3, self.optimizer_types)) |
| self.assertEqual({'max_iter': 10, 'optimizer': "irls", 'precision': 0.02, |
| 'lambda': ['1', '2', '3', '4']}, |
| extract_keyvalue_params(self.optimizer_params4, self.optimizer_types)) |
| self.assertRaises(ValueError, |
| extract_keyvalue_params, self.optimizer_params2, self.optimizer_types) |
| |
| def test_split_delimited_string(self): |
| self.assertEqual(['a', 'b', 'c'], split_quoted_delimited_str('a, b, c', quote='|')) |
| self.assertEqual(['a', '|b, c|'], split_quoted_delimited_str('a, |b, c|', quote='|')) |
| self.assertEqual(['a', '"b, c"'], split_quoted_delimited_str('a, "b, c"')) |
| self.assertEqual(['"a^5,6"', 'b', 'c'], split_quoted_delimited_str('"a^5,6", b, c', quote='"')) |
| self.assertEqual(['"A""^5,6"', 'b', 'c'], split_quoted_delimited_str('"A""^5,6", b, c', quote='"')) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |