| |
| import collections |
| import re |
| import time |
| import random |
| from distutils.util import strtobool |
| |
| if __name__ != "__main__": |
| from validate_args import _get_table_schema_names |
| from validate_args import get_first_schema |
| import plpy |
| |
| |
| m4_changequote(`<!', `!>') |
| |
| |
| def get_seg_number(): |
| """ Find out how many primary segments exist in the distribution |
| Might be useful for partitioning data. |
| """ |
| m4_ifdef(<!__POSTGRESQL__!>, <!return 1!>, <! |
| return plpy.execute( |
| """ |
| SELECT count(*) from gp_segment_configuration |
| WHERE role = 'p' |
| """)[0]['count'] |
| !>) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def is_orca(): |
| m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <! |
| optimizer = plpy.execute("show optimizer")[0]["optimizer"] |
| return True if optimizer == 'on' else False |
| !>, <! |
| return False |
| !>) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def is_platform_pg(): |
| return m4_ifdef(<!__POSTGRESQL__!>, <!True!>, <!False!>) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _assert(condition, msg): |
| """ |
| @brief if the given condition is false, then raise an error with the message |
| @param condition the condition to be asserted |
| @param msg the error message to be reported |
| """ |
| if not condition: |
| plpy.error(msg) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def warn(condition, msg): |
| """ |
| @brief if the given condition is false, then raise a warning with the message |
| @param condition the condition to be asserted |
| @param msg the error message to be reported |
| """ |
| if not condition: |
| plpy.warning(msg) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def get_distribution_policy(source_table): |
| """ Return a list of columns that define the distribution policy of source_table |
| Args: |
| @param source_table |
| |
| Returns: |
| List of str. |
| """ |
| _, table_name = _get_table_schema_names(source_table) |
| schema_name = get_first_schema(source_table) |
| dist_attr = plpy.execute(""" |
| SELECT array_agg(pga.attname) as dist_attr |
| FROM ( |
| SELECT gdp.localoid, |
| CASE |
| WHEN ( ARRAY_UPPER(gdp.attrnums, 1) > 0 ) THEN |
| UNNEST(gdp.attrnums) |
| ELSE NULL |
| END AS attnum |
| FROM gp_distribution_policy gdp |
| ) AS distkey |
| INNER JOIN pg_class AS pgc |
| ON distkey.localoid = pgc.oid AND pgc.relname = '{table_name}' |
| INNER JOIN pg_namespace pgn |
| ON pgc.relnamespace = pgn.oid AND pgn.nspname = '{schema_name}' |
| LEFT OUTER JOIN pg_attribute pga |
| ON distkey.attnum = pga.attnum AND distkey.localoid = pga.attrelid |
| """.format(table_name=table_name, schema_name=schema_name))[0]["dist_attr"] |
| return dist_attr |
| # ------------------------------------------------------------------------------ |
| |
| |
| def num_features(source_table, independent_varname): |
| return plpy.execute("SELECT array_upper({0}, 1) AS dim " |
| "FROM {1} LIMIT 1" |
| .format(independent_varname, |
| source_table))[0]['dim'] |
| # ------------------------------------------------------------------------------ |
| |
| |
| def num_samples(source_table): |
| return plpy.execute("SELECT count(*) AS n FROM {0}" |
| .format(source_table))[0]['n'] |
| # ------------------------------------------------------------------------------ |
| |
| |
| def unique_string(desp='', **kwargs): |
| """ |
| Generate random remporary names for temp table and other names. |
| It has a SQL interface so both SQL and Python functions can call it. |
| """ |
| r1 = random.randint(1, 100000000) |
| r2 = int(time.time()) |
| r3 = int(time.time()) % random.randint(1, 100000000) |
| u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__" |
| return u_string |
| # ------------------------------------------------------------------------------ |
| |
| |
| def add_postfix(quoted_string, postfix): |
| """ Append a string to the end of the table name. |
| If input table name is quoted by double quotes, make sure the postfix is |
| inside of the double quotes. |
| |
| Arguments: |
| @param quoted_string: str. A string representing a database quoted string |
| @param postfix: str. A string to add as a suffix to quoted_string. |
| ** This is assumed to not contain any quotes ** |
| """ |
| quoted_string = quoted_string.strip() |
| if quoted_string.startswith('"') and quoted_string.endswith('"'): |
| output_str = quoted_string[:-1] + postfix + '"' |
| else: |
| output_str = quoted_string + postfix |
| return output_str |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_psql_numeric_type(arg, exclude=None): |
| """ |
| Checks if argument is one of the various numeric types in PostgreSQL |
| Args: |
| @param arg: string, Type name to check |
| @param exclude: iterable, List of types to exclude from checking |
| |
| Returns: |
| Boolean. Returns if 'arg' is one of the numeric types |
| """ |
| numeric_types = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric', |
| 'real', 'double precision', 'serial', 'bigserial']) |
| if exclude is None: |
| exclude = [] |
| to_check_types = numeric_types - set(exclude) |
| return (arg in to_check_types) |
| # ------------------------------------------------------------------------- |
| |
| |
| def _string_to_array(s): |
| """ |
| Split a string into an array of strings |
| Any space around the substrings are removed |
| |
| Requirement: every individual element in the string |
| must be a valid Postgres name, which means that if |
| there are spaces or commas in the element then the |
| whole element must be quoted by a pair of double |
| quotes. |
| |
| Usually this is not a problem. Especially in older |
| versions of GPDB, when an array is passed from |
| SQL to Python, it is converted to a string, which |
| automatically adds double quotes if there are spaces or |
| commas in the element. |
| |
| So use this function, if you are sure that all elements |
| are valid Postgres names. |
| """ |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| for i in range(len(elm)): |
| elm[i] = elm[i].strip("\"") |
| return elm |
| # ------------------------------------------------------------------------ |
| |
| |
| def _string_to_array_with_quotes(s): |
| """ |
| Same as _string_to_array except the double quotes will be kept. |
| """ |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| return elm |
| # ------------------------------------------------------------------------ |
| |
| |
| def py_list_to_sql_string(array, array_type=None): |
| """Convert a list to SQL array string """ |
| long_format = True |
| if (array_type and |
| (any(i in array_type |
| for i in ["text", "varchar", "character varying"]))): |
| long_format = False |
| if not array_type: |
| array_type = "double precision[]" |
| else: |
| array_type = array_type.strip() |
| if not array_type.endswith("[]"): |
| array_type += "[]" |
| |
| if not array: |
| return "'{{ }}'::{0}".format(array_type) |
| else: |
| array_str = "ARRAY[ {0} ]" if long_format else "'{{ {0} }}'" |
| return (array_str + "::{1}").format(','.join(map(str, array)), array_type) |
| # ------------------------------------------------------------------------ |
| |
| |
| def _array_to_string(origin): |
| """ |
| Convert an array to string |
| """ |
| return "{" + ",".join(map(str, origin)) + "}" |
| # ------------------------------------------------------------------------ |
| |
| |
| def _cast_if_null(input, alias=''): |
| if input: |
| return str(input) |
| else: |
| null_str = "NULL::text" |
| return null_str + " as " + alias if alias else null_str |
| # ------------------------------------------------------------------------ |
| |
| |
| def set_client_min_messages(new_level): |
| """ |
| Set the client_min_message setting in Postgres which controls the messages |
| sent to the psql client. |
| |
| Args: |
| @param new_level: string, New level to set the client_min_message |
| |
| Returns: |
| old_msg_level: string, The old client_min_message level before changing it. |
| |
| Raise: |
| ValueError: if the argument new_level is not a valid message level |
| """ |
| if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2', |
| 'debug1', 'log', 'notice', 'warning', 'error', |
| 'fatal', 'panic'): |
| raise ValueError("Not a valid message level sent to client") |
| |
| old_msg_level = plpy.execute(""" SELECT setting |
| FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| plpy.execute("SET client_min_messages TO {0}".format(new_level)) |
| return old_msg_level |
| # ------------------------------------------------------------------------- |
| |
| |
| # Deal with earlier versions of PG or GPDB |
| class __mad_version: |
| def __init__(self): |
| self.version = plpy.execute("select version()")[0]["version"] |
| |
| def select_vecfunc(self): |
| """ |
| PG84 and GP40, GP41 do not have a good support for |
| vectors. They convert any vector into a string, surrounded |
| by { and }. Thus special care is needed for these older |
| versions of GPDB and PG. |
| """ |
| # GPDB 4.0 or 4.1 |
| if self.is_less_than_gp42() or self.is_less_than_pg90(): |
| return self.__extract |
| else: |
| return self.__identity |
| |
| def __extract(self, origin, text=True): |
| """ |
| Extract vector elements from a string with {} |
| as the brackets |
| """ |
| if origin is None: |
| return None |
| elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1)) |
| if text is False: |
| for i in range(len(elm)): |
| elm[i] = float(elm[i]) |
| return elm |
| |
| def __identity(self, origin, text=True): |
| return origin |
| |
| def select_vec_return(self): |
| """ |
| Special care is needed if one needs to return |
| vector from Python to SQL |
| """ |
| if self.is_less_than_gp42() or self.is_less_than_pg90(): |
| return self.__condense |
| else: |
| return self.__identity |
| |
| def __condense(self, origin): |
| """ |
| Convert the original vector into a string which some |
| old versions of SQL system can recognize |
| """ |
| return _array_to_string(origin) |
| |
| def select_array_agg(self, schema_madlib): |
| """ |
| GPDB < 4.1 and PG < 9.0 do not have support for array_agg, |
| so use the madlib array_agg for those versions |
| """ |
| if self.is_less_than_gp41() or self.is_less_than_pg90(): |
| return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib) |
| else: |
| return "array_agg" |
| |
| def is_pg(self): |
| if (re.search(r"PostgreSQL", self.version) and |
| not re.search(r"Greenplum\s*Database", self.version)): |
| return True |
| return False |
| |
| def is_gp43(self): |
| if re.search(r"Greenplum\s+Database\s+4\.3", self.version): |
| return True |
| return False |
| |
| def is_less_than_pg90(self): |
| regex = re.compile('PostgreSQL\s*([0-9])([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9: |
| return True |
| else: |
| return False |
| |
| def is_less_than_gp41(self): |
| regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and float(version[0]) < 4.1: |
| return True |
| else: |
| return False |
| |
| def is_less_than_gp42(self): |
| regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and float(version[0]) < 4.2: |
| return True |
| else: |
| return False |
| |
| def is_hawq(self): |
| if re.search(r"HAWQ\s+[0-9.]+", self.version): |
| return True |
| return False |
| |
| def is_pg_version_less_than(self, compare_version): |
| """ Return True if self is a PostgreSQL database and |
| self.version is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| |
| """ |
| regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and self.is_pg(): |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| def is_gp_version_less_than(self, compare_version): |
| """ Return True if self is a Greenplum database and self.version |
| is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. Expected |
| format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| """ |
| regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0 and not self.is_hawq(): |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| def is_hq_version_less_than(self, compare_version): |
| """ Return True if self is a HAWQ database and self.version |
| is less than compare_version |
| |
| @param compare_version: str, String form of the comparison version. |
| Expected format is Semantic Versioning. |
| examples of versions: 1.0, 2.3, 9.3.5 |
| """ |
| regex = re.compile('HAWQ\s*([0-9.]+)', re.IGNORECASE) |
| version = regex.findall(self.version) |
| if len(version) > 0: |
| db_ver = [float(i) for i in version[0].split('.') if i.isdigit()] |
| cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()] |
| return db_ver < cmp_ver |
| else: |
| return False |
| |
| |
| def _string_to_sql_array(schema_madlib, s, **kwargs): |
| """ |
| Split a string into an array of strings |
| Any space around the substrings are removed |
| |
| Requirement: every individual element in the string |
| must be a valid Postgres name, which means that if |
| there are spaces or commas in the element then the |
| whole element must be quoted by a pair of double |
| quotes. |
| |
| Usually this is not a problem. Especially in older |
| versions of GPDB, when an array is passed from |
| SQL to Python, it is converted to a string, which |
| automatically adds double quotes if there are spaces or |
| commas in the element. |
| |
| So use this function, if you are sure that all elements |
| are valid Postgres names. |
| """ |
| # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0 |
| version_wrapper = __mad_version() |
| array_to_string = version_wrapper.select_vec_return() |
| |
| elm = [] |
| for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s): |
| elm.append(m.group(1)) |
| for i in range(len(elm)): |
| elm[i] = elm[i].strip("\"") |
| return array_to_string(elm) |
| # ------------------------------------------------------------------------ |
| |
| |
| def current_user(): |
| """Returns the user name of the current database user.""" |
| return plpy.execute("SELECT current_user")[0]['current_user'] |
| # ------------------------------------------------------------------------ |
| |
| |
| def madlib_version(schema_madlib): |
| """Returns the MADlib version string.""" |
| raw = plpy.execute(""" |
| SELECT {schema_madlib}.version() |
| """.format(**locals()))[0]['version'] |
| return raw.split(',')[0].split(' ')[-1] |
| # ------------------------------------------------------------------------ |
| |
| |
| def preprocess_keyvalue_params(input_params, split_char='='): |
| """ |
| Parse the input_params string and split it using the split_char |
| |
| @param input_params: str, Comma-separated list of parameters |
| The parameter can be any key = value, where |
| key is a string of alphanumeric character |
| value is either and |
| @param split_char: str, character that splits the key and value elements. |
| Default set to '=' |
| """ |
| re_str = (r"([-:\w]+\s*" + # key is any alphanumeric character |
| # (including - and :) string |
| |
| split_char + # key and value are separated by split_char |
| |
| r""" |
| \s*([\(\{\[] # value can be string or array |
| [^\[\]\(\)\{\}]* # if value is array then accept anything inside |
| [\)\}\]] # and then match closing braces of array |
| |
| | # either array (above) or string (below) |
| |
| (?P<quote>\"?)[\w\s\-\%.]+(?P=quote) |
| # if value is string, it can be alphanumeric |
| # character string with a decimal dot, |
| # hyphen, or percent |
| # optionally quoted by `quote_char` |
| ) |
| )""" |
| ) |
| pattern = re.compile(re_str, re.VERBOSE) |
| return [m.group(1).strip() for m in pattern.finditer(input_params)] |
| # ------------------------------------------------------------------------ |
| |
| |
| def extract_keyvalue_params(input_params, |
| input_param_types=None, |
| default_values=None, |
| split_char='=', |
| usage_str='', |
| ignore_invalid=False, |
| allow_duplicates=True, |
| lower_case_names=True): |
| """ Extract key value pairs from input parameters or set the default values |
| |
| Args: |
| @param input_params: string, Format of |
| 'key1=value1, key2=value2,...', assuming default split_char. |
| The order does not matter. If a parameter is missing, then |
| the default value is used. If input_params is None or '', |
| then all default values are returned. This function also |
| validates the values of these parameters. |
| |
| @param input_param_types: dict, The type of each allowed parameter |
| name. Currently supports one of |
| (int, float, str, list) |
| @param default_values: dict, Default values for each allowed parameter. |
| @param split_char: str, The character used to split key and value. |
| Default set to '=' |
| @param usage_str: str, An optional usage string to print with error message. |
| @param ignore_invalid: bool, If True an invalid param input is ignore silently |
| @param allow_duplicates: bool, Allow repeat of a 'key' in input_params, |
| where the last occurence will be reflected |
| in the output. If False, then a ValueError is |
| raised. |
| @param lower_case_names: bool, Convert parameter names to lower case. |
| |
| Returns: |
| Dict. Dictionary of input parameter values with key as parameter name |
| and value as the parameter value |
| |
| Throws: |
| plpy.error - If the parameter is unsupported or the value is |
| not valid. |
| """ |
| if not input_params: |
| return default_values if default_values is not None else {} |
| |
| if default_values: |
| parameter_dict = default_values |
| else: |
| parameter_dict = {} |
| seen_params = set() |
| |
| for s in preprocess_keyvalue_params(input_params, split_char=split_char): |
| items = split_quoted_delimited_str(s, delimiter=split_char) |
| if (len(items) != 2): |
| raise KeyError("Input parameter list has incorrect format " |
| "{0}".format(usage_str)) |
| |
| param_name = items[0].strip(" \"") |
| if lower_case_names: |
| param_name = param_name.lower() |
| param_value = items[1].strip() |
| |
| if not allow_duplicates and param_name in seen_params: |
| raise ValueError("Invalid input: {0} duplicated in the param list". |
| format(param_name)) |
| |
| if not param_name or param_name in ('none', 'null'): |
| plpy.error("Invalid input param name: {0} \n" |
| "{1}".format(param_name, usage_str)) |
| if input_param_types: |
| try: |
| param_type = input_param_types[param_name] |
| except KeyError: |
| if not ignore_invalid: |
| raise KeyError("Invalid input: {0} is not a valid parameter " |
| "{1}".format(param_name, usage_str)) |
| else: |
| continue |
| try: |
| if param_type == bool: # bool is not subclassable |
| # True values are y, yes, t, true, on and 1; |
| # False values are n, no, f, false, off and 0. |
| # Raises ValueError if anything else. |
| parameter_dict[param_name] = bool(strtobool(param_value)) |
| elif param_type in (int, str, float): |
| parameter_dict[param_name] = param_type(param_value) |
| elif issubclass(param_type, collections.Iterable): |
| parameter_dict[param_name] = split_quoted_delimited_str( |
| param_value.strip('[](){} ')) |
| else: |
| raise TypeError("Invalid input: {0} has unsupported type " |
| "{1}".format(param_name, usage_str)) |
| except ValueError: |
| raise ValueError("Invalid input: {0} must be {1} \n" |
| "{2}".format(param_name, param_type, usage_str)) |
| else: |
| # if input parameter types not provided then just return all as string |
| parameter_dict[param_name] = str(param_value) |
| seen_params.add(param_name) |
| return parameter_dict |
| # ------------------------------------------------------------------------- |
| |
| |
| def split_quoted_delimited_str(input_str, delimiter=',', quote='"'): |
| """ Parse a delimited-string to return a list of individual tokens taking |
| quotes into account. |
| |
| Args: |
| @param input_str: str, Delimited input string |
| @param delimiter: str, The field delimiter character that separates |
| the tokens in input_str. Default = ',' |
| @param quote: str, One-character string used to quote fields containing |
| special characters, such as the field delimiter |
| Default = '"' |
| |
| Returns: |
| List. List of delimited strings. |
| """ |
| if not input_str or not delimiter or not quote: |
| return [] |
| try: |
| delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'. |
| format(d=delimiter, q=quote)) |
| return [i.strip() for i in |
| delimiter_reg.split(input_str.strip())[1::2]] |
| except Exception as e: |
| plpy.warning(str(e)) |
| raise ValueError("Invalid string input for splitting") |
| # ------------------------------------------------------------------------------ |
| |
| |
| def strip_end_quotes(input_str, quote='"'): |
| """ Remove the quote character from the start and end if they are present |
| (at both ends). Whitespace at start and end are not ignored. |
| |
| Args: |
| @param input_str |
| @param quote |
| |
| Returns: |
| str. Original string without the quotes at start and end |
| """ |
| if not input_str or not quote: |
| return input_str |
| if not isinstance(input_str, str): |
| return input_str |
| if input_str.startswith(quote) and input_str.endswith(quote): |
| return input_str[1:-1] |
| else: |
| return input_str |
| # ------------------------------------------------------------------------------ |
| |
| |
| import unittest |
| |
| |
| class UtilitiesTestCase(unittest.TestCase): |
| """ |
| Comment "import plpy" and replace plpy.error calls with appropriate |
| Python Exceptions to successfully run the test cases |
| """ |
| def setUp(self): |
| self.optimizer_params1 = 'max_iter=10, optimizer::text="irls", precision=1e-4' |
| self.optimizer_params2 = 'max_iter=.01, optimizer=newton-irls, precision=1e-5' |
| self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,"2,2",3,4}' |
| self.optimizer_params4 = ('max_iter=10, optimizer="irls",' |
| 'precision=0.02.01, lambda={1,2,3,4}') |
| self.optimizer_params5 = ('max_iter=10, optimizer="irls",' |
| 'precision=0.02, PRECISION=2., lambda={1,2,3,4}') |
| self.optimizer_types = {'max_iter': int, 'optimizer': str, 'optimizer::text': str, |
| 'lambda': list, 'precision': float} |
| |
| def test_preprocess_optimizer(self): |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params1), |
| ['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4']) |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params2), |
| ['max_iter=.01', 'optimizer=newton-irls', 'precision=1e-5']) |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params3), |
| ['max_iter=10', 'lambda={1,"2,2",3,4}']) |
| self.assertEqual(preprocess_keyvalue_params(self.optimizer_params4), |
| ['max_iter=10', 'optimizer="irls"', 'precision=0.02.01', 'lambda={1,2,3,4}']) |
| |
| def test_extract_optimizers(self): |
| self.assertEqual({'max_iter': 10, 'optimizer::text': '"irls"', 'precision': 0.0001}, |
| extract_keyvalue_params(self.optimizer_params1, self.optimizer_types)) |
| self.assertEqual({'max_iter': 10, 'lambda': ['1', '"2,2"', '3', '4']}, |
| extract_keyvalue_params(self.optimizer_params3, self.optimizer_types)) |
| self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', 'precision': '0.02.01', |
| 'lambda': '{1,2,3,4}'}, |
| extract_keyvalue_params(self.optimizer_params4)) |
| self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', |
| 'PRECISION': '2.', 'precision': '0.02', |
| 'lambda': '{1,2,3,4}'}, |
| extract_keyvalue_params(self.optimizer_params5, |
| allow_duplicates=False, |
| lower_case_names=False |
| )) |
| self.assertRaises(ValueError, |
| extract_keyvalue_params, self.optimizer_params2, self.optimizer_types) |
| self.assertRaises(ValueError, |
| extract_keyvalue_params, self.optimizer_params5, allow_duplicates=False) |
| self.assertRaises(ValueError, |
| extract_keyvalue_params, self.optimizer_params4, self.optimizer_types) |
| |
| def test_split_delimited_string(self): |
| self.assertEqual(['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'], |
| split_quoted_delimited_str(self.optimizer_params1, quote='"')) |
| self.assertEqual(['a', 'b', 'c'], split_quoted_delimited_str('a, b, c', quote='|')) |
| self.assertEqual(['a', '|b, c|'], split_quoted_delimited_str('a, |b, c|', quote='|')) |
| self.assertEqual(['a', '"b, c"'], split_quoted_delimited_str('a, "b, c"')) |
| self.assertEqual(['"a^5,6"', 'b', 'c'], split_quoted_delimited_str('"a^5,6", b, c', quote='"')) |
| self.assertEqual(['"A""^5,6"', 'b', 'c'], split_quoted_delimited_str('"A""^5,6", b, c', quote='"')) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |