[BEAM-12596] Ensure that SDFs always report a size >= 0.
diff --git a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
index db7021d..ef37772 100644
--- a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
+++ b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
@@ -409,6 +409,7 @@
"\x01\x00\x00\x00\x00\x02\x03foo\x01\xa9F\x03bar\x01\xff\xff\xff\xff\xff\xff\xff\xff\x7f": {f_map: {"foo": 9001, "bar": 9223372036854775807}}
"\x01\x00\x00\x00\x00\x04\neverything\x00\x02is\x00\x05null!\x00\r\xc2\xaf\\_(\xe3\x83\x84)_/\xc2\xaf\x00": {f_map: {"everything": null, "is": null, "null!": null, "¯\\_(ツ)_/¯": null}}
+---
coder:
urn: "beam:coder:row:v1"
@@ -440,3 +441,25 @@
shardId: "",
key: "key"
}
+
+---
+
+# Java code snippet to generate example bytes:
+# TimestampPrefixingWindowCoder<IntervalWindow> coder = TimestampPrefixingWindowCoder.of(IntervalWindowCoder.of());
+# Instant end = new Instant(-9223372036854410L);
+# Duration span = Duration.millis(365L);
+# IntervalWindow window = new IntervalWindow(end.minus(span), span);
+# byte[] bytes = CoderUtils.encodeToByteArray(coder, window);
+# String str = new String(bytes, java.nio.charset.StandardCharsets.ISO_8859_1);
+# String example = "";
+# for(int i = 0; i < str.length(); i++){
+# example += CharUtils.unicodeEscaped(str.charAt(i));
+# }
+# System.out.println(example);
+coder:
+ urn: "beam:coder:custom_window:v1"
+ components: [{urn: "beam:coder:interval_window:v1"}]
+
+examples:
+ "\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0067\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0068\u0080\u00dd\u00db\u0001" : {window: {end: 1454293425000, span: 3600000}}
+ "\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0075\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0076\u00ed\u0002" : {window: {end: -9223372036854410, span: 365}}
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index b5ad193..a367321 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -939,6 +939,31 @@
// Experimental.
STATE_BACKED_ITERABLE = 9 [(beam_urn) = "beam:coder:state_backed_iterable:v1"];
+
+ // Encodes an arbitrary user defined window and its max timestamp (inclusive).
+ // The encoding format is:
+ // maxTimestamp window
+ //
+ // maxTimestamp - A big endian 8 byte integer representing millis-since-epoch.
+ // The encoded representation is shifted so that the byte representation
+ // of negative values are lexicographically ordered before the byte
+ // representation of positive values. This is typically done by
+ // subtracting -9223372036854775808 from the value and encoding it as a
+ // signed big endian integer. Example values:
+ //
+ // -9223372036854775808: 00 00 00 00 00 00 00 00
+ // -255: 7F FF FF FF FF FF FF 01
+ // -1: 7F FF FF FF FF FF FF FF
+ // 0: 80 00 00 00 00 00 00 00
+ // 1: 80 00 00 00 00 00 00 01
+ // 256: 80 00 00 00 00 00 01 00
+ // 9223372036854775807: FF FF FF FF FF FF FF FF
+ //
+ // window - the window is encoded using the supplied window coder.
+ //
+ // Components: Coder for the custom window type.
+ CUSTOM_WINDOW = 16 [(beam_urn) = "beam:coder:custom_window:v1"];
+
// Additional Standard Coders
// --------------------------
// The following coders are not required to be implemented for an SDK or
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
index 38763e2..992a9eb 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
@@ -28,6 +28,7 @@
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -189,6 +190,20 @@
};
}
+ static CoderTranslator<TimestampPrefixingWindowCoder<?>> timestampPrefixingWindow() {
+ return new SimpleStructuredCoderTranslator<TimestampPrefixingWindowCoder<?>>() {
+ @Override
+ protected TimestampPrefixingWindowCoder<?> fromComponents(List<Coder<?>> components) {
+ return TimestampPrefixingWindowCoder.of((Coder<? extends BoundedWindow>) components.get(0));
+ }
+
+ @Override
+ public List<? extends Coder<?>> getComponents(TimestampPrefixingWindowCoder<?> from) {
+ return from.getComponents();
+ }
+ };
+ }
+
public abstract static class SimpleStructuredCoderTranslator<T extends Coder<?>>
implements CoderTranslator<T> {
@Override
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
index 44e2caa..1fc8379 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
@@ -32,6 +32,7 @@
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
@@ -74,6 +75,7 @@
.put(DoubleCoder.class, ModelCoders.DOUBLE_CODER_URN)
.put(RowCoder.class, ModelCoders.ROW_CODER_URN)
.put(ShardedKey.Coder.class, ModelCoders.SHARDED_KEY_CODER_URN)
+ .put(TimestampPrefixingWindowCoder.class, ModelCoders.CUSTOM_WINDOW_CODER_URN)
.build();
public static final Set<String> WELL_KNOWN_CODER_URNS = BEAM_MODEL_CODER_URNS.values();
@@ -96,6 +98,7 @@
.put(DoubleCoder.class, CoderTranslators.atomic(DoubleCoder.class))
.put(RowCoder.class, CoderTranslators.row())
.put(ShardedKey.Coder.class, CoderTranslators.shardedKey())
+ .put(TimestampPrefixingWindowCoder.class, CoderTranslators.timestampPrefixingWindow())
.build();
static {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
index 86e1e7d..0ff70f1 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
@@ -54,6 +54,8 @@
public static final String INTERVAL_WINDOW_CODER_URN =
getUrn(StandardCoders.Enum.INTERVAL_WINDOW);
+ public static final String CUSTOM_WINDOW_CODER_URN = getUrn(StandardCoders.Enum.CUSTOM_WINDOW);
+
public static final String WINDOWED_VALUE_CODER_URN = getUrn(StandardCoders.Enum.WINDOWED_VALUE);
public static final String PARAM_WINDOWED_VALUE_CODER_URN =
getUrn(StandardCoders.Enum.PARAM_WINDOWED_VALUE);
@@ -82,6 +84,7 @@
LENGTH_PREFIX_CODER_URN,
GLOBAL_WINDOW_CODER_URN,
INTERVAL_WINDOW_CODER_URN,
+ CUSTOM_WINDOW_CODER_URN,
WINDOWED_VALUE_CODER_URN,
DOUBLE_CODER_URN,
ROW_CODER_URN,
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
index 8a1af22..11543f3 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
@@ -45,6 +45,7 @@
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
@@ -95,6 +96,7 @@
Field.of("map", FieldType.map(FieldType.STRING, FieldType.INT32)),
Field.of("bar", FieldType.logicalType(FixedBytes.of(123))))))
.add(ShardedKey.Coder.of(StringUtf8Coder.of()))
+ .add(TimestampPrefixingWindowCoder.of(IntervalWindowCoder.of()))
.build();
/**
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java
index 6c7744f..d7d3665 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java
@@ -59,6 +59,7 @@
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
@@ -116,6 +117,7 @@
WindowedValue.ParamWindowedValueCoder.class)
.put(getUrn(StandardCoders.Enum.ROW), RowCoder.class)
.put(getUrn(StandardCoders.Enum.SHARDED_KEY), ShardedKey.Coder.class)
+ .put(getUrn(StandardCoders.Enum.CUSTOM_WINDOW), TimestampPrefixingWindowCoder.class)
.build();
@AutoValue
@@ -345,6 +347,10 @@
byte[] shardId = ((String) kvMap.get("shardId")).getBytes(StandardCharsets.ISO_8859_1);
return ShardedKey.of(
convertValue(kvMap.get("key"), coderSpec.getComponents().get(0), keyCoder), shardId);
+ } else if (s.equals(getUrn(StandardCoders.Enum.CUSTOM_WINDOW))) {
+ Map<String, Object> kvMap = (Map<String, Object>) value;
+ Coder windowCoder = ((TimestampPrefixingWindowCoder) coder).getWindowCoder();
+ return convertValue(kvMap.get("window"), coderSpec.getComponents().get(0), windowCoder);
} else {
throw new IllegalStateException("Unknown coder URN: " + coderSpec.getUrn());
}
@@ -502,6 +508,8 @@
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.SHARDED_KEY))) {
assertEquals(expectedValue, actualValue);
+ } else if (s.equals(getUrn(StandardCoders.Enum.CUSTOM_WINDOW))) {
+ assertEquals(expectedValue, actualValue);
} else {
throw new IllegalStateException("Unknown coder URN: " + coder.getUrn());
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java
index c6264f3..0c397f4 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectKinds.java
@@ -21,6 +21,7 @@
class CloudObjectKinds {
static final String KIND_GLOBAL_WINDOW = "kind:global_window";
static final String KIND_INTERVAL_WINDOW = "kind:interval_window";
+ static final String KIND_CUSTOM_WINDOW = "kind:custom_window";
static final String KIND_LENGTH_PREFIX = "kind:length_prefix";
static final String KIND_PAIR = "kind:pair";
static final String KIND_STREAM = "kind:stream";
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java
index e57205d..07377fd 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjectTranslators.java
@@ -33,6 +33,7 @@
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.coders.NullableCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder;
import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
@@ -241,6 +242,34 @@
};
}
+ static CloudObjectTranslator<TimestampPrefixingWindowCoder> customWindow() {
+ return new CloudObjectTranslator<TimestampPrefixingWindowCoder>() {
+ @Override
+ public CloudObject toCloudObject(
+ TimestampPrefixingWindowCoder target, SdkComponents sdkComponents) {
+ CloudObject result = CloudObject.forClassName(CloudObjectKinds.KIND_CUSTOM_WINDOW);
+ return addComponents(result, target.getComponents(), sdkComponents);
+ }
+
+ @Override
+ public TimestampPrefixingWindowCoder fromCloudObject(CloudObject cloudObject) {
+ List<Coder<?>> components = getComponents(cloudObject);
+ checkArgument(components.size() == 1, "Expecting 1 component, got %s", components.size());
+ return TimestampPrefixingWindowCoder.of((Coder<? extends BoundedWindow>) components.get(0));
+ }
+
+ @Override
+ public Class<? extends TimestampPrefixingWindowCoder> getSupportedClass() {
+ return TimestampPrefixingWindowCoder.class;
+ }
+
+ @Override
+ public String cloudObjectClassName() {
+ return CloudObjectKinds.KIND_CUSTOM_WINDOW;
+ }
+ };
+ }
+
/**
* Returns a {@link CloudObjectTranslator} that produces a {@link CloudObject} that is of kind
* "windowed_value".
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
index 766a32d..56c4d57 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java
@@ -31,6 +31,7 @@
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
@@ -58,7 +59,8 @@
Timer.Coder.class,
LengthPrefixCoder.class,
GlobalWindow.Coder.class,
- FullWindowedValueCoder.class);
+ FullWindowedValueCoder.class,
+ TimestampPrefixingWindowCoder.class);
static final Map<Class<? extends Coder>, CloudObjectTranslator<? extends Coder>>
CODER_TRANSLATORS = populateCoderTranslators();
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java
index 01101c6..92070ef 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DefaultCoderCloudObjectTranslatorRegistrar.java
@@ -65,6 +65,7 @@
ImmutableList.of(
CloudObjectTranslators.globalWindow(),
CloudObjectTranslators.intervalWindow(),
+ CloudObjectTranslators.customWindow(),
CloudObjectTranslators.bytes(),
CloudObjectTranslators.varInt(),
CloudObjectTranslators.lengthPrefix(),
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
index eb7e6cd..f1a004f 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java
@@ -51,6 +51,7 @@
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.SetCoder;
import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -144,6 +145,7 @@
.add(new ObjectCoder())
.add(GlobalWindow.Coder.INSTANCE)
.add(IntervalWindow.getCoder())
+ .add(TimestampPrefixingWindowCoder.of(IntervalWindow.getCoder()))
.add(LengthPrefixCoder.of(VarLongCoder.of()))
.add(IterableCoder.of(VarLongCoder.of()))
.add(KvCoder.of(VarLongCoder.of(), ByteArrayCoder.of()))
diff --git a/sdks/go/pkg/beam/core/graph/window/fn.go b/sdks/go/pkg/beam/core/graph/window/fn.go
index 7197571..dba90f6 100644
--- a/sdks/go/pkg/beam/core/graph/window/fn.go
+++ b/sdks/go/pkg/beam/core/graph/window/fn.go
@@ -29,7 +29,7 @@
GlobalWindows Kind = "GLO"
FixedWindows Kind = "FIX"
SlidingWindows Kind = "SLI"
- Sessions Kind = "SES" // TODO
+ Sessions Kind = "SES"
)
// NewGlobalWindows returns the default WindowFn, which places all elements
diff --git a/sdks/go/pkg/beam/core/runtime/exec/window.go b/sdks/go/pkg/beam/core/runtime/exec/window.go
index cc7accc..7e8d0c4 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/window.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/window.go
@@ -73,6 +73,11 @@
ret = append(ret, window.IntervalWindow{Start: start, End: start.Add(wfn.Size)})
}
return ret
+ case window.Sessions:
+ // Assign each element into a window from its timestamp until Gap in the
+ // future. Overlapping windows (representing elements within Gap of
+ // each other) will be merged.
+ return []typex.Window{window.IntervalWindow{Start: ts, End: ts.Add(wfn.Gap)}}
default:
panic(fmt.Sprintf("Unexpected window fn: %v", wfn))
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
index 099738f..7472e83 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
@@ -1036,7 +1036,7 @@
switch w.Kind {
case window.GlobalWindows:
return coder.NewGlobalWindow(), nil
- case window.FixedWindows, window.SlidingWindows, URNSlidingWindowsWindowFn:
+ case window.FixedWindows, window.SlidingWindows, window.Sessions, URNSlidingWindowsWindowFn:
return coder.NewIntervalWindow(), nil
default:
return nil, errors.Errorf("unexpected windowing strategy for coder: %v", w)
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index cdb7e47..2297443 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -60,6 +60,8 @@
var directFilters = []string{
// The direct runner does not yet support cross-language.
"TestXLang.*",
+ // TODO(BEAM-4152): The direct runner does not support session window merging.
+ "TestWindowSums.*",
}
var portableFilters = []string{}
diff --git a/sdks/go/test/integration/primitives/windowinto.go b/sdks/go/test/integration/primitives/windowinto.go
index 573f5c7..f6e3fbd 100644
--- a/sdks/go/test/integration/primitives/windowinto.go
+++ b/sdks/go/test/integration/primitives/windowinto.go
@@ -67,9 +67,8 @@
validate(s.Scope("SlidingFixed"), window.NewSlidingWindows(windowSize, windowSize), timestampedData, 15, 15, 15)
// This will have overlap, but each value should be a multiple of the magic number.
validate(s.Scope("Sliding"), window.NewSlidingWindows(windowSize, 3*windowSize), timestampedData, 15, 30, 45, 30, 15)
- // TODO(BEAM-4152): This will do a smoke test of session windows, once implemented through.
// With such a large gap, there should be a single session which will sum to 45.
- // validate(s.Scope("Session"), window.NewSessions(windowSize), timestampedData, 45)
+ validate(s.Scope("Session"), window.NewSessions(windowSize), timestampedData, 45)
}
func sumPerKey(ws beam.Window, ts beam.EventTime, key beam.U, iter func(*int) bool) (beam.U, int) {
diff --git a/sdks/go/test/regression/coders/fromyaml/fromyaml.go b/sdks/go/test/regression/coders/fromyaml/fromyaml.go
index 3544a5f..8ba9e99 100644
--- a/sdks/go/test/regression/coders/fromyaml/fromyaml.go
+++ b/sdks/go/test/regression/coders/fromyaml/fromyaml.go
@@ -45,6 +45,7 @@
"beam:coder:param_windowed_value:v1": true,
"beam:coder:timer:v1": true,
"beam:coder:sharded_key:v1": true,
+ "beam:coder:custom_window:v1": true,
}
// Coder is a representation a serialized beam coder.
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoder.java
new file mode 100644
index 0000000..ea7d56e
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoder.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.coders;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+
+/**
+ * A {@link TimestampPrefixingWindowCoder} wraps arbitrary user custom window coder. While encoding
+ * the custom window type, it extracts the maxTimestamp(inclusive) of the window and prefix it to
+ * the encoded bytes of the window using the user custom window coder.
+ *
+ * @param <T> The custom window type.
+ */
+public class TimestampPrefixingWindowCoder<T extends BoundedWindow> extends StructuredCoder<T> {
+ private final Coder<T> windowCoder;
+
+ public static <T extends BoundedWindow> TimestampPrefixingWindowCoder<T> of(
+ Coder<T> windowCoder) {
+ return new TimestampPrefixingWindowCoder<>(windowCoder);
+ }
+
+ TimestampPrefixingWindowCoder(Coder<T> windowCoder) {
+ this.windowCoder = windowCoder;
+ }
+
+ public Coder<T> getWindowCoder() {
+ return windowCoder;
+ }
+
+ @Override
+ public void encode(T value, OutputStream outStream) throws CoderException, IOException {
+ if (value == null) {
+ throw new CoderException("Cannot encode null window");
+ }
+ InstantCoder.of().encode(value.maxTimestamp(), outStream);
+ windowCoder.encode(value, outStream);
+ }
+
+ @Override
+ public T decode(InputStream inStream) throws CoderException, IOException {
+ InstantCoder.of().decode(inStream);
+ return windowCoder.decode(inStream);
+ }
+
+ @Override
+ public List<? extends Coder<?>> getCoderArguments() {
+ return Lists.newArrayList(windowCoder);
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ windowCoder.verifyDeterministic();
+ }
+
+ @Override
+ public boolean consistentWithEquals() {
+ return windowCoder.consistentWithEquals();
+ }
+
+ @Override
+ public boolean isRegisterByteSizeObserverCheap(T value) {
+ return windowCoder.isRegisterByteSizeObserverCheap(value);
+ }
+
+ @Override
+ public void registerByteSizeObserver(T value, ElementByteSizeObserver observer) throws Exception {
+ InstantCoder.of().registerByteSizeObserver(value.maxTimestamp(), observer);
+ windowCoder.registerByteSizeObserver(value, observer);
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java
index 87ffa4a..143901a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/IntervalWindow.java
@@ -27,6 +27,7 @@
import org.apache.beam.sdk.coders.DurationCoder;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -175,6 +176,19 @@
}
@Override
+ public boolean isRegisterByteSizeObserverCheap(IntervalWindow value) {
+ return instantCoder.isRegisterByteSizeObserverCheap(value.end)
+ && durationCoder.isRegisterByteSizeObserverCheap(new Duration(value.start, value.end));
+ }
+
+ @Override
+ public void registerByteSizeObserver(IntervalWindow value, ElementByteSizeObserver observer)
+ throws Exception {
+ instantCoder.registerByteSizeObserver(value.end, observer);
+ durationCoder.registerByteSizeObserver(new Duration(value.start, value.end), observer);
+ }
+
+ @Override
public List<? extends Coder<?>> getCoderArguments() {
return Collections.emptyList();
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoderTest.java
new file mode 100644
index 0000000..a9f8123
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/TimestampPrefixingWindowCoderTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.coders;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
+import java.util.Objects;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.CoderProperties.TestElementByteSizeObserver;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import org.junit.Test;
+
+public class TimestampPrefixingWindowCoderTest {
+
+ private static class CustomWindow extends IntervalWindow {
+ private boolean isBig;
+
+ CustomWindow(Instant start, Instant end, boolean isBig) {
+ super(start, end);
+ this.isBig = isBig;
+ }
+
+ @Override
+ public boolean equals(@Nullable Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ CustomWindow that = (CustomWindow) o;
+ return super.equals(o) && this.isBig == that.isBig;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), isBig);
+ }
+ }
+
+ private static class CustomWindowCoder extends CustomCoder<CustomWindow> {
+
+ private static final Coder<IntervalWindow> INTERVAL_WINDOW_CODER = IntervalWindow.getCoder();
+ private static final int REGISTER_BYTE_SIZE = 1234;
+ private final boolean isConsistentWithEqual;
+ private final boolean isRegisterByteSizeCheap;
+
+ public static CustomWindowCoder of(
+ boolean isConsistentWithEqual, boolean isRegisterByteSizeCheap) {
+ return new CustomWindowCoder(isConsistentWithEqual, isRegisterByteSizeCheap);
+ }
+
+ private CustomWindowCoder(boolean isConsistentWithEqual, boolean isRegisterByteSizeCheap) {
+ this.isConsistentWithEqual = isConsistentWithEqual;
+ this.isRegisterByteSizeCheap = isRegisterByteSizeCheap;
+ }
+
+ @Override
+ public void encode(CustomWindow window, OutputStream outStream) throws IOException {
+ INTERVAL_WINDOW_CODER.encode(window, outStream);
+ BooleanCoder.of().encode(window.isBig, outStream);
+ }
+
+ @Override
+ public CustomWindow decode(InputStream inStream) throws IOException {
+ IntervalWindow superWindow = INTERVAL_WINDOW_CODER.decode(inStream);
+ boolean isBig = BooleanCoder.of().decode(inStream);
+ return new CustomWindow(superWindow.start(), superWindow.end(), isBig);
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ INTERVAL_WINDOW_CODER.verifyDeterministic();
+ BooleanCoder.of().verifyDeterministic();
+ }
+
+ @Override
+ public boolean consistentWithEquals() {
+ return isConsistentWithEqual;
+ }
+
+ @Override
+ public boolean isRegisterByteSizeObserverCheap(CustomWindow value) {
+ return isRegisterByteSizeCheap;
+ }
+
+ @Override
+ public void registerByteSizeObserver(CustomWindow value, ElementByteSizeObserver observer)
+ throws Exception {
+ observer.update(REGISTER_BYTE_SIZE);
+ }
+ }
+
+ private static final List<CustomWindow> CUSTOM_WINDOW_LIST =
+ Lists.newArrayList(
+ new CustomWindow(new Instant(0L), new Instant(1L), true),
+ new CustomWindow(new Instant(100L), new Instant(200L), false),
+ new CustomWindow(new Instant(0L), BoundedWindow.TIMESTAMP_MAX_VALUE, true));
+
+ @Test
+ public void testEncodeAndDecode() throws Exception {
+ List<IntervalWindow> intervalWindowsToTest =
+ Lists.newArrayList(
+ new IntervalWindow(new Instant(0L), new Instant(1L)),
+ new IntervalWindow(new Instant(100L), new Instant(200L)),
+ new IntervalWindow(new Instant(0L), BoundedWindow.TIMESTAMP_MAX_VALUE));
+ TimestampPrefixingWindowCoder<IntervalWindow> coder1 =
+ TimestampPrefixingWindowCoder.of(IntervalWindow.getCoder());
+ for (IntervalWindow window : intervalWindowsToTest) {
+ CoderProperties.coderDecodeEncodeEqual(coder1, window);
+ }
+
+ GlobalWindow globalWindow = GlobalWindow.INSTANCE;
+ TimestampPrefixingWindowCoder<GlobalWindow> coder2 =
+ TimestampPrefixingWindowCoder.of(GlobalWindow.Coder.INSTANCE);
+ CoderProperties.coderDecodeEncodeEqual(coder2, globalWindow);
+ TimestampPrefixingWindowCoder<CustomWindow> coder3 =
+ TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+ for (CustomWindow window : CUSTOM_WINDOW_LIST) {
+ CoderProperties.coderDecodeEncodeEqual(coder3, window);
+ }
+ }
+
+ @Test
+ public void testConsistentWithEquals() {
+ TimestampPrefixingWindowCoder<CustomWindow> coder1 =
+ TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+ assertThat(coder1.consistentWithEquals(), equalTo(true));
+ TimestampPrefixingWindowCoder<CustomWindow> coder2 =
+ TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(false, true));
+ assertThat(coder2.consistentWithEquals(), equalTo(false));
+ }
+
+ @Test
+ public void testIsRegisterByteSizeObserverCheap() {
+ TimestampPrefixingWindowCoder<CustomWindow> coder1 =
+ TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+ assertThat(coder1.isRegisterByteSizeObserverCheap(CUSTOM_WINDOW_LIST.get(0)), equalTo(true));
+ TimestampPrefixingWindowCoder<CustomWindow> coder2 =
+ TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, false));
+ assertThat(coder2.isRegisterByteSizeObserverCheap(CUSTOM_WINDOW_LIST.get(0)), equalTo(false));
+ }
+
+ @Test
+ public void testGetEncodedElementByteSize() throws Exception {
+ TestElementByteSizeObserver observer = new TestElementByteSizeObserver();
+ TimestampPrefixingWindowCoder<CustomWindow> coder =
+ TimestampPrefixingWindowCoder.of(CustomWindowCoder.of(true, true));
+ for (CustomWindow value : CUSTOM_WINDOW_LIST) {
+ coder.registerByteSizeObserver(value, observer);
+ observer.advance();
+ assertThat(
+ observer.getSumAndReset(),
+ equalTo(
+ CustomWindowCoder.REGISTER_BYTE_SIZE
+ + InstantCoder.of().getEncodedElementByteSize(value.maxTimestamp())));
+ }
+ }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java
index 36daef0..a79ff6f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/IntervalWindowTest.java
@@ -22,10 +22,13 @@
import java.util.List;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.DurationCoder;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.CoderProperties.TestElementByteSizeObserver;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -87,4 +90,20 @@
assertThat(encodedHourWindow.length, equalTo(encodedStart.length + encodedHourEnd.length - 4));
assertThat(encodedDayWindow.length, equalTo(encodedStart.length + encodedDayEnd.length - 4));
}
+
+ @Test
+ public void testCoderRegisterByteSizeObserver() throws Exception {
+ assertThat(TEST_CODER.isRegisterByteSizeObserverCheap(TEST_VALUES.get(0)), equalTo(true));
+ TestElementByteSizeObserver observer = new TestElementByteSizeObserver();
+ TestElementByteSizeObserver observer2 = new TestElementByteSizeObserver();
+ for (IntervalWindow window : TEST_VALUES) {
+ TEST_CODER.registerByteSizeObserver(window, observer);
+ InstantCoder.of().registerByteSizeObserver(window.maxTimestamp(), observer2);
+ DurationCoder.of()
+ .registerByteSizeObserver(new Duration(window.start(), window.end()), observer2);
+ observer.advance();
+ observer2.advance();
+ assertThat(observer.getSumAndReset(), equalTo(observer2.getSumAndReset()));
+ }
+ }
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 6cc5e6e..b40c9d5 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -47,6 +47,7 @@
import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
+import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -310,6 +311,7 @@
}
});
try {
+ BeamFnLoggingMDC.setInstructionId(request.getInstructionId());
PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry();
PTransformFunctionRegistry finishFunctionRegistry =
bundleProcessor.getFinishFunctionRegistry();
@@ -364,6 +366,8 @@
// Make sure we clean-up from the active set of bundle processors.
bundleProcessorCache.discard(bundleProcessor);
throw e;
+ } finally {
+ BeamFnLoggingMDC.setInstructionId(null);
}
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
index 71d3b8e..6cd9116 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClient.java
@@ -211,6 +211,11 @@
.setSeconds(record.getMillis() / 1000)
.setNanos((int) (record.getMillis() % 1000) * 1_000_000));
+ String instructionId = BeamFnLoggingMDC.getInstructionId();
+ if (instructionId != null) {
+ builder.setInstructionId(instructionId);
+ }
+
Throwable thrown = record.getThrown();
if (thrown != null) {
builder.setTrace(getStackTraceAsString(thrown));
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingMDC.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingMDC.java
new file mode 100644
index 0000000..bcfcd4b
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/logging/BeamFnLoggingMDC.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.logging;
+
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/** Mapped diagnostic context to be consumed and set on LogEntry protos in BeamFnLoggingClient. */
+public class BeamFnLoggingMDC {
+ private static final ThreadLocal<@Nullable String> instructionId = new ThreadLocal<>();
+
+ /** Sets the Instruction ID of the current thread, which will be inherited by child threads. */
+ public static void setInstructionId(@Nullable String newInstructionId) {
+ instructionId.set(newInstructionId);
+ }
+
+ /** Gets the Instruction ID of the current thread. */
+ public static @Nullable String getInstructionId() {
+ return instructionId.get();
+ }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java
index f77798a..dd9e412 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java
@@ -50,13 +50,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
+import org.junit.rules.TestRule;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link BeamFnLoggingClient}. */
@RunWith(JUnit4.class)
public class BeamFnLoggingClientTest {
-
+ @Rule public TestRule restoreLogging = new RestoreBeamFnLoggingMDC();
private static final LogRecord FILTERED_RECORD;
private static final LogRecord TEST_RECORD;
private static final LogRecord TEST_RECORD_WITH_EXCEPTION;
@@ -78,6 +79,7 @@
private static final BeamFnApi.LogEntry TEST_ENTRY =
BeamFnApi.LogEntry.newBuilder()
+ .setInstructionId("instruction-1")
.setSeverity(BeamFnApi.LogEntry.Severity.Enum.DEBUG)
.setMessage("Message")
.setThread("12345")
@@ -86,6 +88,7 @@
.build();
private static final BeamFnApi.LogEntry TEST_ENTRY_WITH_EXCEPTION =
BeamFnApi.LogEntry.newBuilder()
+ .setInstructionId("instruction-1")
.setSeverity(BeamFnApi.LogEntry.Severity.Enum.WARN)
.setMessage("MessageWithException")
.setTrace(getStackTraceAsString(TEST_RECORD_WITH_EXCEPTION.getThrown()))
@@ -97,6 +100,7 @@
@Test
public void testLogging() throws Exception {
+ BeamFnLoggingMDC.setInstructionId("instruction-1");
AtomicBoolean clientClosedStream = new AtomicBoolean();
Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver =
@@ -175,6 +179,7 @@
@Test
public void testWhenServerFailsThatClientIsAbleToCleanup() throws Exception {
+ BeamFnLoggingMDC.setInstructionId("instruction-1");
Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver =
new AtomicReference<>();
@@ -244,6 +249,7 @@
@Test
public void testWhenServerHangsUpEarlyThatClientIsAbleCleanup() throws Exception {
+ BeamFnLoggingMDC.setInstructionId("instruction-1");
Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver =
new AtomicReference<>();
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDC.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDC.java
new file mode 100644
index 0000000..0036862
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDC.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.logging;
+
+import org.junit.rules.ExternalResource;
+
+/** Saves, clears and restores the current thread-local logging parameters for tests. */
+public class RestoreBeamFnLoggingMDC extends ExternalResource {
+ private String previousInstruction;
+
+ public RestoreBeamFnLoggingMDC() {}
+
+ @Override
+ protected void before() throws Throwable {
+ previousInstruction = BeamFnLoggingMDC.getInstructionId();
+ BeamFnLoggingMDC.setInstructionId(null);
+ }
+
+ @Override
+ protected void after() {
+ BeamFnLoggingMDC.setInstructionId(previousInstruction);
+ }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDCTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDCTest.java
new file mode 100644
index 0000000..47166b0
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/RestoreBeamFnLoggingMDCTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.logging;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.junit.runners.model.Statement;
+
+/** Tests for {@link RestoreBeamFnLoggingMDC}. */
+@RunWith(JUnit4.class)
+public class RestoreBeamFnLoggingMDCTest {
+ @Rule public TestRule restoreMDCAfterTest = new RestoreBeamFnLoggingMDC();
+
+ @Test
+ public void testOldValuesAreRestored() throws Throwable {
+ // We need our own instance here so that we don't squash any saved values.
+ TestRule restoreMDC = new RestoreBeamFnLoggingMDC();
+
+ final boolean[] evaluateRan = new boolean[1];
+ BeamFnLoggingMDC.setInstructionId("oldInstruction");
+
+ restoreMDC
+ .apply(
+ new Statement() {
+ @Override
+ public void evaluate() {
+ evaluateRan[0] = true;
+ // Ensure parameters are cleared before the test runs
+ assertNull("null Instruction", BeamFnLoggingMDC.getInstructionId());
+
+ // Simulate updating parameters for the test
+ BeamFnLoggingMDC.setInstructionId("newInstruction");
+
+ // Ensure that the values changed
+ assertEquals("newInstruction", BeamFnLoggingMDC.getInstructionId());
+ }
+ },
+ Description.EMPTY)
+ .evaluate();
+
+ // Validate that the statement ran and that the values were reverted
+ assertTrue(evaluateRan[0]);
+ assertEquals("oldInstruction", BeamFnLoggingMDC.getInstructionId());
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
index 0e87243..15ea5c0 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java
@@ -737,7 +737,7 @@
private WriteResult writeResult(Pipeline p) {
PCollection<TableRow> empty =
p.apply("CreateEmptyFailedInserts", Create.empty(TypeDescriptor.of(TableRow.class)));
- return WriteResult.in(p, new TupleTag<>("failedInserts"), empty);
+ return WriteResult.in(p, new TupleTag<>("failedInserts"), empty, null);
}
@Override
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
index 120d76f..484fe3b 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
@@ -70,8 +70,9 @@
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
class BatchedStreamingWrite<ErrorT, ElementT>
- extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollection<ErrorT>> {
+ extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollectionTuple> {
private static final TupleTag<Void> mainOutputTag = new TupleTag<>("mainOutput");
+ static final TupleTag<TableRow> SUCCESSFUL_ROWS_TAG = new TupleTag<>("successfulRows");
private static final Logger LOG = LoggerFactory.getLogger(BatchedStreamingWrite.class);
private final BigQueryServices bqServices;
@@ -191,23 +192,24 @@
}
@Override
- public PCollection<ErrorT> expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
+ public PCollectionTuple expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
return batchViaStateful
? input.apply(new ViaStateful())
: input.apply(new ViaBundleFinalization());
}
private class ViaBundleFinalization
- extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollection<ErrorT>> {
+ extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollectionTuple> {
@Override
- public PCollection<ErrorT> expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
+ public PCollectionTuple expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
PCollectionTuple result =
input.apply(
ParDo.of(new BatchAndInsertElements())
- .withOutputTags(mainOutputTag, TupleTagList.of(failedOutputTag)));
- PCollection<ErrorT> failedInserts = result.get(failedOutputTag);
- failedInserts.setCoder(failedOutputCoder);
- return failedInserts;
+ .withOutputTags(
+ mainOutputTag, TupleTagList.of(failedOutputTag).and(SUCCESSFUL_ROWS_TAG)));
+ result.get(failedOutputTag).setCoder(failedOutputCoder);
+ result.get(SUCCESSFUL_ROWS_TAG).setCoder(TableRowJsonCoder.of());
+ return result;
}
}
@@ -249,6 +251,7 @@
@FinishBundle
public void finishBundle(FinishBundleContext context) throws Exception {
List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
+ List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class);
for (Map.Entry<String, List<FailsafeValueInSingleWindow<TableRow, TableRow>>> entry :
tableRows.entrySet()) {
@@ -258,7 +261,8 @@
entry.getValue(),
uniqueIdsForTableRows.get(entry.getKey()),
options,
- failedInserts);
+ failedInserts,
+ successfulInserts);
}
tableRows.clear();
uniqueIdsForTableRows.clear();
@@ -274,9 +278,9 @@
private static final Duration BATCH_MAX_BUFFERING_DURATION = Duration.millis(200);
private class ViaStateful
- extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollection<ErrorT>> {
+ extends PTransform<PCollection<KV<String, TableRowInfo<ElementT>>>, PCollectionTuple> {
@Override
- public PCollection<ErrorT> expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
+ public PCollectionTuple expand(PCollection<KV<String, TableRowInfo<ElementT>>> input) {
BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class);
Duration maxBufferingDuration =
options.getMaxBufferingDurationMilliSec() > 0
@@ -306,10 +310,12 @@
ShardedKey.Coder.of(StringUtf8Coder.of()), IterableCoder.of(valueCoder)))
.apply(
ParDo.of(new InsertBatchedElements())
- .withOutputTags(mainOutputTag, TupleTagList.of(failedOutputTag)));
- PCollection<ErrorT> failedInserts = result.get(failedOutputTag);
- failedInserts.setCoder(failedOutputCoder);
- return failedInserts;
+ .withOutputTags(
+ mainOutputTag,
+ TupleTagList.of(failedOutputTag).and(SUCCESSFUL_ROWS_TAG)));
+ result.get(failedOutputTag).setCoder(failedOutputCoder);
+ result.get(SUCCESSFUL_ROWS_TAG).setCoder(TableRowJsonCoder.of());
+ return result;
}
}
@@ -340,11 +346,15 @@
BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class);
TableReference tableReference = BigQueryHelpers.parseTableSpec(input.getKey().getKey());
List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
- flushRows(tableReference, tableRows, uniqueIds, options, failedInserts);
+ List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
+ flushRows(tableReference, tableRows, uniqueIds, options, failedInserts, successfulInserts);
for (ValueInSingleWindow<ErrorT> row : failedInserts) {
out.get(failedOutputTag).output(row.getValue());
}
+ for (ValueInSingleWindow<TableRow> row : successfulInserts) {
+ out.get(SUCCESSFUL_ROWS_TAG).output(row.getValue());
+ }
reportStreamingApiLogging(options);
}
}
@@ -367,7 +377,8 @@
List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows,
List<String> uniqueIds,
BigQueryOptions options,
- List<ValueInSingleWindow<ErrorT>> failedInserts)
+ List<ValueInSingleWindow<ErrorT>> failedInserts,
+ List<ValueInSingleWindow<TableRow>> successfulInserts)
throws InterruptedException {
if (!tableRows.isEmpty()) {
try {
@@ -382,7 +393,8 @@
errorContainer,
skipInvalidRows,
ignoreUnknownValues,
- ignoreInsertIds);
+ ignoreInsertIds,
+ successfulInserts);
byteCounter.inc(totalBytes);
} catch (IOException e) {
throw new RuntimeException(e);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java
index c0b945c..f7b95e7 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServices.java
@@ -166,7 +166,8 @@
ErrorContainer<T> errorContainer,
boolean skipInvalidRows,
boolean ignoreUnknownValues,
- boolean ignoreInsertIds)
+ boolean ignoreInsertIds,
+ List<ValueInSingleWindow<TableRow>> successfulRows)
throws IOException, InterruptedException;
/** Patch BigQuery {@link Table} description. */
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
index 5e9cdb7..d788bfe 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
@@ -86,9 +86,11 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@@ -814,7 +816,8 @@
ErrorContainer<T> errorContainer,
boolean skipInvalidRows,
boolean ignoreUnkownValues,
- boolean ignoreInsertIds)
+ boolean ignoreInsertIds,
+ List<ValueInSingleWindow<TableRow>> successfulRows)
throws IOException, InterruptedException {
checkNotNull(ref, "ref");
if (executor == null) {
@@ -829,6 +832,7 @@
+ "as many elements as rowList");
}
+ final Set<Integer> failedIndices = new HashSet<>();
long retTotalDataSize = 0;
List<TableDataInsertAllResponse.InsertErrors> allErrors = new ArrayList<>();
// These lists contain the rows to publish. Initially the contain the entire list.
@@ -981,6 +985,7 @@
throw new IOException("Insert failed: " + error + ", other errors: " + allErrors);
}
int errorIndex = error.getIndex().intValue() + strideIndices.get(i);
+ failedIndices.add(errorIndex);
if (retryPolicy.shouldRetry(new InsertRetryPolicy.Context(error))) {
allErrors.add(error);
retryRows.add(rowsToPublish.get(errorIndex));
@@ -1022,6 +1027,18 @@
allErrors.clear();
LOG.info("Retrying {} failed inserts to BigQuery", rowsToPublish.size());
}
+ if (successfulRows != null) {
+ for (int i = 0; i < rowsToPublish.size(); i++) {
+ if (!failedIndices.contains(i)) {
+ successfulRows.add(
+ ValueInSingleWindow.of(
+ rowsToPublish.get(i).getValue(),
+ rowsToPublish.get(i).getTimestamp(),
+ rowsToPublish.get(i).getWindow(),
+ rowsToPublish.get(i).getPane()));
+ }
+ }
+ }
if (!allErrors.isEmpty()) {
throw new IOException("Insert failed: " + allErrors);
} else {
@@ -1039,7 +1056,8 @@
ErrorContainer<T> errorContainer,
boolean skipInvalidRows,
boolean ignoreUnknownValues,
- boolean ignoreInsertIds)
+ boolean ignoreInsertIds,
+ List<ValueInSingleWindow<TableRow>> successfulRows)
throws IOException, InterruptedException {
return insertAll(
ref,
@@ -1053,7 +1071,8 @@
errorContainer,
skipInvalidRows,
ignoreUnknownValues,
- ignoreInsertIds);
+ ignoreInsertIds,
+ successfulRows);
}
protected GoogleJsonError.ErrorInfo getErrorInfo(IOException e) {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
index 602ebea..3c27ddc 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java
@@ -142,6 +142,6 @@
// large.
PCollection<TableRow> empty =
p.apply("CreateEmptyFailedInserts", Create.empty(TypeDescriptor.of(TableRow.class)));
- return WriteResult.in(p, new TupleTag<>("failedInserts"), empty);
+ return WriteResult.in(p, new TupleTag<>("failedInserts"), empty, null);
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java
index 9bfe3d2..5afd1ae 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java
@@ -34,6 +34,7 @@
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.ShardedKey;
import org.apache.beam.sdk.values.TupleTag;
@@ -245,26 +246,36 @@
public WriteResult expand(PCollection<KV<TableDestination, ElementT>> input) {
if (extendedErrorInfo) {
TupleTag<BigQueryInsertError> failedInsertsTag = new TupleTag<>(FAILED_INSERTS_TAG_ID);
- PCollection<BigQueryInsertError> failedInserts =
+ PCollectionTuple result =
writeAndGetErrors(
input,
failedInsertsTag,
BigQueryInsertErrorCoder.of(),
ErrorContainer.BIG_QUERY_INSERT_ERROR_ERROR_CONTAINER);
- return WriteResult.withExtendedErrors(input.getPipeline(), failedInsertsTag, failedInserts);
+ PCollection<BigQueryInsertError> failedInserts = result.get(failedInsertsTag);
+ return WriteResult.withExtendedErrors(
+ input.getPipeline(),
+ failedInsertsTag,
+ failedInserts,
+ result.get(BatchedStreamingWrite.SUCCESSFUL_ROWS_TAG));
} else {
TupleTag<TableRow> failedInsertsTag = new TupleTag<>(FAILED_INSERTS_TAG_ID);
- PCollection<TableRow> failedInserts =
+ PCollectionTuple result =
writeAndGetErrors(
input,
failedInsertsTag,
TableRowJsonCoder.of(),
ErrorContainer.TABLE_ROW_ERROR_CONTAINER);
- return WriteResult.in(input.getPipeline(), failedInsertsTag, failedInserts);
+ PCollection<TableRow> failedInserts = result.get(failedInsertsTag);
+ return WriteResult.in(
+ input.getPipeline(),
+ failedInsertsTag,
+ failedInserts,
+ result.get(BatchedStreamingWrite.SUCCESSFUL_ROWS_TAG));
}
}
- private <T> PCollection<T> writeAndGetErrors(
+ private <T> PCollectionTuple writeAndGetErrors(
PCollection<KV<TableDestination, ElementT>> input,
TupleTag<T> failedInsertsTag,
AtomicCoder<T> coder,
@@ -336,7 +347,8 @@
return shardedTagged
.apply(Reshuffle.of())
- // Put in the global window to ensure that DynamicDestinations side inputs are accessed
+ // Put in the global window to ensure that DynamicDestinations side inputs are
+ // accessed
// correctly.
.apply(
"GlobalWindow",
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java
index fd75f671..85ef275 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java
@@ -29,6 +29,7 @@
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
/** The result of a {@link BigQueryIO.Write} transform. */
@SuppressWarnings({
@@ -40,18 +41,25 @@
private final PCollection<TableRow> failedInserts;
private final TupleTag<BigQueryInsertError> failedInsertsWithErrTag;
private final PCollection<BigQueryInsertError> failedInsertsWithErr;
+ private final PCollection<TableRow> successfulInserts;
/** Creates a {@link WriteResult} in the given {@link Pipeline}. */
static WriteResult in(
- Pipeline pipeline, TupleTag<TableRow> failedInsertsTag, PCollection<TableRow> failedInserts) {
- return new WriteResult(pipeline, failedInsertsTag, failedInserts, null, null);
+ Pipeline pipeline,
+ TupleTag<TableRow> failedInsertsTag,
+ PCollection<TableRow> failedInserts,
+ @Nullable PCollection<TableRow> successfulInserts) {
+ return new WriteResult(
+ pipeline, failedInsertsTag, failedInserts, null, null, successfulInserts);
}
static WriteResult withExtendedErrors(
Pipeline pipeline,
TupleTag<BigQueryInsertError> failedInsertsTag,
- PCollection<BigQueryInsertError> failedInserts) {
- return new WriteResult(pipeline, null, null, failedInsertsTag, failedInserts);
+ PCollection<BigQueryInsertError> failedInserts,
+ PCollection<TableRow> successfulInserts) {
+ return new WriteResult(
+ pipeline, null, null, failedInsertsTag, failedInserts, successfulInserts);
}
@Override
@@ -68,12 +76,14 @@
TupleTag<TableRow> failedInsertsTag,
PCollection<TableRow> failedInserts,
TupleTag<BigQueryInsertError> failedInsertsWithErrTag,
- PCollection<BigQueryInsertError> failedInsertsWithErr) {
+ PCollection<BigQueryInsertError> failedInsertsWithErr,
+ PCollection<TableRow> successfulInserts) {
this.pipeline = pipeline;
this.failedInsertsTag = failedInsertsTag;
this.failedInserts = failedInserts;
this.failedInsertsWithErrTag = failedInsertsWithErrTag;
this.failedInsertsWithErr = failedInsertsWithErr;
+ this.successfulInserts = successfulInserts;
}
/**
@@ -91,6 +101,14 @@
return failedInserts;
}
+ /** Returns a {@link PCollection} containing the {@link TableRow}s that were written to BQ. */
+ public PCollection<TableRow> getSuccessfulInserts() {
+ checkArgument(
+ successfulInserts != null,
+ "Retrieving successful inserts is only supported for streaming inserts.");
+ return successfulInserts;
+ }
+
/**
* Returns a {@link PCollection} containing the {@link BigQueryInsertError}s with detailed error
* information.
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
index 6bfd2b8..1ee6501 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeDatasetService.java
@@ -323,7 +323,8 @@
null,
false,
false,
- false);
+ false,
+ null);
}
@Override
@@ -336,7 +337,8 @@
ErrorContainer<T> errorContainer,
boolean skipInvalidRows,
boolean ignoreUnknownValues,
- boolean ignoreInsertIds)
+ boolean ignoreInsertIds,
+ List<ValueInSingleWindow<TableRow>> successfulRows)
throws IOException, InterruptedException {
Map<TableRow, List<TableDataInsertAllResponse.InsertErrors>> insertErrors = getInsertErrors();
synchronized (tables) {
@@ -371,6 +373,14 @@
} else {
dataSize += tableContainer.addRow(row, insertIdList.get(i));
}
+ if (successfulRows != null) {
+ successfulRows.add(
+ ValueInSingleWindow.of(
+ row,
+ rowList.get(i).getTimestamp(),
+ rowList.get(i).getWindow(),
+ rowList.get(i).getPane()));
+ }
} else {
errorContainer.add(
failedInserts, allErrors.get(allErrors.size() - 1), ref, rowList.get(i));
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
index e4c32f5..a891c48 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java
@@ -173,7 +173,6 @@
FileSystems.copy(sourceFiles.build(), loadFiles.build());
filesForLoadJobs.put(jobRef.getProjectId(), jobRef.getJobId(), loadFiles.build());
}
-
allJobs.put(jobRef.getProjectId(), jobRef.getJobId(), new JobInfo(job));
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
index 946e02f..bde803d 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
@@ -76,6 +76,7 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write;
+import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption;
import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices;
@@ -789,13 +790,13 @@
row1, ImmutableList.of(ephemeralError, ephemeralError),
row2, ImmutableList.of(ephemeralError, ephemeralError, persistentError)));
- PCollection<TableRow> failedRows =
+ WriteResult result =
p.apply(Create.of(row1, row2, row3))
.apply(
BigQueryIO.writeTableRows()
.to("project-id:dataset-id.table-id")
- .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)
- .withMethod(BigQueryIO.Write.Method.STREAMING_INSERTS)
+ .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED)
+ .withMethod(Method.STREAMING_INSERTS)
.withSchema(
new TableSchema()
.setFields(
@@ -804,11 +805,15 @@
new TableFieldSchema().setName("number").setType("INTEGER"))))
.withFailedInsertRetryPolicy(InsertRetryPolicy.retryTransientErrors())
.withTestServices(fakeBqServices)
- .withoutValidation())
- .getFailedInserts();
+ .withoutValidation());
+
+ PCollection<TableRow> failedRows = result.getFailedInserts();
// row2 finally fails with a non-retryable error, so we expect to see it in the collection of
// failed rows.
PAssert.that(failedRows).containsInAnyOrder(row2);
+ if (useStorageApi || !useStreaming) {
+ PAssert.that(result.getSuccessfulInserts()).containsInAnyOrder(row1, row3);
+ }
p.run();
// Only row1 and row3 were successfully inserted.
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
index ae788d7..61b94f6 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImplTest.java
@@ -655,7 +655,8 @@
null,
false,
false,
- false);
+ false,
+ null);
verifyAllResponsesAreRead();
expectedLogs.verifyInfo("BigQuery insertAll error, retrying:");
@@ -699,7 +700,8 @@
null,
false,
false,
- false);
+ false,
+ null);
verifyAllResponsesAreRead();
expectedLogs.verifyInfo("BigQuery insertAll error, retrying:");
@@ -755,7 +757,8 @@
null,
false,
false,
- false);
+ false,
+ null);
verifyAllResponsesAreRead();
@@ -813,7 +816,8 @@
null,
false,
false,
- false);
+ false,
+ null);
verifyAllResponsesAreRead();
@@ -880,7 +884,8 @@
null,
false,
false,
- false);
+ false,
+ null);
fail();
} catch (IOException e) {
assertThat(e, instanceOf(IOException.class));
@@ -940,7 +945,8 @@
null,
false,
false,
- false);
+ false,
+ null);
} finally {
verify(responses[0], atLeastOnce()).getStatusCode();
verify(responses[0]).getContent();
@@ -1023,7 +1029,8 @@
ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
false,
false,
- false);
+ false,
+ null);
assertEquals(1, failedInserts.size());
expectedLogs.verifyInfo("Retrying 1 failed inserts to BigQuery");
@@ -1069,7 +1076,8 @@
ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
false,
false,
- false);
+ false,
+ null);
TableDataInsertAllRequest parsedRequest =
fromString(request.getContentAsString(), TableDataInsertAllRequest.class);
@@ -1090,7 +1098,8 @@
ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
true,
true,
- true);
+ true,
+ null);
parsedRequest = fromString(request.getContentAsString(), TableDataInsertAllRequest.class);
@@ -1325,7 +1334,8 @@
ErrorContainer.TABLE_ROW_ERROR_CONTAINER,
false,
false,
- false);
+ false,
+ null);
assertThat(failedInserts, is(expected));
}
@@ -1384,7 +1394,8 @@
ErrorContainer.BIG_QUERY_INSERT_ERROR_ERROR_CONTAINER,
false,
false,
- false);
+ false,
+ null);
assertThat(failedInserts, is(expected));
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java
index 5626812..1e64e22 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilTest.java
@@ -213,7 +213,16 @@
try {
totalBytes =
datasetService.insertAll(
- ref, rows, ids, InsertRetryPolicy.alwaysRetry(), null, null, false, false, false);
+ ref,
+ rows,
+ ids,
+ InsertRetryPolicy.alwaysRetry(),
+ null,
+ null,
+ false,
+ false,
+ false,
+ null);
} finally {
verifyInsertAll(5);
// Each of the 25 rows has 1 byte for length and 30 bytes: '{"f":[{"v":"foo"},{"v":1234}]}'
diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py
index 0cf0208..618ee55 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -1483,3 +1483,31 @@
estimated_size += (
self._key_coder_impl.estimate_size(value.key, nested=True))
return estimated_size
+
+
+class TimestampPrefixingWindowCoderImpl(StreamCoderImpl):
+ """For internal use only; no backwards-compatibility guarantees.
+
+ A coder for custom window types, which prefix required max_timestamp to
+ encoded original window.
+
+ The coder encodes and decodes custom window types with following format:
+ window's max_timestamp()
+ encoded window using it's own coder.
+ """
+ def __init__(self, window_coder_impl: CoderImpl) -> None:
+ self._window_coder_impl = window_coder_impl
+
+ def encode_to_stream(self, value, stream, nested):
+ TimestampCoderImpl().encode_to_stream(value.max_timestamp(), stream, nested)
+ self._window_coder_impl.encode_to_stream(value, stream, nested)
+
+ def decode_from_stream(self, stream, nested):
+ TimestampCoderImpl().decode_from_stream(stream, nested)
+ return self._window_coder_impl.decode_from_stream(stream, nested)
+
+ def estimate_size(self, value: Any, nested: bool = False) -> int:
+ estimated_size = 0
+ estimated_size += TimestampCoderImpl().estimate_size(value)
+ estimated_size += self._window_coder_impl.estimate_size(value, nested)
+ return estimated_size
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index 163fd17..2d2b336 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -1509,3 +1509,47 @@
Coder.register_structured_urn(
common_urns.coders.SHARDED_KEY.urn, ShardedKeyCoder)
+
+
+class TimestampPrefixingWindowCoder(FastCoder):
+ """For internal use only; no backwards-compatibility guarantees.
+
+ Coder which prefixes the max timestamp of arbitrary window to its encoded
+ form."""
+ def __init__(self, window_coder: Coder) -> None:
+ self._window_coder = window_coder
+
+ def _create_impl(self):
+ return coder_impl.TimestampPrefixingWindowCoderImpl(
+ self._window_coder.get_impl())
+
+ def to_type_hint(self):
+ return self._window_coder.to_type_hint()
+
+ def _get_component_coders(self) -> List[Coder]:
+ return [self._window_coder]
+
+ def is_deterministic(self) -> bool:
+ return self._window_coder.is_deterministic()
+
+ def as_cloud_object(self, coders_context=None):
+ return {
+ '@type': 'kind:custom_window',
+ 'component_encodings': [
+ self._window_coder.as_cloud_object(coders_context)
+ ],
+ }
+
+ def __repr__(self):
+ return 'TimestampPrefixingWindowCoder[%r]' % self._window_coder
+
+ def __eq__(self, other):
+ return (
+ type(self) == type(other) and self._window_coder == other._window_coder)
+
+ def __hash__(self):
+ return hash((type(self), self._window_coder))
+
+
+Coder.register_structured_urn(
+ common_urns.coders.CUSTOM_WINDOW.urn, TimestampPrefixingWindowCoder)
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py
index 98e0547..11334ed 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -749,6 +749,19 @@
coders.TupleCoder((coder, other_coder)),
(ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
+ def test_timestamp_prefixing_window_coder(self):
+ self.check_coder(
+ coders.TimestampPrefixingWindowCoder(coders.IntervalWindowCoder()),
+ *[
+ window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52]
+ for y in range(-100, 100)
+ ])
+ self.check_coder(
+ coders.TupleCoder((
+ coders.TimestampPrefixingWindowCoder(
+ coders.IntervalWindowCoder()), )),
+ (window.IntervalWindow(0, 10), ))
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py
index 1a74dbe..454939f 100644
--- a/sdks/python/apache_beam/coders/standard_coders_test.py
+++ b/sdks/python/apache_beam/coders/standard_coders_test.py
@@ -189,7 +189,9 @@
'beam:coder:double:v1': parse_float,
'beam:coder:sharded_key:v1': lambda x,
value_parser: ShardedKey(
- key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8'))
+ key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
+ 'beam:coder:custom_window:v1': lambda x,
+ window_parser: window_parser(x['window'])
}
def test_standard_coders(self):
diff --git a/sdks/python/apache_beam/runners/interactive/caching/__init__.py b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
index 97b1be9..cce3aca 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/__init__.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/__init__.py
@@ -14,4 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
new file mode 100644
index 0000000..5b1b9ef
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache.py
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import *
+
+import apache_beam as beam
+from apache_beam.dataframe import convert
+from apache_beam.dataframe import expressions
+
+
+class ExpressionCache(object):
+ """Utility class for caching deferred DataFrames expressions.
+
+ This is cache is currently a light-weight wrapper around the
+ TO_PCOLLECTION_CACHE in the beam.dataframes.convert module and the
+ computed_pcollections in the interactive module.
+
+ Example::
+
+ df : beam.dataframe.DeferredDataFrame = ...
+ ...
+ cache = ExpressionCache()
+ cache.replace_with_cached(df._expr)
+
+ This will automatically link the instance to the existing caches. After it is
+ created, the cache can then be used to modify an existing deferred dataframe
+ expression tree to replace nodes with computed PCollections.
+
+ This object can be created and destroyed whenever. This class holds no state
+ and the only side-effect is modifying the given expression.
+ """
+ def __init__(self, pcollection_cache=None, computed_cache=None):
+ from apache_beam.runners.interactive import interactive_environment as ie
+
+ self._pcollection_cache = (
+ convert.TO_PCOLLECTION_CACHE
+ if pcollection_cache is None else pcollection_cache)
+ self._computed_cache = (
+ ie.current_env().computed_pcollections
+ if computed_cache is None else computed_cache)
+
+ def replace_with_cached(
+ self, expr: expressions.Expression) -> Dict[str, expressions.Expression]:
+ """Replaces any previously computed expressions with PlaceholderExpressions.
+
+ This is used to correctly read any expressions that were cached in previous
+ runs. This enables the InteractiveRunner to prune off old calculations from
+ the expression tree.
+ """
+
+ replaced_inputs: Dict[str, expressions.Expression] = {}
+ self._replace_with_cached_recur(expr, replaced_inputs)
+ return replaced_inputs
+
+ def _replace_with_cached_recur(
+ self,
+ expr: expressions.Expression,
+ replaced_inputs: Dict[str, expressions.Expression]) -> None:
+ """Recursive call for `replace_with_cached`.
+
+ Recurses through the expression tree and replaces any cached inputs with
+ `PlaceholderExpression`s.
+ """
+
+ final_inputs = []
+
+ for input in expr.args():
+ pc = self._get_cached(input)
+
+ # Only read from cache when there is the PCollection has been fully
+ # computed. This is so that no partial results are used.
+ if self._is_computed(pc):
+
+ # Reuse previously seen cached expressions. This is so that the same
+ # value isn't cached multiple times.
+ if input._id in replaced_inputs:
+ cached = replaced_inputs[input._id]
+ else:
+ cached = expressions.PlaceholderExpression(
+ input.proxy(), self._pcollection_cache[input._id])
+
+ replaced_inputs[input._id] = cached
+ final_inputs.append(cached)
+ else:
+ final_inputs.append(input)
+ self._replace_with_cached_recur(input, replaced_inputs)
+ expr._args = tuple(final_inputs)
+
+ def _get_cached(self,
+ expr: expressions.Expression) -> Optional[beam.PCollection]:
+ """Returns the PCollection associated with the expression."""
+ return self._pcollection_cache.get(expr._id, None)
+
+ def _is_computed(self, pc: beam.PCollection) -> bool:
+ """Returns True if the PCollection has been run and computed."""
+ return pc is not None and pc in self._computed_cache
diff --git a/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
new file mode 100644
index 0000000..c6e46f3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/caching/expression_cache_test.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import apache_beam as beam
+from apache_beam.dataframe import expressions
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
+
+
+class ExpressionCacheTest(unittest.TestCase):
+ def setUp(self):
+ self._pcollection_cache = {}
+ self._computed_cache = set()
+ self._pipeline = beam.Pipeline()
+ self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache)
+
+ def create_trace(self, expr):
+ trace = [expr]
+ for input in expr.args():
+ trace += self.create_trace(input)
+ return trace
+
+ def mock_cache(self, expr):
+ pcoll = beam.PCollection(self._pipeline)
+ self._pcollection_cache[expr._id] = pcoll
+ self._computed_cache.add(pcoll)
+
+ def assertTraceTypes(self, expr, expected):
+ actual_types = [type(e).__name__ for e in self.create_trace(expr)]
+ expected_types = [e.__name__ for e in expected]
+ self.assertListEqual(actual_types, expected_types)
+
+ def test_only_replaces_cached(self):
+ in_expr = expressions.ConstantExpression(0)
+ comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr])
+
+ # Expect that no replacement of expressions is performed.
+ expected_trace = [
+ expressions.ComputedExpression, expressions.ConstantExpression
+ ]
+ self.assertTraceTypes(comp_expr, expected_trace)
+
+ self.cache.replace_with_cached(comp_expr)
+
+ self.assertTraceTypes(comp_expr, expected_trace)
+
+ # Now "cache" the expression and assert that the cached expression was
+ # replaced with a placeholder.
+ self.mock_cache(in_expr)
+
+ replaced = self.cache.replace_with_cached(comp_expr)
+
+ expected_trace = [
+ expressions.ComputedExpression, expressions.PlaceholderExpression
+ ]
+ self.assertTraceTypes(comp_expr, expected_trace)
+ self.assertIn(in_expr._id, replaced)
+
+ def test_only_replaces_inputs(self):
+ arg_0_expr = expressions.ConstantExpression(0)
+ ident_val = expressions.ComputedExpression(
+ 'ident', lambda x: x, [arg_0_expr])
+
+ arg_1_expr = expressions.ConstantExpression(1)
+ comp_expr = expressions.ComputedExpression(
+ 'add', lambda x, y: x + y, [ident_val, arg_1_expr])
+
+ self.mock_cache(ident_val)
+
+ replaced = self.cache.replace_with_cached(comp_expr)
+
+ # Assert that ident_val was replaced and that its arguments were removed
+ # from the expression tree.
+ expected_trace = [
+ expressions.ComputedExpression,
+ expressions.PlaceholderExpression,
+ expressions.ConstantExpression
+ ]
+ self.assertTraceTypes(comp_expr, expected_trace)
+ self.assertIn(ident_val._id, replaced)
+ self.assertNotIn(arg_0_expr, self.create_trace(comp_expr))
+
+ def test_only_caches_same_input(self):
+ arg_0_expr = expressions.ConstantExpression(0)
+ ident_val = expressions.ComputedExpression(
+ 'ident', lambda x: x, [arg_0_expr])
+ comp_expr = expressions.ComputedExpression(
+ 'add', lambda x, y: x + y, [ident_val, arg_0_expr])
+
+ self.mock_cache(arg_0_expr)
+
+ replaced = self.cache.replace_with_cached(comp_expr)
+
+ # Assert that arg_0_expr, being an input to two computations, was replaced
+ # with the same placeholder expression.
+ expected_trace = [
+ expressions.ComputedExpression,
+ expressions.ComputedExpression,
+ expressions.PlaceholderExpression,
+ expressions.PlaceholderExpression
+ ]
+ actual_trace = self.create_trace(comp_expr)
+ unique_placeholders = set(
+ t for t in actual_trace
+ if isinstance(t, expressions.PlaceholderExpression))
+ self.assertTraceTypes(comp_expr, expected_trace)
+ self.assertTrue(
+ all(e == replaced[arg_0_expr._id] for e in unique_placeholders))
+ self.assertIn(arg_0_expr._id, replaced)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
index c4d062e..0b3f8a3 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py
@@ -37,12 +37,12 @@
import pandas as pd
import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.display import pipeline_graph
from apache_beam.runners.interactive.display.pcoll_visualization import visualize
from apache_beam.runners.interactive.options import interactive_options
+from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
from apache_beam.runners.interactive.utils import elements_to_df
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.runners.runner import PipelineState
@@ -455,10 +455,7 @@
element_types = {}
for pcoll in flatten_pcolls:
if isinstance(pcoll, DeferredBase):
- proxy = pcoll._expr.proxy()
- pcoll = to_pcollection(
- pcoll, yield_elements='pandas', label=str(pcoll._expr))
- element_type = proxy
+ pcoll, element_type = deferred_df_to_pcollection(pcoll)
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
else:
element_type = pcoll.element_type
@@ -569,11 +566,7 @@
# collect the result in elements_to_df.
if isinstance(pcoll, DeferredBase):
# Get the proxy so we can get the output shape of the DataFrame.
- # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
- # instead of the proxy.
- element_type = pcoll._expr.proxy()
- pcoll = to_pcollection(
- pcoll, yield_elements='pandas', label=str(pcoll._expr))
+ pcoll, element_type = deferred_df_to_pcollection(pcoll)
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
else:
element_type = pcoll.element_type
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index fe6989e..69551de 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -408,6 +408,81 @@
df_expected['cube'],
ib.collect(df['cube'], n=10).reset_index(drop=True))
+ @unittest.skipIf(
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
+ @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
+ @patch('IPython.get_ipython', new_callable=mock_get_ipython)
+ def test_dataframe_caching(self, cell):
+
+ # Create a pipeline that exercises the DataFrame API. This will also use
+ # caching in the background.
+ with cell: # Cell 1
+ p = beam.Pipeline(interactive_runner.InteractiveRunner())
+ ib.watch({'p': p})
+
+ with cell: # Cell 2
+ data = p | beam.Create([
+ 1, 2, 3
+ ]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
+
+ with beam.dataframe.allow_non_parallel_operations():
+ df = to_dataframe(data).reset_index(drop=True)
+
+ ib.collect(df)
+
+ with cell: # Cell 3
+ df['output'] = df['square'] * df['cube']
+ ib.collect(df)
+
+ with cell: # Cell 4
+ df['output'] = 0
+ ib.collect(df)
+
+ # We use a trace through the graph to perform an isomorphism test. The end
+ # output should look like a linear graph. This indicates that the dataframe
+ # transform was correctly broken into separate pieces to cache. If caching
+ # isn't enabled, all the dataframe computation nodes are connected to a
+ # single shared node.
+ trace = []
+
+ # Only look at the top-level transforms for the isomorphism. The test
+ # doesn't care about the transform implementations, just the overall shape.
+ class TopLevelTracer(beam.pipeline.PipelineVisitor):
+ def _find_root_producer(self, node: beam.pipeline.AppliedPTransform):
+ if node is None or not node.full_label:
+ return None
+
+ parent = self._find_root_producer(node.parent)
+ if parent is None:
+ return node
+
+ return parent
+
+ def _add_to_trace(self, node, trace):
+ if '/' not in str(node):
+ if node.inputs:
+ producer = self._find_root_producer(node.inputs[0].producer)
+ producer_name = producer.full_label if producer else ''
+ trace.append((producer_name, node.full_label))
+
+ def visit_transform(self, node: beam.pipeline.AppliedPTransform):
+ self._add_to_trace(node, trace)
+
+ def enter_composite_transform(
+ self, node: beam.pipeline.AppliedPTransform):
+ self._add_to_trace(node, trace)
+
+ p.visit(TopLevelTracer())
+
+ # Do the isomorphism test which states that the topological sort of the
+ # graph yields a linear graph.
+ trace_string = '\n'.join(str(t) for t in trace)
+ prev_producer = ''
+ for producer, consumer in trace:
+ self.assertEqual(producer, prev_producer, trace_string)
+ prev_producer = consumer
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index 448f76a..c51a648 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -23,7 +23,6 @@
import pandas as pd
import apache_beam as beam
-from apache_beam.dataframe.convert import to_pcollection
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.runners.interactive import background_caching_job as bcj
@@ -310,7 +309,7 @@
# TODO(BEAM-12388): investigate the mixing pcollections in multiple
# pipelines error when using the default label.
for df in watched_dataframes:
- pcoll = to_pcollection(df, yield_elements='pandas', label=str(df._expr))
+ pcoll, _ = utils.deferred_df_to_pcollection(df)
watched_pcollections.add(pcoll)
for pcoll in pcolls:
if pcoll not in watched_pcollections:
diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png
index 96ed442..44cfc70 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png
+++ b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/29c9237ddf4f3d5988a503069b4d3c47.png
Binary files differ
diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png
index e2e1ad4..aa98c62 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png
+++ b/sdks/python/apache_beam/runners/interactive/testing/integration/goldens/Linux/7a35f487b2a5f3a9b9852a8659eeb4bd.png
Binary files differ
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py
index bbe88ff..3e85145 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -25,7 +25,10 @@
import pandas as pd
+from apache_beam.dataframe.convert import to_pcollection
+from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
from apache_beam.testing.test_stream import WindowedValueHolder
from apache_beam.typehints.schemas import named_fields_from_element_type
@@ -267,3 +270,17 @@
return str(return_value)
return return_as_json
+
+
+def deferred_df_to_pcollection(df):
+ assert isinstance(df, DeferredBase), '{} is not a DeferredBase'.format(df)
+
+ # The proxy is used to output a DataFrame with the correct columns.
+ #
+ # TODO(BEAM-11064): Once type hints are implemented for pandas, use those
+ # instead of the proxy.
+ cache = ExpressionCache()
+ cache.replace_with_cached(df._expr)
+
+ proxy = df._expr.proxy()
+ return to_pcollection(df, yield_elements='pandas', label=str(df._expr)), proxy
diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py
index 4088d3e..5c92eaf 100644
--- a/sdks/python/apache_beam/transforms/sideinputs.py
+++ b/sdks/python/apache_beam/transforms/sideinputs.py
@@ -55,6 +55,9 @@
if target_window_fn == window.GlobalWindows():
return _global_window_mapping_fn
+ if isinstance(target_window_fn, window.Sessions):
+ raise RuntimeError("Sessions is not allowed in side inputs")
+
def map_via_end(source_window):
# type: (window.BoundedWindow) -> window.BoundedWindow
return list(
diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py
index 226252c..e5f7c24 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-"""Support for Dataflow triggers.
+"""Support for Apache Beam triggers.
Triggers control when in processing time windows get emitted.
"""