[multistage] restructure runner test (#9489)

* fix dispatcher/server shutdown

* fix lint

Co-authored-by: Rong Rong <rongr@startree.ai>
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
new file mode 100644
index 0000000..31d1670
--- /dev/null
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query;
+
+import org.testng.annotations.DataProvider;
+
+
+public class QueryTestSet {
+
+  @DataProvider(name = "testSql")
+  public Object[][] provideTestSql() {
+    return new Object[][]{
+        // Order BY LIMIT
+        new Object[]{"SELECT * FROM b ORDER BY col1, col2 DESC LIMIT 3"},
+        new Object[]{"SELECT * FROM a ORDER BY col1, ts LIMIT 10"},
+        new Object[]{"SELECT * FROM a ORDER BY col1 LIMIT 20"},
+
+        // No match filter
+        new Object[]{"SELECT * FROM b WHERE col3 < 0.5"},
+
+        // Hybrid table
+        new Object[]{"SELECT * FROM d"},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // thus the final JOIN result will be 15 x 1 = 15.
+        // Next join with table C which has (5 on server1 and 10 on server2), since data is identical. each of the row
+        // of the A JOIN B will have identical value of col3 as table C.col3 has. Since the values are cycling between
+        // (1, 42, 1, 42, 1). we will have 9 1s, and 6 42s, total result count will be 9 * 9 + 6 * 6 = 117
+        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON a.col3 = c.col3"},
+        // Reverse the order of join condition and join table order.
+        new Object[]{"SELECT * FROM a JOIN b ON b.col1 = a.col1 JOIN c ON c.col3 = a.col3"},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // thus the final JOIN result will be 15 x 1 = 15.
+        new Object[]{"SELECT * FROM a JOIN b on a.col1 = b.col1"},
+
+        // Query with function in JOIN keys, table A and B are both (1, 42, 1, 42, 1), with table A cycling 3 times.
+        // Because:
+        //   - MOD(a.col3, 2) will have 6 (42)s equal to 0 and 9 (1)s equals to 1
+        //   - MOD(b.col3, 3) will have 2 (42)s equal to 0 and 3 (1)s equals to 1;
+        // final results are 6 * 2 + 9 * 3 = 39 rows
+        new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON MOD(a.col3, 2) = MOD(b.col3, 3)"},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // thus the final JOIN result will be 15 x 1 = 15.
+        new Object[]{"SELECT * FROM a JOIN b on a.col1 = b.col1 AND a.col2 = b.col2"},
+        // Reverse the order of join condition and join table order.
+        new Object[]{"SELECT * FROM a JOIN b on b.col1 = a.col1 AND b.col2 = a.col2"},
+
+        // LEFT JOIN
+        new Object[]{"SELECT * FROM a LEFT JOIN b on a.col1 = b.col2"},
+
+        new Object[]{"SELECT a.col1, SUM(CASE WHEN b.col3 IS NULL THEN 0 ELSE b.col3 END) "
+            + " FROM a LEFT JOIN b on a.col1 = b.col2 GROUP BY a.col1"},
+
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // but only 1 out of 5 rows from table A will be selected out; and all in table B will be selected.
+        // thus the final JOIN result will be 1 x 3 x 1 = 3.
+        new Object[]{"SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+            + " WHERE a.col3 >= 0 AND a.col2 = 'alice' AND b.col3 >= 0"},
+
+        // Join query with IN and Not-IN clause. Table A's side of join will return 9 rows and Table B's side will
+        // return 2 rows. Join will be only on col1=bar and since A will return 3 rows with that value and B will return
+        // 1 row, the final output will have 3 rows.
+        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
+            + " WHERE a.col1 IN ('foo', 'bar', 'alice') AND b.col2 NOT IN ('foo', 'alice')"},
+
+        // Same query as above but written using OR/AND instead of IN.
+        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
+            + " WHERE (a.col1 = 'foo' OR a.col1 = 'bar' OR a.col1 = 'alice') AND b.col2 != 'foo'"
+            + " AND b.col2 != 'alice'"},
+
+        // Same as above but with single argument IN clauses. Left side of the join returns 3 rows, and the right side
+        // returns 5 rows. Only key where join succeeds is col1=foo, and since table B has only 1 row with that value,
+        // the number of rows should be 3.
+        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
+            + " WHERE a.col1 IN ('foo') AND b.col2 NOT IN ('')"},
+
+        // Range conditions with continuous and non-continuous range.
+        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
+            + " WHERE a.col3 IN (1, 2, 3) OR (a.col3 > 10 AND a.col3 < 50)"},
+
+        new Object[]{"SELECT col1, SUM(col3) FROM a WHERE a.col3 BETWEEN 23 AND 36 "
+            + " GROUP BY col1 HAVING SUM(col3) > 10.0 AND MIN(col3) <> 123 AND MAX(col3) BETWEEN 10 AND 20"},
+
+        new Object[]{"SELECT col1, SUM(col3) FROM a WHERE (col3 > 0 AND col3 < 45) AND (col3 > 15 AND col3 < 50) "
+            + " GROUP BY col1 HAVING (SUM(col3) > 10 AND SUM(col3) < 20) AND (SUM(col3) > 30 AND SUM(col3) < 40)"},
+
+        // Projection pushdown
+        new Object[]{"SELECT a.col1, a.col3 + a.col3 FROM a WHERE a.col3 >= 0 AND a.col2 = 'alice'"},
+
+        // Inequality JOIN & partial filter pushdown
+        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col3 > b.col3"},
+
+        new Object[]{"SELECT * FROM a, b WHERE a.col1 > b.col2 AND a.col3 > b.col3"},
+
+        // Aggregation with group by
+        new Object[]{"SELECT a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 GROUP BY a.col1"},
+
+        // Aggregation with multiple group key
+        new Object[]{"SELECT a.col2, a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 GROUP BY a.col1, a.col2"},
+
+        // Aggregation without GROUP BY
+        new Object[]{"SELECT SUM(col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'alice'"},
+
+        // Aggregation with GROUP BY on a count star reference
+        new Object[]{"SELECT a.col1, COUNT(*) FROM a WHERE a.col3 >= 0 GROUP BY a.col1"},
+
+        // project in intermediate stage
+        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
+        // col1 on both are "foo", "bar", "alice", "bob", "charlie"
+        // col2 on both are "foo", "bar", "alice", "foo", "bar",
+        //   filtered at :    ^                      ^
+        // thus the final JOIN result will have 6 rows: 3 "foo" <-> "foo"; and 3 "bob" <-> "bob"
+        new Object[]{"SELECT a.col1, a.col2, a.ts, b.col1, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+            + " WHERE a.col3 >= 0 AND a.col2 = 'foo' AND b.col3 >= 0"},
+
+        // Making transform after JOIN, number of rows should be the same as JOIN result.
+        new Object[]{"SELECT a.col1, a.ts, a.col3 - b.col3 FROM a JOIN b ON a.col1 = b.col2 "
+            + " WHERE a.col3 >= 0 AND b.col3 >= 0"},
+
+        // Making transform after GROUP-BY, number of rows should be the same as GROUP-BY result.
+        new Object[]{"SELECT a.col1, a.col2, SUM(a.col3) - MIN(a.col3) FROM a"
+            + " WHERE a.col3 >= 0 GROUP BY a.col1, a.col2"},
+
+        // GROUP BY after JOIN
+        //   - optimizable transport for GROUP BY key after JOIN, using SINGLETON exchange
+        //     only 3 GROUP BY key exist because b.col2 cycles between "foo", "bar", "alice".
+        new Object[]{"SELECT a.col1, SUM(b.col3), COUNT(*), SUM(2) FROM a JOIN b ON a.col1 = b.col2 "
+            + " WHERE a.col3 >= 0 GROUP BY a.col1"},
+        //   - non-optimizable transport for GROUP BY key after JOIN, using HASH exchange
+        //     only 2 GROUP BY key exist for b.col3.
+        new Object[]{"SELECT b.col3, SUM(a.col3) FROM a JOIN b"
+            + " on a.col1 = b.col1 AND a.col2 = b.col2 GROUP BY b.col3"},
+
+        // Sub-query
+        new Object[]{"SELECT b.col1, b.col3, i.maxVal FROM b JOIN "
+            + "  (SELECT a.col2 AS joinKey, MAX(a.col3) AS maxVal FROM a GROUP BY a.col2) AS i "
+            + "  ON b.col1 = i.joinKey"},
+
+        // Sub-query with IN clause to SEMI JOIN.
+        new Object[]{"SELECT b.col1, b.col2, SUM(b.col3) * 100 / COUNT(b.col3) FROM b WHERE b.col1 IN "
+            + " (SELECT a.col2 FROM a WHERE a.col2 != 'foo') GROUP BY b.col1, b.col2"},
+        new Object[]{"SELECT SUM(b.col3) FROM b WHERE b.col3 > (SELECT AVG(a.col3) FROM a WHERE a.col2 != 'bar')"},
+
+        // Aggregate query with HAVING clause, "foo" and "bar" occurred 6/2 times each and "alice" occurred 3/1 times
+        // numbers are cycle in (1, 42, 1, 42, 1), and (foo, bar, alice, foo, bar)
+        // - COUNT(*) < 5 matches "alice" (3 times)
+        // - COUNT(*) > 5 matches "foo" and "bar" (6 times); so both will be selected out SUM(a.col3) = (1 + 42) * 3
+        // - last condition doesn't match anything.
+        // total to 3 rows.
+        new Object[]{"SELECT a.col2, COUNT(*), MAX(a.col3), MIN(a.col3), SUM(a.col3) FROM a GROUP BY a.col2 "
+            + "HAVING COUNT(*) < 5 OR (COUNT(*) > 5 AND SUM(a.col3) >= 10)"
+            + "OR (MIN(a.col3) != 20 AND SUM(a.col3) = 100)"},
+        new Object[]{"SELECT COUNT(*) AS Count, MAX(a.col3) AS \"max\" FROM a GROUP BY a.col2 "
+            + "HAVING Count > 1 AND \"max\" < 50"},
+
+        // Order-by
+        new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON a.col1 = b.col1 ORDER BY a.col3, b.col3 DESC"},
+        new Object[]{"SELECT MAX(a.col3) FROM a GROUP BY a.col2 ORDER BY MAX(a.col3) - MIN(a.col3)"},
+
+        // Test CAST
+        //   - implicit CAST
+        new Object[]{"SELECT a.col1, a.col2, AVG(a.col3) FROM a GROUP BY a.col1, a.col2"},
+        new Object[]{"SELECT a.col1 FROM a WHERE a.col3 >= 0.5 AND a.col3 < 0.7 OR a.col3 = 42.0"},
+        new Object[]{"SELECT a.col1, SUM(a.col3) FROM a GROUP BY a.col1 "
+            + " HAVING MIN(a.col3) > 0.5 AND MIN(a.col3) <> 0.7 OR MIN(a.col3) > 30"},
+        //   - explicit CAST
+        new Object[]{"SELECT a.col1, CAST(SUM(a.col3) AS BIGINT) FROM a GROUP BY a.col1"},
+
+        // Test DISTINCT
+        //   - distinct value done via GROUP BY with empty expr aggregation list.
+        new Object[]{"SELECT a.col2, a.col3 FROM a JOIN b ON a.col1 = b.col1 "
+            + " WHERE b.col3 > 0 GROUP BY a.col2, a.col3"},
+
+        // Test optimized constant literal.
+        new Object[]{"SELECT col1 FROM a WHERE col3 > 0 AND col3 < -5"},
+        new Object[]{"SELECT COALESCE(SUM(col3), 0) FROM a WHERE col1 = 'foo' AND col1 = 'bar'"},
+        new Object[]{"SELECT SUM(CAST(col3 AS INTEGER)) FROM a HAVING MIN(col3) BETWEEN 1 AND 0"},
+        new Object[]{"SELECT col1, COUNT(col3) FROM a GROUP BY col1 HAVING SUM(col3) > 40 AND SUM(col3) < 30"},
+    };
+  }
+}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 872eec4..e1f685a 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -46,7 +46,7 @@
     Assert.assertEquals(resultRows.size(), expectedRows);
   }
 
