@orhankislal WCC: Optimize subtx count and catalog entry frequency (#573)

* WCC: Optimize subtx count and catalog entry frequency

WCC had a high number of plpy.execute commands. Since each call
creates a new subtransaction, it was constantly hitting the
overflow limit. This commit merges most of them.

WCC also created and dropped temp tables in each iteration.
This bloats the catalog, so this commit uses a few actual
tables and uses truncate/insert.
diff --git a/src/ports/postgres/modules/graph/test/wcc.sql_in b/src/ports/postgres/modules/graph/test/wcc.sql_in
index 0917ae2..e1dd7b5 100644
--- a/src/ports/postgres/modules/graph/test/wcc.sql_in
+++ b/src/ports/postgres/modules/graph/test/wcc.sql_in
@@ -180,7 +180,6 @@
 CREATE TABLE v2 AS SELECT (id+992147483647)::bigint as id FROM vertex;
 CREATE TABLE e2 AS SELECT (src_node+992147483647)::bigint as src, (dest_node+992147483647)::bigint as dest FROM "EDGE";
 
-DROP TABLE IF EXISTS pg_temp.out2, pg_temp.out2_summary;
-SELECT weakly_connected_components('v2',NULL,'e2',NULL,'pg_temp.out2');
-SELECT count(*) from pg_temp.out2;
-SELECT count(*) from pg_temp.out2_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'pg_temp.wcc_out');
+SELECT count(*) from pg_temp.wcc_out;
+SELECT count(*) from pg_temp.wcc_out_summary;
diff --git a/src/ports/postgres/modules/graph/wcc.py_in b/src/ports/postgres/modules/graph/wcc.py_in
index c8b5ab2..0c80aab 100644
--- a/src/ports/postgres/modules/graph/wcc.py_in
+++ b/src/ports/postgres/modules/graph/wcc.py_in
@@ -28,6 +28,7 @@
 """
 
 import plpy
+from utilities.control import SetGUC
 from utilities.utilities import _assert
 from utilities.utilities import _check_groups
 from utilities.utilities import get_table_qualified_col_str
@@ -130,8 +131,8 @@
         # In Greenplum, to avoid redistribution of data when in later queries,
         # edge_table is duplicated by creating a temporary table distributed
         # on dest column
-        plpy.execute(""" CREATE TEMP TABLE {edge_inverse} AS
-                            SELECT * FROM {edge_table} DISTRIBUTED BY ({dest})
+        plpy.execute(""" CREATE TABLE {edge_inverse} AS
+                         SELECT * FROM {edge_table} DISTRIBUTED BY ({dest});
                      """.format(**locals()))
     else:
         edge_inverse = edge_table
@@ -143,10 +144,6 @@
         # Update some variables useful for grouping based query strings
         subq = unique_string(desp='subquery')
         distinct_grp_table = unique_string(desp='grptable')
-        plpy.execute("""
-                CREATE TABLE {distinct_grp_table} AS
-                SELECT DISTINCT {grouping_cols} FROM {edge_table}
-            """.format(**locals()))
 
         comma_toupdate_prefixed_grouping_cols = ', ' + \
             get_table_qualified_col_str(toupdate, grouping_cols_list)
@@ -162,52 +159,128 @@
         edge_inverse_to_update_where_condition = ' AND ' + \
             _check_groups(edge_inverse, toupdate, grouping_cols_list)
         join_grouping_cols = _check_groups(subq, distinct_grp_table, grouping_cols_list)
