| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # Weakly Connected Components |
| |
| # Please refer to the wcc.sql_in file for the documentation |
| |
| """ |
| @file wcc.py_in |
| |
| @namespace graph |
| """ |
| |
| import plpy |
| from utilities.control import SetGUC |
| from utilities.utilities import _assert |
| from utilities.utilities import _check_groups |
| from utilities.utilities import get_table_qualified_col_str |
| from utilities.utilities import is_platform_gp6_or_up |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import unique_string, split_quoted_delimited_str |
| from utilities.validate_args import columns_exist_in_table, get_expr_type |
| from utilities.utilities import is_platform_pg |
| from utilities.utilities import get_seg_number |
| from utilities.utilities import add_postfix |
| from utilities.validate_args import table_exists |
| from utilities.utilities import rename_table |
| from utilities.control import MinWarning |
| from graph_utils import validate_graph_coding, get_graph_usage |
| from graph_utils import validate_output_and_summary_tables |
| |
| def validate_wcc_args(schema_madlib, vertex_table, vertex_table_in, vertex_id, |
| vertex_id_in, edge_table, edge_params, edge_args, |
| out_table, out_table_summary, grouping_cols, |
| grouping_cols_list, warm_start, out_table_message, |
| module_name): |
| """ |
| Function to validate input parameters for wcc |
| """ |
| validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params, |
| out_table, module_name, warm_start) |
| if not warm_start: |
| _assert(not table_exists(out_table_summary), |
| "Graph {module_name}: Output summary table already exists!".format(**locals())) |
| _assert(not table_exists(out_table_message), |
| "Graph {module_name}: Output message table already exists!".format(**locals())) |
| else: |
| _assert(table_exists(out_table_summary), |
| "Graph {module_name}: Output summary table is missing for warm start!".format(**locals())) |
| _assert(table_exists(out_table_message), |
| "Graph {module_name}: Output message table is missing for warm start! Either wcc was completed in the last run or the table got dropped/renamed.".format(**locals())) |
| |
| prev_summary = plpy.execute("SELECT * FROM {0}".format(out_table_summary))[0] |
| _assert(prev_summary['vertex_table'] == vertex_table_in, "Graph {module_name}: Warm start vertex_table do not match!".format(**locals())) |
| if vertex_id_in: |
| _assert(prev_summary['vertex_id'] == vertex_id_in, "Graph {module_name}: Warm start vertex_id do not match!".format(**locals())) |
| _assert(prev_summary['edge_table'] == edge_table, "Graph {module_name}: Warm start edge_table do not match!".format(**locals())) |
| if edge_args: |
| _assert(prev_summary['edge_args'] == edge_args, "Graph {module_name}: Warm start edge_args do not match!".format(**locals())) |
| |
| _assert(prev_summary['grouping_cols'] == grouping_cols, "Graph {module_name}: Warm start grouping_cols do not match!".format(**locals())) |
| |
| if grouping_cols_list: |
| # validate the grouping columns. We currently only support grouping_cols |
| # to be column names in the edge_table, and not expressions! |
| _assert(columns_exist_in_table(edge_table, grouping_cols_list, schema_madlib), |
| "Weakly Connected Components error: " |
| "One or more grouping columns specified do not exist!") |
| |
| def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args, |
| out_table, grouping_cols, iteration_limit=0, warm_start=False, **kwargs): |
| """ |
| Function that computes the wcc |
| |
| Args: |
| @param vertex_table |
| @param vertex_id |
| @param edge_table |
| @param dest_vertex |
| @param out_table |
| @param grouping_cols |
| """ |
| |
| BIGINT_MAX = 9223372036854775807 |
| vertex_table_in = vertex_table |
| vertex_id_in = vertex_id |
| edge_table_in = edge_table |
| |
| old_msg_level = plpy.execute(""" |
| SELECT setting |
| FROM pg_settings |
| WHERE name='client_min_messages' |
| """)[0]['setting'] |
| plpy.execute('SET client_min_messages TO warning') |
| params_types = {'src': list, 'dest': list} |
| default_args = {'src': ['src'], 'dest': ['dest']} |
| edge_params = extract_keyvalue_params( |
| edge_args, params_types, default_args) |
| |
| if iteration_limit is None or iteration_limit == 0: |
| iteration_limit = BIGINT_MAX |
| elif iteration_limit < 0: |
| plpy.error("Weakly Connected Components: iteration_limit cannot be a negative number.") |
| |
| # populate default values for optional params if null, and prepare data |
| # to be written into the summary table (*_st variable names) |
| vertex_view = unique_string('vertex_view') |
| edge_view = unique_string('edge_view') |
| single_id = 'single_id' |
| vertex_view_sql = """ |
| CREATE VIEW {vertex_view} AS |
| SELECT {vertex_sql} AS id, {vertex_sql} AS {single_id} |
| FROM {vertex_table}; |
| """ |
| if not vertex_id: |
| vertex_id = "id" |
| vertex_sql = vertex_id |
| vertex_type = "BIGINT" |
| else: |
| if vertex_id[0] == '[' and vertex_id[-1] == ']': |
| vertex_id = split_quoted_delimited_str(vertex_id[1:-1]) |
| vertex_sql = "ARRAY[{0}]".format(','.join(vertex_id)) |
| vertex_type = "BIGINT[]" |
| |
| if is_platform_pg(): |
| num_segments = 1 |
| seg_sql = ' 0 ' |
| else: |
| num_segments = get_seg_number() |
| seg_sql = " gp_segment_id " |
| vertex_view_sql = """ |
| CREATE VIEW {vertex_view} AS |
| WITH q1 AS ( |
| SELECT {vertex_sql} AS id, |
| ctid AS ctid_in, |
| {seg_sql} AS seg_id_in, |
| CAST( regexp_matches(ctid::TEXT, '\\((\\d+),(\\d+)\\)') AS BIGINT[]) AS new_id |
| FROM {vertex_table}), |
| q2 AS (SELECT MAX(new_id[1]) AS max_block FROM q1) |
| SELECT id, ctid_in, {num_segments}*(new_id[2]*(max_block+1)+new_id[1])+seg_id_in AS {single_id} |
| FROM q1, q2; |
| """ |
| |
| else: |
| vertex_sql = vertex_id |
| vertex_id = [vertex_id] |
| vertex_type = "BIGINT" |
| |
| |
| src_list = edge_params["src"] |
| if len(src_list) > 1: |
| src = "ARRAY[{0}]".format(','.join(edge_params["src"])) |
| else: |
| edge_params["src"] = edge_params["src"][0] |
| src = edge_params["src"] |
| |
| dest_list = edge_params["dest"] |
| if len(dest_list) > 1: |
| dest = "ARRAY[{0}]".format(','.join(edge_params["dest"])) |
| else: |
| edge_params["dest"] = edge_params["dest"][0] |
| dest = edge_params["dest"] |
| |
| if not grouping_cols: |
| grouping_cols = '' |
| grouping_sql = '' |
| else: |
| grouping_sql = ', {0}'.format(grouping_cols) |
| |
| out_table_summary = '' |
| message = '' |
| if out_table: |
| out_table_summary = add_postfix(out_table, "_summary") |
| message = add_postfix(out_table, "_message") |
| |
| grouping_cols_list = split_quoted_delimited_str(grouping_cols) |
| |
| validate_wcc_args(schema_madlib, vertex_table, vertex_table_in, |
| vertex_id, vertex_id_in, edge_table, |
| edge_params, edge_args, out_table, out_table_summary, |
| grouping_cols, grouping_cols_list, warm_start, message, |
| 'Weakly Connected Components') |
| |
| vertex_view_sql = vertex_view_sql.format(**locals()) |
| |
| edge_view_sql = """ |
| CREATE VIEW {edge_view} AS |
| SELECT {src} AS src, {dest} AS dest {grouping_sql} |
| FROM {edge_table}; |
| """.format(**locals()) |
| |
| plpy.execute(vertex_view_sql + edge_view_sql) |
| |
| vertex_table = vertex_view |
| edge_table = edge_view |
| vertex_id = 'id' |
| src = 'src' |
| dest = 'dest' |
| |
| distribution = '' if is_platform_pg() else "DISTRIBUTED BY (id)" |
| |
| oldupdate = unique_string(desp='oldupdate') |
| toupdate = unique_string(desp='toupdate') |
| temp_out_table = unique_string(desp='tempout') |
| edge_inverse = unique_string(desp='edge_inverse') |
| |
| subq_prefixed_grouping_cols = '' |
| comma_toupdate_prefixed_grouping_cols = '' |
| comma_oldupdate_prefixed_grouping_cols = '' |
| old_new_update_where_condition = '' |
| new_to_update_where_condition = '' |
| edge_to_update_where_condition = '' |
| edge_inverse_to_update_where_condition = '' |
| |
| distinct_grp_sql = "" |
| |
| component_id = 'component_id' |
| grouping_cols_comma = '' if not grouping_cols else grouping_cols + ',' |
| comma_grouping_cols = '' if not grouping_cols else ',' + grouping_cols |
| |
| if not is_platform_pg(): |
| # In Greenplum, to avoid redistribution of data when in later queries, |
| # edge_table is duplicated by creating a temporary table distributed |
| # on dest column |
| plpy.execute(""" CREATE TABLE {edge_inverse} AS |
| SELECT * FROM {edge_table} DISTRIBUTED BY ({dest}); |
| """.format(**locals())) |
| else: |
| edge_inverse = edge_table |
| |
| if warm_start: |
| out_table_sql = "" |
| msg_sql = "" |
| if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id': |
| out_table_sql = """ |
| ALTER TABLE {out_table} RENAME COLUMN {vertex_id_in} TO {vertex_id}; |
| """.format(**locals()) |
| |
| if grouping_cols: |
| distribution = ('' if is_platform_pg() else |
| "DISTRIBUTED BY ({0}, {1})".format(grouping_cols, |
| vertex_id)) |
| # Update some variables useful for grouping based query strings |
| subq = unique_string(desp='subquery') |
| distinct_grp_table = unique_string(desp='grptable') |
| |
| comma_toupdate_prefixed_grouping_cols = ', ' + \ |
| get_table_qualified_col_str(toupdate, grouping_cols_list) |
| comma_oldupdate_prefixed_grouping_cols = ', ' + \ |
| get_table_qualified_col_str(oldupdate, grouping_cols_list) |
| subq_prefixed_grouping_cols = get_table_qualified_col_str(subq, grouping_cols_list) |
| old_new_update_where_condition = ' AND ' + \ |
| _check_groups(oldupdate, out_table, grouping_cols_list) |
| new_to_update_where_condition = ' AND ' + \ |
| _check_groups(out_table, toupdate, grouping_cols_list) |
| edge_to_update_where_condition = ' AND ' + \ |
| _check_groups(edge_table, toupdate, grouping_cols_list) |
| edge_inverse_to_update_where_condition = ' AND ' + \ |
| _check_groups(edge_inverse, toupdate, grouping_cols_list) |
| join_grouping_cols = _check_groups(subq, distinct_grp_table, grouping_cols_list) |
| group_by_clause_newupdate = ('{0}, {1}.{2}'.format(subq_prefixed_grouping_cols, |
| subq, vertex_id)) |
| select_grouping_cols = ',' + subq_prefixed_grouping_cols |
| |
| distinct_grp_sql = """ |
| CREATE TABLE {distinct_grp_table} AS |
| SELECT DISTINCT {grouping_cols} FROM {edge_table}; |
| """ |
| |
| if not warm_start: |
| out_table_sql = """ |
| CREATE TABLE {out_table} AS |
| SELECT {subq}.{vertex_id}, |
| CAST({BIGINT_MAX} AS BIGINT) AS {component_id} |
| {select_grouping_cols} |
| FROM {distinct_grp_table} INNER JOIN ( |
| SELECT {grouping_cols_comma} {src} AS {vertex_id} |
| FROM {edge_table} |
| UNION |
| SELECT {grouping_cols_comma} {dest} AS {vertex_id} |
| FROM {edge_inverse} |
| ) {subq} |
| ON {join_grouping_cols} |
| GROUP BY {group_by_clause_newupdate} |
| {distribution}; |
| """.format(**locals()) |
| msg_sql = """ |
| CREATE TABLE {message} AS |
| SELECT {vertex_table}.{vertex_id}, |
| CAST({vertex_table}.{single_id} AS BIGINT) AS {component_id} |
| {comma_grouping_cols} |
| FROM {out_table} INNER JOIN {vertex_table} |
| ON {vertex_table}.{vertex_id} = {out_table}.{vertex_id} |
| {distribution}; |
| """.format(**locals()) |
| |
| else: |
| if not warm_start: |
| out_table_sql = """ |
| CREATE TABLE {out_table} AS |
| SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id} |
| FROM {vertex_table} |
| {distribution}; |
| """.format(**locals()) |
| msg_sql = """ |
| CREATE TABLE {message} AS |
| SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS {component_id} |
| FROM {vertex_table} |
| {distribution}; |
| """.format(**locals()) |
| |
| old_update_sql = """ |
| CREATE TABLE {oldupdate} AS |
| SELECT {message}.{vertex_id}, |
| MIN({message}.{component_id}) AS {component_id} |
| {comma_grouping_cols} |
| FROM {message} |
| GROUP BY {grouping_cols_comma} {vertex_id} |
| LIMIT 0 |
| {distribution}; |
| """ |
| to_update_sql = """ |
| CREATE TABLE {toupdate} AS |
| SELECT * FROM {oldupdate} |
| {distribution}; |
| """ |
| |
| # We combine the sql statements as much as possible to reduce the number of |
| # subtransactions we create. Postgres and Greenplum has a limit of 64 for |
| # cached subtx and each plpy.execute create one. |
| # Postgres and Greenplum 5 do not like creating a table and using it in the |
| # same plpy.execute so we have too keep them seperate for this step but the |
| # loop is combined for all platforms. |
| if is_platform_gp6_or_up(): |
| plpy.execute((distinct_grp_sql + out_table_sql + msg_sql + old_update_sql + to_update_sql).format(**locals())) |
| else: |
| if distinct_grp_sql != "": |
| plpy.execute(distinct_grp_sql.format(**locals())) |
| plpy.execute(out_table_sql.format(**locals())) |
| plpy.execute(msg_sql.format(**locals())) |
| plpy.execute(old_update_sql.format(**locals())) |
| plpy.execute(to_update_sql.format(**locals())) |
| |
| nodes_to_update = 1 |
| |
| """ |
| WCC Logic: |
| Assume we have the following graph: [1,2] [2,3] [2,4] |
| The first iteration starts with a number of set up steps. |
| For vertex 2, the component_id is set to 2. |
| The relevant work start with the creation of message table. |
| 1) |
| message gets filled in two steps, one for incoming edges and one for outgoing. |
| The logic looks for every neighbor of a vertex and takes the minimum component id it sees. |
| For vertex 2, message will have two entries, 2->1 and 2->3. 2->4 got eliminated because 3<4. |
| 2) |
| next iteration starts with oldupdate. |
| This table is used to reduce the two possible messages into one. |
| For vertex 2, oldupdate will have 2->1. 2->3 got eliminated because 1<3 |
| 3) |
| toupdate is used to check if the update is necessary. |
| We compare the incoming component_id value with the existing one. |
| For vertex 2, toupdate will have 2->1 because 1<2. |
| 4) The out_table gets updated with the contents of toupdate table. |
| 5) A new message is created based on the toupdate table as mentioned above. |
| |
| Warm Start: |
| To facilitate warm start we use two tables: |
| - out_table: it contains the results so far, we will continue building on this table |
| - message: This is the message at the end of an iteration for the next one. |
| We save these two tables if the iteration_limit is reached and nodes_to_update > 0. |
| When the user starts again with warm_start on, we plug these two tables back and continue as usual. |
| """ |
| |
| # Use truncate instead of drop/recreate to avoid catalog bloat. |
| loop_sql = """ |
| TRUNCATE TABLE {oldupdate}; |
| |
| INSERT INTO {oldupdate} |
| SELECT {message}.{vertex_id}, |
| MIN({message}.{component_id}) AS {component_id} |
| {comma_grouping_cols} |
| FROM {message} |
| GROUP BY {grouping_cols_comma} {vertex_id}; |
| |
| TRUNCATE TABLE {toupdate}; |
| |
| INSERT INTO {toupdate} |
| SELECT {oldupdate}.{vertex_id}, |
| {oldupdate}.{component_id} |
| {comma_oldupdate_prefixed_grouping_cols} |
| FROM {oldupdate}, {out_table} |
| WHERE {oldupdate}.{vertex_id}={out_table}.{vertex_id} |
| AND {oldupdate}.{component_id}<{out_table}.{component_id} |
| {old_new_update_where_condition}; |
| |
| UPDATE {out_table} SET |
| {component_id}={toupdate}.{component_id} |
| FROM {toupdate} |
| WHERE {out_table}.{vertex_id}={toupdate}.{vertex_id} |
| {new_to_update_where_condition}; |
| |
| TRUNCATE TABLE {message}; |
| |
| INSERT INTO {message} |
| SELECT {edge_inverse}.{src} AS {vertex_id}, |
| MIN({toupdate}.{component_id}) AS {component_id} |
| {comma_toupdate_prefixed_grouping_cols} |
| FROM {toupdate}, {edge_inverse} |
| WHERE {edge_inverse}.{dest} = {toupdate}.{vertex_id} |
| {edge_inverse_to_update_where_condition} |
| GROUP BY {edge_inverse}.{src} {comma_toupdate_prefixed_grouping_cols}; |
| |
| INSERT INTO {message} |
| SELECT {edge_table}.{dest} AS {vertex_id}, |
| MIN({toupdate}.{component_id}) AS {component_id} |
| {comma_toupdate_prefixed_grouping_cols} |
| FROM {toupdate}, {edge_table} |
| WHERE {edge_table}.{src} = {toupdate}.{vertex_id} |
| {edge_to_update_where_condition} |
| GROUP BY {edge_table}.{dest} {comma_toupdate_prefixed_grouping_cols}; |
| |
| TRUNCATE TABLE {oldupdate}; |
| |
| SELECT COUNT(*) AS cnt_sum FROM {toupdate}; |
| """ |
| iteration_counter = 0 |
| while nodes_to_update > 0 and iteration_counter < iteration_limit: |
| with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"): |
| |
| nodes_to_update = plpy.execute(loop_sql.format(**locals()))[0]["cnt_sum"] |
| iteration_counter += 1 |
| |
| if not is_platform_pg(): |
| # Drop intermediate table created for Greenplum |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse)) |
| |
| if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id': |
| plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO {vertex_id_in}".format(**locals())) |
| |
| if nodes_to_update is None or nodes_to_update == 0: |
| nodes_to_update = 0 |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(message)) |
| |
| plpy.execute("DROP VIEW IF EXISTS {0}, {1}".format(vertex_view, edge_view)) |
| if not warm_start: |
| plpy.execute(""" |
| CREATE TABLE {out_table_summary} AS SELECT |
| '{grouping_cols_summary}'::TEXT AS grouping_cols, |
| '{vertex_table_in}'::TEXT AS vertex_table, |
| '{vertex_id_in}'::TEXT AS vertex_id, |
| '{vertex_type}'::TEXT AS vertex_id_type, |
| '{edge_table_in}'::TEXT AS edge_table, |
| '{edge_args}'::TEXT AS edge_args, |
| {iteration_counter}::BIGINT AS iteration_counter, |
| {nodes_to_update}::BIGINT AS nodes_to_update; |
| """.format(grouping_cols_summary='' if not grouping_cols else grouping_cols, |
| **locals())) |
| else: |
| plpy.execute(""" |
| UPDATE {out_table_summary} SET |
| iteration_counter = iteration_counter + {iteration_counter}, |
| nodes_to_update = {nodes_to_update}; |
| """.format(**locals())) |
| |
| if grouping_cols: |
| plpy.execute("DROP TABLE IF EXISTS {distinct_grp_table}".format(**locals())) |
| plpy.execute("DROP TABLE IF EXISTS {oldupdate}, {toupdate}".format(**locals())) |
| plpy.execute("DROP VIEW IF EXISTS {vertex_view}, {edge_view}".format(**locals())) |
| |
| |
| # WCC Helper functions: |
| def extract_wcc_summary_cols(wcc_summary_table): |
| """ |
| WCC helper function to find all values stored in the summary table. |
| Args: |
| @param wcc_summary_table |
| |
| Returns: |
| Dictionary, containing the column names and their values. The |
| keys in the dictionary are 'vertex_id', 'vertex_id_type' and |
| 'grouoping_cols' if grouping cols exist. |
| """ |
| return plpy.execute("SELECT * FROM {wcc_summary_table} ".format( |
| **locals()))[0] |
| |
| |
| def preprocess_wcc_table_args(wcc_table, out_table): |
| """ |
| Validate wcc_table, wcc_table_summary and the output tables. Read |
| the summary table and return a dictionary of the summary table. |
| """ |
| validate_output_and_summary_tables(wcc_table, "WCC", out_table) |
| wcc_summary_table = add_postfix(wcc_table, "_summary") |
| return extract_wcc_summary_cols(wcc_summary_table) |
| |
| def check_input_vertex_validity(wcc_args, vertices): |
| """ |
| Function to check if vertices are all valid, i.e., are present |
| in the WCC's original input vertex table. Even if one of the input |
| vertices (when more than one) is not valid, return False |
| Args: |
| @param wcc_args (dict) |
| @param vertices (list) |
| Returns: |
| True if all vertices in the list are present in the original input |
| vertex table, False otherwise. |
| """ |
| vertex_table = wcc_args['vertex_table'] |
| _assert(table_exists(vertex_table), |
| "Graph WCC: Input vertex table '{0}' does not exist.".format( |
| vertex_table)) |
| vertex_col = wcc_args['vertex_id'] |
| where_clause = ' OR '.join(["{0}='{1}'".format(vertex_col, v) |
| for v in vertices]) |
| count = plpy.execute(""" |
| SELECT COUNT(*) as count FROM ( |
| SELECT 1 FROM {vertex_table} |
| WHERE {where_clause} |
| ) t |
| """.format(**locals()))[0]['count'] |
| _assert(count == len(vertices), |
| "Graph WCC: Invalid input vertex in {0}.".format(str(vertices))) |
| |
| def check_input_mcol_vertex_validity(schema_madlib, wcc_args, vertices, vertex_str_list): |
| """ |
| Function to check if vertices are all valid, i.e., are present |
| in the WCC's original input vertex table. Even if one of the input |
| vertices (when more than one) is not valid, return False |
| Args: |
| @param wcc_args (dict) |
| @param vertices (list of list) |
| Returns: |
| True if all vertices in the list are present in the original input |
| vertex table, False otherwise. |
| """ |
| vertex_table = wcc_args['vertex_table'] |
| _assert(table_exists(vertex_table), |
| "Graph WCC: Input vertex table '{0}' does not exist.".format( |
| vertex_table)) |
| vertex_col = wcc_args['vertex_id'] |
| |
| psubq = unique_string(desp='psubquery') |
| vertex_str = ','.join(vertex_str_list) |
| |
| vertex_join = """ |
| (SELECT ({schema_madlib}.array_unnest_2d_to_1d('{{ {vertex_str} }}'::BIGINT[])).unnest_result |
| ) {psubq} |
| """.format(**locals()) |
| |
| sql = """ |
| SELECT COUNT(*) as count |
| FROM |
| {vertex_table} |
| INNER JOIN {vertex_join} |
| ON (ARRAY{vertex_col}::BIGINT[] = unnest_result) |
| """.format(**locals()) |
| count = plpy.execute(sql)[0]['count'] |
| _assert(count == len(vertices), |
| "Graph WCC: Invalid input vertex in {0}.".format(str(vertices))) |
| |
| def create_component_cnts_table(wcc_table, cnts_out_table, |
| grouping_cols_comma): |
| """ |
| WCC helper function to create a table containing the number of vertices |
| per component. |
| |
| Args: |
| @param wcc_table |
| @param cnts_out_table |
| @param grouping_cols_comma |
| |
| Returns: |
| Creates a new table called cnts_out_table with necessary content. |
| """ |
| plpy.execute(""" |
| CREATE TABLE {cnts_out_table} AS |
| SELECT {grouping_cols_select} component_id, COUNT(*) as num_vertices |
| FROM {wcc_table} |
| GROUP BY {group_by_clause} component_id |
| """.format(grouping_cols_select=grouping_cols_comma, |
| group_by_clause=grouping_cols_comma, **locals())) |
| |
| |
| def graph_wcc_largest_cpt(schema_madlib, wcc_table, largest_cpt_table, |
| **kwargs): |
| """ |
| WCC helper function that computes the largest weakly connected component |
| in each group (if grouping cols are defined) |
| |
| Args: |
| @param wcc_table |
| @param largest_cpt_table |
| |
| Returns: |
| Creates table largest_cpt_table that contains a column called |
| component_id that refers to the largest component. If grouping_cols |
| are defined, columns corresponding to the grouping_cols are also |
| created, and the largest component is computed with regard to a group. |
| """ |
| with MinWarning("warning"): |
| wcc_args = preprocess_wcc_table_args(wcc_table, largest_cpt_table) |
| # Create temp table containing the number of vertices in each |
| # component. |
| tmp_cnt_table = unique_string(desp='tmpcnt') |
| if 'grouping_cols' in wcc_args: |
| grouping_cols = wcc_args['grouping_cols'] |
| else: |
| grouping_cols = '' |
| glist = split_quoted_delimited_str(grouping_cols) |
| grouping_cols_comma = '' if not grouping_cols else grouping_cols + ',' |
| |
| subq = unique_string(desp='q') |
| subt = unique_string(desp='t') |
| create_component_cnts_table(wcc_table, tmp_cnt_table, |
| grouping_cols_comma) |
| # Query to find ALL largest components within groups. |
| select_grouping_cols_subq = '' |
| groupby_clause_subt = '' |
| grouping_cols_join = '' |
| if grouping_cols: |
| select_grouping_cols_subq = get_table_qualified_col_str(subq, glist) + ',' |
| groupby_clause_subt = ' GROUP BY {0}'.format(grouping_cols) |
| grouping_cols_join = ' AND ' + _check_groups (subq, subt, glist) |
| |
| plpy.execute(""" |
| CREATE TABLE {largest_cpt_table} AS |
| SELECT {select_grouping_cols_subq} {subq}.component_id, |
| {subt}.maxcnt AS num_vertices |
| FROM {tmp_cnt_table} AS {subq} |
| INNER JOIN ( |
| SELECT {grouping_cols_select_subt} |
| MAX(num_vertices) AS maxcnt |
| FROM {tmp_cnt_table} |
| {groupby_clause_subt} |
| ) {subt} |
| ON {subq}.num_vertices={subt}.maxcnt |
| {grouping_cols_join} |
| """.format(grouping_cols_select_subt=grouping_cols_comma, |
| **locals())) |
| # Drop temp table |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(tmp_cnt_table)) |
| |
| |
| def graph_wcc_histogram(schema_madlib, wcc_table, histogram_table, **kwargs): |
| """ |
| Retrieve Histogram of Vertices Per Connected Component |
| |
| Args: |
| @param wcc_table |
| @param histogram_table |
| |
| Returns: |
| Creates and populates histogram_table with number of vertices per |
| component (represented by column num_vertices). Columns corresponding |
| to grouping_cols are also created if defined. |
| """ |
| with MinWarning("warning"): |
| wcc_args = preprocess_wcc_table_args(wcc_table, histogram_table) |
| grouping_cols_comma = '' |
| if 'grouping_cols' in wcc_args: |
| grouping_cols_comma = wcc_args['grouping_cols'] + ', ' |
| create_component_cnts_table(wcc_table, histogram_table, |
| grouping_cols_comma) |
| |
| |
| def graph_wcc_vertex_check(schema_madlib, wcc_table, vertex_pair, pair_table, |
| **kwargs): |
| """ |
| WCC helper function to check if two vertices belong to the same component. |
| |
| Args: |
| @param wcc_table |
| @param vertex_pair |
| @param pair_table |
| |
| Returns: |
| Creates and populates pair_table with all the components that have |
| both the vertices specified in the vertex_pair attribute. There are |
| columns for grouping, if specified. |
| """ |
| with MinWarning("warning"): |
| wcc_args = preprocess_wcc_table_args(wcc_table, pair_table) |
| |
| _assert(vertex_pair and len(vertex_pair) == 2, |
| "Graph WCC: Invalid vertex pair ({0}) input.".format( |
| vertex_pair)) |
| |
| if isinstance(vertex_pair[0], list): |
| vertex_str_list = [] |
| for i in vertex_pair: |
| vertex_str_list.append('{' + ', '.join([str(j) for j in i]) + '}') |
| check_input_mcol_vertex_validity(schema_madlib, wcc_args, vertex_pair, vertex_str_list) |
| else: |
| check_input_vertex_validity(wcc_args, vertex_pair) |
| vertex_str_list = vertex_pair |
| |
| grouping_cols_comma = '' |
| if 'grouping_cols' in wcc_args: |
| grouping_cols_comma = wcc_args['grouping_cols'] + ', ' |
| subq = unique_string(desp='subq') |
| inner_select_clause = " SELECT {0} component_id ".format( |
| grouping_cols_comma) |
| inner_from_clause = " FROM {0} ".format(wcc_table) |
| inner_groupby_clause = " GROUP BY {0} component_id".format( |
| grouping_cols_comma) |
| |
| vertex_id = wcc_args['vertex_id'] |
| if vertex_id[0] == '[' and vertex_id[-1] == ']': |
| vertex_id = 'id' |
| |
| sql = """ |
| CREATE TABLE {pair_table} AS |
| SELECT {grouping_cols_comma} component_id |
| FROM ( |
| {inner_select_clause}, 1 |
| {inner_from_clause} |
| WHERE {vertex_id}='{vertex1}' |
| {inner_groupby_clause} |
| UNION ALL |
| {inner_select_clause}, 2 |
| {inner_from_clause} |
| WHERE {vertex_id}='{vertex2}' |
| {inner_groupby_clause} |
| ) {subq} |
| GROUP BY {grouping_cols_comma} component_id |
| HAVING COUNT(*)=2 |
| """.format(vertex1=vertex_str_list[0], vertex2=vertex_str_list[1], **locals()) |
| plpy.execute(sql) |
| |
| def graph_wcc_reachable_vertices(schema_madlib, wcc_table, src, |
| reachable_vertices_table, **kwargs): |
| """ |
| WCC helper function to retrieve all vertices reachable from a vertex |
| |
| Args: |
| @param wcc_table |
| @param src |
| @param reachable_vertices_table |
| |
| Results: |
| Creates and populates reachable_vertices_table table with all the |
| vertices reachable from src vertex, where reachability is with |
| regard to a component. There are columns for grouping, if specified. |
| """ |
| with MinWarning("warning"): |
| wcc_args = preprocess_wcc_table_args(wcc_table, |
| reachable_vertices_table) |
| if not isinstance(src, list): |
| vertex_str = str(src) |
| check_input_vertex_validity(wcc_args, [vertex_str]) |
| else: |
| vertex_str = '{' + ', '.join([str(j) for j in src]) + '}' |
| check_input_mcol_vertex_validity(schema_madlib, wcc_args, [src], [vertex_str]) |
| |
| grouping_cols_comma = '' |
| grouping_cols = '' |
| if 'grouping_cols' in wcc_args: |
| grouping_cols = wcc_args['grouping_cols'] |
| grouping_cols_comma = grouping_cols + ', ' |
| |
| vertex_id = wcc_args['vertex_id'] |
| if vertex_id[0] == '[' and vertex_id[-1] == ']': |
| vertex_id = 'id' |
| |
| subq = unique_string(desp='subq') |
| glist = split_quoted_delimited_str(grouping_cols) |
| grouping_cols_join = '' if not grouping_cols else ' AND ' + \ |
| _check_groups(wcc_table, subq, glist) |
| subq_grouping_cols = '' if not grouping_cols else \ |
| get_table_qualified_col_str(subq, glist) + ', ' |
| plpy.execute(""" |
| CREATE TABLE {reachable_vertices_table} AS |
| SELECT {subq_grouping_cols} {subq}.component_id, |
| {wcc_table}.{vertex_id} AS dest |
| FROM {wcc_table} |
| INNER JOIN ( |
| SELECT {grouping_cols_comma} component_id, {vertex_id} |
| FROM {wcc_table} |
| GROUP BY {vertex_id}, {grouping_cols_comma} component_id |
| HAVING {vertex_id} ='{vertex_str}' |
| ) {subq} |
| ON {wcc_table}.component_id={subq}.component_id |
| {grouping_cols_join} |
| WHERE {wcc_table}.{vertex_id} != '{vertex_str}' |
| """.format(**locals())) |
| |
| |
| def graph_wcc_num_cpts(schema_madlib, wcc_table, count_table, **kwargs): |
| """ |
| WCC helper function to count the number of connected components |
| |
| Args: |
| @param: wcc_table |
| @param: count_table |
| |
| Results: |
| Creates and populates the count_table table with the total number |
| of components. If grouping_cols is involved, number of components |
| are computed with regard to a group. |
| """ |
| with MinWarning("warning"): |
| wcc_args = preprocess_wcc_table_args(wcc_table, count_table) |
| grouping_cols = '' |
| grouping_cols_comma = '' |
| if 'grouping_cols' in wcc_args: |
| grouping_cols = wcc_args['grouping_cols'] |
| grouping_cols_comma = grouping_cols + ', ' |
| plpy.execute(""" |
| CREATE TABLE {count_table} AS |
| SELECT {grouping_cols_comma} |
| COUNT(DISTINCT component_id) AS num_components |
| FROM {wcc_table} |
| {grp_by_clause} |
| """.format(grp_by_clause='' if not grouping_cols else |
| ' GROUP BY {0}'.format(grouping_cols), **locals())) |
| |
| |
| def wcc_help(schema_madlib, message, **kwargs): |
| """ |
| Help function for wcc |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if message is not None and \ |
| message.lower() in ("usage", "help", "?"): |
| help_string = "Get from method below" |
| help_string = get_graph_usage( |
| schema_madlib, |
| 'Weakly Connected Components', |
| """out_table TEXT, -- Output table of weakly connected components |
| grouping_col TEXT -- Comma separated column names to group on |
| -- (DEFAULT = NULL, no grouping) |
| """) + """ |
| |
| Once the above function is used to obtain the out_table, it can be used to |
| call several other helper functions based on weakly connected components: |
| |
| (1) To retrieve the largest connected component: |
| SELECT {schema_madlib}.graph_wcc_largest_cpt( |
| wcc_table TEXT, -- Name of the table that contains the WCC output. |
| largest_cpt_table TEXT -- Name of the output table that contains the |
| -- largest components details. |
| ); |
| |
| (2) To retrieve the histogram of vertices per connected component: |
| SELECT {schema_madlib}.graph_wcc_histogram( |
| wcc_table TEXT, -- Name of the table that contains the WCC output. |
| histogram_table TEXT -- Name of the output table that contains the |
| -- histogram of vertices per connected component. |
| ); |
| |
| (3) To check if two vertices belong to the same component: |
| SELECT {schema_madlib}.graph_wcc_vertex_check( |
| wcc_table TEXT, -- Name of the table that contains the WCC output. |
| vertex_pair TEXT, -- Pair of vertex IDs, separated by a comma. |
| pair_table TEXT -- Name of the output table that contains the all |
| -- components that contain the two vertices. |
| ); |
| |
| (4) To retrieve all vertices reachable from a vertex: |
| SELECT {schema_madlib}.graph_wcc_reachable_vertices( |
| wcc_table TEXT, -- Name of the table that contains the WCC output. |
| src TEXT, -- Initial source vertex. |
| reachable_vertices_table TEXT -- Name of the output table that |
| -- contains all vertices in a |
| -- component reachable from src. |
| ); |
| |
| (5) To count the number of connected components: |
| SELECT {schema_madlib}.graph_wcc_num_cpts( |
| wcc_table TEXT, -- Name of the table that contains the WCC output. |
| count_table TEXT -- Name of the output table that contains the count |
| -- of number of components. |
| );""" |
| else: |
| help_string = """ |
| ---------------------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------------------- |
| Given a directed graph, a weakly connected component is a sub-graph of the |
| original graph where all vertices are connected to each other by some path, |
| ignoring the direction of edges. In case of an undirected graph, a weakly |
| connected component is also a strongly connected component. |
| -- |
| For an overview on usage, run: |
| SELECT {schema_madlib}.weakly_connected_components('usage'); |
| """ |
| |
| return help_string.format(schema_madlib=schema_madlib) |
| # --------------------------------------------------------------------- |