-  @Test(dataProvider = "testDataWithSql")
+  @Test(dataProvider = "testSql")
   public void testSqlWithH2Checker(String sql)
       throws Exception {
     List<Object[]> resultRows = queryRunner(sql);
@@ -144,180 +144,6 @@
     }
   }
 
-  @DataProvider(name = "testDataWithSql")
-  private Object[][] provideTestSql() {
-    return new Object[][]{
-        // Order BY LIMIT
-        new Object[]{"SELECT * FROM b ORDER BY col1, col2 DESC LIMIT 3"},
-        new Object[]{"SELECT * FROM a ORDER BY col1, ts LIMIT 10"},
-        new Object[]{"SELECT * FROM a ORDER BY col1 LIMIT 20"},
-
-        // No match filter
-        new Object[]{"SELECT * FROM b WHERE col3 < 0.5"},
-
-        // Hybrid table
-        new Object[]{"SELECT * FROM d"},
-
-        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-        // thus the final JOIN result will be 15 x 1 = 15.
-        // Next join with table C which has (5 on server1 and 10 on server2), since data is identical. each of the row
-        // of the A JOIN B will have identical value of col3 as table C.col3 has. Since the values are cycling between
-        // (1, 42, 1, 42, 1). we will have 9 1s, and 6 42s, total result count will be 9 * 9 + 6 * 6 = 117
-        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON a.col3 = c.col3"},
-        // Reverse the order of join condition and join table order.
-        new Object[]{"SELECT * FROM a JOIN b ON b.col1 = a.col1 JOIN c ON c.col3 = a.col3"},
-
-        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-        // thus the final JOIN result will be 15 x 1 = 15.
-        new Object[]{"SELECT * FROM a JOIN b on a.col1 = b.col1"},
-
-        // Query with function in JOIN keys, table A and B are both (1, 42, 1, 42, 1), with table A cycling 3 times.
-        // Because:
-        //   - MOD(a.col3, 2) will have 6 (42)s equal to 0 and 9 (1)s equals to 1
-        //   - MOD(b.col3, 3) will have 2 (42)s equal to 0 and 3 (1)s equals to 1;
-        // final results are 6 * 2 + 9 * 3 = 39 rows
-        new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON MOD(a.col3, 2) = MOD(b.col3, 3)"},
-
-        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-        // thus the final JOIN result will be 15 x 1 = 15.
-        new Object[]{"SELECT * FROM a JOIN b on a.col1 = b.col1 AND a.col2 = b.col2"},
-        // Reverse the order of join condition and join table order.
-        new Object[]{"SELECT * FROM a JOIN b on b.col1 = a.col1 AND b.col2 = a.col2"},
-
-        // LEFT JOIN
-        new Object[]{"SELECT * FROM a LEFT JOIN b on a.col1 = b.col2"},
-
-        new Object[]{"SELECT a.col1, SUM(CASE WHEN b.col3 IS NULL THEN 0 ELSE b.col3 END) "
-            + " FROM a LEFT JOIN b on a.col1 = b.col2 GROUP BY a.col1"},
-
-        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-        // but only 1 out of 5 rows from table A will be selected out; and all in table B will be selected.
-        // thus the final JOIN result will be 1 x 3 x 1 = 3.
-        new Object[]{"SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
-            + " WHERE a.col3 >= 0 AND a.col2 = 'alice' AND b.col3 >= 0"},
-
-        // Join query with IN and Not-IN clause. Table A's side of join will return 9 rows and Table B's side will
-        // return 2 rows. Join will be only on col1=bar and since A will return 3 rows with that value and B will return
-        // 1 row, the final output will have 3 rows.
-        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
-            + " WHERE a.col1 IN ('foo', 'bar', 'alice') AND b.col2 NOT IN ('foo', 'alice')"},
-
-        // Same query as above but written using OR/AND instead of IN.
-        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
-            + " WHERE (a.col1 = 'foo' OR a.col1 = 'bar' OR a.col1 = 'alice') AND b.col2 != 'foo'"
-            + " AND b.col2 != 'alice'"},
-
-        // Same as above but with single argument IN clauses. Left side of the join returns 3 rows, and the right side
-        // returns 5 rows. Only key where join succeeds is col1=foo, and since table B has only 1 row with that value,
-        // the number of rows should be 3.
-        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
-            + " WHERE a.col1 IN ('foo') AND b.col2 NOT IN ('')"},
-
-        // Range conditions with continuous and non-continuous range.
-        new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
-            + " WHERE a.col3 IN (1, 2, 3) OR (a.col3 > 10 AND a.col3 < 50)"},
-
-        new Object[]{"SELECT col1, SUM(col3) FROM a WHERE a.col3 BETWEEN 23 AND 36 "
-            + " GROUP BY col1 HAVING SUM(col3) > 10.0 AND MIN(col3) <> 123 AND MAX(col3) BETWEEN 10 AND 20"},
-
-        new Object[]{"SELECT col1, SUM(col3) FROM a WHERE (col3 > 0 AND col3 < 45) AND (col3 > 15 AND col3 < 50) "
-            + " GROUP BY col1 HAVING (SUM(col3) > 10 AND SUM(col3) < 20) AND (SUM(col3) > 30 AND SUM(col3) < 40)"},
-
-        // Projection pushdown
-        new Object[]{"SELECT a.col1, a.col3 + a.col3 FROM a WHERE a.col3 >= 0 AND a.col2 = 'alice'"},
-
-        // Inequality JOIN & partial filter pushdown
-        new Object[]{"SELECT * FROM a JOIN b ON a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col3 > b.col3"},
-
-        new Object[]{"SELECT * FROM a, b WHERE a.col1 > b.col2 AND a.col3 > b.col3"},
-
-        // Aggregation with group by
-        new Object[]{"SELECT a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 GROUP BY a.col1"},
-
-        // Aggregation with multiple group key
-        new Object[]{"SELECT a.col2, a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 GROUP BY a.col1, a.col2"},
-
-        // Aggregation without GROUP BY
-        new Object[]{"SELECT SUM(col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'alice'"},
-
-        // Aggregation with GROUP BY on a count star reference
-        new Object[]{"SELECT a.col1, COUNT(*) FROM a WHERE a.col3 >= 0 GROUP BY a.col1"},
-
-        // project in intermediate stage
-        // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
-        // col1 on both are "foo", "bar", "alice", "bob", "charlie"
-        // col2 on both are "foo", "bar", "alice", "foo", "bar",
-        //   filtered at :    ^                      ^
-        // thus the final JOIN result will have 6 rows: 3 "foo" <-> "foo"; and 3 "bob" <-> "bob"
-        new Object[]{"SELECT a.col1, a.col2, a.ts, b.col1, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
-            + " WHERE a.col3 >= 0 AND a.col2 = 'foo' AND b.col3 >= 0"},
-
-        // Making transform after JOIN, number of rows should be the same as JOIN result.
-        new Object[]{"SELECT a.col1, a.ts, a.col3 - b.col3 FROM a JOIN b ON a.col1 = b.col2 "
-            + " WHERE a.col3 >= 0 AND b.col3 >= 0"},
-
-        // Making transform after GROUP-BY, number of rows should be the same as GROUP-BY result.
-        new Object[]{"SELECT a.col1, a.col2, SUM(a.col3) - MIN(a.col3) FROM a"
-            + " WHERE a.col3 >= 0 GROUP BY a.col1, a.col2"},
-
-        // GROUP BY after JOIN
-        //   - optimizable transport for GROUP BY key after JOIN, using SINGLETON exchange
-        //     only 3 GROUP BY key exist because b.col2 cycles between "foo", "bar", "alice".
-        new Object[]{"SELECT a.col1, SUM(b.col3), COUNT(*), SUM(2) FROM a JOIN b ON a.col1 = b.col2 "
-            + " WHERE a.col3 >= 0 GROUP BY a.col1"},
-        //   - non-optimizable transport for GROUP BY key after JOIN, using HASH exchange
-        //     only 2 GROUP BY key exist for b.col3.
-        new Object[]{"SELECT b.col3, SUM(a.col3) FROM a JOIN b"
-            + " on a.col1 = b.col1 AND a.col2 = b.col2 GROUP BY b.col3"},
-
-        // Sub-query
-        new Object[]{"SELECT b.col1, b.col3, i.maxVal FROM b JOIN "
-            + "  (SELECT a.col2 AS joinKey, MAX(a.col3) AS maxVal FROM a GROUP BY a.col2) AS i "
-            + "  ON b.col1 = i.joinKey"},
-
-        // Sub-query with IN clause to SEMI JOIN.
-        new Object[]{"SELECT b.col1, b.col2, SUM(b.col3) * 100 / COUNT(b.col3) FROM b WHERE b.col1 IN "
-            + " (SELECT a.col2 FROM a WHERE a.col2 != 'foo') GROUP BY b.col1, b.col2"},
-        new Object[]{"SELECT SUM(b.col3) FROM b WHERE b.col3 > (SELECT AVG(a.col3) FROM a WHERE a.col2 != 'bar')"},
-
-        // Aggregate query with HAVING clause, "foo" and "bar" occurred 6/2 times each and "alice" occurred 3/1 times
-        // numbers are cycle in (1, 42, 1, 42, 1), and (foo, bar, alice, foo, bar)
-        // - COUNT(*) < 5 matches "alice" (3 times)
-        // - COUNT(*) > 5 matches "foo" and "bar" (6 times); so both will be selected out SUM(a.col3) = (1 + 42) * 3
-        // - last condition doesn't match anything.
-        // total to 3 rows.
-        new Object[]{"SELECT a.col2, COUNT(*), MAX(a.col3), MIN(a.col3), SUM(a.col3) FROM a GROUP BY a.col2 "
-            + "HAVING COUNT(*) < 5 OR (COUNT(*) > 5 AND SUM(a.col3) >= 10)"
-            + "OR (MIN(a.col3) != 20 AND SUM(a.col3) = 100)"},
-        new Object[]{"SELECT COUNT(*) AS Count, MAX(a.col3) AS \"max\" FROM a GROUP BY a.col2 "
-            + "HAVING Count > 1 AND \"max\" < 50"},
-
-        // Order-by
-        new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON a.col1 = b.col1 ORDER BY a.col3, b.col3 DESC"},
-        new Object[]{"SELECT MAX(a.col3) FROM a GROUP BY a.col2 ORDER BY MAX(a.col3) - MIN(a.col3)"},
-
-        // Test CAST
-        //   - implicit CAST
-        new Object[]{"SELECT a.col1, a.col2, AVG(a.col3) FROM a GROUP BY a.col1, a.col2"},
-        new Object[]{"SELECT a.col1 FROM a WHERE a.col3 >= 0.5 AND a.col3 < 0.7 OR a.col3 = 42.0"},
-        new Object[]{"SELECT a.col1, SUM(a.col3) FROM a GROUP BY a.col1 "
-            + " HAVING MIN(a.col3) > 0.5 AND MIN(a.col3) <> 0.7 OR MIN(a.col3) > 30"},
-        //   - explicit CAST
-        new Object[]{"SELECT a.col1, CAST(SUM(a.col3) AS BIGINT) FROM a GROUP BY a.col1"},
-
-        // Test DISTINCT
-        //   - distinct value done via GROUP BY with empty expr aggregation list.
-        new Object[]{"SELECT a.col2, a.col3 FROM a JOIN b ON a.col1 = b.col1 "
-            + " WHERE b.col3 > 0 GROUP BY a.col2, a.col3"},
-
-        // Test optimized constant literal.
-        new Object[]{"SELECT col1 FROM a WHERE col3 > 0 AND col3 < -5"},
-        new Object[]{"SELECT COALESCE(SUM(col3), 0) FROM a WHERE col1 = 'foo' AND col1 = 'bar'"},
-        new Object[]{"SELECT SUM(CAST(col3 AS INTEGER)) FROM a HAVING MIN(col3) BETWEEN 1 AND 0"},
-        new Object[]{"SELECT col1, COUNT(col3) FROM a GROUP BY col1 HAVING SUM(col3) > 40 AND SUM(col3) < 30"},
-    };
-  }
-
   @DataProvider(name = "testDataWithSqlToFinalRowCount")
   private Object[][] provideTestSqlAndRowCount() {
     return new Object[][] {
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
index 702e7b4..9ad1723 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTestBase.java
@@ -36,6 +36,7 @@
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestUtils;
 import org.apache.pinot.query.QueryServerEnclosure;
+import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.mailbox.GrpcMailboxService;
 import org.apache.pinot.query.routing.WorkerInstance;
 import org.apache.pinot.query.service.QueryConfig;
@@ -50,7 +51,7 @@
 
 
 
-public class QueryRunnerTestBase {
+public class QueryRunnerTestBase extends QueryTestSet {
   private static final File INDEX_DIR_S1_A = new File(FileUtils.getTempDirectory(), "QueryRunnerTest_server1_tableA");
   private static final File INDEX_DIR_S1_B = new File(FileUtils.getTempDirectory(), "QueryRunnerTest_server1_tableB");
   private static final File INDEX_DIR_S1_C = new File(FileUtils.getTempDirectory(), "QueryRunnerTest_server1_tableC");
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryDispatcherTest.java
index 9057d1a..34e52f5 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryDispatcherTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryDispatcherTest.java
@@ -25,6 +25,7 @@
 import java.util.Random;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestUtils;
+import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.planner.PlannerUtils;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.runtime.QueryRunner;
@@ -32,11 +33,10 @@
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
-import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 
-public class QueryDispatcherTest {
+public class QueryDispatcherTest extends QueryTestSet {
   private static final Random RANDOM_REQUEST_ID_GEN = new Random();
   private static final int QUERY_SERVER_COUNT = 2;
   private final Map<Integer, QueryServer> _queryServerMap = new HashMap<>();
@@ -70,23 +70,13 @@
     }
   }
 
-  @Test(dataProvider = "testDataWithSqlToCompiledAsWorkerRequest")
+  @Test(dataProvider = "testSql")
   public void testQueryDispatcherCanSendCorrectPayload(String sql)
       throws Exception {
     QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
     QueryDispatcher dispatcher = new QueryDispatcher();
     int reducerStageId = dispatcher.submit(RANDOM_REQUEST_ID_GEN.nextLong(), queryPlan);
     Assert.assertTrue(PlannerUtils.isRootStage(reducerStageId));
-  }
-
-  @DataProvider(name = "testDataWithSqlToCompiledAsWorkerRequest")
-  private Object[][] provideTestSqlToCompiledToWorkerRequest() {
-    return new Object[][] {
-        new Object[]{"SELECT * FROM b"},
-        new Object[]{"SELECT * FROM a"},
-        new Object[]{"SELECT * FROM a JOIN b ON a.col3 = b.col3"},
-        new Object[]{"SELECT a.col1, a.ts, c.col2, c.col3 FROM a JOIN c ON a.col1 = c.col2 "
-            + " WHERE (a.col3 >= 0 OR a.col2 = 'foo') AND c.col3 >= 0"},
-    };
+    dispatcher.shutdown();
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
index 245f78e..7202ab8 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/QueryServerTest.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.service;
 
 import com.google.common.collect.Lists;
+import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -31,6 +32,7 @@
 import org.apache.pinot.core.transport.ServerInstance;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestUtils;
+import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.planner.QueryPlan;
 import org.apache.pinot.query.planner.StageMetadata;
 import org.apache.pinot.query.planner.stage.StageNode;
@@ -42,13 +44,12 @@
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
-import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import static org.mockito.ArgumentMatchers.any;
 
 
-public class QueryServerTest {
+public class QueryServerTest extends QueryTestSet {
   private static final Random RANDOM_REQUEST_ID_GEN = new Random();
   private static final int QUERY_SERVER_COUNT = 2;
   private final Map<Integer, QueryServer> _queryServerMap = new HashMap<>();
@@ -88,7 +89,7 @@
   }
 
   @SuppressWarnings("unchecked")
-  @Test(dataProvider = "testDataWithSqlToCompiledAsWorkerRequest")
+  @Test(dataProvider = "testSql")
   public void testWorkerAcceptsWorkerRequestCorrect(String sql)
       throws Exception {
     QueryPlan queryPlan = _queryEnvironment.planQuery(sql);
@@ -126,17 +127,6 @@
     }
   }
 
-  @DataProvider(name = "testDataWithSqlToCompiledAsWorkerRequest")
-  private Object[][] provideTestSqlToCompiledToWorkerRequest() {
-    return new Object[][] {
-        new Object[]{"SELECT * FROM b"},
-        new Object[]{"SELECT * FROM a"},
-        new Object[]{"SELECT * FROM a JOIN b ON a.col3 = b.col3"},
-        new Object[]{"SELECT a.col1, a.ts, c.col2, c.col3 FROM a JOIN c ON a.col1 = c.col2 "
-            + " WHERE (a.col3 >= 0 OR a.col2 = 'foo') AND c.col3 >= 0"},
-    };
-  }
-
   private static boolean isMetadataMapsEqual(StageMetadata left, StageMetadata right) {
     return left.getServerInstances().equals(right.getServerInstances())
         && left.getServerInstanceToSegmentsMap().equals(right.getServerInstanceToSegmentsMap())
@@ -163,11 +153,12 @@
   private void submitRequest(Worker.QueryRequest queryRequest) {
     String host = queryRequest.getMetadataMap().get("SERVER_INSTANCE_HOST");
     int port = Integer.parseInt(queryRequest.getMetadataMap().get("SERVER_INSTANCE_PORT"));
-    PinotQueryWorkerGrpc.PinotQueryWorkerBlockingStub stub =
-        PinotQueryWorkerGrpc.newBlockingStub(ManagedChannelBuilder.forAddress(host, port).usePlaintext().build());
+    ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
+    PinotQueryWorkerGrpc.PinotQueryWorkerBlockingStub stub = PinotQueryWorkerGrpc.newBlockingStub(channel);
     Worker.QueryResponse resp = stub.submit(queryRequest);
     // TODO: validate meaningful return value
     Assert.assertNotNull(resp.getMetadataMap().get("OK"));
+    channel.shutdown();
   }
 
   private Worker.QueryRequest getQueryRequest(QueryPlan queryPlan, int stageId) {