| from collections import defaultdict, Iterable |
| import glob |
| import os |
| import re |
| import yaml |
| |
| from utilities import _write_to_file |
| from utilities import is_rev_gte |
| from utilities import get_dbver |
| from utilities import get_rev_num |
| from utilities import remove_comments_from_sql |
| from utilities import run_query |
| |
| |
| if not __name__ == "__main__": |
| def run_sql(sql, portid, con_args): |
| """ |
| @brief Wrapper function for run_query |
| """ |
| return run_query(sql, con_args, True) |
| else: |
| def run_sql(sql, portid, con_args): |
| return [{'dummy': 0}] |
| |
| |
| def get_signature_for_compare(schema, proname, rettype, argument): |
| """ |
| @brief Get the signature of a UDF/UDA for comparison |
| """ |
| signature = '{0} {1}.{2}({3})'.format(rettype.strip(), schema.strip(), |
| proname.strip(), argument.strip()) |
| signature = re.sub('"', '', signature) |
| return signature.lower() |
| |
| |
| class UpgradeBase: |
| """ |
| @brief Base class for handling the upgrade |
| """ |
| |
| def __init__(self, schema, portid, con_args): |
| self._schema = schema.lower() |
| self._portid = portid |
| self._con_args = con_args |
| self._schema_oid = None |
| self._get_schema_oid() |
| self._dbver = get_dbver(self._con_args, self._portid) |
| |
| """ |
| @brief Wrapper function for run_sql |
| """ |
| |
| def _run_sql(self, sql): |
| return run_sql(sql, self._portid, self._con_args) |
| |
| """ |
| @brief Get the oids of some objects from the catalog in the current version |
| """ |
| |
| def _get_schema_oid(self): |
| res = self._run_sql("SELECT oid FROM pg_namespace WHERE nspname = '{0}'". |
| format(self._schema))[0] |
| if 'oid' in res: |
| self._schema_oid = res['oid'] |
| else: |
| self._schema_oid = None |
| return self._schema_oid |
| |
| def _get_function_info(self, oid): |
| """ |
| @brief Get the function name, return type, and arguments given an oid |
| @note The function can only handle the case that proallargtypes is null, |
| refer to pg_catalog.pg_get_function_identity_argument and |
| pg_catalog.pg_get_function_result in PG for a complete implementation, which are |
| not supported by GP |
| """ |
| |
| # Check if the function has any arguments |
| proargtypes = self._run_sql( |
| """ |
| SELECT |
| array_upper(proargtypes,1) as proargtypes |
| FROM pg_proc |
| WHERE oid = {oid} |
| """.format(oid=oid)) |
| # If it does not have any arguments then the unnest will not return |
| # any rows. We need a single row with an empty string. |
| unnest_proargtypes = "\'\'::VARCHAR" |
| gen_series_proargtypes = "1" |
| if proargtypes[0]['proargtypes'] != "-1": |
| # Convert the argument types to text |
| unnest_proargtypes = "textin(regtypeout(unnest(proargtypes)::regtype))" |
| gen_series_proargtypes = "generate_series(0, array_upper(proargtypes, 1))" |
| |
| # Convert the return type to text. The aggregate (max) is necessary for |
| # the array_to_string aggregate to work. Every row should have the same |
| # proname and rettype. |
| row = self._run_sql( |
| """ |
| SELECT |
| max(proname) AS proname, |
| max(rettype) AS rettype, |
| array_to_string(array_agg(argtype order by i), ', ') AS argument |
| FROM |
| ( |
| SELECT |
| proname, |
| textin(regtypeout(prorettype::regtype)) AS rettype, |
| {unnest_proargtypes} AS argtype, |
| {gen_series_proargtypes} AS i |
| FROM |
| pg_proc AS p |
| WHERE |
| oid = {oid} |
| ) AS f |
| """.format(**locals())) |
| return {"proname": row[0]['proname'], |
| "rettype": row[0]['rettype'], |
| "argument": row[0]['argument']} |
| |
| |
| class ChangeHandler(UpgradeBase): |
| """ |
| @brief This class reads changes from the configuration file and handles |
| the dropping of objects |
| """ |
| |
| def __init__(self, schema, portid, con_args, maddir, mad_dbrev, |
| output_filehandle, upgrade_to=None): |
| UpgradeBase.__init__(self, schema, portid, con_args) |
| |
| # FIXME: maddir includes the '/src' folder. It's supposed to be the |
| # parent of that directory. |
| self._maddir = maddir |
| self._mad_dbrev = mad_dbrev |
| self._newmodule = {} |
| self._curr_rev = self._get_current_version() if not upgrade_to else upgrade_to |
| self.output_filehandle = output_filehandle |
| self._udt = {} |
| self._udf = {} |
| self._uda = {} |
| self._udc = {} |
| self._udo = {} |
| self._udoc = {} |
| self._load() |
| |
| def _get_current_version(self): |
| """ Get current version of MADlib |
| |
| This currently assumes that version is available in |
| '$MADLIB_HOME/src/config/Version.yml' |
| """ |
| version_filepath = os.path.abspath( |
| os.path.join(self._maddir, 'config', 'Version.yml')) |
| with open(version_filepath) as ver_file: |
| version_str = str(yaml.load(ver_file)['version']) |
| return get_rev_num(version_str) |
| |
| def _load_config_param(self, config_iterable, output_config_dict=None): |
| """ |
| Replace schema_madlib with the appropriate schema name and |
| make all function names lower case to ensure ease of comparison. |
| |
| Args: |
| @param config_iterable is an iterable of dictionaries, each with |
| key = object name (eg. function name) and value = details |
| for the object. The details for the object are assumed |
| to be in a dictionary with following keys: |
| rettype: Return type |
| argument: List of arguments |
| |
| Returns: |
| A dictionary that lists all specific objects (functions, aggregates, etc) |
| with object name as key and a list as value, where the list |
| contains all the items present in another dictionary with objects |
| details as the value. |
| """ |
| _return_obj = defaultdict(list) if not output_config_dict else output_config_dict |
| if config_iterable is not None: |
| for each_config in config_iterable: |
| for obj_name, obj_details in each_config.iteritems(): |
| formatted_obj = {} |
| for k, v in obj_details.items(): |
| v = v.lower().replace('schema_madlib', self._schema) if v else "" |
| formatted_obj[k] = v |
| _return_obj[obj_name].append(formatted_obj) |
| return _return_obj |
| |
| @classmethod |
| def _add_to_dict(cls, src_dict, dest_dict): |
| """ Update dictionary with contents of another dictionary |
| |
| This function performs the same function as dict.update except it adds |
| to an existing value (instead of replacing it) if the value is an |
| Iterable. |
| """ |
| if src_dict: |
| for k, v in src_dict.items(): |
| if k in dest_dict: |
| if (isinstance(dest_dict[k], Iterable) and isinstance(v, Iterable)): |
| dest_dict[k] += v |
| elif isinstance(dest_dict[k], Iterable): |
| dest_dict[k].append(v) |
| else: |
| dest_dict[k] = v |
| else: |
| dest_dict[k] = v |
| return dest_dict |
| |
| def _update_objects(self, config): |
| """ Update each upgrade object """ |
| self._add_to_dict(config['new module'], self._newmodule) |
| self._add_to_dict(config['udt'], self._udt) |
| self._add_to_dict(config['udc'], self._udc) |
| self._add_to_dict(self._load_config_param(config['udf']), self._udf) |
| self._add_to_dict(self._load_config_param(config['uda']), self._uda) |
| self._add_to_dict(self._load_config_param(config['udo']), self._udo) |
| self._add_to_dict(self._load_config_param(config['udoc']), self._udoc) |
| |
| def _get_relevant_filenames(self, upgrade_from): |
| """ Get all changelist files that together describe the upgrade process |
| |
| Args: |
| @param upgrade_from: List. Version to upgrade from - the format is |
| expected to be per the output of get_rev_num |
| |
| Details: |
| Changelist files are named in the format changelist_<src>_<dest>.yaml |
| |
| When upgrading from 'upgrade_from_rev' to 'self._curr_rev', all |
| intermediate changelist files need to be followed to get all upgrade |
| steps. This function globs for such files and filters in changelists |
| that lie between the desired versions. |
| |
| Additional verification: The function also ensures that a valid |
| upgrade path exists. Each version in the changelist files needs to |
| be seen twice (except upgrade_from and upgrade_to) for a valid path. |
| This is verified by performing an xor-like operation by |
| adding/deleting from a list. |
| """ |
| output_filenames = [] |
| upgrade_to = self._curr_rev |
| |
| verify_list = [upgrade_from, upgrade_to] |
| |
| # assuming that changelists are in the same directory as this file |
| glob_filter = os.path.abspath( |
| os.path.join(self._maddir, 'madpack', 'changelist*.yaml')) |
| all_changelists = glob.glob(glob_filter) |
| for each_ch in all_changelists: |
| # split file names to get dest versions |
| # Assumption: changelist format is |
| # changelist_<src>_<dest>.yaml |
| ch_basename = os.path.splitext(os.path.basename(each_ch))[0] # remove extension |
| ch_splits = ch_basename.split('_') # underscore delineates sections |
| if len(ch_splits) >= 3: |
| src_version, dest_version = [get_rev_num(i) for i in ch_splits[1:3]] |
| |
| # file is part of upgrade if |
| # upgrade_to >= dest >= src >= upgrade_from |
| is_part_of_upgrade = ( |
| is_rev_gte(src_version, upgrade_from) and |
| is_rev_gte(upgrade_to, dest_version)) |
| if is_part_of_upgrade: |
| for ver in (src_version, dest_version): |
| if ver in verify_list: |
| verify_list.remove(ver) |
| else: |
| verify_list.append(ver) |
| abs_path = os.path.join(self._maddir, 'src', 'madpack', each_ch) |
| output_filenames.append(abs_path) |
| |
| if verify_list: |
| # any version remaining in verify_list implies upgrade path is broken |
| raise RuntimeError("Upgrade from {0} to {1} broken due to missing " |
| "changelist files ({2}). ". |
| format(upgrade_from, upgrade_to, verify_list)) |
| return output_filenames |
| |
| def _load(self): |
| """ |
| @brief Load the configuration file |
| """ |
| rev = get_rev_num(self._mad_dbrev) |
| upgrade_filenames = self._get_relevant_filenames(rev) |
| for f in upgrade_filenames: |
| with open(f) as handle: |
| config = yaml.load(handle) |
| self._update_objects(config) |
| |
| @property |
| def newmodule(self): |
| return self._newmodule |
| |
| @property |
| def udt(self): |
| return self._udt |
| |
| @property |
| def uda(self): |
| return self._uda |
| |
| @property |
| def udf(self): |
| return self._udf |
| |
| @property |
| def udc(self): |
| return self._udc |
| |
| @property |
| def udo(self): |
| return self._udo |
| |
| @property |
| def udoc(self): |
| return self._udoc |
| |
| def get_udf_signature(self): |
| """ |
| @brief Get the list of UDF signatures for comparison |
| """ |
| res = defaultdict(bool) |
| for udf in self._udf: |
| for item in self._udf[udf]: |
| udf_arglist = item['argument'] if 'argument' in item else '' |
| signature = get_signature_for_compare( |
| self._schema, udf, item['rettype'], udf_arglist) |
| res[signature] = True |
| return res |
| |
| def get_uda_signature(self): |
| """ |
| @brief Get the list of UDA signatures for comparison |
| """ |
| res = defaultdict(bool) |
| for uda in self._uda: |
| for item in self._uda[uda]: |
| uda_arglist = item['argument'] if 'argument' in item else '' |
| signature = get_signature_for_compare( |
| self._schema, uda, item['rettype'], uda_arglist) |
| res[signature] = True |
| return res |
| |
| def get_udo_oids(self): |
| """ |
| @brief Get the list of changed/removed UDO OIDs for comparison |
| """ |
| ret = [] |
| |
| changed_ops = set() |
| for op, li in self._udo.items(): |
| for e in li: |
| changed_ops.add((op, e['leftarg'], e['rightarg'])) |
| |
| rows = self._run_sql(""" |
| SELECT |
| o.oid, oprname, oprleft::regtype, oprright::regtype |
| FROM |
| pg_operator AS o, pg_namespace AS ns |
| WHERE |
| o.oprnamespace = ns.oid AND |
| ns.nspname = '{schema}' |
| """.format(schema=self._schema.lower())) |
| for row in rows: |
| if (row['oprname'], row['oprleft'], row['oprright']) in changed_ops: |
| ret.append(row['oid']) |
| |
| return ret |
| |
| def get_udoc_oids(self): |
| """ |
| @brief Get the list of changed/removed UDOC OIDs for comparison |
| """ |
| ret = [] |
| |
| changed_opcs = set() |
| for opc, li in self._udoc.items(): |
| for e in li: |
| changed_opcs.add((opc, e['index'])) |
| gte_gpdb5 = (self._portid == 'greenplum' and |
| is_rev_gte(get_rev_num(self._dbver), get_rev_num('5.0'))) |
| if (self._portid == 'postgres' or gte_gpdb5): |
| method_col = 'opcmethod' |
| else: |
| method_col = 'opcamid' |
| rows = self._run_sql(""" |
| SELECT |
| oc.oid, opcname, amname AS index |
| FROM |
| pg_opclass AS oc, pg_am as am |
| WHERE |
| oc.opcnamespace = {madlib_schema_oid} AND |
| oc.{method_col} = am.oid; |
| """.format(method_col=method_col, madlib_schema_oid=self._schema_oid)) |
| for row in rows: |
| if (row['opcname'], row['index']) in changed_opcs: |
| ret.append(row['oid']) |
| |
| return ret |
| |
| def drop_changed_udt(self): |
| """ |
| @brief Drop all types that were updated/removed in the new version |
| @note It is dangerous to drop a UDT becuase there might be many |
| dependencies |
| """ |
| for udt in self._udt: |
| cascade_str = 'CASCADE' if udt in ('svec', 'bytea8') else '' |
| # CASCADE DROP for svec and bytea8 because the recv/send |
| # functions and the type depend on each other |
| _write_to_file(self.output_filehandle, "DROP TYPE IF EXISTS {0}.{1} {2};". |
| format(self._schema, udt, cascade_str)) |
| |
| def drop_changed_udf(self): |
| """ |
| @brief Drop all functions (UDF) that were removed in new version |
| """ |
| for udf in self._udf: |
| for item in self._udf[udf]: |
| # This is a fix for https://issues.apache.org/jira/browse/MADLIB-1197. |
| # kNN had a peculiar case where a UDF with no arguments was defined, |
| # so dropping that function needs this extra check. |
| udf_arglist = item['argument'] if 'argument' in item else '' |
| |
| _write_to_file(self.output_filehandle, "DROP FUNCTION IF EXISTS {schema}.{udf}({arg});". |
| format(schema=self._schema, |
| udf=udf, |
| arg=udf_arglist)) |
| |
| def drop_changed_uda(self): |
| """ |
| @brief Drop all aggregates (UDA) that were removed in new version |
| """ |
| for uda in self._uda: |
| for item in self._uda[uda]: |
| _write_to_file(self.output_filehandle, "DROP AGGREGATE IF EXISTS {schema}.{uda}({arg});". |
| format(schema=self._schema, |
| uda=uda, |
| arg=item['argument'])) |
| |
| def drop_changed_udc(self): |
| """ |
| @brief Drop all casts (UDC) that were updated/removed in new version |
| @note We have special treatment for UDCs defined in the svec module |
| """ |
| for udc in self._udc: |
| _write_to_file(self.output_filehandle, "DROP CAST IF EXISTS ({sourcetype} AS {targettype});". |
| format(sourcetype=self._udc[udc]['sourcetype'], |
| targettype=self._udc[udc]['targettype'])) |
| |
| def drop_changed_udo(self): |
| """ |
| @brief Drop all operators (UDO) that were removed/updated in new version |
| """ |
| for op in self._udo: |
| for value in self._udo[op]: |
| leftarg = value['leftarg'].replace('schema_madlib', self._schema) |
| rightarg = value['rightarg'].replace('schema_madlib', self._schema) |
| _write_to_file(self.output_filehandle, """ |
| DROP OPERATOR IF EXISTS {schema}.{op} ({leftarg}, {rightarg}); |
| """.format(schema=self._schema, **locals())) |
| |
| def drop_changed_udoc(self): |
| """ |
| @brief Drop all operator classes (UDOC) that were removed/updated in new version |
| """ |
| for op_cls in self._udoc: |
| for value in self._udoc[op_cls]: |
| index = value['index_method'] |
| _write_to_file(self.output_filehandle, """ |
| DROP OPERATOR CLASS IF EXISTS {schema}.{op_cls} USING {index}; |
| """.format(schema=self._schema, **locals())) |
| |
| |
| class ViewDependency(UpgradeBase): |
| """ |
| @brief This class detects the direct/recursive view dependencies on MADLib |
| UDFs/UDAs/UDOs defined in the current version |
| """ |
| |
| def __init__(self, schema, portid, con_args): |
| UpgradeBase.__init__(self, schema, portid, con_args) |
| self._view2proc = None |
| self._view2op = None |
| self._view2view = None |
| self._view2def = None |
| self._detect_direct_view_dependency_udf_uda() |
| self._detect_direct_view_dependency_udo() |
| self._detect_recursive_view_dependency() |
| self._filter_recursive_view_dependency() |
| |
| def _detect_direct_view_dependency_udf_uda(self): |
| """ |
| @brief Detect direct view dependencies on MADlib UDFs/UDAs |
| """ |
| proisagg_wrapper = "p.proisagg" |
| if self._portid == 'postgres' and self._dbver > 11: |
| proisagg_wrapper = "p.prokind = 'a'" |
| rows = self._run_sql(""" |
| SELECT |
| view, nsp.nspname AS schema, procname, procoid, proisagg |
| FROM |
| pg_namespace nsp, |
| ( |
| SELECT |
| c.relname AS view, |
| c.relnamespace AS namespace, |
| p.proname As procname, |
| p.oid AS procoid, |
| {proisagg_wrapper} AS proisagg |
| FROM |
| pg_class AS c, |
| pg_rewrite AS rw, |
| pg_depend AS d, |
| pg_proc AS p |
| WHERE |
| c.oid = rw.ev_class AND |
| rw.oid = d.objid AND |
| d.classid = 'pg_rewrite'::regclass AND |
| d.refclassid = 'pg_proc'::regclass AND |
| d.refobjid = p.oid AND |
| p.pronamespace = {schema_madlib_oid} |
| ) t1 |
| WHERE |
| t1.namespace = nsp.oid |
| """.format(schema_madlib_oid=self._schema_oid, |
| proisagg_wrapper=proisagg_wrapper)) |
| |
| self._view2proc = defaultdict(list) |
| for row in rows: |
| key = (row['schema'], row['view']) |
| self._view2proc[key].append( |
| (row['procname'], row['procoid'], |
| 'UDA' if row['proisagg'] == 't' else 'UDF')) |
| |
| def _detect_direct_view_dependency_udo(self): |
| """ |
| @brief Detect direct view dependencies on MADlib UDOs |
| """ |
| rows = self._run_sql(""" |
| SELECT |
| view, nsp.nspname AS schema, oprname, oproid |
| FROM |
| pg_namespace nsp, |
| ( |
| SELECT |
| c.relname AS view, |
| c.relnamespace AS namespace, |
| p.oprname AS oprname, |
| p.oid AS oproid |
| FROM |
| pg_class AS c, |
| pg_rewrite AS rw, |
| pg_depend AS d, |
| pg_operator AS p |
| WHERE |
| c.oid = rw.ev_class AND |
| rw.oid = d.objid AND |
| d.classid = 'pg_rewrite'::regclass AND |
| d.refclassid = 'pg_operator'::regclass AND |
| d.refobjid = p.oid AND |
| p.oprnamespace = {schema_madlib_oid} |
| ) t1 |
| WHERE |
| t1.namespace = nsp.oid |
| """.format(schema_madlib_oid=self._schema_oid)) |
| |
| self._view2op = defaultdict(list) |
| for row in rows: |
| key = (row['schema'], row['view']) |
| self._view2op[key].append((row['oprname'], row['oproid'], "UDO")) |
| |
| """ |
| @brief Detect recursive view dependencies (view on view) |
| """ |
| |
| def _detect_recursive_view_dependency(self): |
| rows = self._run_sql(""" |
| SELECT |
| nsp1.nspname AS depender_schema, |
| depender, |
| nsp2.nspname AS dependee_schema, |
| dependee |
| FROM |
| pg_namespace AS nsp1, |
| pg_namespace AS nsp2, |
| ( |
| SELECT |
| c.relname depender, |
| c.relnamespace AS depender_nsp, |
| c1.relname AS dependee, |
| c1.relnamespace AS dependee_nsp |
| FROM |
| pg_rewrite AS rw, |
| pg_depend AS d, |
| pg_class AS c, |
| pg_class AS c1 |
| WHERE |
| rw.ev_class = c.oid AND |
| rw.oid = d.objid AND |
| d.classid = 'pg_rewrite'::regclass AND |
| d.refclassid = 'pg_class'::regclass AND |
| d.refobjid = c1.oid AND |
| c1.relkind = 'v' AND |
| c.relname <> c1.relname |
| GROUP BY |
| depender, depender_nsp, dependee, dependee_nsp |
| ) t1 |
| WHERE |
| t1.depender_nsp = nsp1.oid AND |
| t1.dependee_nsp = nsp2.oid |
| """) |
| |
| self._view2view = defaultdict(list) |
| for row in rows: |
| key = (row['depender_schema'], row['depender']) |
| val = (row['dependee_schema'], row['dependee']) |
| self._view2view[key].append(val) |
| |
| """ |
| @brief Filter out recursive view dependencies which are independent of |
| MADLib UDFs/UDAs |
| """ |
| |
| def _filter_recursive_view_dependency(self): |
| # Get initial list |
| checklist = [] |
| checklist.extend(self._view2proc.keys()) |
| checklist.extend(self._view2op.keys()) |
| |
| while True: |
| new_checklist = [] |
| for depender, dependeelist in self._view2view.iteritems(): |
| for dependee in dependeelist: |
| if dependee in checklist and depender not in checklist: |
| new_checklist.append(depender) |
| break |
| if len(new_checklist) == 0: |
| break |
| else: |
| checklist.extend(new_checklist) |
| |
| # Filter recursive dependencies not related with MADLib UDF/UDAs |
| filtered_view2view = defaultdict(list) |
| for depender, dependeelist in self._view2view.iteritems(): |
| filtered_dependeelist = [r for r in dependeelist if r in checklist] |
| if len(filtered_dependeelist) > 0: |
| filtered_view2view[depender] = filtered_dependeelist |
| |
| self._view2view = filtered_view2view |
| |
| """ |
| @brief Build the dependency graph (depender-to-dependee adjacency list) |
| """ |
| |
| def _build_dependency_graph(self, hasProcDependency=False): |
| der2dee = self._view2view.copy() |
| for view in self._view2proc: |
| if view not in self._view2view: |
| der2dee[view] = [] |
| if hasProcDependency: |
| der2dee[view].extend(self._view2proc[view]) |
| |
| for view in self._view2op: |
| if view not in self._view2view: |
| der2dee[view] = [] |
| if hasProcDependency: |
| der2dee[view].extend(self._view2op[view]) |
| |
| graph = der2dee.copy() |
| for der in der2dee: |
| for dee in der2dee[der]: |
| if dee not in graph: |
| graph[dee] = [] |
| return graph |
| |
| """ |
| @brief Check dependencies |
| """ |
| |
| def has_dependency(self): |
| return (len(self._view2proc) > 0) or (len(self._view2op) > 0) |
| |
| """ |
| @brief Get the ordered views for creation |
| """ |
| |
| def get_create_order_views(self): |
| graph = self._build_dependency_graph() |
| ordered_views = [] |
| while True: |
| remove_list = [] |
| for depender in graph: |
| if len(graph[depender]) == 0: |
| ordered_views.append(depender) |
| remove_list.append(depender) |
| for view in remove_list: |
| del graph[view] |
| for depender in graph: |
| graph[depender] = [r for r in graph[depender] |
| if r not in remove_list] |
| if len(remove_list) == 0: |
| break |
| return ordered_views |
| |
| """ |
| @brief Get the ordered views for dropping |
| """ |
| |
| def get_drop_order_views(self): |
| ordered_views = self.get_create_order_views() |
| ordered_views.reverse() |
| return ordered_views |
| |
| def get_depended_func_signature(self, tag='UDA'): |
| """ |
| @brief Get the depended UDF/UDA signatures for comparison |
| """ |
| res = {} |
| for procs in self._view2proc.values(): |
| for proc in procs: |
| if proc[2] == tag and (self._schema, proc) not in res: |
| funcinfo = self._get_function_info(proc[1]) |
| signature = get_signature_for_compare(self._schema, proc[0], |
| funcinfo['rettype'], |
| funcinfo['argument']) |
| res[signature] = True |
| return res |
| |
| def get_depended_opr_oids(self): |
| """ |
| @brief Get the depended UDO OIDs for comparison |
| """ |
| res = set() |
| for depended_ops in self._view2op.values(): |
| for op_entry in depended_ops: |
| res.add(op_entry[1]) |
| |
| return list(res) |
| |
| def get_proc_w_dependency(self, tag='UDA'): |
| res = [] |
| for procs in self._view2proc.values(): |
| for proc in procs: |
| if proc[2] == tag and (self._schema, proc) not in res: |
| res.append((self._schema, proc)) |
| res.sort() |
| return res |
| |
| def get_depended_uda(self): |
| """ |
| @brief Get dependent UDAs |
| """ |
| self.get_proc_w_dependency(tag='UDA') |
| |
| def get_depended_udf(self): |
| """ |
| @brief Get dependent UDFs |
| """ |
| self.get_proc_w_dependency(tag='UDF') |
| |
| # DEPRECATED ------------------------------------------------------------ |
| def save_and_drop(self): |
| """ |
| @brief Save and drop the dependent views |
| """ |
| self._view2def = {} |
| ordered_views = self.get_drop_order_views() |
| # Save views |
| for view in ordered_views: |
| row = self._run_sql(""" |
| SELECT |
| schemaname, viewname, viewowner, definition |
| FROM |
| pg_views |
| WHERE |
| schemaname = '{schemaname}' AND |
| viewname = '{viewname}' |
| """.format(schemaname=view[0], viewname=view[1])) |
| self._view2def[view] = row[0] |
| |
| # Drop views |
| for view in ordered_views: |
| self._run_sql(""" |
| DROP VIEW IF EXISTS {schema}.{view} |
| """.format(schema=view[0], view=view[1])) |
| |
| # DEPRECATED ------------------------------------------------------------ |
| def restore(self): |
| """ |
| @brief Restore the dependent views |
| """ |
| ordered_views = self.get_create_order_views() |
| for view in ordered_views: |
| row = self._view2def[view] |
| schema = row['schemaname'] |
| view = row['viewname'] |
| owner = row['viewowner'] |
| definition = row['definition'] |
| self._run_sql(""" |
| --Alter view not supported by GP, so use set/reset role as a |
| --workaround |
| --ALTER VIEW {schema}.{view} OWNER TO {owner} |
| SET ROLE {owner}; |
| CREATE OR REPLACE VIEW {schema}.{view} AS {definition}; |
| RESET ROLE |
| """.format(schema=schema, view=view, |
| definition=definition, owner=owner)) |
| |
| def _node_to_str(self, node): |
| if len(node) == 2: |
| res = '%s.%s' % (node[0], node[1]) |
| else: |
| node_type = 'uda' |
| if node[2] == 'UDO': |
| node_type = 'udo' |
| elif node[2] == 'UDF': |
| node_type = 'udf' |
| res = '%s.%s{oid=%s, %s}' % (self._schema, node[0], node[1], node_type) |
| return res |
| |
| def _nodes_to_str(self, nodes): |
| return [self._node_to_str(i) for i in nodes] |
| |
| def get_dependency_graph_str(self): |
| """ |
| @brief Get the dependency graph string for print |
| """ |
| graph = self._build_dependency_graph(True) |
| nodes = list(graph.keys()) |
| nodes.sort() |
| res = ["\tDependency Graph (Depender-Dependee Adjacency List):"] |
| for node in nodes: |
| res.append("{0} -> {1}".format(self._node_to_str(node), |
| self._nodes_to_str(graph[node]))) |
| return "\n\t\t\t\t".join(res) |
| |
| |
| class TableDependency(UpgradeBase): |
| """ |
| @brief This class detects the table dependencies on MADLib UDTs defined in the |
| current version |
| """ |
| |
| def __init__(self, schema, portid, con_args): |
| UpgradeBase.__init__(self, schema, portid, con_args) |
| self._table2type = None |
| self._detect_table_dependency() |
| self._detect_index_dependency() |
| |
| def _detect_table_dependency(self): |
| """ |
| @brief Detect the table dependencies on MADLib UDTs |
| """ |
| rows = self._run_sql(""" |
| SELECT |
| nsp.nspname AS schema, |
| relname AS relation, |
| attname AS column, |
| typname AS type |
| FROM |
| pg_attribute a, |
| pg_class c, |
| pg_type t, |
| pg_namespace nsp |
| WHERE |
| t.typnamespace = {schema_madlib_oid} |
| AND a.atttypid = t.oid |
| AND c.oid = a.attrelid |
| AND c.relnamespace = nsp.oid |
| AND c.relkind = 'r' |
| ORDER BY |
| nsp.nspname, relname, attname, typname |
| """.format(schema_madlib_oid=self._schema_oid)) |
| |
| self._table2type = defaultdict(list) |
| for row in rows: |
| key = (row['schema'], row['relation']) |
| self._table2type[key].append( |
| (row['column'], row['type'])) |
| |
| def _detect_index_dependency(self): |
| """ |
| @brief Detect the index dependencies on MADlib UDOCs |
| """ |
| rows = self._run_sql( |
| """ |
| select |
| s.idxname, s.oid as opcoid, nsp.nspname as schema, s.name as opcname |
| from |
| pg_namespace nsp |
| join |
| ( |
| select |
| objid::regclass as idxname, c.relnamespace as namespace, oc.oid as oid, |
| oc.opcname as name |
| from |
| pg_depend d |
| join |
| pg_opclass oc |
| on (d.refclassid='pg_opclass'::regclass and d.refobjid = oc.oid) |
| join |
| pg_class c |
| on (c.oid = d.objid) |
| where oc.opcnamespace = {schema_madlib_oid} and c.relkind = 'i' |
| ) s |
| on (nsp.oid = s.namespace) |
| """.format(schema_madlib_oid=self._schema_oid)) |
| self._index2opclass = defaultdict(list) |
| for row in rows: |
| key = (row['schema'], row['idxname']) |
| self._index2opclass[key].append( |
| (row['opcoid'], row['opcname'])) |
| |
| def has_dependency(self): |
| """ |
| @brief Check dependencies |
| """ |
| return len(self._table2type) > 0 or len(self._index2opclass) > 0 |
| |
| def get_depended_udt(self): |
| """ |
| @brief Get the list of depended UDTs |
| """ |
| res = defaultdict(bool) |
| for table in self._table2type: |
| for (col, typ) in self._table2type[table]: |
| if typ not in res: |
| res[typ] = True |
| return res |
| |
| def get_depended_udoc_oids(self): |
| """ |
| @brief Get the list of depended UDOC OIDs |
| """ |
| res = set() |
| for depended_opcs in self._index2opclass.values(): |
| for opc_entry in depended_opcs: |
| res.add(opc_entry[0]) |
| |
| return list(res) |
| |
| def get_dependency_str(self): |
| """ |
| @brief Get the dependencies in string for print |
| """ |
| res = ['\tTable Dependency (schema.table.column -> MADlib type):'] |
| for table in self._table2type: |
| for (col, udt) in self._table2type[table]: |
| res.append("{0}.{1}.{2} -> {3}".format(table[0], table[1], col, |
| udt)) |
| for index in self._index2opclass: |
| for (oid, name) in self._index2opclass[index]: |
| res.append("{0}.{1} -> {3}(oid={4})".format(index[0], index[1], name, oid)) |
| |
| return "\n\t\t\t\t".join(res) |
| |
| |
| class ScriptCleaner(UpgradeBase): |
| """ |
| @brief This class removes sql statements from a sql script which should not be |
| executed during the upgrade |
| """ |
| |
| def __init__(self, schema, portid, con_args, change_handler): |
| UpgradeBase.__init__(self, schema, portid, con_args) |
| self._ch = change_handler |
| self._sql = None |
| self._existing_uda = None |
| self._existing_udt = None |
| self._aggregate_patterns = self._get_all_aggregate_patterns() |
| self._unchanged_operator_patterns = self._get_unchanged_operator_patterns() |
| self._unchanged_opclass_patterns = self._get_unchanged_opclass_patterns() |
| # print("Number of existing UDAs = " + str(len(self._existing_uda))) |
| # print("Number of UDAs to not create = " + str(len(self._aggregate_patterns))) |
| self._get_existing_udt() |
| |
| def _get_existing_udoc(self): |
| """ |
| @brief Get the existing UDOCs in the current version |
| """ |
| gte_gpdb5 = (self._portid == 'greenplum' and |
| is_rev_gte(get_rev_num(self._dbver), get_rev_num('5.0'))) |
| if (self._portid == 'postgres' or gte_gpdb5): |
| method_col = 'opcmethod' |
| else: |
| method_col = 'opcamid' |
| rows = self._run_sql(""" |
| SELECT |
| opcname, amname AS index |
| FROM |
| pg_opclass AS oc, pg_namespace AS ns, pg_am as am |
| WHERE |
| oc.opcnamespace = ns.oid AND |
| oc.{method_col} = am.oid AND |
| ns.nspname = '{schema}'; |
| """.format(schema=self._schema.lower(), **locals())) |
| self._existing_udoc = defaultdict(list) |
| for row in rows: |
| self._existing_udoc[row['opcname']].append({'index': row['index']}) |
| |
| def _get_existing_udo(self): |
| """ |
| @brief Get the existing UDOs in the current version |
| """ |
| rows = self._run_sql(""" |
| SELECT |
| oprname, oprleft::regtype, oprright::regtype |
| FROM |
| pg_operator AS o, pg_namespace AS ns |
| WHERE |
| o.oprnamespace = ns.oid AND |
| ns.nspname = '{schema}' |
| """.format(schema=self._schema.lower())) |
| self._existing_udo = defaultdict(list) |
| for row in rows: |
| self._existing_udo[row['oprname']].append( |
| {'leftarg': row['oprleft'], |
| 'rightarg': row['oprright']}) |
| |
| def _get_existing_uda(self): |
| """ |
| @brief Get the existing UDAs in the current version. |
| """ |
| # See _get_function_info for explanations. |
| |
| proisagg_wrapper = "p.proisagg = true" |
| if self._portid == 'postgres' and self._dbver > 11: |
| proisagg_wrapper = "p.prokind = 'a'" |
| |
| rows = self._run_sql(""" |
| SELECT |
| max(proname) AS proname, |
| max(rettype) AS rettype, |
| array_to_string(array_agg(argtype order by i), ', ') AS argument |
| FROM |
| ( |
| SELECT |
| p.oid AS procoid, |
| proname, |
| textin(regtypeout(prorettype::regtype)) AS rettype, |
| textin(regtypeout(unnest(proargtypes)::regtype)) AS argtype, |
| generate_series(0, array_upper(proargtypes, 1)) AS i |
| FROM |
| pg_proc AS p, |
| pg_namespace AS nsp |
| WHERE |
| p.pronamespace = nsp.oid AND |
| {proisagg_wrapper} AND |
| nsp.nspname = '{schema}' |
| ) AS f |
| GROUP BY |
| procoid |
| """.format(schema=self._schema, proisagg_wrapper=proisagg_wrapper)) |
| self._existing_uda = defaultdict(list) |
| for row in rows: |
| # Consider about the overloaded aggregates |
| self._existing_uda[row['proname']].append( |
| {'rettype': row['rettype'], |
| 'argument': row['argument']}) |
| |
| def _get_unchanged_operator_patterns(self): |
| """ |
| Creates a list of string patterns that represent all |
| 'CREATE OPERATOR' statements not changed since the old version. |
| |
| @return unchanged = existing - changed |
| """ |
| self._get_existing_udo() # from the old version |
| operator_patterns = [] |
| # for all, pass the changed ones, add others to ret |
| for each_udo, udo_details in self._existing_udo.items(): |
| for each_item in udo_details: |
| if each_udo in self._ch.udo: |
| if each_item in self._ch.udo[each_udo]: |
| continue |
| p_arg_str = '' |
| # assuming binary ops |
| leftarg = self._rewrite_type_in(each_item['leftarg']) |
| rightarg = self._rewrite_type_in(each_item['rightarg']) |
| p_str = "CREATE\s+OPERATOR\s+{schema}\.{op_name}\s*\(" \ |
| "\s*leftarg\s*=\s*{leftarg}\s*," \ |
| "\s*rightarg\s*=\s*{rightarg}\s*," \ |
| ".*?\)\s*;".format(schema=self._schema, |
| op_name=re.escape(each_udo), **locals()) |
| operator_patterns.append(p_str) |
| return operator_patterns |
| |
| def _get_unchanged_opclass_patterns(self): |
| """ |
| Creates a list of string patterns that represent all |
| 'CREATE OPERATOR CLASS' statements not changed since the old version. |
| |
| @return unchanged = existing - changed |
| """ |
| self._get_existing_udoc() # from the old version |
| opclass_patterns = [] |
| # for all, pass the changed ones, add others to ret |
| for each_udoc, udoc_details in self._existing_udoc.items(): |
| for each_item in udoc_details: |
| if each_udoc in self._ch.udoc: |
| if each_item in self._ch.udoc[each_udoc]: |
| continue |
| p_arg_str = '' |
| # assuming binary ops |
| index = each_item['index'] |
| p_str = "CREATE\s+OPERATOR\s+CLASS\s+{schema}\.{opc_name}" \ |
| ".*?USING\s+{index}" \ |
| ".*?;".format(schema=self._schema, |
| opc_name=each_udoc, **locals()) |
| opclass_patterns.append(p_str) |
| return opclass_patterns |
| |
| def _get_all_aggregate_patterns(self): |
| """ |
| Creates a list of string patterns that represent all possible |
| 'CREATE AGGREGATE' statements except ones that are being |
| replaced/introduced as part of this upgrade. |
| |
| """ |
| self._get_existing_uda() |
| aggregate_patterns = [] |
| |
| for each_uda, uda_details in self._existing_uda.iteritems(): |
| for each_item in uda_details: |
| if each_uda in self._ch.uda: |
| if each_item in self._ch.uda[each_uda]: |
| continue |
| p_arg_str = '' |
| argument = each_item['argument'] |
| args = argument.split(',') |
| for arg in args: |
| arg = self._rewrite_type_in(arg.strip()) |
| if p_arg_str == '': |
| p_arg_str += '%s\s*' % arg |
| else: |
| p_arg_str += ',\s*%s\s*' % arg |
| p_str = "CREATE\s+(ORDERED\s)*\s*AGGREGATE" \ |
| "\s+%s\.(%s)\s*\(\s*%s\)(.*?);" % (self._schema, |
| each_uda, |
| p_arg_str) |
| aggregate_patterns.append(p_str) |
| return aggregate_patterns |
| |
| def _get_existing_udt(self): |
| """ |
| @brief Get the existing UDTs in the current version |
| """ |
| rows = self._run_sql(""" |
| SELECT |
| typname |
| FROM |
| pg_type AS t, |
| pg_namespace AS nsp |
| WHERE |
| t.typnamespace = nsp.oid AND |
| nsp.nspname = '{schema}' |
| """.format(schema=self._schema)) |
| self._existing_udt = [row['typname'] for row in rows] |
| |
| def get_change_handler(self): |
| """ |
| @note The changer_handler is needed for deciding which sql statements to |
| remove |
| """ |
| return self._ch |
| |
| def _clean_comment(self): |
| """ |
| @brief Remove comments in the sql script |
| """ |
| self._sql = remove_comments_from_sql(self._sql) |
| |
| def _clean_type(self): |
| """ |
| @breif Remove "drop/create type" statements in the sql script |
| """ |
| # remove 'drop type' |
| pattern = re.compile('DROP(\s+)TYPE(.*?);', re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| |
| # remove 'create type' |
| udt_str = '' |
| for udt in self._existing_udt: |
| if udt in self._ch.udt: |
| continue |
| if udt_str == '': |
| udt_str += udt |
| else: |
| udt_str += '|' + udt |
| p_str = 'CREATE(\s+)TYPE(\s+)%s\.(%s)(.*?);' % (self._schema, udt_str) |
| pattern = re.compile(p_str, re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| |
| """ |
| @brief Remove "drop/create cast" statements in the sql script |
| """ |
| |
| def _clean_cast(self): |
| # remove 'drop cast' |
| pattern = re.compile('DROP(\s+)CAST(.*?);', re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| |
| # remove 'create cast' |
| udc_str = '' |
| for udc in self._ch.udc: |
| if udc_str == '': |
| udc_str += '%s\s+AS\s+%s' % ( |
| self._ch.udc[udc]['sourcetype'], |
| self._ch.udc[udc]['targettype']) |
| else: |
| udc_str += '|' + '%s\s+AS\s+%s' % ( |
| self._ch.udc[udc]['sourcetype'], |
| self._ch.udc[udc]['targettype']) |
| |
| pattern = re.compile('CREATE\s+CAST(.*?);', re.DOTALL | re.IGNORECASE) |
| if udc_str != '': |
| pattern = re.compile('CREATE\s+CAST\s*\(\s*(?!%s)(.*?);' % |
| udc_str, re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| |
| """ |
| @brief Remove "drop/create operator" statements in the sql script |
| """ |
| |
| def _clean_operator(self): |
| # remove 'drop operator' |
| pattern = re.compile('DROP\s+OPERATOR.*?PROCEDURE\s+=.*?;', re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| |
| # for create operator statements: |
| # delete: unchanged, removed (not in the input sql anyway) |
| # keep: new, changed |
| for p in self._unchanged_operator_patterns: |
| regex_pat = re.compile(p, re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(regex_pat, '', self._sql) |
| |
| """ |
| @brief Remove "drop/create operator class" statements in the sql script |
| """ |
| |
| def _clean_opclass(self): |
| # remove 'drop operator class' |
| pattern = re.compile(r'DROP\s+OPERATOR\s*CLASS.*?;', re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| |
| # for create operator class statements: |
| # delete: unchanged, removed (not in the input sql anyway) |
| # keep: new, changed |
| for p in self._unchanged_opclass_patterns: |
| regex_pat = re.compile(p, re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(regex_pat, '', self._sql) |
| |
| """ |
| @brief Rewrite the type |
| """ |
| |
| def _rewrite_type_in(self, arg): |
| type_mapper = { |
| 'smallint': '(int2|smallint)', |
| 'integer': '(int|int4|integer)', |
| 'bigint': '(int8|bigint)', |
| 'double precision': '(float8|double precision)', |
| 'real': '(float4|real)', |
| 'character varying': '(varchar|character varying)' |
| } |
| for typ in type_mapper: |
| arg = arg.replace(typ, type_mapper[typ]) |
| return arg.replace('[', '\[').replace(']', '\]') |
| |
| def _clean_aggregate(self): |
| # remove all drop aggregate statements |
| self._sql = re.sub(re.compile('DROP(\s+)AGGREGATE(.*?);', |
| re.DOTALL | re.IGNORECASE), |
| '', self._sql) |
| # for create aggregate statements: |
| # delete: unchanged, removed (not in the input sql anyway) |
| # keep: new, changed |
| for each_pattern in self._aggregate_patterns: |
| regex_pat = re.compile(each_pattern, re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(regex_pat, '', self._sql) |
| |
| def _clean_function(self): |
| """ |
| @brief Remove "drop function" statements and rewrite "create function" |
| statements in the sql script |
| @note We don't drop any function |
| """ |
| # remove 'drop function' |
| pattern = re.compile(r"""DROP(\s+)FUNCTION(.*?);""", re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, '', self._sql) |
| # replace 'create function' with 'create or replace function' |
| pattern = re.compile(r"""CREATE(\s+)FUNCTION""", re.DOTALL | re.IGNORECASE) |
| self._sql = re.sub(pattern, 'CREATE OR REPLACE FUNCTION', self._sql) |
| |
| def cleanup(self, sql, algoname): |
| """ |
| @brief Entry function for cleaning the sql script |
| """ |
| self._sql = sql |
| # Modify the original sql during upgrade. Clean only non-new modules as |
| # they already exist prior to the upgrade. Mostly, all drops are removed |
| # and replaced by creates. |
| if algoname not in self.get_change_handler().newmodule: |
| self._clean_comment() |
| self._clean_type() |
| self._clean_cast() |
| # self._clean_operator() |
| # self._clean_opclass() |
| # self._clean_aggregate() |
| # self._clean_function() |
| return self._sql |
| |
| |
| import unittest |
| |
| |
| class TestChangeHandler(unittest.TestCase): |
| |
| def setUp(self): |
| self._dummy_schema = 'madlib' |
| self._dummy_portid = 1 |
| self._dummy_con_args = 'x' |
| # maddir is the directory one level above current file |
| # dirname gives the directory of current file (madpack) |
| # join with pardir adds .. (e.g .../madpack/..) |
| # abspath concatenates by traversing the .. |
| self.maddir = os.path.abspath( |
| os.path.join(os.path.dirname(os.path.realpath(__file__)), |
| os.pardir)) |
| |
| def tearDown(self): |
| pass |
| |
| def test_invalid_path(self): |
| with self.assertRaises(RuntimeError): |
| ChangeHandler(self._dummy_schema, self._dummy_portid, |
| self._dummy_con_args, self.maddir, |
| '1.9', upgrade_to=get_rev_num('1.12')) |
| |
| def test_valid_path(self): |
| ch = ChangeHandler(self._dummy_schema, self._dummy_portid, |
| self._dummy_con_args, self.maddir, |
| '1.9.1', upgrade_to=get_rev_num('1.12')) |
| self.assertEqual(ch.newmodule.keys(), |
| ['knn', 'sssp', 'apsp', 'measures', 'stratified_sample', |
| 'encode_categorical', 'bfs', 'mlp', 'pagerank', |
| 'train_test_split', 'wcc']) |
| self.assertEqual(ch.udt, {'kmeans_result': None, 'kmeans_state': None}) |
| self.assertEqual(ch.udf['forest_train'], |
| [{'argument': 'text, text, text, text, text, text, text, ' |
| 'integer, integer, boolean, integer, integer, ' |
| 'integer, integer, integer, text, boolean, ' |
| 'double precision', |
| 'rettype': 'void'}, |
| {'argument': 'text, text, text, text, text, text, text, ' |
| 'integer, integer, boolean, integer, integer, ' |
| 'integer, integer, integer, text, boolean', |
| 'rettype': 'void'}, |
| {'argument': 'text, text, text, text, text, text, text, ' |
| 'integer, integer, boolean, integer, integer, ' |
| 'integer, integer, integer, text', |
| 'rettype': 'void'}]) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |