[BEAM-11839] Improve DeferredFrameTest._run_test (#14552)

* Improve _run_test

* DeferredFrameTest, check that nonparallel operations raise

* Move expect_error logic to a separate method

* Add docstrings

* Address review comments
diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py
index f819757..3a8010a 100644
--- a/sdks/python/apache_beam/dataframe/frames_test.py
+++ b/sdks/python/apache_beam/dataframe/frames_test.py
@@ -37,71 +37,109 @@
 })
 
 
+def _get_deferred_args(*args):
+  return [
+      frame_base.DeferredFrame.wrap(
+          expressions.ConstantExpression(arg, arg[0:0])) for arg in args
+  ]
+
+
 class DeferredFrameTest(unittest.TestCase):
-  def _run_test(self, func, *args, distributed=True, expect_error=False):
-    deferred_args = [
-        frame_base.DeferredFrame.wrap(
-            expressions.ConstantExpression(arg, arg[0:0])) for arg in args
-    ]
+  def _run_error_test(self, func, *args):
+    """Verify that func(*args) raises the same exception in pandas and in Beam.
+
+    Note that for Beam this only checks for exceptions that are raised during
+    expression generation (i.e. construction time). Execution time exceptions
+    are not helpful."""
+    deferred_args = _get_deferred_args(*args)
+
+    # Get expected error
     try:
       expected = func(*args)
     except Exception as e:
-      if not expect_error:
-        raise
-      expected = e
+      expected_error = e
     else:
-      if expect_error:
-        raise AssertionError(
-            "Expected an error but computing expected result successfully "
-            f"returned: {expected}")
+      raise AssertionError(
+          "Expected an error, but executing with pandas successfully "
+          f"returned:\n{expected}")
 
-    session_type = (
-        expressions.PartitioningSession if distributed else expressions.Session)
+    # Get actual error
     try:
-      actual = session_type({}).evaluate(func(*deferred_args)._expr)
+      _ = func(*deferred_args)._expr
     except Exception as e:
-      if not expect_error:
-        raise
       actual = e
     else:
-      if expect_error:
-        raise AssertionError(
-            "Expected an error:\n{expected}\nbut successfully "
-            f"returned:\n{actual}")
+      raise AssertionError(
+          "Expected an error:\n{expected_error}\nbut Beam successfully "
+          "generated an expression.")
 
-    if expect_error:
-      if not isinstance(actual,
-                        type(expected)) or not str(actual) == str(expected):
-        raise AssertionError(
-            f'Expected {expected!r} to be raised, but got {actual!r}'
-        ) from actual
+    # Verify
+    if (not isinstance(actual, type(expected_error)) or
+        not str(actual) == str(expected_error)):
+      raise AssertionError(
+          f'Expected {expected_error!r} to be raised, but got {actual!r}'
+      ) from actual
+
+  def _run_test(self, func, *args, distributed=True, nonparallel=False):
+    """Verify that func(*args) produces the same result in pandas and in Beam.
+
+    Args:
+        distributed (bool): Whether or not to use PartitioningSession to
+            simulate parallel execution.
+        nonparallel (bool): Whether or not this function contains a
+            non-parallelizable operation. If True, the expression will be
+            generated twice, once outside of an allow_non_parallel_operations
+            block (to verify NonParallelOperation is raised), and again inside
+            of an allow_non_parallel_operations block to actually generate an
+            expression to verify."""
+    # Compute expected value
+    expected = func(*args)
+
+    # Compute actual value
+    deferred_args = _get_deferred_args(*args)
+    if nonparallel:
+      # First run outside a nonparallel block to confirm this raises as expected
+      with self.assertRaises(expressions.NonParallelOperation):
+        _ = func(*deferred_args)
+
+      # Re-run in an allow non parallel block to get an expression to verify
+      with beam.dataframe.allow_non_parallel_operations():
+        expr = func(*deferred_args)._expr
     else:
