[BEAM-12388] Add caching to deferred dataframes

This adds caching to any dataframes using the InteractiveRuner.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/__init__.py b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
index 97b1be9..cce3aca 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/__init__.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
@@ -14,4 +14,3 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
new file mode 100644
index 0000000..5b1b9ef
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import *
+
+import apache_beam as beam
+from apache_beam.dataframe import convert
+from apache_beam.dataframe import expressions
+
+
+class ExpressionCache(object):
+  """Utility class for caching deferred DataFrames expressions.
+
+  This is cache is currently a light-weight wrapper around the
+  TO_PCOLLECTION_CACHE in the beam.dataframes.convert module and the
+  computed_pcollections in the interactive module.
+
+  Example::
+
+    df : beam.dataframe.DeferredDataFrame = ...
+    ...
+    cache = ExpressionCache()
+    cache.replace_with_cached(df._expr)
+
+  This will automatically link the instance to the existing caches. After it is
+  created, the cache can then be used to modify an existing deferred dataframe
+  expression tree to replace nodes with computed PCollections.
+
+  This object can be created and destroyed whenever. This class holds no state
+  and the only side-effect is modifying the given expression.
+  """
+  def __init__(self, pcollection_cache=None, computed_cache=None):
+    from apache_beam.runners.interactive import interactive_environment as ie
+
+    self._pcollection_cache = (
+        convert.TO_PCOLLECTION_CACHE
+        if pcollection_cache is None else pcollection_cache)
+    self._computed_cache = (
+        ie.current_env().computed_pcollections
+        if computed_cache is None else computed_cache)
+
+  def replace_with_cached(
+      self, expr: expressions.Expression) -> Dict[str, expressions.Expression]:
+    """Replaces any previously computed expressions with PlaceholderExpressions.
+
+    This is used to correctly read any expressions that were cached in previous
+    runs. This enables the InteractiveRunner to prune off old calculations from
+    the expression tree.
+    """
+
+    replaced_inputs: Dict[str, expressions.Expression] = {}
+    self._replace_with_cached_recur(expr, replaced_inputs)
+    return replaced_inputs
+
+  def _replace_with_cached_recur(
+      self,
+      expr: expressions.Expression,
+      replaced_inputs: Dict[str, expressions.Expression]) -> None:
+    """Recursive call for `replace_with_cached`.
+
+    Recurses through the expression tree and replaces any cached inputs with
+    `PlaceholderExpression`s.
+    """
+
+    final_inputs = []
+
+    for input in expr.args():
+      pc = self._get_cached(input)
+
+      # Only read from cache when there is the PCollection has been fully
+      # computed. This is so that no partial results are used.
+      if self._is_computed(pc):
+
+        # Reuse previously seen cached expressions. This is so that the same
+        # value isn't cached multiple times.
+        if input._id in replaced_inputs:
+          cached = replaced_inputs[input._id]
+        else:
+          cached = expressions.PlaceholderExpression(
+              input.proxy(), self._pcollection_cache[input._id])
+
+          replaced_inputs[input._id] = cached
+        final_inputs.append(cached)
+      else:
+        final_inputs.append(input)
+        self._replace_with_cached_recur(input, replaced_inputs)
+    expr._args = tuple(final_inputs)
+
+  def _get_cached(self,
+                  expr: expressions.Expression) -> Optional[beam.PCollection]:
+    """Returns the PCollection associated with the expression."""
+    return self._pcollection_cache.get(expr._id, None)
+
+  def _is_computed(self, pc: beam.PCollection) -> bool:
+    """Returns True if the PCollection has been run and computed."""
+    return pc is not None and pc in self._computed_cache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
new file mode 100644
index 0000000..c6e46f3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.dataframe import expressions
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
+
+
+class ExpressionCacheTest(unittest.TestCase):
+  def setUp(self):
+    self._pcollection_cache = {}
+    self._computed_cache = set()
+    self._pipeline = beam.Pipeline()
+    self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache)
+
+  def create_trace(self, expr):
+    trace = [expr]
+    for input in expr.args():
+      trace += self.create_trace(input)
+    return trace
+
+  def mock_cache(self, expr):
+    pcoll = beam.PCollection(self._pipeline)
+    self._pcollection_cache[expr._id] = pcoll
+    self._computed_cache.add(pcoll)
+
+  def assertTraceTypes(self, expr, expected):
+    actual_types = [type(e).__name__ for e in self.create_trace(expr)]
+    expected_types = [e.__name__ for e in expected]
+    self.assertListEqual(actual_types, expected_types)
+
+  def test_only_replaces_cached(self):
+    in_expr = expressions.ConstantExpression(0)
+    comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr])
+
+    # Expect that no replacement of expressions is performed.
+    expected_trace = [
+        expressions.ComputedExpression, expressions.ConstantExpression
+    ]
+    self.assertTraceTypes(comp_expr, expected_trace)
+
+    self.cache.replace_with_cached(comp_expr)
+
+    self.assertTraceTypes(comp_expr, expected_trace)
+
+    # Now "cache" the expression and assert that the cached expression was
+    # replaced with a placeholder.
+    self.mock_cache(in_expr)
+
+    replaced = self.cache.replace_with_cached(comp_expr)
+
+    expected_trace = [
+        expressions.ComputedExpression, expressions.PlaceholderExpression
+    ]
+    self.assertTraceTypes(comp_expr, expected_trace)
+    self.assertIn(in_expr._id, replaced)
+
+  def test_only_replaces_inputs(self):
+    arg_0_expr = expressions.ConstantExpression(0)
+    ident_val = expressions.ComputedExpression(
+        'ident', lambda x: x, [arg_0_expr])
+
+    arg_1_expr = expressions.ConstantExpression(1)
+    comp_expr = expressions.ComputedExpression(
+        'add', lambda x, y: x + y, [ident_val, arg_1_expr])
+
+    self.mock_cache(ident_val)
+
+    replaced = self.cache.replace_with_cached(comp_expr)
+
+    # Assert that ident_val was replaced and that its arguments were removed
+    # from the expression tree.
+    expected_trace = [
+        expressions.ComputedExpression,
+        expressions.PlaceholderExpression,
+        expressions.ConstantExpression
+    ]
+    self.assertTraceTypes(comp_expr, expected_trace)
+    self.assertIn(ident_val._id, replaced)
+    self.assertNotIn(arg_0_expr, self.create_trace(comp_expr))
+
+  def test_only_caches_same_input(self):
+    arg_0_expr = expressions.ConstantExpression(0)
+    ident_val = expressions.ComputedExpression(
+        'ident', lambda x: x, [arg_0_expr])
+    comp_expr = expressions.ComputedExpression(
+        'add', lambda x, y: x + y, [ident_val, arg_0_expr])
+
+    self.mock_cache(arg_0_expr)
+
+    replaced = self.cache.replace_with_cached(comp_expr)
+
+    # Assert that arg_0_expr, being an input to two computations, was replaced
+    # with the same placeholder expression.
+    expected_trace = [
+        expressions.ComputedExpression,
+        expressions.ComputedExpression,
+        expressions.PlaceholderExpression,
+        expressions.PlaceholderExpression
+    ]
+    actual_trace = self.create_trace(comp_expr)
+    unique_placeholders = set(
+        t for t in actual_trace
+        if isinstance(t, expressions.PlaceholderExpression))
+    self.assertTraceTypes(comp_expr, expected_trace)
+    self.assertTrue(
+        all(e == replaced[arg_0_expr._id] for e in unique_placeholders))
+    self.assertIn(arg_0_expr._id, replaced)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
index c4d062e..0b3f8a3 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
@@ -37,12 +37,12 @@
 import pandas as pd
 
 import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
 from apache_beam.dataframe.frame_base import DeferredBase
 from apache_beam.runners.interactive import interactive_environment as ie
 from apache_beam.runners.interactive.display import pipeline_graph
 from apache_beam.runners.interactive.display.pcoll_visualization import visualize
 from apache_beam.runners.interactive.options import interactive_options
+from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
 from apache_beam.runners.interactive.utils import elements_to_df
 from apache_beam.runners.interactive.utils import progress_indicated
 from apache_beam.runners.runner import PipelineState
@@ -455,10 +455,7 @@
   element_types = {}
   for pcoll in flatten_pcolls:
     if isinstance(pcoll, DeferredBase):
-      proxy = pcoll._expr.proxy()
-      pcoll = to_pcollection(
-          pcoll, yield_elements='pandas', label=str(pcoll._expr))
-      element_type = proxy
+      pcoll, element_type = deferred_df_to_pcollection(pcoll)
       watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
     else:
       element_type = pcoll.element_type
@@ -569,11 +566,7 @@
   # collect the result in elements_to_df.
   if isinstance(pcoll, DeferredBase):
     # Get the proxy so we can get the output shape of the DataFrame.
-    # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
-    # instead of the proxy.
-    element_type = pcoll._expr.proxy()
-    pcoll = to_pcollection(
-        pcoll, yield_elements='pandas', label=str(pcoll._expr))
+    pcoll, element_type = deferred_df_to_pcollection(pcoll)
     watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
   else:
     element_type = pcoll.element_type
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index fe6989e..69551de 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -408,6 +408,81 @@
         df_expected['cube'],
         ib.collect(df['cube'], n=10).reset_index(drop=True))
 
+  @unittest.skipIf(
+      not ie.current_env().is_interactive_ready,
+      '[interactive] dependency is not installed.')
+  @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
+  @patch('IPython.get_ipython', new_callable=mock_get_ipython)
+  def test_dataframe_caching(self, cell):
+
+    # Create a pipeline that exercises the DataFrame API. This will also use
+    # caching in the background.
+    with cell:  # Cell 1
+      p = beam.Pipeline(interactive_runner.InteractiveRunner())
+      ib.watch({'p': p})
+
+    with cell:  # Cell 2
+      data = p | beam.Create([
+          1, 2, 3
+      ]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
+
+      with beam.dataframe.allow_non_parallel_operations():
+        df = to_dataframe(data).reset_index(drop=True)
+
+      ib.collect(df)
+
+    with cell:  # Cell 3
+      df['output'] = df['square'] * df['cube']
+      ib.collect(df)
+
+    with cell:  # Cell 4
+      df['output'] = 0
+      ib.collect(df)
+
+    # We use a trace through the graph to perform an isomorphism test. The end
+    # output should look like a linear graph. This indicates that the dataframe
+    # transform was correctly broken into separate pieces to cache. If caching
+    # isn't enabled, all the dataframe computation nodes are connected to a
+    # single shared node.
+    trace = []
+
+    # Only look at the top-level transforms for the isomorphism. The test
+    # doesn't care about the transform implementations, just the overall shape.
+    class TopLevelTracer(beam.pipeline.PipelineVisitor):
+      def _find_root_producer(self, node: beam.pipeline.AppliedPTransform):
+        if node is None or not node.full_label:
+          return None
+
+        parent = self._find_root_producer(node.parent)
+        if parent is None:
+          return node
+
+        return parent
+
+      def _add_to_trace(self, node, trace):
+        if '/' not in str(node):
+          if node.inputs:
+            producer = self._find_root_producer(node.inputs[0].producer)
+            producer_name = producer.full_label if producer else ''
+            trace.append((producer_name, node.full_label))
+
+      def visit_transform(self, node: beam.pipeline.AppliedPTransform):
+        self._add_to_trace(node, trace)
+
+      def enter_composite_transform(
+          self, node: beam.pipeline.AppliedPTransform):
+        self._add_to_trace(node, trace)
+
+    p.visit(TopLevelTracer())
+
+    # Do the isomorphism test which states that the topological sort of the
+    # graph yields a linear graph.
+    trace_string = '\n'.join(str(t) for t in trace)
+    prev_producer = ''
+    for producer, consumer in trace:
+      self.assertEqual(producer, prev_producer, trace_string)
+      prev_producer = consumer
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index 448f76a..c51a648 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -23,7 +23,6 @@
 import pandas as pd
 
 import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
 from apache_beam.dataframe.frame_base import DeferredBase
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.runners.interactive import background_caching_job as bcj
@@ -310,7 +309,7 @@
     # TODO(BEAM-12388): investigate the mixing pcollections in multiple
     # pipelines error when using the default label.
     for df in watched_dataframes:
-      pcoll = to_pcollection(df, yield_elements='pandas', label=str(df._expr))
+      pcoll, _ = utils.deferred_df_to_pcollection(df)
       watched_pcollections.add(pcoll)
     for pcoll in pcolls:
       if pcoll not in watched_pcollections:
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py
index bbe88ff..3e85145 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -25,7 +25,10 @@
 
 import pandas as pd
 
+from apache_beam.dataframe.convert import to_pcollection
+from apache_beam.dataframe.frame_base import DeferredBase
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
 from apache_beam.testing.test_stream import WindowedValueHolder
 from apache_beam.typehints.schemas import named_fields_from_element_type
 
@@ -267,3 +270,17 @@
       return str(return_value)
 
   return return_as_json
+
+
+def deferred_df_to_pcollection(df):
+  assert isinstance(df, DeferredBase), '{} is not a DeferredBase'.format(df)
+
+  # The proxy is used to output a DataFrame with the correct columns.
+  #
+  # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
+  # instead of the proxy.
+  cache = ExpressionCache()
+  cache.replace_with_cached(df._expr)
+
+  proxy = df._expr.proxy()
+  return to_pcollection(df, yield_elements='pandas', label=str(df._expr)), proxy