-        group_by_clause_newupdate = ('' if not grouping_cols else
-                                     '{0}, {1}.{2}'.format(subq_prefixed_grouping_cols,
+        group_by_clause_newupdate = ('{0}, {1}.{2}'.format(subq_prefixed_grouping_cols,
                                                            subq, vertex_id))
-        plpy.execute("""
-                CREATE TABLE {newupdate} AS
-                SELECT {subq}.{vertex_id},
-                        CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-                        {select_grouping_cols}
-                FROM {distinct_grp_table} INNER JOIN (
-                    SELECT {select_grouping_cols_clause} {src} AS {vertex_id}
-                    FROM {edge_table}
-                    UNION
-                    SELECT {select_grouping_cols_clause} {dest} AS {vertex_id}
-                    FROM {edge_inverse}
-                ) {subq}
-                ON {join_grouping_cols}
-                GROUP BY {group_by_clause_newupdate}
-                {distribution}
-            """.format(select_grouping_cols=',' + subq_prefixed_grouping_cols,
-                       select_grouping_cols_clause=grouping_cols_comma,
-                       **locals()))
-        # drop intermediate table
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(distinct_grp_table))
-        plpy.execute("""
-                CREATE TEMP TABLE {message} AS
-                SELECT {vertex_id},
-                        CAST({vertex_id} AS BIGINT) AS {component_id}
-                        {select_grouping_cols_clause}
-                FROM {newupdate}
-                {distribution}
-            """.format(select_grouping_cols_clause=comma_grouping_cols,
-                       **locals()))
+
+        grp_sql = """
+                CREATE TABLE {distinct_grp_table} AS
+                SELECT DISTINCT {grouping_cols} FROM {edge_table};
+            """
+        plpy.execute(grp_sql.format(**locals()))
+
+        prep_sql = """
+            CREATE TABLE {newupdate} AS
+            SELECT {subq}.{vertex_id},
+                    CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+                    {select_grouping_cols}
+            FROM {distinct_grp_table} INNER JOIN (
+                SELECT {grouping_cols_comma} {src} AS {vertex_id}
+                FROM {edge_table}
+                UNION
+                SELECT {grouping_cols_comma} {dest} AS {vertex_id}
+                FROM {edge_inverse}
+            ) {subq}
+            ON {join_grouping_cols}
+            GROUP BY {group_by_clause_newupdate}
+            {distribution};
+
+            DROP TABLE IF EXISTS {distinct_grp_table};
+
+        """.format(select_grouping_cols=',' + subq_prefixed_grouping_cols,
+                   **locals())
+        plpy.execute(prep_sql)
+
+        message_sql = """
+            CREATE TABLE {message} AS
+            SELECT {vertex_id},
+                    CAST({vertex_id} AS BIGINT) AS {component_id}
+                    {comma_grouping_cols}
+            FROM {newupdate}
+            {distribution};
+        """
+        plpy.execute(message_sql.format(**locals()))
     else:
-        plpy.execute("""
-                CREATE TABLE {newupdate} AS
-                SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-                FROM {vertex_table}
-                {distribution}
-            """.format(**locals()))
-        plpy.execute("""
-                CREATE TEMP TABLE {message} AS
-                SELECT {vertex_id}, CAST({vertex_id} AS BIGINT) AS {component_id}
-                FROM {vertex_table}
-                {distribution}
-            """.format(**locals()))
+        prep_sql = """
+            CREATE TABLE {newupdate} AS
+            SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+            FROM {vertex_table}
+            {distribution};
+
+            CREATE TABLE {message} AS
+            SELECT {vertex_id}, CAST({vertex_id} AS BIGINT) AS {component_id}
+            FROM {vertex_table}
+            {distribution};
+        """
+        plpy.execute(prep_sql.format(**locals()))
+
+    oldupdate_sql = """
+            CREATE TABLE {oldupdate} AS
+            SELECT {message}.{vertex_id},
+                    MIN({message}.{component_id}) AS {component_id}
+                    {comma_grouping_cols}
+            FROM {message}
+            GROUP BY {grouping_cols_comma} {vertex_id}
+            LIMIT 0
+            {distribution};
+    """
+    plpy.execute(oldupdate_sql.format(**locals()))
+
+    toupdate_sql = """
+            CREATE TABLE {toupdate} AS
+            SELECT * FROM {oldupdate}
+            {distribution};
+        """
+    plpy.execute(toupdate_sql.format(**locals()))
     nodes_to_update = 1
+    loop_sql = """
+        TRUNCATE TABLE {oldupdate};
+
+        INSERT INTO {oldupdate}
+        SELECT {message}.{vertex_id},
+                MIN({message}.{component_id}) AS {component_id}
+                {comma_grouping_cols}
+        FROM {message}
+        GROUP BY {grouping_cols_comma} {vertex_id};
+
+        TRUNCATE TABLE {toupdate};
+
+        INSERT INTO {toupdate}
+        SELECT {oldupdate}.{vertex_id},
+                {oldupdate}.{component_id}
+                {comma_oldupdate_prefixed_grouping_cols}
+        FROM {oldupdate}, {newupdate}
+        WHERE {oldupdate}.{vertex_id}={newupdate}.{vertex_id}
+            AND {oldupdate}.{component_id}<{newupdate}.{component_id}
+            {old_new_update_where_condition};
+
+        UPDATE {newupdate} SET
+            {component_id}={toupdate}.{component_id}
+            FROM {toupdate}
+            WHERE {newupdate}.{vertex_id}={toupdate}.{vertex_id}
+                {new_to_update_where_condition};
+
+        TRUNCATE TABLE {message};
+
+        INSERT INTO {message}
+        SELECT {edge_inverse}.{src} AS {vertex_id},
+            MIN({toupdate}.{component_id}) AS {component_id}
+            {comma_toupdate_prefixed_grouping_cols}
+        FROM {toupdate}, {edge_inverse}
+        WHERE {edge_inverse}.{dest} = {toupdate}.{vertex_id}
+            {edge_inverse_to_update_where_condition}
+        GROUP BY {edge_inverse}.{src} {comma_toupdate_prefixed_grouping_cols};
+
+        INSERT INTO {message}
+        SELECT {edge_table}.{dest} AS {vertex_id},
+            MIN({toupdate}.{component_id}) AS {component_id}
+            {comma_toupdate_prefixed_grouping_cols}
+        FROM {toupdate}, {edge_table}
+        WHERE {edge_table}.{src} = {toupdate}.{vertex_id}
+            {edge_to_update_where_condition}
+        GROUP BY {edge_table}.{dest} {comma_toupdate_prefixed_grouping_cols};
+
+        TRUNCATE TABLE {oldupdate};
+    """
     while nodes_to_update > 0:
         # Look at all the neighbors of a node, and assign the smallest node id
         # among the neighbors as its component_id. The next table starts off
@@ -216,81 +289,23 @@
         # updated in the next table. At every iteration update only those nodes
         # whose component_id in the previous iteration are greater than what was
         # found in the current iteration.
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(oldupdate))
-        plpy.execute("""
-            CREATE TEMP TABLE {oldupdate} AS
-            SELECT {message}.{vertex_id},
-                    MIN({message}.{component_id}) AS {component_id}
-                    {grouping_cols_select}
-            FROM {message}
-            GROUP BY {group_by_clause} {vertex_id}
-            {distribution}
-        """.format(grouping_cols_select='' if not grouping_cols else
-                   ', {0}'.format(grouping_cols),
-                   group_by_clause=grouping_cols_comma,
-                   **locals()))
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(toupdate))
-        plpy.execute("""
-            CREATE TEMP TABLE {toupdate} AS
-            SELECT {oldupdate}.{vertex_id},
-                    {oldupdate}.{component_id}
-                    {comma_oldupdate_prefixed_grouping_cols}
-            FROM {oldupdate}, {newupdate}
-            WHERE {oldupdate}.{vertex_id}={newupdate}.{vertex_id}
-                AND {oldupdate}.{component_id}<{newupdate}.{component_id}
-                {old_new_update_where_condition}
-            {distribution}
-        """.format(**locals()))
+        with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
 