-      if isinstance(expected, pd.core.generic.NDFrame):
-        if distributed:
-          if expected.index.is_unique:
-            expected = expected.sort_index()
-            actual = actual.sort_index()
-          else:
-            expected = expected.sort_values(list(expected.columns))
-            actual = actual.sort_values(list(actual.columns))
+      expr = func(*deferred_args)._expr
 
-        if isinstance(expected, pd.Series):
-          pd.testing.assert_series_equal(expected, actual)
-        elif isinstance(expected, pd.DataFrame):
-          pd.testing.assert_frame_equal(expected, actual)
+    # Compute the result of the generated expression
+    session_type = (
+        expressions.PartitioningSession if distributed else expressions.Session)
+
+    actual = session_type({}).evaluate(expr)
+
+    # Verify
+    if isinstance(expected, pd.core.generic.NDFrame):
+      if distributed:
+        if expected.index.is_unique:
+          expected = expected.sort_index()
+          actual = actual.sort_index()
         else:
-          raise ValueError(
-              f"Expected value is a {type(expected)},"
-              "not a Series or DataFrame.")
+          expected = expected.sort_values(list(expected.columns))
+          actual = actual.sort_values(list(actual.columns))
+
+      if isinstance(expected, pd.Series):
+        pd.testing.assert_series_equal(expected, actual)
+      elif isinstance(expected, pd.DataFrame):
+        pd.testing.assert_frame_equal(expected, actual)
       else:
-        # Expectation is not a pandas object
-        if isinstance(expected, float):
-          cmp = lambda x: np.isclose(expected, x)
-        else:
-          cmp = expected.__eq__
-        self.assertTrue(
-            cmp(actual),
-            'Expected:\n\n%r\n\nActual:\n\n%r' % (expected, actual))
+        raise ValueError(
+            f"Expected value is a {type(expected)},"
+            "not a Series or DataFrame.")
+    else:
+      # Expectation is not a pandas object
+      if isinstance(expected, float):
+        cmp = lambda x: np.isclose(expected, x)
+      else:
+        cmp = expected.__eq__
+      self.assertTrue(
+          cmp(actual), 'Expected:\n\n%r\n\nActual:\n\n%r' % (expected, actual))
 
   def test_series_arithmetic(self):
     a = pd.Series([1, 2, 3])
@@ -246,14 +284,10 @@
     self._run_test(lambda df: df.groupby('group').bar.sum(), df)
     self._run_test(lambda df: df.groupby('group')['foo'].sum(), df)
     self._run_test(lambda df: df.groupby('group')['baz'].sum(), df)
-    self._run_test(
-        lambda df: df.groupby('group')[['bar', 'baz']].bar.sum(),
-        df,
-        expect_error=True)
-    self._run_test(
-        lambda df: df.groupby('group')[['bat']].sum(), df, expect_error=True)
-    self._run_test(
-        lambda df: df.groupby('group').bat.sum(), df, expect_error=True)
+    self._run_error_test(
+        lambda df: df.groupby('group')[['bar', 'baz']].bar.sum(), df)
+    self._run_error_test(lambda df: df.groupby('group')[['bat']].sum(), df)
+    self._run_error_test(lambda df: df.groupby('group').bat.sum(), df)
 
     self._run_test(lambda df: df.groupby('group').median(), df)
     self._run_test(lambda df: df.groupby('group').foo.median(), df)
@@ -266,26 +300,19 @@
     df = GROUPBY_DF
 
     # non-existent projection column
-    self._run_test(
-        lambda df: df.groupby('group')[['bar', 'baz']].bar.median(),
-        df,
-        expect_error=True)
-    self._run_test(
-        lambda df: df.groupby('group')[['bad']].median(), df, expect_error=True)
+    self._run_error_test(
+        lambda df: df.groupby('group')[['bar', 'baz']].bar.median(), df)
+    self._run_error_test(lambda df: df.groupby('group')[['bad']].median(), df)
 
-    self._run_test(
-        lambda df: df.groupby('group').bad.median(), df, expect_error=True)
+    self._run_error_test(lambda df: df.groupby('group').bad.median(), df)
 
   def test_groupby_errors_non_existent_label(self):
     df = GROUPBY_DF
 
     # non-existent grouping label
