import collections
import re
import time
import random
from distutils.util import strtobool
from validate_args import _get_table_schema_names
from validate_args import get_first_schema
from validate_args import cols_in_tbl_valid
from validate_args import explicit_bool_to_text
from validate_args import input_tbl_valid
from validate_args import is_var_valid
from validate_args import output_tbl_valid
import plpy
m4_changequote(`<!', `!>')
def has_function_properties():
""" __HAS_FUNCTION_PROPERTIES__ variable defined during configure """
return m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
def is_platform_pg():
""" __POSTGRESQL__ variable defined during configure """
return m4_ifdef(<!__POSTGRESQL__!>, <!True!>, <!False!>)
# ------------------------------------------------------------------------------
def is_platform_hawq():
""" __HAWQ__ variable defined during configure """
return m4_ifdef(<!__HAWQ__!>, <!True!>, <!False!>)
# ------------------------------------------------------------------------------
def get_seg_number():
""" Find out how many primary segments exist in the distribution
Might be useful for partitioning data.
if is_platform_pg():
return 1
return plpy.execute("""
SELECT count(*) from gp_segment_configuration
WHERE role = 'p'
# ------------------------------------------------------------------------------
def is_orca():
if has_function_properties():
optimizer = plpy.execute("show optimizer")[0]["optimizer"]
if optimizer == 'on':
return True
return False
# ------------------------------------------------------------------------------
def _assert_equal(o1, o2, msg):
@brief if the given objects are not equal, then raise an error with the message
@param o1 the first object
@param o2 the second object
@param msg the error message to be reported
if not o1 == o2:
# ------------------------------------------------------------------------------
def _assert(condition, msg):
@brief if the given condition is false, then raise an error with the message
@param condition the condition to be asserted
@param msg the error message to be reported
if not condition:
# ------------------------------------------------------------------------------
def warn(condition, msg):
@brief if the given condition is false, then raise a warning with the message
@param condition the condition to be asserted
@param msg the error message to be reported
if not condition:
# ------------------------------------------------------------------------------
def get_distribution_policy(source_table):
""" Return a list of columns that define the distribution policy of source_table
@param source_table
List of str.
_, table_name = _get_table_schema_names(source_table)
schema_name = get_first_schema(source_table)
dist_attr = plpy.execute("""
SELECT array_agg(pga.attname) as dist_attr
SELECT gdp.localoid,
WHEN ( ARRAY_UPPER(gdp.attrnums, 1) > 0 ) THEN
END AS attnum
FROM gp_distribution_policy gdp
) AS distkey
INNER JOIN pg_class AS pgc
ON distkey.localoid = pgc.oid AND pgc.relname = '{table_name}'
INNER JOIN pg_namespace pgn
ON pgc.relnamespace = pgn.oid AND pgn.nspname = '{schema_name}'
LEFT OUTER JOIN pg_attribute pga
ON distkey.attnum = pga.attnum AND distkey.localoid = pga.attrelid
""".format(table_name=table_name, schema_name=schema_name))[0]["dist_attr"]
return dist_attr
# ------------------------------------------------------------------------------
def num_features(source_table, independent_varname):
return plpy.execute("SELECT array_upper({0}, 1) AS dim "
"FROM {1} LIMIT 1"
# ------------------------------------------------------------------------------
def num_samples(source_table):
return plpy.execute("SELECT count(*) AS n FROM {0}"
# ------------------------------------------------------------------------------
def unique_string(desp='', **kwargs):
Generate random remporary names for temp table and other names.
It has a SQL interface so both SQL and Python functions can call it.
r1 = random.randint(1, 100000000)
r2 = int(time.time())
r3 = int(time.time()) % random.randint(1, 100000000)
u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
return u_string
# ------------------------------------------------------------------------------
def add_postfix(quoted_string, postfix):
""" Append a string to the end of the table name.
If input table name is quoted by double quotes, make sure the postfix is
inside of the double quotes.
@param quoted_string: str. A string representing a database quoted string
@param postfix: str. A string to add as a suffix to quoted_string.
** This is assumed to not contain any quotes **
quoted_string = quoted_string.strip()
if quoted_string.startswith('"') and quoted_string.endswith('"'):
output_str = quoted_string[:-1] + postfix + '"'
output_str = quoted_string + postfix
return output_str
# -------------------------------------------------------------------------
def is_psql_numeric_type(arg, exclude=None):
Checks if argument is one of the various numeric types in PostgreSQL
@param arg: string, Type name to check
@param exclude: iterable, List of types to exclude from checking
Boolean. Returns if 'arg' is one of the numeric types
numeric_types = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric',
'real', 'double precision', 'serial', 'bigserial'])
if exclude is None:
exclude = []
to_check_types = numeric_types - set(exclude)
return (arg in to_check_types)
# -------------------------------------------------------------------------
def is_string_formatted_as_array_expression(string_to_match):
Return true if the string is formatted as array[<something>], else false
:param string_to_match:
matched = re.match(r"(?i)^array\[(.*)\]", string_to_match)
return matched
def _string_to_array(s):
Split a string into an array of strings
Any space around the substrings are removed
Requirement: every individual element in the string
must be a valid Postgres name, which means that if
there are spaces or commas in the element then the
whole element must be quoted by a pair of double
Usually this is not a problem. Especially in older
versions of GPDB, when an array is passed from
SQL to Python, it is converted to a string, which
automatically adds double quotes if there are spaces or
commas in the element.
So use this function, if you are sure that all elements
are valid Postgres names.
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
for i in range(len(elm)):
elm[i] = elm[i].strip("\"")
return elm
# ------------------------------------------------------------------------
def _string_to_array_with_quotes(s):
Same as _string_to_array except the double quotes will be kept.
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
return elm
# ------------------------------------------------------------------------
def py_list_to_sql_string(array, array_type=None, long_format=None):
"""Convert a list to SQL array string """
if long_format is None:
if (array_type is not None and
for i in ["text", "varchar", "character varying"]))):
long_format = False
long_format = True
if not array_type:
array_type = "double precision[]"
array_type = array_type.strip()
if not array_type.endswith("[]"):
array_type += "[]"
if not array:
return "'{{ }}'::{0}".format(array_type)
array_str = "ARRAY[ {0} ]" if long_format else "'{{ {0} }}'"
return (array_str + "::{1}").format(','.join(map(str, array)), array_type)
# ------------------------------------------------------------------------
def _array_to_string(origin):
Convert an array to string
def _escape(s):
return re.sub(r'"', r'\"', str(s))
return "{" + ",".join(map(_escape, origin)) + "}"
# ------------------------------------------------------------------------
def _cast_if_null(input, alias=''):
if input:
return str(input)
null_str = "NULL::text"
return null_str + " as " + alias if alias else null_str
# ------------------------------------------------------------------------
def set_client_min_messages(new_level):
Set the client_min_message setting in Postgres which controls the messages
sent to the psql client.
@param new_level: string, New level to set the client_min_message
old_msg_level: string, The old client_min_message level before changing it.
ValueError: if the argument new_level is not a valid message level
if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2',
'debug1', 'log', 'notice', 'warning', 'error',
'fatal', 'panic'):
raise ValueError("Not a valid message level sent to client")
old_msg_level = plpy.execute(""" SELECT setting
FROM pg_settings
WHERE name='client_min_messages'
plpy.execute("SET client_min_messages TO {0}".format(new_level))
return old_msg_level
# -------------------------------------------------------------------------
# Deal with earlier versions of PG or GPDB
class __mad_version:
def __init__(self):
self.version = plpy.execute("select version()")[0]["version"]
def select_vecfunc(self):
PG84 and GP40, GP41 do not have a good support for
vectors. They convert any vector into a string, surrounded
by { and }. Thus special care is needed for these older
versions of GPDB and PG.
# GPDB 4.0 or 4.1
if self.is_less_than_gp42() or self.is_less_than_pg90():
return self.__extract
return self.__identity
def __extract(self, origin, text=True):
Extract vector elements from a string with {}
as the brackets
if origin is None:
return None
elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1))
if text is False:
for i in range(len(elm)):
elm[i] = float(elm[i])
return elm
def __identity(self, origin, text=True):
return origin
def select_vec_return(self):
Special care is needed if one needs to return
vector from Python to SQL
if self.is_less_than_gp42() or self.is_less_than_pg90():
return self.__condense
return self.__identity
def __condense(self, origin):
Convert the original vector into a string which some
old versions of SQL system can recognize
return _array_to_string(origin)
def select_array_agg(self, schema_madlib):
GPDB < 4.1 and PG < 9.0 do not have support for array_agg,
so use the madlib array_agg for those versions
if self.is_less_than_gp41() or self.is_less_than_pg90():
return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib)
return "array_agg"
def is_pg(self):
if ("PostgreSQL", self.version) and
not"Greenplum\s*Database", self.version)):
return True
return False
def is_gp43(self):
if"Greenplum\s+Database\s+4\.3", self.version):
return True
return False
def is_less_than_pg90(self):
regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9:
return True
return False
def is_less_than_gp41(self):
regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and float(version[0]) < 4.1:
return True
return False
def is_less_than_gp42(self):
regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and float(version[0]) < 4.2:
return True
return False
def is_hawq(self):
if"HAWQ\s+[0-9.]+", self.version):
return True
return False
def is_pg_version_less_than(self, compare_version):
""" Return True if self is a PostgreSQL database and
self.version is less than compare_version
@param compare_version: str, String form of the comparison version. Expected
format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and self.is_pg():
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
return False
def is_gp_version_less_than(self, compare_version):
""" Return True if self is a Greenplum database and self.version
is less than compare_version
@param compare_version: str, String form of the comparison version. Expected
format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and not self.is_hawq():
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
return False
def is_hq_version_less_than(self, compare_version):
""" Return True if self is a HAWQ database and self.version
is less than compare_version
@param compare_version: str, String form of the comparison version.
Expected format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
regex = re.compile('HAWQ\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0:
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
return False
def _string_to_sql_array(schema_madlib, s, **kwargs):
Split a string into an array of strings
Any space around the substrings are removed
Requirement: every individual element in the string
must be a valid Postgres name, which means that if
there are spaces or commas in the element then the
whole element must be quoted by a pair of double
Usually this is not a problem. Especially in older
versions of GPDB, when an array is passed from
SQL to Python, it is converted to a string, which
automatically adds double quotes if there are spaces or
commas in the element.
So use this function, if you are sure that all elements
are valid Postgres names.
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
version_wrapper = __mad_version()
array_to_string = version_wrapper.select_vec_return()
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
for i in range(len(elm)):
elm[i] = elm[i].strip("\"")
return array_to_string(elm)
# ------------------------------------------------------------------------
def current_user():
"""Returns the user name of the current database user."""
return plpy.execute("SELECT current_user")[0]['current_user']
# ------------------------------------------------------------------------
def madlib_version(schema_madlib):
"""Returns the MADlib version string."""
raw = plpy.execute("""
SELECT {schema_madlib}.version()
return raw.split(',')[0].split(' ')[-1]
# ------------------------------------------------------------------------
def preprocess_keyvalue_params(input_params, split_char='='):
Parse the input_params string and split it using the split_char
@param input_params: str, Comma-separated list of parameters
The parameter can be any key = value, where
key is a string of alphanumeric character
value is either and
@param split_char: str, character that splits the key and value elements.
Default set to '='
re_str = (r"([-:\w]+\s*" + # key is any alphanumeric character
# (including - and :) string
split_char + # key and value are separated by split_char
\s*([\(\{\[] # value can be string or array
[^\[\]\(\)\{\}]* # if value is array then accept anything inside
[\)\}\]] # and then match closing braces of array
| # either array (above) or string (below)
# if value is string, it can be alphanumeric
# character string with a decimal dot,
# hyphen, or percent
# optionally quoted by `quote_char`
pattern = re.compile(re_str, re.VERBOSE)
return [ for m in pattern.finditer(input_params)]
# ------------------------------------------------------------------------
def extract_keyvalue_params(input_params,
""" Extract key value pairs from input parameters or set the default values
@param input_params: string, Format of
'key1=value1, key2=value2,...', assuming default split_char.
The order does not matter. If a parameter is missing, then
the default value is used. If input_params is None or '',
then all default values are returned. This function also
validates the values of these parameters.
@param input_param_types: dict, The type of each allowed parameter
name. Currently supports one of
(int, float, str, list)
@param default_values: dict, Default values for each allowed parameter.
@param split_char: str, The character used to split key and value.
Default set to '='
@param usage_str: str, An optional usage string to print with error message.
@param ignore_invalid: bool, If True an invalid param input is ignore silently
@param allow_duplicates: bool, Allow repeat of a 'key' in input_params,
where the last occurence will be reflected
in the output. If False, then a ValueError is
@param lower_case_names: bool, Convert parameter names to lower case.
Dict. Dictionary of input parameter values with key as parameter name
and value as the parameter value
plpy.error - If the parameter is unsupported or the value is
not valid.
if not input_params:
return default_values if default_values is not None else {}
if default_values:
parameter_dict = default_values
parameter_dict = {}
seen_params = set()
for s in preprocess_keyvalue_params(input_params, split_char=split_char):
items = split_quoted_delimited_str(s, delimiter=split_char)
if (len(items) != 2):
raise KeyError("Input parameter list has incorrect format "
param_name = items[0].strip(" \"")
if lower_case_names:
param_name = param_name.lower()
param_value = items[1].strip()
if not allow_duplicates and param_name in seen_params:
raise ValueError("Invalid input: {0} duplicated in the param list".
if not param_name or param_name in ('none', 'null'):
plpy.error("Invalid input param name: {0} \n"
"{1}".format(param_name, usage_str))
if input_param_types:
param_type = input_param_types[param_name]
except KeyError:
if not ignore_invalid:
raise KeyError("Invalid input: {0} is not a valid parameter "
"{1}".format(param_name, usage_str))
if param_type == bool: # bool is not subclassable
# True values are y, yes, t, true, on and 1;
# False values are n, no, f, false, off and 0.
# Raises ValueError if anything else.
parameter_dict[param_name] = bool(strtobool(param_value))
elif param_type in (int, str, float):
parameter_dict[param_name] = param_type(param_value)
elif issubclass(param_type, collections.Iterable):
parameter_dict[param_name] = split_quoted_delimited_str(
param_value.strip('[](){} '))
raise TypeError("Invalid input: {0} has unsupported type "
"{1}".format(param_name, usage_str))
except ValueError:
raise ValueError("Invalid input: {0} must be {1} \n"
"{2}".format(param_name, param_type, usage_str))
# if input parameter types not provided then just return all as string
parameter_dict[param_name] = str(param_value)
return parameter_dict
# -------------------------------------------------------------------------
def split_quoted_delimited_str(input_str, delimiter=',', quote='"'):
""" Parse a delimited-string to return a list of individual tokens taking
quotes into account.
@param input_str: str, Delimited input string
@param delimiter: str, The field delimiter character that separates
the tokens in input_str. Default = ','
@param quote: str, One-character string used to quote fields containing
special characters, such as the field delimiter
Default = '"'
List. List of delimited strings.
if not input_str or not delimiter or not quote:
return []
delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'.
format(d=delimiter, q=quote))
return [i.strip() for i in
except Exception as e:
raise ValueError("Invalid string input for splitting")
# ------------------------------------------------------------------------------
def strip_end_quotes(input_str, quote='"'):
""" Remove the quote character from the start and end if they are present
(at both ends). Whitespace at start and end are not ignored.
@param input_str
@param quote
str. Original string without the quotes at start and end
if not input_str or not quote:
return input_str
if not isinstance(input_str, str):
return input_str
if input_str.startswith(quote) and input_str.endswith(quote):
return input_str[1:-1]
return input_str
# ------------------------------------------------------------------------------
def _grp_null_checks(grp_list):
Helper function for generating NULL checks for grouping columns
to be used within a WHERE clause
@param grp_list The list of grouping columns
return ' AND '.join([" {i} IS NOT NULL ".format(**locals())
for i in grp_list])
# ------------------------------------------------------------------------------
def _check_groups(tbl1, tbl2, grp_list):
Helper function for joining tables with groups.
@param tbl1 Name of the first table
@param tbl2 Name of the second table
@param grp_list The list of grouping columns
return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals())
for i in grp_list])
# ------------------------------------------------------------------------------
def get_filtered_cols_subquery_str(include_from_table, exclude_from_table,
This function returns a subquery string with columns in the filter_cols_list
that appear in include_from_table but NOT IN exclude_from_table.
:param include_from_table: table from which cols in filter_cols_list should
be included
:param exclude_from_table: table from which cols in filter_cols_list should
be excluded
:param filter_cols_list: list of column names
:return: query string with relevant column names
included_cols = get_table_qualified_col_str(include_from_table, filter_cols_list)
cols = ', '.join([col for col in filter_cols_list])
return """({included_cols}) NOT IN
(SELECT {cols}
FROM {exclude_from_table})
# ------------------------------------------------------------------------------
def get_table_qualified_col_str(tbl_name, col_list):
Helper function for selecting columns of a table
@param tbl Name of the table
@param grp_list The list of grouping columns
return ' , '.join([" {tbl_name}.{col} ".format(**locals())
for col in col_list])
def get_grouping_col_str(schema_madlib, module_name, reserved_cols,
source_table, grouping_col):
if grouping_col and grouping_col.lower() != 'null':
intersect = frozenset(
_assert(len(intersect) == 0,
"{0} error: Conflicting grouping column name.\n"
"Some predefined keyword(s) ({1}) are not allowed "
"for grouping column names!".format(module_name, ', '.join(intersect)))
grouping_list = [i + "::text"
for i in explicit_bool_to_text(
grouping_str = ','.join(grouping_list)
grouping_str = "Null"
grouping_col = None
return grouping_str, grouping_col
# ------------------------------------------------------------------------------
def collate_plpy_result(plpy_result_rows):
if not plpy_result_rows:
return {}
all_keys = plpy_result_rows[0].keys()
result = collections.defaultdict(list)
for each_row in plpy_result_rows:
for each_key in all_keys:
return result
# ------------------------------------------------------------------------------
def validate_module_input_params(source_table, output_table, independent_varname,
dependent_varname, module_name,
This function is supposed to be used for validating params for
supervised learning like algos, e.g. linear regression, mlp, etc. since all
of them need to validate the following 4 parameters.
:param source_table: This table should exist and not be empty
:param output_table: This table should not exist
:param dependent_varname: This should be a valid expression in the source
:param independent_varname: This should be a valid expression in the source
:param module_name: Name of the module to be printed with the error messages
:param other_output_tables: List of additional output tables to validate.
These tables should not exist
input_tbl_valid(source_table, module_name)
output_tbl_valid(output_table, module_name)
if other_output_tables:
for tbl in other_output_tables:
output_tbl_valid(tbl, module_name)
_assert(is_var_valid(source_table, independent_varname),
"{module_name} error: invalid independent_varname "
"('{independent_varname}') for source_table "
_assert(is_var_valid(source_table, dependent_varname),
"{module_name} error: invalid dependent_varname "
"('{dependent_varname}') for source_table "
# ------------------------------------------------------------------------
import unittest
class UtilitiesTestCase(unittest.TestCase):
Comment "import plpy" and replace plpy.error calls with appropriate
Python Exceptions to successfully run the test cases
def setUp(self):
self.optimizer_params1 = 'max_iter=10, optimizer::text="irls", precision=1e-4'
self.optimizer_params2 = 'max_iter=.01, optimizer=newton-irls, precision=1e-5'
self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,"2,2",3,4}'
self.optimizer_params4 = ('max_iter=10, optimizer="irls",'
'precision=0.02.01, lambda={1,2,3,4}')
self.optimizer_params5 = ('max_iter=10, optimizer="irls",'
'precision=0.02, PRECISION=2., lambda={1,2,3,4}')
self.optimizer_types = {'max_iter': int, 'optimizer': str, 'optimizer::text': str,
'lambda': list, 'precision': float}
def test_preprocess_optimizer(self):
['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'])
['max_iter=.01', 'optimizer=newton-irls', 'precision=1e-5'])
['max_iter=10', 'lambda={1,"2,2",3,4}'])
['max_iter=10', 'optimizer="irls"', 'precision=0.02.01', 'lambda={1,2,3,4}'])
def test_extract_optimizers(self):
self.assertEqual({'max_iter': 10, 'optimizer::text': '"irls"', 'precision': 0.0001},
extract_keyvalue_params(self.optimizer_params1, self.optimizer_types))
self.assertEqual({'max_iter': 10, 'lambda': ['1', '"2,2"', '3', '4']},
extract_keyvalue_params(self.optimizer_params3, self.optimizer_types))
self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', 'precision': '0.02.01',
'lambda': '{1,2,3,4}'},
self.assertEqual({'max_iter': '10', 'optimizer': '"irls"',
'PRECISION': '2.', 'precision': '0.02',
'lambda': '{1,2,3,4}'},
extract_keyvalue_params, self.optimizer_params2, self.optimizer_types)
extract_keyvalue_params, self.optimizer_params5, allow_duplicates=False)
extract_keyvalue_params, self.optimizer_params4, self.optimizer_types)
def test_split_delimited_string(self):
self.assertEqual(['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'],
split_quoted_delimited_str(self.optimizer_params1, quote='"'))
self.assertEqual(['a', 'b', 'c'], split_quoted_delimited_str('a, b, c', quote='|'))
self.assertEqual(['a', '|b, c|'], split_quoted_delimited_str('a, |b, c|', quote='|'))
self.assertEqual(['a', '"b, c"'], split_quoted_delimited_str('a, "b, c"'))
self.assertEqual(['"a^5,6"', 'b', 'c'], split_quoted_delimited_str('"a^5,6", b, c', quote='"'))
self.assertEqual(['"A""^5,6"', 'b', 'c'], split_quoted_delimited_str('"A""^5,6", b, c', quote='"'))
def test_collate_plpy_result(self):
plpy_result1 = [{'classes': '4', 'class_count': 3},
{'classes': '1', 'class_count': 18},
{'classes': '5', 'class_count': 7},
{'classes': '3', 'class_count': 3},
{'classes': '6', 'class_count': 7},
{'classes': '2', 'class_count': 7}]
{'classes': ['4', '1', '5', '3', '6', '2'],
'class_count': [3, 18, 7, 3, 7, 7]})
self.assertEqual(collate_plpy_result([]), {})
self.assertEqual(collate_plpy_result([{'class': 'a'},
{'class': 'b'},
{'class': 'c'}]),
{'class': ['a', 'b', 'c']})
if __name__ == '__main__':