-        plpy.execute("""
-                UPDATE {newupdate} SET
-                {component_id}={toupdate}.{component_id}
-                FROM {toupdate}
-                WHERE {newupdate}.{vertex_id}={toupdate}.{vertex_id}
-                    {new_to_update_where_condition}
-            """.format(**locals()))
+            plpy.execute(loop_sql.format(**locals()))
 
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(message))
-        plpy.execute("""
-            CREATE TEMP TABLE {message} AS
-                SELECT {edge_inverse}.{src} AS {vertex_id},
-                    MIN({toupdate}.{component_id}) AS {component_id}
-                    {comma_toupdate_prefixed_grouping_cols}
-                FROM {toupdate}, {edge_inverse}
-                WHERE {edge_inverse}.{dest} = {toupdate}.{vertex_id}
-                    {edge_inverse_to_update_where_condition}
-                GROUP BY {edge_inverse}.{src} {comma_toupdate_prefixed_grouping_cols}
-        """.format(select_grouping_cols='' if not grouping_cols
-                        else ', {0}'.format(grouping_cols),
-                   **locals()))
-
-        plpy.execute("""
-            INSERT INTO {message}
-                SELECT {edge_table}.{dest} AS {vertex_id},
-                    MIN({toupdate}.{component_id}) AS {component_id}
-                    {comma_toupdate_prefixed_grouping_cols}
-                FROM {toupdate}, {edge_table}
-                WHERE {edge_table}.{src} = {toupdate}.{vertex_id}
-                    {edge_to_update_where_condition}
-                GROUP BY {edge_table}.{dest} {comma_toupdate_prefixed_grouping_cols}
-        """.format(select_grouping_cols='' if not grouping_cols
-                        else ', {0}'.format(grouping_cols),
-                   **locals()))
-
-        plpy.execute("DROP TABLE {0}".format(oldupdate))
-        if grouping_cols:
-            nodes_to_update = plpy.execute("""
-                                SELECT SUM(cnt) AS cnt_sum
-                                FROM (
-                                    SELECT COUNT(*) AS cnt
-                                    FROM {toupdate}
-                                    GROUP BY {grouping_cols}
-                                ) t
-                """.format(**locals()))[0]["cnt_sum"]
-        else:
-            nodes_to_update = plpy.execute("""
-                                SELECT COUNT(*) AS cnt FROM {toupdate}
-                            """.format(**locals()))[0]["cnt"]
+            if grouping_cols:
+                nodes_to_update = plpy.execute("""
+                                    SELECT SUM(cnt) AS cnt_sum
+                                    FROM (
+                                        SELECT COUNT(*) AS cnt
+                                        FROM {toupdate}
+                                        GROUP BY {grouping_cols}
+                                    ) t
+                    """.format(**locals()))[0]["cnt_sum"]
+            else:
+                nodes_to_update = plpy.execute("""
+                                    SELECT COUNT(*) AS cnt FROM {toupdate}
+                                """.format(**locals()))[0]["cnt"]
 
     if not is_platform_pg():
         # Drop intermediate table created for Greenplum
