blob: 0c80aabac5e283a501a405234fc27bfb92746ecc [file] [log] [blame]
# coding=utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Weakly Connected Components
# Please refer to the wcc.sql_in file for the documentation
@file wcc.py_in
@namespace graph
import plpy
from utilities.control import SetGUC
from utilities.utilities import _assert
from utilities.utilities import _check_groups
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string, split_quoted_delimited_str
from utilities.validate_args import columns_exist_in_table, get_expr_type
from utilities.utilities import is_platform_pg
from utilities.utilities import add_postfix
from utilities.validate_args import table_exists
from utilities.utilities import rename_table
from utilities.control import MinWarning
from graph_utils import validate_graph_coding, get_graph_usage
from graph_utils import validate_output_and_summary_tables
def validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
edge_params, out_table, out_table_summary,
grouping_cols_list, module_name):
Function to validate input parameters for wcc
validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
out_table, module_name)
_assert(not table_exists(out_table_summary),
"Graph {module_name}: Output summary table already exists!".format(**locals()))
if grouping_cols_list:
# validate the grouping columns. We currently only support grouping_cols
# to be column names in the edge_table, and not expressions!
_assert(columns_exist_in_table(edge_table, grouping_cols_list, schema_madlib),
"Weakly Connected Components error: "
"One or more grouping columns specified do not exist!")
def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
out_table, grouping_cols, **kwargs):
Function that computes the wcc
@param vertex_table
@param vertex_id
@param edge_table
@param dest_vertex
@param out_table
@param grouping_cols
old_msg_level = plpy.execute("""
SELECT setting
FROM pg_settings
WHERE name='client_min_messages'
plpy.execute('SET client_min_messages TO warning')
params_types = {'src': str, 'dest': str}
default_args = {'src': 'src', 'dest': 'dest'}
edge_params = extract_keyvalue_params(
edge_args, params_types, default_args)
# populate default values for optional params if null, and prepare data
# to be written into the summary table (*_st variable names)
if not vertex_id:
vertex_id = "id"
v_st = "id"
v_st = vertex_id
if not grouping_cols:
grouping_cols = ''
out_table_summary = ''
if out_table:
out_table_summary = add_postfix(out_table, "_summary")
grouping_cols_list = split_quoted_delimited_str(grouping_cols)
validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
edge_params, out_table, out_table_summary,
grouping_cols_list, 'Weakly Connected Components')
src = edge_params["src"]
dest = edge_params["dest"]
message = unique_string(desp='message')
oldupdate = unique_string(desp='oldupdate')
newupdate = unique_string(desp='newupdate')
toupdate = unique_string(desp='toupdate')
temp_out_table = unique_string(desp='tempout')
edge_inverse = unique_string(desp='edge_inverse')
distribution = '' if is_platform_pg() else \
"DISTRIBUTED BY ({0})".format(vertex_id)
subq_prefixed_grouping_cols = ''
comma_toupdate_prefixed_grouping_cols = ''
comma_oldupdate_prefixed_grouping_cols = ''
old_new_update_where_condition = ''
new_to_update_where_condition = ''
edge_to_update_where_condition = ''
edge_inverse_to_update_where_condition = ''
BIGINT_MAX = 9223372036854775807
component_id = 'component_id'
grouping_cols_comma = '' if not grouping_cols else grouping_cols + ','
comma_grouping_cols = '' if not grouping_cols else ',' + grouping_cols
if not is_platform_pg():
# In Greenplum, to avoid redistribution of data when in later queries,
# edge_table is duplicated by creating a temporary table distributed
# on dest column
plpy.execute(""" CREATE TABLE {edge_inverse} AS
SELECT * FROM {edge_table} DISTRIBUTED BY ({dest});
edge_inverse = edge_table
if grouping_cols:
distribution = ('' if is_platform_pg() else
"DISTRIBUTED BY ({0}, {1})".format(grouping_cols,
# Update some variables useful for grouping based query strings
subq = unique_string(desp='subquery')
distinct_grp_table = unique_string(desp='grptable')
comma_toupdate_prefixed_grouping_cols = ', ' + \
get_table_qualified_col_str(toupdate, grouping_cols_list)
comma_oldupdate_prefixed_grouping_cols = ', ' + \
get_table_qualified_col_str(oldupdate, grouping_cols_list)
subq_prefixed_grouping_cols = get_table_qualified_col_str(subq, grouping_cols_list)
old_new_update_where_condition = ' AND ' + \
_check_groups(oldupdate, newupdate, grouping_cols_list)
new_to_update_where_condition = ' AND ' + \
_check_groups(newupdate, toupdate, grouping_cols_list)
edge_to_update_where_condition = ' AND ' + \
_check_groups(edge_table, toupdate, grouping_cols_list)
edge_inverse_to_update_where_condition = ' AND ' + \
_check_groups(edge_inverse, toupdate, grouping_cols_list)
join_grouping_cols = _check_groups(subq, distinct_grp_table, grouping_cols_list)
group_by_clause_newupdate = ('{0}, {1}.{2}'.format(subq_prefixed_grouping_cols,
subq, vertex_id))
grp_sql = """
CREATE TABLE {distinct_grp_table} AS
SELECT DISTINCT {grouping_cols} FROM {edge_table};
prep_sql = """
CREATE TABLE {newupdate} AS
SELECT {subq}.{vertex_id},
CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
FROM {distinct_grp_table} INNER JOIN (
SELECT {grouping_cols_comma} {src} AS {vertex_id}
FROM {edge_table}
SELECT {grouping_cols_comma} {dest} AS {vertex_id}
FROM {edge_inverse}
) {subq}
ON {join_grouping_cols}
GROUP BY {group_by_clause_newupdate}
DROP TABLE IF EXISTS {distinct_grp_table};
""".format(select_grouping_cols=',' + subq_prefixed_grouping_cols,
message_sql = """
SELECT {vertex_id},
CAST({vertex_id} AS BIGINT) AS {component_id}
FROM {newupdate}
prep_sql = """
CREATE TABLE {newupdate} AS
SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
FROM {vertex_table}
SELECT {vertex_id}, CAST({vertex_id} AS BIGINT) AS {component_id}
FROM {vertex_table}
oldupdate_sql = """
CREATE TABLE {oldupdate} AS
SELECT {message}.{vertex_id},
MIN({message}.{component_id}) AS {component_id}
FROM {message}
GROUP BY {grouping_cols_comma} {vertex_id}
toupdate_sql = """
CREATE TABLE {toupdate} AS
SELECT * FROM {oldupdate}
nodes_to_update = 1
loop_sql = """
TRUNCATE TABLE {oldupdate};
INSERT INTO {oldupdate}
SELECT {message}.{vertex_id},
MIN({message}.{component_id}) AS {component_id}
FROM {message}
GROUP BY {grouping_cols_comma} {vertex_id};
TRUNCATE TABLE {toupdate};
INSERT INTO {toupdate}
SELECT {oldupdate}.{vertex_id},
FROM {oldupdate}, {newupdate}
WHERE {oldupdate}.{vertex_id}={newupdate}.{vertex_id}
AND {oldupdate}.{component_id}<{newupdate}.{component_id}
UPDATE {newupdate} SET
FROM {toupdate}
WHERE {newupdate}.{vertex_id}={toupdate}.{vertex_id}
INSERT INTO {message}
SELECT {edge_inverse}.{src} AS {vertex_id},
MIN({toupdate}.{component_id}) AS {component_id}
FROM {toupdate}, {edge_inverse}
WHERE {edge_inverse}.{dest} = {toupdate}.{vertex_id}
GROUP BY {edge_inverse}.{src} {comma_toupdate_prefixed_grouping_cols};
INSERT INTO {message}
SELECT {edge_table}.{dest} AS {vertex_id},
MIN({toupdate}.{component_id}) AS {component_id}
FROM {toupdate}, {edge_table}
WHERE {edge_table}.{src} = {toupdate}.{vertex_id}
GROUP BY {edge_table}.{dest} {comma_toupdate_prefixed_grouping_cols};
TRUNCATE TABLE {oldupdate};
while nodes_to_update > 0:
# Look at all the neighbors of a node, and assign the smallest node id
# among the neighbors as its component_id. The next table starts off
# with very high component_id (BIGINT_MAX). The component_id of all nodes
# which obtain a smaller component_id after looking at its neighbors are
# updated in the next table. At every iteration update only those nodes
# whose component_id in the previous iteration are greater than what was
# found in the current iteration.
with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
if grouping_cols:
nodes_to_update = plpy.execute("""
SELECT SUM(cnt) AS cnt_sum
FROM {toupdate}
GROUP BY {grouping_cols}
) t
nodes_to_update = plpy.execute("""
SELECT COUNT(*) AS cnt FROM {toupdate}
if not is_platform_pg():
# Drop intermediate table created for Greenplum
plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse))
rename_table(schema_madlib, newupdate, out_table)
# Create summary table. We only need the vertex_id and grouping columns
# in it.
vertex_id_type = get_expr_type(vertex_id, vertex_table)
CREATE TABLE {out_table_summary} AS SELECT
'{vertex_table}'::TEXT AS vertex_table,
'{vertex_id}'::TEXT AS vertex_id,
'{vertex_id_type}'::TEXT AS vertex_id_type;
DROP TABLE IF EXISTS {message},{oldupdate},{newupdate},{toupdate};
""".format(grouping_cols_summary='' if not grouping_cols else
"'{0}'::TEXT AS grouping_cols, ".format(grouping_cols),
# WCC Helper functions:
def extract_wcc_summary_cols(wcc_summary_table):
WCC helper function to find all values stored in the summary table.
@param wcc_summary_table
Dictionary, containing the column names and their values. The
keys in the dictionary are 'vertex_id', 'vertex_id_type' and
'grouoping_cols' if grouping cols exist.
return plpy.execute("SELECT * FROM {wcc_summary_table} ".format(
def preprocess_wcc_table_args(wcc_table, out_table):
Validate wcc_table, wcc_table_summary and the output tables. Read
the summary table and return a dictionary of the summary table.
validate_output_and_summary_tables(wcc_table, "WCC", out_table)
wcc_summary_table = add_postfix(wcc_table, "_summary")
return extract_wcc_summary_cols(wcc_summary_table)
def check_input_vertex_validity(wcc_args, vertices):
Function to check if vertices are all valid, i.e., are present
in the WCC's original input vertex table. Even if one of the input
vertices (when more than one) is not valid, return False
@param wcc_args (dict)
@param vertices (list)
True if all vertices in the list are present in the original input
vertex table, False otherwise.
vertex_table = wcc_args['vertex_table']
"Graph WCC: Input vertex table '{0}' does not exist.".format(
vertex_col = wcc_args['vertex_id']
where_clause = ' OR '.join(["{0}='{1}'".format(vertex_col, v)
for v in vertices])
count = plpy.execute("""
SELECT COUNT(*) as count FROM (
SELECT 1 FROM {vertex_table}
WHERE {where_clause}
) t
_assert(count == len(vertices),
"Graph WCC: Invalid input vertex in {0}.".format(str(vertices)))
def create_component_cnts_table(wcc_table, cnts_out_table,
WCC helper function to create a table containing the number of vertices
per component.
@param wcc_table
@param cnts_out_table
@param grouping_cols_comma
Creates a new table called cnts_out_table with necessary content.
CREATE TABLE {cnts_out_table} AS
SELECT {grouping_cols_select} component_id, COUNT(*) as num_vertices
FROM {wcc_table}
GROUP BY {group_by_clause} component_id
group_by_clause=grouping_cols_comma, **locals()))
def graph_wcc_largest_cpt(schema_madlib, wcc_table, largest_cpt_table,
WCC helper function that computes the largest weakly connected component
in each group (if grouping cols are defined)
@param wcc_table
@param largest_cpt_table
Creates table largest_cpt_table that contains a column called
component_id that refers to the largest component. If grouping_cols
are defined, columns corresponding to the grouping_cols are also
created, and the largest component is computed with regard to a group.
with MinWarning("warning"):
wcc_args = preprocess_wcc_table_args(wcc_table, largest_cpt_table)
# Create temp table containing the number of vertices in each
# component.
tmp_cnt_table = unique_string(desp='tmpcnt')
if 'grouping_cols' in wcc_args:
grouping_cols = wcc_args['grouping_cols']
grouping_cols = ''
glist = split_quoted_delimited_str(grouping_cols)
grouping_cols_comma = '' if not grouping_cols else grouping_cols + ','
subq = unique_string(desp='q')
subt = unique_string(desp='t')
create_component_cnts_table(wcc_table, tmp_cnt_table,
# Query to find ALL largest components within groups.
select_grouping_cols_subq = ''
groupby_clause_subt = ''
grouping_cols_join = ''
if grouping_cols:
select_grouping_cols_subq = get_table_qualified_col_str(subq, glist) + ','
groupby_clause_subt = ' GROUP BY {0}'.format(grouping_cols)
grouping_cols_join = ' AND ' + _check_groups (subq, subt, glist)
CREATE TABLE {largest_cpt_table} AS
SELECT {select_grouping_cols_subq} {subq}.component_id,
{subt}.maxcnt AS num_vertices
FROM {tmp_cnt_table} AS {subq}
SELECT {grouping_cols_select_subt}
MAX(num_vertices) AS maxcnt
FROM {tmp_cnt_table}
) {subt}
ON {subq}.num_vertices={subt}.maxcnt
# Drop temp table
plpy.execute("DROP TABLE IF EXISTS {0}".format(tmp_cnt_table))
def graph_wcc_histogram(schema_madlib, wcc_table, histogram_table, **kwargs):
Retrieve Histogram of Vertices Per Connected Component
@param wcc_table
@param histogram_table
Creates and populates histogram_table with number of vertices per
component (represented by column num_vertices). Columns corresponding
to grouping_cols are also created if defined.
with MinWarning("warning"):
wcc_args = preprocess_wcc_table_args(wcc_table, histogram_table)
grouping_cols_comma = ''
if 'grouping_cols' in wcc_args:
grouping_cols_comma = wcc_args['grouping_cols'] + ', '
create_component_cnts_table(wcc_table, histogram_table,
def graph_wcc_vertex_check(schema_madlib, wcc_table, vertex_pair, pair_table,
WCC helper function to check if two vertices belong to the same component.
@param wcc_table
@param vertex_pair
@param pair_table
Creates and populates pair_table with all the components that have
both the vertices specified in the vertex_pair attribute. There are
columns for grouping, if specified.
with MinWarning("warning"):
wcc_args = preprocess_wcc_table_args(wcc_table, pair_table)
vertices = split_quoted_delimited_str(vertex_pair)
_assert(vertices and len(vertices) == 2,
"Graph WCC: Invalid vertex pair ({0}) input.".format(
check_input_vertex_validity(wcc_args, vertices)
grouping_cols_comma = ''
if 'grouping_cols' in wcc_args:
grouping_cols_comma = wcc_args['grouping_cols'] + ', '
subq = unique_string(desp='subq')
inner_select_clause = " SELECT {0} component_id ".format(
inner_from_clause = " FROM {0} ".format(wcc_table)
inner_groupby_clause = " GROUP BY {0} component_id".format(
CREATE TABLE {pair_table} AS
SELECT {grouping_cols_comma} component_id
{inner_select_clause}, 1
WHERE {vertex_id}='{vertex1}'
{inner_select_clause}, 2
WHERE {vertex_id}='{vertex2}'
) {subq}
GROUP BY {grouping_cols_comma} component_id
vertex1=vertices[0], vertex2=vertices[1], **locals()))
def graph_wcc_reachable_vertices(schema_madlib, wcc_table, src,
reachable_vertices_table, **kwargs):
WCC helper function to retrieve all vertices reachable from a vertex
@param wcc_table
@param src
@param reachable_vertices_table
Creates and populates reachable_vertices_table table with all the
vertices reachable from src vertex, where reachability is with
regard to a component. There are columns for grouping, if specified.
with MinWarning("warning"):
wcc_args = preprocess_wcc_table_args(wcc_table,
check_input_vertex_validity(wcc_args, split_quoted_delimited_str(src))
grouping_cols_comma = ''
grouping_cols = ''
if 'grouping_cols' in wcc_args:
grouping_cols = wcc_args['grouping_cols']
grouping_cols_comma = grouping_cols + ', '
vertex_id = wcc_args['vertex_id']
subq = unique_string(desp='subq')
glist = split_quoted_delimited_str(grouping_cols)
grouping_cols_join = '' if not grouping_cols else ' AND ' + \
_check_groups(wcc_table, subq, glist)
subq_grouping_cols = '' if not grouping_cols else \
get_table_qualified_col_str(subq, glist) + ', '
CREATE TABLE {reachable_vertices_table} AS
SELECT {subq_grouping_cols} {subq}.component_id,
{wcc_table}.{vertex_id} AS dest
FROM {wcc_table}
SELECT {grouping_cols_comma} component_id, {vertex_id}
FROM {wcc_table}
GROUP BY {vertex_id}, {grouping_cols_comma} component_id
HAVING {vertex_id}='{src}'
) {subq}
ON {wcc_table}.component_id={subq}.component_id
WHERE {wcc_table}.{vertex_id} != '{src}'
def graph_wcc_num_cpts(schema_madlib, wcc_table, count_table, **kwargs):
WCC helper function to count the number of connected components
@param: wcc_table
@param: count_table
Creates and populates the count_table table with the total number
of components. If grouping_cols is involved, number of components
are computed with regard to a group.
with MinWarning("warning"):
wcc_args = preprocess_wcc_table_args(wcc_table, count_table)
grouping_cols = ''
grouping_cols_comma = ''
if 'grouping_cols' in wcc_args:
grouping_cols = wcc_args['grouping_cols']
grouping_cols_comma = grouping_cols + ', '
CREATE TABLE {count_table} AS
SELECT {grouping_cols_comma}
COUNT(DISTINCT component_id) AS num_components
FROM {wcc_table}
""".format(grp_by_clause='' if not grouping_cols else
' GROUP BY {0}'.format(grouping_cols), **locals()))
def wcc_help(schema_madlib, message, **kwargs):
Help function for wcc
@param schema_madlib
@param message: string, Help message string
@param kwargs
String. Help/usage information
if message is not None and \
message.lower() in ("usage", "help", "?"):
help_string = "Get from method below"
help_string = get_graph_usage(
'Weakly Connected Components',
"""out_table TEXT, -- Output table of weakly connected components
grouping_col TEXT -- Comma separated column names to group on
-- (DEFAULT = NULL, no grouping)
""") + """
Once the above function is used to obtain the out_table, it can be used to
call several other helper functions based on weakly connected components:
(1) To retrieve the largest connected component:
SELECT {schema_madlib}.graph_wcc_largest_cpt(
wcc_table TEXT, -- Name of the table that contains the WCC output.
largest_cpt_table TEXT -- Name of the output table that contains the
-- largest components details.
(2) To retrieve the histogram of vertices per connected component:
SELECT {schema_madlib}.graph_wcc_histogram(
wcc_table TEXT, -- Name of the table that contains the WCC output.
histogram_table TEXT -- Name of the output table that contains the
-- histogram of vertices per connected component.
(3) To check if two vertices belong to the same component:
SELECT {schema_madlib}.graph_wcc_vertex_check(
wcc_table TEXT, -- Name of the table that contains the WCC output.
vertex_pair TEXT, -- Pair of vertex IDs, separated by a comma.
pair_table TEXT -- Name of the output table that contains the all
-- components that contain the two vertices.
(4) To retrieve all vertices reachable from a vertex:
SELECT {schema_madlib}.graph_wcc_reachable_vertices(
wcc_table TEXT, -- Name of the table that contains the WCC output.
src TEXT, -- Initial source vertex.
reachable_vertices_table TEXT -- Name of the output table that
-- contains all vertices in a
-- component reachable from src.
(5) To count the number of connected components:
SELECT {schema_madlib}.graph_wcc_num_cpts(
wcc_table TEXT, -- Name of the table that contains the WCC output.
count_table TEXT -- Name of the output table that contains the count
-- of number of components.
help_string = """
Given a directed graph, a weakly connected component is a sub-graph of the
original graph where all vertices are connected to each other by some path,
ignoring the direction of edges. In case of an undirected graph, a weakly
connected component is also a strongly connected component.
For an overview on usage, run:
SELECT {schema_madlib}.weakly_connected_components('usage');
return help_string.format(schema_madlib=schema_madlib)
# ---------------------------------------------------------------------