[BEAM-12596] Ensure that SDFs always report a size >= 0.

diff --git a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
index db7021d..ef37772 100644
--- a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
+++ b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
@@ -409,6 +409,7 @@
   "\x01\x00\x00\x00\x00\x02\x03foo\x01\xa9F\x03bar\x01\xff\xff\xff\xff\xff\xff\xff\xff\x7f": {f_map: {"foo": 9001, "bar": 9223372036854775807}}
   "\x01\x00\x00\x00\x00\x04\neverything\x00\x02is\x00\x05null!\x00\r\xc2\xaf\\_(\xe3\x83\x84)_/\xc2\xaf\x00": {f_map: {"everything": null, "is": null, "null!": null, "¯\\_(ツ)_/¯": null}}
 
+---
 
 coder:
   urn: "beam:coder:row:v1"
@@ -440,3 +441,25 @@
     shardId: "",
     key: "key"
   }
+
+---
+
+# Java code snippet to generate example bytes:
+#  TimestampPrefixingWindowCoder<IntervalWindow> coder = TimestampPrefixingWindowCoder.of(IntervalWindowCoder.of());
+#  Instant end = new Instant(-9223372036854410L);
+#  Duration span = Duration.millis(365L);
+#  IntervalWindow window = new IntervalWindow(end.minus(span), span);
+#  byte[] bytes = CoderUtils.encodeToByteArray(coder, window);
+#  String str = new String(bytes, java.nio.charset.StandardCharsets.ISO_8859_1);
+#  String example = "";
+#  for(int i = 0; i < str.length(); i++){
+#    example += CharUtils.unicodeEscaped(str.charAt(i));
+#  }
+#  System.out.println(example);
+coder:
+  urn: "beam:coder:custom_window:v1"
+  components: [{urn: "beam:coder:interval_window:v1"}]
+
+examples:
+  "\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0067\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0068\u0080\u00dd\u00db\u0001" : {window: {end: 1454293425000, span: 3600000}}
+  "\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0075\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0076\u00ed\u0002" : {window: {end: -9223372036854410, span: 365}}
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index b5ad193..a367321 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -939,6 +939,31 @@
     // Experimental.
     STATE_BACKED_ITERABLE = 9 [(beam_urn) = "beam:coder:state_backed_iterable:v1"];
 
+
+    // Encodes an arbitrary user defined window and its max timestamp (inclusive).
+    // The encoding format is:
+    //   maxTimestamp window
+    //
+    //   maxTimestamp - A big endian 8 byte integer representing millis-since-epoch.
+    //     The encoded representation is shifted so that the byte representation
+    //     of negative values are lexicographically ordered before the byte
+    //     representation of positive values. This is typically done by
+    //     subtracting -9223372036854775808 from the value and encoding it as a
+    //     signed big endian integer. Example values:
+    //
+    //     -9223372036854775808: 00 00 00 00 00 00 00 00
+    //                     -255: 7F FF FF FF FF FF FF 01
+    //                       -1: 7F FF FF FF FF FF FF FF
+    //                        0: 80 00 00 00 00 00 00 00
+    //                        1: 80 00 00 00 00 00 00 01
+    //                      256: 80 00 00 00 00 00 01 00
+    //      9223372036854775807: FF FF FF FF FF FF FF FF
+    //
+    //   window - the window is encoded using the supplied window coder.
+    //
+    // Components: Coder for the custom window type.
+    CUSTOM_WINDOW = 16 [(beam_urn) = "beam:coder:custom_window:v1"];
+
     // Additional Standard Coders
     // --------------------------
     // The following coders are not required to be implemented for an SDK or
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
index 38763e2..992a9eb 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
@@ -28,6 +28,7 @@
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.LengthPrefixCoder;
 import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.SchemaTranslation;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -189,6 +190,20 @@
     };
   }
 
+  static CoderTranslator<TimestampPrefixingWindowCoder<?>> timestampPrefixingWindow() {
+    return new SimpleStructuredCoderTranslator<TimestampPrefixingWindowCoder<?>>() {
+      @Override
+      protected TimestampPrefixingWindowCoder<?> fromComponents(List<Coder<?>> components) {
+        return TimestampPrefixingWindowCoder.of((Coder<? extends BoundedWindow>) components.get(0));
+      }
+
+      @Override
+      public List<? extends Coder<?>> getComponents(TimestampPrefixingWindowCoder<?> from) {
+        return from.getComponents();
+      }
+    };
+  }
+
   public abstract static class SimpleStructuredCoderTranslator<T extends Coder<?>>
       implements CoderTranslator<T> {
     @Override
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
index 44e2caa..1fc8379 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
@@ -32,6 +32,7 @@
 import org.apache.beam.sdk.coders.LengthPrefixCoder;
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
@@ -74,6 +75,7 @@
           .put(DoubleCoder.class, ModelCoders.DOUBLE_CODER_URN)
           .put(RowCoder.class, ModelCoders.ROW_CODER_URN)
           .put(ShardedKey.Coder.class, ModelCoders.SHARDED_KEY_CODER_URN)
+          .put(TimestampPrefixingWindowCoder.class, ModelCoders.CUSTOM_WINDOW_CODER_URN)
           .build();
 
   public static final Set<String> WELL_KNOWN_CODER_URNS = BEAM_MODEL_CODER_URNS.values();
@@ -96,6 +98,7 @@
           .put(DoubleCoder.class, CoderTranslators.atomic(DoubleCoder.class))
           .put(RowCoder.class, CoderTranslators.row())
           .put(ShardedKey.Coder.class, CoderTranslators.shardedKey())
+          .put(TimestampPrefixingWindowCoder.class, CoderTranslators.timestampPrefixingWindow())
           .build();
 
   static {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
index 86e1e7d..0ff70f1 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
@@ -54,6 +54,8 @@
   public static final String INTERVAL_WINDOW_CODER_URN =
       getUrn(StandardCoders.Enum.INTERVAL_WINDOW);
 
+  public static final String CUSTOM_WINDOW_CODER_URN = getUrn(StandardCoders.Enum.CUSTOM_WINDOW);
+
   public static final String WINDOWED_VALUE_CODER_URN = getUrn(StandardCoders.Enum.WINDOWED_VALUE);
   public static final String PARAM_WINDOWED_VALUE_CODER_URN =
       getUrn(StandardCoders.Enum.PARAM_WINDOWED_VALUE);
@@ -82,6 +84,7 @@
           LENGTH_PREFIX_CODER_URN,
           GLOBAL_WINDOW_CODER_URN,
           INTERVAL_WINDOW_CODER_URN,
+          CUSTOM_WINDOW_CODER_URN,
           WINDOWED_VALUE_CODER_URN,
           DOUBLE_CODER_URN,
           ROW_CODER_URN,
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
index 8a1af22..11543f3 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
@@ -45,6 +45,7 @@
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.Field;
@@ -95,6 +96,7 @@
                       Field.of("map", FieldType.map(FieldType.STRING, FieldType.INT32)),
                       Field.of("bar", FieldType.logicalType(FixedBytes.of(123))))))
           .add(ShardedKey.Coder.of(StringUtf8Coder.of()))
+          .add(TimestampPrefixingWindowCoder.of(IntervalWindowCoder.of()))
           .build();
 
   /**
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java
index 6c7744f..d7d3665 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java
@@ -59,6 +59,7 @@
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.SchemaTranslation;
@@ -116,6 +117,7 @@
               WindowedValue.ParamWindowedValueCoder.class)
           .put(getUrn(StandardCoders.Enum.ROW), RowCoder.class)
           .put(getUrn(StandardCoders.Enum.SHARDED_KEY), ShardedKey.Coder.class)
+          .put(getUrn(StandardCoders.Enum.CUSTOM_WINDOW), TimestampPrefixingWindowCoder.class)
           .build();
 
   @AutoValue
@@ -345,6 +347,10 @@
       byte[] shardId = ((String) kvMap.get("shardId")).getBytes(StandardCharsets.ISO_8859_1);
       return ShardedKey.of(
           convertValue(kvMap.get("key"), coderSpec.getComponents().get(0), keyCoder), shardId);
+    } else if (s.equals(getUrn(StandardCoders.Enum.CUSTOM_WINDOW))) {
+      Map<String, Object> kvMap = (Map<String, Object>) value;
+      Coder windowCoder = ((TimestampPrefixingWindowCoder) coder).getWindowCoder();
+      return convertValue(kvMap.get("window"), coderSpec.getComponents().get(0), windowCoder);
     } else {
       throw new IllegalStateException("Unknown coder URN: " + coderSpec.getUrn());
     }
@@ -502,6 +508,8 @@
       assertEquals(expectedValue, actualValue);
     } else if (s.equals(getUrn(StandardCoders.Enum.SHARDED_KEY))) {
       assertEquals(expectedValue, actualValue);
+    } else if (s.equals(getUrn(StandardCoders.Enum.CUSTOM_WINDOW))) {
+      assertEquals(expectedValue, actualValue);
     } else {
       throw new IllegalStateException("Unknown coder URN: " + coder.getUrn());
     }
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java
index c6264f3..0c397f4 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java
@@ -21,6 +21,7 @@
 class CloudObjectKinds {
   static final String KIND_GLOBAL_WINDOW = "kind:global_window";
   static final String KIND_INTERVAL_WINDOW = "kind:interval_window";
+  static final String KIND_CUSTOM_WINDOW = "kind:custom_window";
   static final String KIND_LENGTH_PREFIX = "kind:length_prefix";
   static final String KIND_PAIR = "kind:pair";
   static final String KIND_STREAM = "kind:stream";
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java
index e57205d..07377fd 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java
@@ -33,6 +33,7 @@
 import org.apache.beam.sdk.coders.LengthPrefixCoder;
 import org.apache.beam.sdk.coders.MapCoder;
 import org.apache.beam.sdk.coders.NullableCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder;
 import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
@@ -241,6 +242,34 @@
     };
   }
 
+  static CloudObjectTranslator<TimestampPrefixingWindowCoder> customWindow() {
+    return new CloudObjectTranslator<TimestampPrefixingWindowCoder>() {
+      @Override
+      public CloudObject toCloudObject(
+          TimestampPrefixingWindowCoder target, SdkComponents sdkComponents) {
+        CloudObject result = CloudObject.forClassName(CloudObjectKinds.KIND_CUSTOM_WINDOW);
+        return addComponents(result, target.getComponents(), sdkComponents);
+      }
+
+      @Override
+      public TimestampPrefixingWindowCoder fromCloudObject(CloudObject cloudObject) {
+        List<Coder<?>> components = getComponents(cloudObject);
+        checkArgument(components.size() == 1, "Expecting 1 component, got %s", components.size());
+        return TimestampPrefixingWindowCoder.of((Coder<? extends BoundedWindow>) components.get(0));
+      }
+
+      @Override
+      public Class<? extends TimestampPrefixingWindowCoder> getSupportedClass() {
+        return TimestampPrefixingWindowCoder.class;
+      }
+
+      @Override
+      public String cloudObjectClassName() {
+        return CloudObjectKinds.KIND_CUSTOM_WINDOW;
+      }
+    };
+  }
+
   /**
    * Returns a {@link CloudObjectTranslator} that produces a {@link CloudObject} that is of kind
    * "windowed_value".
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
index 766a32d..56c4d57 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
@@ -31,6 +31,7 @@
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.LengthPrefixCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
@@ -58,7 +59,8 @@
           Timer.Coder.class,
           LengthPrefixCoder.class,
           GlobalWindow.Coder.class,
-          FullWindowedValueCoder.class);
+          FullWindowedValueCoder.class,
+          TimestampPrefixingWindowCoder.class);
 
   static final Map<Class<? extends Coder>, CloudObjectTranslator<? extends Coder>>
       CODER_TRANSLATORS = populateCoderTranslators();
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java
index 01101c6..92070ef 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java
@@ -65,6 +65,7 @@
       ImmutableList.of(
           CloudObjectTranslators.globalWindow(),
           CloudObjectTranslators.intervalWindow(),
+          CloudObjectTranslators.customWindow(),
           CloudObjectTranslators.bytes(),
           CloudObjectTranslators.varInt(),
           CloudObjectTranslators.lengthPrefix(),
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
index eb7e6cd..f1a004f 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
@@ -51,6 +51,7 @@
 import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.SetCoder;
 import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -144,6 +145,7 @@
               .add(new ObjectCoder())
               .add(GlobalWindow.Coder.INSTANCE)
               .add(IntervalWindow.getCoder())
+              .add(TimestampPrefixingWindowCoder.of(IntervalWindow.getCoder()))
               .add(LengthPrefixCoder.of(VarLongCoder.of()))
               .add(IterableCoder.of(VarLongCoder.of()))
               .add(KvCoder.of(VarLongCoder.of(), ByteArrayCoder.of()))
diff --git a/sdks/go/pkg/beam/core/graph/window/fn.go b/sdks/go/pkg/beam/core/graph/window/fn.go
index 7197571..dba90f6 100644
--- a/sdks/go/pkg/beam/core/graph/window/fn.go
+++ b/sdks/go/pkg/beam/core/graph/window/fn.go
@@ -29,7 +29,7 @@
 	GlobalWindows  Kind = "GLO"
 	FixedWindows   Kind = "FIX"
 	SlidingWindows Kind = "SLI"
-	Sessions       Kind = "SES" // TODO
+	Sessions       Kind = "SES"
 )
 
 // NewGlobalWindows returns the default WindowFn, which places all elements
diff --git a/sdks/go/pkg/beam/core/runtime/exec/window.go b/sdks/go/pkg/beam/core/runtime/exec/window.go
index cc7accc..7e8d0c4 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/window.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/window.go
@@ -73,6 +73,11 @@
 			ret = append(ret, window.IntervalWindow{Start: start, End: start.Add(wfn.Size)})
 		}
 		return ret
+	case window.Sessions:
+		// Assign each element into a window from its timestamp until Gap in the
+		// future.  Overlapping windows (representing elements within Gap of
+		// each other) will be merged.
+		return []typex.Window{window.IntervalWindow{Start: ts, End: ts.Add(wfn.Gap)}}
 
 	default:
 		panic(fmt.Sprintf("Unexpected window fn: %v", wfn))
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
index 099738f..7472e83 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
@@ -1036,7 +1036,7 @@
 	switch w.Kind {
 	case window.GlobalWindows:
 		return coder.NewGlobalWindow(), nil
-	case window.FixedWindows, window.SlidingWindows, URNSlidingWindowsWindowFn:
+	case window.FixedWindows, window.SlidingWindows, window.Sessions, URNSlidingWindowsWindowFn:
 		return coder.NewIntervalWindow(), nil
 	default:
 		return nil, errors.Errorf("unexpected windowing strategy for coder: %v", w)
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index cdb7e47..2297443 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -60,6 +60,8 @@
 var directFilters = []string{
 	// The direct runner does not yet support cross-language.
 	"TestXLang.*",
+	// TODO(BEAM-4152): The direct runner does not support session window merging.
+	"TestWindowSums.*",
 }
 
 var portableFilters = []string{}
diff --git a/sdks/go/test/integration/primitives/windowinto.go b/sdks/go/test/integration/primitives/windowinto.go
index 573f5c7..f6e3fbd 100644
--- a/sdks/go/test/integration/primitives/windowinto.go
+++ b/sdks/go/test/integration/primitives/windowinto.go
@@ -67,9 +67,8 @@
 	validate(s.Scope("SlidingFixed"), window.NewSlidingWindows(windowSize, windowSize), timestampedData, 15, 15, 15)
 	// This will have overlap, but each value should be a multiple of the magic number.
 	validate(s.Scope("Sliding"), window.NewSlidingWindows(windowSize, 3*windowSize), timestampedData, 15, 30, 45, 30, 15)
-	// TODO(BEAM-4152): This will do a smoke test of session windows, once implemented through.
 	// With such a large gap, there should be a single session which will sum to 45.
-	// validate(s.Scope("Session"), window.NewSessions(windowSize), timestampedData, 45)
+	validate(s.Scope("Session"), window.NewSessions(windowSize), timestampedData, 45)
 }
 
 func sumPerKey(ws beam.Window, ts beam.EventTime, key beam.U, iter func(*int) bool) (beam.U, int) {
diff --git a/sdks/go/test/regression/coders/fromyaml/fromyaml.go b/sdks/go/test/regression/coders/fromyaml/fromyaml.go
index 3544a5f..8ba9e99 100644
--- a/sdks/go/test/regression/coders/fromyaml/fromyaml.go
+++ b/sdks/go/test/regression/coders/fromyaml/fromyaml.go
@@ -45,6 +45,7 @@
 	"beam:coder:param_windowed_value:v1": true,
 	"beam:coder:timer:v1":                true,
 	"beam:coder:sharded_key:v1":          true,
+	"beam:coder:custom_window:v1":        true,
 }
 
 // Coder is a representation a serialized beam coder.
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoder.java
new file mode 100644
index 0000000..ea7d56e
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoder.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.coders;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+
+/**
+ * A {@link TimestampPrefixingWindowCoder} wraps arbitrary user custom window coder. While encoding
+ * the custom window type, it extracts the maxTimestamp(inclusive) of the window and prefix it to
+ * the encoded bytes of the window using the user custom window coder.
+ *
+ * @param <T> The custom window type.
+ */
+public class TimestampPrefixingWindowCoder<T extends BoundedWindow> extends StructuredCoder<T> {
+  private final Coder<T> windowCoder;
+
+  public static <T extends BoundedWindow> TimestampPrefixingWindowCoder<T> of(
+      Coder<T> windowCoder) {
+    return new TimestampPrefixingWindowCoder<>(windowCoder);
+  }
+
+  TimestampPrefixingWindowCoder(Coder<T> windowCoder) {
+    this.windowCoder = windowCoder;
+  }
+
+  public Coder<T> getWindowCoder() {
+    return windowCoder;
+  }
+
+  @Override
+  public void encode(T value, OutputStream outStream) throws CoderException, IOException {
+    if (value == null) {
+      throw new CoderException("Cannot encode null window");
+    }
+    InstantCoder.of().encode(value.maxTimestamp(), outStream);
+    windowCoder.encode(value, outStream);
+  }
+
+  @Override
+  public T decode(InputStream inStream) throws CoderException, IOException {
+    InstantCoder.of().decode(inStream);
+    return windowCoder.decode(inStream);
+  }
+
+  @Override
+  public List<? extends Coder<?>> getCoderArguments() {
+    return Lists.newArrayList(windowCoder);
+  }
+
+  @Override
+  public void verifyDeterministic() throws NonDeterministicException {
+    windowCoder.verifyDeterministic();
+  }
+
+  @Override
+  public boolean consistentWithEquals() {
+    return windowCoder.consistentWithEquals();
+  }
+
+  @Override
+  public boolean isRegisterByteSizeObserverCheap(T value) {
+    return windowCoder.isRegisterByteSizeObserverCheap(value);
+  }
+
+  @Override
+  public void registerByteSizeObserver(T value, ElementByteSizeObserver observer) throws Exception {
+    InstantCoder.of().registerByteSizeObserver(value.maxTimestamp(), observer);
+    windowCoder.registerByteSizeObserver(value, observer);
+  }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java
index 87ffa4a..143901a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java
@@ -27,6 +27,7 @@
 import org.apache.beam.sdk.coders.DurationCoder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -175,6 +176,19 @@
     }
 
     @Override
+    public boolean isRegisterByteSizeObserverCheap(IntervalWindow value) {
+      return instantCoder.isRegisterByteSizeObserverCheap(value.end)
+          && durationCoder.isRegisterByteSizeObserverCheap(new Duration(value.start, value.end));
+    }
+
+    @Override
+    public void registerByteSizeObserver(IntervalWindow value, ElementByteSizeObserver observer)
+        throws Exception {
+      instantCoder.registerByteSizeObserver(value.end, observer);
+      durationCoder.registerByteSizeObserver(new Duration(value.start, value.end), observer);
+    }
+
+    @Override
     public List<? extends Coder<?>> getCoderArguments() {
       return Collections.emptyList();
     }
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoderTest.java
new file mode 100644
index 0000000..a9f8123
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoderTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.coders;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
+import java.util.Objects;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.CoderProperties.TestElementByteSizeObserver;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import org.junit.Test;
+
+public class TimestampPrefixingWindowCoderTest {
+
+  private static class CustomWindow extends IntervalWindow {
+    private boolean isBig;
+
+    CustomWindow(Instant start, Instant end, boolean isBig) {
+      super(start, end);
+      this.isBig = isBig;
+    }
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      CustomWindow that = (CustomWindow) o;
+      return super.equals(o) && this.isBig == that.isBig;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(super.hashCode(), isBig);
+    }
+  }
+
+  private static class CustomWindowCoder extends CustomCoder<CustomWindow> {
+
+    private static final Coder<IntervalWindow> INTERVAL_WINDOW_CODER = IntervalWindow.getCoder();
+    private static final int REGISTER_BYTE_SIZE = 1234;
+    private final boolean isConsistentWithEqual;
+    private final boolean isRegisterByteSizeCheap;
+
+    public static CustomWindowCoder of(
+        boolean isConsistentWithEqual, boolean isRegisterByteSizeCheap) {
+      return new CustomWindowCoder(isConsistentWithEqual, isRegisterByteSizeCheap);
+    }
+
+    private CustomWindowCoder(boolean isConsistentWithEqual, boolean isRegisterByteSizeCheap) {
+      this.isConsistentWithEqual = isConsistentWithEqual;
+      this.isRegisterByteSizeCheap = isRegisterByteSizeCheap;
+    }
+
+    @Override
+    public void encode(CustomWindow window, OutputStream outStream) throws IOException {
+      INTERVAL_WINDOW_CODER.encode(window, outStream);
+      BooleanCoder.of().encode(window.isBig, outStream);
+    }
+
+    @Override
+    public CustomWindow decode(InputStream inStream) throws IOException {
+      IntervalWindow superWindow = INTERVAL_WINDOW_CODER.decode(inStream);
+      boolean isBig = BooleanCoder.of().decode(inStream);
+      return new CustomWindow(superWindow.start(), superWindow.end(), isBig);
+    }
+
+    @Override
+    public void verifyDeterministic() throws NonDeterministicException {
+      INTERVAL_WINDOW_CODER.verifyDeterministic();
+      BooleanCoder.of().verifyDeterministic();
+    }
+
+    @Override
+    public boolean consistentWithEquals() {
+      return isConsistentWithEqual;
+    }
+
+    @Override
+    public boolean isRegisterByteSizeObserverCheap(CustomWindow value) {
+      return isRegisterByteSizeCheap;
+    }
+
+    @Override
+    public void registerByteSizeObserver(CustomWindow value, ElementByteSizeObserver observer)
+        throws Exception {
+      observer.update(REGISTER_BYTE_SIZE);
+    }
+  }
+
+  private static final List<CustomWindow> CUSTOM_WINDOW_LIST =
+      Lists.newArrayList(
+          new CustomWindow(new Instant(0L), new Instant(1L), true),
+          new CustomWindow(new Instant(100L), new Instant(200L), false),
+          new CustomWindow(new Instant(0L), BoundedWindow.TIMESTAMP_MAX_VALUE, true));
+
+  @Test
+  public void testEncodeAndDecode() throws Exception {
+    List<IntervalWindow> intervalWindowsToTest =
+        Lists.newArrayList(
+            new IntervalWindow(new Instant(0L), new Instant(1L)),
+            new IntervalWindow(new Instant(100L), new Instant(200L)),
+            new IntervalWindow(new Instant(0L), BoundedWindow.TIMESTAMP_MAX_VALUE));
+    TimestampPrefixingWindowCoder<IntervalWindow> coder1 =
+        TimestampPrefixingWindowCoder.of(IntervalWindow.getCoder());
+    for (IntervalWindow window : intervalWindowsToTest) {
+      CoderProperties.coderDecodeEncodeEqual(coder1, window);
+    }
+
+    GlobalWindow globalWindow = GlobalWindow.INSTANCE;
+    TimestampPrefixingWindowCoder<GlobalWindow> coder2 =
+        TimestampPrefixingWindowCoder.of(GlobalWindow.Coder.INSTANCE);
+    CoderProperties.coderDecodeEncodeEqual(coder2, globalWindow);
+    TimestampPrefixingWindowCoder<CustomWindow> coder3 =
+        TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+    for (CustomWindow window : CUSTOM_WINDOW_LIST) {
+      CoderProperties.coderDecodeEncodeEqual(coder3, window);
+    }
+  }
+
+  @Test
+  public void testConsistentWithEquals() {
+    TimestampPrefixingWindowCoder<CustomWindow> coder1 =
+        TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+    assertThat(coder1.consistentWithEquals(), equalTo(true));
+    TimestampPrefixingWindowCoder<CustomWindow> coder2 =
+        TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(false, true));
+    assertThat(coder2.consistentWithEquals(), equalTo(false));
+  }
+
+  @Test
+  public void testIsRegisterByteSizeObserverCheap() {
+    TimestampPrefixingWindowCoder<CustomWindow> coder1 =
+        TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+    assertThat(coder1.isRegisterByteSizeObserverCheap(CUSTOM_WINDOW_LIST.get(0)), equalTo(true));
+    TimestampPrefixingWindowCoder<CustomWindow> coder2 =
+        TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, false));
+    assertThat(coder2.isRegisterByteSizeObserverCheap(CUSTOM_WINDOW_LIST.get(0)), equalTo(false));
+  }
+
+  @Test
+  public void testGetEncodedElementByteSize() throws Exception {
+    TestElementByteSizeObserver observer = new TestElementByteSizeObserver();
+    TimestampPrefixingWindowCoder<CustomWindow> coder =
+        TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+    for (CustomWindow value : CUSTOM_WINDOW_LIST) {
+      coder.registerByteSizeObserver(value, observer);
+      observer.advance();
+      assertThat(
+          observer.getSumAndReset(),
+          equalTo(
+              CustomWindowCoder.REGISTER_BYTE_SIZE
+                  + InstantCoder.of().getEncodedElementByteSize(value.maxTimestamp())));
+    }
+  }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java
index 36daef0..a79ff6f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java
@@ -22,10 +22,13 @@
 
 import java.util.List;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.DurationCoder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.CoderProperties.TestElementByteSizeObserver;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -87,4 +90,20 @@
     assertThat(encodedHourWindow.length, equalTo(encodedStart.length + encodedHourEnd.length - 4));
     assertThat(encodedDayWindow.length, equalTo(encodedStart.length + encodedDayEnd.length - 4));
   }
+
+  @Test
+  public void testCoderRegisterByteSizeObserver() throws Exception {
+    assertThat(TEST_CODER.isRegisterByteSizeObserverCheap(TEST_VALUES.get(0)), equalTo(true));
+    TestElementByteSizeObserver observer = new TestElementByteSizeObserver();
+    TestElementByteSizeObserver observer2 = new TestElementByteSizeObserver();
+    for (IntervalWindow window : TEST_VALUES) {
+      TEST_CODER.registerByteSizeObserver(window, observer);
+      InstantCoder.of().registerByteSizeObserver(window.maxTimestamp(), observer2);
+      DurationCoder.of()
+          .registerByteSizeObserver(new Duration(window.start(), window.end()), observer2);
+      observer.advance();
+      observer2.advance();
+      assertThat(observer.getSumAndReset(), equalTo(observer2.getSumAndReset()));
+    }
+  }
 }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 6cc5e6e..b40c9d5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -47,6 +47,7 @@
 import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
 import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
 import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
+import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -310,6 +311,7 @@
               }
             });
     try {
+      BeamFnLoggingMDC.setInstructionId(request.getInstructionId());
       PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry();
       PTransformFunctionRegistry finishFunctionRegistry =
           bundleProcessor.getFinishFunctionRegistry();
@@ -364,6 +366,8 @@
       // Make sure we clean-up from the active set of bundle processors.
       bundleProcessorCache.discard(bundleProcessor);
       throw e;
+    } finally {
+      BeamFnLoggingMDC.setInstructionId(null);
     }
   }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
index 71d3b8e..6cd9116 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
@@ -211,6 +211,11 @@
                       .setSeconds(record.getMillis() / 1000)
                       .setNanos((int) (record.getMillis() % 1000) * 1_000_000));
 
+      String instructionId = BeamFnLoggingMDC.getInstructionId();
+      if (instructionId != null) {
+        builder.setInstructionId(instructionId);
+      }
+
       Throwable thrown = record.getThrown();
       if (thrown != null) {
         builder.setTrace(getStackTraceAsString(thrown));
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingMDC.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingMDC.java
new file mode 100644
index 0000000..bcfcd4b
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingMDC.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.logging;
+
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/** Mapped diagnostic context to be consumed and set on LogEntry protos in BeamFnLoggingClient. */
+public class BeamFnLoggingMDC {
+  private static final ThreadLocal<@Nullable String> instructionId = new ThreadLocal<>();
+
+  /** Sets the Instruction ID of the current thread, which will be inherited by child threads. */
+  public static void setInstructionId(@Nullable String newInstructionId) {
+    instructionId.set(newInstructionId);
+  }
+
+  /** Gets the Instruction ID of the current thread. */
+  public static @Nullable String getInstructionId() {
+    return instructionId.get();
+  }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java
index f77798a..dd9e412 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java
@@ -50,13 +50,14 @@
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.junit.rules.TestRule;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Tests for {@link BeamFnLoggingClient}. */
 @RunWith(JUnit4.class)
 public class BeamFnLoggingClientTest {
-
+  @Rule public TestRule restoreLogging = new RestoreBeamFnLoggingMDC();
   private static final LogRecord FILTERED_RECORD;
   private static final LogRecord TEST_RECORD;
   private static final LogRecord TEST_RECORD_WITH_EXCEPTION;
@@ -78,6 +79,7 @@
 
   private static final BeamFnApi.LogEntry TEST_ENTRY =
       BeamFnApi.LogEntry.newBuilder()
+          .setInstructionId("instruction-1")
           .setSeverity(BeamFnApi.LogEntry.Severity.Enum.DEBUG)
           .setMessage("Message")
           .setThread("12345")
@@ -86,6 +88,7 @@
           .build();
   private static final BeamFnApi.LogEntry TEST_ENTRY_WITH_EXCEPTION =
       BeamFnApi.LogEntry.newBuilder()
+          .setInstructionId("instruction-1")
           .setSeverity(BeamFnApi.LogEntry.Severity.Enum.WARN)
           .setMessage("MessageWithException")
           .setTrace(getStackTraceAsString(TEST_RECORD_WITH_EXCEPTION.getThrown()))
@@ -97,6 +100,7 @@
 
   @Test
   public void testLogging() throws Exception {
+    BeamFnLoggingMDC.setInstructionId("instruction-1");
     AtomicBoolean clientClosedStream = new AtomicBoolean();
     Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>();
     AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver =
@@ -175,6 +179,7 @@
 
   @Test
   public void testWhenServerFailsThatClientIsAbleToCleanup() throws Exception {
+    BeamFnLoggingMDC.setInstructionId("instruction-1");
     Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>();
     AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver =
         new AtomicReference<>();
@@ -244,6 +249,7 @@
 
   @Test
   public void testWhenServerHangsUpEarlyThatClientIsAbleCleanup() throws Exception {
+    BeamFnLoggingMDC.setInstructionId("instruction-1");
     Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>();
     AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver =
         new AtomicReference<>();
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDC.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDC.java
new file mode 100644
index 0000000..0036862
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDC.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.logging;
+
+import org.junit.rules.ExternalResource;
+
+/** Saves, clears and restores the current thread-local logging parameters for tests. */
+public class RestoreBeamFnLoggingMDC extends ExternalResource {
+  private String previousInstruction;
+
+  public RestoreBeamFnLoggingMDC() {}
+
+  @Override
+  protected void before() throws Throwable {
+    previousInstruction = BeamFnLoggingMDC.getInstructionId();
+    BeamFnLoggingMDC.setInstructionId(null);
+  }
+
+  @Override
+  protected void after() {
+    BeamFnLoggingMDC.setInstructionId(previousInstruction);
+  }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDCTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDCTest.java
new file mode 100644
index 0000000..47166b0
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDCTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.logging;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.junit.runners.model.Statement;
+
+/** Tests for {@link RestoreBeamFnLoggingMDC}. */
+@RunWith(JUnit4.class)
+public class RestoreBeamFnLoggingMDCTest {
+  @Rule public TestRule restoreMDCAfterTest = new RestoreBeamFnLoggingMDC();
+
+  @Test
+  public void testOldValuesAreRestored() throws Throwable {
+    // We need our own instance here so that we don't squash any saved values.
+    TestRule restoreMDC = new RestoreBeamFnLoggingMDC();
+
+    final boolean[] evaluateRan = new boolean[1];
+    BeamFnLoggingMDC.setInstructionId("oldInstruction");
+
+    restoreMDC
+        .apply(
+            new Statement() {
+              @Override
+              public void evaluate() {
+                evaluateRan[0] = true;
+                // Ensure parameters are cleared before the test runs
+                assertNull("null Instruction", BeamFnLoggingMDC.getInstructionId());
+
+                // Simulate updating parameters for the test
+                BeamFnLoggingMDC.setInstructionId("newInstruction");
+
+                // Ensure that the values changed
+                assertEquals("newInstruction", BeamFnLoggingMDC.getInstructionId());
+              }
+            },
+            Description.EMPTY)
+        .evaluate();
+
+    // Validate that the statement ran and that the values were reverted
+    assertTrue(evaluateRan[0]);
+    assertEquals("oldInstruction", BeamFnLoggingMDC.getInstructionId());
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
index 0e87243..15ea5c0 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
@@ -737,7 +737,7 @@
   private WriteResult writeResult(Pipeline p) {
     PCollection<TableRow> empty =
         p.apply("CreateEmptyFailedInserts", Create.empty(TypeDescriptor.of(TableRow.class)));
-    return WriteResult.in(p, new TupleTag<>("failedInserts"), empty);
+    return WriteResult.in(p, new TupleTag<>("failedInserts"), empty, null);
   }
 
   @Override
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
index 120d76f..484fe3b 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
@@ -70,8 +70,9 @@
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
 class BatchedStreamingWrite<ErrorT, ElementT>
-    extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollection<ErrorT>> {
+    extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollectionTuple> {
   private static final TupleTag<Void> mainOutputTag = new TupleTag<>("mainOutput");
+  static final TupleTag<TableRow> SUCCESSFUL_ROWS_TAG = new TupleTag<>("successfulRows");
   private static final Logger LOG = LoggerFactory.getLogger(BatchedStreamingWrite.class);
 
   private final BigQueryServices bqServices;
@@ -191,23 +192,24 @@
   }
 
   @Override
-  public PCollection<ErrorT> expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
+  public PCollectionTuple expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
     return batchViaStateful
         ? input.apply(new ViaStateful())
         : input.apply(new ViaBundleFinalization());
   }
 
   private class ViaBundleFinalization
-      extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollection<ErrorT>> {
+      extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollectionTuple> {
     @Override
-    public PCollection<ErrorT> expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
+    public PCollectionTuple expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
       PCollectionTuple result =
           input.apply(
               ParDo.of(new BatchAndInsertElements())
-                  .withOutputTags(mainOutputTag, TupleTagList.of(failedOutputTag)));
-      PCollection<ErrorT> failedInserts = result.get(failedOutputTag);
-      failedInserts.setCoder(failedOutputCoder);
-      return failedInserts;
+                  .withOutputTags(
+                      mainOutputTag, TupleTagList.of(failedOutputTag).and(SUCCESSFUL_ROWS_TAG)));
+      result.get(failedOutputTag).setCoder(failedOutputCoder);
+      result.get(SUCCESSFUL_ROWS_TAG).setCoder(TableRowJsonCoder.of());
+      return result;
     }
   }
 
@@ -249,6 +251,7 @@
     @FinishBundle
     public void finishBundle(FinishBundleContext context) throws Exception {
       List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
+      List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
       BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class);
       for (Map.Entry<String, List<FailsafeValueInSingleWindow<TableRow, TableRow>>> entry :
           tableRows.entrySet()) {
@@ -258,7 +261,8 @@
             entry.getValue(),
             uniqueIdsForTableRows.get(entry.getKey()),
             options,
-            failedInserts);
+            failedInserts,
+            successfulInserts);
       }
       tableRows.clear();
       uniqueIdsForTableRows.clear();
@@ -274,9 +278,9 @@
   private static final Duration BATCH_MAX_BUFFERING_DURATION = Duration.millis(200);
 
   private class ViaStateful
-      extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollection<ErrorT>> {
+      extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollectionTuple> {
     @Override
-    public PCollection<ErrorT> expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
+    public PCollectionTuple expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
       BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class);
       Duration maxBufferingDuration =
           options.getMaxBufferingDurationMilliSec() > 0
@@ -306,10 +310,12 @@
                       ShardedKey.Coder.of(StringUtf8Coder.of()), IterableCoder.of(valueCoder)))
               .apply(
                   ParDo.of(new InsertBatchedElements())
-                      .withOutputTags(mainOutputTag, TupleTagList.of(failedOutputTag)));
-      PCollection<ErrorT> failedInserts = result.get(failedOutputTag);
-      failedInserts.setCoder(failedOutputCoder);
-      return failedInserts;
+                      .withOutputTags(
+                          mainOutputTag,
+                          TupleTagList.of(failedOutputTag).and(SUCCESSFUL_ROWS_TAG)));
+      result.get(failedOutputTag).setCoder(failedOutputCoder);
+      result.get(SUCCESSFUL_ROWS_TAG).setCoder(TableRowJsonCoder.of());
+      return result;
     }
   }
 
@@ -340,11 +346,15 @@
       BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class);
       TableReference tableReference = BigQueryHelpers.parseTableSpec(input.getKey().getKey());
       List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
-      flushRows(tableReference, tableRows, uniqueIds, options, failedInserts);
+      List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
+      flushRows(tableReference, tableRows, uniqueIds, options, failedInserts, successfulInserts);
 
       for (ValueInSingleWindow<ErrorT> row : failedInserts) {
         out.get(failedOutputTag).output(row.getValue());
       }
+      for (ValueInSingleWindow<TableRow> row : successfulInserts) {
+        out.get(SUCCESSFUL_ROWS_TAG).output(row.getValue());
+      }
       reportStreamingApiLogging(options);
     }
   }
@@ -367,7 +377,8 @@
       List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows,
       List<String> uniqueIds,
       BigQueryOptions options,
-      List<ValueInSingleWindow<ErrorT>> failedInserts)
+      List<ValueInSingleWindow<ErrorT>> failedInserts,
+      List<ValueInSingleWindow<TableRow>> successfulInserts)
       throws InterruptedException {
     if (!tableRows.isEmpty()) {
       try {
@@ -382,7 +393,8 @@
                     errorContainer,
                     skipInvalidRows,
                     ignoreUnknownValues,
-                    ignoreInsertIds);
+                    ignoreInsertIds,
+                    successfulInserts);
         byteCounter.inc(totalBytes);
       } catch (IOException e) {
         throw new RuntimeException(e);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java
index c0b945c..f7b95e7 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java
@@ -166,7 +166,8 @@
         ErrorContainer<T> errorContainer,
         boolean skipInvalidRows,
         boolean ignoreUnknownValues,
-        boolean ignoreInsertIds)
+        boolean ignoreInsertIds,
+        List<ValueInSingleWindow<TableRow>> successfulRows)
         throws IOException, InterruptedException;
 
     /** Patch BigQuery {@link Table} description. */
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
index 5e9cdb7..d788bfe 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
@@ -86,9 +86,11 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
@@ -814,7 +816,8 @@
         ErrorContainer<T> errorContainer,
         boolean skipInvalidRows,
         boolean ignoreUnkownValues,
-        boolean ignoreInsertIds)
+        boolean ignoreInsertIds,
+        List<ValueInSingleWindow<TableRow>> successfulRows)
         throws IOException, InterruptedException {
       checkNotNull(ref, "ref");
       if (executor == null) {
@@ -829,6 +832,7 @@
                 + "as many elements as rowList");
       }
 
+      final Set<Integer> failedIndices = new HashSet<>();
       long retTotalDataSize = 0;
       List<TableDataInsertAllResponse.InsertErrors> allErrors = new ArrayList<>();
       // These lists contain the rows to publish. Initially the contain the entire list.
@@ -981,6 +985,7 @@
                 throw new IOException("Insert failed: " + error + ", other errors: " + allErrors);
               }
               int errorIndex = error.getIndex().intValue() + strideIndices.get(i);
+              failedIndices.add(errorIndex);
               if (retryPolicy.shouldRetry(new InsertRetryPolicy.Context(error))) {
                 allErrors.add(error);
                 retryRows.add(rowsToPublish.get(errorIndex));
@@ -1022,6 +1027,18 @@
         allErrors.clear();
         LOG.info("Retrying {} failed inserts to BigQuery", rowsToPublish.size());
       }
+      if (successfulRows != null) {
+        for (int i = 0; i < rowsToPublish.size(); i++) {
+          if (!failedIndices.contains(i)) {
+            successfulRows.add(
+                ValueInSingleWindow.of(
+                    rowsToPublish.get(i).getValue(),
+                    rowsToPublish.get(i).getTimestamp(),
+                    rowsToPublish.get(i).getWindow(),
+                    rowsToPublish.get(i).getPane()));
+          }
+        }
+      }
       if (!allErrors.isEmpty()) {
         throw new IOException("Insert failed: " + allErrors);
       } else {
@@ -1039,7 +1056,8 @@
         ErrorContainer<T> errorContainer,
         boolean skipInvalidRows,
         boolean ignoreUnknownValues,
-        boolean ignoreInsertIds)
+        boolean ignoreInsertIds,
+        List<ValueInSingleWindow<TableRow>> successfulRows)
         throws IOException, InterruptedException {
       return insertAll(
           ref,
@@ -1053,7 +1071,8 @@
           errorContainer,
           skipInvalidRows,
           ignoreUnknownValues,
-          ignoreInsertIds);
+          ignoreInsertIds,
+          successfulRows);
     }
 
     protected GoogleJsonError.ErrorInfo getErrorInfo(IOException e) {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
index 602ebea..3c27ddc 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
@@ -142,6 +142,6 @@
     // large.
     PCollection<TableRow> empty =
         p.apply("CreateEmptyFailedInserts", Create.empty(TypeDescriptor.of(TableRow.class)));
-    return WriteResult.in(p, new TupleTag<>("failedInserts"), empty);
+    return WriteResult.in(p, new TupleTag<>("failedInserts"), empty, null);
   }
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java
index 9bfe3d2..5afd1ae 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java
@@ -34,6 +34,7 @@
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.ShardedKey;
 import org.apache.beam.sdk.values.TupleTag;
 
@@ -245,26 +246,36 @@
   public WriteResult expand(PCollection<KV<TableDestination, ElementT>> input) {
     if (extendedErrorInfo) {
       TupleTag<BigQueryInsertError> failedInsertsTag = new TupleTag<>(FAILED_INSERTS_TAG_ID);
-      PCollection<BigQueryInsertError> failedInserts =
+      PCollectionTuple result =
           writeAndGetErrors(
               input,
               failedInsertsTag,
               BigQueryInsertErrorCoder.of(),
               ErrorContainer.BIG_QUERY_INSERT_ERROR_ERROR_CONTAINER);
-      return WriteResult.withExtendedErrors(input.getPipeline(), failedInsertsTag, failedInserts);
+      PCollection<BigQueryInsertError> failedInserts = result.get(failedInsertsTag);
+      return WriteResult.withExtendedErrors(
+          input.getPipeline(),
+          failedInsertsTag,
+          failedInserts,
+          result.get(BatchedStreamingWrite.SUCCESSFUL_ROWS_TAG));
     } else {
       TupleTag<TableRow> failedInsertsTag = new TupleTag<>(FAILED_INSERTS_TAG_ID);
-      PCollection<TableRow> failedInserts =
+      PCollectionTuple result =
           writeAndGetErrors(
               input,
               failedInsertsTag,
               TableRowJsonCoder.of(),
               ErrorContainer.TABLE_ROW_ERROR_CONTAINER);
-      return WriteResult.in(input.getPipeline(), failedInsertsTag, failedInserts);
+      PCollection<TableRow> failedInserts = result.get(failedInsertsTag);
+      return WriteResult.in(
+          input.getPipeline(),
+          failedInsertsTag,
+          failedInserts,
+          result.get(BatchedStreamingWrite.SUCCESSFUL_ROWS_TAG));
     }
   }
 
-  private <T> PCollection<T> writeAndGetErrors(
+  private <T> PCollectionTuple writeAndGetErrors(
       PCollection<KV<TableDestination, ElementT>> input,
       TupleTag<T> failedInsertsTag,
       AtomicCoder<T> coder,
@@ -336,7 +347,8 @@
 
       return shardedTagged
           .apply(Reshuffle.of())
-          // Put in the global window to ensure that DynamicDestinations side inputs are accessed
+          // Put in the global window to ensure that DynamicDestinations side inputs are
+          // accessed
           // correctly.
           .apply(
               "GlobalWindow",
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java
index fd75f671..85ef275 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java
@@ -29,6 +29,7 @@
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** The result of a {@link BigQueryIO.Write} transform. */
 @SuppressWarnings({
@@ -40,18 +41,25 @@
   private final PCollection<TableRow> failedInserts;
   private final TupleTag<BigQueryInsertError> failedInsertsWithErrTag;
   private final PCollection<BigQueryInsertError> failedInsertsWithErr;
+  private final PCollection<TableRow> successfulInserts;
 
   /** Creates a {@link WriteResult} in the given {@link Pipeline}. */
   static WriteResult in(
-      Pipeline pipeline, TupleTag<TableRow> failedInsertsTag, PCollection<TableRow> failedInserts) {
-    return new WriteResult(pipeline, failedInsertsTag, failedInserts, null, null);
+      Pipeline pipeline,
+      TupleTag<TableRow> failedInsertsTag,
+      PCollection<TableRow> failedInserts,
+      @Nullable PCollection<TableRow> successfulInserts) {
+    return new WriteResult(
+        pipeline, failedInsertsTag, failedInserts, null, null, successfulInserts);
   }
 
   static WriteResult withExtendedErrors(
       Pipeline pipeline,
       TupleTag<BigQueryInsertError> failedInsertsTag,
-      PCollection<BigQueryInsertError> failedInserts) {
-    return new WriteResult(pipeline, null, null, failedInsertsTag, failedInserts);
+      PCollection<BigQueryInsertError> failedInserts,
+      PCollection<TableRow> successfulInserts) {
+    return new WriteResult(
+        pipeline, null, null, failedInsertsTag, failedInserts, successfulInserts);
   }
 
   @Override
@@ -68,12 +76,14 @@
       TupleTag<TableRow> failedInsertsTag,
       PCollection<TableRow> failedInserts,
       TupleTag<BigQueryInsertError> failedInsertsWithErrTag,
-      PCollection<BigQueryInsertError> failedInsertsWithErr) {
+      PCollection<BigQueryInsertError> failedInsertsWithErr,
+      PCollection<TableRow> successfulInserts) {
     this.pipeline = pipeline;
     this.failedInsertsTag = failedInsertsTag;
     this.failedInserts = failedInserts;
     this.failedInsertsWithErrTag = failedInsertsWithErrTag;
     this.failedInsertsWithErr = failedInsertsWithErr;
+    this.successfulInserts = successfulInserts;
   }
 
   /**
@@ -91,6 +101,14 @@
     return failedInserts;
   }
 
+  /** Returns a {@link PCollection} containing the {@link TableRow}s that were written to BQ. */
+  public PCollection<TableRow> getSuccessfulInserts() {
+    checkArgument(
+        successfulInserts != null,
+        "Retrieving successful inserts is only supported for streaming inserts.");
+    return successfulInserts;
+  }
+
   /**
    * Returns a {@link PCollection} containing the {@link BigQueryInsertError}s with detailed error
    * information.
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
index 6bfd2b8..1ee6501 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
@@ -323,7 +323,8 @@
         null,
         false,
         false,
-        false);
+        false,
+        null);
   }
 
   @Override
@@ -336,7 +337,8 @@
       ErrorContainer<T> errorContainer,
       boolean skipInvalidRows,
       boolean ignoreUnknownValues,
-      boolean ignoreInsertIds)
+      boolean ignoreInsertIds,
+      List<ValueInSingleWindow<TableRow>> successfulRows)
       throws IOException, InterruptedException {
     Map<TableRow, List<TableDataInsertAllResponse.InsertErrors>> insertErrors = getInsertErrors();
     synchronized (tables) {
@@ -371,6 +373,14 @@
           } else {
             dataSize += tableContainer.addRow(row, insertIdList.get(i));
           }
+          if (successfulRows != null) {
+            successfulRows.add(
+                ValueInSingleWindow.of(
+                    row,
+                    rowList.get(i).getTimestamp(),
+                    rowList.get(i).getWindow(),
+                    rowList.get(i).getPane()));
+          }
         } else {
           errorContainer.add(
               failedInserts, allErrors.get(allErrors.size() - 1), ref, rowList.get(i));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
index e4c32f5..a891c48 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
@@ -173,7 +173,6 @@
         FileSystems.copy(sourceFiles.build(), loadFiles.build());
         filesForLoadJobs.put(jobRef.getProjectId(), jobRef.getJobId(), loadFiles.build());
       }
-
       allJobs.put(jobRef.getProjectId(), jobRef.getJobId(), new JobInfo(job));
     }
   }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
index 946e02f..bde803d 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
@@ -76,6 +76,7 @@
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method;
 import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption;
 import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices;
@@ -789,13 +790,13 @@
             row1, ImmutableList.of(ephemeralError, ephemeralError),
             row2, ImmutableList.of(ephemeralError, ephemeralError, persistentError)));
 
-    PCollection<TableRow> failedRows =
+    WriteResult result =
         p.apply(Create.of(row1, row2, row3))
             .apply(
                 BigQueryIO.writeTableRows()
                     .to("project-id:dataset-id.table-id")
-                    .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)
-                    .withMethod(BigQueryIO.Write.Method.STREAMING_INSERTS)
+                    .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED)
+                    .withMethod(Method.STREAMING_INSERTS)
                     .withSchema(
                         new TableSchema()
                             .setFields(
@@ -804,11 +805,15 @@
                                     new TableFieldSchema().setName("number").setType("INTEGER"))))
                     .withFailedInsertRetryPolicy(InsertRetryPolicy.retryTransientErrors())
                     .withTestServices(fakeBqServices)
-                    .withoutValidation())
-            .getFailedInserts();
+                    .withoutValidation());
+
+    PCollection<TableRow> failedRows = result.getFailedInserts();
     // row2 finally fails with a non-retryable error, so we expect to see it in the collection of
     // failed rows.
     PAssert.that(failedRows).containsInAnyOrder(row2);
+    if (useStorageApi || !useStreaming) {
+      PAssert.that(result.getSuccessfulInserts()).containsInAnyOrder(row1, row3);
+    }
     p.run();
 
     // Only row1 and row3 were successfully inserted.
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
index ae788d7..61b94f6 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
@@ -655,7 +655,8 @@
         null,
         false,
         false,
-        false);
+        false,
+        null);
 
     verifyAllResponsesAreRead();
     expectedLogs.verifyInfo("BigQuery insertAll error, retrying:");
@@ -699,7 +700,8 @@
         null,
         false,
         false,
-        false);
+        false,
+        null);
 
     verifyAllResponsesAreRead();
     expectedLogs.verifyInfo("BigQuery insertAll error, retrying:");
@@ -755,7 +757,8 @@
         null,
         false,
         false,
-        false);
+        false,
+        null);
 
     verifyAllResponsesAreRead();
 
@@ -813,7 +816,8 @@
         null,
         false,
         false,
-        false);
+        false,
+        null);
 
     verifyAllResponsesAreRead();
 
@@ -880,7 +884,8 @@
           null,
           false,
           false,
-          false);
+          false,
+          null);
       fail();
     } catch (IOException e) {
       assertThat(e, instanceOf(IOException.class));
@@ -940,7 +945,8 @@
           null,
           false,
           false,
-          false);
+          false,
+          null);
     } finally {
       verify(responses[0], atLeastOnce()).getStatusCode();
       verify(responses[0]).getContent();
@@ -1023,7 +1029,8 @@
         ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
         false,
         false,
-        false);
+        false,
+        null);
     assertEquals(1, failedInserts.size());
     expectedLogs.verifyInfo("Retrying 1 failed inserts to BigQuery");
 
@@ -1069,7 +1076,8 @@
         ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
         false,
         false,
-        false);
+        false,
+        null);
 
     TableDataInsertAllRequest parsedRequest =
         fromString(request.getContentAsString(), TableDataInsertAllRequest.class);
@@ -1090,7 +1098,8 @@
         ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
         true,
         true,
-        true);
+        true,
+        null);
 
     parsedRequest = fromString(request.getContentAsString(), TableDataInsertAllRequest.class);
 
@@ -1325,7 +1334,8 @@
         ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
         false,
         false,
-        false);
+        false,
+        null);
 
     assertThat(failedInserts, is(expected));
   }
@@ -1384,7 +1394,8 @@
         ErrorContainer.BIG_QUERY_INSERT_ERROR_ERROR_CONTAINER,
         false,
         false,
-        false);
+        false,
+        null);
 
     assertThat(failedInserts, is(expected));
   }
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java
index 5626812..1e64e22 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java
@@ -213,7 +213,16 @@
     try {
       totalBytes =
           datasetService.insertAll(
-              ref, rows, ids, InsertRetryPolicy.alwaysRetry(), null, null, false, false, false);
+              ref,
+              rows,
+              ids,
+              InsertRetryPolicy.alwaysRetry(),
+              null,
+              null,
+              false,
+              false,
+              false,
+              null);
     } finally {
       verifyInsertAll(5);
       // Each of the 25 rows has 1 byte for length and 30 bytes: '{"f":[{"v":"foo"},{"v":1234}]}'
diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py
index 0cf0208..618ee55 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -1483,3 +1483,31 @@
     estimated_size += (
         self._key_coder_impl.estimate_size(value.key, nested=True))
     return estimated_size
+
+
+class TimestampPrefixingWindowCoderImpl(StreamCoderImpl):
+  """For internal use only; no backwards-compatibility guarantees.
+
+  A coder for custom window types, which prefix required max_timestamp to
+  encoded original window.
+
+  The coder encodes and decodes custom window types with following format:
+    window's max_timestamp()
+    encoded window using it's own coder.
+  """
+  def __init__(self, window_coder_impl: CoderImpl) -> None:
+    self._window_coder_impl = window_coder_impl
+
+  def encode_to_stream(self, value, stream, nested):
+    TimestampCoderImpl().encode_to_stream(value.max_timestamp(), stream, nested)
+    self._window_coder_impl.encode_to_stream(value, stream, nested)
+
+  def decode_from_stream(self, stream, nested):
+    TimestampCoderImpl().decode_from_stream(stream, nested)
+    return self._window_coder_impl.decode_from_stream(stream, nested)
+
+  def estimate_size(self, value: Any, nested: bool = False) -> int:
+    estimated_size = 0
+    estimated_size += TimestampCoderImpl().estimate_size(value)
+    estimated_size += self._window_coder_impl.estimate_size(value, nested)
+    return estimated_size
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index 163fd17..2d2b336 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -1509,3 +1509,47 @@
 
 Coder.register_structured_urn(
     common_urns.coders.SHARDED_KEY.urn, ShardedKeyCoder)
+
+
+class TimestampPrefixingWindowCoder(FastCoder):
+  """For internal use only; no backwards-compatibility guarantees.
+
+  Coder which prefixes the max timestamp of arbitrary window to its encoded
+  form."""
+  def __init__(self, window_coder: Coder) -> None:
+    self._window_coder = window_coder
+
+  def _create_impl(self):
+    return coder_impl.TimestampPrefixingWindowCoderImpl(
+        self._window_coder.get_impl())
+
+  def to_type_hint(self):
+    return self._window_coder.to_type_hint()
+
+  def _get_component_coders(self) -> List[Coder]:
+    return [self._window_coder]
+
+  def is_deterministic(self) -> bool:
+    return self._window_coder.is_deterministic()
+
+  def as_cloud_object(self, coders_context=None):
+    return {
+        '@type': 'kind:custom_window',
+        'component_encodings': [
+            self._window_coder.as_cloud_object(coders_context)
+        ],
+    }
+
+  def __repr__(self):
+    return 'TimestampPrefixingWindowCoder[%r]' % self._window_coder
+
+  def __eq__(self, other):
+    return (
+        type(self) == type(other) and self._window_coder == other._window_coder)
+
+  def __hash__(self):
+    return hash((type(self), self._window_coder))
+
+
+Coder.register_structured_urn(
+    common_urns.coders.CUSTOM_WINDOW.urn, TimestampPrefixingWindowCoder)
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py
index 98e0547..11334ed 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -749,6 +749,19 @@
             coders.TupleCoder((coder, other_coder)),
             (ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
 
+  def test_timestamp_prefixing_window_coder(self):
+    self.check_coder(
+        coders.TimestampPrefixingWindowCoder(coders.IntervalWindowCoder()),
+        *[
+            window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52]
+            for y in range(-100, 100)
+        ])
+    self.check_coder(
+        coders.TupleCoder((
+            coders.TimestampPrefixingWindowCoder(
+                coders.IntervalWindowCoder()), )),
+        (window.IntervalWindow(0, 10), ))
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py
index 1a74dbe..454939f 100644
--- a/sdks/python/apache_beam/coders/standard_coders_test.py
+++ b/sdks/python/apache_beam/coders/standard_coders_test.py
@@ -189,7 +189,9 @@
       'beam:coder:double:v1': parse_float,
       'beam:coder:sharded_key:v1': lambda x,
       value_parser: ShardedKey(
-          key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8'))
+          key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
+      'beam:coder:custom_window:v1': lambda x,
+      window_parser: window_parser(x['window'])
   }
 
   def test_standard_coders(self):
diff --git a/sdks/python/apache_beam/runners/interactive/caching/__init__.py b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
index 97b1be9..cce3aca 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/__init__.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
@@ -14,4 +14,3 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
new file mode 100644
index 0000000..5b1b9ef
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import *
+
+import apache_beam as beam
+from apache_beam.dataframe import convert
+from apache_beam.dataframe import expressions
+
+
+class ExpressionCache(object):
+  """Utility class for caching deferred DataFrames expressions.
+
+  This is cache is currently a light-weight wrapper around the
+  TO_PCOLLECTION_CACHE in the beam.dataframes.convert module and the
+  computed_pcollections in the interactive module.
+
+  Example::
+
+    df : beam.dataframe.DeferredDataFrame = ...
+    ...
+    cache = ExpressionCache()
+    cache.replace_with_cached(df._expr)
+
+  This will automatically link the instance to the existing caches. After it is
+  created, the cache can then be used to modify an existing deferred dataframe
+  expression tree to replace nodes with computed PCollections.
+
+  This object can be created and destroyed whenever. This class holds no state
+  and the only side-effect is modifying the given expression.
+  """
+  def __init__(self, pcollection_cache=None, computed_cache=None):
+    from apache_beam.runners.interactive import interactive_environment as ie
+
+    self._pcollection_cache = (
+        convert.TO_PCOLLECTION_CACHE
+        if pcollection_cache is None else pcollection_cache)
+    self._computed_cache = (
+        ie.current_env().computed_pcollections
+        if computed_cache is None else computed_cache)
+
+  def replace_with_cached(
+      self, expr: expressions.Expression) -> Dict[str, expressions.Expression]:
+    """Replaces any previously computed expressions with PlaceholderExpressions.
+
+    This is used to correctly read any expressions that were cached in previous
+    runs. This enables the InteractiveRunner to prune off old calculations from
+    the expression tree.
+    """
+
+    replaced_inputs: Dict[str, expressions.Expression] = {}
+    self._replace_with_cached_recur(expr, replaced_inputs)
+    return replaced_inputs
+
+  def _replace_with_cached_recur(
+      self,
+      expr: expressions.Expression,
+      replaced_inputs: Dict[str, expressions.Expression]) -> None:
+    """Recursive call for `replace_with_cached`.
+
+    Recurses through the expression tree and replaces any cached inputs with
+    `PlaceholderExpression`s.
+    """
+
+    final_inputs = []
+
+    for input in expr.args():
+      pc = self._get_cached(input)
+
+      # Only read from cache when there is the PCollection has been fully
+      # computed. This is so that no partial results are used.
+      if self._is_computed(pc):
+
+        # Reuse previously seen cached expressions. This is so that the same
+        # value isn't cached multiple times.
+        if input._id in replaced_inputs:
+          cached = replaced_inputs[input._id]
+        else:
+          cached = expressions.PlaceholderExpression(
+              input.proxy(), self._pcollection_cache[input._id])
+
+          replaced_inputs[input._id] = cached
+        final_inputs.append(cached)
+      else:
+        final_inputs.append(input)
+        self._replace_with_cached_recur(input, replaced_inputs)
+    expr._args = tuple(final_inputs)
+
+  def _get_cached(self,
+                  expr: expressions.Expression) -> Optional[beam.PCollection]:
+    """Returns the PCollection associated with the expression."""
+    return self._pcollection_cache.get(expr._id, None)
+
+  def _is_computed(self, pc: beam.PCollection) -> bool:
+    """Returns True if the PCollection has been run and computed."""
+    return pc is not None and pc in self._computed_cache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
new file mode 100644
index 0000000..c6e46f3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.dataframe import expressions
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
+
+
+class ExpressionCacheTest(unittest.TestCase):
+  def setUp(self):
+    self._pcollection_cache = {}
+    self._computed_cache = set()
+    self._pipeline = beam.Pipeline()
+    self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache)
+
+  def create_trace(self, expr):
+    trace = [expr]
+    for input in expr.args():
+      trace += self.create_trace(input)
+    return trace
+
+  def mock_cache(self, expr):
+    pcoll = beam.PCollection(self._pipeline)
+    self._pcollection_cache[expr._id] = pcoll
+    self._computed_cache.add(pcoll)
+
+  def assertTraceTypes(self, expr, expected):
+    actual_types = [type(e).__name__ for e in self.create_trace(expr)]
+    expected_types = [e.__name__ for e in expected]
+    self.assertListEqual(actual_types, expected_types)
+
+  def test_only_replaces_cached(self):
+    in_expr = expressions.ConstantExpression(0)
+    comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr])
+
+    # Expect that no replacement of expressions is performed.
+    expected_trace = [
+        expressions.ComputedExpression, expressions.ConstantExpression
+    ]
+    self.assertTraceTypes(comp_expr, expected_trace)
+
+    self.cache.replace_with_cached(comp_expr)
+
+    self.assertTraceTypes(comp_expr, expected_trace)
+
+    # Now "cache" the expression and assert that the cached expression was
+    # replaced with a placeholder.
+    self.mock_cache(in_expr)
+
+    replaced = self.cache.replace_with_cached(comp_expr)
+
+    expected_trace = [
+        expressions.ComputedExpression, expressions.PlaceholderExpression
+    ]
+    self.assertTraceTypes(comp_expr, expected_trace)
+    self.assertIn(in_expr._id, replaced)
+
+  def test_only_replaces_inputs(self):
+    arg_0_expr = expressions.ConstantExpression(0)
+    ident_val = expressions.ComputedExpression(
+        'ident', lambda x: x, [arg_0_expr])
+
+    arg_1_expr = expressions.ConstantExpression(1)
+    comp_expr = expressions.ComputedExpression(
+        'add', lambda x, y: x + y, [ident_val, arg_1_expr])
+
+    self.mock_cache(ident_val)
+
+    replaced = self.cache.replace_with_cached(comp_expr)
+
+    # Assert that ident_val was replaced and that its arguments were removed
+    # from the expression tree.
+    expected_trace = [
+        expressions.ComputedExpression,
+        expressions.PlaceholderExpression,
+        expressions.ConstantExpression
+    ]
+    self.assertTraceTypes(comp_expr, expected_trace)
+    self.assertIn(ident_val._id, replaced)
+    self.assertNotIn(arg_0_expr, self.create_trace(comp_expr))
+
+  def test_only_caches_same_input(self):
+    arg_0_expr = expressions.ConstantExpression(0)
+    ident_val = expressions.ComputedExpression(
+        'ident', lambda x: x, [arg_0_expr])
+    comp_expr = expressions.ComputedExpression(
+        'add', lambda x, y: x + y, [ident_val, arg_0_expr])
+
+    self.mock_cache(arg_0_expr)
+
+    replaced = self.cache.replace_with_cached(comp_expr)
+
+    # Assert that arg_0_expr, being an input to two computations, was replaced
+    # with the same placeholder expression.
+    expected_trace = [
+        expressions.ComputedExpression,
+        expressions.ComputedExpression,
+        expressions.PlaceholderExpression,
+        expressions.PlaceholderExpression
+    ]
+    actual_trace = self.create_trace(comp_expr)
+    unique_placeholders = set(
+        t for t in actual_trace
+        if isinstance(t, expressions.PlaceholderExpression))
+    self.assertTraceTypes(comp_expr, expected_trace)
+    self.assertTrue(
+        all(e == replaced[arg_0_expr._id] for e in unique_placeholders))
+    self.assertIn(arg_0_expr._id, replaced)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
index c4d062e..0b3f8a3 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
@@ -37,12 +37,12 @@
 import pandas as pd
 
 import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
 from apache_beam.dataframe.frame_base import DeferredBase
 from apache_beam.runners.interactive import interactive_environment as ie
 from apache_beam.runners.interactive.display import pipeline_graph
 from apache_beam.runners.interactive.display.pcoll_visualization import visualize
 from apache_beam.runners.interactive.options import interactive_options
+from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
 from apache_beam.runners.interactive.utils import elements_to_df
 from apache_beam.runners.interactive.utils import progress_indicated
 from apache_beam.runners.runner import PipelineState
@@ -455,10 +455,7 @@
   element_types = {}
   for pcoll in flatten_pcolls:
     if isinstance(pcoll, DeferredBase):
-      proxy = pcoll._expr.proxy()
-      pcoll = to_pcollection(
-          pcoll, yield_elements='pandas', label=str(pcoll._expr))
-      element_type = proxy
+      pcoll, element_type = deferred_df_to_pcollection(pcoll)
       watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
     else:
       element_type = pcoll.element_type
@@ -569,11 +566,7 @@
   # collect the result in elements_to_df.
   if isinstance(pcoll, DeferredBase):
     # Get the proxy so we can get the output shape of the DataFrame.
-    # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
-    # instead of the proxy.
-    element_type = pcoll._expr.proxy()
-    pcoll = to_pcollection(
-        pcoll, yield_elements='pandas', label=str(pcoll._expr))
+    pcoll, element_type = deferred_df_to_pcollection(pcoll)
     watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
   else:
     element_type = pcoll.element_type
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index fe6989e..69551de 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -408,6 +408,81 @@
         df_expected['cube'],
         ib.collect(df['cube'], n=10).reset_index(drop=True))
 
+  @unittest.skipIf(
+      not ie.current_env().is_interactive_ready,
+      '[interactive] dependency is not installed.')
+  @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
+  @patch('IPython.get_ipython', new_callable=mock_get_ipython)
+  def test_dataframe_caching(self, cell):
+
+    # Create a pipeline that exercises the DataFrame API. This will also use
+    # caching in the background.
+    with cell:  # Cell 1
+      p = beam.Pipeline(interactive_runner.InteractiveRunner())
+      ib.watch({'p': p})
+
+    with cell:  # Cell 2
+      data = p | beam.Create([
+          1, 2, 3
+      ]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
+
+      with beam.dataframe.allow_non_parallel_operations():
+        df = to_dataframe(data).reset_index(drop=True)
+
+      ib.collect(df)
+
+    with cell:  # Cell 3
+      df['output'] = df['square'] * df['cube']
+      ib.collect(df)
+
+    with cell:  # Cell 4
+      df['output'] = 0
+      ib.collect(df)
+
+    # We use a trace through the graph to perform an isomorphism test. The end
+    # output should look like a linear graph. This indicates that the dataframe
+    # transform was correctly broken into separate pieces to cache. If caching
+    # isn't enabled, all the dataframe computation nodes are connected to a
+    # single shared node.
+    trace = []
+
+    # Only look at the top-level transforms for the isomorphism. The test
+    # doesn't care about the transform implementations, just the overall shape.
+    class TopLevelTracer(beam.pipeline.PipelineVisitor):
+      def _find_root_producer(self, node: beam.pipeline.AppliedPTransform):
+        if node is None or not node.full_label:
+          return None
+
+        parent = self._find_root_producer(node.parent)
+        if parent is None:
+          return node
+
+        return parent
+
+      def _add_to_trace(self, node, trace):
+        if '/' not in str(node):
+          if node.inputs:
+            producer = self._find_root_producer(node.inputs[0].producer)
+            producer_name = producer.full_label if producer else ''
+            trace.append((producer_name, node.full_label))
+
+      def visit_transform(self, node: beam.pipeline.AppliedPTransform):
+        self._add_to_trace(node, trace)
+
+      def enter_composite_transform(
+          self, node: beam.pipeline.AppliedPTransform):
+        self._add_to_trace(node, trace)
+
+    p.visit(TopLevelTracer())
+
+    # Do the isomorphism test which states that the topological sort of the
+    # graph yields a linear graph.
+    trace_string = '\n'.join(str(t) for t in trace)
+    prev_producer = ''
+    for producer, consumer in trace:
+      self.assertEqual(producer, prev_producer, trace_string)
+      prev_producer = consumer
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index 448f76a..c51a648 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -23,7 +23,6 @@
 import pandas as pd
 
 import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
 from apache_beam.dataframe.frame_base import DeferredBase
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.runners.interactive import background_caching_job as bcj
@@ -310,7 +309,7 @@
     # TODO(BEAM-12388): investigate the mixing pcollections in multiple
     # pipelines error when using the default label.
     for df in watched_dataframes:
-      pcoll = to_pcollection(df, yield_elements='pandas', label=str(df._expr))
+      pcoll, _ = utils.deferred_df_to_pcollection(df)
       watched_pcollections.add(pcoll)
     for pcoll in pcolls:
       if pcoll not in watched_pcollections:
diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png
index 96ed442..44cfc70 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png
+++ b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png
Binary files differ
diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png
index e2e1ad4..aa98c62 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png
+++ b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png
Binary files differ
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py
index bbe88ff..3e85145 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -25,7 +25,10 @@
 
 import pandas as pd
 
+from apache_beam.dataframe.convert import to_pcollection
+from apache_beam.dataframe.frame_base import DeferredBase
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
 from apache_beam.testing.test_stream import WindowedValueHolder
 from apache_beam.typehints.schemas import named_fields_from_element_type
 
@@ -267,3 +270,17 @@
       return str(return_value)
 
   return return_as_json
+
+
+def deferred_df_to_pcollection(df):
+  assert isinstance(df, DeferredBase), '{} is not a DeferredBase'.format(df)
+
+  # The proxy is used to output a DataFrame with the correct columns.
+  #
+  # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
+  # instead of the proxy.
+  cache = ExpressionCache()
+  cache.replace_with_cached(df._expr)
+
+  proxy = df._expr.proxy()
+  return to_pcollection(df, yield_elements='pandas', label=str(df._expr)), proxy
diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py
index 4088d3e..5c92eaf 100644
--- a/sdks/python/apache_beam/transforms/sideinputs.py
+++ b/sdks/python/apache_beam/transforms/sideinputs.py
@@ -55,6 +55,9 @@
   if target_window_fn == window.GlobalWindows():
     return _global_window_mapping_fn
 
+  if isinstance(target_window_fn, window.Sessions):
+    raise RuntimeError("Sessions is not allowed in side inputs")
+
   def map_via_end(source_window):
     # type: (window.BoundedWindow) -> window.BoundedWindow
     return list(
diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py
index 226252c..e5f7c24 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-"""Support for Dataflow triggers.
+"""Support for Apache Beam triggers.
 
 Triggers control when in processing time windows get emitted.
 """