blob: 946b1397738c5e785443b3d86fb86277b78f746c [file] [log] [blame]
import collections
import re
import time
import random
from distutils.util import strtobool
from validate_args import _get_table_schema_names
from validate_args import cols_in_tbl_valid
from validate_args import does_exclude_reserved
from validate_args import explicit_bool_to_text
from validate_args import get_cols
from validate_args import get_first_schema
from validate_args import input_tbl_valid
from validate_args import is_var_valid
from validate_args import output_tbl_valid
from validate_args import quote_ident
from validate_args import drop_tables
import plpy
m4_changequote(`<!', `!>')
def plpy_execute_debug(sql, *args, **kwargs):
""" Replace plpy.execute(sql, ...) with
plpy_execute_debug(sql, ...) to debug
a query. Shows the query itself, the
EXPLAIN of it, and how long the query
takes to execute.
"""
plpy.info(sql) # Print sql command
# Print EXPLAIN of sql command
res = plpy.execute("EXPLAIN " + sql, *args)
for r in res:
plpy.info(r['QUERY PLAN'])
# Run actual sql command, with timing
start = time.time()
plpy.execute(sql, *args)
# Print how long execution of query took
plpy.info("Query took {0}s".format(time.time() - start))
def has_function_properties():
""" __HAS_FUNCTION_PROPERTIES__ variable defined during configure """
return m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
def is_platform_pg():
""" __POSTGRESQL__ variable defined during configure """
return m4_ifdef(<!__POSTGRESQL__!>, <!True!>, <!False!>)
def is_platform_gp6_or_up():
version_wrapper = __mad_version()
return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0')
# ------------------------------------------------------------------------------
def get_seg_number():
""" Find out how many primary segments(not include master segment) exist
in the distribution. Might be useful for partitioning data.
"""
if is_platform_pg():
return 1
else:
count = plpy.execute("""
SELECT count(*) from gp_segment_configuration
WHERE role = 'p' and content != -1
""")[0]['count']
# in case some weird gpdb configuration happens, always returns
# primary segment number >= 1
return max(1, count)
# ------------------------------------------------------------------------------
def get_segments_per_host():
""" Find out how many primary segments(not include master segment) exist
per host. We assume every host has the same number of segments and
we only return the first one.
"""
if is_platform_pg():
return 1
else:
count = plpy.execute("""
SELECT count(*) from gp_segment_configuration
WHERE role = 'p' and content != -1
GROUP BY hostname
LIMIT 1
""")[0]['count']
# in case some weird gpdb configuration happens, always returns
# primary segment number >= 1
return max(1, count)
# ------------------------------------------------------------------------------
def is_orca():
if has_function_properties():
optimizer = plpy.execute("show optimizer")[0]["optimizer"]
if optimizer == 'on':
return True
return False
# ------------------------------------------------------------------------------
def _assert_equal(o1, o2, msg):
"""
@brief if the given objects are not equal, then raise an error with the message
@param o1 the first object
@param o2 the second object
@param msg the error message to be reported
"""
if not o1 == o2:
plpy.error(msg)
# ------------------------------------------------------------------------------
def _assert(condition, msg):
"""
@brief if the given condition is false, then raise an error with the message
@param condition the condition to be asserted
@param msg the error message to be reported
"""
if not condition:
plpy.error(msg)
# ------------------------------------------------------------------------------
def warn(condition, msg):
"""
@brief if the given condition is false, then raise a warning with the message
@param condition the condition to be asserted
@param msg the error message to be reported
"""
if not condition:
plpy.warning(msg)
# ------------------------------------------------------------------------------
def get_distributed_by(source_table):
""" Return a "distributed by (...)" clause that defines distribution policy of source_table
Args:
@param source_table
Returns:
List of str.
"""
_, table_name = _get_table_schema_names(source_table)
schema_name = get_first_schema(source_table)
# GPDB 6 has pg_get_table_distributedby(<oid>) function to get the
# DISTRIBUTED BY clause of a table. In older version, we have to
# dig the column names based on gp_distribution_policy catalog.
version_wrapper = __mad_version()
if version_wrapper.is_gp_version_less_than("6.0"):
dist_attr = plpy.execute("""
SELECT array_agg(pga.attname) as dist_attr
FROM (
SELECT gdp.localoid,
CASE
WHEN ( ARRAY_UPPER(gdp.attrnums, 1) > 0 ) THEN
UNNEST(gdp.attrnums)
ELSE NULL
END AS attnum
FROM gp_distribution_policy gdp
) AS distkey
INNER JOIN pg_class AS pgc
ON distkey.localoid = pgc.oid AND pgc.relname = '{table_name}'
INNER JOIN pg_namespace pgn
ON pgc.relnamespace = pgn.oid AND pgn.nspname = '{schema_name}'
LEFT OUTER JOIN pg_attribute pga
ON distkey.attnum = pga.attnum AND distkey.localoid = pga.attrelid
""".format(table_name=table_name, schema_name=schema_name))[0]["dist_attr"]
if len(dist_attr) > 0:
dist_str = 'distributed by (' + ','.join(['"%s"' % i
for i in dist_attr
if i is not None]) + ')'
else:
dist_str = 'distributed randomly'
else:
dist_str = plpy.execute("""
SELECT pg_catalog.pg_get_table_distributedby(pgc.oid) as distributedby
FROM pg_class AS pgc
INNER JOIN pg_namespace pgn ON pgc.relnamespace = pgn.oid
WHERE pgc.relname = '{table_name}' AND pgn.nspname = '{schema_name}'
""".format(table_name=table_name, schema_name=schema_name))[0]["distributedby"]
return dist_str
# ------------------------------------------------------------------------------
def num_features(source_table, independent_varname):
return plpy.execute("SELECT array_upper({0}, 1) AS dim "
"FROM {1} LIMIT 1"
.format(independent_varname,
source_table))[0]['dim']
# ------------------------------------------------------------------------------
def num_samples(source_table):
return plpy.execute("SELECT count(*) AS n FROM {0}"
.format(source_table))[0]['n']
# ------------------------------------------------------------------------------
def unique_string(desp='', **kwargs):
"""
Generate random remporary names for temp table and other names.
It has a SQL interface so both SQL and Python functions can call it.
"""
r1 = random.randint(1, 100000000)
r2 = int(time.time())
r3 = int(time.time()) % random.randint(1, 100000000)
u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__"
return u_string
# ------------------------------------------------------------------------------
def add_postfix(quoted_string, postfix):
""" Append a string to the end of the table name.
If input table name is quoted by double quotes, make sure the postfix is
inside of the double quotes.
Arguments:
@param quoted_string: str. A string representing a database quoted string
@param postfix: str. A string to add as a suffix to quoted_string.
** This is assumed to not contain any quotes **
"""
quoted_string = quoted_string.strip()
if quoted_string.startswith('"') and quoted_string.endswith('"'):
output_str = quoted_string[:-1] + postfix + '"'
else:
output_str = quoted_string + postfix
return output_str
# -------------------------------------------------------------------------
NUMERIC = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric',
'real', 'double precision', 'float', 'serial', 'bigserial'])
INTEGER = set(['smallint', 'integer', 'bigint'])
TEXT = set(['text', 'varchar', 'character varying', 'char', 'character'])
BOOLEAN = set(['boolean'])
INCLUDE_ARRAY = set([unique_string('__include_array__')])
ONLY_ARRAY = set([unique_string('__only_array__')])
ANY_ARRAY = set([unique_string('__any_array__')])
def is_valid_psql_type(arg, valid_types):
""" Verify if argument is a valid type
Args:
@param arg: str. Name of the Postgres type to validate
@param valid_types: set. Set of valid type names to search.
This is typically created using the global names
in this module.
Three non-type flags are provided
(in descending order of precedence):
- ANY_ARRAY: check if arg is any array type
- ONLY_ARRAY: indicates that only array forms
of the valid types should be checked
- INCLUDE_ARRAY: indicates that array and scalar
forms of the valid types should be checked
Examples: 1. valid_types = BOOLEAN | INTEGER | TEXT
2. valid_types = BOOLEAN | INTEGER | ONLY_ARRAY
3. valid_types = NUMERIC | INCLUDE_ARRAY
"""
if not arg or not valid_types:
return False
arg = arg.lower()
if ANY_ARRAY <= valid_types:
return arg.rstrip().endswith('[]')
if ONLY_ARRAY <= valid_types:
return (arg.rstrip().endswith('[]') and arg.rstrip('[] ') in valid_types)
if INCLUDE_ARRAY <= valid_types:
# Remove the [] from end of the arg type
# The single space is needed to ensure trailing white space is stripped
arg = arg.rstrip('[] ')
return (arg in valid_types)
# ------------------------------------------------------------------------------
def is_psql_numeric_type(arg, exclude=None):
"""
Checks if argument is one of the various numeric types in PostgreSQL
Args:
@param arg: string, Type name to check
@param exclude: iterable, List of types to exclude from checking
Returns:
Boolean. Returns if 'arg' is one of the numeric types
"""
if exclude is None:
exclude = []
to_check_types = NUMERIC - set(exclude)
return (arg in to_check_types)
# -------------------------------------------------------------------------
def is_psql_int_type(arg, exclude=None):
"""
Checks if argument is one of the various numeric types in PostgreSQL
Args:
@param arg: string, Type name to check
@param exclude: iterable, List of types to exclude from checking
Returns:
Boolean. Returns if 'arg' is one of the numeric types
"""
if exclude is None:
to_check_types = INTEGER
else:
to_check_types = INTEGER - set(exclude)
return (arg in to_check_types)
# -------------------------------------------------------------------------
def is_psql_char_type(arg, exclude_list=[]):
"""
This function checks if the given arg is one of the predefined postgres
character types
:param arg:
:param exclude: Optionally exclude one or more types from the comparison
:return: True if it is one of the character types, else False.
"""
if not isinstance(exclude_list, list):
exclude_list = [exclude_list]
return arg in TEXT - set(exclude_list)
def is_psql_boolean_type(arg):
"""
This function checks if the given arg is one of type postgres boolean
:param arg:
:return: True if it is boolean, else False.
"""
return arg == 'boolean'
def is_string_formatted_as_array_expression(string_to_match):
"""
Return true if the string is formatted as array[<something>], else false
:param string_to_match:
"""
matched = re.match(r"(?i)^array\[(.*)\]", string_to_match)
return matched
def _string_to_array(s):
"""
Split a string into an array of strings
Any space around the substrings are removed
Requirement: every individual element in the string
must be a valid Postgres name, which means that if
there are spaces or commas in the element then the
whole element must be quoted by a pair of double
quotes.
Usually this is not a problem. Especially in older
versions of GPDB, when an array is passed from
SQL to Python, it is converted to a string, which
automatically adds double quotes if there are spaces or
commas in the element.
So use this function, if you are sure that all elements
are valid Postgres names.
"""
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
elm.append(m.group(1))
for i in range(len(elm)):
elm[i] = elm[i].strip("\"")
return elm
# ------------------------------------------------------------------------
def _string_to_array_with_quotes(s):
"""
Same as _string_to_array except the double quotes will be kept.
"""
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
elm.append(m.group(1))
return elm
# ------------------------------------------------------------------------
def get_col_name_type_sql_string(colnames, coltypes):
if colnames and coltypes and len(colnames)==len(coltypes):
return ','.join(map(' '.join, zip(colnames, coltypes)))
return None
# ------------------------------------------------------------------------
def py_list_to_sql_string(array, array_type=None, long_format=None):
"""Convert a list to SQL array string
@note: The long format is recommended with the input values quoted
appropriately by 'quote_literal'.
The short format (e.g. '{1,2,3}') does not take care of escaping the
string values in the array. Hence an input like '{ t$}t, test}' will fail due
to the unescaped '{'.
For eg. following is a valid input for short format:
array = ['M'M', M\"M, "M$M,M'M", M, M, M@[\}(:*;M, MM ]
Note the double quotes and curly brackets need to be escaped in the input if
used with the short format.
"""
if long_format is None:
if (array_type is not None and
(any(array_type.startswith(i)
for i in ["text", "varchar", "character varying"]))):
long_format = False
else:
long_format = True
if not array_type:
array_type = "double precision[]"
else:
array_type = array_type.strip()
if not array_type.endswith("[]"):
array_type += "[]"
if not array:
return "'{{ }}'::{0}".format(array_type)
else:
quote_delimiter = "$__MADLIB_OUTER__$"
# This is a quote delimiter that can be used in lieu of
# single quotes and allows the use of single quotes in the
# string without escaping
array_str = "ARRAY[ {val} ]" if long_format else "{qd}{{ {val} }}{qd}"
return (array_str + "::{array_type}").format(
val=','.join(map(str, array)),
array_type=array_type,
qd=quote_delimiter)
# ------------------------------------------------------------------------
def create_cols_from_array_sql_string(py_list, sql_array_col, colname,
coltype, has_one_ele,
module_name='Input Error'):
"""
Create SQL string to convert array of elements into multiple columns and corresponding
SQL string of columns for CREATE TABLE.
@args:
@param: py_list, python list, if None, return sql_array_col as colname.
The py_list can at most have one 'None' element that
is converted to sql 'NULL'
@param: sql_array_col, str pointing to a column in table containing
an array.
@param: colname, name of output column (can be treated as prefix
to multiple cols if has_one_ele=False)
@param: coltype, Type of columns to be created
@param: has_one_ele, bool. if True, assumes sql_array_col has
an array of exactly one element, which is treated as in
index to get value from py_list. If False, then a new
column is created for every element in py_list,
whose corresponding values are obtained from sql_array_col.
@examples:
1) Input:
py_list = ['cat', 'dog']
sql_array_col = sqlcol
colname = prob
coltype = TEXT
has_one_ele = FALSE
Output:
prob_cat TEXT, prob_dog TEXT
CAST(sqlcol[1] AS TEXT) AS prob_cat, CAST(sqlcol[2] AS TEXT) AS prob_dog
2) Input:
py_list = ['cat', 'dog']
sql_array_col = sqlcol
colname = estimated_pred
coltype = TEXT
has_one_ele = TRUE
Output:
estimated_pred TEXT
(ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred
@NOTE:
If py_list is [None, 'cat', 'dog', NULL']:
then the SQL query string returned would create the following
column names:
prob_NULL, prob_cat, 'prob_dog', and 'prob_"NULL"'.
1. Notice that for None, which represents Postgres' NULL value, the
column name will be 'prob_NULL',
2. and to differentiate the column name for a string 'NULL', the
resulting column name will be 'prob_"NULL"'.
The weird quoting in this column name is due to calling strip after
quote_ident in the code below.
@returns:
@param, str, that can be used in a SQL query.
@param, str, that can be used in a SQL query.
"""
_assert(sql_array_col, "{0}: sql_array_col should be a valid string.".
format(module_name))
_assert(colname, "{0}: colname should be a valid string.".format(
module_name))
if py_list:
_assert(is_valid_psql_type(coltype, BOOLEAN | NUMERIC | TEXT),
"{0}: Invalid coltype parameter {1}.".format(
module_name, coltype))
_assert(py_list.count(None) <= 1,
"{0}: Input list should contain at most 1 None element.".
format(module_name))
def py_list_str(ele):
"""
A python None is converted to a SQL NULL.
String 'NULL' is converted to SQL 'NULL' string by quoting
it to '"NULL"'. This quoting is necessary for Postgres to
differentiate between NULL and 'NULL' in the SQL query
string returned by create_cols_from_array_sql_string.
"""
if ele is None:
return 'NULL'
elif isinstance(ele, str) and ele.lower()=='null':
return '"{0}"'.format(ele)
return ele
py_list = list(map(py_list_str, py_list))
if has_one_ele:
# Query to choose the value in the first element of
# sql_array_col which is the index to access in py_list.
# The value from that corresponding index in py_list is
# the value of colname.
py_list_sql_str = py_list_to_sql_string(py_list, coltype+'[]')
select_clause = "({0})[{1}[1]+1]::{2} AS {3}".format(
py_list_sql_str, sql_array_col, coltype, colname)
create_columns = "{0} {1}".format(
colname, coltype)
else:
# Create as many columns as the length of py_list. The
# colnames are created based on the elements in py_list,
# while the value for these colnames is obtained from
# sql_array_col.
# we cannot call sql quote_ident on the py_list entries because
# aliasing does not support quote_ident. Hence calling our
# python implementation of quote_ident. We must call strip()
# after quote_ident since the resulting SQL query fails otherwise.
select_clause = ', '.join(
['CAST({sql_array_col}[{j}] AS {coltype}) AS "{final_colname}"'.
format(j=i + 1,
final_colname=quote_ident("{0}_{1}".
format(colname, str(suffix))).strip(' "'),
sql_array_col=sql_array_col,
coltype=coltype)
for i, suffix in enumerate(py_list)
])
create_columns = ', '.join(
['"{final_colname}" {coltype}'.
format(final_colname=quote_ident("{0}_{1}".
format(colname, str(suffix))).strip(' "'),
coltype=coltype)
for i, suffix in enumerate(py_list)
])
else:
if has_one_ele:
select_clause = '{0}[1]+1 AS {1}'.format(sql_array_col, colname)
create_columns = '{0} {1}'.format(colname, coltype)
else:
select_clause = '{0} AS {1}'.format(sql_array_col, colname)
create_columns = '{0} {1}'.format(colname, coltype+'[]')
return select_clause, create_columns
# ------------------------------------------------------------------------
def _array_to_string(origin):
"""
Convert an array to string
"""
def _escape(s):
return re.sub(r'"', r'\"', str(s))
return "{" + ",".join(map(_escape, origin)) + "}"
# ------------------------------------------------------------------------
def _cast_if_null(input, alias=''):
if input:
return str(input)
else:
null_str = "NULL::text"
return null_str + " as " + alias if alias else null_str
# ------------------------------------------------------------------------
def set_client_min_messages(new_level):
"""
Set the client_min_message setting in Postgres which controls the messages
sent to the psql client.
Args:
@param new_level: string, New level to set the client_min_message
Returns:
old_msg_level: string, The old client_min_message level before changing it.
Raise:
ValueError: if the argument new_level is not a valid message level
"""
if new_level.lower() not in ('debug5', 'debug4', 'debug3', 'debug2',
'debug1', 'log', 'notice', 'warning', 'error',
'fatal', 'panic'):
raise ValueError("Not a valid message level sent to client")
old_msg_level = plpy.execute(""" SELECT setting
FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
plpy.execute("SET client_min_messages TO {0}".format(new_level))
return old_msg_level
# -------------------------------------------------------------------------
def is_pg_major_version_less_than(schema_madlib, compare_version, **kwargs):
version = plpy.execute("select version()")[0]["version"]
regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
version = regex.findall(version)
if len(version) > 0 and int(version[0][0]) < compare_version:
return True
else:
return False
# Deal with earlier versions of PG or GPDB
class __mad_version:
def __init__(self):
self.version = plpy.execute("select version()")[0]["version"]
def select_vecfunc(self):
"""
PG84 and GP40, GP41 do not have a good support for
vectors. They convert any vector into a string, surrounded
by { and }. Thus special care is needed for these older
versions of GPDB and PG.
"""
# GPDB 4.0 or 4.1
if self.is_less_than_gp42() or self.is_less_than_pg90():
return self.__extract
else:
return self.__identity
def __extract(self, origin, text=True):
"""
Extract vector elements from a string with {}
as the brackets
"""
if origin is None:
return None
elm = _string_to_array(re.match(r"^\{(.*)\}$", origin).group(1))
if text is False:
for i in range(len(elm)):
elm[i] = float(elm[i])
return elm
def __identity(self, origin, text=True):
return origin
def select_vec_return(self):
"""
Special care is needed if one needs to return
vector from Python to SQL
"""
if self.is_less_than_gp42() or self.is_less_than_pg90():
return self.__condense
else:
return self.__identity
def __condense(self, origin):
"""
Convert the original vector into a string which some
old versions of SQL system can recognize
"""
return _array_to_string(origin)
def select_array_agg(self, schema_madlib):
"""
GPDB < 4.1 and PG < 9.0 do not have support for array_agg,
so use the madlib array_agg for those versions
"""
if self.is_less_than_gp41() or self.is_less_than_pg90():
return "{schema_madlib}.array_agg".format(schema_madlib=schema_madlib)
else:
return "array_agg"
def is_pg(self):
if (re.search(r"PostgreSQL", self.version) and
not re.search(r"Greenplum\s*Database", self.version)):
return True
return False
def is_gp43(self):
if re.search(r"Greenplum\s+Database\s+4\.3", self.version):
return True
return False
def is_less_than_pg90(self):
regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and self.is_pg() and int(version[0][0]) < 9:
return True
else:
return False
def is_less_than_gp41(self):
regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and float(version[0]) < 4.1:
return True
else:
return False
def is_less_than_gp42(self):
regex = re.compile('Greenplum\s+Database\s*([0-9].[0-9])[0-9.]+\s+build', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and float(version[0]) < 4.2:
return True
else:
return False
def is_pg_version_less_than(self, compare_version):
""" Return True if self is a PostgreSQL database and
self.version is less than compare_version
@param compare_version: str, String form of the comparison version. Expected
format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
"""
regex = re.compile('PostgreSQL\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if len(version) > 0 and self.is_pg():
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
else:
return False
def is_gp_version_less_than(self, compare_version):
""" Return True if self is a Greenplum database and self.version
is less than compare_version
@param compare_version: str, String form of the comparison version. Expected
format is Semantic Versioning.
examples of versions: 1.0, 2.3, 9.3.5
"""
regex = re.compile('Greenplum\s+Database\s*([0-9.]+)', re.IGNORECASE)
version = regex.findall(self.version)
if version:
db_ver = [float(i) for i in version[0].split('.') if i.isdigit()]
cmp_ver = [float(i) for i in compare_version.split('.') if i.isdigit()]
return db_ver < cmp_ver
else:
return False
def _string_to_sql_array(schema_madlib, s, **kwargs):
"""
Split a string into an array of strings
Any space around the substrings are removed
Requirement: every individual element in the string
must be a valid Postgres name, which means that if
there are spaces or commas in the element then the
whole element must be quoted by a pair of double
quotes.
Usually this is not a problem. Especially in older
versions of GPDB, when an array is passed from
SQL to Python, it is converted to a string, which
automatically adds double quotes if there are spaces or
commas in the element.
So use this function, if you are sure that all elements
are valid Postgres names.
"""
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
version_wrapper = __mad_version()
array_to_string = version_wrapper.select_vec_return()
elm = []
for m in re.finditer(r"(\"(\\\"|[^\"])*\"|[^\",\s]+)", s):
elm.append(m.group(1))
for i in range(len(elm)):
elm[i] = elm[i].strip("\"")
return array_to_string(elm)
# ------------------------------------------------------------------------
def current_user():
"""Returns the user name of the current database user."""
return plpy.execute("SELECT current_user")[0]['current_user']
# ------------------------------------------------------------------------
def madlib_version(schema_madlib):
"""Returns the MADlib version string."""
raw = plpy.execute("""
SELECT {schema_madlib}.version()
""".format(**locals()))[0]['version']
return raw.split(',')[0].split(' ')[-1]
# ------------------------------------------------------------------------
def preprocess_keyvalue_params(input_params, split_char='='):
"""
Parse the input_params string and split it using the split_char
@param input_params: str, Comma-separated list of parameters
The parameter can be any key = value, where
key is a string of alphanumeric character
value is either and
@param split_char: str, character that splits the key and value elements.
Default set to '='
"""
re_str = (r"([-:\w]+\s*" + # key is any alphanumeric character
# (including - and :) string
split_char + # key and value are separated by split_char
r"""
\s*([\(\{\[] # value can be string or array
[^\[\]\(\)\{\}]* # if value is array then accept anything inside
[\)\}\]] # and then match closing braces of array
| # either array (above) or string (below)
(?P<quote>\"?)[\w\s\-\%.]+(?P=quote)
# if value is string, it can be alphanumeric
# character string with a decimal dot,
# hyphen, or percent
# optionally quoted by `quote_char`
)
)"""
)
pattern = re.compile(re_str, re.VERBOSE)
return [m.group(1).strip() for m in pattern.finditer(input_params)]
# ------------------------------------------------------------------------
def extract_keyvalue_params(input_params,
input_param_types=None,
default_values=None,
split_char='=',
usage_str='',
ignore_invalid=False,
allow_duplicates=True,
lower_case_names=True):
""" Extract key value pairs from input parameters or set the default values
Args:
@param input_params: string, Format of
'key1=value1, key2=value2,...', assuming default split_char.
The order does not matter. If a parameter is missing, then
the default value is used. If input_params is None or '',
then all default values are returned. This function also
validates the values of these parameters.
@param input_param_types: dict, The type of each allowed parameter
name. Currently supports one of
(int, float, str, list)
@param default_values: dict, Default values for each allowed parameter.
@param split_char: str, The character used to split key and value.
Default set to '='
@param usage_str: str, An optional usage string to print with error message.
@param ignore_invalid: bool, If True an invalid param input is ignore silently
@param allow_duplicates: bool, Allow repeat of a 'key' in input_params,
where the last occurence will be reflected
in the output. If False, then a ValueError is
raised.
@param lower_case_names: bool, Convert parameter names to lower case.
Returns:
Dict. Dictionary of input parameter values with key as parameter name
and value as the parameter value
Throws:
plpy.error - If the parameter is unsupported or the value is
not valid.
"""
if not input_params:
return default_values if default_values is not None else {}
if default_values:
parameter_dict = default_values
else:
parameter_dict = {}
seen_params = set()
for s in preprocess_keyvalue_params(input_params, split_char=split_char):
items = split_quoted_delimited_str(s, delimiter=split_char)
if (len(items) != 2):
raise KeyError("Input parameter list has incorrect format "
"{0}".format(usage_str))
param_name = items[0].strip(" \"")
if lower_case_names:
param_name = param_name.lower()
param_value = items[1].strip()
if not allow_duplicates and param_name in seen_params:
raise ValueError("Invalid input: {0} duplicated in the param list".
format(param_name))
if not param_name or param_name in ('none', 'null'):
plpy.error("Invalid input param name: {0} \n"
"{1}".format(param_name, usage_str))
if input_param_types:
try:
param_type = input_param_types[param_name]
except KeyError:
if not ignore_invalid:
raise KeyError("Invalid input: {0} is not a valid parameter "
"{1}".format(param_name, usage_str))
else:
continue
try:
if param_type == bool: # bool is not subclassable
# True values are y, yes, t, true, on and 1;
# False values are n, no, f, false, off and 0.
# Raises ValueError if anything else.
parameter_dict[param_name] = bool(strtobool(param_value))
elif param_type in (int, str, float):
parameter_dict[param_name] = param_type(param_value)
elif issubclass(param_type, collections.Iterable):
parameter_dict[param_name] = split_quoted_delimited_str(
param_value.strip('[](){} '))
else:
raise TypeError("Invalid input: {0} has unsupported type "
"{1}".format(param_name, usage_str))
except ValueError:
raise ValueError("Invalid input: {0} must be {1} \n"
"{2}".format(param_name, param_type, usage_str))
else:
# if input parameter types not provided then just return all as string
parameter_dict[param_name] = str(param_value)
seen_params.add(param_name)
return parameter_dict
# -------------------------------------------------------------------------
def split_quoted_delimited_str(input_str, delimiter=',', quote='"'):
""" Parse a delimited-string to return a list of individual tokens taking
quotes into account.
Args:
@param input_str: str, Delimited input string
@param delimiter: str, The field delimiter character that separates
the tokens in input_str. Default = ','
@param quote: str, One-character string used to quote fields containing
special characters, such as the field delimiter
Default = '"'
Returns:
List. List of delimited strings.
"""
if not input_str or not delimiter or not quote:
return []
try:
delimiter_reg = re.compile('((?:[^\{d}\{q}]|\{q}[^\{q}]*\{q})+)'.
format(d=delimiter, q=quote))
return [i.strip() for i in
delimiter_reg.split(input_str.strip())[1::2]]
except Exception as e:
plpy.warning(str(e))
raise ValueError("Invalid string input for splitting")
# ------------------------------------------------------------------------------
def strip_end_quotes(input_str, quote='"'):
""" Remove the quote character from the start and end if they are present
(at both ends). Whitespace at start and end are not ignored.
Args:
@param input_str
@param quote
Returns:
str. Original string without the quotes at start and end
"""
if not input_str or not quote:
return input_str
if not isinstance(input_str, str):
return input_str
if input_str.startswith(quote) and input_str.endswith(quote):
return input_str[1:-1]
else:
return input_str
# ------------------------------------------------------------------------------
def _grp_null_checks(grp_list):
"""
Helper function for generating NULL checks for grouping columns
to be used within a WHERE clause
Args:
@param grp_list The list of grouping columns
"""
return ' AND '.join([" {i} IS NOT NULL ".format(**locals())
for i in grp_list])
# ------------------------------------------------------------------------------
def _check_groups(tbl1, tbl2, grp_list):
"""
Helper function for joining tables with groups.
Args:
@param tbl1 Name of the first table
@param tbl2 Name of the second table
@param grp_list The list of grouping columns
"""
return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals())
for i in grp_list])
# ------------------------------------------------------------------------------
def get_filtered_cols_subquery_str(include_from_table, exclude_from_table,
filter_cols_list):
"""
This function returns a subquery string with columns in the filter_cols_list
that appear in include_from_table but NOT IN exclude_from_table.
:param include_from_table: table from which cols in filter_cols_list should
be included
:param exclude_from_table: table from which cols in filter_cols_list should
be excluded
:param filter_cols_list: list of column names
:return: query string with relevant column names
"""
included_cols = get_table_qualified_col_str(include_from_table, filter_cols_list)
cols = ', '.join([col for col in filter_cols_list])
return """({included_cols}) NOT IN
(SELECT {cols}
FROM {exclude_from_table})
""".format(**locals())
# ------------------------------------------------------------------------------
def get_table_qualified_col_str(tbl_name, col_list):
"""
Helper function for selecting columns of a table
Args:
@param tbl Name of the table
@param grp_list The list of grouping columns
"""
return ' , '.join([" {tbl_name}.{col} ".format(**locals())
for col in col_list])
# ------------------------------------------------------------------------------
def get_grouping_col_str(schema_madlib, module_name, reserved_cols,
source_table, grouping_col):
if grouping_col and grouping_col.lower() != 'null':
grouping_col_array = _string_to_array_with_quotes(grouping_col)
cols_in_tbl_valid(source_table, grouping_col_array, module_name)
does_exclude_reserved(grouping_col_array, reserved_cols)
grp_array_w_cast = explicit_bool_to_text(source_table,
grouping_col_array,
schema_madlib)
grouping_str = ', '.join(i + "::text" for i in grp_array_w_cast)
else:
grouping_str = "Null"
grouping_col = None
return grouping_str, grouping_col
# ------------------------------------------------------------------------------
def collate_plpy_result(plpy_result_rows):
if not plpy_result_rows:
return {}
else:
all_keys = plpy_result_rows[0].keys()
result = collections.defaultdict(list)
for each_row in plpy_result_rows:
for each_key in all_keys:
result[each_key].append(each_row[each_key])
return result
# ------------------------------------------------------------------------------
def validate_module_input_params(source_table, output_table, independent_varname,
dependent_varname, module_name,
grouping_cols=None, other_output_tables=None):
"""
This function is supposed to be used for validating params for
supervised learning like algos, e.g. linear regression, mlp, etc. since all
of them need to validate the following 4 parameters.
:param source_table: This table should exist and not be empty
:param output_table: This table should not exist
:param dependent_varname: This should be a valid expression in the source
table
:param independent_varname: This should be a valid expression in the source
table
:param module_name: Name of the module to be printed with the error messages
:param other_output_tables: List of additional output tables to validate.
These tables should not exist
"""
input_tbl_valid(source_table, module_name)
output_tbl_valid(output_table, module_name)
if other_output_tables:
for tbl in other_output_tables:
output_tbl_valid(tbl, module_name)
_assert(is_var_valid(source_table, independent_varname),
"{module_name} error: invalid independent_varname "
"('{independent_varname}') for source_table "
"({source_table})!".format(module_name=module_name,
independent_varname=independent_varname,
source_table=source_table))
_assert(is_var_valid(source_table, dependent_varname),
"{module_name} error: invalid dependent_varname "
"('{dependent_varname}') for source_table "
"({source_table})!".format(module_name=module_name,
dependent_varname=dependent_varname,
source_table=source_table))
if grouping_cols:
_assert(is_var_valid(source_table, grouping_cols),
"{module_name} error: invalid grouping_cols "
"('{grouping_cols}') for source_table "
"({source_table})!".format(module_name=module_name,
grouping_cols=grouping_cols,
source_table=source_table))
# ------------------------------------------------------------------------
def create_table_drop_cols(source_table, out_table, cols_to_drop, **kwargs):
""" Create copy of table while dropping some of the columns
Args:
@param source_table str. Name of the source table
@param out_table str. Name of the output table
@param cols_to_drop str. Comma-separated list of columns to drop
"""
input_tbl_valid(source_table, 'Utilities')
output_tbl_valid(out_table, 'Utilities')
_assert(cols_to_drop and cols_to_drop.strip(),
"Utilities error: cols_to_drop cannot be empty or NULL")
source_table_cols = get_cols(source_table)
cols_to_drop_list = split_quoted_delimited_str(cols_to_drop)
cols_not_in_source = set(cols_to_drop_list) - set(source_table_cols)
_assert(not cols_not_in_source,
"Utilities error: Some column(s) in cols_to_drop are not present "
"in source table")
cols_to_retain = [c for c in source_table_cols if c not in cols_to_drop_list]
_assert(cols_to_retain,
"Utilities error: No valid columns for the output table")
plpy.execute("""
CREATE TABLE {out_table} AS
SELECT {cols}
FROM {source_table}
""".format(cols=', '.join(cols_to_retain),
out_table=out_table,
source_table=source_table))
# ------------------------------------------------------------------------------
def rotate(l, n):
"""Summary
Rotate the list l to right(the index increasing direction) for n elements.
Args:
l (list): The input list to rotate
n (integer): The number of elements to rotate
Returns:
list: The rotated list
"""
return l[-n:] + l[:-n]
# ------------------------------------------------------------------------------
def rename_table(schema_madlib, orig_name, new_name):
"""
Renames possibly schema qualified table name to a new schema qualified name
ensuring the schema qualification are changed appropriately
Args:
@param orig_name: string, Original name of the table
(must be schema qualified if table schema is not in search path)
@param new_name: string, New name of the table
(can be schema qualified. If it is not then the original
schema is maintained)
Returns:
String. The new table name qualified with the schema name
"""
new_names_split = new_name.split(".")
if len(new_names_split) > 2:
raise AssertionError("Invalid table name")
new_table_name = new_names_split[-1]
new_table_schema = new_names_split[0] if len(new_names_split) > 1 else None
orig_names_split = orig_name.split(".")
if len(orig_names_split) > 2:
raise AssertionError("Invalid table name")
if len(orig_names_split) > 1:
orig_table_schema = orig_names_split[0]
else:
# we need to get the schema name of the original table if we are
# to change the schema of the new table. This is to ensure that we
# change the schema of the correct table in case there are multiple
# tables with the same new name.
orig_table_schema = get_first_schema(orig_name)
if orig_table_schema is None:
raise AssertionError("Relation {0} not found during rename".
format(orig_name))
return __do_rename_and_get_new_name(orig_name, new_name, orig_table_schema,
new_table_schema, new_table_name)
# ------------------------------------------------------------------------------
def __do_rename_and_get_new_name(orig_name, new_name, orig_table_schema,
new_table_schema, new_table_name):
"""
Internal private function to perform the rename operation after all the
validation checks
"""
"""
CASE 1
If the output table is schema is pg_temp, we cannot alter table schemas from/to
temp schemas. If it looks like a temp schema, we stay safe and just use
create/drop
Test cases
foo.bar to pg_temp.bar
foo.bar to pg_temp.bar2
foo to pg_temp.bar
pg_temp.foo to pg_temp.bar
"""
if new_table_schema and 'pg_temp' in new_table_schema:
"""
If both new_table_schema and orig_table_schema have pg_temp in it,
just run an alter statement instead of CTAS. Without this, pca dev-check
fails on gpdb5/6 (but not on pg)
"""
if new_table_schema != orig_table_schema:
plpy.info("""CREATE TABLE {new_name} AS SELECT * FROM {orig_name}"""
.format(**locals()))
plpy.execute("""CREATE TABLE {new_name} AS SELECT * FROM {orig_name}"""
.format(**locals()))
drop_tables([orig_name])
return new_name
else:
plpy.execute("ALTER TABLE {orig_name} RENAME TO {new_table_name}".
format(**locals()))
return new_name
"""
CASE 2
Do direct rename if the new table does not have an output schema or
if the new table schema is the same as the original table schema
Test Cases
rename foo to bar
rename foo.bar to foo.bar2
rename foo.bar to bar2
"""
if not new_table_schema or new_table_schema == orig_table_schema:
plpy.execute("ALTER TABLE {orig_name} RENAME TO {new_table_name}".
format(**locals()))
return orig_table_schema + "." + new_table_name
"""
CASE 3
output table is schema qualified
1. rename the original table to an interim temp name
2. set the new schema on that interim table
3. rename interim table to the new table name
Test cases
foo.bar to foo2.bar2
foo.bar to foo2.bar
"""
interim_temp_name = unique_string("rename_table_helper")
plpy.execute(
"ALTER TABLE {orig_name} RENAME to {interim_temp_name}".format(
**locals()))
plpy.execute(
"""ALTER TABLE {interim_temp_name} SET SCHEMA {new_table_schema}""".format(
**locals()))
plpy.execute(
"""ALTER TABLE {new_table_schema}.{interim_temp_name} RENAME to {new_table_name}"""
.format(**locals()))
return new_name
# ------------------------------------------------------------------------------
def is_platform_gp6_or_up():
version_wrapper = __mad_version()
return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0')
def get_psql_type(py_type):
if type(py_type) == int:
return 'integer'
elif type(py_type) == float:
return 'double precision'
elif type(py_type) == bool:
return 'boolean'
elif type(py_type) == str:
return 'varchar'
else:
plpy.error("Cannot determine the type of {0}".format(py_type))