blob: 9e813b04c7fba19148daf0211bdb9c29554d9ea3 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Single Source Shortest Path
# Please refer to the sssp.sql_in file for the documentation
"""
@file sssp.py_in
@namespace graph
"""
import plpy
from graph_utils import validate_graph_coding
from graph_utils import get_graph_usage
from graph_utils import get_edge_params
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import _check_groups
from utilities.utilities import get_table_qualified_col_str
from utilities.utilities import add_postfix
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import split_quoted_delimited_str
from utilities.utilities import is_platform_pg
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_expr_type
def graph_sssp(schema_madlib, vertex_table, vertex_id, edge_table,
edge_args, source_vertex, out_table, grouping_cols, **kwargs):
"""
Single source shortest path function for graphs using the Bellman-Ford
algorhtm [1].
Args:
@param vertex_table Name of the table that contains the vertex data.
@param vertex_id Name of the column containing the vertex ids.
@param edge_table Name of the table that contains the edge data.
@param edge_args A comma-delimited string containing multiple
named arguments of the form "name=value".
@param source_vertex The source vertex id for the algorithm to start.
@param out_table Name of the table to store the result of SSSP.
@param grouping_cols The list of grouping columns.
[1] https://en.wikipedia.org/wiki/Bellman-Ford_algorithm
"""
with MinWarning("warning"):
INT_MAX = 2147483647
INFINITY = "'Infinity'"
EPSILON = 0.000001
message = unique_string(desp='message')
oldupdate = unique_string(desp='oldupdate')
newupdate = unique_string(desp='newupdate')
params_types = {'src': str, 'dest': str, 'weight': str}
default_args = {'src': 'src', 'dest': 'dest', 'weight': 'weight'}
edge_params = extract_keyvalue_params(edge_args,
params_types,
default_args)
# Prepare the input for recording in the summary table
if not vertex_id:
v_st = ''
vertex_id = "id"
else:
v_st = vertex_id
if not edge_args:
e_st = ''
else:
e_st = edge_args
if not grouping_cols:
g_st = ''
glist = None
else:
g_st = grouping_cols
glist = split_quoted_delimited_str(grouping_cols)
src = edge_params["src"]
dest = edge_params["dest"]
weight = edge_params["weight"]
distribution = '' if is_platform_pg() else "DISTRIBUTED BY ({0})".format(vertex_id)
local_distribution = '' if is_platform_pg() else "DISTRIBUTED BY (id)"
_validate_sssp(vertex_table, vertex_id, edge_table,
edge_params, source_vertex, out_table, glist)
plpy.execute(" DROP TABLE IF EXISTS {0},{1},{2}".format(
message, oldupdate, newupdate))
# Initialize grouping related variables
comma_grp = ""
comma_grp_e = ""
comma_grp_m = ""
grp_comma = ""
checkg_oo = ""
checkg_eo = ""
checkg_ex = ""
checkg_om = ""
group_by = ""
if grouping_cols:
comma_grp = " , " + grouping_cols
group_by = " , " + get_table_qualified_col_str(edge_table, glist)
comma_grp_e = " , " + get_table_qualified_col_str(edge_table, glist)
comma_grp_m = " , " + get_table_qualified_col_str("message", glist)
grp_comma = grouping_cols + " , "
checkg_oo_sub = _check_groups(out_table, "oldupdate", glist)
checkg_oo = " AND " + checkg_oo_sub
checkg_eo = " AND " + _check_groups(edge_table, "oldupdate", glist)
checkg_ex = " AND " + _check_groups(edge_table, "x", glist)
checkg_om = " AND " + _check_groups("out_table", "message", glist)
w_type = get_expr_type(weight, edge_table).lower()
init_w = INT_MAX
if w_type in ['real', 'double precision', 'float8', 'bigint']:
init_w = INFINITY
# We keep a table of every vertex, the minimum cost to that destination
# seen so far and the parent to this vertex in the associated shortest
# path. This table will be updated throughout the execution.
plpy.execute(
""" CREATE TABLE {out_table} AS (SELECT
{grp_comma} {src} AS {vertex_id}, {weight},
{src} AS parent FROM {edge_table} LIMIT 0)
{distribution} """.format(**locals()))
# We keep a summary table to keep track of the parameters used for this
# SSSP run. This table is used in the path finding function to eliminate
# the need for repetition.
summary_table = add_postfix(out_table, "_summary")
plpy.execute(""" CREATE TABLE {summary_table} (
vertex_table TEXT,
vertex_id TEXT,
edge_table TEXT,
edge_args TEXT,
source_vertex BIGINT,
out_table TEXT,
grouping_cols TEXT)
""".format(**locals()))
plpy.execute(""" INSERT INTO {summary_table} VALUES
('{vertex_table}', '{v_st}', '{edge_table}', '{e_st}',
{source_vertex}, '{out_table}', '{g_st}')
""".format(**locals()))
# We keep 2 update tables and alternate them during the execution.
# This is necessary since we need to know which vertices are updated in
# the previous iteration to calculate the next set of updates.
plpy.execute(
""" CREATE TEMP TABLE {oldupdate} AS (SELECT
{src} AS id, {weight},
{src} AS parent {comma_grp} FROM {edge_table} LIMIT 0)
{local_distribution}
""".format(**locals()))
plpy.execute(
""" CREATE TEMP TABLE {newupdate} AS (SELECT
{src} AS id, {weight},
{src} AS parent {comma_grp} FROM {edge_table} LIMIT 0)
{local_distribution}
""".format(**locals()))
# GPDB has distributed by clauses to help them with indexing.
# For Postgres we add the indices manually.
if is_platform_pg():
plpy.execute("""
CREATE INDEX ON {out_table} ({vertex_id});
CREATE INDEX ON {oldupdate} (id);
CREATE INDEX ON {newupdate} (id);
""".format(**locals()))
# The initialization step is quite different when grouping is involved
# since not every group (subgraph) will have the same set of vertices.
# Example:
# Assume there are two grouping columns g1 and g2
# g1 values are 0 and 1. g2 values are 5 and 6
if grouping_cols:
distinct_grp_table = unique_string(desp='grp')
plpy.execute("DROP TABLE IF EXISTS {distinct_grp_table}".
format(**locals()))
plpy.execute("""
CREATE TEMP TABLE {distinct_grp_table} AS
SELECT DISTINCT {grouping_cols} FROM {edge_table}
""".format(**locals()))
subq = unique_string(desp='subquery')
checkg_ds_sub = _check_groups(distinct_grp_table, subq, glist)
grp_d_comma = get_table_qualified_col_str(distinct_grp_table, glist) + ","
plpy.execute("""
INSERT INTO {out_table}
SELECT {grp_d_comma} {vertex_id} AS {vertex_id},
{init_w} AS {weight}, NULL::INT AS parent
FROM {distinct_grp_table} INNER JOIN
(
SELECT {src} AS {vertex_id} {comma_grp}
FROM {edge_table}
UNION
SELECT {dest} AS {vertex_id} {comma_grp}
FROM {edge_table}
) {subq} ON ({checkg_ds_sub})
WHERE {vertex_id} IS NOT NULL
""".format(**locals()))
plpy.execute("""
INSERT INTO {oldupdate}
SELECT {source_vertex}, 0, {source_vertex},
{grouping_cols}
FROM {distinct_grp_table}
""".format(**locals()))
# The maximum number of vertices for any group.
# Used for determining negative cycles.
v_cnt = plpy.execute("""
SELECT max(count) as max FROM (
SELECT count({vertex_id}) AS count
FROM {out_table}
GROUP BY {grouping_cols}) x
""".format(**locals()))[0]['max']
plpy.execute("DROP TABLE IF EXISTS {0}".format(distinct_grp_table))
else:
plpy.execute("""
INSERT INTO {out_table}
SELECT {vertex_id} AS {vertex_id},
{init_w} AS {weight},
NULL AS parent
FROM {vertex_table}
WHERE {vertex_id} IS NOT NULL
""".format(**locals()))
# The source can be reached with 0 cost and it has itself as the
# parent.
plpy.execute("""
INSERT INTO {oldupdate}
VALUES({source_vertex},0,{source_vertex})
""".format(**locals()))
v_cnt = plpy.execute("""
SELECT count(*) FROM {vertex_table}
WHERE {vertex_id} IS NOT NULL
""".format(**locals()))[0]['count']
for i in range(0, v_cnt + 1):
# Apply the updates calculated in the last iteration.
sql = """
UPDATE {out_table} SET
{weight} = oldupdate.{weight},
parent = oldupdate.parent
FROM
{oldupdate} AS oldupdate
WHERE
{out_table}.{vertex_id} = oldupdate.id AND
{out_table}.{weight} > oldupdate.{weight} {checkg_oo}
"""
ret = plpy.execute(sql.format(**locals()))
if ret.nrows() == 0:
break
plpy.execute("TRUNCATE TABLE {0}".format(newupdate))
# 'oldupdate' table has the update info from the last iteration
# Consider every edge that has an updated source
# From these edges:
# For every destination vertex, find the min total cost to reach.
# Note that, just calling an aggregate function with group by won't
# let us store the src field of the edge (needed for the parent).
# This is why we need the 'x'; it gives a list of destinations and
# associated min values. Using these values, we identify which edge
# is selected.
# Since using '=' with floats is dangerous we use an epsilon value
# for comparison.
# Once we have a list of edges and values (stores as 'message'),
# we check if these values are lower than the existing shortest
# path values.
sql = (""" INSERT INTO {newupdate}
SELECT DISTINCT ON (message.id {comma_grp})
message.id AS id,
message.{weight} AS {weight},
message.parent AS parent {comma_grp_m}
FROM {out_table} AS out_table INNER JOIN
(
SELECT {edge_table}.{dest} AS id, x.{weight} AS {weight},
oldupdate.id AS parent {comma_grp_e}
FROM {oldupdate} AS oldupdate INNER JOIN
{edge_table} ON
({edge_table}.{src} = oldupdate.id {checkg_eo})
INNER JOIN
(
SELECT {edge_table}.{dest} AS id,
min(oldupdate.{weight} +
{edge_table}.{weight}) AS {weight} {comma_grp_e}
FROM {oldupdate} AS oldupdate INNER JOIN
{edge_table} ON
({edge_table}.{src}=oldupdate.id {checkg_eo})
GROUP BY {edge_table}.{dest} {comma_grp_e}
) x
ON ({edge_table}.{dest} = x.id {checkg_ex} )
WHERE ABS(oldupdate.{weight} + {edge_table}.{weight}
- x.{weight}) < {EPSILON}
) message
ON (message.id = out_table.{vertex_id} {checkg_om})
WHERE message.{weight}<out_table.{weight}
""".format(**locals()))
plpy.execute(sql)
# Swap the update tables for the next iteration.
tmp = oldupdate
oldupdate = newupdate
newupdate = tmp
plpy.execute("DROP TABLE IF EXISTS {0}".format(newupdate))
# The algorithm should converge in less than |V| iterations.
# Otherwise there is a negative cycle in the graph.
if i == v_cnt:
if not grouping_cols:
plpy.execute("DROP TABLE IF EXISTS {0},{1},{2}".
format(out_table, summary_table, oldupdate))
plpy.error("Graph SSSP: Detected a negative cycle in the graph.")
# It is possible that not all groups has negative cycles.
else:
# negs is the string created by collating grouping columns.
# By looking at the oldupdate table we can see which groups
# are in a negative cycle.
negs = plpy.execute(
""" SELECT array_agg(DISTINCT ({grouping_cols})) AS grp
FROM {oldupdate}
""".format(**locals()))[0]['grp']
# Delete the groups with negative cycles from the output table.
sql_del = """ DELETE FROM {out_table}
USING {oldupdate} AS oldupdate
WHERE {checkg_oo_sub} """
plpy.execute(sql_del.format(**locals()))
# If every group has a negative cycle,
# drop the output table as well.
if table_is_empty(out_table):
plpy.execute("DROP TABLE IF EXISTS {0},{1}".
format(out_table, summary_table))
plpy.warning(
"""Graph SSSP: Detected a negative cycle in the """ +
"""sub-graphs of following groups: {0}.""".
format(str(negs)[1:-1]))
# Filter out the infinite paths (disconnected pairs)
plpy.execute(""" DELETE FROM {0} WHERE parent IS NULL
""".format(out_table))
plpy.execute("DROP TABLE IF EXISTS {0}".format(oldupdate))
return None
def graph_sssp_get_path(schema_madlib, sssp_table, dest_vertex, path_table,
**kwargs):
"""
Helper function that can be used to get the shortest path for a vertex
Args:
@param sssp_table Name of the table that contains the SSSP output.
@param dest_vertex The vertex that will be the destination of the
desired path.
@param path_table Name of the output table that contains the path.
"""
with MinWarning("warning"):
_validate_get_path(sssp_table, dest_vertex, path_table)
temp1_name = unique_string(desp='temp1')
temp2_name = unique_string(desp='temp2')
select_grps = ""
check_grps_t1 = ""
check_grps_t2 = ""
grp_comma = ""
tmp = ""
summary_table = add_postfix(sssp_table, "_summary")
summary = plpy.execute("SELECT * FROM {0}".format(summary_table))
vertex_id = summary[0]['vertex_id']
source_vertex = summary[0]['source_vertex']
edge_args = summary[0]['edge_args']
edge_params, final_type = get_edge_params(schema_madlib, sssp_table, edge_args)
weight = edge_params["weight"]
if not vertex_id:
vertex_id = "id"
grouping_cols = summary[0]['grouping_cols']
if not grouping_cols:
grouping_cols = None
if grouping_cols:
glist = split_quoted_delimited_str(grouping_cols)
select_grps = get_table_qualified_col_str(sssp_table, glist) + " , "
check_grps_t1 = " AND " + _check_groups(
sssp_table, temp1_name, glist)
check_grps_t2 = " AND " + _check_groups(
sssp_table, temp2_name, glist)
grp_comma = grouping_cols + " , "
if source_vertex == dest_vertex:
plpy.execute("""
CREATE TABLE {path_table} AS
SELECT {grp_comma} '{{{dest_vertex}}}'::{final_type}[] AS path
FROM {sssp_table} WHERE {vertex_id} = {dest_vertex}
""".format(**locals()))
return
plpy.execute("DROP TABLE IF EXISTS {0},{1}".
format(temp1_name, temp2_name))
out = plpy.execute(""" CREATE TEMP TABLE {temp1_name} AS
SELECT {grp_comma} {sssp_table}.parent AS {vertex_id},
ARRAY[{dest_vertex}]::{final_type}[] AS path
FROM {sssp_table}
WHERE {vertex_id} = {dest_vertex}
AND {sssp_table}.parent IS NOT NULL
""".format(**locals()))
plpy.execute("""
CREATE TEMP TABLE {temp2_name} AS
SELECT * FROM {temp1_name} LIMIT 0
""".format(**locals()))
# Follow the 'parent' chain until you reach the source.
while out.nrows() > 0:
plpy.execute("TRUNCATE TABLE {temp2_name}".format(**locals()))
# If the vertex id is not the source vertex,
# Add it to the path and move to its parent
out = plpy.execute(
""" INSERT INTO {temp2_name}
SELECT {select_grps} {sssp_table}.parent AS {vertex_id},
{sssp_table}.{vertex_id} || {temp1_name}.path AS path
FROM {sssp_table} INNER JOIN {temp1_name} ON
({sssp_table}.{vertex_id} = {temp1_name}.{vertex_id}
{check_grps_t1})
WHERE {source_vertex} <> {sssp_table}.{vertex_id}
""".format(**locals()))
tmp = temp2_name
temp2_name = temp1_name
temp1_name = tmp
tmp = check_grps_t1
check_grps_t1 = check_grps_t2
check_grps_t2 = tmp
# Add the source vertex to the beginning of every path and
# add the empty arrays for the groups that don't have a path to reach
# the destination vertex
plpy.execute("""
CREATE TABLE {path_table} AS
SELECT {grp_comma} {source_vertex}::{final_type} || path AS path
FROM {temp2_name}
UNION
SELECT {grp_comma} '{{}}'::{final_type}[] AS path
FROM {sssp_table}
WHERE {vertex_id} = {dest_vertex}
AND {sssp_table}.parent IS NULL
""".format(**locals()))
out = plpy.execute("SELECT 1 FROM {0} LIMIT 1".format(path_table))
if out.nrows() == 0:
plpy.error(
"Graph SSSP: Vertex {0} is not present in the SSSP table {1}".
format(dest_vertex, sssp_table))
plpy.execute("DROP TABLE IF EXISTS {temp1_name}, {temp1_name}".
format(**locals()))
return None
def _validate_sssp(vertex_table, vertex_id, edge_table, edge_params,
source_vertex, out_table, glist, **kwargs):
validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
out_table, 'SSSP')
_assert(isinstance(source_vertex, int) or isinstance(source_vertex, long),
"""Graph SSSP: Source vertex {source_vertex} has to be an integer.""".
format(**locals()))
src_exists = plpy.execute("""
SELECT * FROM {vertex_table} WHERE {vertex_id}={source_vertex}
""".format(**locals()))
if src_exists.nrows() == 0:
plpy.error("Graph SSSP: Source vertex {source_vertex} is not present "
"in the vertex table {vertex_table}.".format(**locals()))
vt_error = plpy.execute("""
SELECT {vertex_id}
FROM {vertex_table}
WHERE {vertex_id} IS NOT NULL
GROUP BY {vertex_id}
HAVING count(*) > 1
""".format(**locals()))
if vt_error.nrows() != 0:
plpy.error("Graph SSSP: Source vertex table {vertex_table} "
"contains duplicate vertex id's.".format(**locals()))
summary_table = add_postfix(out_table, "_summary")
_assert(not table_exists(summary_table),
"Graph SSSP: Output summary table already exists!")
if glist is not None:
_assert(columns_exist_in_table(edge_table, glist),
"Graph SSSP: Not all columns from {glist} are present in "
"edge table ({edge_table}).".format(**locals()))
return None
def _validate_get_path(sssp_table, dest_vertex, path_table, **kwargs):
_assert(sssp_table and sssp_table.strip().lower() not in ('null', ''),
"Graph SSSP: Invalid SSSP table name!")
_assert(table_exists(sssp_table),
"Graph SSSP: SSSP table ({0}) is missing!".format(sssp_table))
_assert(not table_is_empty(sssp_table),
"Graph SSSP: SSSP table ({0}) is empty!".format(sssp_table))
summary_table = add_postfix(sssp_table, "_summary")
_assert(table_exists(summary_table),
"Graph SSSP: SSSP summary table ({0}) is missing!".format(summary_table))
_assert(not table_is_empty(summary_table),
"Graph SSSP: SSSP summary table ({0}) is empty!".format(summary_table))
_assert(not table_exists(path_table),
"Graph SSSP: Output path table already exists!")
return None
def graph_sssp_help(schema_madlib, message, **kwargs):
"""
Help function for graph_sssp and graph_sssp_get_path
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Given a graph and a source vertex, single source shortest path (SSSP)
algorithm finds a path for every vertex such that the sum of the
weights of its constituent edges is minimized.
For more details on function usage:
SELECT {schema_madlib}.graph_sssp('usage')
"""
elif message.lower() in ['usage', 'help', '?']:
help_string = """
Given a graph and a source vertex, single source shortest path (SSSP)
algorithm finds a path for every vertex such that the sum of the
weights of its constituent edges is minimized.
{graph_usage}
To retrieve the path for a specific vertex:
SELECT {schema_madlib}.graph_sssp_get_path(
sssp_table TEXT, -- Name of the table that contains the SSSP output.
dest_vertex INT, -- The vertex that will be the destination of the
-- desired path.
path_table TEXT -- Name of the output table that contains the path.
);
----------------------------------------------------------------------------
OUTPUT
----------------------------------------------------------------------------
The output of SSSP ('out_table' above) contains a row for every vertex of
every group and have the following columns (in addition to the grouping
columns):
- vertex_id : The id for the destination. Will use the input parameter
'vertex_id' for column naming.
- weight : The total weight of the shortest path from the source vertex
to this particular vertex.
Will use the input parameter 'weight' for column naming.
- parent : The parent of this vertex in the shortest path from source.
Will use 'parent' for column naming.
The output of graph_sssp_get_path ('path_table' above) contains a row for
every group and has the following columns:
- grouping_cols : The grouping columns given in the creation of the SSSP
table. If there are no grouping columns, these columns
will not exist and the table will have a single row.
- path (ARRAY) : The shortest path from the source vertex (as specified
in the SSSP execution) to the destination vertex.
"""
else:
help_string = "No such option. Use {schema_madlib}.graph_sssp()"
common_usage_string = get_graph_usage(
schema_madlib, 'graph_sssp',
"""source_vertex INT, -- The source vertex id for the algorithm to start.
out_table TEXT, -- Name of the table to store the result of SSSP.
grouping_cols TEXT -- The list of grouping columns.""")
return help_string.format(schema_madlib=schema_madlib,
graph_usage=common_usage_string)
# ---------------------------------------------------------------------