-    self._run_test(
-        lambda df: df.groupby(['really_bad', 'foo', 'bad']).foo.sum(),
-        df,
-        expect_error=True)
-    self._run_test(
-        lambda df: df.groupby('bad').foo.sum(), df, expect_error=True)
+    self._run_error_test(
+        lambda df: df.groupby(['really_bad', 'foo', 'bad']).foo.sum(), df)
+    self._run_error_test(lambda df: df.groupby('bad').foo.sum(), df)
 
   def test_groupby_callable(self):
     df = GROUPBY_DF
@@ -307,11 +334,9 @@
     self._run_test(lambda df: df.set_index(['index1', 'index2'], drop=True), df)
     self._run_test(lambda df: df.set_index('values'), df)
 
-    self._run_test(lambda df: df.set_index('bad'), df, expect_error=True)
-    self._run_test(
-        lambda df: df.set_index(['index2', 'bad', 'really_bad']),
-        df,
-        expect_error=True)
+    self._run_error_test(lambda df: df.set_index('bad'), df)
+    self._run_error_test(
+        lambda df: df.set_index(['index2', 'bad', 'really_bad']), df)
 
   def test_series_drop_ignore_errors(self):
     midx = pd.MultiIndex(
@@ -397,22 +422,21 @@
     df2 = pd.DataFrame({
         'rkey': ['foo', 'bar', 'baz', 'foo'], 'value': [5, 6, 7, 8]
     })
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, left_on='lkey', right_on='rkey').rename(
-              index=lambda x: '*'),
-          df1,
-          df2)
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(
-              df2,
-              left_on='lkey',
-              right_on='rkey',
-              suffixes=('_left', '_right')).rename(index=lambda x: '*'),
-          df1,
-          df2)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, left_on='lkey', right_on='rkey').rename(
+            index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(
+            df2, left_on='lkey', right_on='rkey', suffixes=('_left', '_right')).
+        rename(index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
 
   def test_merge_left_join(self):
     # This is from the pandas doctests, but fails due to re-indexing being
@@ -420,12 +444,12 @@
     df1 = pd.DataFrame({'a': ['foo', 'bar'], 'b': [1, 2]})
     df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4]})
 
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, how='left', on='a').rename(index=lambda x: '*'),
-          df1,
-          df2)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, how='left', on='a').rename(index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
 
   def test_merge_on_index(self):
     # This is from the pandas doctests, but fails due to re-indexing being
@@ -436,12 +460,12 @@
     df2 = pd.DataFrame({
         'rkey': ['foo', 'bar', 'baz', 'foo'], 'value': [5, 6, 7, 8]
     }).set_index('rkey')
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, left_index=True, right_index=True),
-          df1,
-          df2)
+
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, left_index=True, right_index=True),
+        df1,
+        df2)
 
   def test_merge_same_key(self):
     df1 = pd.DataFrame({
@@ -450,55 +474,58 @@
     df2 = pd.DataFrame({
         'key': ['foo', 'bar', 'baz', 'foo'], 'value': [5, 6, 7, 8]
     })
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, on='key').rename(index=lambda x: '*'),
-          df1,
-          df2)
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, on='key', suffixes=('_left', '_right')).rename(
-              index=lambda x: '*'),
-          df1,
-          df2)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, on='key').rename(index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, on='key', suffixes=('_left', '_right')).rename(
+            index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
 
   def test_merge_same_key_doctest(self):
     df1 = pd.DataFrame({'a': ['foo', 'bar'], 'b': [1, 2]})
     df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4]})
 
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, how='left', on='a').rename(index=lambda x: '*'),
-          df1,
-          df2)
-      # Test without specifying 'on'
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, how='left').rename(index=lambda x: '*'),
-          df1,
-          df2)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, how='left', on='a').rename(index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
+    # Test without specifying 'on'
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, how='left').rename(index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
 
   def test_merge_same_key_suffix_collision(self):
     df1 = pd.DataFrame({'a': ['foo', 'bar'], 'b': [1, 2], 'a_lsuffix': [5, 6]})
     df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4], 'a_rsuffix': [7, 8]})
 
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(
-              df2, how='left', on='a', suffixes=('_lsuffix', '_rsuffix')).
-          rename(index=lambda x: '*'),
-          df1,
-          df2)
-      # Test without specifying 'on'
-      self._run_test(
-          lambda df1,
-          df2: df1.merge(df2, how='left', suffixes=('_lsuffix', '_rsuffix')).
-          rename(index=lambda x: '*'),
-          df1,
-          df2)
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(
+            df2, how='left', on='a', suffixes=('_lsuffix', '_rsuffix')).rename(
+                index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
+    # Test without specifying 'on'
+    self._run_test(
+        lambda df1,
+        df2: df1.merge(df2, how='left', suffixes=('_lsuffix', '_rsuffix')).
+        rename(index=lambda x: '*'),
+        df1,
+        df2,
+        nonparallel=True)
 
   def test_series_getitem(self):
     s = pd.Series([x**2 for x in range(10)])
@@ -550,10 +577,9 @@
     s = pd.Series(list(range(16)))
     self._run_test(lambda s: s.agg('sum'), s)
     self._run_test(lambda s: s.agg(['sum']), s)
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(lambda s: s.agg(['sum', 'mean']), s)
-      self._run_test(lambda s: s.agg(['mean']), s)
-      self._run_test(lambda s: s.agg('mean'), s)
+    self._run_test(lambda s: s.agg(['sum', 'mean']), s, nonparallel=True)
+    self._run_test(lambda s: s.agg(['mean']), s, nonparallel=True)
+    self._run_test(lambda s: s.agg('mean'), s, nonparallel=True)
 
   def test_append_sort(self):
     # yapf: disable
@@ -573,12 +599,20 @@
   def test_dataframe_agg(self):
     df = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [2, 3, 5, 7]})
     self._run_test(lambda df: df.agg('sum'), df)
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(lambda df: df.agg(['sum', 'mean']), df)
-      self._run_test(lambda df: df.agg({'A': 'sum', 'B': 'sum'}), df)
-      self._run_test(lambda df: df.agg({'A': 'sum', 'B': 'mean'}), df)
-      self._run_test(lambda df: df.agg({'A': ['sum', 'mean']}), df)
-      self._run_test(lambda df: df.agg({'A': ['sum', 'mean'], 'B': 'min'}), df)
+    self._run_test(lambda df: df.agg(['sum', 'mean']), df, nonparallel=True)
+    self._run_test(lambda df: df.agg({'A': 'sum', 'B': 'sum'}), df)
+    self._run_test(
+        lambda df: df.agg({
+            'A': 'sum', 'B': 'mean'
+        }), df, nonparallel=True)
+    self._run_test(
+        lambda df: df.agg({'A': ['sum', 'mean']}), df, nonparallel=True)
+    self._run_test(
+        lambda df: df.agg({
+            'A': ['sum', 'mean'], 'B': 'min'
+        }),
+        df,
+        nonparallel=True)
 
   def test_smallest_largest(self):
     df = pd.DataFrame({'A': [1, 1, 2, 2], 'B': [2, 3, 5, 7]})
@@ -622,9 +656,8 @@
     df = df.set_index('B')
     # TODO(BEAM-11190): These aggregations can be done in index partitions, but
     # it will require a little more complex logic
-    with beam.dataframe.allow_non_parallel_operations():
-      self._run_test(lambda df: df.groupby(level=0).sum(), df)
-      self._run_test(lambda df: df.groupby(level=0).mean(), df)
+    self._run_test(lambda df: df.groupby(level=0).sum(), df, nonparallel=True)
+    self._run_test(lambda df: df.groupby(level=0).mean(), df, nonparallel=True)
 
   def test_dataframe_eval_query(self):
     df = pd.DataFrame(np.random.randn(20, 3), columns=['a', 'b', 'c'])