@@ -299,25 +314,19 @@
     rename_table(schema_madlib, newupdate, out_table)
     # Create summary table. We only need the vertex_id and grouping columns
     # in it.
-    plpy.execute("""
-            CREATE TABLE {out_table_summary} (
-                {grouping_cols_summary}
-                vertex_table    TEXT,
-                vertex_id   TEXT,
-                vertex_id_type  TEXT
-            )
-        """.format(grouping_cols_summary='' if not grouping_cols else
-                   'grouping_cols TEXT, ', **locals()))
     vertex_id_type = get_expr_type(vertex_id, vertex_table)
-    plpy.execute("""
-            INSERT INTO {out_table_summary} VALUES
-            ({grouping_cols_summary} '{vertex_table}', '{vertex_id}',
-            '{vertex_id_type}')
-        """.format(grouping_cols_summary='' if not grouping_cols else
-                   "'{0}', ".format(grouping_cols), **locals()))
-    plpy.execute("DROP TABLE IF EXISTS {0},{1},{2},{3} ".
-                 format(message, oldupdate, newupdate, toupdate))
 
+    plpy.execute("""
+        CREATE TABLE {out_table_summary} AS SELECT
+            {grouping_cols_summary}
+            '{vertex_table}'::TEXT AS vertex_table,
+            '{vertex_id}'::TEXT AS vertex_id,
+            '{vertex_id_type}'::TEXT AS vertex_id_type;
+
+        DROP TABLE IF EXISTS {message},{oldupdate},{newupdate},{toupdate};
+    """.format(grouping_cols_summary='' if not grouping_cols else
+                    "'{0}'::TEXT AS grouping_cols, ".format(grouping_cols),
+               **locals()))
 
 # WCC Helper functions:
 def extract_wcc_summary_cols(wcc_summary_table):