blob: 133f4aca62fc74df11f2a8cf952ae6969e3c8ffd [file] [log] [blame]
import collections
import re
import time
import random
from distutils.util import strtobool
from validate_args import _get_table_schema_names
from validate_args import get_first_schema
from validate_args import cols_in_tbl_valid
from validate_args import explicit_bool_to_text
from validate_args import input_tbl_valid
from validate_args import is_var_valid
from validate_args import output_tbl_valid
import plpy
m4_changequote(`<!', `!>')
def has_function_properties():
""" __HAS_FUNCTION_PROPERTIES__ variable defined during configure """
return m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
def is_platform_pg():
""" __POSTGRESQL__ variable defined during configure """
return m4_ifdef(<!__POSTGRESQL__!>, <!True!>, <!False!>)
# ------------------------------------------------------------------------------
def is_platform_hawq():
""" __HAWQ__ variable defined during configure """
return m4_ifdef(<!__HAWQ__!>, <!True!>, <!False!>)
# ------------------------------------------------------------------------------
def get_seg_number():
""" Find out how many primary segments exist in the distribution
Might be useful for partitioning data.
"""
if is_platform_pg():
return 1
else:
return plpy.execute("""
SELECT count(*) from gp_segment_configuration
WHERE role = 'p'
""")[0]['count']
# ------------------------------------------------------------------------------
def is_orca():
if has_function_properties():
optimizer = plpy.execute("show optimizer")[0]["optimizer"]
if optimizer == 'on':
return True
return False
# ------------------------------------------------------------------------------
def _assert_equal(o1, o2, msg):
"""
@brief if the given objects are not equal, then raise an error with the message
@param o1 the first object
@param o2 the second object
@param msg the error message to be reported
"""
if not o1 == o2:
plpy.error(msg)
# ------------------------------------------------------------------------------
def _assert(condition, msg):
"""
@brief if the given condition is false, then raise an error with the message
@param condition the condition to be asserted
@param msg the error message to be reported
"""
if not condition:
plpy.error(msg)
# ------------------------------------------------------------------------------
def warn(condition, msg):
"""
@brief if the given condition is false, then raise a warning with the message
@param condition the condition to be asserted
@param msg the error message to be reported
"""
if not condition:
plpy.warning(msg)
# ------------------------------------------------------------------------------
def get_distribution_policy(source_table):
""" Return a list of columns that define the distribution policy of source_table
Args:
@param source_table
Returns:
List of str.
"""
_, table_name = _get_table_schema_names(source_table)
schema_name = get_first_schema(source_table)
dist_attr = plpy.execute("""
SELECT array_agg(pga.attname) as dist_attr
FROM (
SELECT gdp.localoid,
CASE
WHEN ( ARRAY_UPPER(gdp.attrnums, 1) > 0 ) THEN
UNNEST(gdp.attrnums)
ELSE NULL
END AS attnum
FROM gp_distribution_policy gdp
) AS distkey
INNER JOIN pg_class AS pgc
ON distkey.localoid = pgc.oid AND pgc.relname = '{table_name}'
INNER JOIN pg_namespace pgn
ON pgc.relnamespace = pgn.oid AND pgn.nspname = '{schema_name}'
LEFT OUTER JOIN pg_attribute pga
ON distkey.attnum = pga.attnum AND distkey.localoid = pga.attrelid
""".format(table_name=table_name, schema_name=schema_name))[0]["dist_attr"]
return dist_attr
# ------------------------------------------------------------------------------
def num_features(source_table, independent_varname):
return plpy.execute("SELECT array_upper({0}, 1) AS dim "
"FROM {1} LIMIT 1"
.format(independent_varname,
source_table))[0]['dim']
# ------------------------------------------------------------------------------
def num_samples(source_table):
return plpy.execute("SELECT count(*) AS n FROM {0}"
.format(source_table))[0]['n']
# ------------------------------------------------------------------------------
def unique_string(desp='', **kwargs):
"""
Generate random remporary names for temp table and other names.
It has a SQL interface so both SQL and Python functions can call it.
"""
r1 = random.randint(1, 100000000)
r2 = int(time.time())
r3 = int(time.time()) % random.randint(1, 100000000)
u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
return u_string
# ------------------------------------------------------------------------------
def add_postfix(quoted_string, postfix):
""" Append a string to the end of the table name.
If input table name is quoted by double quotes, make sure the postfix is
inside of the double quotes.
Arguments:
@param quoted_string: str. A string representing a database quoted string
@param postfix: str. A string to add as a suffix to quoted_string.
** This is assumed to not contain any quotes **
"""
quoted_string = quoted_string.strip()
if quoted_string.startswith('"') and quoted_string.endswith('"'):
output_str = quoted_string[:-1] + postfix + '"'
else:
output_str = quoted_string + postfix
return output_str
# -------------------------------------------------------------------------
def is_psql_numeric_type(arg, exclude=None):
"""
Checks if argument is one of the various numeric types in PostgreSQL
Args:
@param arg: string, Type name to check
@param exclude: iterable, List of types to exclude from checking
Returns:
Boolean. Returns if 'arg' is one of the numeric types
"""
numeric_types = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric',
'real', 'double precision', 'serial', 'bigserial'])
if exclude is None:
exclude = []
to_check_types = numeric_types - set(exclude)
return (arg in to_check_types)
# -------------------------------------------------------------------------
def is_string_formatted_as_array_expression(string_to_match):
"""
Return true if the string is formatted as array[<something>], else false
:param string_to_match:
"""
matched = re.match(r"(?i)^array\[(.*)\]", string_to_match)
return matched
def _string_to_array(s):
"""
Split a string into an array of strings
Any space around the substrings are removed
Requirement: every individual element in the string
must be a valid Postgres name, which means that if
there are spaces or commas in the element then the
whole element must be quoted by a pair of double
quotes.
Usually this is not a problem. Especially in older
versions of GPDB, when an array is passed from
SQL to Python, it is converted to a string, which
automatically adds double quotes if there are spaces or
commas in the element.
So use this function, if you are sure that all elements
are valid Postgres names.
"""
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
elm.append(m.group(1))
for i in range(len(elm)):
elm[i] = elm[i].strip("\"")
return elm
# ------------------------------------------------------------------------
def _string_to_array_with_quotes(s):
"""
Same as _string_to_array except the double quotes will be kept.
"""
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
elm.append(m.group(1))
return elm
# ------------------------------------------------------------------------
def py_list_to_sql_string(array, array_type=None, long_format=None):
"""Convert a list to SQL array string """
if long_format is None:
if (array_type is not None and
(any(array_type.startswith(i)
for i in ["text", "varchar", "character varying"]))):
long_format = False
else:
long_format = True
if not array_type:
array_type = "double precision[]"
else:
array_type = array_type.strip()
if not array_type.endswith("[]"):
array_type += "[]"
if not array:
return "'{{ }}'::{0}".format(array_type)
else:
array_str = "ARRAY[ {0} ]" if long_format else "'{{ {0} }}'"
return (array_str + "::{1}").format(','.join(map(str, array)), array_type)
# ------------------------------------------------------------------------
def _array_to_string(origin):
"""
Convert an array to string
"""
def _escape(s):
return re.sub(r'"', r'\"', str(s))
return "{" + ",".join(map(_escape, origin)) + "}"
# ------------------------------------------------------------------------
def _cast_if_null(input, alias=''):
if input:
return str(input)
else:
null_str = "NULL::text"
return null_str + " as " + alias if alias else null_str
# ------------------------------------------------------------------------
def set_client_min_messages(new_level):
"""
Set the client_min_message setting in Postgres which controls the messages
sent to the psql client.
Args:
@param new_level: string, New level to set the client_min_message
Returns:
old_msg_level: string, The old client_min_message level before changing it.
Raise:
ValueError: if the argument new_level is not a valid message level
"""
if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2',
'debug1', 'log', 'notice', 'warning', 'error',
'fatal', 'panic'):
raise ValueError("Not a valid message level sent to client")
old_msg_level = plpy.execute(""" SELECT setting
FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
plpy.execute("SET client_min_messages TO {0}".format(new_level))
return old_msg_level
# -------------------------------------------------------------------------
# Deal with earlier versions of PG or GPDB
class __mad_version:
def __init__(self):
self.version = plpy.execute("select version()")[0]["version"]
def select_vecfunc(self):
"""
PG84 and GP40, GP41 do not have a good support for
vectors. They convert any vector into a string, surrounded
by { and }. Thus special care is needed for these older
versions of GPDB and PG.
"""
# GPDB 4.0 or 4.1
if self.is_less_than_gp42() or self.is_less_than_pg90():
return self.__extract
else:
return self.__identity
def __extract(self, origin, text=True):
"""
Extract vector elements from a string with {}
as the brackets
"""
if origin is None:
return None
elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1))
if text is False:
for i in range(len(elm)):
elm[i] = float(elm[i])
return elm
def __identity(self, origin, text=True):
return origin
def select_vec_return(self):
"""
Special care is needed if one needs to return
vector from Python to SQL
"""
if self.is_less_than_gp42() or self.is_less_than_pg90():
return self.__condense
else:
return self.__identity
def __condense(self, origin):
"""
Convert the original vector into a string which some
old versions of SQL system can recognize
"""
return _array_to_string(origin)
def select_array_agg(self, schema_madlib):
"""
GPDB < 4.1 and PG < 9.0 do not have support for array_agg,
so use the madlib array_agg for those versions
"""
if self.is_less_than_gp41() or self.is_less_than_pg90():
return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib)
else:
return "array_agg"
def is_pg(self):
if (re.search(r"PostgreSQL", self.version) and
not re.search(r"Greenplum\s*Database", self.version)):
return True
return False
def is_gp43(self):
if re.search(r"Greenplum\s+Database\s+4\.3", self.version):
return True
return False
def is_less_than_pg90(self):
regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9:
return True
else:
return False
def is_less_than_gp41(self):
regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and float(version[0]) < 4.1:
return True
else:
return False
def is_less_than_gp42(self):
regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and float(version[0]) < 4.2:
return True
else:
return False
def is_hawq(self):
if re.search(r"HAWQ\s+[0-9.]+", self.version):
return True
return False
def is_pg_version_less_than(self, compare_version):
""" Return True if self is a PostgreSQL database and
self.version is less than compare_version
@param compare_version: str, String form of the comparison version. Expected
format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
"""
regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and self.is_pg():
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
else:
return False
def is_gp_version_less_than(self, compare_version):
""" Return True if self is a Greenplum database and self.version
is less than compare_version
@param compare_version: str, String form of the comparison version. Expected
format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
"""
regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and not self.is_hawq():
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
else:
return False
def is_hq_version_less_than(self, compare_version):
""" Return True if self is a HAWQ database and self.version
is less than compare_version
@param compare_version: str, String form of the comparison version.
Expected format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
"""
regex = re.compile('HAWQ\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0:
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
else:
return False
def _string_to_sql_array(schema_madlib, s, **kwargs):
"""
Split a string into an array of strings
Any space around the substrings are removed
Requirement: every individual element in the string
must be a valid Postgres name, which means that if
there are spaces or commas in the element then the
whole element must be quoted by a pair of double
quotes.
Usually this is not a problem. Especially in older
versions of GPDB, when an array is passed from
SQL to Python, it is converted to a string, which
automatically adds double quotes if there are spaces or
commas in the element.
So use this function, if you are sure that all elements
are valid Postgres names.
"""
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
version_wrapper = __mad_version()
array_to_string = version_wrapper.select_vec_return()
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
elm.append(m.group(1))
for i in range(len(elm)):
elm[i] = elm[i].strip("\"")
return array_to_string(elm)
# ------------------------------------------------------------------------
def current_user():
"""Returns the user name of the current database user."""
return plpy.execute("SELECT current_user")[0]['current_user']
# ------------------------------------------------------------------------
def madlib_version(schema_madlib):
"""Returns the MADlib version string."""
raw = plpy.execute("""
SELECT {schema_madlib}.version()
""".format(**locals()))[0]['version']
return raw.split(',')[0].split(' ')[-1]
# ------------------------------------------------------------------------
def preprocess_keyvalue_params(input_params, split_char='='):
"""
Parse the input_params string and split it using the split_char
@param input_params: str, Comma-separated list of parameters
The parameter can be any key = value, where
key is a string of alphanumeric character
value is either and
@param split_char: str, character that splits the key and value elements.
Default set to '='
"""
re_str = (r"([-:\w]+\s*" + # key is any alphanumeric character
# (including - and :) string
split_char + # key and value are separated by split_char
r"""
\s*([\(\{\[] # value can be string or array
[^\[\]\(\)\{\}]* # if value is array then accept anything inside
[\)\}\]] # and then match closing braces of array
| # either array (above) or string (below)
(?P<quote>\"?)[\w\s\-\%.]+(?P=quote)
# if value is string, it can be alphanumeric
# character string with a decimal dot,
# hyphen, or percent
# optionally quoted by `quote_char`
)
)"""
)
pattern = re.compile(re_str, re.VERBOSE)
return [m.group(1).strip() for m in pattern.finditer(input_params)]
# ------------------------------------------------------------------------
def extract_keyvalue_params(input_params,
input_param_types=None,
default_values=None,
split_char='=',
usage_str='',
ignore_invalid=False,
allow_duplicates=True,
lower_case_names=True):
""" Extract key value pairs from input parameters or set the default values
Args:
@param input_params: string, Format of
'key1=value1, key2=value2,...', assuming default split_char.
The order does not matter. If a parameter is missing, then
the default value is used. If input_params is None or '',
then all default values are returned. This function also
validates the values of these parameters.
@param input_param_types: dict, The type of each allowed parameter
name. Currently supports one of
(int, float, str, list)
@param default_values: dict, Default values for each allowed parameter.
@param split_char: str, The character used to split key and value.
Default set to '='
@param usage_str: str, An optional usage string to print with error message.
@param ignore_invalid: bool, If True an invalid param input is ignore silently
@param allow_duplicates: bool, Allow repeat of a 'key' in input_params,
where the last occurence will be reflected
in the output. If False, then a ValueError is
raised.
@param lower_case_names: bool, Convert parameter names to lower case.
Returns:
Dict. Dictionary of input parameter values with key as parameter name
and value as the parameter value
Throws:
plpy.error - If the parameter is unsupported or the value is
not valid.
"""
if not input_params:
return default_values if default_values is not None else {}
if default_values:
parameter_dict = default_values
else:
parameter_dict = {}
seen_params = set()
for s in preprocess_keyvalue_params(input_params, split_char=split_char):
items = split_quoted_delimited_str(s, delimiter=split_char)
if (len(items) != 2):
raise KeyError("Input parameter list has incorrect format "
"{0}".format(usage_str))
param_name = items[0].strip(" \"")
if lower_case_names:
param_name = param_name.lower()
param_value = items[1].strip()
if not allow_duplicates and param_name in seen_params:
raise ValueError("Invalid input: {0} duplicated in the param list".
format(param_name))
if not param_name or param_name in ('none', 'null'):
plpy.error("Invalid input param name: {0} \n"
"{1}".format(param_name, usage_str))
if input_param_types:
try:
param_type = input_param_types[param_name]
except KeyError:
if not ignore_invalid:
raise KeyError("Invalid input: {0} is not a valid parameter "
"{1}".format(param_name, usage_str))
else:
continue
try:
if param_type == bool: # bool is not subclassable
# True values are y, yes, t, true, on and 1;
# False values are n, no, f, false, off and 0.
# Raises ValueError if anything else.
parameter_dict[param_name] = bool(strtobool(param_value))
elif param_type in (int, str, float):
parameter_dict[param_name] = param_type(param_value)
elif issubclass(param_type, collections.Iterable):
parameter_dict[param_name] = split_quoted_delimited_str(
param_value.strip('[](){} '))
else:
raise TypeError("Invalid input: {0} has unsupported type "
"{1}".format(param_name, usage_str))
except ValueError:
raise ValueError("Invalid input: {0} must be {1} \n"
"{2}".format(param_name, param_type, usage_str))
else:
# if input parameter types not provided then just return all as string
parameter_dict[param_name] = str(param_value)
seen_params.add(param_name)
return parameter_dict
# -------------------------------------------------------------------------
def split_quoted_delimited_str(input_str, delimiter=',', quote='"'):
""" Parse a delimited-string to return a list of individual tokens taking
quotes into account.
Args:
@param input_str: str, Delimited input string
@param delimiter: str, The field delimiter character that separates
the tokens in input_str. Default = ','
@param quote: str, One-character string used to quote fields containing
special characters, such as the field delimiter
Default = '"'
Returns:
List. List of delimited strings.
"""
if not input_str or not delimiter or not quote:
return []
try:
delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'.
format(d=delimiter, q=quote))
return [i.strip() for i in
delimiter_reg.split(input_str.strip())[1::2]]
except Exception as e:
plpy.warning(str(e))
raise ValueError("Invalid string input for splitting")
# ------------------------------------------------------------------------------
def strip_end_quotes(input_str, quote='"'):
""" Remove the quote character from the start and end if they are present
(at both ends). Whitespace at start and end are not ignored.
Args:
@param input_str
@param quote
Returns:
str. Original string without the quotes at start and end
"""
if not input_str or not quote:
return input_str
if not isinstance(input_str, str):
return input_str
if input_str.startswith(quote) and input_str.endswith(quote):
return input_str[1:-1]
else:
return input_str
# ------------------------------------------------------------------------------
def _grp_null_checks(grp_list):
"""
Helper function for generating NULL checks for grouping columns
to be used within a WHERE clause
Args:
@param grp_list The list of grouping columns
"""
return ' AND '.join([" {i} IS NOT NULL ".format(**locals())
for i in grp_list])
# ------------------------------------------------------------------------------
def _check_groups(tbl1, tbl2, grp_list):
"""
Helper function for joining tables with groups.
Args:
@param tbl1 Name of the first table
@param tbl2 Name of the second table
@param grp_list The list of grouping columns
"""
return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals())
for i in grp_list])
# ------------------------------------------------------------------------------
def get_filtered_cols_subquery_str(include_from_table, exclude_from_table,
filter_cols_list):
"""
This function returns a subquery string with columns in the filter_cols_list
that appear in include_from_table but NOT IN exclude_from_table.
:param include_from_table: table from which cols in filter_cols_list should
be included
:param exclude_from_table: table from which cols in filter_cols_list should
be excluded
:param filter_cols_list: list of column names
:return: query string with relevant column names
"""
included_cols = get_table_qualified_col_str(include_from_table, filter_cols_list)
cols = ', '.join([col for col in filter_cols_list])
return """({included_cols}) NOT IN
(SELECT {cols}
FROM {exclude_from_table})
""".format(**locals())
# ------------------------------------------------------------------------------
def get_table_qualified_col_str(tbl_name, col_list):
"""
Helper function for selecting columns of a table
Args:
@param tbl Name of the table
@param grp_list The list of grouping columns
"""
return ' , '.join([" {tbl_name}.{col} ".format(**locals())
for col in col_list])
def get_grouping_col_str(schema_madlib, module_name, reserved_cols,
source_table, grouping_col):
if grouping_col and grouping_col.lower() != 'null':
cols_in_tbl_valid(source_table,
_string_to_array_with_quotes(grouping_col),
module_name)
intersect = frozenset(
_string_to_array(grouping_col)).intersection(frozenset(reserved_cols))
_assert(len(intersect) == 0,
"{0} error: Conflicting grouping column name.\n"
"Some predefined keyword(s) ({1}) are not allowed "
"for grouping column names!".format(module_name, ', '.join(intersect)))
grouping_list = [i + "::text"
for i in explicit_bool_to_text(
source_table,
_string_to_array_with_quotes(grouping_col),
schema_madlib)]
grouping_str = ','.join(grouping_list)
else:
grouping_str = "Null"
grouping_col = None
return grouping_str, grouping_col
# ------------------------------------------------------------------------------
def collate_plpy_result(plpy_result_rows):
if not plpy_result_rows:
return {}
else:
all_keys = plpy_result_rows[0].keys()
result = collections.defaultdict(list)
for each_row in plpy_result_rows:
for each_key in all_keys:
result[each_key].append(each_row[each_key])
return result
# ------------------------------------------------------------------------------
def validate_module_input_params(source_table, output_table, independent_varname,
dependent_varname, module_name,
other_output_tables=None):
"""
This function is supposed to be used for validating params for
supervised learning like algos, e.g. linear regression, mlp, etc. since all
of them need to validate the following 4 parameters.
:param source_table: This table should exist and not be empty
:param output_table: This table should not exist
:param dependent_varname: This should be a valid expression in the source
table
:param independent_varname: This should be a valid expression in the source
table
:param module_name: Name of the module to be printed with the error messages
:param other_output_tables: List of additional output tables to validate.
These tables should not exist
"""
input_tbl_valid(source_table, module_name)
output_tbl_valid(output_table, module_name)
if other_output_tables:
for tbl in other_output_tables:
output_tbl_valid(tbl, module_name)
_assert(is_var_valid(source_table, independent_varname),
"{module_name} error: invalid independent_varname "
"('{independent_varname}') for source_table "
"({source_table})!".format(module_name=module_name,
independent_varname=independent_varname,
source_table=source_table))
_assert(is_var_valid(source_table, dependent_varname),
"{module_name} error: invalid dependent_varname "
"('{dependent_varname}') for source_table "
"({source_table})!".format(module_name=module_name,
dependent_varname=dependent_varname,
source_table=source_table))
# ------------------------------------------------------------------------
import unittest
class UtilitiesTestCase(unittest.TestCase):
"""
Comment "import plpy" and replace plpy.error calls with appropriate
Python Exceptions to successfully run the test cases
"""
def setUp(self):
self.optimizer_params1 = 'max_iter=10, optimizer::text="irls", precision=1e-4'
self.optimizer_params2 = 'max_iter=.01, optimizer=newton-irls, precision=1e-5'
self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,"2,2",3,4}'
self.optimizer_params4 = ('max_iter=10, optimizer="irls",'
'precision=0.02.01, lambda={1,2,3,4}')
self.optimizer_params5 = ('max_iter=10, optimizer="irls",'
'precision=0.02, PRECISION=2., lambda={1,2,3,4}')
self.optimizer_types = {'max_iter': int, 'optimizer': str, 'optimizer::text': str,
'lambda': list, 'precision': float}
def test_preprocess_optimizer(self):
self.assertEqual(preprocess_keyvalue_params(self.optimizer_params1),
['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'])
self.assertEqual(preprocess_keyvalue_params(self.optimizer_params2),
['max_iter=.01', 'optimizer=newton-irls', 'precision=1e-5'])
self.assertEqual(preprocess_keyvalue_params(self.optimizer_params3),
['max_iter=10', 'lambda={1,"2,2",3,4}'])
self.assertEqual(preprocess_keyvalue_params(self.optimizer_params4),
['max_iter=10', 'optimizer="irls"', 'precision=0.02.01', 'lambda={1,2,3,4}'])
def test_extract_optimizers(self):
self.assertEqual({'max_iter': 10, 'optimizer::text': '"irls"', 'precision': 0.0001},
extract_keyvalue_params(self.optimizer_params1, self.optimizer_types))
self.assertEqual({'max_iter': 10, 'lambda': ['1', '"2,2"', '3', '4']},
extract_keyvalue_params(self.optimizer_params3, self.optimizer_types))
self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', 'precision': '0.02.01',
'lambda': '{1,2,3,4}'},
extract_keyvalue_params(self.optimizer_params4))
self.assertEqual({'max_iter': '10', 'optimizer': '"irls"',
'PRECISION': '2.', 'precision': '0.02',
'lambda': '{1,2,3,4}'},
extract_keyvalue_params(self.optimizer_params5,
allow_duplicates=False,
lower_case_names=False
))
self.assertRaises(ValueError,
extract_keyvalue_params, self.optimizer_params2, self.optimizer_types)
self.assertRaises(ValueError,
extract_keyvalue_params, self.optimizer_params5, allow_duplicates=False)
self.assertRaises(ValueError,
extract_keyvalue_params, self.optimizer_params4, self.optimizer_types)
def test_split_delimited_string(self):
self.assertEqual(['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'],
split_quoted_delimited_str(self.optimizer_params1, quote='"'))
self.assertEqual(['a', 'b', 'c'], split_quoted_delimited_str('a, b, c', quote='|'))
self.assertEqual(['a', '|b, c|'], split_quoted_delimited_str('a, |b, c|', quote='|'))
self.assertEqual(['a', '"b, c"'], split_quoted_delimited_str('a, "b, c"'))
self.assertEqual(['"a^5,6"', 'b', 'c'], split_quoted_delimited_str('"a^5,6", b, c', quote='"'))
self.assertEqual(['"A""^5,6"', 'b', 'c'], split_quoted_delimited_str('"A""^5,6", b, c', quote='"'))
def test_collate_plpy_result(self):
plpy_result1 = [{'classes': '4', 'class_count': 3},
{'classes': '1', 'class_count': 18},
{'classes': '5', 'class_count': 7},
{'classes': '3', 'class_count': 3},
{'classes': '6', 'class_count': 7},
{'classes': '2', 'class_count': 7}]
self.assertEqual(collate_plpy_result(plpy_result1),
{'classes': ['4', '1', '5', '3', '6', '2'],
'class_count': [3, 18, 7, 3, 7, 7]})
self.assertEqual(collate_plpy_result([]), {})
self.assertEqual(collate_plpy_result([{'class': 'a'},
{'class': 'b'},
{'class': 'c'}]),
{'class': ['a', 'b', 'c']})
if __name__ == '__main__':
unittest.main()