| import plpy |
| from utilities import __mad_version |
| import re |
| |
| version_wrapper = __mad_version() |
| _string_to_array = version_wrapper.select_vecfunc() |
| |
| # Postgresql naming restrictions |
| """ |
| Both keywords and identifier names in PostgreSQL have a maximum length limit of |
| 31 characters. Parsed keywords or identifiers over that length limit are |
| automatically truncated. Identifiers may begin with any letter (a through z), or |
| with an underscore, and may then be followed by letters, numbers (0 through 9), |
| or underscores. While keywords are not permitted to start or end with an |
| underscore, identifier names are permitted to do so. Neither keywords nor |
| identifiers should ever begin with a number. |
| |
| The only instances quotes are required are either when a |
| database object's identifier is identical to a keyword, or when the identifier |
| has at least one capitalized letter in its name. In either of these |
| circumstances, remember to quote the identifier both when creating the |
| object, as well as in any subsequent references to that object (e.g., in SELECT, |
| DELETE, or UPDATE statements). |
| """ |
| |
| |
| def _unquote_name(input_str): |
| """ |
| Returns input_str with starting and trailing double quotes stripped |
| |
| If the input_str is not quoted then a lower case version of the string is |
| returned. |
| Args: |
| @param input_str |
| |
| Returns: |
| String |
| """ |
| if input_str: |
| input_str = input_str.strip() |
| if input_str.startswith('"') and input_str.endswith('"'): |
| # if input_str has pair of double quotes within itself |
| # (not the ones at the two ends) then each pair is same as single |
| # double quote (the first double quote is used to escape the 2nd |
| # double quote) |
| return re.sub(r'""', r'"', input_str[1:-1]) |
| else: |
| return input_str.lower() |
| else: |
| return input_str |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_table_schema_names(tbl, only_first_schema=False): |
| """ |
| Returns a pair containing a set of schema names and the table name from |
| input string. |
| |
| The schema name is output as a string representation of the tuple: (schema |
| names). If input table name is schema qualified then only the specific |
| schema name is included in the tuple string. If it is not schema qualified |
| then all the current schemas (including implicit schemas) are included. |
| |
| Note: The table/schema names could be double-quoted. This function unquotes |
| the names by stripping the leading and trailing quotes and replaces every |
| pair of double quotes with a single double quote. |
| |
| Args: |
| @param tbl Input table name (could be schema qualified) |
| Returns: |
| Tuple pair, each element a string |
| """ |
| if tbl is None or tbl.strip(' \'').lower() in ('null', ''): |
| plpy.error('Input error: Table name (NULL) is invalid') |
| names = tbl.split(".") |
| if len(names) == 1: |
| if only_first_schema: |
| # restricted to the first schema in search path |
| all_schemas = [plpy.execute("SELECT current_schema() AS cs")[0]["cs"]] |
| else: |
| all_schemas = _string_to_array(plpy.execute( |
| "SELECT current_schemas(True) ""AS cs")[0]["cs"]) |
| schema_str = "('{0}')".format("','".join(_unquote_name(s) |
| for s in all_schemas)) |
| table = _unquote_name(names[0]) |
| elif len(names) == 2: |
| schema_str = "('" + _unquote_name(names[0]) + "')" |
| table = _unquote_name(names[1]) |
| else: |
| plpy.error("Incorrect table name ({0}) provided! Table name " |
| "should be of the form: <schema name>.<table name>".format(tbl)) |
| return (schema_str.strip(), table.strip()) |
| # ------------------------------------------------------------------------- |
| |
| |
| def table_exists(tbl, only_first_schema=False): |
| """ |
| Returns True if the table exists in the database. |
| |
| If the table name is not schema qualified then current_schemas() is used. |
| The table name is searched in information_schema.tables. |
| |
| Args: |
| @param tbl Name of the table. Can be schema qualified. If it is not |
| qualified then the current schema is used. |
| """ |
| schema_str, table = _get_table_schema_names(tbl, only_first_schema) |
| if schema_str and table: |
| schema_expr = "LIKE 'pg_temp%'" if schema_str == "('pg_temp')" \ |
| else 'IN {0}'.format(schema_str) |
| does_table_exist = plpy.execute( |
| """ |
| SELECT EXISTS( |
| SELECT 1 |
| FROM pg_class, pg_namespace |
| WHERE relnamespace = pg_namespace.oid |
| AND nspname {schema_expr} |
| AND relname = '{table}' |
| AND relkind IN ('r', 'v', 'm', 't', 'f') |
| ) AS table_exists |
| """.format(**locals()))[0]['table_exists'] |
| return bool(does_table_exist) |
| else: |
| return False |
| # ------------------------------------------------------------------------- |
| |
| |
| def rename_table(schema_madlib, orig_name, new_name): |
| """ |
| Renames possibly schema qualified table name to a new schema qualified name |
| ensuring the schema qualification are changed appropriately |
| |
| Args: |
| @param orig_name: string, Original name of the table |
| (must be schema qualified if table schema is not in search path) |
| @param new_name: string, New name of the table |
| (can be schema qualified. If it is not then the original |
| schema is maintained) |
| Returns: |
| String. The new table name qualified with the schema name |
| """ |
| new_names_split = new_name.split(".") |
| if len(new_names_split) > 2: |
| raise AssertionError("Invalid table name") |
| new_table_name = new_names_split[-1] |
| new_table_schema = new_names_split[0] if len(new_names_split) > 1 else None |
| |
| orig_names_split = orig_name.split(".") |
| if len(orig_names_split) > 2: |
| raise AssertionError("Invalid table name") |
| |
| if len(orig_names_split) > 1: |
| orig_table_schema = orig_names_split[0] |
| else: |
| ## we need to get the schema name of the original table if we are |
| ## to change the schema of the new table. This is to ensure that we |
| ## change the schema of the correct table in case there are multiple |
| ## tables with the same new name. |
| orig_table_schema = get_first_schema(orig_name) |
| |
| if orig_table_schema is None: |
| raise AssertionError("Relation {0} not found during rename". |
| format(orig_name)) |
| |
| plpy.execute("ALTER TABLE {orig_table} RENAME TO {new_table}". |
| format(orig_table=orig_name, new_table=new_table_name)) |
| |
| if new_table_schema: |
| if new_table_schema != orig_table_schema: |
| ## set schema only if a change in schema is required |
| before_schema_string = "{0}.{1}".format(orig_table_schema, |
| new_table_name) |
| plpy.execute("""ALTER TABLE {new_table} |
| SET SCHEMA {schema_name}""". |
| format(new_table=before_schema_string, |
| schema_name=new_table_schema)) |
| return new_name |
| else: |
| return orig_table_schema + "." + new_table_name |
| # ------------------------------------------------------------------------- |
| |
| |
| def get_first_schema(table_name): |
| """ |
| Return first schema name from search path that contains given table. |
| The search does not include implicit schemas (like pg_catalog) |
| |
| Args: |
| @param table_name: String, table name to search. If table name is |
| schema-qualified then the schema name is returned |
| directly. |
| |
| Returns: |
| String, schema name if a schema containing the table is found. |
| None, if none of the schemas in search path contain the table. |
| """ |
| names = table_name.split(".") |
| if not names or len(names) > 2: |
| raise TypeError("Incorrect table name ({0}) provided! Table name should be " |
| "of the form: <schema name>.<table name>".format(table_name)) |
| elif len(names) == 2: |
| return _unquote_name(names[0]) |
| |
| ## create a list of schema names in search path |
| ## _string_to_array is used for GPDB versions less than 4.2 where an array |
| ## is returned to Python as a string |
| current_schemas = _string_to_array(plpy.execute( |
| "SELECT current_schemas(True) AS cs")[0]["cs"]) |
| |
| if not current_schemas: |
| return None |
| |
| ## get all schemas that contain a table with this name |
| schemas_w_table = _string_to_array(plpy.execute( |
| """SELECT array_agg(table_schema::text) AS schemas |
| FROM information_schema.tables |
| WHERE table_name='{table_name}'""". |
| format(table_name=table_name))[0]["schemas"]) |
| |
| if not schemas_w_table: |
| return None |
| |
| for each_schema in current_schemas: |
| ## get the first schema in search path that contains the table |
| if each_schema in schemas_w_table: |
| return each_schema |
| |
| ## None of the schemas in search path have the table |
| return None |
| # ------------------------------------------------------------------------- |
| |
| |
| def table_is_empty(tbl): |
| """ |
| Returns True if the input table has no rows |
| """ |
| if tbl is None or tbl.lower() == 'null': |
| plpy.error('Input error: Table name (NULL) is invalid') |
| content = plpy.execute("""SELECT count(*) FROM |
| (SELECT * FROM {0} LIMIT 1) q1""". |
| format(tbl))[0]["count"] |
| return not bool(content) # if content == 0 then True, else False |
| # ------------------------------------------------------------------------- |
| |
| |
| def _get_cols_in_current_schema(tbl, schema_madlib="madlib"): |
| """ |
| Get all column names in a table. |
| |
| All schemas in current_schemas are searched for the table and the first one |
| found is returned. |
| Note: This function assumes that the table name is *not* qualified with |
| the schema name |
| """ |
| schema = plpy.execute("select current_schemas(True) as cs")[0]["cs"] |
| # special handling for array in GPDB <= 4.1 |
| schema = _string_to_array(schema) |
| array_agg_string = version_wrapper.select_array_agg(schema_madlib) |
| sql_string = "SELECT " + array_agg_string + \ |
| """(quote_ident(column_name)::varchar) AS cols |
| FROM information_schema.columns |
| WHERE table_name = '{table_name}' |
| AND table_schema = '{s}' |
| """ |
| tbl = _unquote_name(tbl) |
| for s in schema: |
| s = _unquote_name(s) |
| existing_cols = plpy.execute(sql_string.format(table_name=tbl, |
| s=s))[0]["cols"] |
| if existing_cols is not None: |
| return existing_cols |
| return None |
| #------------------------------------------------------------------------- |
| |
| |
| def get_cols(tbl, schema_madlib="madlib"): |
| """ |
| Get all column names in a table. |
| |
| If the table is schema qualified then the appropriate schema is searched. |
| If no schema qualification is provided then the current schema is used. |
| """ |
| array_agg_string = version_wrapper.select_array_agg(schema_madlib) |
| |
| if tbl is None or tbl.lower() == 'null': |
| plpy.error('Input error: Table name (NULL) is invalid') |
| if not schema_madlib: |
| plpy.error('Input error: Invalid MADlib schema name') |
| |
| names = tbl.split(".") |
| |
| if len(names) == 1: |
| return _get_cols_in_current_schema(tbl, schema_madlib) |
| elif len(names) == 2: |
| schema = _unquote_name(names[0]) |
| table = _unquote_name(names[1]) |
| sql_string = "SELECT " + array_agg_string + \ |
| """(quote_ident(attname)::varchar) AS cols |
| FROM pg_attribute |
| WHERE attrelid = '{tbl}'::regclass |
| AND NOT attisdropped |
| AND attnum > 0""" |
| existing_cols = plpy.execute(sql_string.format(**locals()))[0]["cols"] |
| else: |
| plpy.error("Input error: Invalid table name - {0}!".format(tbl)) |
| return existing_cols |
| #------------------------------------------------------------------------- |
| |
| |
| def get_cols_and_types(tbl): |
| """ |
| Get the data types for all columns in a table. |
| |
| If the table is schema qualified then the appropriate schema is searched. |
| If no schema qualification is provided then the current schema is used. |
| |
| Args: |
| @param tbl: string, Name of the table to search in |
| |
| Returns: |
| Dictionary. Key is the column name and the Value is the data type |
| |
| The data type returned will be the type name if it is a built-in type, or |
| 'ARRAY' if it is some array. For any other case it will be 'USER-DEFINED'. |
| """ |
| if tbl is None or tbl.lower() == 'null': |
| plpy.error('Input error: Table name (NULL) is invalid') |
| |
| names = tbl.split(".") |
| |
| if not names or len(names) > 2: |
| raise TypeError("Input error: Invalid table name - {0}!".format(tbl)) |
| elif len(names) == 1: |
| table = _unquote_name(names[0]) |
| schema = get_first_schema(table) |
| elif len(names) == 2: |
| schema = _unquote_name(names[0]) |
| table = _unquote_name(names[1]) |
| sql_string = """SELECT array_agg(quote_ident(column_name)::varchar) AS cols, |
| array_agg(data_type::varchar) AS types |
| FROM information_schema.columns |
| WHERE table_name = '{table_name}' |
| AND table_schema = '{schema_name}' |
| """.format(table_name=table, |
| schema_name=schema) |
| result = plpy.execute(sql_string)[0] |
| col_names = _string_to_array(result['cols']) |
| col_types = _string_to_array(result['types']) |
| return dict(zip(col_names, col_types)) |
| # ------------------------------------------------------------------------- |
| |
| |
| def get_expr_type(expr, tbl): |
| """ Temporary function to obtain the type of an expression by importing |
| the expression data into python. |
| |
| Args: |
| @param expr |
| |
| Returns: |
| str. |
| |
| FIXME: Currently this utilizes PLPYTHON to get the type of an expression. |
| This can be improved to obtain the type directly from the parsed tree in SQL. |
| Also, return types are limited to one of |
| {TEXT, BOOLLEAN, INTEGER, DOUBLE PRECISION, INTEGER[], DOUBLE PRECISION[]} |
| """ |
| expr_type = plpy.execute("SELECT {0} as type from {1} LIMIT 1".format(expr, tbl)) |
| if expr_type: |
| expr_type = expr_type[0]["type"] |
| if isinstance(expr_type, str): |
| return "TEXT" |
| elif isinstance(expr_type, bool): |
| return "BOOLEAN" |
| elif isinstance(expr_type, int): |
| return "INTEGER" |
| elif isinstance(expr_type, float): |
| return "DOUBLE PRECISION" |
| elif isinstance(expr_type, list): |
| if isinstance(expr_type[0], int): |
| return "INTEGER[]" |
| elif isinstance(expr_type[0], float): |
| return "DOUBLE PRECISION[]" |
| else: |
| raise ValueError("ARRAY type cannot be determined. ") |
| else: |
| raise ValueError("Type for {0} cannot to be determined.".format(expr)) |
| # ------------------------------------------------------------------------- |
| |
| |
| def columns_exist_in_table(tbl, cols, schema_madlib="madlib"): |
| """ |
| Does each column exist in the table? |
| |
| Args: |
| @param tbl Name of source table |
| @param cols Iterable list of column names |
| @param schema Schema in which madlib is installed |
| |
| Returns: |
| True if all columns in 'cols' exist in source table else False |
| """ |
| existing_cols = set(_unquote_name(i) for i in get_cols(tbl, schema_madlib)) |
| for col in cols: |
| if not col or _unquote_name(col) not in existing_cols: |
| return False |
| return True |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_col_array(tbl, col): |
| """ |
| Return True if the column is of an array datatype |
| |
| Args: |
| @param tbl Name of the table to search. This can be schema qualified, |
| if it is not qualified then the current_schema is used. |
| @param col Name of the column to check datatype of |
| Returns: |
| Boolean |
| Throws: |
| plpy.error if the column is not found in the table |
| """ |
| if not tbl: |
| plpy.error("Input error: Invalid table {0}".format(tbl)) |
| if not col: |
| plpy.error("Input error: Invalid column name {0}".format(col)) |
| col = _unquote_name(col) |
| |
| data_type_list = plpy.execute( |
| """ |
| SELECT format_type(atttypid, atttypmod) AS data_type |
| FROM pg_attribute |
| WHERE attrelid = '{tbl}'::regclass |
| AND NOT attisdropped |
| AND attnum > 0 |
| AND attname = '{col}' |
| """.format(**locals())) |
| |
| if data_type_list: |
| for data_type in data_type_list: |
| if '[]' in data_type["data_type"]: |
| return True |
| return False |
| else: |
| plpy.error("Column {0} not found in table {1}".format(col, tbl)) |
| # ------------------------------------------------------------------------- |
| |
| |
| def scalar_col_has_no_null(tbl, col): |
| """ |
| Return True if a scalar column has no NULL values? |
| """ |
| if tbl is None or tbl.lower() == 'null': |
| plpy.error('Input error: Table name (NULL) is invalid') |
| if col is None or col.lower() == 'null': |
| plpy.error('Input error: Column name is invalid') |
| col_null_rows = plpy.execute("""SELECT count(*) |
| FROM {tbl} |
| WHERE ({col}) IS NULL |
| """.format(col=col, tbl=tbl))[0]["count"] |
| return (col_null_rows == 0) |
| # ------------------------------------------------------------------------- |
| |
| |
| def array_col_has_same_dimension(tbl, col): |
| """ |
| Do all array elements of an array column have the same length? |
| """ |
| if tbl is None or tbl.lower() == 'null': |
| plpy.error('Input error: Table name (NULL) is invalid') |
| if col is None or col.lower() == 'null': |
| plpy.error('Input error: Column name is invalid') |
| |
| max_dim = plpy.execute(""" |
| SELECT max(array_upper({col}, 1)) AS max_dim |
| FROM {tbl} |
| """.format(col=col, tbl=tbl))[0]["max_dim"] |
| min_dim = plpy.execute(""" |
| SELECT min(array_upper({col}, 1)) AS min_dim |
| FROM {tbl} |
| """.format(col=col, tbl=tbl))[0]["min_dim"] |
| return max_dim == min_dim |
| # ------------------------------------------------------------------------ |
| |
| |
| def __explicit_bool_to_text(tbl, cols, schema_madlib): |
| """ |
| Patch madlib.bool_to_text for columns that are of type boolean. |
| """ |
| col_to_type = get_cols_and_types(tbl) |
| patched = [] |
| for col in cols: |
| if col_to_type[col] == 'boolean': |
| patched.append(schema_madlib + ".bool_to_text(" + col + ")") |
| else: |
| patched.append(col) |
| |
| return patched |
| # ------------------------------------------------------------------------- |
| |
| |
| def array_col_has_no_null(tbl, col): |
| """ |
| Return True if an array column has no NULL values? |
| """ |
| if tbl is None or tbl.lower() == 'null': |
| plpy.error('Input error: Table name (NULL) is invalid') |
| if col is None or col.lower() == 'null': |
| plpy.error('Input error: Column name is invalid') |
| |
| row_len = plpy.execute("SELECT count(*) from {tbl}". |
| format(tbl=tbl))[0]["count"] |
| dim = plpy.execute(""" |
| SELECT max(array_upper({col}, 1)) AS dim |
| FROM {tbl} |
| """.format(col=col, tbl=tbl))[0]["dim"] |
| for i in range(1, dim + 1): |
| l = plpy.execute("SELECT count({col}[{i}]) FROM {tbl}". |
| format(col=col, tbl=tbl, i=i))[0]["count"] |
| if row_len != l: |
| return False |
| return True |
| # ------------------------------------------------------------------------- |
| |
| |
| def is_var_valid(tbl, var): |
| """ |
| Test whether the variable(s) is valid by actually selecting it from |
| the table |
| """ |
| try: |
| plpy.execute( |
| """ |
| SELECT {var} FROM {tbl} LIMIT 0 |
| """.format(var=var, |
| tbl=tbl)) |
| except Exception: |
| return False |
| return True |
| # ------------------------------------------------------------------------- |
| |
| |
| def input_tbl_valid(tbl, module, check_empty=True): |
| if tbl is None or tbl.strip() == '': |
| plpy.error("{module} error: NULL/empty input table name!".format(**locals())) |
| |
| if not table_exists(tbl): |
| plpy.error("{module} error: Input table '{tbl}' does not exist".format(**locals())) |
| |
| if check_empty and table_is_empty(tbl): |
| plpy.error("{module} error: Input table '{tbl}' is empty!".format(**locals())) |
| # ------------------------------------------------------------------------- |
| |
| |
| def output_tbl_valid(tbl, module): |
| if tbl is None or tbl.strip() == '': |
| plpy.error("{module} error: NULL/empty output table name!".format(**locals())) |
| |
| if table_exists(tbl, only_first_schema=True): |
| plpy.error("""{module} error: Output table '{tbl}' already exists. |
| Drop it before calling the function.""".format(**locals())) |
| # ------------------------------------------------------------------------- |
| |
| |
| def cols_in_tbl_valid(tbl, cols, module): |
| for c in cols: |
| if c is None or c.strip() == '': |
| plpy.error("{module} error: NULL/empty column name!".format(**locals())) |
| if not columns_exist_in_table(tbl, cols): |
| for c in cols: |
| if not columns_exist_in_table(tbl, [c]): |
| plpy.error("{module} error: Column '{c}' does not exist in table '{tbl}'!".format(**locals())) |
| # ------------------------------------------------------------------------- |
| |
| |
| def regproc_valid(qualified_name, args_str, module): |
| try: |
| plpy.execute(""" |
| SELECT '{qualified_name}({args_str})'::regprocedure; |
| """.format(**locals())) |
| except: |
| plpy.error("""{module} error: Required function "{qualified_name}({args_str})" not found!""".format(**locals())) |
| # ------------------------------------------------------------------------- |
| |
| |
| import unittest |
| |
| |
| class TestValidateFunctions(unittest.TestCase): |
| |
| def test_table_names(self): |
| self.assertEqual(('test_schema', 'test_table'), |
| _get_table_schema_names('test_schema.test_table')) |
| self.assertEqual(('"test_schema"', '"test_table"'), |
| _get_table_schema_names('"test_schema"."test_table"')) |
| self.assertEqual('Test', _unquote_name('"Test"')) |
| self.assertEqual('test', _unquote_name('Test')) |
| self.assertEqual('Test123', _unquote_name('"Test123"')) |
| self.assertEqual('test', _unquote_name('"test"')) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |