WCC: Add warm start

WCC creates a large number of subtransactions which may cause system
performance degredation in some cases. This cpmmit adds a parameter to
limit the number of iterations it runs as well as another one to
continue from the incomplete state.
diff --git a/src/ports/postgres/modules/graph/graph_utils.py_in b/src/ports/postgres/modules/graph/graph_utils.py_in
index 9c0a1c7..8b0cf06 100644
--- a/src/ports/postgres/modules/graph/graph_utils.py_in
+++ b/src/ports/postgres/modules/graph/graph_utils.py_in
@@ -74,14 +74,18 @@
                 "Graph WCC: Output table {0} already exists.".format(out_table))
 
 def validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
-                          out_table, func_name, **kwargs):
+                          out_table, func_name, warm_start = False, **kwargs):
     """
     Validates graph tables (vertex and edge) as well as the output table.
     """
     _assert(out_table and out_table.strip().lower() not in ('null', ''),
-            "Graph {func_name}: Invalid output table name!".format(**locals()))
-    _assert(not table_exists(out_table),
-            "Graph {func_name}: Output table already exists!".format(**locals()))
+                "Graph {func_name}: Invalid output table name!".format(**locals()))
+    if warm_start:
+        _assert(table_exists(out_table),
+                "Graph {func_name}: Output table is missing for warm start!".format(**locals()))
+    else:
+        _assert(not table_exists(out_table),
+                "Graph {func_name}: Output table already exists!".format(**locals()))
 
     _assert(vertex_table and vertex_table.strip().lower() not in ('null', ''),
             "Graph {func_name}: Invalid vertex table name!".format(**locals()))
diff --git a/src/ports/postgres/modules/graph/test/wcc.sql_in b/src/ports/postgres/modules/graph/test/wcc.sql_in
index f7af686..5012246 100644
--- a/src/ports/postgres/modules/graph/test/wcc.sql_in
+++ b/src/ports/postgres/modules/graph/test/wcc.sql_in
@@ -173,12 +173,12 @@
     'src=src_node,dest=dest_node','out','user_id');
 SELECT * FROM out;
 
-ALTER TABLE vertex RENAME COLUMN dest TO id;
+ALTER TABLE vertex RENAME COLUMN dest TO vertex_id;
 
 -- Test for bigint columns
-
-CREATE TABLE v2 AS SELECT (id+992147483647)::bigint as id FROM vertex;
-CREATE TABLE e2 AS SELECT (src_node+992147483647)::bigint as src, (dest_node+992147483647)::bigint as dest FROM "EDGE";
+DROP TABLE IF EXISTS v2,e2;
+CREATE TABLE v2 AS SELECT (vertex_id+992147483647)::bigint as id FROM vertex;
+CREATE TABLE e2 AS SELECT (src_node+992147483647)::bigint as src, (dest_node+992147483647)::bigint as dest, user_id FROM "EDGE";
 
 SELECT weakly_connected_components('v2',NULL,'e2',NULL,'pg_temp.wcc_out');
 SELECT count(*) from pg_temp.wcc_out;
@@ -188,7 +188,7 @@
 -- The datasets have the columns doubled so that the same tests can be run on the output tables
 
 DROP TABLE IF EXISTS vertex_mult, edge_mult CASCADE;
-CREATE TABLE vertex_mult AS SELECT id AS id1, id AS id2 FROM vertex;
+CREATE TABLE vertex_mult AS SELECT vertex_id AS id1, vertex_id AS id2 FROM vertex;
 CREATE TABLE edge_mult AS
 SELECT src_node AS src1, src_node AS src2,
        dest_node AS dest1, dest_node AS dest2,
@@ -276,3 +276,82 @@
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+-- Warm Start
+
+-- Without grouping
+DROP TABLE IF EXISTS wcc_non_warm_start_out, wcc_non_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_non_warm_start_out');
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary, wcc_warm_start_out_message;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', NULL, 1);
+
+SELECT assert(count(*) > 0, 'Weakly Connected Components: Empty warm start summary table.')
+FROM wcc_warm_start_out_summary;
+SELECT assert(nodes_to_update > 0,
+        'Weakly Connected Components: Warm start incorrect nodes_to_update.'
+    ) FROM wcc_warm_start_out_summary;
+SELECT assert(count(*) > 0, 'Weakly Connected Components: Empty warm start message table.')
+FROM wcc_warm_start_out_message;
+
+SELECT assert(nodes_to_update > 0,
+        'Weakly Connected Components: Incorrect nodes to update count.'
+    ) FROM wcc_warm_start_out_summary;
+
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', NULL, 1, True);
+
+SELECT assert(count(*) > 0, 'Weakly Connected Components: Empty warm start summary table.')
+FROM wcc_warm_start_out_summary;
+SELECT assert(nodes_to_update > 0,
+        'Weakly Connected Components: Warm start incorrect nodes_to_update.'
+    ) FROM wcc_warm_start_out_summary;
+SELECT assert(count(*) > 0, 'Weakly Connected Components: Empty warm start message table.')
+FROM wcc_warm_start_out_message;
+
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', NULL, 2, True);
+
+SELECT assert(nodes_to_update = 0,
+        'Weakly Connected Components: Warm start incorrect nodes_to_update.'
+    ) FROM wcc_warm_start_out_summary;
+
+SELECT assert(count(*) < 0.00001, 'Weakly Connected Components: Different warm start result.')
+FROM wcc_non_warm_start_out w1, wcc_warm_start_out w2
+WHERE w1.id = w2.id AND w1.component_id != w2.component_id;
+
+SELECT assert(relative_error(count(*), 4) < 0.00001,
+        'Weakly Connected Components: Warm start incorrect component_id.'
+    ) FROM wcc_warm_start_out
+WHERE component_id = 992147483657;
+
+-- With grouping
+DROP TABLE IF EXISTS wcc_non_warm_start_out, wcc_non_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_non_warm_start_out', 'user_id');
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary, wcc_warm_start_out_message;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 'user_id', 2);
+
+SELECT assert(count(*) > 0, 'Weakly Connected Components: Empty warm start summary table.')
+FROM wcc_warm_start_out_summary;
+SELECT assert(count(*) > 0, 'Weakly Connected Components: Empty warm start message table.')
+FROM wcc_warm_start_out_message;
+
+SELECT assert(nodes_to_update > 0,
+        'Weakly Connected Components: Incorrect nodes to update count.'
+    ) FROM wcc_warm_start_out_summary;
+
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 'user_id', 2, True);
+
+SELECT assert(count(*) = 0, 'Weakly Connected Components: Different warm start result.')
+FROM wcc_non_warm_start_out w1, wcc_warm_start_out w2
+WHERE w1.id = w2.id AND w1.user_id = w2.user_id AND w1.component_id != w2.component_id;
+
+SELECT assert(relative_error(count(*), 4) < 0.00001,
+        'Weakly Connected Components: Warm start incorrect component_id.'
+    ) FROM wcc_warm_start_out WHERE user_id=1 AND component_id = 992147483657;
+
+SELECT assert(count(table_name) = 0, 'Weakly Connected Components: Found leftover temp tables.')
+FROM
+    information_schema.tables
+WHERE
+    table_schema LIKE 'madlib_installcheck_%' AND
+    table_name LIKE '__madlib_temp_%';
diff --git a/src/ports/postgres/modules/graph/wcc.py_in b/src/ports/postgres/modules/graph/wcc.py_in
index fe0d877..d4f9022 100644
--- a/src/ports/postgres/modules/graph/wcc.py_in
+++ b/src/ports/postgres/modules/graph/wcc.py_in
@@ -32,6 +32,7 @@
 from utilities.utilities import _assert
 from utilities.utilities import _check_groups
 from utilities.utilities import get_table_qualified_col_str
+from utilities.utilities import is_platform_gp6_or_up
 from utilities.utilities import extract_keyvalue_params
 from utilities.utilities import unique_string, split_quoted_delimited_str
 from utilities.validate_args import columns_exist_in_table, get_expr_type
@@ -44,17 +45,37 @@
 from graph_utils import validate_graph_coding, get_graph_usage
 from graph_utils import validate_output_and_summary_tables
 
-
-def validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
-                      edge_params, out_table, out_table_summary,
-                      grouping_cols_list, module_name):
+def validate_wcc_args(schema_madlib, vertex_table, vertex_table_in, vertex_id,
+                      vertex_id_in, edge_table, edge_params, edge_args,
+                      out_table, out_table_summary, grouping_cols,
+                      grouping_cols_list, warm_start, out_table_message,
+                      module_name):
     """
     Function to validate input parameters for wcc
     """
     validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
-                          out_table, module_name)
-    _assert(not table_exists(out_table_summary),
-            "Graph {module_name}: Output summary table already exists!".format(**locals()))
+                          out_table, module_name, warm_start)
+    if not warm_start:
+        _assert(not table_exists(out_table_summary),
+                "Graph {module_name}: Output summary table already exists!".format(**locals()))
+        _assert(not table_exists(out_table_message),
+                "Graph {module_name}: Output message table already exists!".format(**locals()))
+    else:
+        _assert(table_exists(out_table_summary),
+                "Graph {module_name}: Output summary table is missing for warm start!".format(**locals()))
+        _assert(table_exists(out_table_message),
+                "Graph {module_name}: Output message table is missing for warm start! Either wcc was completed in the last run or the table got dropped/renamed.".format(**locals()))
+
+        prev_summary = plpy.execute("SELECT * FROM {0}".format(out_table_summary))[0]
+        _assert(prev_summary['vertex_table'] == vertex_table_in, "Graph {module_name}: Warm start vertex_table do not match!".format(**locals()))
+        if vertex_id_in:
+            _assert(prev_summary['vertex_id'] == vertex_id_in, "Graph {module_name}: Warm start vertex_id do not match!".format(**locals()))
+        _assert(prev_summary['edge_table'] == edge_table, "Graph {module_name}: Warm start edge_table do not match!".format(**locals()))
+        if edge_args:
+            _assert(prev_summary['edge_args'] == edge_args, "Graph {module_name}: Warm start edge_args do not match!".format(**locals()))
+
+        _assert(prev_summary['grouping_cols'] == grouping_cols, "Graph {module_name}: Warm start grouping_cols do not match!".format(**locals()))
+
     if grouping_cols_list:
         # validate the grouping columns. We currently only support grouping_cols
         # to be column names in the edge_table, and not expressions!
@@ -63,7 +84,7 @@
                 "One or more grouping columns specified do not exist!")
 
 def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
-        out_table, grouping_cols, **kwargs):
+        out_table, grouping_cols, iteration_limit=0, warm_start=False, **kwargs):
     """
     Function that computes the wcc
 
@@ -76,8 +97,10 @@
         @param grouping_cols
     """
 
+    BIGINT_MAX = 9223372036854775807
     vertex_table_in = vertex_table
     vertex_id_in = vertex_id
+    edge_table_in = edge_table
 
     old_msg_level = plpy.execute("""
                                   SELECT setting
@@ -90,6 +113,11 @@
     edge_params = extract_keyvalue_params(
         edge_args, params_types, default_args)
 
+    if iteration_limit is None or iteration_limit == 0:
+        iteration_limit = BIGINT_MAX
+    elif iteration_limit < 0:
+        plpy.error("Weakly Connected Components: iteration_limit cannot be a negative number.")
+
     # populate default values for optional params if null, and prepare data
     # to be written into the summary table (*_st variable names)
     vertex_view = unique_string('vertex_view')
@@ -98,7 +126,7 @@
     vertex_view_sql = """
         CREATE VIEW {vertex_view} AS
         SELECT {vertex_sql} AS id, {vertex_sql} AS {single_id}
-        FROM {vertex_table}
+        FROM {vertex_table};
         """
     if not vertex_id:
         vertex_id = "id"
@@ -156,22 +184,28 @@
         grouping_sql = ', {0}'.format(grouping_cols)
 
     out_table_summary = ''
+    message = ''
     if out_table:
         out_table_summary = add_postfix(out_table, "_summary")
+        message = add_postfix(out_table, "_message")
+
     grouping_cols_list = split_quoted_delimited_str(grouping_cols)
-    validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
-                      edge_params, out_table, out_table_summary,
-                      grouping_cols_list, 'Weakly Connected Components')
+
+    validate_wcc_args(schema_madlib, vertex_table, vertex_table_in,
+                      vertex_id, vertex_id_in, edge_table,
+                      edge_params, edge_args, out_table, out_table_summary,
+                      grouping_cols, grouping_cols_list, warm_start, message,
+                      'Weakly Connected Components')
 
     vertex_view_sql = vertex_view_sql.format(**locals())
-    plpy.execute(vertex_view_sql)
 
-    sql = """
+    edge_view_sql = """
         CREATE VIEW {edge_view} AS
         SELECT {src} AS src, {dest} AS dest {grouping_sql}
-        FROM {edge_table}
+        FROM {edge_table};
         """.format(**locals())
-    plpy.execute(sql)
+
+    plpy.execute(vertex_view_sql + edge_view_sql)
 
     vertex_table = vertex_view
     edge_table = edge_view
@@ -181,9 +215,7 @@
 
     distribution = '' if is_platform_pg() else "DISTRIBUTED BY (id)"
 
-    message = unique_string(desp='message')
     oldupdate = unique_string(desp='oldupdate')
-    newupdate = unique_string(desp='newupdate')
     toupdate = unique_string(desp='toupdate')
     temp_out_table = unique_string(desp='tempout')
     edge_inverse = unique_string(desp='edge_inverse')
@@ -196,7 +228,8 @@
     edge_to_update_where_condition = ''
     edge_inverse_to_update_where_condition = ''
 
-    BIGINT_MAX = 9223372036854775807
+    distinct_grp_sql = ""
+
     component_id = 'component_id'
     grouping_cols_comma = '' if not grouping_cols else grouping_cols + ','
     comma_grouping_cols = '' if not grouping_cols else ',' + grouping_cols
@@ -211,6 +244,14 @@
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        out_table_sql = ""
+        msg_sql = ""
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            out_table_sql = """
+                ALTER TABLE {out_table} RENAME COLUMN {vertex_id_in} TO {vertex_id};
+            """.format(**locals())
+
     if grouping_cols:
         distribution = ('' if is_platform_pg() else
                         "DISTRIBUTED BY ({0}, {1})".format(grouping_cols,
@@ -225,9 +266,9 @@
             get_table_qualified_col_str(oldupdate, grouping_cols_list)
         subq_prefixed_grouping_cols = get_table_qualified_col_str(subq, grouping_cols_list)
         old_new_update_where_condition = ' AND ' + \
-            _check_groups(oldupdate, newupdate, grouping_cols_list)
+            _check_groups(oldupdate, out_table, grouping_cols_list)
         new_to_update_where_condition = ' AND ' + \
-            _check_groups(newupdate, toupdate, grouping_cols_list)
+            _check_groups(out_table, toupdate, grouping_cols_list)
         edge_to_update_where_condition = ' AND ' + \
             _check_groups(edge_table, toupdate, grouping_cols_list)
         edge_inverse_to_update_where_condition = ' AND ' + \
@@ -235,78 +276,119 @@
         join_grouping_cols = _check_groups(subq, distinct_grp_table, grouping_cols_list)
         group_by_clause_newupdate = ('{0}, {1}.{2}'.format(subq_prefixed_grouping_cols,
                                                            subq, vertex_id))
+        select_grouping_cols = ',' + subq_prefixed_grouping_cols
 
-        grp_sql = """
-                CREATE TABLE {distinct_grp_table} AS
-                SELECT DISTINCT {grouping_cols} FROM {edge_table};
-            """
-        plpy.execute(grp_sql.format(**locals()))
-
-        prep_sql = """
-            CREATE TABLE {newupdate} AS
-            SELECT {subq}.{vertex_id},
-                    CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-                    {select_grouping_cols}
-            FROM {distinct_grp_table} INNER JOIN (
-                SELECT {grouping_cols_comma} {src} AS {vertex_id}
-                FROM {edge_table}
-                UNION
-                SELECT {grouping_cols_comma} {dest} AS {vertex_id}
-                FROM {edge_inverse}
-            ) {subq}
-            ON {join_grouping_cols}
-            GROUP BY {group_by_clause_newupdate}
-            {distribution};
-
-            DROP TABLE IF EXISTS {distinct_grp_table};
-
-        """.format(select_grouping_cols=',' + subq_prefixed_grouping_cols,
-                   **locals())
-        plpy.execute(prep_sql)
-
-        message_sql = """
-            CREATE TABLE {message} AS
-            SELECT {vertex_table}.{vertex_id},
-                    CAST({vertex_table}.{single_id} AS BIGINT) AS {component_id}
-                    {comma_grouping_cols}
-            FROM {newupdate} INNER JOIN {vertex_table}
-            ON {vertex_table}.{vertex_id} = {newupdate}.{vertex_id}
-            {distribution};
+        distinct_grp_sql = """
+            CREATE TABLE {distinct_grp_table} AS
+            SELECT DISTINCT {grouping_cols} FROM {edge_table};
         """
-        plpy.execute(message_sql.format(**locals()))
+
+        if not warm_start:
+            out_table_sql = """
+                CREATE TABLE {out_table} AS
+                SELECT {subq}.{vertex_id},
+                        CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+                        {select_grouping_cols}
+                FROM {distinct_grp_table} INNER JOIN (
+                    SELECT {grouping_cols_comma} {src} AS {vertex_id}
+                    FROM {edge_table}
+                    UNION
+                    SELECT {grouping_cols_comma} {dest} AS {vertex_id}
+                    FROM {edge_inverse}
+                ) {subq}
+                ON {join_grouping_cols}
+                GROUP BY {group_by_clause_newupdate}
+                {distribution};
+            """.format(**locals())
+            msg_sql = """
+                CREATE TABLE {message} AS
+                SELECT {vertex_table}.{vertex_id},
+                        CAST({vertex_table}.{single_id} AS BIGINT) AS {component_id}
+                        {comma_grouping_cols}
+                FROM {out_table} INNER JOIN {vertex_table}
+                ON {vertex_table}.{vertex_id} = {out_table}.{vertex_id}
+                {distribution};
+            """.format(**locals())
+
     else:
-        prep_sql = """
-            CREATE TABLE {newupdate} AS
-            SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-            FROM {vertex_table}
-            {distribution};
+        if not warm_start:
+            out_table_sql = """
+                CREATE TABLE {out_table} AS
+                SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+                FROM {vertex_table}
+                {distribution};
+            """.format(**locals())
+            msg_sql = """
+                CREATE TABLE {message} AS
+                SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS {component_id}
+                FROM {vertex_table}
+                {distribution};
+            """.format(**locals())
 
-            CREATE TABLE {message} AS
-            SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS {component_id}
-            FROM {vertex_table}
-            {distribution};
-        """
-        plpy.execute(prep_sql.format(**locals()))
-
-    oldupdate_sql = """
-            CREATE TABLE {oldupdate} AS
-            SELECT {message}.{vertex_id},
-                    MIN({message}.{component_id}) AS {component_id}
-                    {comma_grouping_cols}
-            FROM {message}
-            GROUP BY {grouping_cols_comma} {vertex_id}
-            LIMIT 0
-            {distribution};
+    old_update_sql = """
+        CREATE TABLE {oldupdate} AS
+        SELECT {message}.{vertex_id},
+                MIN({message}.{component_id}) AS {component_id}
+                {comma_grouping_cols}
+        FROM {message}
+        GROUP BY {grouping_cols_comma} {vertex_id}
+        LIMIT 0
+        {distribution};
     """
-    plpy.execute(oldupdate_sql.format(**locals()))
+    to_update_sql = """
+        CREATE TABLE {toupdate} AS
+        SELECT * FROM {oldupdate}
+        {distribution};
+    """
 
-    toupdate_sql = """
-            CREATE TABLE {toupdate} AS
-            SELECT * FROM {oldupdate}
-            {distribution};
-        """
-    plpy.execute(toupdate_sql.format(**locals()))
+    # We combine the sql statements as much as possible to reduce the number of
+    # subtransactions we create. Postgres and Greenplum has a limit of 64 for
+    # cached subtx and each plpy.execute create one.
+    # Postgres and Greenplum 5 do not like creating a table and using it in the
+    # same plpy.execute so we have too keep them seperate for this step but the
+    # loop is combined for all platforms.
+    if is_platform_gp6_or_up():
+        plpy.execute((distinct_grp_sql + out_table_sql + msg_sql + old_update_sql + to_update_sql).format(**locals()))
+    else:
+        if distinct_grp_sql != "":
+            plpy.execute(distinct_grp_sql.format(**locals()))
+        plpy.execute(out_table_sql.format(**locals()))
+        plpy.execute(msg_sql.format(**locals()))
+        plpy.execute(old_update_sql.format(**locals()))
+        plpy.execute(to_update_sql.format(**locals()))
+
     nodes_to_update = 1
+
+    """
+    WCC Logic:
+    Assume we have the following graph: [1,2] [2,3] [2,4]
+    The first iteration starts with a number of set up steps.
+    For vertex 2, the component_id is set to 2.
+    The relevant work start with the creation of message table.
+    1)
+    message gets filled in two steps, one for incoming edges and one for outgoing.
+    The logic looks for every neighbor of a vertex and takes the minimum component id it sees.
+    For vertex 2, message will have two entries, 2->1 and 2->3. 2->4 got eliminated because 3<4.
+    2)
+    next iteration starts with oldupdate.
+    This table is used to reduce the two possible messages into one.
+    For vertex 2, oldupdate will have 2->1. 2->3 got eliminated because 1<3
+    3)
+    toupdate is used to check if the update is necessary.
+    We compare the incoming component_id value with the existing one.
+    For vertex 2, toupdate will have 2->1 because 1<2.
+    4) The out_table gets updated with the contents of toupdate table.
+    5) A new message is created based on the toupdate table as mentioned above.
+
+    Warm Start:
+    To facilitate warm start we use two tables:
+    - out_table: it contains the results so far, we will continue building on this table
+    - message: This is the message at the end of an iteration for the next one.
+    We save these two tables if the iteration_limit is reached and nodes_to_update > 0.
+    When the user starts again with warm_start on, we plug these two tables back and continue as usual.
+    """
+
+    # Use truncate instead of drop/recreate to avoid catalog bloat.
     loop_sql = """
         TRUNCATE TABLE {oldupdate};
 
@@ -323,15 +405,15 @@
         SELECT {oldupdate}.{vertex_id},
                 {oldupdate}.{component_id}
                 {comma_oldupdate_prefixed_grouping_cols}
-        FROM {oldupdate}, {newupdate}
-        WHERE {oldupdate}.{vertex_id}={newupdate}.{vertex_id}
-            AND {oldupdate}.{component_id}<{newupdate}.{component_id}
+        FROM {oldupdate}, {out_table}
+        WHERE {oldupdate}.{vertex_id}={out_table}.{vertex_id}
+            AND {oldupdate}.{component_id}<{out_table}.{component_id}
             {old_new_update_where_condition};
 
-        UPDATE {newupdate} SET
+        UPDATE {out_table} SET
             {component_id}={toupdate}.{component_id}
             FROM {toupdate}
-            WHERE {newupdate}.{vertex_id}={toupdate}.{vertex_id}
+            WHERE {out_table}.{vertex_id}={toupdate}.{vertex_id}
                 {new_to_update_where_condition};
 
         TRUNCATE TABLE {message};
@@ -355,55 +437,52 @@
         GROUP BY {edge_table}.{dest} {comma_toupdate_prefixed_grouping_cols};
 
         TRUNCATE TABLE {oldupdate};
+
+        SELECT COUNT(*) AS cnt_sum FROM {toupdate};
     """
-    while nodes_to_update > 0:
-        # Look at all the neighbors of a node, and assign the smallest node id
-        # among the neighbors as its component_id. The next table starts off
-        # with very high component_id (BIGINT_MAX). The component_id of all nodes
-        # which obtain a smaller component_id after looking at its neighbors are
-        # updated in the next table. At every iteration update only those nodes
-        # whose component_id in the previous iteration are greater than what was
-        # found in the current iteration.
+    iteration_counter = 0
+    while nodes_to_update > 0 and iteration_counter < iteration_limit:
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
 
-            plpy.execute(loop_sql.format(**locals()))
-
-            if grouping_cols:
-                nodes_to_update = plpy.execute("""
-                                    SELECT SUM(cnt) AS cnt_sum
-                                    FROM (
-                                        SELECT COUNT(*) AS cnt
-                                        FROM {toupdate}
-                                        GROUP BY {grouping_cols}
-                                    ) t
-                    """.format(**locals()))[0]["cnt_sum"]
-            else:
-                nodes_to_update = plpy.execute("""
-                                    SELECT COUNT(*) AS cnt FROM {toupdate}
-                                """.format(**locals()))[0]["cnt"]
+            nodes_to_update = plpy.execute(loop_sql.format(**locals()))[0]["cnt_sum"]
+            iteration_counter += 1
 
     if not is_platform_pg():
         # Drop intermediate table created for Greenplum
         plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse))
 
-    rename_table(schema_madlib, newupdate, out_table)
     if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
         plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO {vertex_id_in}".format(**locals()))
-    # Create summary table. We only need the vertex_id and grouping columns
-    # in it.
+
+    if nodes_to_update is None or nodes_to_update == 0:
+        nodes_to_update = 0
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(message))
 
     plpy.execute("DROP VIEW IF EXISTS {0}, {1}".format(vertex_view, edge_view))
-    plpy.execute("""
-        CREATE TABLE {out_table_summary} AS SELECT
-            {grouping_cols_summary}
-            '{vertex_table_in}'::TEXT AS vertex_table,
-            '{vertex_id_in}'::TEXT AS vertex_id,
-            '{vertex_type}'::TEXT AS vertex_id_type;
+    if not warm_start:
+        plpy.execute("""
+            CREATE TABLE {out_table_summary} AS SELECT
+                '{grouping_cols_summary}'::TEXT AS grouping_cols,
+                '{vertex_table_in}'::TEXT AS vertex_table,
+                '{vertex_id_in}'::TEXT AS vertex_id,
+                '{vertex_type}'::TEXT AS vertex_id_type,
+                '{edge_table_in}'::TEXT AS edge_table,
+                '{edge_args}'::TEXT AS edge_args,
+                {iteration_counter}::BIGINT AS iteration_counter,
+                {nodes_to_update}::BIGINT AS nodes_to_update;
+        """.format(grouping_cols_summary='' if not grouping_cols else grouping_cols,
+                   **locals()))
+    else:
+        plpy.execute("""
+            UPDATE {out_table_summary} SET
+            iteration_counter = iteration_counter + {iteration_counter},
+            nodes_to_update = {nodes_to_update};
+        """.format(**locals()))
 
-        DROP TABLE IF EXISTS {message},{oldupdate},{newupdate},{toupdate};
-    """.format(grouping_cols_summary='' if not grouping_cols else
-                    "'{0}'::TEXT AS grouping_cols, ".format(grouping_cols),
-               **locals()))
+    if grouping_cols:
+        plpy.execute("DROP TABLE IF EXISTS {distinct_grp_table}".format(**locals()))
+    plpy.execute("DROP TABLE IF EXISTS  {oldupdate}, {toupdate}".format(**locals()))
+    plpy.execute("DROP VIEW IF EXISTS  {vertex_view}, {edge_view}".format(**locals()))
 
 
 # WCC Helper functions:
diff --git a/src/ports/postgres/modules/graph/wcc.sql_in b/src/ports/postgres/modules/graph/wcc.sql_in
index 26a8a8d..594b74a 100644
--- a/src/ports/postgres/modules/graph/wcc.sql_in
+++ b/src/ports/postgres/modules/graph/wcc.sql_in
@@ -61,7 +61,9 @@
             edge_table,
             edge_args,
             out_table,
-            grouping_cols
+            grouping_cols,
+            iteration_limit,
+            warm_start
           )
 </pre>
 
@@ -115,6 +117,26 @@
 (single graph).
 @note Expressions are not currently supported for 'grouping_cols'.</dd>
 
+<dt>iteration_limit (optional)</dt>
+<dd>INTEGER, default: NULL. Maximum number of iterations to run wcc. This
+parameter is used to stop wcc early to limit the number of subtransactions
+created by wcc. For such subtx issues, it is advised to set this parameter to
+40. A wcc run that stopped early by this parameter can resume its progress by
+using the warm_start parameter.
+An additional table named <out_table>_message is also created. This table is
+necessary in case the iteration_limit is reached and there are still vertices to
+update. It gets used when the wcc continues the process via warm_start and
+gets dropped when the wcc determines there are no more updates necessary.
+The user might determine if the wcc is completed or not by checking the
+nodes_to_update column of <out_table>_summary table where 0 means wcc is
+complete. </dd>
+
+<dt>warm_start (optional)</dt>
+<dd>BOOLEAN, default: NULL. If True, wcc will look for the <out_table>_message
+table and continue using it and the partial output from <out_table> to continue
+the wcc process.
+</dd>
+
 </dl>
 
 @note On a Greenplum cluster, the edge table should be distributed
@@ -592,6 +614,33 @@
     edge_table              TEXT,
     edge_args               TEXT,
     out_table               TEXT,
+    grouping_cols           TEXT,
+    iteration_limit         INTEGER,
+    warm_start              BOOLEAN
+) RETURNS VOID AS $$
+    PythonFunction(graph, wcc, wcc)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+-------------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.weakly_connected_components(
+    vertex_table            TEXT,
+    vertex_id               TEXT,
+    edge_table              TEXT,
+    edge_args               TEXT,
+    out_table               TEXT,
+    grouping_cols           TEXT,
+    iteration_limit         INTEGER
+) RETURNS VOID AS $$
+    PythonFunction(graph, wcc, wcc)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+-------------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.weakly_connected_components(
+    vertex_table            TEXT,
+    vertex_id               TEXT,
+    edge_table              TEXT,
+    edge_args               TEXT,
+    out_table               TEXT,
     grouping_cols           TEXT
 
 ) RETURNS VOID AS $$