blob: 56c406fb2ed4c47bfead0ee7a5f5c1b8e42e007f [file] [log] [blame]
import plpy
from utilities import __mad_version
import re
version_wrapper = __mad_version()
_string_to_array = version_wrapper.select_vecfunc()
# Postgresql naming restrictions
"""
Both keywords and identifier names in PostgreSQL have a maximum length limit of
31 characters. Parsed keywords or identifiers over that length limit are
automatically truncated. Identifiers may begin with any letter (a through z), or
with an underscore, and may then be followed by letters, numbers (0 through 9),
or underscores. While keywords are not permitted to start or end with an
underscore, identifier names are permitted to do so. Neither keywords nor
identifiers should ever begin with a number.
The only instances quotes are required are either when a
database object's identifier is identical to a keyword, or when the identifier
has at least one capitalized letter in its name. In either of these
circumstances, remember to quote the identifier both when creating the
object, as well as in any subsequent references to that object (e.g., in SELECT,
DELETE, or UPDATE statements).
"""
def _unquote_name(input_str):
"""
Returns input_str with starting and trailing double quotes stripped
If the input_str is not quoted then a lower case version of the string is
returned.
Args:
@param input_str
Returns:
String
"""
if input_str:
input_str = input_str.strip()
if input_str.startswith('"') and input_str.endswith('"'):
# if input_str has pair of double quotes within itself
# (not the ones at the two ends) then each pair is same as single
# double quote (the first double quote is used to escape the 2nd
# double quote)
return re.sub(r'""', r'"', input_str[1:-1])
else:
return input_str.lower()
else:
return input_str
# -------------------------------------------------------------------------
def _get_table_schema_names(tbl, only_first_schema=False):
"""
Returns a pair containing a set of schema names and the table name from
input string.
The schema name is output as a string representation of the tuple: (schema
names). If input table name is schema qualified then only the specific
schema name is included in the tuple string. If it is not schema qualified
then all the current schemas (including implicit schemas) are included.
Note: The table/schema names could be double-quoted. This function unquotes
the names by stripping the leading and trailing quotes and replaces every
pair of double quotes with a single double quote.
Args:
@param tbl Input table name (could be schema qualified)
Returns:
Tuple pair, each element a string
"""
if tbl is None or tbl.strip(' \'').lower() in ('null', ''):
plpy.error('Input error: Table name (NULL) is invalid')
names = tbl.split(".")
if len(names) == 1:
if only_first_schema:
# restricted to the first schema in search path
all_schemas = [plpy.execute("SELECT current_schema() AS cs")[0]["cs"]]
else:
all_schemas = _string_to_array(plpy.execute(
"SELECT current_schemas(True) ""AS cs")[0]["cs"])
schema_str = "('{0}')".format("','".join(_unquote_name(s)
for s in all_schemas))
table = _unquote_name(names[0])
elif len(names) == 2:
schema_str = "('" + _unquote_name(names[0]) + "')"
table = _unquote_name(names[1])
else:
plpy.error("Incorrect table name ({0}) provided! Table name "
"should be of the form: <schema name>.<table name>".format(tbl))
return (schema_str.strip(), table.strip())
# -------------------------------------------------------------------------
def table_exists(tbl, only_first_schema=False):
"""
Returns True if the table exists in the database.
If the table name is not schema qualified then current_schemas() is used.
The table name is searched in information_schema.tables.
Args:
@param tbl Name of the table. Can be schema qualified. If it is not
qualified then the current schema is used.
"""
schema_str, table = _get_table_schema_names(tbl, only_first_schema)
if schema_str and table:
schema_expr = "LIKE 'pg_temp%'" if schema_str == "('pg_temp')" \
else 'IN {0}'.format(schema_str)
does_table_exist = plpy.execute(
"""
SELECT EXISTS(
SELECT 1
FROM pg_class, pg_namespace
WHERE relnamespace = pg_namespace.oid
AND nspname {schema_expr}
AND relname = '{table}'
AND relkind IN ('r', 'v', 'm', 't', 'f')
) AS table_exists
""".format(**locals()))[0]['table_exists']
return bool(does_table_exist)
else:
return False
# -------------------------------------------------------------------------
def rename_table(schema_madlib, orig_name, new_name):
"""
Renames possibly schema qualified table name to a new schema qualified name
ensuring the schema qualification are changed appropriately
Args:
@param orig_name: string, Original name of the table
(must be schema qualified if table schema is not in search path)
@param new_name: string, New name of the table
(can be schema qualified. If it is not then the original
schema is maintained)
Returns:
String. The new table name qualified with the schema name
"""
new_names_split = new_name.split(".")
if len(new_names_split) > 2:
raise AssertionError("Invalid table name")
new_table_name = new_names_split[-1]
new_table_schema = new_names_split[0] if len(new_names_split) > 1 else None
orig_names_split = orig_name.split(".")
if len(orig_names_split) > 2:
raise AssertionError("Invalid table name")
if len(orig_names_split) > 1:
orig_table_schema = orig_names_split[0]
else:
## we need to get the schema name of the original table if we are
## to change the schema of the new table. This is to ensure that we
## change the schema of the correct table in case there are multiple
## tables with the same new name.
orig_table_schema = get_first_schema(orig_name)
if orig_table_schema is None:
raise AssertionError("Relation {0} not found during rename".
format(orig_name))
plpy.execute("ALTER TABLE {orig_table} RENAME TO {new_table}".
format(orig_table=orig_name, new_table=new_table_name))
if new_table_schema:
if new_table_schema != orig_table_schema:
## set schema only if a change in schema is required
before_schema_string = "{0}.{1}".format(orig_table_schema,
new_table_name)
plpy.execute("""ALTER TABLE {new_table}
SET SCHEMA {schema_name}""".
format(new_table=before_schema_string,
schema_name=new_table_schema))
return new_name
else:
return orig_table_schema + "." + new_table_name
# -------------------------------------------------------------------------
def get_first_schema(table_name):
"""
Return first schema name from search path that contains given table.
The search does not include implicit schemas (like pg_catalog)
Args:
@param table_name: String, table name to search. If table name is
schema-qualified then the schema name is returned
directly.
Returns:
String, schema name if a schema containing the table is found.
None, if none of the schemas in search path contain the table.
"""
names = table_name.split(".")
if not names or len(names) > 2:
raise TypeError("Incorrect table name ({0}) provided! Table name should be "
"of the form: <schema name>.<table name>".format(table_name))
elif len(names) == 2:
return _unquote_name(names[0])
## create a list of schema names in search path
## _string_to_array is used for GPDB versions less than 4.2 where an array
## is returned to Python as a string
current_schemas = _string_to_array(plpy.execute(
"SELECT current_schemas(True) AS cs")[0]["cs"])
if not current_schemas:
return None
## get all schemas that contain a table with this name
schemas_w_table = _string_to_array(plpy.execute(
"""SELECT array_agg(table_schema::text) AS schemas
FROM information_schema.tables
WHERE table_name='{table_name}'""".
format(table_name=table_name))[0]["schemas"])
if not schemas_w_table:
return None
for each_schema in current_schemas:
## get the first schema in search path that contains the table
if each_schema in schemas_w_table:
return each_schema
## None of the schemas in search path have the table
return None
# -------------------------------------------------------------------------
def table_is_empty(tbl):
"""
Returns True if the input table has no rows
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
content = plpy.execute("""SELECT count(*) FROM
(SELECT * FROM {0} LIMIT 1) q1""".
format(tbl))[0]["count"]
return not bool(content) # if content == 0 then True, else False
# -------------------------------------------------------------------------
def _get_cols_in_current_schema(tbl, schema_madlib="madlib"):
"""
Get all column names in a table.
All schemas in current_schemas are searched for the table and the first one
found is returned.
Note: This function assumes that the table name is *not* qualified with
the schema name
"""
schema = plpy.execute("select current_schemas(True) as cs")[0]["cs"]
# special handling for array in GPDB <= 4.1
schema = _string_to_array(schema)
array_agg_string = version_wrapper.select_array_agg(schema_madlib)
sql_string = "SELECT " + array_agg_string + \
"""(quote_ident(column_name)::varchar) AS cols
FROM information_schema.columns
WHERE table_name = '{table_name}'
AND table_schema = '{s}'
"""
tbl = _unquote_name(tbl)
for s in schema:
s = _unquote_name(s)
existing_cols = plpy.execute(sql_string.format(table_name=tbl,
s=s))[0]["cols"]
if existing_cols is not None:
return existing_cols
return None
#-------------------------------------------------------------------------
def get_cols(tbl, schema_madlib="madlib"):
"""
Get all column names in a table.
If the table is schema qualified then the appropriate schema is searched.
If no schema qualification is provided then the current schema is used.
"""
array_agg_string = version_wrapper.select_array_agg(schema_madlib)
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if not schema_madlib:
plpy.error('Input error: Invalid MADlib schema name')
names = tbl.split(".")
if len(names) == 1:
return _get_cols_in_current_schema(tbl, schema_madlib)
elif len(names) == 2:
schema = _unquote_name(names[0])
table = _unquote_name(names[1])
sql_string = "SELECT " + array_agg_string + \
"""(quote_ident(attname)::varchar) AS cols
FROM pg_attribute
WHERE attrelid = '{tbl}'::regclass
AND NOT attisdropped
AND attnum > 0"""
existing_cols = plpy.execute(sql_string.format(**locals()))[0]["cols"]
else:
plpy.error("Input error: Invalid table name - {0}!".format(tbl))
return existing_cols
#-------------------------------------------------------------------------
def get_cols_and_types(tbl):
"""
Get the data types for all columns in a table.
If the table is schema qualified then the appropriate schema is searched.
If no schema qualification is provided then the current schema is used.
Args:
@param tbl: string, Name of the table to search in
Returns:
Dictionary. Key is the column name and the Value is the data type
The data type returned will be the type name if it is a built-in type, or
'ARRAY' if it is some array. For any other case it will be 'USER-DEFINED'.
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
names = tbl.split(".")
if not names or len(names) > 2:
raise TypeError("Input error: Invalid table name - {0}!".format(tbl))
elif len(names) == 1:
table = _unquote_name(names[0])
schema = get_first_schema(table)
elif len(names) == 2:
schema = _unquote_name(names[0])
table = _unquote_name(names[1])
sql_string = """SELECT array_agg(quote_ident(column_name)::varchar) AS cols,
array_agg(data_type::varchar) AS types
FROM information_schema.columns
WHERE table_name = '{table_name}'
AND table_schema = '{schema_name}'
""".format(table_name=table,
schema_name=schema)
result = plpy.execute(sql_string)[0]
col_names = _string_to_array(result['cols'])
col_types = _string_to_array(result['types'])
return dict(zip(col_names, col_types))
# -------------------------------------------------------------------------
def get_expr_type(expr, tbl):
""" Temporary function to obtain the type of an expression by importing
the expression data into python.
Args:
@param expr
Returns:
str.
FIXME: Currently this utilizes PLPYTHON to get the type of an expression.
This can be improved to obtain the type directly from the parsed tree in SQL.
Also, return types are limited to one of
{TEXT, BOOLLEAN, INTEGER, DOUBLE PRECISION, INTEGER[], DOUBLE PRECISION[]}
"""
expr_type = plpy.execute("SELECT {0} as type from {1} LIMIT 1".format(expr, tbl))
if expr_type:
expr_type = expr_type[0]["type"]
if isinstance(expr_type, str):
return "TEXT"
elif isinstance(expr_type, bool):
return "BOOLEAN"
elif isinstance(expr_type, int):
return "INTEGER"
elif isinstance(expr_type, float):
return "DOUBLE PRECISION"
elif isinstance(expr_type, list):
if isinstance(expr_type[0], int):
return "INTEGER[]"
elif isinstance(expr_type[0], float):
return "DOUBLE PRECISION[]"
else:
raise ValueError("ARRAY type cannot be determined. ")
else:
raise ValueError("Type for {0} cannot to be determined.".format(expr))
# -------------------------------------------------------------------------
def columns_exist_in_table(tbl, cols, schema_madlib="madlib"):
"""
Does each column exist in the table?
Args:
@param tbl Name of source table
@param cols Iterable list of column names
@param schema Schema in which madlib is installed
Returns:
True if all columns in 'cols' exist in source table else False
"""
existing_cols = set(_unquote_name(i) for i in get_cols(tbl, schema_madlib))
for col in cols:
if not col or _unquote_name(col) not in existing_cols:
return False
return True
# -------------------------------------------------------------------------
def is_col_array(tbl, col):
"""
Return True if the column is of an array datatype
Args:
@param tbl Name of the table to search. This can be schema qualified,
if it is not qualified then the current_schema is used.
@param col Name of the column to check datatype of
Returns:
Boolean
Throws:
plpy.error if the column is not found in the table
"""
if not tbl:
plpy.error("Input error: Invalid table {0}".format(tbl))
if not col:
plpy.error("Input error: Invalid column name {0}".format(col))
col = _unquote_name(col)
data_type_list = plpy.execute(
"""
SELECT format_type(atttypid, atttypmod) AS data_type
FROM pg_attribute
WHERE attrelid = '{tbl}'::regclass
AND NOT attisdropped
AND attnum > 0
AND attname = '{col}'
""".format(**locals()))
if data_type_list:
for data_type in data_type_list:
if '[]' in data_type["data_type"]:
return True
return False
else:
plpy.error("Column {0} not found in table {1}".format(col, tbl))
# -------------------------------------------------------------------------
def scalar_col_has_no_null(tbl, col):
"""
Return True if a scalar column has no NULL values?
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if col is None or col.lower() == 'null':
plpy.error('Input error: Column name is invalid')
col_null_rows = plpy.execute("""SELECT count(*)
FROM {tbl}
WHERE ({col}) IS NULL
""".format(col=col, tbl=tbl))[0]["count"]
return (col_null_rows == 0)
# -------------------------------------------------------------------------
def array_col_has_same_dimension(tbl, col):
"""
Do all array elements of an array column have the same length?
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if col is None or col.lower() == 'null':
plpy.error('Input error: Column name is invalid')
max_dim = plpy.execute("""
SELECT max(array_upper({col}, 1)) AS max_dim
FROM {tbl}
""".format(col=col, tbl=tbl))[0]["max_dim"]
min_dim = plpy.execute("""
SELECT min(array_upper({col}, 1)) AS min_dim
FROM {tbl}
""".format(col=col, tbl=tbl))[0]["min_dim"]
return max_dim == min_dim
# ------------------------------------------------------------------------
def __explicit_bool_to_text(tbl, cols, schema_madlib):
"""
Patch madlib.bool_to_text for columns that are of type boolean.
"""
col_to_type = get_cols_and_types(tbl)
patched = []
for col in cols:
if col_to_type[col] == 'boolean':
patched.append(schema_madlib + ".bool_to_text(" + col + ")")
else:
patched.append(col)
return patched
# -------------------------------------------------------------------------
def array_col_has_no_null(tbl, col):
"""
Return True if an array column has no NULL values?
"""
if tbl is None or tbl.lower() == 'null':
plpy.error('Input error: Table name (NULL) is invalid')
if col is None or col.lower() == 'null':
plpy.error('Input error: Column name is invalid')
row_len = plpy.execute("SELECT count(*) from {tbl}".
format(tbl=tbl))[0]["count"]
dim = plpy.execute("""
SELECT max(array_upper({col}, 1)) AS dim
FROM {tbl}
""".format(col=col, tbl=tbl))[0]["dim"]
for i in range(1, dim + 1):
l = plpy.execute("SELECT count({col}[{i}]) FROM {tbl}".
format(col=col, tbl=tbl, i=i))[0]["count"]
if row_len != l:
return False
return True
# -------------------------------------------------------------------------
def is_var_valid(tbl, var):
"""
Test whether the variable(s) is valid by actually selecting it from
the table
"""
try:
plpy.execute(
"""
SELECT {var} FROM {tbl} LIMIT 0
""".format(var=var,
tbl=tbl))
except Exception:
return False
return True
# -------------------------------------------------------------------------
def input_tbl_valid(tbl, module, check_empty=True):
if tbl is None or tbl.strip() == '':
plpy.error("{module} error: NULL/empty input table name!".format(**locals()))
if not table_exists(tbl):
plpy.error("{module} error: Input table '{tbl}' does not exist".format(**locals()))
if check_empty and table_is_empty(tbl):
plpy.error("{module} error: Input table '{tbl}' is empty!".format(**locals()))
# -------------------------------------------------------------------------
def output_tbl_valid(tbl, module):
if tbl is None or tbl.strip() == '':
plpy.error("{module} error: NULL/empty output table name!".format(**locals()))
if table_exists(tbl, only_first_schema=True):
plpy.error("""{module} error: Output table '{tbl}' already exists.
Drop it before calling the function.""".format(**locals()))
# -------------------------------------------------------------------------
def cols_in_tbl_valid(tbl, cols, module):
for c in cols:
if c is None or c.strip() == '':
plpy.error("{module} error: NULL/empty column name!".format(**locals()))
if not columns_exist_in_table(tbl, cols):
for c in cols:
if not columns_exist_in_table(tbl, [c]):
plpy.error("{module} error: Column '{c}' does not exist in table '{tbl}'!".format(**locals()))
# -------------------------------------------------------------------------
def regproc_valid(qualified_name, args_str, module):
try:
plpy.execute("""
SELECT '{qualified_name}({args_str})'::regprocedure;
""".format(**locals()))
except:
plpy.error("""{module} error: Required function "{qualified_name}({args_str})" not found!""".format(**locals()))
# -------------------------------------------------------------------------
import unittest
class TestValidateFunctions(unittest.TestCase):
def test_table_names(self):
self.assertEqual(('test_schema', 'test_table'),
_get_table_schema_names('test_schema.test_table'))
self.assertEqual(('"test_schema"', '"test_table"'),
_get_table_schema_names('"test_schema"."test_table"'))
self.assertEqual('Test', _unquote_name('"Test"'))
self.assertEqual('test', _unquote_name('Test'))
self.assertEqual('Test123', _unquote_name('"Test123"'))
self.assertEqual('test', _unquote_name('"test"'))
if __name__ == '__main__':
unittest.main()