blob: 20a11c2a6e5755a10c7905e9cc71a5214fbea770 [file] [log] [blame]
from collections import Iterable
import plpy
import re
import string
from types import StringTypes
# Postgresql naming restrictions
"""
Both keywords and identifier names in PostgreSQL have a maximum length limit of
31 characters. Parsed keywords or identifiers over that length limit are
automatically truncated. Identifiers may begin with any letter (a through z), or
with an underscore, and may then be followed by letters, numbers (0 through 9),
or underscores. While keywords are not permitted to start or end with an
underscore, identifier names are permitted to do so. Neither keywords nor
identifiers should ever begin with a number.
The only instances quotes are required are either when a
database object's identifier is identical to a keyword, or when the identifier
has at least one capitalized letter in its name. In either of these
circumstances, remember to quote the identifier both when creating the
object, as well as in any subsequent references to that object (e.g., in SELECT,
DELETE, or UPDATE statements).
"""
m4_changequote(`<!', `!>')
def unquote_ident(input_str):
"""
Returns input_str with starting and trailing double quotes stripped
If the input_str is not quoted then a lower case version of the string is
returned.
Args:
@param input_str
Returns:
String
"""
if input_str:
input_str = input_str.strip()
if input_str.startswith('"') and input_str.endswith('"'):
# if input_str has pair of double quotes within itself
# (not the ones at the two ends) then each pair is same as single
# double quote (the first double quote is used to escape the 2nd
# double quote)
return re.sub(r'""', r'"', input_str[1:-1])
else:
return input_str.lower()
else:
return input_str
# -------------------------------------------------------------------------
def quote_ident(input_str):
"""
Returns input_str with quotes added per Postgres identifier rules.
This function is available via plpy.quote_ident in PG > 9.1 and GPDB >= 6.0.
We add this function for compatibility with older versions of Greenplum.
If the input_str is a lower case string with characters in [a-z0-9_] then the
string is returned as is, else a double quote is added in front and back of the string.
Every double quote in the original string is preceeded by another double qoute.
Note: we don't check for SQL keywords. plpy.quote_ident is a better alternative
when available.
Args:
@param input_str
Returns:
String
"""
try:
return plpy.quote_ident(input_str)
except AttributeError:
def quote_not_needed(ch):
return (ch in string.ascii_lowercase or ch in string.digits or ch == '_')
if input_str:
input_str = input_str.strip()
if all(quote_not_needed(c) for c in input_str):
return input_str
else:
# if input_str has double quotes then each double quote
# is prependend with a double quote
# (the 1st double quote is used to escape the 2nd double quote)
return '"' + re.sub(r'"', r'""', input_str) + '"'
# -------------------------------------------------------------------------
def _get_table_schema_names(tbl, only_first_schema=False):
"""
Returns a pair containing a set of schema names and the table name from
input string.
The schema name is output as a string representation of the tuple: (schema
names). If input table name is schema qualified then only the specific
schema name is included in the tuple string. If it is not schema qualified
then all the current schemas (including implicit schemas) are included.
Note: The table/schema names could be double-quoted. This function unquotes
the names by stripping the leading and trailing quotes and replaces every
pair of double quotes with a single double quote.
Args:
@param tbl Input table name (could be schema qualified)
Returns:
Tuple pair, each element a string
"""
if tbl is None or tbl.strip(' \'').lower() in ('null', ''):
plpy.error('Input error: Table name (NULL) is invalid')
names = tbl.split(".")
if len(names) == 1:
if only_first_schema:
# restricted to the first schema in search path
all_schemas = [plpy.execute(
"SELECT current_schema() AS cs")[0]["cs"]]
else:
all_schemas = plpy.execute(
"SELECT current_schemas(True) ""AS cs")[0]["cs"]
schema_str = "('{0}')".format("','".join(unquote_ident(s)
for s in all_schemas))
table = unquote_ident(names[0])
elif len(names) == 2:
schema_str = "('" + unquote_ident(names[0]) + "')"
table = unquote_ident(names[1])
else:
plpy.error("Incorrect table name ({0}) provided! Table name "
"should be of the form: <schema name>.<table name>".format(tbl))
return (schema_str.strip(), table.strip())
# -------------------------------------------------------------------------
def table_exists(tbl, only_first_schema=False):
"""
Returns True if the table exists in the database.
If the table name is not schema qualified then current_schemas() is used.
The table name is searched in information_schema.tables.
Args:
@param tbl Name of the table. Can be schema qualified. If it is not
qualified then the current schema is used.
"""
schema_str, table = _get_table_schema_names(tbl, only_first_schema)
if schema_str and table:
schema_expr = "LIKE 'pg_temp%'" if schema_str == "('pg_temp')" \
else 'IN {0}'.format(schema_str)
does_table_exist = plpy.execute(
"""
SELECT EXISTS(
SELECT 1
FROM pg_class, pg_namespace
WHERE relnamespace = pg_namespace.oid
AND nspname {schema_expr}
AND (relname = '{table}')
) AS table_exists
""".format(**locals()))[0]['table_exists']
return bool(does_table_exist)
else:
return False
# -------------------------------------------------------------------------
def get_first_schema(table_name):
"""
Return first schema name from search path that contains given table.
The search does not include implicit schemas (like pg_catalog)
Args:
@param table_name: String, table name to search. If table name is
schema-qualified then the schema name is returned
directly.
Returns:
String, schema name if a schema containing the table is found.
None, if none of the schemas in search path contain the table.
"""
names = table_name.split(".")
if not names or len(names) > 2:
raise TypeError("Incorrect table name ({0}) provided! Table name should be "
"of the form: <schema name>.<table name>".format(table_name))
elif len(names) == 2:
return unquote_ident(names[0])
# create a list of schema names in search path
current_schemas = plpy.execute(
"SELECT current_schemas(True) AS cs")[0]["cs"]
if not current_schemas:
return None
# get all schemas that contain a table with this name
schemas_w_table = plpy.execute(
"""SELECT array_agg(table_schema::text) AS schemas
FROM information_schema.tables
WHERE table_name='{table_name}'""".
format(table_name=table_name))[0]["schemas"]
if not schemas_w_table:
return None
for each_schema in current_schemas:
# get the first schema in search path that contains the table
if each_schema in schemas_w_table:
return each_schema
# None of the schemas in search path have the table
return None
# -------------------------------------------------------------------------
def drop_tables(table_list):
"""
Drop tables specified in table_list.
"""
drop_str = ', '.join(table_list)
if drop_str:
plpy.execute("DROP TABLE IF EXISTS {0}".format(drop_str))
def table_is_empty(tbl, filter_str=None):
"""
Returns True if the input table has no rows
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
where_clause = "WHERE " + filter_str if filter_str else ''
n_valid_tuples = plpy.execute("""
SELECT count(*)
FROM (
-- create a sub-query to get only single tuple
SELECT *
FROM {0}
{1}
LIMIT 1
) q
""".format(tbl, where_clause))[0]["count"]
return (False if n_valid_tuples > 0 else True)
# -------------------------------------------------------------------------
def get_cols(tbl):
"""
Get all column names in a table.
If the table is schema qualified then the appropriate schema is searched.
If no schema qualification is provided then the current schema is used.
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
sql_string = """SELECT array_agg(quote_ident(attname)::varchar
ORDER BY attnum) AS cols
FROM pg_attribute
WHERE attrelid = '{tbl}'::regclass
AND NOT attisdropped
AND attnum > 0"""
return plpy.execute(sql_string.format(tbl=tbl))[0]["cols"]
# -------------------------------------------------------------------------
def get_cols_and_types(tbl):
"""
Get the data types for all columns in a table.
If the table is schema qualified then the appropriate schema is searched.
If no schema qualification is provided then the current schema is used.
Args:
@param tbl: string, Name of the table to search in
Returns:
List. Contains list of pair (col, type). The output is a list instead
of a dictionary since the columns are ordered in 'tbl' and this result
maintains that order.
The data type returned will be the type name if it is a built-in type, or
'ARRAY' if it is some array. For any other case it will be 'USER-DEFINED'.
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
# determine the exact table_schema and table_name
# in case that source_table only contains table_name
row = plpy.execute("""
SELECT
nspname AS table_schema,
relname AS table_name
FROM
pg_class AS c,
pg_namespace AS nsp
WHERE
c.oid = '{tbl}'::regclass::oid AND
c.relnamespace = nsp.oid
""".format(tbl=tbl))
schema = row[0]['table_schema']
table = row[0]['table_name']
sql_string = """SELECT array_agg(quote_ident(column_name)::varchar ORDER BY ordinal_position) AS cols,
array_agg(data_type::varchar ORDER BY ordinal_position) AS types
FROM information_schema.columns
WHERE table_name = '{table}'
AND table_schema = '{schema}'
""".format(table=table, schema=schema)
result = plpy.execute(sql_string)[0]
col_names = result['cols']
col_types = result['types']
return list(zip(col_names, col_types))
# -------------------------------------------------------------------------
def get_col_value_and_type(table_name, column_name):
"""
Return the value and type of a column from a table.
Args:
@param table_name
@param column_name
Returns column_value, column_type
"""
if table_name is None or table_name.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid.')
if not is_var_valid(table_name, column_name):
plpy.error('Input error: Column name is invalid.')
value = plpy.execute("SELECT {0} AS value FROM {1}".
format(column_name, table_name)
)[0]['value']
col_type = get_expr_type(column_name, table_name)
return value, col_type
# -------------------------------------------------------------------------
def get_expr_type(expressions, tbl):
""" Return the type of a multiple expressions run on a given table
Note: this
Args:
@param expr
Returns:
str
"""
# FIXME: Below transformation exist to ensure backwards compatibility
# Remove this when all callers have been modified to pass an Iterable
# 'expressions'
if (isinstance(expressions, StringTypes) or
not isinstance(expressions, Iterable)):
expressions = [expressions]
input_was_scalar = True
else:
input_was_scalar = False
pg_type_expressions = ["pg_typeof({0})".format(e) for e in expressions]
expr_types = plpy.execute("""
SELECT {0} as all_types
FROM {1}
LIMIT 1
""".format("ARRAY[{0}]".format(','.join(pg_type_expressions)),
tbl))
if not expr_types:
plpy.error("Unable to get type of expression ({0}). "
"Table {1} may not contain any valid tuples".
format(expressions, tbl))
expr_types = expr_types[0]["all_types"]
if input_was_scalar:
# output should be same form as input
return expr_types[0]
else:
return expr_types
# # -------------------------------------------------------------------------
def columns_exist_in_table(tbl, cols, schema_madlib="madlib"):
"""
Does each column exist in the table?
Args:
@param tbl Name of source table
@param cols Iterable list of column names
@param schema Schema in which madlib is installed
Returns:
True if all columns in 'cols' exist in source table else False
"""
existing_cols = set(unquote_ident(i) for i in get_cols(tbl))
for col in cols:
if not col or unquote_ident(col) not in existing_cols:
return False
return True
# -------------------------------------------------------------------------
def columns_missing_from_table(tbl, cols):
""" Get which columns are not present in a given table
Args:
@param tbl Name of source table
@param cols Iterable containing column names
Returns:
True if all columns in 'cols' exist in source table else False
"""
if not cols:
return []
existing_cols = set(unquote_ident(i) for i in get_cols(tbl))
# column is considered missing if the name is invalid (None or empty) or
# if the column is not present in the table
return [col for col in cols
if not col or unquote_ident(col) not in existing_cols]
# -------------------------------------------------------------------------
def is_col_array(tbl, col):
"""
Return True if the column is of an array datatype
Args:
@param tbl Name of the table to search. This can be schema qualified,
if it is not qualified then the current_schema is used.
@param col Name of the column to check datatype of
Returns:
Boolean
Throws:
plpy.error if the column is not found in the table
"""
if not tbl:
plpy.error("Input error: Invalid table {0}".format(tbl))
if not col:
plpy.error("Input error: Invalid column name {0}".format(col))
col = unquote_ident(col)
data_type_list = plpy.execute(
"""
SELECT format_type(atttypid, atttypmod) AS data_type
FROM pg_attribute
WHERE attrelid = '{tbl}'::regclass
AND NOT attisdropped
AND attnum > 0
AND attname = '{col}'
""".format(**locals()))
if data_type_list:
for data_type in data_type_list:
if '[]' in data_type["data_type"]:
return True
return False
else:
plpy.error("Column {0} not found in table {1}".format(col, tbl))
# -------------------------------------------------------------------------
def scalar_col_has_no_null(tbl, col):
"""
Return True if a scalar column has no NULL values?
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if col is None or col.lower() == 'null':
plpy.error('Input error: Column name is invalid')
col_null_rows = plpy.execute("""SELECT count(*)
FROM {tbl}
WHERE ({col}) IS NULL
""".format(col=col, tbl=tbl))[0]["count"]
return (col_null_rows == 0)
# -------------------------------------------------------------------------
def array_col_dimension(tbl, col, dim=1):
"""
What is the dimension of this array column
"""
if tbl is None:
plpy.error('Input error: Table name (NULL) is invalid')
if col is None:
plpy.error('Input error: Column name is invalid')
dim = plpy.execute("""
SELECT max(array_upper({col}, {dim})) AS dim
FROM {tbl}
""".format(col=col, tbl=tbl, dim=dim))[0]["dim"]
return dim
# ------------------------------------------------------------------------
def array_col_has_same_dimension(tbl, col):
"""
Do all array elements of an array column have the same length?
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if col is None or col.lower() == 'null':
plpy.error('Input error: Column name is invalid')
results = plpy.execute("""
SELECT min(array_upper({col}, 1)) AS min_dim,
max(array_upper({col}, 1)) AS max_dim
FROM {tbl}
""".format(col=col, tbl=tbl))[0]
return results['max_dim'] == results['min_dim']
# ------------------------------------------------------------------------
def explicit_bool_to_text(tbl, cols, schema_madlib):
"""
Patch madlib.bool_to_text for columns that are of type boolean.
"""
m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!> , <!return cols!> , <!!>)
patched = []
col_types = get_expr_type(cols, tbl)
for col, col_type in zip(cols, col_types):
if col_type == 'boolean':
patched.append("{0}.bool_to_text({1})".format(schema_madlib, col))
else:
patched.append(col)
return patched
# -------------------------------------------------------------------------
def array_col_has_no_null(tbl, col):
"""
Return True if an array column has no NULL values?
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if col is None or col.lower() == 'null':
plpy.error('Input error: Column name is invalid')
row_len = plpy.execute("SELECT count(*) from {tbl}".
format(tbl=tbl))[0]["count"]
dim = plpy.execute("""
SELECT max(array_upper({col}, 1)) AS dim
FROM {tbl}
""".format(col=col, tbl=tbl))[0]["dim"]
for i in range(1, dim + 1):
n_non_nulls = plpy.execute("SELECT count(({col})[{i}]) FROM {tbl}".
format(col=col, tbl=tbl, i=i))[0]["count"]
if row_len != n_non_nulls:
return False
return True
# -------------------------------------------------------------------------
def get_col_dimension(tbl, col_name, dim=1):
"""
Returns upper bound of the requested array dimension
Example:
col_name : ARRAY[[1,2,3],[4,5,6]]
dim=1, return value = 2
dim=2, return value = 3
col_name : ARRAY[1,2,3,4,5,6]
dim=1, return value = 6
"""
col_dim = plpy.execute("""
SELECT array_upper({col_name}, {dim}) AS dimension
FROM {tbl} LIMIT 1
""".format(col_name=col_name, dim=dim, tbl=tbl))[0]["dimension"]
return col_dim
def _tbl_dimension_rownum(schema_madlib, tbl, col_name, skip_row_count=False):
"""
Measure the dimension and row number of source data table
Please note that calculating the row count will incur a pass over the
entire dataset. Hence the flag skip_row_count to optionally skip the row
count calculation.
"""
dimension = get_col_dimension(tbl, col_name)
# total row number of data source table
# The WHERE clause here ignores rows in the table that contain one or more
# NULLs in the independent variable (x). There is no NULL check made for
# the dependent variable (y), since one of the hard assumptions of the
# input data to elastic_net is that the dependent variable cannot be NULL.
if skip_row_count:
return dimension, None
row_num = plpy.execute("""
SELECT COUNT(*) FROM {tbl}
WHERE NOT {schema_madlib}.array_contains_null({col_name})
""".format(tbl=tbl,
schema_madlib=schema_madlib,
col_name=col_name))[0]["count"]
return (dimension, row_num)
# ------------------------------------------------------------------------
def is_var_valid(tbl, var, order_by=None):
"""
Test whether the variable(s) is valid by actually selecting it from
the table
"""
if var is None or var.strip() == '':
return False
try:
order_by_str = "" if not order_by else "ORDER BY " + order_by
plpy.execute(
"""
SELECT {var}
FROM {tbl}
{order_by_str}
LIMIT 0
""".format(**locals()))
except Exception as e:
plpy.warning(str(e))
return False
return True
# -------------------------------------------------------------------------
def input_tbl_valid(tbl, module, check_empty=True, error_suffix_str=None):
# Add a space in the beginning here instead of at the error message for
# backward compatibility.
error_suffix_str = ' {0}'.format(error_suffix_str) if error_suffix_str else ''
if tbl is None or tbl.strip() == '':
plpy.error("{0} error: NULL/empty input table name!{1}".format(
module, error_suffix_str))
if not table_exists(tbl):
plpy.error("{0} error: Input table '{1}' does not exist.{2}".format(
module, tbl, error_suffix_str))
if check_empty and table_is_empty(tbl):
plpy.error("{0} error: Input table '{1}' is empty!{2}".format(
module, tbl, error_suffix_str))
# -------------------------------------------------------------------------
def output_tbl_valid(tbl, module):
if tbl is None or tbl.strip().lower() in ['', 'null']:
plpy.error(
"{module} error: NULL/empty output table name!".format(**locals()))
if table_exists(tbl, only_first_schema=True):
plpy.error("""{module} error: Output table '{tbl}' already exists.
Drop it before calling the function.""".format(**locals()))
t0 = tbl.strip().lower()[0]
if not(t0.isalpha() or t0 == '_' or t0 == '"'):
plpy.error(
"{module} error: Invalid output table name {tbl}".format(**locals()))
# -------------------------------------------------------------------------
def cols_in_tbl_valid(tbl, cols, module, invalid_names=None):
for c in cols:
if c is None or c.strip() == '':
plpy.error(
"{module} error: NULL/empty column name!".format(module=module))
if invalid_names and c.strip() in invalid_names:
plpy.error("{module} error: Column {c} is an invalid name.".
format(module=module, c=c))
missing_cols = columns_missing_from_table(tbl, cols)
# FIXME: still printing just 1 column name for backwards compatibility
# this should be changed to print all missing columns
if missing_cols:
c = missing_cols[0]
plpy.error("{module} error: Column '{c}' does not exist in table '{tbl}'!".
format(**locals()))
# -------------------------------------------------------------------------
def regproc_valid(qualified_name, args_str, module):
try:
plpy.execute("""
SELECT '{qualified_name}({args_str})'::regprocedure;
""".format(**locals()))
except:
plpy.error(
"""{module} error: Required function "{qualified_name}({args_str})" not found!""".format(**locals()))
# -------------------------------------------------------------------------
def does_exclude_reserved(targets, reserved):
"""
Function to check if any target column name is part of reserved column
names. Throw an error if it is.
"""
intersect = frozenset(targets).intersection(frozenset(reserved))
if len(intersect) != 0:
plpy.error("Error: Conflicting column names.\n"
"Some predefined keyword(s) ({0}) are not allowed "
"for column names in module input params.".format(
', '.join(intersect)))
# -------------------------------------------------------------------------
def get_algorithm_name(algorithm, default, supported_algorithms, module):
if not algorithm:
algorithm = default
else:
algorithm = algorithm.lower()
try:
# allow user to specify a prefix substring of
# supported algorithms. This works because the supported
# algorithms have unique prefixes.
algorithm = next(x for x in supported_algorithms
if x.startswith(algorithm))
except StopIteration:
# next() returns a StopIteration if no element found
plpy.error("{0} Error: Invalid algorithm: "
"{1}. Supported algorithms are ({2})"
.format(module, algorithm,
','.join(sorted(supported_algorithms))))
return algorithm
import unittest
class TestValidateFunctions(unittest.TestCase):
def test_table_names(self):
self.assertEqual(('test_schema', 'test_table'),
_get_table_schema_names('test_schema.test_table'))
self.assertEqual(('"test_schema"', '"test_table"'),
_get_table_schema_names('"test_schema"."test_table"'))
self.assertEqual('Test', unquote_ident('"Test"'))
self.assertEqual('test', unquote_ident('Test'))
self.assertEqual('Test123', unquote_ident('"Test123"'))
self.assertEqual('test', unquote_ident('"test"'))
if __name__ == '__main__':
unittest.main()