[BEAM-12388] Add caching to deferred dataframes
This adds caching to any dataframes using the InteractiveRuner.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/__init__.py b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
index 97b1be9..cce3aca 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/__init__.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
@@ -14,4 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
new file mode 100644
index 0000000..5b1b9ef
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import *
+
+import apache_beam as beam
+from apache_beam.dataframe import convert
+from apache_beam.dataframe import expressions
+
+
+class ExpressionCache(object):
+ """Utility class for caching deferred DataFrames expressions.
+
+ This is cache is currently a light-weight wrapper around the
+ TO_PCOLLECTION_CACHE in the beam.dataframes.convert module and the
+ computed_pcollections in the interactive module.
+
+ Example::
+
+ df : beam.dataframe.DeferredDataFrame = ...
+ ...
+ cache = ExpressionCache()
+ cache.replace_with_cached(df._expr)
+
+ This will automatically link the instance to the existing caches. After it is
+ created, the cache can then be used to modify an existing deferred dataframe
+ expression tree to replace nodes with computed PCollections.
+
+ This object can be created and destroyed whenever. This class holds no state
+ and the only side-effect is modifying the given expression.
+ """
+ def __init__(self, pcollection_cache=None, computed_cache=None):
+ from apache_beam.runners.interactive import interactive_environment as ie
+
+ self._pcollection_cache = (
+ convert.TO_PCOLLECTION_CACHE
+ if pcollection_cache is None else pcollection_cache)
+ self._computed_cache = (
+ ie.current_env().computed_pcollections
+ if computed_cache is None else computed_cache)
+
+ def replace_with_cached(
+ self, expr: expressions.Expression) -> Dict[str, expressions.Expression]:
+ """Replaces any previously computed expressions with PlaceholderExpressions.
+
+ This is used to correctly read any expressions that were cached in previous
+ runs. This enables the InteractiveRunner to prune off old calculations from
+ the expression tree.
+ """
+
+ replaced_inputs: Dict[str, expressions.Expression] = {}
+ self._replace_with_cached_recur(expr, replaced_inputs)
+ return replaced_inputs
+
+ def _replace_with_cached_recur(
+ self,
+ expr: expressions.Expression,
+ replaced_inputs: Dict[str, expressions.Expression]) -> None:
+ """Recursive call for `replace_with_cached`.
+
+ Recurses through the expression tree and replaces any cached inputs with
+ `PlaceholderExpression`s.
+ """
+
+ final_inputs = []
+
+ for input in expr.args():
+ pc = self._get_cached(input)
+
+ # Only read from cache when there is the PCollection has been fully
+ # computed. This is so that no partial results are used.
+ if self._is_computed(pc):
+
+ # Reuse previously seen cached expressions. This is so that the same
+ # value isn't cached multiple times.
+ if input._id in replaced_inputs:
+ cached = replaced_inputs[input._id]
+ else:
+ cached = expressions.PlaceholderExpression(
+ input.proxy(), self._pcollection_cache[input._id])
+
+ replaced_inputs[input._id] = cached
+ final_inputs.append(cached)
+ else:
+ final_inputs.append(input)
+ self._replace_with_cached_recur(input, replaced_inputs)
+ expr._args = tuple(final_inputs)
+
+ def _get_cached(self,
+ expr: expressions.Expression) -> Optional[beam.PCollection]:
+ """Returns the PCollection associated with the expression."""
+ return self._pcollection_cache.get(expr._id, None)
+
+ def _is_computed(self, pc: beam.PCollection) -> bool:
+ """Returns True if the PCollection has been run and computed."""
+ return pc is not None and pc in self._computed_cache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
new file mode 100644
index 0000000..c6e46f3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.dataframe import expressions
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
+
+
+class ExpressionCacheTest(unittest.TestCase):
+ def setUp(self):
+ self._pcollection_cache = {}
+ self._computed_cache = set()
+ self._pipeline = beam.Pipeline()
+ self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache)
+
+ def create_trace(self, expr):
+ trace = [expr]
+ for input in expr.args():
+ trace += self.create_trace(input)
+ return trace
+
+ def mock_cache(self, expr):
+ pcoll = beam.PCollection(self._pipeline)
+ self._pcollection_cache[expr._id] = pcoll
+ self._computed_cache.add(pcoll)
+
+ def assertTraceTypes(self, expr, expected):
+ actual_types = [type(e).__name__ for e in self.create_trace(expr)]
+ expected_types = [e.__name__ for e in expected]
+ self.assertListEqual(actual_types, expected_types)
+
+ def test_only_replaces_cached(self):
+ in_expr = expressions.ConstantExpression(0)
+ comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr])
+
+ # Expect that no replacement of expressions is performed.
+ expected_trace = [
+ expressions.ComputedExpression, expressions.ConstantExpression
+ ]
+ self.assertTraceTypes(comp_expr, expected_trace)
+
+ self.cache.replace_with_cached(comp_expr)
+
+ self.assertTraceTypes(comp_expr, expected_trace)
+
+ # Now "cache" the expression and assert that the cached expression was
+ # replaced with a placeholder.
+ self.mock_cache(in_expr)
+
+ replaced = self.cache.replace_with_cached(comp_expr)
+
+ expected_trace = [
+ expressions.ComputedExpression, expressions.PlaceholderExpression
+ ]
+ self.assertTraceTypes(comp_expr, expected_trace)
+ self.assertIn(in_expr._id, replaced)
+
+ def test_only_replaces_inputs(self):
+ arg_0_expr = expressions.ConstantExpression(0)
+ ident_val = expressions.ComputedExpression(
+ 'ident', lambda x: x, [arg_0_expr])
+
+ arg_1_expr = expressions.ConstantExpression(1)
+ comp_expr = expressions.ComputedExpression(
+ 'add', lambda x, y: x + y, [ident_val, arg_1_expr])
+
+ self.mock_cache(ident_val)
+
+ replaced = self.cache.replace_with_cached(comp_expr)
+
+ # Assert that ident_val was replaced and that its arguments were removed
+ # from the expression tree.
+ expected_trace = [
+ expressions.ComputedExpression,
+ expressions.PlaceholderExpression,
+ expressions.ConstantExpression
+ ]
+ self.assertTraceTypes(comp_expr, expected_trace)
+ self.assertIn(ident_val._id, replaced)
+ self.assertNotIn(arg_0_expr, self.create_trace(comp_expr))
+
+ def test_only_caches_same_input(self):
+ arg_0_expr = expressions.ConstantExpression(0)
+ ident_val = expressions.ComputedExpression(
+ 'ident', lambda x: x, [arg_0_expr])
+ comp_expr = expressions.ComputedExpression(
+ 'add', lambda x, y: x + y, [ident_val, arg_0_expr])
+
+ self.mock_cache(arg_0_expr)
+
+ replaced = self.cache.replace_with_cached(comp_expr)
+
+ # Assert that arg_0_expr, being an input to two computations, was replaced
+ # with the same placeholder expression.
+ expected_trace = [
+ expressions.ComputedExpression,
+ expressions.ComputedExpression,
+ expressions.PlaceholderExpression,
+ expressions.PlaceholderExpression
+ ]
+ actual_trace = self.create_trace(comp_expr)
+ unique_placeholders = set(
+ t for t in actual_trace
+ if isinstance(t, expressions.PlaceholderExpression))
+ self.assertTraceTypes(comp_expr, expected_trace)
+ self.assertTrue(
+ all(e == replaced[arg_0_expr._id] for e in unique_placeholders))
+ self.assertIn(arg_0_expr._id, replaced)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
index c4d062e..0b3f8a3 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
@@ -37,12 +37,12 @@
import pandas as pd
import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.display import pipeline_graph
from apache_beam.runners.interactive.display.pcoll_visualization import visualize
from apache_beam.runners.interactive.options import interactive_options
+from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
from apache_beam.runners.interactive.utils import elements_to_df
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.runners.runner import PipelineState
@@ -455,10 +455,7 @@
element_types = {}
for pcoll in flatten_pcolls:
if isinstance(pcoll, DeferredBase):
- proxy = pcoll._expr.proxy()
- pcoll = to_pcollection(
- pcoll, yield_elements='pandas', label=str(pcoll._expr))
- element_type = proxy
+ pcoll, element_type = deferred_df_to_pcollection(pcoll)
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
else:
element_type = pcoll.element_type
@@ -569,11 +566,7 @@
# collect the result in elements_to_df.
if isinstance(pcoll, DeferredBase):
# Get the proxy so we can get the output shape of the DataFrame.
- # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
- # instead of the proxy.
- element_type = pcoll._expr.proxy()
- pcoll = to_pcollection(
- pcoll, yield_elements='pandas', label=str(pcoll._expr))
+ pcoll, element_type = deferred_df_to_pcollection(pcoll)
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
else:
element_type = pcoll.element_type
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index fe6989e..69551de 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -408,6 +408,81 @@
df_expected['cube'],
ib.collect(df['cube'], n=10).reset_index(drop=True))
+ @unittest.skipIf(
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
+ @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
+ @patch('IPython.get_ipython', new_callable=mock_get_ipython)
+ def test_dataframe_caching(self, cell):
+
+ # Create a pipeline that exercises the DataFrame API. This will also use
+ # caching in the background.
+ with cell: # Cell 1
+ p = beam.Pipeline(interactive_runner.InteractiveRunner())
+ ib.watch({'p': p})
+
+ with cell: # Cell 2
+ data = p | beam.Create([
+ 1, 2, 3
+ ]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
+
+ with beam.dataframe.allow_non_parallel_operations():
+ df = to_dataframe(data).reset_index(drop=True)
+
+ ib.collect(df)
+
+ with cell: # Cell 3
+ df['output'] = df['square'] * df['cube']
+ ib.collect(df)
+
+ with cell: # Cell 4
+ df['output'] = 0
+ ib.collect(df)
+
+ # We use a trace through the graph to perform an isomorphism test. The end
+ # output should look like a linear graph. This indicates that the dataframe
+ # transform was correctly broken into separate pieces to cache. If caching
+ # isn't enabled, all the dataframe computation nodes are connected to a
+ # single shared node.
+ trace = []
+
+ # Only look at the top-level transforms for the isomorphism. The test
+ # doesn't care about the transform implementations, just the overall shape.
+ class TopLevelTracer(beam.pipeline.PipelineVisitor):
+ def _find_root_producer(self, node: beam.pipeline.AppliedPTransform):
+ if node is None or not node.full_label:
+ return None
+
+ parent = self._find_root_producer(node.parent)
+ if parent is None:
+ return node
+
+ return parent
+
+ def _add_to_trace(self, node, trace):
+ if '/' not in str(node):
+ if node.inputs:
+ producer = self._find_root_producer(node.inputs[0].producer)
+ producer_name = producer.full_label if producer else ''
+ trace.append((producer_name, node.full_label))
+
+ def visit_transform(self, node: beam.pipeline.AppliedPTransform):
+ self._add_to_trace(node, trace)
+
+ def enter_composite_transform(
+ self, node: beam.pipeline.AppliedPTransform):
+ self._add_to_trace(node, trace)
+
+ p.visit(TopLevelTracer())
+
+ # Do the isomorphism test which states that the topological sort of the
+ # graph yields a linear graph.
+ trace_string = '\n'.join(str(t) for t in trace)
+ prev_producer = ''
+ for producer, consumer in trace:
+ self.assertEqual(producer, prev_producer, trace_string)
+ prev_producer = consumer
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index 448f76a..c51a648 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -23,7 +23,6 @@
import pandas as pd
import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.runners.interactive import background_caching_job as bcj
@@ -310,7 +309,7 @@
# TODO(BEAM-12388): investigate the mixing pcollections in multiple
# pipelines error when using the default label.
for df in watched_dataframes:
- pcoll = to_pcollection(df, yield_elements='pandas', label=str(df._expr))
+ pcoll, _ = utils.deferred_df_to_pcollection(df)
watched_pcollections.add(pcoll)
for pcoll in pcolls:
if pcoll not in watched_pcollections:
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py
index bbe88ff..3e85145 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -25,7 +25,10 @@
import pandas as pd
+from apache_beam.dataframe.convert import to_pcollection
+from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
from apache_beam.testing.test_stream import WindowedValueHolder
from apache_beam.typehints.schemas import named_fields_from_element_type
@@ -267,3 +270,17 @@
return str(return_value)
return return_as_json
+
+
+def deferred_df_to_pcollection(df):
+ assert isinstance(df, DeferredBase), '{} is not a DeferredBase'.format(df)
+
+ # The proxy is used to output a DataFrame with the correct columns.
+ #
+ # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
+ # instead of the proxy.
+ cache = ExpressionCache()
+ cache.replace_with_cached(df._expr)
+
+ proxy = df._expr.proxy()
+ return to_pcollection(df, yield_elements='pandas', label=str(df._expr)), proxy