This closes #3705: [BEAM-165] Initial implementation of the MapReduce runner
mr-runner: Removes WordCountTest, fixes checkstyle, findbugs, and addressed comments.
mr-runner-hack: disable unrelated modules to shorten build time during development.
mr-runner: support SourceMetrics, this fixes MetricsTest.testBoundedSourceMetrics().
mr-runner: introduces duplicateFactor in FlattenOperation, this fixes testFlattenInputMultipleCopies().
mr-runner: translate empty flatten into EmptySource, this fixes few empty FalttenTests.
mr-runner: ensure Operation only start/finish once for diamond shaped DAG, this fixes ParDoLifecycleTest.
mr-runner: Graph.getSteps() to return with topological order, this fixes few CombineTests.
mr-runner: fail early in the runner when MapReduce job fails.
mr-runner: use InMemoryStateInternals in ParDoOperation, this fixed ParDoTest that uses state.
mr-runner: use the correct step name in ParDoTranslator, this fixes MetricsTest.testAttemptedCounterMetrics().
mr-runner: remove the hard-coded GlobalWindow coder, and fixes WindowingTest.
mr-runner: handle no files case in FileSideInputReader for empty views.
mr-runner: fix NPE in PipelineTest.testIdentityTransform().
mr-runner: filter out unsupported features in ValidatesRunner tests.
mr-runner: setMetricsSupported to run ValidatesRunner tests with TestPipeline.
mr-runner: fix the bug that steps are attached multiple times in diamond shaped DAG.
[BEAM-2783] support metrics in MapReduceRunner.
mr-runner: setup file paths for read and write sides of materialization.
mr-runner: support side inputs by reading in all views contents.
mr-runner: support multiple SourceOperations by composing and partitioning.
mr-runner: support PCollections materialization with multiple MR jobs.
mr-runner: hack to get around that ViewAsXXX.expand() return wrong output PValue.
mr-runner: support graph visualization with dotfiles.
mr-runner: refactors and creates Graph data structures to handle general Beam pipelines.
mr-runner: add JarClassInstanceFactory to run ValidatesRunner tests.
mr-runner: support reduce side ParDos and WordCount.
core-java: InMemoryTimerInternals expose getTimers() for timer firings in mr-runner.
mr-runner: add BeamReducer and support GroupByKey.
mr-runner: add ParDoOperation and support ParDos chaining.
mr-runner: add JobPrototype and translate it to a MR job.
mr-runner: support BoundedSource with BeamInputFormat.
MapReduceRunner: add unit tests for GraphConverter and GraphPlanner.
MapReduceRunner: add Graph and its visitors.
Initial commit for MapReduceRunner.
diff --git a/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Gearpump.groovy b/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Gearpump.groovy
index 1348a19..e1cbafe 100644
--- a/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Gearpump.groovy
+++ b/.test-infra/jenkins/job_beam_PostCommit_Java_ValidatesRunner_Gearpump.groovy
@@ -45,5 +45,5 @@
'Run Gearpump ValidatesRunner')
// Maven goals for this job.
- goals('-B -e clean verify -am -pl runners/gearpump -DforkCount=0 -DvalidatesRunnerPipelineOptions=\'[ "--runner=TestGearpumpRunner", "--streaming=false" ]\'')
+ goals('-B -e clean verify -am -pl runners/gearpump -DforkCount=0 -DvalidatesRunnerPipelineOptions=\'[ "--runner=TestGearpumpRunner"]\'')
}
diff --git a/examples/java/pom.xml b/examples/java/pom.xml
index 12fe06f..ade4cac 100644
--- a/examples/java/pom.xml
+++ b/examples/java/pom.xml
@@ -365,20 +365,7 @@
</profiles>
<build>
-
<plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <systemPropertyVariables>
- <beamUseDummyRunner />
- <beamTestPipelineOptions>
- </beamTestPipelineOptions>
- </systemPropertyVariables>
- </configuration>
- </plugin>
-
<!-- Coverage analysis for unit tests. -->
<plugin>
<groupId>org.jacoco</groupId>
@@ -518,7 +505,6 @@
</dependency>
<!-- Test dependencies -->
-
<!--
For testing the example itself, use the direct runner. This is separate from
the use of ValidatesRunner tests for testing a particular runner.
diff --git a/examples/java/src/main/java/org/apache/beam/examples/complete/TfIdf.java b/examples/java/src/main/java/org/apache/beam/examples/complete/TfIdf.java
index 435ffab..cfc413c 100644
--- a/examples/java/src/main/java/org/apache/beam/examples/complete/TfIdf.java
+++ b/examples/java/src/main/java/org/apache/beam/examples/complete/TfIdf.java
@@ -25,7 +25,6 @@
import java.util.HashSet;
import java.util.Set;
import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringDelegateCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
@@ -155,11 +154,6 @@
}
@Override
- public Coder<?> getDefaultOutputCoder() {
- return KvCoder.of(StringDelegateCoder.of(URI.class), StringUtf8Coder.of());
- }
-
- @Override
public PCollection<KV<URI, String>> expand(PBegin input) {
Pipeline pipeline = input.getPipeline();
@@ -179,9 +173,11 @@
uriString = uri.toString();
}
- PCollection<KV<URI, String>> oneUriToLines = pipeline
- .apply("TextIO.Read(" + uriString + ")", TextIO.read().from(uriString))
- .apply("WithKeys(" + uriString + ")", WithKeys.<URI, String>of(uri));
+ PCollection<KV<URI, String>> oneUriToLines =
+ pipeline
+ .apply("TextIO.Read(" + uriString + ")", TextIO.read().from(uriString))
+ .apply("WithKeys(" + uriString + ")", WithKeys.<URI, String>of(uri))
+ .setCoder(KvCoder.of(StringDelegateCoder.of(URI.class), StringUtf8Coder.of()));
urisToLines = urisToLines.and(oneUriToLines);
}
diff --git a/examples/java8/pom.xml b/examples/java8/pom.xml
index 6fd29a4..6e1fe8f 100644
--- a/examples/java8/pom.xml
+++ b/examples/java8/pom.xml
@@ -151,6 +151,18 @@
</dependency>
</dependencies>
</profile>
+
+ <!-- Include the Apache Gearpump (incubating) runner with -P gearpump-runner -->
+ <profile>
+ <id>gearpump-runner</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-runners-gearpump</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ </dependencies>
+ </profile>
</profiles>
<build>
@@ -178,17 +190,6 @@
</configuration>
</plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <systemPropertyVariables>
- <beamTestPipelineOptions>
- </beamTestPipelineOptions>
- </systemPropertyVariables>
- </configuration>
- </plugin>
-
<!-- Coverage analysis for unit tests. -->
<plugin>
<groupId>org.jacoco</groupId>
diff --git a/pom.xml b/pom.xml
index 25cd51b..ae86a9c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -157,10 +157,12 @@
<groovy-maven-plugin.version>2.0</groovy-maven-plugin.version>
<surefire-plugin.version>2.20</surefire-plugin.version>
<failsafe-plugin.version>2.20</failsafe-plugin.version>
- <maven-compiler-plugin.version>3.6.1</maven-compiler-plugin.version>
+ <maven-compiler-plugin.version>3.6.2</maven-compiler-plugin.version>
<maven-dependency-plugin.version>3.0.1</maven-dependency-plugin.version>
+ <maven-enforcer-plugin.version>3.0.0-M1</maven-enforcer-plugin.version>
<maven-exec-plugin.version>1.6.0</maven-exec-plugin.version>
<maven-jar-plugin.version>3.0.2</maven-jar-plugin.version>
+ <maven-javadoc-plugin.version>3.0.0-M1</maven-javadoc-plugin.version>
<maven-resources-plugin.version>3.0.2</maven-resources-plugin.version>
<maven-shade-plugin.version>3.0.0</maven-shade-plugin.version>
@@ -579,6 +581,12 @@
<dependency>
<groupId>org.apache.beam</groupId>
+ <artifactId>beam-runners-gearpump</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.beam</groupId>
<artifactId>beam-examples-java</artifactId>
<version>${project.version}</version>
</dependency>
@@ -1389,7 +1397,7 @@
here, we leave things simple here. -->
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
- <version>2.10.4</version>
+ <version>${maven-javadoc-plugin.version}</version>
<configuration>
<additionalparam>${beam.javadoc_opts}</additionalparam>
<windowtitle>Apache Beam SDK for Java, version ${project.version} API</windowtitle>
@@ -1780,7 +1788,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
- <version>1.4.1</version>
+ <version>${maven-enforcer-plugin.version}</version>
<executions>
<execution>
<id>enforce</id>
diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml
index fd5aafb..96aac8b 100644
--- a/runners/apex/pom.xml
+++ b/runners/apex/pom.xml
@@ -63,14 +63,6 @@
<version>${apex.malhar.version}</version>
</dependency>
<dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-core</artifactId>
- </dependency>
- <dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-databind</artifactId>
- </dependency>
- <dependency>
<groupId>org.apache.apex</groupId>
<artifactId>apex-engine</artifactId>
<version>${apex.core.version}</version>
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
index fd0a1c9..57d2593 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
@@ -227,9 +227,8 @@
@Override
public PCollection<ElemT> expand(PCollection<ElemT> input) {
- return PCollection.<ElemT>createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
- .setCoder(input.getCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder());
}
public PCollectionView<ViewT> getView() {
@@ -380,8 +379,9 @@
public PTransformReplacement<PCollection<InputT>, PCollectionTuple> getReplacementTransform(
AppliedPTransform<PCollection<InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>>
transform) {
- return PTransformReplacement.of(PTransformReplacements.getSingletonMainInput(transform),
- SplittableParDo.forJavaParDo(transform.getTransform()));
+ return PTransformReplacement.of(
+ PTransformReplacements.getSingletonMainInput(transform),
+ SplittableParDo.forAppliedParDo(transform));
}
@Override
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslator.java
index 440b801..189cb65 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslator.java
@@ -110,8 +110,12 @@
}
if (collections.size() > 2) {
- PCollection<T> intermediateCollection = intermediateCollection(collection,
- collection.getCoder());
+ PCollection<T> intermediateCollection =
+ PCollection.createPrimitiveOutputInternal(
+ collection.getPipeline(),
+ collection.getWindowingStrategy(),
+ collection.isBounded(),
+ collection.getCoder());
context.addOperator(operator, operator.out, intermediateCollection);
remainingCollections.add(intermediateCollection);
} else {
@@ -135,11 +139,4 @@
}
}
- static <T> PCollection<T> intermediateCollection(PCollection<T> input, Coder<T> outputCoder) {
- PCollection<T> output = PCollection.createPrimitiveOutputInternal(input.getPipeline(),
- input.getWindowingStrategy(), input.isBounded());
- output.setCoder(outputCoder);
- return output;
- }
-
}
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
index e46687a..be11b02 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
@@ -241,8 +241,11 @@
}
PCollection<Object> resultCollection =
- FlattenPCollectionTranslator.intermediateCollection(
- firstSideInput, firstSideInput.getCoder());
+ PCollection.createPrimitiveOutputInternal(
+ firstSideInput.getPipeline(),
+ firstSideInput.getWindowingStrategy(),
+ firstSideInput.isBounded(),
+ firstSideInput.getCoder());
FlattenPCollectionTranslator.flattenCollections(
sourceCollections, unionTags, resultCollection, context);
return resultCollection;
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java
index 39f681f..5c0d72f 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java
@@ -33,7 +33,6 @@
import org.apache.beam.runners.apex.ApexPipelineOptions;
import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
-import org.apache.beam.runners.apex.translation.utils.SerializablePipelineOptions;
import org.apache.beam.runners.core.NullSideInputReader;
import org.apache.beam.runners.core.OutputWindowedValue;
import org.apache.beam.runners.core.ReduceFnRunner;
@@ -41,6 +40,7 @@
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.TimerInternals.TimerData;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TriggerTranslation;
import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
import org.apache.beam.runners.core.triggers.TriggerStateMachines;
@@ -149,7 +149,9 @@
@Override
public void setup(OperatorContext context) {
- this.traceTuples = ApexStreamTuple.Logging.isDebugEnabled(serializedOptions.get(), this);
+ this.traceTuples =
+ ApexStreamTuple.Logging.isDebugEnabled(
+ serializedOptions.get().as(ApexPipelineOptions.class), this);
}
@Override
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index c3cbab2..4dc807d 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -40,7 +40,6 @@
import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
import org.apache.beam.runners.apex.translation.utils.NoOpStepContext;
-import org.apache.beam.runners.apex.translation.utils.SerializablePipelineOptions;
import org.apache.beam.runners.apex.translation.utils.StateInternalsProxy;
import org.apache.beam.runners.apex.translation.utils.ValueAndCoderKryoSerializable;
import org.apache.beam.runners.core.DoFnRunner;
@@ -64,6 +63,7 @@
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.runners.core.TimerInternalsFactory;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.ListCoder;
@@ -386,7 +386,9 @@
@Override
public void setup(OperatorContext context) {
- this.traceTuples = ApexStreamTuple.Logging.isDebugEnabled(pipelineOptions.get(), this);
+ this.traceTuples =
+ ApexStreamTuple.Logging.isDebugEnabled(
+ pipelineOptions.get().as(ApexPipelineOptions.class), this);
SideInputReader sideInputReader = NullSideInputReader.of(sideInputs);
if (!sideInputs.isEmpty()) {
sideInputHandler = new SideInputHandler(sideInputs, sideInputStateInternals);
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexReadUnboundedInputOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexReadUnboundedInputOperator.java
index 1549560..21fb9d2 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexReadUnboundedInputOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexReadUnboundedInputOperator.java
@@ -30,8 +30,8 @@
import org.apache.beam.runners.apex.ApexPipelineOptions;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple.DataTuple;
-import org.apache.beam.runners.apex.translation.utils.SerializablePipelineOptions;
import org.apache.beam.runners.apex.translation.utils.ValuesSource;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -119,7 +119,9 @@
@Override
public void setup(OperatorContext context) {
- this.traceTuples = ApexStreamTuple.Logging.isDebugEnabled(pipelineOptions.get(), this);
+ this.traceTuples =
+ ApexStreamTuple.Logging.isDebugEnabled(
+ pipelineOptions.get().as(ApexPipelineOptions.class), this);
try {
reader = source.createReader(this.pipelineOptions.get(), null);
available = reader.start();
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/SerializablePipelineOptions.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/SerializablePipelineOptions.java
deleted file mode 100644
index 46b04fc..0000000
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/SerializablePipelineOptions.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.apex.translation.utils;
-
-import com.fasterxml.jackson.databind.Module;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.concurrent.atomic.AtomicBoolean;
-import org.apache.beam.runners.apex.ApexPipelineOptions;
-import org.apache.beam.sdk.io.FileSystems;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.util.common.ReflectHelpers;
-
-/**
- * A wrapper to enable serialization of {@link PipelineOptions}.
- */
-public class SerializablePipelineOptions implements Externalizable {
-
- /* Used to ensure we initialize file systems exactly once, because it's a slow operation. */
- private static final AtomicBoolean FILE_SYSTEMS_INTIIALIZED = new AtomicBoolean(false);
-
- private transient ApexPipelineOptions pipelineOptions;
-
- public SerializablePipelineOptions(ApexPipelineOptions pipelineOptions) {
- this.pipelineOptions = pipelineOptions;
- }
-
- public SerializablePipelineOptions() {
- }
-
- public ApexPipelineOptions get() {
- return this.pipelineOptions;
- }
-
- @Override
- public void writeExternal(ObjectOutput out) throws IOException {
- out.writeUTF(createMapper().writeValueAsString(pipelineOptions));
- }
-
- @Override
- public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- String s = in.readUTF();
- this.pipelineOptions = createMapper().readValue(s, PipelineOptions.class)
- .as(ApexPipelineOptions.class);
-
- if (FILE_SYSTEMS_INTIIALIZED.compareAndSet(false, true)) {
- FileSystems.setDefaultPipelineOptions(pipelineOptions);
- }
- }
-
- /**
- * Use an {@link ObjectMapper} configured with any {@link Module}s in the class path allowing
- * for user specified configuration injection into the ObjectMapper. This supports user custom
- * types on {@link PipelineOptions}.
- */
- private static ObjectMapper createMapper() {
- return new ObjectMapper().registerModules(
- ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
- }
-}
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ValuesSource.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ValuesSource.java
index 41f027f..4a00ff1 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ValuesSource.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ValuesSource.java
@@ -81,7 +81,7 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return iterableCoder.getElemCoder();
}
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/UnboundedTextSource.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/UnboundedTextSource.java
index c590a2e..8f3e6bc 100644
--- a/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/UnboundedTextSource.java
+++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/examples/UnboundedTextSource.java
@@ -59,7 +59,7 @@
}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java
index 206b430..63a218b 100644
--- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java
+++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java
@@ -59,9 +59,9 @@
WindowingStrategy<?, ?> ws = WindowingStrategy.of(FixedWindows.of(
Duration.standardSeconds(10)));
- PCollection<KV<String, Integer>> input = PCollection.createPrimitiveOutputInternal(pipeline,
- ws, IsBounded.BOUNDED);
- input.setCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
+ PCollection<KV<String, Integer>> input =
+ PCollection.createPrimitiveOutputInternal(
+ pipeline, ws, IsBounded.BOUNDED, KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
ApexGroupByKeyOperator<String, Integer> operator = new ApexGroupByKeyOperator<>(options,
input, new ApexStateInternals.ApexStateBackend()
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslatorTest.java
index 9c61b47..58f33ae 100644
--- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslatorTest.java
+++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslatorTest.java
@@ -153,7 +153,7 @@
}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/CollectionSource.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/CollectionSource.java
index 288aade..01a2a85 100644
--- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/CollectionSource.java
+++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/CollectionSource.java
@@ -67,7 +67,7 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return coder;
}
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/PipelineOptionsTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/PipelineOptionsTest.java
deleted file mode 100644
index 118ff99..0000000
--- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/PipelineOptionsTest.java
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.apex.translation.utils;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-
-import com.datatorrent.common.util.FSStorageAgent;
-import com.esotericsoftware.kryo.serializers.FieldSerializer.Bind;
-import com.esotericsoftware.kryo.serializers.JavaSerializer;
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.core.JsonParser;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.DeserializationContext;
-import com.fasterxml.jackson.databind.JsonDeserializer;
-import com.fasterxml.jackson.databind.JsonSerializer;
-import com.fasterxml.jackson.databind.Module;
-import com.fasterxml.jackson.databind.SerializerProvider;
-import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
-import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-import com.fasterxml.jackson.databind.module.SimpleModule;
-import com.google.auto.service.AutoService;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import org.apache.beam.runners.apex.ApexPipelineOptions;
-import org.apache.beam.sdk.options.Default;
-import org.apache.beam.sdk.options.Description;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.junit.Test;
-
-/**
- * Tests the serialization of PipelineOptions.
- */
-public class PipelineOptionsTest {
-
- /**
- * Interface for testing.
- */
- public interface MyOptions extends ApexPipelineOptions {
- @Description("Bla bla bla")
- @Default.String("Hello")
- String getTestOption();
- void setTestOption(String value);
- }
-
- private static class OptionsWrapper {
- private OptionsWrapper() {
- this(null); // required for Kryo
- }
- private OptionsWrapper(ApexPipelineOptions options) {
- this.options = new SerializablePipelineOptions(options);
- }
- @Bind(JavaSerializer.class)
- private final SerializablePipelineOptions options;
- }
-
- @Test
- public void testSerialization() {
- OptionsWrapper wrapper = new OptionsWrapper(
- PipelineOptionsFactory.fromArgs("--testOption=nothing").as(MyOptions.class));
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- FSStorageAgent.store(bos, wrapper);
-
- ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
- OptionsWrapper wrapperCopy = (OptionsWrapper) FSStorageAgent.retrieve(bis);
- assertNotNull(wrapperCopy.options);
- assertEquals("nothing", wrapperCopy.options.get().as(MyOptions.class).getTestOption());
- }
-
- @Test
- public void testSerializationWithUserCustomType() {
- OptionsWrapper wrapper = new OptionsWrapper(
- PipelineOptionsFactory.fromArgs("--jacksonIncompatible=\"testValue\"")
- .as(JacksonIncompatibleOptions.class));
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- FSStorageAgent.store(bos, wrapper);
-
- ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
- OptionsWrapper wrapperCopy = (OptionsWrapper) FSStorageAgent.retrieve(bis);
- assertNotNull(wrapperCopy.options);
- assertEquals("testValue",
- wrapperCopy.options.get().as(JacksonIncompatibleOptions.class)
- .getJacksonIncompatible().value);
- }
-
- /** PipelineOptions used to test auto registration of Jackson modules. */
- public interface JacksonIncompatibleOptions extends ApexPipelineOptions {
- JacksonIncompatible getJacksonIncompatible();
- void setJacksonIncompatible(JacksonIncompatible value);
- }
-
- /** A Jackson {@link Module} to test auto-registration of modules. */
- @AutoService(Module.class)
- public static class RegisteredTestModule extends SimpleModule {
- public RegisteredTestModule() {
- super("RegisteredTestModule");
- setMixInAnnotation(JacksonIncompatible.class, JacksonIncompatibleMixin.class);
- }
- }
-
- /** A class which Jackson does not know how to serialize/deserialize. */
- public static class JacksonIncompatible {
- private final String value;
- public JacksonIncompatible(String value) {
- this.value = value;
- }
- }
-
- /** A Jackson mixin used to add annotations to other classes. */
- @JsonDeserialize(using = JacksonIncompatibleDeserializer.class)
- @JsonSerialize(using = JacksonIncompatibleSerializer.class)
- public static final class JacksonIncompatibleMixin {}
-
- /** A Jackson deserializer for {@link JacksonIncompatible}. */
- public static class JacksonIncompatibleDeserializer extends
- JsonDeserializer<JacksonIncompatible> {
-
- @Override
- public JacksonIncompatible deserialize(JsonParser jsonParser,
- DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
- return new JacksonIncompatible(jsonParser.readValueAs(String.class));
- }
- }
-
- /** A Jackson serializer for {@link JacksonIncompatible}. */
- public static class JacksonIncompatibleSerializer extends JsonSerializer<JacksonIncompatible> {
-
- @Override
- public void serialize(JacksonIncompatible jacksonIncompatible, JsonGenerator jsonGenerator,
- SerializerProvider serializerProvider) throws IOException, JsonProcessingException {
- jsonGenerator.writeString(jacksonIncompatible.value);
- }
- }
-}
diff --git a/runners/core-construction-java/pom.xml b/runners/core-construction-java/pom.xml
index b85b5f5..1a52914 100644
--- a/runners/core-construction-java/pom.xml
+++ b/runners/core-construction-java/pom.xml
@@ -65,6 +65,21 @@
</dependency>
<dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ </dependency>
+
+ <dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</dependency>
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DisplayDataTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DisplayDataTranslation.java
new file mode 100644
index 0000000..ff7f9f2
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DisplayDataTranslation.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.BoolValue;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+
+/** Utilities for going to/from DisplayData protos. */
+public class DisplayDataTranslation {
+ public static RunnerApi.DisplayData toProto(DisplayData displayData) {
+ // TODO https://issues.apache.org/jira/browse/BEAM-2645
+ return RunnerApi.DisplayData.newBuilder()
+ .addItems(
+ RunnerApi.DisplayData.Item.newBuilder()
+ .setId(RunnerApi.DisplayData.Identifier.newBuilder().setKey("stubImplementation"))
+ .setLabel("Stub implementation")
+ .setType(RunnerApi.DisplayData.Type.BOOLEAN)
+ .setValue(Any.pack(BoolValue.newBuilder().setValue(true).build())))
+ .build();
+ }
+}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ForwardingPTransform.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ForwardingPTransform.java
index ca25ba7..ccf41f3 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ForwardingPTransform.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ForwardingPTransform.java
@@ -18,7 +18,6 @@
package org.apache.beam.runners.core.construction;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
@@ -37,7 +36,16 @@
@Override
public OutputT expand(InputT input) {
- return delegate().expand(input);
+ OutputT res = delegate().expand(input);
+ if (res instanceof PCollection) {
+ PCollection pc = (PCollection) res;
+ try {
+ pc.setCoder(delegate().getDefaultOutputCoder(input, pc));
+ } catch (CannotProvideCoderException e) {
+ // Let coder inference happen later.
+ }
+ }
+ return res;
}
@Override
@@ -51,12 +59,6 @@
}
@Override
- public <T> Coder<T> getDefaultOutputCoder(InputT input, PCollection<T> output)
- throws CannotProvideCoderException {
- return delegate().getDefaultOutputCoder(input, output);
- }
-
- @Override
public void populateDisplayData(DisplayData.Builder builder) {
builder.delegate(delegate());
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java
index c0a5acf..c256e4c 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionTranslation.java
@@ -52,10 +52,10 @@
Coder<?> coder = components.getCoder(pCollection.getCoderId());
return PCollection.createPrimitiveOutputInternal(
- pipeline,
- components.getWindowingStrategy(pCollection.getWindowingStrategyId()),
- fromProto(pCollection.getIsBounded()))
- .setCoder((Coder) coder);
+ pipeline,
+ components.getWindowingStrategy(pCollection.getWindowingStrategyId()),
+ fromProto(pCollection.getIsBounded()),
+ (Coder) coder);
}
public static IsBounded isBounded(RunnerApi.PCollection pCollection) {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
index 706a956..35bad15 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
@@ -20,6 +20,7 @@
import static com.google.common.base.Preconditions.checkArgument;
+import com.google.common.collect.Iterables;
import java.util.Map;
import java.util.Set;
import org.apache.beam.sdk.runners.AppliedPTransform;
@@ -66,4 +67,9 @@
ignoredTags);
return mainInput;
}
+
+ public static <T> PCollection<T> getSingletonMainOutput(
+ AppliedPTransform<?, PCollection<T>, ? extends PTransform<?, PCollection<T>>> transform) {
+ return ((PCollection<T>) Iterables.getOnlyElement(transform.getOutputs().values()));
+ }
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
index 3b94724..b8365c9 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java
@@ -33,6 +33,7 @@
import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
@@ -91,6 +92,7 @@
List<AppliedPTransform<?, ?, ?>> subtransforms,
SdkComponents components)
throws IOException {
+ // TODO include DisplayData https://issues.apache.org/jira/browse/BEAM-2645
RunnerApi.PTransform.Builder transformBuilder = RunnerApi.PTransform.newBuilder();
for (Map.Entry<TupleTag<?>, PValue> taggedInput : appliedPTransform.getInputs().entrySet()) {
checkArgument(
@@ -118,7 +120,8 @@
}
transformBuilder.setUniqueName(appliedPTransform.getFullName());
- // TODO: Display Data
+ transformBuilder.setDisplayData(
+ DisplayDataTranslation.toProto(DisplayData.from(appliedPTransform.getTransform())));
PTransform<?, ?> transform = appliedPTransform.getTransform();
// A RawPTransform directly vends its payload. Because it will generally be
@@ -134,6 +137,7 @@
}
transformBuilder.setSpec(payload);
}
+ rawPTransform.registerComponents(components);
} else if (KNOWN_PAYLOAD_TRANSLATORS.containsKey(transform.getClass())) {
FunctionSpec payload =
KNOWN_PAYLOAD_TRANSLATORS
@@ -223,6 +227,8 @@
public Any getPayload() {
return null;
}
+
+ public void registerComponents(SdkComponents components) {}
}
/**
@@ -253,6 +259,10 @@
transformSpec.setParameter(payload);
}
+ // Transforms like Combine may have Coders that need to be added but do not
+ // occur in a black-box traversal
+ transform.getTransform().registerComponents(components);
+
return transformSpec.build();
}
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index d7b0e9f..5765c51 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -484,7 +484,7 @@
});
}
- private static SideInput toProto(PCollectionView<?> view) {
+ public static SideInput toProto(PCollectionView<?> view) {
Builder builder = SideInput.newBuilder();
builder.setAccessPattern(
FunctionSpec.newBuilder()
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java
new file mode 100644
index 0000000..9e4839a
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.base.MoreObjects;
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.ListMultimap;
+import com.google.protobuf.Any;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PCollectionViews;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** Utilities for going to/from Runner API pipelines. */
+public class PipelineTranslation {
+
+ public static RunnerApi.Pipeline toProto(final Pipeline pipeline) {
+ final SdkComponents components = SdkComponents.create();
+ final Collection<String> rootIds = new HashSet<>();
+ pipeline.traverseTopologically(
+ new PipelineVisitor.Defaults() {
+ private final ListMultimap<Node, AppliedPTransform<?, ?, ?>> children =
+ ArrayListMultimap.create();
+
+ @Override
+ public void leaveCompositeTransform(Node node) {
+ if (node.isRootNode()) {
+ for (AppliedPTransform<?, ?, ?> pipelineRoot : children.get(node)) {
+ rootIds.add(components.getExistingPTransformId(pipelineRoot));
+ }
+ } else {
+ // TODO: Include DisplayData in the proto
+ children.put(node.getEnclosingNode(), node.toAppliedPTransform(pipeline));
+ try {
+ components.registerPTransform(
+ node.toAppliedPTransform(pipeline), children.get(node));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ @Override
+ public void visitPrimitiveTransform(Node node) {
+ // TODO: Include DisplayData in the proto
+ children.put(node.getEnclosingNode(), node.toAppliedPTransform(pipeline));
+ try {
+ components.registerPTransform(
+ node.toAppliedPTransform(pipeline),
+ Collections.<AppliedPTransform<?, ?, ?>>emptyList());
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ });
+ return RunnerApi.Pipeline.newBuilder()
+ .setComponents(components.toComponents())
+ .addAllRootTransformIds(rootIds)
+ .build();
+ }
+
+ private static DisplayData evaluateDisplayData(HasDisplayData component) {
+ return DisplayData.from(component);
+ }
+
+ public static Pipeline fromProto(final RunnerApi.Pipeline pipelineProto)
+ throws IOException {
+ TransformHierarchy transforms = new TransformHierarchy();
+ Pipeline pipeline = Pipeline.forTransformHierarchy(transforms, PipelineOptionsFactory.create());
+
+ // Keeping the PCollections straight is a semantic necessity, but being careful not to explode
+ // the number of coders and windowing strategies is also nice, and helps testing.
+ RehydratedComponents rehydratedComponents =
+ RehydratedComponents.forComponents(pipelineProto.getComponents()).withPipeline(pipeline);
+
+ for (String rootId : pipelineProto.getRootTransformIdsList()) {
+ addRehydratedTransform(
+ transforms,
+ pipelineProto.getComponents().getTransformsOrThrow(rootId),
+ pipeline,
+ pipelineProto.getComponents().getTransformsMap(),
+ rehydratedComponents);
+ }
+
+ return pipeline;
+ }
+
+ private static void addRehydratedTransform(
+ TransformHierarchy transforms,
+ RunnerApi.PTransform transformProto,
+ Pipeline pipeline,
+ Map<String, RunnerApi.PTransform> transformProtos,
+ RehydratedComponents rehydratedComponents)
+ throws IOException {
+
+ Map<TupleTag<?>, PValue> rehydratedInputs = new HashMap<>();
+ for (Map.Entry<String, String> inputEntry : transformProto.getInputsMap().entrySet()) {
+ rehydratedInputs.put(
+ new TupleTag<>(inputEntry.getKey()),
+ rehydratedComponents.getPCollection(inputEntry.getValue()));
+ }
+
+ Map<TupleTag<?>, PValue> rehydratedOutputs = new HashMap<>();
+ for (Map.Entry<String, String> outputEntry : transformProto.getOutputsMap().entrySet()) {
+ rehydratedOutputs.put(
+ new TupleTag<>(outputEntry.getKey()),
+ rehydratedComponents.getPCollection(outputEntry.getValue()));
+ }
+
+ RunnerApi.FunctionSpec transformSpec = transformProto.getSpec();
+
+ // By default, no "additional" inputs, since that is an SDK-specific thing.
+ // Only ParDo really separates main from side inputs
+ Map<TupleTag<?>, PValue> additionalInputs = Collections.emptyMap();
+
+ // TODO: ParDoTranslator should own it - https://issues.apache.org/jira/browse/BEAM-2674
+ if (transformSpec.getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN)) {
+ RunnerApi.ParDoPayload payload =
+ transformSpec.getParameter().unpack(RunnerApi.ParDoPayload.class);
+
+ List<PCollectionView<?>> views = new ArrayList<>();
+ for (Map.Entry<String, RunnerApi.SideInput> sideInputEntry :
+ payload.getSideInputsMap().entrySet()) {
+ String localName = sideInputEntry.getKey();
+ RunnerApi.SideInput sideInput = sideInputEntry.getValue();
+ PCollection<?> pCollection =
+ (PCollection<?>) checkNotNull(rehydratedInputs.get(new TupleTag<>(localName)));
+ views.add(
+ ParDoTranslation.viewFromProto(
+ sideInputEntry.getValue(),
+ sideInputEntry.getKey(),
+ pCollection,
+ transformProto,
+ rehydratedComponents));
+ }
+ additionalInputs = PCollectionViews.toAdditionalInputs(views);
+ }
+
+ // TODO: CombineTranslator should own it - https://issues.apache.org/jira/browse/BEAM-2674
+ List<Coder<?>> additionalCoders = Collections.emptyList();
+ if (transformSpec.getUrn().equals(PTransformTranslation.COMBINE_TRANSFORM_URN)) {
+ RunnerApi.CombinePayload payload =
+ transformSpec.getParameter().unpack(RunnerApi.CombinePayload.class);
+ additionalCoders =
+ (List)
+ Collections.singletonList(
+ rehydratedComponents.getCoder(payload.getAccumulatorCoderId()));
+ }
+
+ RehydratedPTransform transform =
+ RehydratedPTransform.of(
+ transformSpec.getUrn(),
+ transformSpec.getParameter(),
+ additionalInputs,
+ additionalCoders);
+
+ if (isPrimitive(transformProto)) {
+ transforms.addFinalizedPrimitiveNode(
+ transformProto.getUniqueName(), rehydratedInputs, transform, rehydratedOutputs);
+ } else {
+ transforms.pushFinalizedNode(
+ transformProto.getUniqueName(), rehydratedInputs, transform, rehydratedOutputs);
+
+ for (String childTransformId : transformProto.getSubtransformsList()) {
+ addRehydratedTransform(
+ transforms,
+ transformProtos.get(childTransformId),
+ pipeline,
+ transformProtos,
+ rehydratedComponents);
+ }
+
+ transforms.popNode();
+ }
+ }
+
+ // A primitive transform is one with outputs that are not in its input and also
+ // not produced by a subtransform.
+ private static boolean isPrimitive(RunnerApi.PTransform transformProto) {
+ return transformProto.getSubtransformsCount() == 0
+ && !transformProto
+ .getInputsMap()
+ .values()
+ .containsAll(transformProto.getOutputsMap().values());
+ }
+
+ @AutoValue
+ abstract static class RehydratedPTransform extends RawPTransform<PInput, POutput> {
+
+ @Nullable
+ public abstract String getUrn();
+
+ @Nullable
+ public abstract Any getPayload();
+
+ @Override
+ public abstract Map<TupleTag<?>, PValue> getAdditionalInputs();
+
+ public abstract List<Coder<?>> getCoders();
+
+ public static RehydratedPTransform of(
+ String urn,
+ Any payload,
+ Map<TupleTag<?>, PValue> additionalInputs,
+ List<Coder<?>> additionalCoders) {
+ return new AutoValue_PipelineTranslation_RehydratedPTransform(
+ urn, payload, additionalInputs, additionalCoders);
+ }
+
+ @Override
+ public POutput expand(PInput input) {
+ throw new IllegalStateException(
+ String.format(
+ "%s should never be asked to expand;"
+ + " it is the result of deserializing an already-constructed Pipeline",
+ getClass().getSimpleName()));
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("urn", getUrn())
+ .add("payload", getPayload())
+ .toString();
+ }
+
+ @Override
+ public void registerComponents(SdkComponents components) {
+ for (Coder<?> coder : getCoders()) {
+ try {
+ components.registerCoder(coder);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+ }
+}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
index f43d23b..62b6d0a 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
@@ -18,7 +18,9 @@
package org.apache.beam.runners.core.construction;
+import com.google.common.collect.Iterables;
import java.util.Map;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.transforms.Create;
@@ -36,15 +38,17 @@
*/
public class PrimitiveCreate<T> extends PTransform<PBegin, PCollection<T>> {
private final Create.Values<T> transform;
+ private final Coder<T> coder;
- private PrimitiveCreate(Create.Values<T> transform) {
+ private PrimitiveCreate(Create.Values<T> transform, Coder<T> coder) {
this.transform = transform;
+ this.coder = coder;
}
@Override
public PCollection<T> expand(PBegin input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.BOUNDED, coder);
}
public Iterable<T> getElements() {
@@ -60,7 +64,11 @@
public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
AppliedPTransform<PBegin, PCollection<T>, Values<T>> transform) {
return PTransformReplacement.of(
- transform.getPipeline().begin(), new PrimitiveCreate<T>(transform.getTransform()));
+ transform.getPipeline().begin(),
+ new PrimitiveCreate<T>(
+ transform.getTransform(),
+ ((PCollection<T>) Iterables.getOnlyElement(transform.getOutputs().values()))
+ .getCoder()));
}
@Override
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
index a9a34d7..ccdd4a7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/RehydratedComponents.java
@@ -93,7 +93,8 @@
PCollection.class.getSimpleName(),
Pipeline.class.getSimpleName());
return PCollectionTranslation.fromProto(
- components.getPcollectionsOrThrow(id), pipeline, RehydratedComponents.this);
+ components.getPcollectionsOrThrow(id), pipeline, RehydratedComponents.this)
+ .setName(id);
}
});
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
index 0d3ba60..54d2e9d 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SdkComponents.java
@@ -22,24 +22,16 @@
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.base.Equivalence;
-import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
-import com.google.common.collect.ListMultimap;
import java.io.IOException;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
import java.util.List;
import java.util.Set;
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.NameUtils;
import org.apache.beam.sdk.values.PCollection;
@@ -62,50 +54,6 @@
return new SdkComponents();
}
- public static RunnerApi.Pipeline translatePipeline(Pipeline pipeline) {
- final SdkComponents components = create();
- final Collection<String> rootIds = new HashSet<>();
- pipeline.traverseTopologically(
- new PipelineVisitor.Defaults() {
- private final ListMultimap<Node, AppliedPTransform<?, ?, ?>> children =
- ArrayListMultimap.create();
-
- @Override
- public void leaveCompositeTransform(Node node) {
- if (node.isRootNode()) {
- for (AppliedPTransform<?, ?, ?> pipelineRoot : children.get(node)) {
- rootIds.add(components.getExistingPTransformId(pipelineRoot));
- }
- } else {
- children.put(node.getEnclosingNode(), node.toAppliedPTransform(getPipeline()));
- try {
- components.registerPTransform(
- node.toAppliedPTransform(getPipeline()), children.get(node));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
- }
-
- @Override
- public void visitPrimitiveTransform(Node node) {
- children.put(node.getEnclosingNode(), node.toAppliedPTransform(getPipeline()));
- try {
- components.registerPTransform(
- node.toAppliedPTransform(getPipeline()),
- Collections.<AppliedPTransform<?, ?, ?>>emptyList());
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- }
- });
- // TODO: Display Data
- return RunnerApi.Pipeline.newBuilder()
- .setComponents(components.toComponents())
- .addAllRootTransformIds(rootIds)
- .build();
- }
-
private SdkComponents() {
this.componentsBuilder = RunnerApi.Components.newBuilder();
this.transformIds = HashBiMap.create();
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SerializablePipelineOptions.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SerializablePipelineOptions.java
new file mode 100644
index 0000000..e697fb2
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SerializablePipelineOptions.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.Serializable;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
+
+/**
+ * Holds a {@link PipelineOptions} in JSON serialized form and calls {@link
+ * FileSystems#setDefaultPipelineOptions(PipelineOptions)} on construction or on deserialization.
+ */
+public class SerializablePipelineOptions implements Serializable {
+ private static final ObjectMapper MAPPER =
+ new ObjectMapper()
+ .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
+
+ private final String serializedPipelineOptions;
+ private transient PipelineOptions options;
+
+ public SerializablePipelineOptions(PipelineOptions options) {
+ this.serializedPipelineOptions = serializeToJson(options);
+ this.options = options;
+ FileSystems.setDefaultPipelineOptions(options);
+ }
+
+ public PipelineOptions get() {
+ return options;
+ }
+
+ private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException {
+ is.defaultReadObject();
+ this.options = deserializeFromJson(serializedPipelineOptions);
+ // TODO https://issues.apache.org/jira/browse/BEAM-2712: remove this call.
+ FileSystems.setDefaultPipelineOptions(options);
+ }
+
+ private static String serializeToJson(PipelineOptions options) {
+ try {
+ return MAPPER.writeValueAsString(options);
+ } catch (JsonProcessingException e) {
+ throw new IllegalArgumentException("Failed to serialize PipelineOptions", e);
+ }
+ }
+
+ private static PipelineOptions deserializeFromJson(String options) {
+ try {
+ return MAPPER.readValue(options, PipelineOptions.class);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Failed to deserialize PipelineOptions", e);
+ }
+ }
+}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
index e71187b..32d3409 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
@@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkArgument;
+import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.List;
import java.util.Map;
@@ -73,6 +74,7 @@
private final List<PCollectionView<?>> sideInputs;
private final TupleTag<OutputT> mainOutputTag;
private final TupleTagList additionalOutputTags;
+ private final Map<TupleTag<?>, Coder<?>> outputTagsToCoders;
public static final String SPLITTABLE_PROCESS_URN =
"urn:beam:runners_core:transforms:splittable_process:v1";
@@ -85,34 +87,18 @@
private SplittableParDo(
DoFn<InputT, OutputT> doFn,
- TupleTag<OutputT> mainOutputTag,
List<PCollectionView<?>> sideInputs,
- TupleTagList additionalOutputTags) {
+ TupleTag<OutputT> mainOutputTag,
+ TupleTagList additionalOutputTags,
+ Map<TupleTag<?>, Coder<?>> outputTagsToCoders) {
checkArgument(
DoFnSignatures.getSignature(doFn.getClass()).processElement().isSplittable(),
"fn must be a splittable DoFn");
this.doFn = doFn;
- this.mainOutputTag = mainOutputTag;
this.sideInputs = sideInputs;
+ this.mainOutputTag = mainOutputTag;
this.additionalOutputTags = additionalOutputTags;
- }
-
- /**
- * Creates a {@link SplittableParDo} from an original Java {@link ParDo}.
- *
- * @param parDo The splittable {@link ParDo} transform.
- */
- public static <InputT, OutputT> SplittableParDo<InputT, OutputT, ?> forJavaParDo(
- ParDo.MultiOutput<InputT, OutputT> parDo) {
- checkArgument(parDo != null, "parDo must not be null");
- checkArgument(
- DoFnSignatures.getSignature(parDo.getFn().getClass()).processElement().isSplittable(),
- "fn must be a splittable DoFn");
- return new SplittableParDo(
- parDo.getFn(),
- parDo.getMainOutputTag(),
- parDo.getSideInputs(),
- parDo.getAdditionalOutputTags());
+ this.outputTagsToCoders = outputTagsToCoders;
}
/**
@@ -121,15 +107,22 @@
* <p>The input may generally be a deserialized transform so it may not actually be a {@link
* ParDo}. Instead {@link ParDoTranslation} will be used to extract fields.
*/
- public static SplittableParDo<?, ?, ?> forAppliedParDo(AppliedPTransform<?, ?, ?> parDo) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public static <InputT, OutputT> SplittableParDo<InputT, OutputT, ?> forAppliedParDo(
+ AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> parDo) {
checkArgument(parDo != null, "parDo must not be null");
try {
- return new SplittableParDo<>(
+ Map<TupleTag<?>, Coder<?>> outputTagsToCoders = Maps.newHashMap();
+ for (Map.Entry<TupleTag<?>, PValue> entry : parDo.getOutputs().entrySet()) {
+ outputTagsToCoders.put(entry.getKey(), ((PCollection) entry.getValue()).getCoder());
+ }
+ return new SplittableParDo(
ParDoTranslation.getDoFn(parDo),
- (TupleTag) ParDoTranslation.getMainOutputTag(parDo),
ParDoTranslation.getSideInputs(parDo),
- ParDoTranslation.getAdditionalOutputTags(parDo));
+ ParDoTranslation.getMainOutputTag(parDo),
+ ParDoTranslation.getAdditionalOutputTags(parDo),
+ outputTagsToCoders);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
@@ -168,7 +161,8 @@
(WindowingStrategy<InputT, ?>) input.getWindowingStrategy(),
sideInputs,
mainOutputTag,
- additionalOutputTags));
+ additionalOutputTags,
+ outputTagsToCoders));
}
@Override
@@ -202,6 +196,7 @@
private final List<PCollectionView<?>> sideInputs;
private final TupleTag<OutputT> mainOutputTag;
private final TupleTagList additionalOutputTags;
+ private final Map<TupleTag<?>, Coder<?>> outputTagsToCoders;
/**
* @param fn the splittable {@link DoFn}.
@@ -209,7 +204,8 @@
* @param sideInputs list of side inputs that should be available to the {@link DoFn}.
* @param mainOutputTag {@link TupleTag Tag} of the {@link DoFn DoFn's} main output.
* @param additionalOutputTags {@link TupleTagList Tags} of the {@link DoFn DoFn's} additional
- * outputs.
+ * @param outputTagsToCoders A map from output tag to the coder for that output, which should
+ * provide mappings for the main and all additional tags.
*/
public ProcessKeyedElements(
DoFn<InputT, OutputT> fn,
@@ -218,7 +214,8 @@
WindowingStrategy<InputT, ?> windowingStrategy,
List<PCollectionView<?>> sideInputs,
TupleTag<OutputT> mainOutputTag,
- TupleTagList additionalOutputTags) {
+ TupleTagList additionalOutputTags,
+ Map<TupleTag<?>, Coder<?>> outputTagsToCoders) {
this.fn = fn;
this.elementCoder = elementCoder;
this.restrictionCoder = restrictionCoder;
@@ -226,6 +223,7 @@
this.sideInputs = sideInputs;
this.mainOutputTag = mainOutputTag;
this.additionalOutputTags = additionalOutputTags;
+ this.outputTagsToCoders = outputTagsToCoders;
}
public DoFn<InputT, OutputT> getFn() {
@@ -256,10 +254,14 @@
return additionalOutputTags;
}
+ public Map<TupleTag<?>, Coder<?>> getOutputTagsToCoders() {
+ return outputTagsToCoders;
+ }
+
@Override
public PCollectionTuple expand(PCollection<KV<String, KV<InputT, RestrictionT>>> input) {
return createPrimitiveOutputFor(
- input, fn, mainOutputTag, additionalOutputTags, windowingStrategy);
+ input, fn, mainOutputTag, additionalOutputTags, outputTagsToCoders, windowingStrategy);
}
public static <OutputT> PCollectionTuple createPrimitiveOutputFor(
@@ -267,12 +269,14 @@
DoFn<?, OutputT> fn,
TupleTag<OutputT> mainOutputTag,
TupleTagList additionalOutputTags,
+ Map<TupleTag<?>, Coder<?>> outputTagsToCoders,
WindowingStrategy<?, ?> windowingStrategy) {
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
PCollectionTuple outputs =
PCollectionTuple.ofPrimitiveOutputsInternal(
input.getPipeline(),
TupleTagList.of(mainOutputTag).and(additionalOutputTags.getAll()),
+ outputTagsToCoders,
windowingStrategy,
input.isBounded().and(signature.isBoundedPerElement()));
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
index b1d2da4..7954b0e 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/WriteFilesTranslation.java
@@ -19,29 +19,35 @@
package org.apache.beam.runners.core.construction;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
import com.google.auto.service.AutoService;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.FunctionSpec;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.SdkFunctionSpec;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.SideInput;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.WriteFilesPayload;
import org.apache.beam.sdk.io.FileBasedSink;
import org.apache.beam.sdk.io.WriteFiles;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.TupleTag;
/**
* Utility methods for translating a {@link WriteFiles} to and from {@link RunnerApi}
@@ -53,28 +59,25 @@
public static final String CUSTOM_JAVA_FILE_BASED_SINK_URN =
"urn:beam:file_based_sink:javasdk:0.1";
- public static final String CUSTOM_JAVA_FILE_BASED_SINK_FORMAT_FUNCTION_URN =
- "urn:beam:file_based_sink_format_function:javasdk:0.1";
-
@VisibleForTesting
static WriteFilesPayload toProto(WriteFiles<?, ?, ?> transform) {
+ Map<String, SideInput> sideInputs = Maps.newHashMap();
+ for (PCollectionView<?> view : transform.getSink().getDynamicDestinations().getSideInputs()) {
+ sideInputs.put(view.getTagInternal().getId(), ParDoTranslation.toProto(view));
+ }
return WriteFilesPayload.newBuilder()
.setSink(toProto(transform.getSink()))
- .setFormatFunction(toProto(transform.getFormatFunction()))
.setWindowedWrites(transform.isWindowedWrites())
.setRunnerDeterminedSharding(
transform.getNumShards() == null && transform.getSharding() == null)
+ .putAllSideInputs(sideInputs)
.build();
}
- private static SdkFunctionSpec toProto(FileBasedSink<?, ?> sink) {
+ private static SdkFunctionSpec toProto(FileBasedSink<?, ?, ?> sink) {
return toProto(CUSTOM_JAVA_FILE_BASED_SINK_URN, sink);
}
- private static SdkFunctionSpec toProto(SerializableFunction<?, ?> serializableFunction) {
- return toProto(CUSTOM_JAVA_FILE_BASED_SINK_FORMAT_FUNCTION_URN, serializableFunction);
- }
-
private static SdkFunctionSpec toProto(String urn, Serializable serializable) {
return SdkFunctionSpec.newBuilder()
.setSpec(
@@ -91,7 +94,7 @@
}
@VisibleForTesting
- static FileBasedSink<?, ?> sinkFromProto(SdkFunctionSpec sinkProto) throws IOException {
+ static FileBasedSink<?, ?, ?> sinkFromProto(SdkFunctionSpec sinkProto) throws IOException {
checkArgument(
sinkProto.getSpec().getUrn().equals(CUSTOM_JAVA_FILE_BASED_SINK_URN),
"Cannot extract %s instance from %s with URN %s",
@@ -102,44 +105,44 @@
byte[] serializedSink =
sinkProto.getSpec().getParameter().unpack(BytesValue.class).getValue().toByteArray();
- return (FileBasedSink<?, ?>)
+ return (FileBasedSink<?, ?, ?>)
SerializableUtils.deserializeFromByteArray(
serializedSink, FileBasedSink.class.getSimpleName());
}
- @VisibleForTesting
- static <InputT, OutputT> SerializableFunction<InputT, OutputT> formatFunctionFromProto(
- SdkFunctionSpec sinkProto) throws IOException {
- checkArgument(
- sinkProto.getSpec().getUrn().equals(CUSTOM_JAVA_FILE_BASED_SINK_FORMAT_FUNCTION_URN),
- "Cannot extract %s instance from %s with URN %s",
- SerializableFunction.class.getSimpleName(),
- FunctionSpec.class.getSimpleName(),
- sinkProto.getSpec().getUrn());
-
- byte[] serializedFunction =
- sinkProto.getSpec().getParameter().unpack(BytesValue.class).getValue().toByteArray();
-
- return (SerializableFunction<InputT, OutputT>)
- SerializableUtils.deserializeFromByteArray(
- serializedFunction, FileBasedSink.class.getSimpleName());
- }
-
- public static <UserT, DestinationT, OutputT> FileBasedSink<OutputT, DestinationT> getSink(
+ public static <UserT, DestinationT, OutputT> FileBasedSink<UserT, DestinationT, OutputT> getSink(
AppliedPTransform<PCollection<UserT>, PDone, ? extends PTransform<PCollection<UserT>, PDone>>
transform)
throws IOException {
- return (FileBasedSink<OutputT, DestinationT>)
+ return (FileBasedSink<UserT, DestinationT, OutputT>)
sinkFromProto(getWriteFilesPayload(transform).getSink());
}
- public static <InputT, OutputT> SerializableFunction<InputT, OutputT> getFormatFunction(
- AppliedPTransform<
- PCollection<InputT>, PDone, ? extends PTransform<PCollection<InputT>, PDone>>
- transform)
- throws IOException {
- return formatFunctionFromProto(
- getWriteFilesPayload(transform).<InputT, OutputT>getFormatFunction());
+ public static <UserT, DestinationT, OutputT>
+ List<PCollectionView<?>> getDynamicDestinationSideInputs(
+ AppliedPTransform<
+ PCollection<UserT>, PDone, ? extends PTransform<PCollection<UserT>, PDone>>
+ transform)
+ throws IOException {
+ SdkComponents sdkComponents = SdkComponents.create();
+ RunnerApi.PTransform transformProto = PTransformTranslation.toProto(transform, sdkComponents);
+ List<PCollectionView<?>> views = Lists.newArrayList();
+ Map<String, SideInput> sideInputs = getWriteFilesPayload(transform).getSideInputsMap();
+ for (Map.Entry<String, SideInput> entry : sideInputs.entrySet()) {
+ PCollection<?> originalPCollection =
+ checkNotNull(
+ (PCollection<?>) transform.getInputs().get(new TupleTag<>(entry.getKey())),
+ "no input with tag %s",
+ entry.getKey());
+ views.add(
+ ParDoTranslation.viewFromProto(
+ entry.getValue(),
+ entry.getKey(),
+ originalPCollection,
+ transformProto,
+ RehydratedComponents.forComponents(sdkComponents.toComponents())));
+ }
+ return views;
}
public static <T> boolean isWindowedWrites(
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ForwardingPTransformTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ForwardingPTransformTest.java
index 74c056c..4741b6b 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ForwardingPTransformTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ForwardingPTransformTest.java
@@ -26,6 +26,7 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowingStrategy;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -90,14 +91,24 @@
@Test
public void getDefaultOutputCoderDelegates() throws Exception {
@SuppressWarnings("unchecked")
- PCollection<Integer> input = Mockito.mock(PCollection.class);
+ PCollection<Integer> input =
+ PCollection.createPrimitiveOutputInternal(
+ null /* pipeline */,
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.BOUNDED,
+ null /* coder */);
@SuppressWarnings("unchecked")
- PCollection<String> output = Mockito.mock(PCollection.class);
+ PCollection<String> output = PCollection.createPrimitiveOutputInternal(
+ null /* pipeline */,
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.BOUNDED,
+ null /* coder */);
@SuppressWarnings("unchecked")
Coder<String> outputCoder = Mockito.mock(Coder.class);
+ Mockito.when(delegate.expand(input)).thenReturn(output);
Mockito.when(delegate.getDefaultOutputCoder(input, output)).thenReturn(outputCoder);
- assertThat(forwarding.getDefaultOutputCoder(input, output), equalTo(outputCoder));
+ assertThat(forwarding.expand(input).getCoder(), equalTo(outputCoder));
}
@Test
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
index 99d3dd1..fa7e1e9 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
@@ -27,6 +27,7 @@
import com.google.common.collect.ImmutableMap;
import java.io.Serializable;
import java.util.Collections;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
@@ -56,13 +57,14 @@
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
@@ -98,15 +100,16 @@
private AppliedPTransform<?, ?, ?> getAppliedTransform(PTransform pardo) {
PCollection<KV<String, Integer>> input =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p,
+ WindowingStrategy.globalDefault(),
+ IsBounded.BOUNDED,
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
input.setName("dummy input");
- input.setCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
PCollection<Integer> output =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
output.setName("dummy output");
- output.setCoder(VarIntCoder.of());
return AppliedPTransform.of("pardo", input.expand(), output.expand(), pardo, p);
}
@@ -131,7 +134,7 @@
@Override
public PCollection<Integer> expand(PCollection<KV<String, Integer>> input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded());
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), VarIntCoder.of());
}
}
PTransformMatcher matcher = PTransformMatchers.classEqualTo(MyPTransform.class);
@@ -423,14 +426,14 @@
public void emptyFlattenWithEmptyFlatten() {
AppliedPTransform application =
AppliedPTransform
- .<PCollectionList<Object>, PCollection<Object>, Flatten.PCollections<Object>>of(
+ .<PCollectionList<Integer>, PCollection<Integer>, Flatten.PCollections<Integer>>of(
"EmptyFlatten",
Collections.<TupleTag<?>, PValue>emptyMap(),
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
- Flatten.pCollections(),
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
+ Flatten.<Integer>pCollections(),
p);
assertThat(PTransformMatchers.emptyFlatten().matches(application), is(true));
@@ -440,17 +443,17 @@
public void emptyFlattenWithNonEmptyFlatten() {
AppliedPTransform application =
AppliedPTransform
- .<PCollectionList<Object>, PCollection<Object>, Flatten.PCollections<Object>>of(
+ .<PCollectionList<Integer>, PCollection<Integer>, Flatten.PCollections<Integer>>of(
"Flatten",
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
- Flatten.pCollections(),
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
+ Flatten.<Integer>pCollections(),
p);
assertThat(PTransformMatchers.emptyFlatten().matches(application), is(false));
@@ -460,15 +463,15 @@
public void emptyFlattenWithNonFlatten() {
AppliedPTransform application =
AppliedPTransform
- .<PCollection<Iterable<Object>>, PCollection<Object>, Flatten.Iterables<Object>>of(
+ .<PCollection<Iterable<Integer>>, PCollection<Integer>, Flatten.Iterables<Integer>>of(
"EmptyFlatten",
Collections.<TupleTag<?>, PValue>emptyMap(),
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
- Flatten.iterables() /* This isn't actually possible to construct,
- * but for the sake of example */,
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
+ /* This isn't actually possible to construct, but for the sake of example */
+ Flatten.<Integer>iterables(),
p);
assertThat(PTransformMatchers.emptyFlatten().matches(application), is(false));
@@ -478,17 +481,17 @@
public void flattenWithDuplicateInputsWithoutDuplicates() {
AppliedPTransform application =
AppliedPTransform
- .<PCollectionList<Object>, PCollection<Object>, Flatten.PCollections<Object>>of(
+ .<PCollectionList<Integer>, PCollection<Integer>, Flatten.PCollections<Integer>>of(
"Flatten",
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
- Flatten.pCollections(),
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
+ Flatten.<Integer>pCollections(),
p);
assertThat(PTransformMatchers.flattenWithDuplicateInputs().matches(application), is(false));
@@ -496,22 +499,22 @@
@Test
public void flattenWithDuplicateInputsWithDuplicates() {
- PCollection<Object> duplicate =
+ PCollection<Integer> duplicate =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
AppliedPTransform application =
AppliedPTransform
- .<PCollectionList<Object>, PCollection<Object>, Flatten.PCollections<Object>>of(
+ .<PCollectionList<Integer>, PCollection<Integer>, Flatten.PCollections<Integer>>of(
"Flatten",
ImmutableMap.<TupleTag<?>, PValue>builder()
- .put(new TupleTag<Object>(), duplicate)
- .put(new TupleTag<Object>(), duplicate)
+ .put(new TupleTag<Integer>(), duplicate)
+ .put(new TupleTag<Integer>(), duplicate)
.build(),
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
- Flatten.pCollections(),
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
+ Flatten.<Integer>pCollections(),
p);
assertThat(PTransformMatchers.flattenWithDuplicateInputs().matches(application), is(true));
@@ -521,15 +524,15 @@
public void flattenWithDuplicateInputsNonFlatten() {
AppliedPTransform application =
AppliedPTransform
- .<PCollection<Iterable<Object>>, PCollection<Object>, Flatten.Iterables<Object>>of(
+ .<PCollection<Iterable<Integer>>, PCollection<Integer>, Flatten.Iterables<Integer>>of(
"EmptyFlatten",
Collections.<TupleTag<?>, PValue>emptyMap(),
Collections.<TupleTag<?>, PValue>singletonMap(
- new TupleTag<Object>(),
+ new TupleTag<Integer>(),
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)),
- Flatten.iterables() /* This isn't actually possible to construct,
- * but for the sake of example */,
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of())),
+ /* This isn't actually possible to construct, but for the sake of example */
+ Flatten.<Integer>iterables(),
p);
assertThat(PTransformMatchers.flattenWithDuplicateInputs().matches(application), is(false));
@@ -546,14 +549,14 @@
false);
WriteFiles<Integer, Void, Integer> write =
WriteFiles.to(
- new FileBasedSink<Integer, Void>(
- StaticValueProvider.of(outputDirectory), DynamicFileDestinations.constant(null)) {
+ new FileBasedSink<Integer, Void, Integer>(
+ StaticValueProvider.of(outputDirectory),
+ DynamicFileDestinations.<Integer>constant(new FakeFilenamePolicy())) {
@Override
- public WriteOperation<Integer, Void> createWriteOperation() {
+ public WriteOperation<Void, Integer> createWriteOperation() {
return null;
}
- },
- SerializableFunctions.<Integer>identity());
+ });
assertThat(
PTransformMatchers.writeWithRunnerDeterminedSharding().matches(appliedWrite(write)),
is(true));
@@ -580,4 +583,23 @@
write,
p);
}
+
+ private static class FakeFilenamePolicy extends FilenamePolicy {
+ @Override
+ public ResourceId windowedFilename(
+ int shardNumber,
+ int numShards,
+ BoundedWindow window,
+ PaneInfo paneInfo,
+ FileBasedSink.OutputFileHints outputFileHints) {
+ throw new UnsupportedOperationException("should not be called");
+ }
+
+ @Nullable
+ @Override
+ public ResourceId unwindowedFilename(
+ int shardNumber, int numShards, FileBasedSink.OutputFileHints outputFileHints) {
+ throw new UnsupportedOperationException("should not be called");
+ }
+ }
}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
new file mode 100644
index 0000000..9e6dff4
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.base.Equivalence;
+import com.google.common.collect.ImmutableList;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.coders.BigEndianLongCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.io.GenerateSequence;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.transforms.windowing.AfterPane;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.joda.time.Duration;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+/** Tests for {@link PipelineTranslation}. */
+@RunWith(Parameterized.class)
+public class PipelineTranslationTest {
+ @Parameter(0)
+ public Pipeline pipeline;
+
+ @Parameters(name = "{index}")
+ public static Iterable<Pipeline> testPipelines() {
+ Pipeline trivialPipeline = Pipeline.create();
+ trivialPipeline.apply(Create.of(1, 2, 3));
+
+ Pipeline sideInputPipeline = Pipeline.create();
+ final PCollectionView<String> singletonView =
+ sideInputPipeline.apply(Create.of("foo")).apply(View.<String>asSingleton());
+ sideInputPipeline
+ .apply(Create.of("main input"))
+ .apply(
+ ParDo.of(
+ new DoFn<String, String>() {
+ @ProcessElement
+ public void process(ProcessContext c) {
+ // actually never executed and no effect on translation
+ c.sideInput(singletonView);
+ }
+ })
+ .withSideInputs(singletonView));
+
+ Pipeline complexPipeline = Pipeline.create();
+ BigEndianLongCoder customCoder = BigEndianLongCoder.of();
+ PCollection<Long> elems = complexPipeline.apply(GenerateSequence.from(0L).to(207L));
+ PCollection<Long> counted = elems.apply(Count.<Long>globally()).setCoder(customCoder);
+ PCollection<Long> windowed =
+ counted.apply(
+ Window.<Long>into(FixedWindows.of(Duration.standardMinutes(7)))
+ .triggering(
+ AfterWatermark.pastEndOfWindow()
+ .withEarlyFirings(AfterPane.elementCountAtLeast(19)))
+ .accumulatingFiredPanes()
+ .withAllowedLateness(Duration.standardMinutes(3L)));
+ final WindowingStrategy<?, ?> windowedStrategy = windowed.getWindowingStrategy();
+ PCollection<KV<String, Long>> keyed = windowed.apply(WithKeys.<String, Long>of("foo"));
+ PCollection<KV<String, Iterable<Long>>> grouped =
+ keyed.apply(GroupByKey.<String, Long>create());
+
+ return ImmutableList.of(trivialPipeline, sideInputPipeline, complexPipeline);
+ }
+
+ @Test
+ public void testProtoDirectly() {
+ final RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+ pipeline.traverseTopologically(
+ new PipelineProtoVerificationVisitor(pipelineProto));
+ }
+
+ @Test
+ public void testProtoAgainstRehydrated() throws Exception {
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+ Pipeline rehydrated = PipelineTranslation.fromProto(pipelineProto);
+
+ rehydrated.traverseTopologically(
+ new PipelineProtoVerificationVisitor(pipelineProto));
+ }
+
+ private static class PipelineProtoVerificationVisitor extends PipelineVisitor.Defaults {
+
+ private final RunnerApi.Pipeline pipelineProto;
+ Set<Node> transforms;
+ Set<PCollection<?>> pcollections;
+ Set<Equivalence.Wrapper<? extends Coder<?>>> coders;
+ Set<WindowingStrategy<?, ?>> windowingStrategies;
+
+ public PipelineProtoVerificationVisitor(RunnerApi.Pipeline pipelineProto) {
+ this.pipelineProto = pipelineProto;
+ transforms = new HashSet<>();
+ pcollections = new HashSet<>();
+ coders = new HashSet<>();
+ windowingStrategies = new HashSet<>();
+ }
+
+ @Override
+ public void leaveCompositeTransform(Node node) {
+ if (node.isRootNode()) {
+ assertThat(
+ "Unexpected number of PTransforms",
+ pipelineProto.getComponents().getTransformsCount(),
+ equalTo(transforms.size()));
+ assertThat(
+ "Unexpected number of PCollections",
+ pipelineProto.getComponents().getPcollectionsCount(),
+ equalTo(pcollections.size()));
+ assertThat(
+ "Unexpected number of Coders",
+ pipelineProto.getComponents().getCodersCount(),
+ equalTo(coders.size()));
+ assertThat(
+ "Unexpected number of Windowing Strategies",
+ pipelineProto.getComponents().getWindowingStrategiesCount(),
+ equalTo(windowingStrategies.size()));
+ } else {
+ transforms.add(node);
+ if (PTransformTranslation.COMBINE_TRANSFORM_URN.equals(
+ PTransformTranslation.urnForTransformOrNull(node.getTransform()))) {
+ // Combine translation introduces a coder that is not assigned to any PCollection
+ // in the default expansion, and must be explicitly added here.
+ try {
+ addCoders(
+ CombineTranslation.getAccumulatorCoder(node.toAppliedPTransform(getPipeline())));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void visitPrimitiveTransform(Node node) {
+ transforms.add(node);
+ }
+
+ @Override
+ public void visitValue(PValue value, Node producer) {
+ if (value instanceof PCollection) {
+ PCollection pc = (PCollection) value;
+ pcollections.add(pc);
+ addCoders(pc.getCoder());
+ windowingStrategies.add(pc.getWindowingStrategy());
+ addCoders(pc.getWindowingStrategy().getWindowFn().windowCoder());
+ }
+ }
+
+ private void addCoders(Coder<?> coder) {
+ coders.add(Equivalence.<Coder<?>>identity().wrap(coder));
+ if (CoderTranslation.KNOWN_CODER_URNS.containsKey(coder.getClass())) {
+ for (Coder<?> component : ((StructuredCoder<?>) coder).getComponents()) {
+ addCoders(component);
+ }
+ }
+ }
+ }
+}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReadTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReadTranslationTest.java
index 740b324..f85bd79 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReadTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReadTranslationTest.java
@@ -112,7 +112,7 @@
public void validate() {}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
@@ -132,7 +132,7 @@
public void validate() {}
@Override
- public Coder<byte[]> getDefaultOutputCoder() {
+ public Coder<byte[]> getOutputCoder() {
return ByteArrayCoder.of();
}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReplacementOutputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReplacementOutputsTest.java
index f8d01e9..0165e4b 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReplacementOutputsTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ReplacementOutputsTest.java
@@ -24,6 +24,8 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Map;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.PCollection;
@@ -50,23 +52,23 @@
private PCollection<Integer> ints =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
private PCollection<Integer> moreInts =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
private PCollection<String> strs =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of());
private PCollection<Integer> replacementInts =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
private PCollection<Integer> moreReplacementInts =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
private PCollection<String> replacementStrs =
PCollection.createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of());
@Test
public void singletonSucceeds() {
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
index ce6a99f..82840d6 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SdkComponentsTest.java
@@ -24,43 +24,25 @@
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertThat;
-import com.google.common.base.Equivalence;
import java.io.IOException;
import java.util.Collections;
-import java.util.HashSet;
-import java.util.Set;
-import org.apache.beam.sdk.Pipeline.PipelineVisitor;
-import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.SetCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.coders.StructuredCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
-import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.Components;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.WithKeys;
-import org.apache.beam.sdk.transforms.windowing.AfterPane;
-import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
-import org.apache.beam.sdk.transforms.windowing.FixedWindows;
-import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode;
import org.hamcrest.Matchers;
-import org.joda.time.Duration;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -78,95 +60,6 @@
private SdkComponents components = SdkComponents.create();
@Test
- public void translatePipeline() {
- BigEndianLongCoder customCoder = BigEndianLongCoder.of();
- PCollection<Long> elems = pipeline.apply(GenerateSequence.from(0L).to(207L));
- PCollection<Long> counted = elems.apply(Count.<Long>globally()).setCoder(customCoder);
- PCollection<Long> windowed =
- counted.apply(
- Window.<Long>into(FixedWindows.of(Duration.standardMinutes(7)))
- .triggering(
- AfterWatermark.pastEndOfWindow()
- .withEarlyFirings(AfterPane.elementCountAtLeast(19)))
- .accumulatingFiredPanes()
- .withAllowedLateness(Duration.standardMinutes(3L)));
- final WindowingStrategy<?, ?> windowedStrategy = windowed.getWindowingStrategy();
- PCollection<KV<String, Long>> keyed = windowed.apply(WithKeys.<String, Long>of("foo"));
- PCollection<KV<String, Iterable<Long>>> grouped =
- keyed.apply(GroupByKey.<String, Long>create());
-
- final RunnerApi.Pipeline pipelineProto = SdkComponents.translatePipeline(pipeline);
- pipeline.traverseTopologically(
- new PipelineVisitor.Defaults() {
- Set<Node> transforms = new HashSet<>();
- Set<PCollection<?>> pcollections = new HashSet<>();
- Set<Equivalence.Wrapper<? extends Coder<?>>> coders = new HashSet<>();
- Set<WindowingStrategy<?, ?>> windowingStrategies = new HashSet<>();
-
- @Override
- public void leaveCompositeTransform(Node node) {
- if (node.isRootNode()) {
- assertThat(
- "Unexpected number of PTransforms",
- pipelineProto.getComponents().getTransformsCount(),
- equalTo(transforms.size()));
- assertThat(
- "Unexpected number of PCollections",
- pipelineProto.getComponents().getPcollectionsCount(),
- equalTo(pcollections.size()));
- assertThat(
- "Unexpected number of Coders",
- pipelineProto.getComponents().getCodersCount(),
- equalTo(coders.size()));
- assertThat(
- "Unexpected number of Windowing Strategies",
- pipelineProto.getComponents().getWindowingStrategiesCount(),
- equalTo(windowingStrategies.size()));
- } else {
- transforms.add(node);
- if (PTransformTranslation.COMBINE_TRANSFORM_URN.equals(
- PTransformTranslation.urnForTransformOrNull(node.getTransform()))) {
- // Combine translation introduces a coder that is not assigned to any PCollection
- // in the default expansion, and must be explicitly added here.
- try {
- addCoders(
- CombineTranslation.getAccumulatorCoder(
- node.toAppliedPTransform(getPipeline())));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
- }
- }
-
- @Override
- public void visitPrimitiveTransform(Node node) {
- transforms.add(node);
- }
-
- @Override
- public void visitValue(PValue value, Node producer) {
- if (value instanceof PCollection) {
- PCollection pc = (PCollection) value;
- pcollections.add(pc);
- addCoders(pc.getCoder());
- windowingStrategies.add(pc.getWindowingStrategy());
- addCoders(pc.getWindowingStrategy().getWindowFn().windowCoder());
- }
- }
-
- private void addCoders(Coder<?> coder) {
- coders.add(Equivalence.<Coder<?>>identity().wrap(coder));
- if (coder instanceof StructuredCoder) {
- for (Coder<?> component : ((StructuredCoder<?>) coder).getComponents()) {
- addCoders(component);
- }
- }
- }
- });
- }
-
- @Test
public void registerCoder() throws IOException {
Coder<?> coder =
KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(SetCoder.of(ByteArrayCoder.of())));
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SerializablePipelineOptionsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SerializablePipelineOptionsTest.java
new file mode 100644
index 0000000..cd470b2
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SerializablePipelineOptionsTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static org.junit.Assert.assertEquals;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.util.SerializableUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link SerializablePipelineOptions}. */
+@RunWith(JUnit4.class)
+public class SerializablePipelineOptionsTest {
+ /** Options for testing. */
+ public interface MyOptions extends PipelineOptions {
+ String getFoo();
+
+ void setFoo(String foo);
+
+ @JsonIgnore
+ @Default.String("not overridden")
+ String getIgnoredField();
+
+ void setIgnoredField(String value);
+ }
+
+ @Test
+ public void testSerializationAndDeserialization() throws Exception {
+ PipelineOptions options =
+ PipelineOptionsFactory.fromArgs("--foo=testValue", "--ignoredField=overridden")
+ .as(MyOptions.class);
+
+ SerializablePipelineOptions serializableOptions = new SerializablePipelineOptions(options);
+ assertEquals("testValue", serializableOptions.get().as(MyOptions.class).getFoo());
+ assertEquals("overridden", serializableOptions.get().as(MyOptions.class).getIgnoredField());
+
+ SerializablePipelineOptions copy = SerializableUtils.clone(serializableOptions);
+ assertEquals("testValue", copy.get().as(MyOptions.class).getFoo());
+ assertEquals("not overridden", copy.get().as(MyOptions.class).getIgnoredField());
+ }
+
+ @Test
+ public void testIndependence() throws Exception {
+ SerializablePipelineOptions first =
+ new SerializablePipelineOptions(
+ PipelineOptionsFactory.fromArgs("--foo=first").as(MyOptions.class));
+ SerializablePipelineOptions firstCopy = SerializableUtils.clone(first);
+ SerializablePipelineOptions second =
+ new SerializablePipelineOptions(
+ PipelineOptionsFactory.fromArgs("--foo=second").as(MyOptions.class));
+ SerializablePipelineOptions secondCopy = SerializableUtils.clone(second);
+
+ assertEquals("first", first.get().as(MyOptions.class).getFoo());
+ assertEquals("first", firstCopy.get().as(MyOptions.class).getFoo());
+ assertEquals("second", second.get().as(MyOptions.class).getFoo());
+ assertEquals("second", secondCopy.get().as(MyOptions.class).getFoo());
+
+ first.get().as(MyOptions.class).setFoo("new first");
+ firstCopy.get().as(MyOptions.class).setFoo("new firstCopy");
+ second.get().as(MyOptions.class).setFoo("new second");
+ secondCopy.get().as(MyOptions.class).setFoo("new secondCopy");
+
+ assertEquals("new first", first.get().as(MyOptions.class).getFoo());
+ assertEquals("new firstCopy", firstCopy.get().as(MyOptions.class).getFoo());
+ assertEquals("new second", second.get().as(MyOptions.class).getFoo());
+ assertEquals("new secondCopy", secondCopy.get().as(MyOptions.class).getFoo());
+ }
+}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java
index 267232c..05c471d 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SplittableParDoTest.java
@@ -22,6 +22,7 @@
import java.io.Serializable;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
@@ -29,6 +30,7 @@
import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.junit.Rule;
@@ -106,12 +108,18 @@
private static final TupleTag<String> MAIN_OUTPUT_TAG = new TupleTag<String>() {};
- private ParDo.MultiOutput<Integer, String> makeParDo(DoFn<Integer, String> fn) {
- return ParDo.of(fn).withOutputTags(MAIN_OUTPUT_TAG, TupleTagList.empty());
+ private PCollection<String> applySplittableParDo(
+ String name, PCollection<Integer> input, DoFn<Integer, String> fn) {
+ ParDo.MultiOutput<Integer, String> multiOutput =
+ ParDo.of(fn).withOutputTags(MAIN_OUTPUT_TAG, TupleTagList.empty());
+ PCollectionTuple output = multiOutput.expand(input);
+ output.get(MAIN_OUTPUT_TAG).setName("main");
+ AppliedPTransform<PCollection<Integer>, PCollectionTuple, ?> transform =
+ AppliedPTransform.of("ParDo", input.expand(), output.expand(), multiOutput, pipeline);
+ return input.apply(name, SplittableParDo.forAppliedParDo(transform)).get(MAIN_OUTPUT_TAG);
}
- @Rule
- public TestPipeline pipeline = TestPipeline.create();
+ @Rule public TestPipeline pipeline = TestPipeline.create();
@Test
public void testBoundednessForBoundedFn() {
@@ -121,16 +129,12 @@
assertEquals(
"Applying a bounded SDF to a bounded collection produces a bounded collection",
PCollection.IsBounded.BOUNDED,
- makeBoundedCollection(pipeline)
- .apply("bounded to bounded", SplittableParDo.forJavaParDo(makeParDo(boundedFn)))
- .get(MAIN_OUTPUT_TAG)
+ applySplittableParDo("bounded to bounded", makeBoundedCollection(pipeline), boundedFn)
.isBounded());
assertEquals(
"Applying a bounded SDF to an unbounded collection produces an unbounded collection",
PCollection.IsBounded.UNBOUNDED,
- makeUnboundedCollection(pipeline)
- .apply("bounded to unbounded", SplittableParDo.forJavaParDo(makeParDo(boundedFn)))
- .get(MAIN_OUTPUT_TAG)
+ applySplittableParDo("bounded to unbounded", makeUnboundedCollection(pipeline), boundedFn)
.isBounded());
}
@@ -142,16 +146,13 @@
assertEquals(
"Applying an unbounded SDF to a bounded collection produces a bounded collection",
PCollection.IsBounded.UNBOUNDED,
- makeBoundedCollection(pipeline)
- .apply("unbounded to bounded", SplittableParDo.forJavaParDo(makeParDo(unboundedFn)))
- .get(MAIN_OUTPUT_TAG)
+ applySplittableParDo("unbounded to bounded", makeBoundedCollection(pipeline), unboundedFn)
.isBounded());
assertEquals(
"Applying an unbounded SDF to an unbounded collection produces an unbounded collection",
PCollection.IsBounded.UNBOUNDED,
- makeUnboundedCollection(pipeline)
- .apply("unbounded to unbounded", SplittableParDo.forJavaParDo(makeParDo(unboundedFn)))
- .get(MAIN_OUTPUT_TAG)
+ applySplittableParDo(
+ "unbounded to unbounded", makeUnboundedCollection(pipeline), unboundedFn)
.isBounded());
}
}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java
index 0e48a9d..62b06b7 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java
@@ -320,7 +320,7 @@
}
@Override
- public Coder<Byte> getDefaultOutputCoder() {
+ public Coder<Byte> getOutputCoder() {
return SerializableCoder.of(Byte.class);
}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
index 4259ac8..e067fac 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/WriteFilesTranslationTest.java
@@ -38,7 +38,6 @@
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -63,12 +62,11 @@
public static class TestWriteFilesPayloadTranslation {
@Parameters(name = "{index}: {0}")
public static Iterable<WriteFiles<Object, Void, Object>> data() {
- SerializableFunction<Object, Object> format = SerializableFunctions.constant(null);
return ImmutableList.of(
- WriteFiles.to(new DummySink(), format),
- WriteFiles.to(new DummySink(), format).withWindowedWrites(),
- WriteFiles.to(new DummySink(), format).withNumShards(17),
- WriteFiles.to(new DummySink(), format).withWindowedWrites().withNumShards(42));
+ WriteFiles.to(new DummySink()),
+ WriteFiles.to(new DummySink()).withWindowedWrites(),
+ WriteFiles.to(new DummySink()).withNumShards(17),
+ WriteFiles.to(new DummySink()).withWindowedWrites().withNumShards(42));
}
@Parameter(0)
@@ -87,7 +85,8 @@
assertThat(payload.getWindowedWrites(), equalTo(writeFiles.isWindowedWrites()));
assertThat(
- (FileBasedSink<String, Void>) WriteFilesTranslation.sinkFromProto(payload.getSink()),
+ (FileBasedSink<String, Void, String>)
+ WriteFilesTranslation.sinkFromProto(payload.getSink()),
equalTo(writeFiles.getSink()));
}
@@ -118,16 +117,17 @@
* A simple {@link FileBasedSink} for testing serialization/deserialization. Not mocked to avoid
* any issues serializing mocks.
*/
- private static class DummySink extends FileBasedSink<Object, Void> {
+ private static class DummySink extends FileBasedSink<Object, Void, Object> {
DummySink() {
super(
StaticValueProvider.of(FileSystems.matchNewResource("nowhere", false)),
- DynamicFileDestinations.constant(new DummyFilenamePolicy()));
+ DynamicFileDestinations.constant(
+ new DummyFilenamePolicy(), SerializableFunctions.constant(null)));
}
@Override
- public WriteOperation<Object, Void> createWriteOperation() {
+ public WriteOperation<Void, Object> createWriteOperation() {
return new DummyWriteOperation(this);
}
@@ -152,13 +152,13 @@
}
}
- private static class DummyWriteOperation extends FileBasedSink.WriteOperation<Object, Void> {
- public DummyWriteOperation(FileBasedSink<Object, Void> sink) {
+ private static class DummyWriteOperation extends FileBasedSink.WriteOperation<Void, Object> {
+ public DummyWriteOperation(FileBasedSink<Object, Void, Object> sink) {
super(sink);
}
@Override
- public FileBasedSink.Writer<Object, Void> createWriter() throws Exception {
+ public FileBasedSink.Writer<Void, Object> createWriter() throws Exception {
throw new UnsupportedOperationException("Should never be called.");
}
}
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupByKeyViaGroupByKeyOnly.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupByKeyViaGroupByKeyOnly.java
index fca3c76..1fdf07c 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupByKeyViaGroupByKeyOnly.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupByKeyViaGroupByKeyOnly.java
@@ -111,12 +111,10 @@
@Override
public PCollection<KV<K, Iterable<WindowedValue<V>>>> expand(PCollection<KV<K, V>> input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded());
- }
-
- @Override
- public Coder<KV<K, Iterable<V>>> getDefaultOutputCoder(PCollection<KV<K, V>> input) {
- return GroupByKey.getOutputKvCoder(input.getCoder());
+ input.getPipeline(),
+ input.getWindowingStrategy(),
+ input.isBounded(),
+ (Coder) GroupByKey.getOutputKvCoder(input.getCoder()));
}
}
@@ -244,9 +242,8 @@
Coder<Iterable<V>> outputValueCoder = IterableCoder.of(inputIterableElementValueCoder);
Coder<KV<K, Iterable<V>>> outputKvCoder = KvCoder.of(keyCoder, outputValueCoder);
- return PCollection.<KV<K, Iterable<V>>>createPrimitiveOutputInternal(
- input.getPipeline(), windowingStrategy, input.isBounded())
- .setCoder(outputKvCoder);
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), windowingStrategy, input.isBounded(), outputKvCoder);
}
}
}
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
index 0c956d5..d830db5 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.core;
import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.Futures;
@@ -37,6 +38,7 @@
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.joda.time.Duration;
@@ -158,7 +160,8 @@
// TODO: verify that if there was a failed tryClaim() call, then cont.shouldResume() is false.
// Currently we can't verify this because there are no hooks into tryClaim().
// See https://issues.apache.org/jira/browse/BEAM-2607
- RestrictionT residual = processContext.extractCheckpoint();
+ processContext.cancelScheduledCheckpoint();
+ KV<RestrictionT, Instant> residual = processContext.getTakenCheckpoint();
if (cont.shouldResume()) {
if (residual == null) {
// No checkpoint had been taken by the runner while the ProcessElement call ran, however
@@ -166,7 +169,7 @@
// a checkpoint now: checkpoint() guarantees that the primary restriction describes exactly
// the work that was done in the current ProcessElement call, and returns a residual
// restriction that describes exactly the work that wasn't done in the current call.
- residual = tracker.checkpoint();
+ residual = checkNotNull(processContext.takeCheckpointNow());
} else {
// A checkpoint was taken by the runner, and then the ProcessElement call returned resume()
// without making more tryClaim() calls (since no tryClaim() calls can succeed after
@@ -185,7 +188,13 @@
// special needs to be done.
}
tracker.checkDone();
- return new Result(residual, cont, processContext.getLastReportedWatermark());
+ if (residual == null) {
+ // Can only be true if cont.shouldResume() is false and no checkpoint was taken.
+ // This means the restriction has been fully processed.
+ checkState(!cont.shouldResume());
+ return new Result(null, cont, BoundedWindow.TIMESTAMP_MAX_VALUE);
+ }
+ return new Result(residual.getKey(), cont, residual.getValue());
}
private class ProcessContext extends DoFn<InputT, OutputT>.ProcessContext {
@@ -199,6 +208,9 @@
// This is either the result of the sole tracker.checkpoint() call, or null if
// the call completed before reaching the given number of outputs or duration.
private RestrictionT checkpoint;
+ // Watermark captured at the moment before checkpoint was taken, describing a lower bound
+ // on the output from "checkpoint".
+ private Instant residualWatermark;
// A handle on the scheduled action to take a checkpoint.
private Future<?> scheduledCheckpoint;
private Instant lastReportedWatermark;
@@ -213,34 +225,36 @@
new Runnable() {
@Override
public void run() {
- initiateCheckpoint();
+ takeCheckpointNow();
}
},
maxDuration.getMillis(),
TimeUnit.MILLISECONDS);
}
- @Nullable
- RestrictionT extractCheckpoint() {
+ void cancelScheduledCheckpoint() {
scheduledCheckpoint.cancel(true);
try {
Futures.getUnchecked(scheduledCheckpoint);
} catch (CancellationException e) {
// This is expected if the call took less than the maximum duration.
}
- // By now, a checkpoint may or may not have been taken;
- // via .output() or via scheduledCheckpoint.
- synchronized (this) {
- return checkpoint;
- }
}
- private synchronized void initiateCheckpoint() {
+ synchronized KV<RestrictionT, Instant> takeCheckpointNow() {
// This method may be entered either via .output(), or via scheduledCheckpoint.
// Only one of them "wins" - tracker.checkpoint() must be called only once.
if (checkpoint == null) {
+ residualWatermark = lastReportedWatermark;
checkpoint = checkNotNull(tracker.checkpoint());
}
+ return getTakenCheckpoint();
+ }
+
+ @Nullable
+ synchronized KV<RestrictionT, Instant> getTakenCheckpoint() {
+ // The checkpoint may or may not have been taken.
+ return (checkpoint == null) ? null : KV.of(checkpoint, residualWatermark);
}
@Override
@@ -271,10 +285,6 @@
lastReportedWatermark = watermark;
}
- public synchronized Instant getLastReportedWatermark() {
- return lastReportedWatermark;
- }
-
@Override
public PipelineOptions getPipelineOptions() {
return pipelineOptions;
@@ -306,7 +316,7 @@
private void noteOutput() {
++numOutputs;
if (numOutputs >= maxNumOutputs) {
- initiateCheckpoint();
+ takeCheckpointNow();
}
}
}
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
index 6e97645..251260e 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
@@ -72,8 +72,15 @@
PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>> {
@Override
public PCollection<KeyedWorkItem<KeyT, InputT>> expand(PCollection<KV<KeyT, InputT>> input) {
+ KvCoder<KeyT, InputT> kvCoder = (KvCoder<KeyT, InputT>) input.getCoder();
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), input.isBounded());
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ input.isBounded(),
+ KeyedWorkItemCoder.of(
+ kvCoder.getKeyCoder(),
+ kvCoder.getValueCoder(),
+ input.getWindowingStrategy().getWindowFn().windowCoder()));
}
@Override
@@ -177,6 +184,7 @@
original.getFn(),
original.getMainOutputTag(),
original.getAdditionalOutputTags(),
+ original.getOutputTagsToCoders(),
original.getInputWindowingStrategy());
}
}
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
index 4f13af1..2341502 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnRunnerTest.java
@@ -39,6 +39,7 @@
import com.google.common.collect.Iterables;
import java.util.List;
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
+import org.apache.beam.runners.core.triggers.DefaultTriggerStateMachine;
import org.apache.beam.runners.core.triggers.TriggerStateMachine;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.metrics.MetricName;
@@ -247,6 +248,88 @@
}
/**
+ * When the watermark passes the end-of-window and window expiration time
+ * in a single update, this tests that it does not crash.
+ */
+ @Test
+ public void testSessionEowAndGcTogether() throws Exception {
+ ReduceFnTester<Integer, Iterable<Integer>, IntervalWindow> tester =
+ ReduceFnTester.nonCombining(
+ Sessions.withGapDuration(Duration.millis(10)),
+ DefaultTriggerStateMachine.<IntervalWindow>of(),
+ AccumulationMode.ACCUMULATING_FIRED_PANES,
+ Duration.millis(50),
+ ClosingBehavior.FIRE_ALWAYS);
+
+ tester.setAutoAdvanceOutputWatermark(true);
+
+ tester.advanceInputWatermark(new Instant(0));
+ injectElement(tester, 1);
+ tester.advanceInputWatermark(new Instant(100));
+
+ assertThat(
+ tester.extractOutput(),
+ contains(
+ isSingleWindowedValue(
+ contains(1), 1, 1, 11, PaneInfo.createPane(true, true, Timing.ON_TIME))));
+ }
+
+ /**
+ * When the watermark passes the end-of-window and window expiration time
+ * in a single update, this tests that it does not crash.
+ */
+ @Test
+ public void testFixedWindowsEowAndGcTogether() throws Exception {
+ ReduceFnTester<Integer, Iterable<Integer>, IntervalWindow> tester =
+ ReduceFnTester.nonCombining(
+ FixedWindows.of(Duration.millis(10)),
+ DefaultTriggerStateMachine.<IntervalWindow>of(),
+ AccumulationMode.ACCUMULATING_FIRED_PANES,
+ Duration.millis(50),
+ ClosingBehavior.FIRE_ALWAYS);
+
+ tester.setAutoAdvanceOutputWatermark(true);
+
+ tester.advanceInputWatermark(new Instant(0));
+ injectElement(tester, 1);
+ tester.advanceInputWatermark(new Instant(100));
+
+ assertThat(
+ tester.extractOutput(),
+ contains(
+ isSingleWindowedValue(
+ contains(1), 1, 0, 10, PaneInfo.createPane(true, true, Timing.ON_TIME))));
+ }
+
+ /**
+ * When the watermark passes the end-of-window and window expiration time
+ * in a single update, this tests that it does not crash.
+ */
+ @Test
+ public void testFixedWindowsEowAndGcTogetherFireIfNonEmpty() throws Exception {
+ ReduceFnTester<Integer, Iterable<Integer>, IntervalWindow> tester =
+ ReduceFnTester.nonCombining(
+ FixedWindows.of(Duration.millis(10)),
+ DefaultTriggerStateMachine.<IntervalWindow>of(),
+ AccumulationMode.ACCUMULATING_FIRED_PANES,
+ Duration.millis(50),
+ ClosingBehavior.FIRE_IF_NON_EMPTY);
+
+ tester.setAutoAdvanceOutputWatermark(true);
+
+ tester.advanceInputWatermark(new Instant(0));
+ injectElement(tester, 1);
+ tester.advanceInputWatermark(new Instant(100));
+
+ List<WindowedValue<Iterable<Integer>>> output = tester.extractOutput();
+ assertThat(
+ output,
+ contains(
+ isSingleWindowedValue(
+ contains(1), 1, 0, 10, PaneInfo.createPane(true, true, Timing.ON_TIME))));
+ }
+
+ /**
* Tests that with the default trigger we will not produce two ON_TIME panes, even
* if there are two outputs that are both candidates.
*/
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java
index 06b8e29..3ba04e7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java
@@ -24,7 +24,6 @@
import org.apache.beam.runners.core.KeyedWorkItemCoder;
import org.apache.beam.runners.core.construction.ForwardingPTransform;
import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
@@ -78,22 +77,18 @@
@Override
public PCollection<KeyedWorkItem<K, V>> expand(PCollection<KV<K, V>> input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), input.isBounded());
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ input.isBounded(),
+ KeyedWorkItemCoder.of(
+ GroupByKey.getKeyCoder(input.getCoder()),
+ GroupByKey.getInputValueCoder(input.getCoder()),
+ input.getWindowingStrategy().getWindowFn().windowCoder()));
}
DirectGroupByKeyOnly() {}
@Override
- protected Coder<?> getDefaultOutputCoder(
- @SuppressWarnings("unused") PCollection<KV<K, V>> input)
- throws CannotProvideCoderException {
- return KeyedWorkItemCoder.of(
- GroupByKey.getKeyCoder(input.getCoder()),
- GroupByKey.getInputValueCoder(input.getCoder()),
- input.getWindowingStrategy().getWindowFn().windowCoder());
- }
-
- @Override
public String getUrn() {
return DIRECT_GBKO_URN;
}
@@ -135,17 +130,11 @@
}
@Override
- protected Coder<?> getDefaultOutputCoder(
- @SuppressWarnings("unused") PCollection<KeyedWorkItem<K, V>> input)
- throws CannotProvideCoderException {
- KeyedWorkItemCoder<K, V> inputCoder = getKeyedWorkItemCoder(input.getCoder());
- return KvCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getElementCoder()));
- }
-
- @Override
public PCollection<KV<K, Iterable<V>>> expand(PCollection<KeyedWorkItem<K, V>> input) {
+ KeyedWorkItemCoder<K, V> inputCoder = getKeyedWorkItemCoder(input.getCoder());
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), outputWindowingStrategy, input.isBounded());
+ input.getPipeline(), outputWindowingStrategy, input.isBounded(),
+ KvCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getElementCoder())));
}
@Override
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 4621224..642ce8f 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -22,6 +22,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
@@ -31,6 +32,7 @@
import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems;
import org.apache.beam.runners.core.construction.PTransformMatchers;
import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult;
import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory;
@@ -156,7 +158,14 @@
}
@Override
- public DirectPipelineResult run(Pipeline pipeline) {
+ public DirectPipelineResult run(Pipeline originalPipeline) {
+ Pipeline pipeline;
+ try {
+ pipeline = PipelineTranslation.fromProto(
+ PipelineTranslation.toProto(originalPipeline));
+ } catch (IOException exception) {
+ throw new RuntimeException("Error preparing pipeline for direct execution.", exception);
+ }
pipeline.replaceAll(defaultTransformOverrides());
MetricsEnvironment.setMetricsSupported(true);
DirectGraphVisitor graphVisitor = new DirectGraphVisitor();
@@ -224,36 +233,41 @@
PTransformMatchers.writeWithRunnerDeterminedSharding(),
new WriteWithShardingFactory())); /* Uses a view internally. */
}
- builder = builder.add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN),
- new ViewOverrideFactory())) /* Uses pardos and GBKs */
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN),
- new DirectTestStreamFactory(this))) /* primitive */
- // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra
- // primitives
- .add(
- PTransformOverride.of(
- PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory()))
- // state and timer pardos are implemented in terms of simple ParDos and extra primitives
- .add(
- PTransformOverride.of(
- PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory()))
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(
- SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN),
- new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN),
- new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN),
- new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */
+ builder =
+ builder
+ .add(
+ PTransformOverride.of(
+ MultiStepCombine.matcher(), MultiStepCombine.Factory.create()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN),
+ new ViewOverrideFactory())) /* Uses pardos and GBKs */
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN),
+ new DirectTestStreamFactory(this))) /* primitive */
+ // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra
+ // primitives
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory()))
+ // state and timer pardos are implemented in terms of simple ParDos and extra primitives
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(
+ SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN),
+ new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN),
+ new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN),
+ new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */
return builder.build();
}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java
new file mode 100644
index 0000000..ae21b4d
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.CombineTranslation;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform;
+import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.runners.PTransformMatcher;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Combine.PerKey;
+import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.joda.time.Instant;
+
+/** A {@link Combine} that performs the combine in multiple steps. */
+class MultiStepCombine<K, InputT, AccumT, OutputT>
+ extends RawPTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> {
+ public static PTransformMatcher matcher() {
+ return new PTransformMatcher() {
+ @Override
+ public boolean matches(AppliedPTransform<?, ?, ?> application) {
+ if (PTransformTranslation.COMBINE_TRANSFORM_URN.equals(
+ PTransformTranslation.urnForTransformOrNull(application.getTransform()))) {
+ try {
+ GlobalCombineFn fn = CombineTranslation.getCombineFn(application);
+ return isApplicable(application.getInputs(), fn);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return false;
+ }
+
+ private <K, InputT> boolean isApplicable(
+ Map<TupleTag<?>, PValue> inputs, GlobalCombineFn<InputT, ?, ?> fn) {
+ if (!(fn instanceof CombineFn)) {
+ return false;
+ }
+ if (inputs.size() == 1) {
+ PCollection<KV<K, InputT>> input =
+ (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values());
+ WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
+ boolean windowFnApplicable = windowingStrategy.getWindowFn().isNonMerging();
+ // Triggering with count based triggers is not appropriately handled here. Disabling
+ // most triggers is safe, though more broad than is technically required.
+ boolean triggerApplicable = DefaultTrigger.of().equals(windowingStrategy.getTrigger());
+ boolean accumulatorCoderAvailable;
+ try {
+ if (input.getCoder() instanceof KvCoder) {
+ KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder();
+ Coder<?> accumulatorCoder =
+ fn.getAccumulatorCoder(
+ input.getPipeline().getCoderRegistry(), kvCoder.getValueCoder());
+ accumulatorCoderAvailable = accumulatorCoder != null;
+ } else {
+ accumulatorCoderAvailable = false;
+ }
+ } catch (CannotProvideCoderException e) {
+ throw new RuntimeException(
+ String.format(
+ "Could not construct an accumulator %s for %s. Accumulator %s for a %s may be"
+ + " null, but may not throw an exception",
+ Coder.class.getSimpleName(),
+ fn,
+ Coder.class.getSimpleName(),
+ Combine.class.getSimpleName()),
+ e);
+ }
+ return windowFnApplicable && triggerApplicable && accumulatorCoderAvailable;
+ }
+ return false;
+ }
+ };
+ }
+
+ static class Factory<K, InputT, AccumT, OutputT>
+ extends SingleInputOutputOverrideFactory<
+ PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>,
+ PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
+ public static PTransformOverrideFactory create() {
+ return new Factory<>();
+ }
+
+ private Factory() {}
+
+ @Override
+ public PTransformReplacement<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>
+ getReplacementTransform(
+ AppliedPTransform<
+ PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>,
+ PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>>
+ transform) {
+ try {
+ GlobalCombineFn<?, ?, ?> globalFn = CombineTranslation.getCombineFn(transform);
+ checkState(
+ globalFn instanceof CombineFn,
+ "%s.matcher() should only match %s instances using %s, got %s",
+ MultiStepCombine.class.getSimpleName(),
+ PerKey.class.getSimpleName(),
+ CombineFn.class.getSimpleName(),
+ globalFn.getClass().getName());
+ @SuppressWarnings("unchecked")
+ CombineFn<InputT, AccumT, OutputT> fn = (CombineFn<InputT, AccumT, OutputT>) globalFn;
+ @SuppressWarnings("unchecked")
+ PCollection<KV<K, InputT>> input =
+ (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(transform.getInputs().values());
+ @SuppressWarnings("unchecked")
+ PCollection<KV<K, OutputT>> output =
+ (PCollection<KV<K, OutputT>>) Iterables.getOnlyElement(transform.getOutputs().values());
+ return PTransformReplacement.of(input, new MultiStepCombine<>(fn, output.getCoder()));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ // ===========================================================================================
+
+ private final CombineFn<InputT, AccumT, OutputT> combineFn;
+ private final Coder<KV<K, OutputT>> outputCoder;
+
+ public static <K, InputT, AccumT, OutputT> MultiStepCombine<K, InputT, AccumT, OutputT> of(
+ CombineFn<InputT, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) {
+ return new MultiStepCombine<>(combineFn, outputCoder);
+ }
+
+ private MultiStepCombine(
+ CombineFn<InputT, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) {
+ this.combineFn = combineFn;
+ this.outputCoder = outputCoder;
+ }
+
+ @Nullable
+ @Override
+ public String getUrn() {
+ return "urn:beam:directrunner:transforms:multistepcombine:v1";
+ }
+
+ @Override
+ public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, InputT>> input) {
+ checkArgument(
+ input.getCoder() instanceof KvCoder,
+ "Expected input to have a %s of type %s, got %s",
+ Coder.class.getSimpleName(),
+ KvCoder.class.getSimpleName(),
+ input.getCoder());
+ KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+ Coder<InputT> inputValueCoder = inputCoder.getValueCoder();
+ Coder<AccumT> accumulatorCoder;
+ try {
+ accumulatorCoder =
+ combineFn.getAccumulatorCoder(input.getPipeline().getCoderRegistry(), inputValueCoder);
+ } catch (CannotProvideCoderException e) {
+ throw new IllegalStateException(
+ String.format(
+ "Could not construct an Accumulator Coder with the provided %s %s",
+ CombineFn.class.getSimpleName(), combineFn),
+ e);
+ }
+ return input
+ .apply(
+ ParDo.of(
+ new CombineInputs<>(
+ combineFn,
+ input.getWindowingStrategy().getTimestampCombiner(),
+ inputCoder.getKeyCoder())))
+ .setCoder(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder))
+ .apply(GroupByKey.<K, AccumT>create())
+ .apply(new MergeAndExtractAccumulatorOutput<>(combineFn, outputCoder));
+ }
+
+ private static class CombineInputs<K, InputT, AccumT> extends DoFn<KV<K, InputT>, KV<K, AccumT>> {
+ private final CombineFn<InputT, AccumT, ?> combineFn;
+ private final TimestampCombiner timestampCombiner;
+ private final Coder<K> keyCoder;
+
+ /**
+ * Per-bundle state. Accumulators and output timestamps should only be tracked while a bundle
+ * is being processed, and must be cleared when a bundle is completed.
+ */
+ private transient Map<WindowedStructuralKey<K>, AccumT> accumulators;
+ private transient Map<WindowedStructuralKey<K>, Instant> timestamps;
+
+ private CombineInputs(
+ CombineFn<InputT, AccumT, ?> combineFn,
+ TimestampCombiner timestampCombiner,
+ Coder<K> keyCoder) {
+ this.combineFn = combineFn;
+ this.timestampCombiner = timestampCombiner;
+ this.keyCoder = keyCoder;
+ }
+
+ @StartBundle
+ public void startBundle() {
+ accumulators = new LinkedHashMap<>();
+ timestamps = new LinkedHashMap<>();
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context, BoundedWindow window) {
+ WindowedStructuralKey<K>
+ key = WindowedStructuralKey.create(keyCoder, context.element().getKey(), window);
+ AccumT accumulator = accumulators.get(key);
+ Instant assignedTs = timestampCombiner.assign(window, context.timestamp());
+ if (accumulator == null) {
+ accumulator = combineFn.createAccumulator();
+ accumulators.put(key, accumulator);
+ timestamps.put(key, assignedTs);
+ }
+ accumulators.put(key, combineFn.addInput(accumulator, context.element().getValue()));
+ timestamps.put(key, timestampCombiner.combine(assignedTs, timestamps.get(key)));
+ }
+
+ @FinishBundle
+ public void outputAccumulators(FinishBundleContext context) {
+ for (Map.Entry<WindowedStructuralKey<K>, AccumT> preCombineEntry : accumulators.entrySet()) {
+ context.output(
+ KV.of(preCombineEntry.getKey().getKey(), combineFn.compact(preCombineEntry.getValue())),
+ timestamps.get(preCombineEntry.getKey()),
+ preCombineEntry.getKey().getWindow());
+ }
+ accumulators = null;
+ timestamps = null;
+ }
+ }
+
+ static class WindowedStructuralKey<K> {
+ public static <K> WindowedStructuralKey<K> create(
+ Coder<K> keyCoder, K key, BoundedWindow window) {
+ return new WindowedStructuralKey<>(StructuralKey.of(key, keyCoder), window);
+ }
+
+ private final StructuralKey<K> key;
+ private final BoundedWindow window;
+
+ private WindowedStructuralKey(StructuralKey<K> key, BoundedWindow window) {
+ this.key = checkNotNull(key, "key cannot be null");
+ this.window = checkNotNull(window, "Window cannot be null");
+ }
+
+ public K getKey() {
+ return key.getKey();
+ }
+
+ public BoundedWindow getWindow() {
+ return window;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof MultiStepCombine.WindowedStructuralKey)) {
+ return false;
+ }
+ WindowedStructuralKey that = (WindowedStructuralKey<?>) other;
+ return this.window.equals(that.window) && this.key.equals(that.key);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(window, key);
+ }
+ }
+
+ static final String DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN =
+ "urn:beam:directrunner:transforms:merge_accumulators_extract_output:v1";
+ /**
+ * A primitive {@link PTransform} that merges iterables of accumulators and extracts the output.
+ *
+ * <p>Required to ensure that Immutability Enforcement is not applied. Accumulators
+ * are explicitly mutable.
+ */
+ static class MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>
+ extends RawPTransform<PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>> {
+ private final CombineFn<?, AccumT, OutputT> combineFn;
+ private final Coder<KV<K, OutputT>> outputCoder;
+
+ private MergeAndExtractAccumulatorOutput(
+ CombineFn<?, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) {
+ this.combineFn = combineFn;
+ this.outputCoder = outputCoder;
+ }
+
+ CombineFn<?, AccumT, OutputT> getCombineFn() {
+ return combineFn;
+ }
+
+ @Override
+ public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, Iterable<AccumT>>> input) {
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), outputCoder);
+ }
+
+ @Nullable
+ @Override
+ public String getUrn() {
+ return DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN;
+ }
+ }
+
+ static class MergeAndExtractAccumulatorOutputEvaluatorFactory
+ implements TransformEvaluatorFactory {
+ private final EvaluationContext ctxt;
+
+ public MergeAndExtractAccumulatorOutputEvaluatorFactory(EvaluationContext ctxt) {
+ this.ctxt = ctxt;
+ }
+
+ @Nullable
+ @Override
+ public <InputT> TransformEvaluator<InputT> forApplication(
+ AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) throws Exception {
+ return createEvaluator((AppliedPTransform) application, (CommittedBundle) inputBundle);
+ }
+
+ private <K, AccumT, OutputT> TransformEvaluator<KV<K, Iterable<AccumT>>> createEvaluator(
+ AppliedPTransform<
+ PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>,
+ MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>>
+ application,
+ CommittedBundle<KV<K, Iterable<AccumT>>> inputBundle) {
+ return new MergeAccumulatorsAndExtractOutputEvaluator<>(ctxt, application);
+ }
+
+ @Override
+ public void cleanup() throws Exception {}
+ }
+
+ private static class MergeAccumulatorsAndExtractOutputEvaluator<K, AccumT, OutputT>
+ implements TransformEvaluator<KV<K, Iterable<AccumT>>> {
+ private final AppliedPTransform<
+ PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>,
+ MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>>
+ application;
+ private final CombineFn<?, AccumT, OutputT> combineFn;
+ private final UncommittedBundle<KV<K, OutputT>> output;
+
+ public MergeAccumulatorsAndExtractOutputEvaluator(
+ EvaluationContext ctxt,
+ AppliedPTransform<
+ PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>,
+ MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>>
+ application) {
+ this.application = application;
+ this.combineFn = application.getTransform().getCombineFn();
+ this.output =
+ ctxt.createBundle(
+ (PCollection<KV<K, OutputT>>)
+ Iterables.getOnlyElement(application.getOutputs().values()));
+ }
+
+ @Override
+ public void processElement(WindowedValue<KV<K, Iterable<AccumT>>> element) throws Exception {
+ checkState(
+ element.getWindows().size() == 1,
+ "Expected inputs to %s to be in exactly one window. Got %s",
+ MergeAccumulatorsAndExtractOutputEvaluator.class.getSimpleName(),
+ element.getWindows().size());
+ Iterable<AccumT> inputAccumulators = element.getValue().getValue();
+ try {
+ AccumT first = combineFn.createAccumulator();
+ AccumT merged = combineFn.mergeAccumulators(Iterables.concat(Collections.singleton(first),
+ inputAccumulators,
+ Collections.singleton(combineFn.createAccumulator())));
+ OutputT extracted = combineFn.extractOutput(merged);
+ output.add(element.withValue(KV.of(element.getValue().getKey(), extracted)));
+ } catch (Exception e) {
+ throw UserCodeException.wrap(e);
+ }
+ }
+
+ @Override
+ public TransformResult<KV<K, Iterable<AccumT>>> finishBundle() throws Exception {
+ return StepTransformResult.<KV<K, Iterable<AccumT>>>withoutHold(application)
+ .addOutput(output)
+ .build();
+ }
+ }
+}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 891d102..26f30b0 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -20,6 +20,7 @@
import static com.google.common.base.Preconditions.checkState;
import java.io.IOException;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.KeyedWorkItem;
@@ -95,7 +96,7 @@
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
if (signature.processElement().isSplittable()) {
- return (PTransform) SplittableParDo.forAppliedParDo(application);
+ return SplittableParDo.forAppliedParDo((AppliedPTransform) application);
} else if (signature.stateDeclarations().size() > 0
|| signature.timerDeclarations().size() > 0) {
return new GbkThenStatefulParDo(
@@ -248,6 +249,8 @@
PCollectionTuple.ofPrimitiveOutputsInternal(
input.getPipeline(),
TupleTagList.of(getMainOutputTag()).and(getAdditionalOutputTags().getAll()),
+ // TODO
+ Collections.<TupleTag<?>, Coder<?>>emptyMap(),
input.getWindowingStrategy(),
input.isBounded());
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
index e6b51b7..bc7b193 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
@@ -178,8 +178,10 @@
.setDaemon(true)
.setNameFormat("direct-splittable-process-element-checkpoint-executor")
.build()),
- 10000,
- Duration.standardSeconds(10)));
+ // Setting small values here to stimulate frequent checkpointing and better exercise
+ // splittable DoFn's in that respect.
+ 100,
+ Duration.standardSeconds(1)));
return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(parDoEvaluator, fnManager);
}
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
index 16c8589..49e7be7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
@@ -207,9 +207,11 @@
@Override
public PCollection<T> expand(PBegin input) {
runner.setClockSupplier(new TestClockSupplier());
- return PCollection.<T>createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
- .setCoder(original.getValueCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ IsBounded.UNBOUNDED,
+ original.getValueCoder());
}
@Override
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index 0c907df..30666db 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -26,6 +26,7 @@
import static org.apache.beam.runners.core.construction.SplittableParDo.SPLITTABLE_PROCESS_URN;
import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GABW_URN;
import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GBKO_URN;
+import static org.apache.beam.runners.direct.MultiStepCombine.DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN;
import static org.apache.beam.runners.direct.ParDoMultiOverrideFactory.DIRECT_STATEFUL_PAR_DO_URN;
import static org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory.DIRECT_TEST_STREAM_URN;
import static org.apache.beam.runners.direct.ViewOverrideFactory.DIRECT_WRITE_VIEW_URN;
@@ -73,6 +74,9 @@
.put(DIRECT_GBKO_URN, new GroupByKeyOnlyEvaluatorFactory(ctxt))
.put(DIRECT_GABW_URN, new GroupAlsoByWindowEvaluatorFactory(ctxt))
.put(DIRECT_TEST_STREAM_URN, new TestStreamEvaluatorFactory(ctxt))
+ .put(
+ DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN,
+ new MultiStepCombine.MergeAndExtractAccumulatorOutputEvaluatorFactory(ctxt))
// Runners-core primitives
.put(SPLITTABLE_PROCESS_URN, new SplittableProcessElementsEvaluatorFactory<>(ctxt))
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
index 5dcf016..c2255fe 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
@@ -115,9 +115,8 @@
@Override
@SuppressWarnings("deprecation")
public PCollection<Iterable<ElemT>> expand(PCollection<Iterable<ElemT>> input) {
- return PCollection.<Iterable<ElemT>>createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
- .setCoder(input.getCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder());
}
@SuppressWarnings("deprecation")
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
index ba796ae..3557c5d 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -24,10 +24,12 @@
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.beam.runners.core.construction.PTransformReplacements;
import org.apache.beam.runners.core.construction.WriteFilesTranslation;
+import org.apache.beam.sdk.io.FileBasedSink;
import org.apache.beam.sdk.io.WriteFiles;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
@@ -61,10 +63,10 @@
AppliedPTransform<PCollection<InputT>, PDone, PTransform<PCollection<InputT>, PDone>>
transform) {
try {
- WriteFiles<InputT, ?, ?> replacement =
- WriteFiles.to(
- WriteFilesTranslation.getSink(transform),
- WriteFilesTranslation.getFormatFunction(transform));
+ List<PCollectionView<?>> sideInputs =
+ WriteFilesTranslation.getDynamicDestinationSideInputs(transform);
+ FileBasedSink sink = WriteFilesTranslation.getSink(transform);
+ WriteFiles<InputT, ?, ?> replacement = WriteFiles.to(sink).withSideInputs(sideInputs);
if (WriteFilesTranslation.isWindowedWrites(transform)) {
replacement = replacement.withWindowedWrites();
}
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java
index 6180d29..3d81884 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java
@@ -395,7 +395,7 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return coder;
}
}
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java
index 8b95b34..29ed55d 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java
@@ -27,6 +27,7 @@
import java.util.EnumSet;
import java.util.List;
import org.apache.beam.runners.direct.CommittedResult.OutputType;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
@@ -113,13 +114,24 @@
@Test
public void getOutputsEqualInput() {
- List<? extends CommittedBundle<?>> outputs =
- ImmutableList.of(bundleFactory.createBundle(PCollection.createPrimitiveOutputInternal(p,
- WindowingStrategy.globalDefault(),
- PCollection.IsBounded.BOUNDED)).commit(Instant.now()),
- bundleFactory.createBundle(PCollection.createPrimitiveOutputInternal(p,
- WindowingStrategy.globalDefault(),
- PCollection.IsBounded.UNBOUNDED)).commit(Instant.now()));
+ List<? extends CommittedBundle<Integer>> outputs =
+ ImmutableList.of(
+ bundleFactory
+ .createBundle(
+ PCollection.createPrimitiveOutputInternal(
+ p,
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.BOUNDED,
+ VarIntCoder.of()))
+ .commit(Instant.now()),
+ bundleFactory
+ .createBundle(
+ PCollection.createPrimitiveOutputInternal(
+ p,
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.UNBOUNDED,
+ VarIntCoder.of()))
+ .commit(Instant.now()));
CommittedResult result =
CommittedResult.create(
StepTransformResult.withoutHold(transform).build(),
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java
index 943d27c..d3f407a 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java
@@ -573,8 +573,8 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
- return underlying.getDefaultOutputCoder();
+ public Coder<T> getOutputCoder() {
+ return underlying.getOutputCoder();
}
}
}
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
index 699a318..cc9ce60 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java
@@ -40,6 +40,7 @@
import org.apache.beam.runners.direct.WatermarkManager.FiredTimers;
import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate;
import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
@@ -127,8 +128,11 @@
public void writeToViewWriterThenReadReads() {
PCollectionViewWriter<Integer, Iterable<Integer>> viewWriter =
context.createPCollectionViewWriter(
- PCollection.<Iterable<Integer>>createPrimitiveOutputInternal(
- p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED),
+ PCollection.createPrimitiveOutputInternal(
+ p,
+ WindowingStrategy.globalDefault(),
+ IsBounded.BOUNDED,
+ IterableCoder.of(VarIntCoder.of())),
view);
BoundedWindow window = new TestBoundedWindow(new Instant(1024L));
BoundedWindow second = new TestBoundedWindow(new Instant(899999L));
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java
new file mode 100644
index 0000000..0c11a8a
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+
+import com.google.auto.value.AutoValue;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.VarInt;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link MultiStepCombine}.
+ */
+@RunWith(JUnit4.class)
+public class MultiStepCombineTest implements Serializable {
+ @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+ private transient KvCoder<String, Long> combinedCoder =
+ KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of());
+
+ @Test
+ public void testMultiStepCombine() {
+ PCollection<KV<String, Long>> combined =
+ pipeline
+ .apply(
+ Create.of(
+ KV.of("foo", 1L),
+ KV.of("bar", 2L),
+ KV.of("bizzle", 3L),
+ KV.of("bar", 4L),
+ KV.of("bizzle", 11L)))
+ .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn()));
+
+ PAssert.that(combined)
+ .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 6L), KV.of("bizzle", 14L));
+ pipeline.run();
+ }
+
+ @Test
+ public void testMultiStepCombineWindowed() {
+ SlidingWindows windowFn = SlidingWindows.of(Duration.millis(6L)).every(Duration.millis(3L));
+ PCollection<KV<String, Long>> combined =
+ pipeline
+ .apply(
+ Create.timestamped(
+ TimestampedValue.of(KV.of("foo", 1L), new Instant(1L)),
+ TimestampedValue.of(KV.of("bar", 2L), new Instant(2L)),
+ TimestampedValue.of(KV.of("bizzle", 3L), new Instant(3L)),
+ TimestampedValue.of(KV.of("bar", 4L), new Instant(4L)),
+ TimestampedValue.of(KV.of("bizzle", 11L), new Instant(11L))))
+ .apply(Window.<KV<String, Long>>into(windowFn))
+ .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn()));
+
+ PAssert.that("Windows should combine only elements in their windows", combined)
+ .inWindow(new IntervalWindow(new Instant(0L), Duration.millis(6L)))
+ .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 6L), KV.of("bizzle", 3L));
+ PAssert.that("Elements should appear in all the windows they are assigned to", combined)
+ .inWindow(new IntervalWindow(new Instant(-3L), Duration.millis(6L)))
+ .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 2L));
+ PAssert.that(combined)
+ .inWindow(new IntervalWindow(new Instant(6L), Duration.millis(6L)))
+ .containsInAnyOrder(KV.of("bizzle", 11L));
+ PAssert.that(combined)
+ .containsInAnyOrder(
+ KV.of("foo", 1L),
+ KV.of("foo", 1L),
+ KV.of("bar", 6L),
+ KV.of("bar", 2L),
+ KV.of("bar", 4L),
+ KV.of("bizzle", 11L),
+ KV.of("bizzle", 11L),
+ KV.of("bizzle", 3L),
+ KV.of("bizzle", 3L));
+ pipeline.run();
+ }
+
+ @Test
+ public void testMultiStepCombineTimestampCombiner() {
+ TimestampCombiner combiner = TimestampCombiner.LATEST;
+ combinedCoder = KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of());
+ PCollection<KV<String, Long>> combined =
+ pipeline
+ .apply(
+ Create.timestamped(
+ TimestampedValue.of(KV.of("foo", 4L), new Instant(1L)),
+ TimestampedValue.of(KV.of("foo", 1L), new Instant(4L)),
+ TimestampedValue.of(KV.of("bazzle", 4L), new Instant(4L)),
+ TimestampedValue.of(KV.of("foo", 12L), new Instant(12L))))
+ .apply(
+ Window.<KV<String, Long>>into(FixedWindows.of(Duration.millis(5L)))
+ .withTimestampCombiner(combiner))
+ .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn()));
+ PCollection<KV<String, TimestampedValue<Long>>> reified =
+ combined.apply(
+ ParDo.of(
+ new DoFn<KV<String, Long>, KV<String, TimestampedValue<Long>>>() {
+ @ProcessElement
+ public void reifyTimestamp(ProcessContext context) {
+ context.output(
+ KV.of(
+ context.element().getKey(),
+ TimestampedValue.of(
+ context.element().getValue(), context.timestamp())));
+ }
+ }));
+
+ PAssert.that(reified)
+ .containsInAnyOrder(
+ KV.of("foo", TimestampedValue.of(5L, new Instant(4L))),
+ KV.of("bazzle", TimestampedValue.of(4L, new Instant(4L))),
+ KV.of("foo", TimestampedValue.of(12L, new Instant(12L))));
+ pipeline.run();
+ }
+
+ private static class MultiStepCombineFn extends CombineFn<Long, MultiStepAccumulator, Long> {
+ @Override
+ public Coder<MultiStepAccumulator> getAccumulatorCoder(
+ CoderRegistry registry, Coder<Long> inputCoder) throws CannotProvideCoderException {
+ return new MultiStepAccumulatorCoder();
+ }
+
+ @Override
+ public MultiStepAccumulator createAccumulator() {
+ return MultiStepAccumulator.of(0L, false);
+ }
+
+ @Override
+ public MultiStepAccumulator addInput(MultiStepAccumulator accumulator, Long input) {
+ return MultiStepAccumulator.of(accumulator.getValue() + input, accumulator.isDeserialized());
+ }
+
+ @Override
+ public MultiStepAccumulator mergeAccumulators(Iterable<MultiStepAccumulator> accumulators) {
+ MultiStepAccumulator result = MultiStepAccumulator.of(0L, false);
+ for (MultiStepAccumulator accumulator : accumulators) {
+ result = result.merge(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public Long extractOutput(MultiStepAccumulator accumulator) {
+ assertThat(
+ "Accumulators should have been serialized and deserialized within the Pipeline",
+ accumulator.isDeserialized(),
+ is(true));
+ return accumulator.getValue();
+ }
+ }
+
+ @AutoValue
+ abstract static class MultiStepAccumulator {
+ private static MultiStepAccumulator of(long value, boolean deserialized) {
+ return new AutoValue_MultiStepCombineTest_MultiStepAccumulator(value, deserialized);
+ }
+
+ MultiStepAccumulator merge(MultiStepAccumulator other) {
+ return MultiStepAccumulator.of(
+ this.getValue() + other.getValue(), this.isDeserialized() || other.isDeserialized());
+ }
+
+ abstract long getValue();
+
+ abstract boolean isDeserialized();
+ }
+
+ private static class MultiStepAccumulatorCoder extends CustomCoder<MultiStepAccumulator> {
+ @Override
+ public void encode(MultiStepAccumulator value, OutputStream outStream)
+ throws CoderException, IOException {
+ VarInt.encode(value.getValue(), outStream);
+ }
+
+ @Override
+ public MultiStepAccumulator decode(InputStream inStream) throws CoderException, IOException {
+ return MultiStepAccumulator.of(VarInt.decodeLong(inStream), true);
+ }
+ }
+}
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java
index 2a01db5..cc6847d 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java
@@ -477,7 +477,7 @@
public void validate() {}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return coder;
}
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
index 6af9273..94d8d70 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
@@ -23,22 +23,17 @@
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
-import com.google.common.collect.ImmutableSet;
import java.io.Serializable;
import java.util.List;
-import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
-import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
import org.apache.beam.sdk.transforms.ViewFn;
import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
@@ -62,42 +57,6 @@
new ViewOverrideFactory<>();
@Test
- public void replacementSucceeds() {
- PCollection<Integer> ints = p.apply("CreateContents", Create.of(1, 2, 3));
- final PCollectionView<List<Integer>> view =
- PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder());
- PTransformReplacement<PCollection<Integer>, PCollection<Integer>>
- replacementTransform =
- factory.getReplacementTransform(
- AppliedPTransform
- .<PCollection<Integer>, PCollection<Integer>,
- PTransform<PCollection<Integer>, PCollection<Integer>>>
- of(
- "foo",
- ints.expand(),
- view.expand(),
- CreatePCollectionView.<Integer, List<Integer>>of(view),
- p));
- ints.apply(replacementTransform.getTransform());
-
- PCollection<Set<Integer>> outputViewContents =
- p.apply("CreateSingleton", Create.of(0))
- .apply(
- "OutputContents",
- ParDo.of(
- new DoFn<Integer, Set<Integer>>() {
- @ProcessElement
- public void outputSideInput(ProcessContext context) {
- context.output(ImmutableSet.copyOf(context.sideInput(view)));
- }
- })
- .withSideInputs(view));
- PAssert.thatSingleton(outputViewContents).isEqualTo(ImmutableSet.of(1, 2, 3));
-
- p.run();
- }
-
- @Test
public void replacementGetViewReturnsOriginal() {
final PCollection<Integer> ints = p.apply("CreateContents", Create.of(1, 2, 3));
final PCollectionView<List<Integer>> view =
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
index 546a181..d0db44e 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -36,6 +36,7 @@
import java.util.Collections;
import java.util.List;
import java.util.UUID;
+import javax.annotation.Nullable;
import org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.coders.VoidCoder;
@@ -54,8 +55,9 @@
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnTester;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PCollectionViews;
@@ -140,14 +142,14 @@
PTransform<PCollection<Object>, PDone> original =
WriteFiles.to(
- new FileBasedSink<Object, Void>(
- StaticValueProvider.of(outputDirectory), DynamicFileDestinations.constant(null)) {
+ new FileBasedSink<Object, Void, Object>(
+ StaticValueProvider.of(outputDirectory),
+ DynamicFileDestinations.constant(new FakeFilenamePolicy())) {
@Override
- public WriteOperation<Object, Void> createWriteOperation() {
+ public WriteOperation<Void, Object> createWriteOperation() {
throw new IllegalArgumentException("Should not be used");
}
- },
- SerializableFunctions.identity());
+ });
@SuppressWarnings("unchecked")
PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of()));
@@ -234,4 +236,25 @@
List<Integer> shards = fnTester.processBundle((long) count);
assertThat(shards, containsInAnyOrder(13));
}
+
+ private static class FakeFilenamePolicy extends FileBasedSink.FilenamePolicy {
+ @Override
+ public ResourceId windowedFilename(
+ int shardNumber,
+ int numShards,
+ BoundedWindow window,
+ PaneInfo paneInfo,
+ FileBasedSink.OutputFileHints outputFileHints) {
+ throw new IllegalArgumentException("Should not be used");
+ }
+
+ @Nullable
+ @Override
+ public ResourceId unwindowedFilename(
+ int shardNumber,
+ int numShards,
+ FileBasedSink.OutputFileHints outputFileHints) {
+ throw new IllegalArgumentException("Should not be used");
+ }
+ }
}
diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml
index c063a2d..06746fd 100644
--- a/runners/flink/pom.xml
+++ b/runners/flink/pom.xml
@@ -256,16 +256,6 @@
</dependency>
<dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-core</artifactId>
- </dependency>
-
- <dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-databind</artifactId>
- </dependency>
-
- <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
index 0cc3aec..3114a6f 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
@@ -120,9 +120,8 @@
@Override
public PCollection<List<ElemT>> expand(PCollection<List<ElemT>> input) {
- return PCollection.<List<ElemT>>createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
- .setCoder(input.getCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder());
}
public PCollectionView<ViewT> getView() {
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
index d8ed622..3048168 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
@@ -22,9 +22,9 @@
import java.util.Map;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
@@ -50,7 +50,7 @@
public class FlinkDoFnFunction<InputT, OutputT>
extends RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<OutputT>> {
- private final SerializedPipelineOptions serializedOptions;
+ private final SerializablePipelineOptions serializedOptions;
private final DoFn<InputT, OutputT> doFn;
private final String stepName;
@@ -75,7 +75,7 @@
this.doFn = doFn;
this.stepName = stepName;
this.sideInputs = sideInputs;
- this.serializedOptions = new SerializedPipelineOptions(options);
+ this.serializedOptions = new SerializablePipelineOptions(options);
this.windowingStrategy = windowingStrategy;
this.outputMap = outputMap;
this.mainOutputTag = mainOutputTag;
@@ -101,7 +101,7 @@
List<TupleTag<?>> additionalOutputTags = Lists.newArrayList(outputMap.keySet());
DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner(
- serializedOptions.getPipelineOptions(), doFn,
+ serializedOptions.get(), doFn,
new FlinkSideInputReader(sideInputs, runtimeContext),
outputManager,
mainOutputTag,
@@ -109,7 +109,7 @@
new FlinkNoOpStepContext(),
windowingStrategy);
- if ((serializedOptions.getPipelineOptions().as(FlinkPipelineOptions.class))
+ if ((serializedOptions.get().as(FlinkPipelineOptions.class))
.getEnableMetrics()) {
doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, getRuntimeContext());
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java
index 13be913..c73dade 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMergingNonShuffleReduceFunction.java
@@ -18,7 +18,7 @@
package org.apache.beam.runners.flink.translation.functions;
import java.util.Map;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -47,7 +47,7 @@
private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
- private final SerializedPipelineOptions serializedOptions;
+ private final SerializablePipelineOptions serializedOptions;
public FlinkMergingNonShuffleReduceFunction(
CombineFnBase.GlobalCombineFn<InputT, AccumT, OutputT> combineFn,
@@ -60,7 +60,7 @@
this.windowingStrategy = windowingStrategy;
this.sideInputs = sideInputs;
- this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+ this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
}
@@ -69,7 +69,7 @@
Iterable<WindowedValue<KV<K, InputT>>> elements,
Collector<WindowedValue<KV<K, OutputT>>> out) throws Exception {
- PipelineOptions options = serializedOptions.getPipelineOptions();
+ PipelineOptions options = serializedOptions.get();
FlinkSideInputReader sideInputReader =
new FlinkSideInputReader(sideInputs, getRuntimeContext());
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java
index db12a49..49e821c 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java
@@ -18,7 +18,7 @@
package org.apache.beam.runners.flink.translation.functions;
import java.util.Map;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -46,7 +46,7 @@
protected final WindowingStrategy<Object, W> windowingStrategy;
- protected final SerializedPipelineOptions serializedOptions;
+ protected final SerializablePipelineOptions serializedOptions;
protected final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
@@ -59,7 +59,7 @@
this.combineFn = combineFn;
this.windowingStrategy = windowingStrategy;
this.sideInputs = sideInputs;
- this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+ this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
}
@@ -68,7 +68,7 @@
Iterable<WindowedValue<KV<K, InputT>>> elements,
Collector<WindowedValue<KV<K, AccumT>>> out) throws Exception {
- PipelineOptions options = serializedOptions.getPipelineOptions();
+ PipelineOptions options = serializedOptions.get();
FlinkSideInputReader sideInputReader =
new FlinkSideInputReader(sideInputs, getRuntimeContext());
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java
index 53d71d8..6645b3a 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java
@@ -18,7 +18,7 @@
package org.apache.beam.runners.flink.translation.functions;
import java.util.Map;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -48,7 +48,7 @@
protected final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
- protected final SerializedPipelineOptions serializedOptions;
+ protected final SerializablePipelineOptions serializedOptions;
public FlinkReduceFunction(
CombineFnBase.GlobalCombineFn<?, AccumT, OutputT> combineFn,
@@ -61,7 +61,7 @@
this.windowingStrategy = windowingStrategy;
this.sideInputs = sideInputs;
- this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+ this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
}
@@ -70,7 +70,7 @@
Iterable<WindowedValue<KV<K, AccumT>>> elements,
Collector<WindowedValue<KV<K, OutputT>>> out) throws Exception {
- PipelineOptions options = serializedOptions.getPipelineOptions();
+ PipelineOptions options = serializedOptions.get();
FlinkSideInputReader sideInputReader =
new FlinkSideInputReader(sideInputs, getRuntimeContext());
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
index 11d4fee4..412269c 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
@@ -31,9 +31,9 @@
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
@@ -61,7 +61,7 @@
private String stepName;
private final WindowingStrategy<?, ?> windowingStrategy;
private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
- private final SerializedPipelineOptions serializedOptions;
+ private final SerializablePipelineOptions serializedOptions;
private final Map<TupleTag<?>, Integer> outputMap;
private final TupleTag<OutputT> mainOutputTag;
private transient DoFnInvoker doFnInvoker;
@@ -79,7 +79,7 @@
this.stepName = stepName;
this.windowingStrategy = windowingStrategy;
this.sideInputs = sideInputs;
- this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+ this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
this.outputMap = outputMap;
this.mainOutputTag = mainOutputTag;
}
@@ -118,7 +118,7 @@
List<TupleTag<?>> additionalOutputTags = Lists.newArrayList(outputMap.keySet());
DoFnRunner<KV<K, V>, OutputT> doFnRunner = DoFnRunners.simpleRunner(
- serializedOptions.getPipelineOptions(), dofn,
+ serializedOptions.get(), dofn,
new FlinkSideInputReader(sideInputs, runtimeContext),
outputManager,
mainOutputTag,
@@ -135,7 +135,7 @@
},
windowingStrategy);
- if ((serializedOptions.getPipelineOptions().as(FlinkPipelineOptions.class))
+ if ((serializedOptions.get().as(FlinkPipelineOptions.class))
.getEnableMetrics()) {
doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, getRuntimeContext());
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/SerializedPipelineOptions.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/SerializedPipelineOptions.java
deleted file mode 100644
index 40b6dd6..0000000
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/SerializedPipelineOptions.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.flink.translation.utils;
-
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import com.fasterxml.jackson.databind.Module;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.Serializable;
-import org.apache.beam.sdk.io.FileSystems;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.util.common.ReflectHelpers;
-
-/**
- * Encapsulates the PipelineOptions in serialized form to ship them to the cluster.
- */
-public class SerializedPipelineOptions implements Serializable {
-
- private final byte[] serializedOptions;
-
- /** Lazily initialized copy of deserialized options. */
- private transient PipelineOptions pipelineOptions;
-
- public SerializedPipelineOptions(PipelineOptions options) {
- checkNotNull(options, "PipelineOptions must not be null.");
-
- try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
- createMapper().writeValue(baos, options);
- this.serializedOptions = baos.toByteArray();
- } catch (Exception e) {
- throw new RuntimeException("Couldn't serialize PipelineOptions.", e);
- }
-
- }
-
- public PipelineOptions getPipelineOptions() {
- if (pipelineOptions == null) {
- try {
- pipelineOptions = createMapper().readValue(serializedOptions, PipelineOptions.class);
-
- FileSystems.setDefaultPipelineOptions(pipelineOptions);
- } catch (IOException e) {
- throw new RuntimeException("Couldn't deserialize the PipelineOptions.", e);
- }
- }
-
- return pipelineOptions;
- }
-
- /**
- * Use an {@link ObjectMapper} configured with any {@link Module}s in the class path allowing
- * for user specified configuration injection into the ObjectMapper. This supports user custom
- * types on {@link PipelineOptions}.
- */
- private static ObjectMapper createMapper() {
- return new ObjectMapper().registerModules(
- ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
- }
-}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
index 27e6912..3f9d601 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
@@ -19,9 +19,9 @@
import java.io.IOException;
import java.util.List;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
import org.apache.beam.runners.flink.metrics.ReaderInvocationUtil;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -50,7 +50,7 @@
private final BoundedSource<T> initialSource;
private transient PipelineOptions options;
- private final SerializedPipelineOptions serializedOptions;
+ private final SerializablePipelineOptions serializedOptions;
private transient BoundedSource.BoundedReader<T> reader;
private boolean inputAvailable = false;
@@ -61,12 +61,12 @@
String stepName, BoundedSource<T> initialSource, PipelineOptions options) {
this.stepName = stepName;
this.initialSource = initialSource;
- this.serializedOptions = new SerializedPipelineOptions(options);
+ this.serializedOptions = new SerializablePipelineOptions(options);
}
@Override
public void configure(Configuration configuration) {
- options = serializedOptions.getPipelineOptions();
+ options = serializedOptions.get();
}
@Override
@@ -76,7 +76,7 @@
readerInvoker =
new ReaderInvocationUtil<>(
stepName,
- serializedOptions.getPipelineOptions(),
+ serializedOptions.get(),
metricContainer);
reader = ((BoundedSource<T>) sourceInputSplit.getSource()).createReader(options);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index 7995ea8..62de423 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -47,10 +47,10 @@
import org.apache.beam.runners.core.StatefulDoFnRunner;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.TimerInternals.TimerData;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals;
@@ -106,7 +106,7 @@
protected DoFn<InputT, OutputT> doFn;
- protected final SerializedPipelineOptions serializedOptions;
+ protected final SerializablePipelineOptions serializedOptions;
protected final TupleTag<OutputT> mainOutputTag;
protected final List<TupleTag<?>> additionalOutputTags;
@@ -174,7 +174,7 @@
this.additionalOutputTags = additionalOutputTags;
this.sideInputTagMapping = sideInputTagMapping;
this.sideInputs = sideInputs;
- this.serializedOptions = new SerializedPipelineOptions(options);
+ this.serializedOptions = new SerializablePipelineOptions(options);
this.windowingStrategy = windowingStrategy;
this.outputManagerFactory = outputManagerFactory;
@@ -256,7 +256,7 @@
org.apache.beam.runners.core.StepContext stepContext = createStepContext();
doFnRunner = DoFnRunners.simpleRunner(
- serializedOptions.getPipelineOptions(),
+ serializedOptions.get(),
doFn,
sideInputReader,
outputManager,
@@ -301,7 +301,7 @@
stateCleaner);
}
- if ((serializedOptions.getPipelineOptions().as(FlinkPipelineOptions.class))
+ if ((serializedOptions.get().as(FlinkPipelineOptions.class))
.getEnableMetrics()) {
doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, getRuntimeContext());
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
index 2f095d4..be758a6 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java
@@ -115,7 +115,7 @@
((ProcessFn) doFn).setProcessElementInvoker(
new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
doFn,
- serializedOptions.getPipelineOptions(),
+ serializedOptions.get(),
new OutputWindowedValue<OutputT>() {
@Override
public void outputWindowedValue(
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
index 6d75688..5ddc46f 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
@@ -20,9 +20,9 @@
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
import org.apache.beam.runners.flink.metrics.ReaderInvocationUtil;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -48,7 +48,7 @@
/**
* Keep the options so that we can initialize the readers.
*/
- private final SerializedPipelineOptions serializedOptions;
+ private final SerializablePipelineOptions serializedOptions;
/**
* The split sources. We split them in the constructor to ensure that all parallel
@@ -74,7 +74,7 @@
BoundedSource<OutputT> source,
int parallelism) throws Exception {
this.stepName = stepName;
- this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+ this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
long desiredBundleSize = source.getEstimatedSizeBytes(pipelineOptions) / parallelism;
@@ -109,13 +109,13 @@
ReaderInvocationUtil<OutputT, BoundedSource.BoundedReader<OutputT>> readerInvoker =
new ReaderInvocationUtil<>(
stepName,
- serializedOptions.getPipelineOptions(),
+ serializedOptions.get(),
metricContainer);
readers = new ArrayList<>();
// initialize readers from scratch
for (BoundedSource<OutputT> source : localSources) {
- readers.add(source.createReader(serializedOptions.getPipelineOptions()));
+ readers.add(source.createReader(serializedOptions.get()));
}
if (readers.size() == 1) {
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java
index 910a33f..49e4ddc 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java
@@ -123,7 +123,7 @@
}
@Override
- public Coder getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return DEFAULT_SOCKET_CODER;
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
index e75072a..817dd74 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
@@ -22,10 +22,10 @@
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
import org.apache.beam.runners.flink.metrics.ReaderInvocationUtil;
import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
@@ -72,7 +72,7 @@
/**
* Keep the options so that we can initialize the localReaders.
*/
- private final SerializedPipelineOptions serializedOptions;
+ private final SerializablePipelineOptions serializedOptions;
/**
* For snapshot and restore.
@@ -141,7 +141,7 @@
UnboundedSource<OutputT, CheckpointMarkT> source,
int parallelism) throws Exception {
this.stepName = stepName;
- this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+ this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
if (source.requiresDeduping()) {
LOG.warn("Source {} requires deduping but Flink runner doesn't support this yet.", source);
@@ -189,7 +189,7 @@
stateForCheckpoint.get()) {
localSplitSources.add(restored.getKey());
localReaders.add(restored.getKey().createReader(
- serializedOptions.getPipelineOptions(), restored.getValue()));
+ serializedOptions.get(), restored.getValue()));
}
} else {
// initialize localReaders and localSources from scratch
@@ -198,7 +198,7 @@
UnboundedSource<OutputT, CheckpointMarkT> source =
splitSources.get(i);
UnboundedSource.UnboundedReader<OutputT> reader =
- source.createReader(serializedOptions.getPipelineOptions(), null);
+ source.createReader(serializedOptions.get(), null);
localSplitSources.add(source);
localReaders.add(reader);
}
@@ -221,7 +221,7 @@
ReaderInvocationUtil<OutputT, UnboundedSource.UnboundedReader<OutputT>> readerInvoker =
new ReaderInvocationUtil<>(
stepName,
- serializedOptions.getPipelineOptions(),
+ serializedOptions.get(),
metricContainer);
if (localReaders.size() == 0) {
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
index d0281ec..eb06026 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
@@ -17,32 +17,8 @@
*/
package org.apache.beam.runners.flink;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.core.JsonParser;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.DeserializationContext;
-import com.fasterxml.jackson.databind.JsonDeserializer;
-import com.fasterxml.jackson.databind.JsonSerializer;
-import com.fasterxml.jackson.databind.Module;
-import com.fasterxml.jackson.databind.SerializerProvider;
-import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
-import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-import com.fasterxml.jackson.databind.module.SimpleModule;
-import com.google.auto.service.AutoService;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
import java.util.Collections;
import java.util.HashMap;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.Default;
@@ -60,12 +36,10 @@
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.joda.time.Instant;
import org.junit.Assert;
-import org.junit.BeforeClass;
import org.junit.Test;
/**
@@ -73,9 +47,7 @@
*/
public class PipelineOptionsTest {
- /**
- * Pipeline options.
- */
+ /** Pipeline options. */
public interface MyOptions extends FlinkPipelineOptions {
@Description("Bla bla bla")
@Default.String("Hello")
@@ -83,60 +55,12 @@
void setTestOption(String value);
}
- private static MyOptions options;
- private static SerializedPipelineOptions serializedOptions;
-
- private static final String[] args = new String[]{"--testOption=nothing"};
-
- @BeforeClass
- public static void beforeTest() {
- options = PipelineOptionsFactory.fromArgs(args).as(MyOptions.class);
- serializedOptions = new SerializedPipelineOptions(options);
- }
-
- @Test
- public void testDeserialization() {
- MyOptions deserializedOptions = serializedOptions.getPipelineOptions().as(MyOptions.class);
- assertEquals("nothing", deserializedOptions.getTestOption());
- }
-
- @Test
- public void testIgnoredFieldSerialization() {
- FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class);
- options.setStateBackend(new MemoryStateBackend());
-
- FlinkPipelineOptions deserialized =
- new SerializedPipelineOptions(options).getPipelineOptions().as(FlinkPipelineOptions.class);
-
- assertNull(deserialized.getStateBackend());
- }
-
- @Test
- public void testEnableMetrics() {
- FlinkPipelineOptions options = PipelineOptionsFactory.as(FlinkPipelineOptions.class);
- options.setEnableMetrics(false);
- assertFalse(options.getEnableMetrics());
- }
-
- @Test
- public void testCaching() {
- PipelineOptions deserializedOptions =
- serializedOptions.getPipelineOptions().as(PipelineOptions.class);
-
- assertNotNull(deserializedOptions);
- assertTrue(deserializedOptions == serializedOptions.getPipelineOptions());
- assertTrue(deserializedOptions == serializedOptions.getPipelineOptions());
- assertTrue(deserializedOptions == serializedOptions.getPipelineOptions());
- }
-
- @Test(expected = Exception.class)
- public void testNonNull() {
- new SerializedPipelineOptions(null);
- }
+ private static MyOptions options =
+ PipelineOptionsFactory.fromArgs("--testOption=nothing").as(MyOptions.class);
@Test(expected = Exception.class)
public void parDoBaseClassPipelineOptionsNullTest() {
- DoFnOperator<String, String> doFnOperator = new DoFnOperator<>(
+ new DoFnOperator<>(
new TestDoFn(),
"stepName",
WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()),
@@ -196,18 +120,7 @@
}
- @Test
- public void testExternalizedCheckpointsConfigs() {
- String[] args = new String[] { "--externalizedCheckpointsEnabled=true",
- "--retainExternalizedCheckpointsOnCancellation=false" };
- final FlinkPipelineOptions options = PipelineOptionsFactory.fromArgs(args)
- .as(FlinkPipelineOptions.class);
- assertEquals(options.isExternalizedCheckpointsEnabled(), true);
- assertEquals(options.getRetainExternalizedCheckpointsOnCancellation(), false);
- }
-
private static class TestDoFn extends DoFn<String, String> {
-
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
Assert.assertNotNull(c.getPipelineOptions());
@@ -216,74 +129,4 @@
c.getPipelineOptions().as(MyOptions.class).getTestOption());
}
}
-
- /** PipelineOptions used to test auto registration of Jackson modules. */
- public interface JacksonIncompatibleOptions extends PipelineOptions {
- JacksonIncompatible getJacksonIncompatible();
- void setJacksonIncompatible(JacksonIncompatible value);
- }
-
- /** A Jackson {@link Module} to test auto-registration of modules. */
- @AutoService(Module.class)
- public static class RegisteredTestModule extends SimpleModule {
- public RegisteredTestModule() {
- super("RegisteredTestModule");
- setMixInAnnotation(JacksonIncompatible.class, JacksonIncompatibleMixin.class);
- }
- }
-
- /** A class which Jackson does not know how to serialize/deserialize. */
- public static class JacksonIncompatible {
- private final String value;
- public JacksonIncompatible(String value) {
- this.value = value;
- }
- }
-
- /** A Jackson mixin used to add annotations to other classes. */
- @JsonDeserialize(using = JacksonIncompatibleDeserializer.class)
- @JsonSerialize(using = JacksonIncompatibleSerializer.class)
- public static final class JacksonIncompatibleMixin {}
-
- /** A Jackson deserializer for {@link JacksonIncompatible}. */
- public static class JacksonIncompatibleDeserializer extends
- JsonDeserializer<JacksonIncompatible> {
-
- @Override
- public JacksonIncompatible deserialize(JsonParser jsonParser,
- DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
- return new JacksonIncompatible(jsonParser.readValueAs(String.class));
- }
- }
-
- /** A Jackson serializer for {@link JacksonIncompatible}. */
- public static class JacksonIncompatibleSerializer extends JsonSerializer<JacksonIncompatible> {
-
- @Override
- public void serialize(JacksonIncompatible jacksonIncompatible, JsonGenerator jsonGenerator,
- SerializerProvider serializerProvider) throws IOException, JsonProcessingException {
- jsonGenerator.writeString(jacksonIncompatible.value);
- }
- }
-
- @Test
- public void testSerializingPipelineOptionsWithCustomUserType() throws Exception {
- String expectedValue = "testValue";
- PipelineOptions options = PipelineOptionsFactory
- .fromArgs("--jacksonIncompatible=\"" + expectedValue + "\"")
- .as(JacksonIncompatibleOptions.class);
- SerializedPipelineOptions context = new SerializedPipelineOptions(options);
-
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- try (ObjectOutputStream outputStream = new ObjectOutputStream(baos)) {
- outputStream.writeObject(context);
- }
- try (ObjectInputStream inputStream =
- new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()))) {
- SerializedPipelineOptions copy = (SerializedPipelineOptions) inputStream.readObject();
- assertEquals(expectedValue,
- copy.getPipelineOptions().as(JacksonIncompatibleOptions.class)
- .getJacksonIncompatible().value);
- }
- }
}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
index edf548a..fcb9282 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
@@ -238,7 +238,7 @@
public void validate() {}
@Override
- public Coder<KV<Integer, Integer>> getDefaultOutputCoder() {
+ public Coder<KV<Integer, Integer>> getOutputCoder() {
return KvCoder.of(VarIntCoder.of(), VarIntCoder.of());
}
diff --git a/runners/gearpump/README.md b/runners/gearpump/README.md
new file mode 100644
index 0000000..e8ce794
--- /dev/null
+++ b/runners/gearpump/README.md
@@ -0,0 +1,61 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+## Gearpump Beam Runner
+
+The Gearpump Beam runner allows users to execute pipelines written using the Apache Beam programming API with Apache Gearpump (incubating) as an execution engine.
+
+##Getting Started
+
+The following shows how to run the WordCount example that is provided with the source code on Beam.
+
+###Installing Beam
+
+To get the latest version of Beam with Gearpump-Runner, first clone the Beam repository:
+
+```
+git clone https://github.com/apache/beam
+git checkout gearpump-runner
+```
+
+Then switch to the newly created directory and run Maven to build the Apache Beam:
+
+```
+cd beam
+mvn clean install -DskipTests
+```
+
+Now Apache Beam and the Gearpump Runner are installed in your local Maven repository.
+
+###Running Wordcount Example
+
+Download something to count:
+
+```
+curl http://www.gutenberg.org/cache/epub/1128/pg1128.txt > /tmp/kinglear.txt
+```
+
+Run the pipeline, using the Gearpump runner:
+
+```
+cd examples/java
+mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount -Dexec.args="--inputFile=/tmp/kinglear.txt --output=/tmp/wordcounts.txt --runner=TestGearpumpRunner" -Pgearpump-runner
+```
+
+Once completed, check the output file /tmp/wordcounts.txt-00000-of-00001
diff --git a/runners/gearpump/pom.xml b/runners/gearpump/pom.xml
new file mode 100644
index 0000000..3a4722f
--- /dev/null
+++ b/runners/gearpump/pom.xml
@@ -0,0 +1,280 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"
+ xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
+
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-runners-parent</artifactId>
+ <version>2.2.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <artifactId>beam-runners-gearpump</artifactId>
+
+ <name>Apache Beam :: Runners :: Gearpump</name>
+ <packaging>jar</packaging>
+
+ <repositories>
+ <repository>
+ <id>apache-repo</id>
+ <name>apache maven repo</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ </repository>
+ </repositories>
+
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
+ <gearpump.version>0.8.4</gearpump.version>
+ </properties>
+
+ <profiles>
+ <profile>
+ <id>local-validates-runner-tests</id>
+ <activation><activeByDefault>false</activeByDefault></activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>validates-runner-tests</id>
+ <phase>integration-test</phase>
+ <goals>
+ <goal>test</goal>
+ </goals>
+ <configuration>
+ <groups>org.apache.beam.sdk.testing.ValidatesRunner</groups>
+ <excludedGroups>
+ org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders,
+ org.apache.beam.sdk.testing.UsesStatefulParDo,
+ org.apache.beam.sdk.testing.UsesTimersInParDo,
+ org.apache.beam.sdk.testing.UsesSplittableParDo,
+ org.apache.beam.sdk.testing.UsesAttemptedMetrics,
+ org.apache.beam.sdk.testing.UsesCommittedMetrics,
+ org.apache.beam.sdk.testing.UsesTestStream
+ </excludedGroups>
+ <parallel>none</parallel>
+ <failIfNoTests>true</failIfNoTests>
+ <dependenciesToScan>
+ <dependency>org.apache.beam:beam-sdks-java-core</dependency>
+ </dependenciesToScan>
+ <systemPropertyVariables>
+ <beamTestPipelineOptions>
+ [
+ "--runner=TestGearpumpRunner",
+ "--streaming=true"
+ ]
+ </beamTestPipelineOptions>
+ </systemPropertyVariables>
+ <threadCount>4</threadCount>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.gearpump</groupId>
+ <artifactId>gearpump-streaming_2.11</artifactId>
+ <version>${gearpump.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.gearpump</groupId>
+ <artifactId>gearpump-core_2.11</artifactId>
+ <version>${gearpump.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.typesafe</groupId>
+ <artifactId>config</artifactId>
+ <version>1.3.0</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scala-library</artifactId>
+ <version>2.11.8</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-sdks-java-core</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jdk14</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.google.collections</groupId>
+ <artifactId>google-collections</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-runners-core-java</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-runners-core-construction-java</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jdk14</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>joda-time</groupId>
+ <artifactId>joda-time</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.google.code.findbugs</groupId>
+ <artifactId>jsr305</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-sdks-java-core</artifactId>
+ <classifier>tests</classifier>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jdk14</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.dataformat</groupId>
+ <artifactId>jackson-dataformat-yaml</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.google.auto.service</groupId>
+ <artifactId>auto-service</artifactId>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <!-- JAR Packaging -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <archive>
+ <manifest>
+ <addDefaultImplementationEntries>true</addDefaultImplementationEntries>
+ <addDefaultSpecificationEntries>true</addDefaultSpecificationEntries>
+ </manifest>
+ </archive>
+ </configuration>
+ </plugin>
+
+ <!-- Java compiler -->
+ <plugin>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <source>1.8</source>
+ <target>1.8</target>
+ <testSource>1.8</testSource>
+ <testTarget>1.8</testTarget>
+ </configuration>
+ </plugin>
+
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-enforcer-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>enforce</id>
+ <goals>
+ <goal>enforce</goal>
+ </goals>
+ <configuration>
+ <rules>
+ <enforceBytecodeVersion>
+ <maxJdkVersion>1.8</maxJdkVersion>
+ </enforceBytecodeVersion>
+ <requireJavaVersion>
+ <version>[1.8,)</version>
+ </requireJavaVersion>
+ </rules>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+
+ <!-- uber jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <configuration>
+ <descriptorRefs>
+ <descriptorRef>jar-with-dependencies</descriptorRef>
+ </descriptorRefs>
+ </configuration>
+ </plugin>
+
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-checkstyle-plugin</artifactId>
+ </plugin>
+
+ </plugins>
+ </build>
+</project>
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java
new file mode 100644
index 0000000..e02cbbc
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineOptions.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+
+import java.util.Map;
+
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+
+import org.apache.gearpump.cluster.client.ClientContext;
+import org.apache.gearpump.cluster.embedded.EmbeddedCluster;
+
+/**
+ * Options that configure the Gearpump pipeline.
+ */
+public interface GearpumpPipelineOptions extends PipelineOptions {
+
+ @Description("set unique application name for Gearpump runner")
+ void setApplicationName(String name);
+
+ String getApplicationName();
+
+ @Description("set parallelism for Gearpump processor")
+ void setParallelism(int parallelism);
+
+ @Default.Integer(1)
+ int getParallelism();
+
+ @Description("register Kryo serializers")
+ void setSerializers(Map<String, String> serializers);
+
+ @JsonIgnore
+ Map<String, String> getSerializers();
+
+ @Description("set EmbeddedCluster for tests")
+ void setEmbeddedCluster(EmbeddedCluster cluster);
+
+ @JsonIgnore
+ EmbeddedCluster getEmbeddedCluster();
+
+ void setClientContext(ClientContext clientContext);
+
+ @JsonIgnore
+ @Description("get client context to query application status")
+ ClientContext getClientContext();
+
+}
+
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineResult.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineResult.java
new file mode 100644
index 0000000..dd7fa23
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpPipelineResult.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.gearpump;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.metrics.MetricResults;
+
+import org.apache.gearpump.cluster.ApplicationStatus;
+import org.apache.gearpump.cluster.MasterToAppMaster.AppMasterData;
+import org.apache.gearpump.cluster.client.ClientContext;
+import org.apache.gearpump.cluster.client.RunningApplication;
+import org.joda.time.Duration;
+
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+/**
+ * Result of executing a {@link Pipeline} with Gearpump.
+ */
+public class GearpumpPipelineResult implements PipelineResult {
+
+ private final ClientContext client;
+ private final RunningApplication app;
+ private boolean finished = false;
+
+ public GearpumpPipelineResult(ClientContext client, RunningApplication app) {
+ this.client = client;
+ this.app = app;
+ }
+
+ @Override
+ public State getState() {
+ if (!finished) {
+ return getGearpumpState();
+ } else {
+ return State.DONE;
+ }
+ }
+
+ @Override
+ public State cancel() throws IOException {
+ if (!finished) {
+ app.shutDown();
+ finished = true;
+ return State.CANCELLED;
+ } else {
+ return State.DONE;
+ }
+ }
+
+ @Override
+ public State waitUntilFinish(Duration duration) {
+ return waitUntilFinish();
+ }
+
+ @Override
+ public State waitUntilFinish() {
+ if (!finished) {
+ app.waitUntilFinish();
+ finished = true;
+ }
+ return State.DONE;
+ }
+
+ @Override
+ public MetricResults metrics() {
+ return null;
+ }
+
+ private State getGearpumpState() {
+ ApplicationStatus status = null;
+ List<AppMasterData> apps =
+ JavaConverters.<AppMasterData>seqAsJavaListConverter(
+ (Seq<AppMasterData>) client.listApps().appMasters()).asJava();
+ for (AppMasterData appData: apps) {
+ if (appData.appId() == app.appId()) {
+ status = appData.status();
+ }
+ }
+ if (null == status || status instanceof ApplicationStatus.NONEXIST$) {
+ return State.UNKNOWN;
+ } else if (status instanceof ApplicationStatus.ACTIVE$) {
+ return State.RUNNING;
+ } else if (status instanceof ApplicationStatus.SUCCEEDED$) {
+ return State.DONE;
+ } else {
+ return State.FAILED;
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java
new file mode 100644
index 0000000..5febf3c
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunner.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.gearpump;
+
+import com.typesafe.config.Config;
+import com.typesafe.config.ConfigValueFactory;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.beam.runners.gearpump.translators.GearpumpPipelineTranslator;
+import org.apache.beam.runners.gearpump.translators.TranslationContext;
+
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineRunner;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsValidator;
+
+import org.apache.gearpump.cluster.ClusterConfig;
+import org.apache.gearpump.cluster.UserConfig;
+import org.apache.gearpump.cluster.client.ClientContext;
+import org.apache.gearpump.cluster.client.RunningApplication;
+import org.apache.gearpump.cluster.embedded.EmbeddedCluster;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStreamApp;
+
+/**
+ * A {@link PipelineRunner} that executes the operations in the
+ * pipeline by first translating them to Gearpump Stream DSL
+ * and then executing them on a Gearpump cluster.
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class GearpumpRunner extends PipelineRunner<GearpumpPipelineResult> {
+
+ private final GearpumpPipelineOptions options;
+
+ private static final String GEARPUMP_SERIALIZERS = "gearpump.serializers";
+ private static final String DEFAULT_APPNAME = "beam_gearpump_app";
+
+ public GearpumpRunner(GearpumpPipelineOptions options) {
+ this.options = options;
+ }
+
+ public static GearpumpRunner fromOptions(PipelineOptions options) {
+ GearpumpPipelineOptions pipelineOptions =
+ PipelineOptionsValidator.validate(GearpumpPipelineOptions.class, options);
+ return new GearpumpRunner(pipelineOptions);
+ }
+
+ @Override
+ public GearpumpPipelineResult run(Pipeline pipeline) {
+ String appName = options.getApplicationName();
+ if (null == appName) {
+ appName = DEFAULT_APPNAME;
+ }
+ Config config = registerSerializers(ClusterConfig.defaultConfig(),
+ options.getSerializers());
+ ClientContext clientContext = getClientContext(options, config);
+ options.setClientContext(clientContext);
+ UserConfig userConfig = UserConfig.empty();
+ JavaStreamApp streamApp = new JavaStreamApp(
+ appName, clientContext, userConfig);
+ TranslationContext translationContext = new TranslationContext(streamApp, options);
+ GearpumpPipelineTranslator translator = new GearpumpPipelineTranslator(translationContext);
+ translator.translate(pipeline);
+ RunningApplication app = streamApp.submit();
+
+ return new GearpumpPipelineResult(clientContext, app);
+ }
+
+ private ClientContext getClientContext(GearpumpPipelineOptions options, Config config) {
+ EmbeddedCluster cluster = options.getEmbeddedCluster();
+ if (cluster != null) {
+ return cluster.newClientContext();
+ } else {
+ return ClientContext.apply(config);
+ }
+ }
+
+ /**
+ * register class with default kryo serializers.
+ */
+ private Config registerSerializers(Config config, Map<String, String> userSerializers) {
+ Map<String, String> serializers = new HashMap<>();
+ serializers.put("org.apache.beam.sdk.util.WindowedValue$ValueInGlobalWindow", "");
+ serializers.put("org.apache.beam.sdk.util.WindowedValue$TimestampedValueInSingleWindow", "");
+ serializers.put("org.apache.beam.sdk.util.WindowedValue$TimestampedValueInGlobalWindow", "");
+ serializers.put("org.apache.beam.sdk.util.WindowedValue$TimestampedValueInMultipleWindows", "");
+ serializers.put("org.apache.beam.sdk.transforms.windowing.PaneInfo", "");
+ serializers.put("org.apache.beam.sdk.transforms.windowing.PaneInfo$Timing", "");
+ serializers.put("org.joda.time.Instant", "");
+ serializers.put("org.apache.beam.sdk.values.KV", "");
+ serializers.put("org.apache.beam.sdk.transforms.windowing.IntervalWindow", "");
+ serializers.put("org.apache.beam.sdk.values.TimestampedValue", "");
+ serializers.put(
+ "org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils$RawUnionValue", "");
+
+ if (userSerializers != null && !userSerializers.isEmpty()) {
+ serializers.putAll(userSerializers);
+ }
+
+ return config.withValue(GEARPUMP_SERIALIZERS, ConfigValueFactory.fromMap(serializers));
+ }
+
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunnerRegistrar.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunnerRegistrar.java
new file mode 100644
index 0000000..5152105
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/GearpumpRunnerRegistrar.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump;
+
+import com.google.auto.service.AutoService;
+import com.google.common.collect.ImmutableList;
+
+import org.apache.beam.sdk.PipelineRunner;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsRegistrar;
+import org.apache.beam.sdk.runners.PipelineRunnerRegistrar;
+
+/**
+ * Contains the {@link PipelineRunnerRegistrar} and {@link PipelineOptionsRegistrar} for the
+ * {@link GearpumpRunner}.
+ *
+ * {@link AutoService} will register Gearpump's implementations of the {@link PipelineRunner}
+ * and {@link PipelineOptions} as available pipeline runner services.
+ */
+public class GearpumpRunnerRegistrar {
+ private GearpumpRunnerRegistrar() { }
+
+ /**
+ * Registers the {@link GearpumpRunner}.
+ */
+ @AutoService(PipelineRunnerRegistrar.class)
+ public static class Runner implements PipelineRunnerRegistrar {
+
+ @Override
+ public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() {
+ return ImmutableList.<Class<? extends PipelineRunner<?>>>of(
+ GearpumpRunner.class,
+ TestGearpumpRunner.class);
+ }
+ }
+
+ /**
+ * Registers the {@link GearpumpPipelineOptions}.
+ */
+ @AutoService(PipelineOptionsRegistrar.class)
+ public static class Options implements PipelineOptionsRegistrar {
+
+ @Override
+ public Iterable<Class<? extends PipelineOptions>> getPipelineOptions() {
+ return ImmutableList.<Class<? extends PipelineOptions>>of(GearpumpPipelineOptions.class);
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java
new file mode 100644
index 0000000..0a88849
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/TestGearpumpRunner.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump;
+
+import com.typesafe.config.Config;
+import com.typesafe.config.ConfigValueFactory;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineRunner;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsValidator;
+
+import org.apache.gearpump.cluster.ClusterConfig;
+import org.apache.gearpump.cluster.embedded.EmbeddedCluster;
+import org.apache.gearpump.util.Constants;
+
+/**
+ * Gearpump {@link PipelineRunner} for tests, which uses {@link EmbeddedCluster}.
+ */
+public class TestGearpumpRunner extends PipelineRunner<GearpumpPipelineResult> {
+
+ private final GearpumpRunner delegate;
+ private final EmbeddedCluster cluster;
+
+ private TestGearpumpRunner(GearpumpPipelineOptions options) {
+ Config config = ClusterConfig.master(null);
+ config = config.withValue(Constants.APPLICATION_TOTAL_RETRIES(),
+ ConfigValueFactory.fromAnyRef(0));
+ cluster = new EmbeddedCluster(config);
+ cluster.start();
+ options.setEmbeddedCluster(cluster);
+ delegate = GearpumpRunner.fromOptions(options);
+ }
+
+ public static TestGearpumpRunner fromOptions(PipelineOptions options) {
+ GearpumpPipelineOptions pipelineOptions =
+ PipelineOptionsValidator.validate(GearpumpPipelineOptions.class, options);
+ return new TestGearpumpRunner(pipelineOptions);
+ }
+
+ @Override
+ public GearpumpPipelineResult run(Pipeline pipeline) {
+ GearpumpPipelineResult result = delegate.run(pipeline);
+ result.waitUntilFinish();
+ cluster.stop();
+ return result;
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/package-info.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/package-info.java
new file mode 100644
index 0000000..5013616
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Internal implementation of the Beam runner for Apache Gearpump.
+ */
+package org.apache.beam.runners.gearpump;
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java
new file mode 100644
index 0000000..559cb28
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslator.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import java.util.List;
+
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+
+/**
+ * CreateGearpumpPCollectionView bridges input stream to down stream
+ * transforms.
+ */
+public class CreateGearpumpPCollectionViewTranslator<ElemT, ViewT> implements
+ TransformTranslator<CreateStreamingGearpumpView.CreateGearpumpPCollectionView<ElemT, ViewT>> {
+
+ private static final long serialVersionUID = -3955521308055056034L;
+
+ @Override
+ public void translate(
+ CreateStreamingGearpumpView.CreateGearpumpPCollectionView<ElemT, ViewT> transform,
+ TranslationContext context) {
+ JavaStream<WindowedValue<List<ElemT>>> inputStream =
+ context.getInputStream(context.getInput());
+ PCollectionView<ViewT> view = transform.getView();
+ context.setOutputStream(view, inputStream);
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateStreamingGearpumpView.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateStreamingGearpumpView.java
new file mode 100644
index 0000000..3ebe5c8
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/CreateStreamingGearpumpView.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.gearpump.translators;
+
+import com.google.common.collect.Iterables;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.beam.runners.core.construction.ReplacementOutputs;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** Gearpump streaming overrides for various view (side input) transforms. */
+class CreateStreamingGearpumpView<ElemT, ViewT>
+ extends PTransform<PCollection<ElemT>, PCollection<ElemT>> {
+ private final PCollectionView<ViewT> view;
+
+ public CreateStreamingGearpumpView(PCollectionView<ViewT> view) {
+ this.view = view;
+ }
+
+ @Override
+ public PCollection<ElemT> expand(PCollection<ElemT> input) {
+ input
+ .apply(Combine.globally(new Concatenate<ElemT>()).withoutDefaults())
+ .apply(CreateGearpumpPCollectionView.<ElemT, ViewT>of(view));
+ return input;
+ }
+
+ /**
+ * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
+ *
+ * <p>For internal use by {@link CreateStreamingGearpumpView}. This combiner requires that
+ * the input {@link PCollection} fits in memory. For a large {@link PCollection} this is
+ * expected to crash!
+ *
+ * @param <T> the type of elements to concatenate.
+ */
+ private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
+ @Override
+ public List<T> createAccumulator() {
+ return new ArrayList<T>();
+ }
+
+ @Override
+ public List<T> addInput(List<T> accumulator, T input) {
+ accumulator.add(input);
+ return accumulator;
+ }
+
+ @Override
+ public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
+ List<T> result = createAccumulator();
+ for (List<T> accumulator : accumulators) {
+ result.addAll(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public List<T> extractOutput(List<T> accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
+ return ListCoder.of(inputCoder);
+ }
+
+ @Override
+ public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
+ return ListCoder.of(inputCoder);
+ }
+ }
+
+ /**
+ * Creates a primitive {@link PCollectionView}.
+ *
+ * <p>For internal use only by runner implementors.
+ *
+ * @param <ElemT> The type of the elements of the input PCollection
+ * @param <ViewT> The type associated with the {@link PCollectionView} used as a side input
+ */
+ public static class CreateGearpumpPCollectionView<ElemT, ViewT>
+ extends PTransform<PCollection<List<ElemT>>, PCollection<List<ElemT>>> {
+ private PCollectionView<ViewT> view;
+
+ private CreateGearpumpPCollectionView(PCollectionView<ViewT> view) {
+ this.view = view;
+ }
+
+ public static <ElemT, ViewT> CreateGearpumpPCollectionView<ElemT, ViewT> of(
+ PCollectionView<ViewT> view) {
+ return new CreateGearpumpPCollectionView<>(view);
+ }
+
+ @Override
+ public PCollection<List<ElemT>> expand(PCollection<List<ElemT>> input) {
+ return PCollection.<List<ElemT>>createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder());
+ }
+
+ public PCollectionView<ViewT> getView() {
+ return view;
+ }
+ }
+
+ public static class Factory<ElemT, ViewT>
+ implements PTransformOverrideFactory<
+ PCollection<ElemT>, PCollection<ElemT>, CreatePCollectionView<ElemT, ViewT>> {
+ public Factory() {}
+
+ @Override
+ public PTransformReplacement<PCollection<ElemT>, PCollection<ElemT>> getReplacementTransform(
+ AppliedPTransform<
+ PCollection<ElemT>, PCollection<ElemT>, CreatePCollectionView<ElemT, ViewT>>
+ transform) {
+ return PTransformReplacement.of(
+ (PCollection<ElemT>) Iterables.getOnlyElement(transform.getInputs().values()),
+ new CreateStreamingGearpumpView<ElemT, ViewT>(transform.getTransform().getView()));
+ }
+
+ @Override
+ public Map<PValue, ReplacementOutput> mapOutputs(
+ Map<TupleTag<?>, PValue> outputs, PCollection<ElemT> newOutput) {
+ return ReplacementOutputs.singleton(outputs, newOutput);
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslator.java
new file mode 100644
index 0000000..8cc0058
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslator.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.beam.runners.gearpump.translators.io.UnboundedSourceWrapper;
+import org.apache.beam.runners.gearpump.translators.io.ValuesSource;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.values.PCollection;
+
+import org.apache.beam.sdk.values.PValue;
+import org.apache.gearpump.streaming.dsl.api.functions.MapFunction;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+
+/**
+ * Flatten.FlattenPCollectionList is translated to Gearpump merge function.
+ */
+public class FlattenPCollectionsTranslator<T> implements
+ TransformTranslator<Flatten.PCollections<T>> {
+
+ private static final long serialVersionUID = -5552148802472944759L;
+
+ @Override
+ public void translate(Flatten.PCollections<T> transform, TranslationContext context) {
+ JavaStream<T> merged = null;
+ Set<PCollection<T>> unique = new HashSet<>();
+ for (PValue input: context.getInputs().values()) {
+ PCollection<T> collection = (PCollection<T>) input;
+ JavaStream<T> inputStream = context.getInputStream(collection);
+ if (null == merged) {
+ merged = inputStream;
+ } else {
+ // duplicate edges are not allowed in Gearpump graph
+ // so we route through a dummy node
+ if (unique.contains(collection)) {
+ inputStream = inputStream.map(new DummyFunction<T>(), "dummy");
+ }
+
+ merged = merged.merge(inputStream, 1, transform.getName());
+ }
+ unique.add(collection);
+ }
+
+ if (null == merged) {
+ UnboundedSourceWrapper<String, ?> unboundedSourceWrapper = new UnboundedSourceWrapper<>(
+ new ValuesSource<>(Lists.newArrayList("dummy"),
+ StringUtf8Coder.of()), context.getPipelineOptions());
+ merged = context.getSourceStream(unboundedSourceWrapper);
+ }
+ context.setOutputStream(context.getOutput(), merged);
+ }
+
+ private static class DummyFunction<T> extends MapFunction<T, T> {
+
+ private static final long serialVersionUID = 5454396869997290471L;
+
+ @Override
+ public T map(T t) {
+ return t;
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java
new file mode 100644
index 0000000..ca98aac
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GearpumpPipelineTranslator.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.runners.PTransformOverride;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.PValue;
+
+import org.apache.gearpump.util.Graph;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@link GearpumpPipelineTranslator} knows how to translate {@link Pipeline} objects
+ * into Gearpump {@link Graph}.
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class GearpumpPipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+
+ private static final Logger LOG = LoggerFactory.getLogger(
+ GearpumpPipelineTranslator.class);
+
+ /**
+ * A map from {@link PTransform} subclass to the corresponding
+ * {@link TransformTranslator} to use to translate that transform.
+ */
+ private static final Map<Class<? extends PTransform>, TransformTranslator>
+ transformTranslators = new HashMap<>();
+
+ private final TranslationContext translationContext;
+
+ static {
+ // register TransformTranslators
+ registerTransformTranslator(Read.Unbounded.class, new ReadUnboundedTranslator());
+ registerTransformTranslator(Read.Bounded.class, new ReadBoundedTranslator());
+ registerTransformTranslator(GroupByKey.class, new GroupByKeyTranslator());
+ registerTransformTranslator(Flatten.PCollections.class,
+ new FlattenPCollectionsTranslator());
+ registerTransformTranslator(ParDo.MultiOutput.class, new ParDoMultiOutputTranslator());
+ registerTransformTranslator(Window.Assign.class, new WindowAssignTranslator());
+ registerTransformTranslator(CreateStreamingGearpumpView.CreateGearpumpPCollectionView.class,
+ new CreateGearpumpPCollectionViewTranslator());
+ }
+
+ public GearpumpPipelineTranslator(TranslationContext translationContext) {
+ this.translationContext = translationContext;
+ }
+
+ public void translate(Pipeline pipeline) {
+ List<PTransformOverride> overrides =
+ ImmutableList.<PTransformOverride>builder()
+ .add(PTransformOverride.of(
+ PTransformMatchers.classEqualTo(View.CreatePCollectionView.class),
+ new CreateStreamingGearpumpView.Factory()))
+ .build();
+
+ pipeline.replaceAll(overrides);
+ pipeline.traverseTopologically(this);
+ }
+
+ @Override
+ public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
+ LOG.debug("entering composite transform {}", node.getTransform());
+ return CompositeBehavior.ENTER_TRANSFORM;
+ }
+
+ @Override
+ public void leaveCompositeTransform(TransformHierarchy.Node node) {
+ LOG.debug("leaving composite transform {}", node.getTransform());
+ }
+
+ @Override
+ public void visitPrimitiveTransform(TransformHierarchy.Node node) {
+ LOG.debug("visiting transform {}", node.getTransform());
+ PTransform transform = node.getTransform();
+ TransformTranslator translator = getTransformTranslator(transform.getClass());
+ if (null == translator) {
+ throw new IllegalStateException(
+ "no translator registered for " + transform);
+ }
+ translationContext.setCurrentTransform(node, getPipeline());
+ translator.translate(transform, translationContext);
+ }
+
+ @Override
+ public void visitValue(PValue value, TransformHierarchy.Node producer) {
+ LOG.debug("visiting value {}", value);
+ }
+
+ /**
+ * Records that instances of the specified PTransform class
+ * should be translated by default by the corresponding
+ * {@link TransformTranslator}.
+ */
+ private static <TransformT extends PTransform> void registerTransformTranslator(
+ Class<TransformT> transformClass,
+ TransformTranslator<? extends TransformT> transformTranslator) {
+ if (transformTranslators.put(transformClass, transformTranslator) != null) {
+ throw new IllegalArgumentException(
+ "defining multiple translators for " + transformClass);
+ }
+ }
+
+ /**
+ * Returns the {@link TransformTranslator} to use for instances of the
+ * specified PTransform class, or null if none registered.
+ */
+ private <TransformT extends PTransform>
+ TransformTranslator<TransformT> getTransformTranslator(Class<TransformT> transformClass) {
+ return transformTranslators.get(transformClass);
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslator.java
new file mode 100644
index 0000000..8409beb
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslator.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+
+import java.io.Serializable;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.gearpump.streaming.dsl.api.functions.FoldFunction;
+import org.apache.gearpump.streaming.dsl.api.functions.MapFunction;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.dsl.javaapi.functions.GroupByFunction;
+import org.apache.gearpump.streaming.dsl.window.api.Discarding$;
+import org.apache.gearpump.streaming.dsl.window.api.EventTimeTrigger$;
+import org.apache.gearpump.streaming.dsl.window.api.WindowFunction;
+import org.apache.gearpump.streaming.dsl.window.api.Windows;
+import org.apache.gearpump.streaming.dsl.window.impl.Window;
+import org.joda.time.Instant;
+
+/**
+ * {@link GroupByKey} is translated to Gearpump groupBy function.
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class GroupByKeyTranslator<K, V> implements TransformTranslator<GroupByKey<K, V>> {
+
+ private static final long serialVersionUID = -8742202583992787659L;
+
+ @Override
+ public void translate(GroupByKey<K, V> transform, TranslationContext context) {
+ PCollection<KV<K, V>> input = (PCollection<KV<K, V>>) context.getInput();
+ Coder<K> inputKeyCoder = ((KvCoder<K, V>) input.getCoder()).getKeyCoder();
+ JavaStream<WindowedValue<KV<K, V>>> inputStream =
+ context.getInputStream(input);
+ int parallelism = context.getPipelineOptions().getParallelism();
+ TimestampCombiner timestampCombiner = input.getWindowingStrategy().getTimestampCombiner();
+ WindowFn<KV<K, V>, BoundedWindow> windowFn = (WindowFn<KV<K, V>, BoundedWindow>)
+ input.getWindowingStrategy().getWindowFn();
+ JavaStream<WindowedValue<KV<K, List<V>>>> outputStream = inputStream
+ .window(Windows.apply(
+ new GearpumpWindowFn(windowFn.isNonMerging()),
+ EventTimeTrigger$.MODULE$, Discarding$.MODULE$, windowFn.toString()))
+ .groupBy(new GroupByFn<K, V>(inputKeyCoder), parallelism, "group_by_Key_and_Window")
+ .map(new KeyedByTimestamp<K, V>(windowFn, timestampCombiner), "keyed_by_timestamp")
+ .fold(new Merge<>(windowFn, timestampCombiner), "merge")
+ .map(new Values<K, V>(), "values");
+
+ context.setOutputStream(context.getOutput(), outputStream);
+ }
+
+ /**
+ * A transform used internally to translate Beam's Window to Gearpump's Window.
+ */
+ protected static class GearpumpWindowFn<T, W extends BoundedWindow>
+ implements WindowFunction, Serializable {
+
+ private final boolean isNonMerging;
+
+ public GearpumpWindowFn(boolean isNonMerging) {
+ this.isNonMerging = isNonMerging;
+ }
+
+ @Override
+ public <T> Window[] apply(Context<T> context) {
+ try {
+ Object element = context.element();
+ if (element instanceof TranslatorUtils.RawUnionValue) {
+ element = ((TranslatorUtils.RawUnionValue) element).getValue();
+ }
+ return toGearpumpWindows(((WindowedValue<T>) element).getWindows()
+ .toArray(new BoundedWindow[0]));
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public boolean isNonMerging() {
+ return isNonMerging;
+ }
+
+ private Window[] toGearpumpWindows(BoundedWindow[] windows) {
+ Window[] gwins = new Window[windows.length];
+ for (int i = 0; i < windows.length; i++) {
+ gwins[i] = TranslatorUtils.boundedWindowToGearpumpWindow(windows[i]);
+ }
+ return gwins;
+ }
+ }
+
+ /**
+ * A transform used internally to group KV message by its key.
+ */
+ protected static class GroupByFn<K, V> extends
+ GroupByFunction<WindowedValue<KV<K, V>>, ByteBuffer> {
+
+ private static final long serialVersionUID = -807905402490735530L;
+ private final Coder<K> keyCoder;
+
+ GroupByFn(Coder<K> keyCoder) {
+ this.keyCoder = keyCoder;
+ }
+
+ @Override
+ public ByteBuffer groupBy(WindowedValue<KV<K, V>> wv) {
+ try {
+ return ByteBuffer.wrap(CoderUtils.encodeToByteArray(keyCoder, wv.getValue().getKey()));
+ } catch (CoderException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ /**
+ * A transform used internally to transform WindowedValue to KV.
+ */
+ protected static class KeyedByTimestamp<K, V>
+ extends MapFunction<WindowedValue<KV<K, V>>,
+ KV<Instant, WindowedValue<KV<K, V>>>> {
+
+ private final WindowFn<KV<K, V>, BoundedWindow> windowFn;
+ private final TimestampCombiner timestampCombiner;
+
+ public KeyedByTimestamp(WindowFn<KV<K, V>, BoundedWindow> windowFn,
+ TimestampCombiner timestampCombiner) {
+ this.windowFn = windowFn;
+ this.timestampCombiner = timestampCombiner;
+ }
+
+ @Override
+ public KV<org.joda.time.Instant, WindowedValue<KV<K, V>>> map(
+ WindowedValue<KV<K, V>> wv) {
+ BoundedWindow window = Iterables.getOnlyElement(wv.getWindows());
+ Instant timestamp = timestampCombiner.assign(window
+ , windowFn.getOutputTime(wv.getTimestamp(), window));
+ return KV.of(timestamp, wv);
+ }
+ }
+
+ /**
+ * A transform used internally by Gearpump which encapsulates the merge logic.
+ */
+ protected static class Merge<K, V> extends
+ FoldFunction<KV<Instant, WindowedValue<KV<K, V>>>,
+ KV<Instant, WindowedValue<KV<K, List<V>>>>> {
+
+ private final WindowFn<KV<K, V>, BoundedWindow> windowFn;
+ private final TimestampCombiner timestampCombiner;
+
+ Merge(WindowFn<KV<K, V>, BoundedWindow> windowFn,
+ TimestampCombiner timestampCombiner) {
+ this.windowFn = windowFn;
+ this.timestampCombiner = timestampCombiner;
+ }
+
+ @Override
+ public KV<Instant, WindowedValue<KV<K, List<V>>>> init() {
+ return KV.of(null, null);
+ }
+
+ @Override
+ public KV<Instant, WindowedValue<KV<K, List<V>>>> fold(
+ KV<Instant, WindowedValue<KV<K, List<V>>>> accum,
+ KV<Instant, WindowedValue<KV<K, V>>> iter) {
+ if (accum.getKey() == null) {
+ WindowedValue<KV<K, V>> wv = iter.getValue();
+ KV<K, V> kv = wv.getValue();
+ V v = kv.getValue();
+ List<V> nv = Lists.newArrayList(v);
+ return KV.of(iter.getKey(), wv.withValue(KV.of(kv.getKey(), nv)));
+ }
+
+ Instant t1 = accum.getKey();
+ Instant t2 = iter.getKey();
+
+ final WindowedValue<KV<K, List<V>>> wv1 = accum.getValue();
+ final WindowedValue<KV<K, V>> wv2 = iter.getValue();
+ wv1.getValue().getValue().add(wv2.getValue().getValue());
+
+ final List<BoundedWindow> mergedWindows = new ArrayList<>();
+ if (!windowFn.isNonMerging()) {
+ try {
+ windowFn.mergeWindows(windowFn.new MergeContext() {
+
+ @Override
+ public Collection<BoundedWindow> windows() {
+ ArrayList<BoundedWindow> windows = new ArrayList<>();
+ windows.addAll(wv1.getWindows());
+ windows.addAll(wv2.getWindows());
+ return windows;
+ }
+
+ @Override
+ public void merge(Collection<BoundedWindow> toBeMerged,
+ BoundedWindow mergeResult) throws Exception {
+ mergedWindows.add(mergeResult);
+ }
+ });
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ } else {
+ mergedWindows.addAll(wv1.getWindows());
+ }
+
+ Instant timestamp = timestampCombiner.combine(t1, t2);
+ return KV.of(timestamp,
+ WindowedValue.of(wv1.getValue(), timestamp,
+ mergedWindows, wv1.getPane()));
+ }
+ }
+
+ private static class Values<K, V> extends
+ MapFunction<KV<Instant, WindowedValue<KV<K, List<V>>>>,
+ WindowedValue<KV<K, List<V>>>> {
+
+ @Override
+ public WindowedValue<KV<K, List<V>>> map(KV<org.joda.time.Instant,
+ WindowedValue<KV<K, List<V>>>> kv) {
+ Instant timestamp = kv.getKey();
+ WindowedValue<KV<K, List<V>>> wv = kv.getValue();
+ return WindowedValue.of(wv.getValue(), timestamp, wv.getWindows(), wv.getPane());
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java
new file mode 100644
index 0000000..d92979b
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.beam.runners.gearpump.translators.functions.DoFnFunction;
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+
+import org.apache.gearpump.streaming.dsl.api.functions.FilterFunction;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+
+/**
+ * {@link ParDo.MultiOutput} is translated to Gearpump flatMap function
+ * with {@link DoFn} wrapped in {@link DoFnFunction}. The outputs are
+ * further filtered with Gearpump filter function by output tag
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class ParDoMultiOutputTranslator<InputT, OutputT> implements
+ TransformTranslator<ParDo.MultiOutput<InputT, OutputT>> {
+
+ private static final long serialVersionUID = -6023461558200028849L;
+
+ @Override
+ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationContext context) {
+ PCollection<InputT> inputT = (PCollection<InputT>) context.getInput();
+ JavaStream<WindowedValue<InputT>> inputStream = context.getInputStream(inputT);
+ Collection<PCollectionView<?>> sideInputs = transform.getSideInputs();
+ Map<String, PCollectionView<?>> tagsToSideInputs =
+ TranslatorUtils.getTagsToSideInputs(sideInputs);
+
+ Map<TupleTag<?>, PValue> outputs = context.getOutputs();
+ final TupleTag<OutputT> mainOutput = transform.getMainOutputTag();
+ List<TupleTag<?>> sideOutputs = new ArrayList<>(outputs.size() - 1);
+ for (TupleTag<?> tag: outputs.keySet()) {
+ if (tag != null && !tag.getId().equals(mainOutput.getId())) {
+ sideOutputs.add(tag);
+ }
+ }
+
+ JavaStream<TranslatorUtils.RawUnionValue> unionStream = TranslatorUtils.withSideInputStream(
+ context, inputStream, tagsToSideInputs);
+
+ JavaStream<TranslatorUtils.RawUnionValue> outputStream =
+ TranslatorUtils.toList(unionStream).flatMap(
+ new DoFnFunction<>(
+ context.getPipelineOptions(),
+ transform.getFn(),
+ inputT.getWindowingStrategy(),
+ sideInputs,
+ tagsToSideInputs,
+ mainOutput,
+ sideOutputs), transform.getName());
+ for (Map.Entry<TupleTag<?>, PValue> output: outputs.entrySet()) {
+ JavaStream<WindowedValue<OutputT>> taggedStream = outputStream
+ .filter(new FilterByOutputTag(output.getKey().getId()),
+ "filter_by_output_tag")
+ .map(new TranslatorUtils.FromRawUnionValue<OutputT>(), "from_RawUnionValue");
+ context.setOutputStream(output.getValue(), taggedStream);
+ }
+ }
+
+ private static class FilterByOutputTag extends FilterFunction<TranslatorUtils.RawUnionValue> {
+
+ private static final long serialVersionUID = 7276155265895637526L;
+ private final String tag;
+
+ FilterByOutputTag(String tag) {
+ this.tag = tag;
+ }
+
+ @Override
+ public boolean filter(TranslatorUtils.RawUnionValue value) {
+ return value.getUnionTag().equals(tag);
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadBoundedTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadBoundedTranslator.java
new file mode 100644
index 0000000..8f71a8e
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadBoundedTranslator.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import org.apache.beam.runners.gearpump.translators.io.BoundedSourceWrapper;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.source.DataSource;
+
+/**
+ * {@link Read.Bounded} is translated to Gearpump source function
+ * and {@link BoundedSource} is wrapped into Gearpump {@link DataSource}.
+ */
+public class ReadBoundedTranslator <T> implements TransformTranslator<Read.Bounded<T>> {
+
+ private static final long serialVersionUID = -3899020490896998330L;
+
+ @Override
+ public void translate(Read.Bounded<T> transform, TranslationContext context) {
+ BoundedSource<T> boundedSource = transform.getSource();
+ BoundedSourceWrapper<T> sourceWrapper = new BoundedSourceWrapper<>(boundedSource,
+ context.getPipelineOptions());
+ JavaStream<WindowedValue<T>> sourceStream = context.getSourceStream(sourceWrapper);
+
+ context.setOutputStream(context.getOutput(), sourceStream);
+ }
+
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java
new file mode 100644
index 0000000..0462c57
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslator.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import org.apache.beam.runners.gearpump.translators.io.UnboundedSourceWrapper;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.source.DataSource;
+
+/**
+ * {@link Read.Unbounded} is translated to Gearpump source function
+ * and {@link UnboundedSource} is wrapped into Gearpump {@link DataSource}.
+ */
+
+public class ReadUnboundedTranslator<T> implements TransformTranslator<Read.Unbounded<T>> {
+
+ private static final long serialVersionUID = 3529494817859948619L;
+
+ @Override
+ public void translate(Read.Unbounded<T> transform, TranslationContext context) {
+ UnboundedSource<T, ?> unboundedSource = transform.getSource();
+ UnboundedSourceWrapper<T, ?> unboundedSourceWrapper = new UnboundedSourceWrapper<>(
+ unboundedSource, context.getPipelineOptions());
+ JavaStream<WindowedValue<T>> sourceStream = context.getSourceStream(unboundedSourceWrapper);
+
+ context.setOutputStream(context.getOutput(), sourceStream);
+ }
+
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java
new file mode 100644
index 0000000..c7becad
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TransformTranslator.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import java.io.Serializable;
+
+import org.apache.beam.sdk.transforms.PTransform;
+
+/**
+ * Translates {@link PTransform} to Gearpump functions.
+ */
+public interface TransformTranslator<T extends PTransform> extends Serializable {
+ void translate(T transform, TranslationContext context);
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java
new file mode 100644
index 0000000..42b7a53
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/TranslationContext.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.collect.Iterables;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.beam.runners.core.construction.TransformInputs;
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.values.PValue;
+
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.gearpump.cluster.UserConfig;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStreamApp;
+import org.apache.gearpump.streaming.source.DataSource;
+
+/**
+ * Maintains context data for {@link TransformTranslator}s.
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class TranslationContext {
+
+ private final JavaStreamApp streamApp;
+ private final GearpumpPipelineOptions pipelineOptions;
+ private AppliedPTransform<?, ?, ?> currentTransform;
+ private final Map<PValue, JavaStream<?>> streams = new HashMap<>();
+
+ public TranslationContext(JavaStreamApp streamApp, GearpumpPipelineOptions pipelineOptions) {
+ this.streamApp = streamApp;
+ this.pipelineOptions = pipelineOptions;
+ }
+
+ public void setCurrentTransform(TransformHierarchy.Node treeNode, Pipeline pipeline) {
+ this.currentTransform = treeNode.toAppliedPTransform(pipeline);
+ }
+
+ public GearpumpPipelineOptions getPipelineOptions() {
+ return pipelineOptions;
+ }
+
+ public <InputT> JavaStream<InputT> getInputStream(PValue input) {
+ return (JavaStream<InputT>) streams.get(input);
+ }
+
+ public <OutputT> void setOutputStream(PValue output, JavaStream<OutputT> outputStream) {
+ if (!streams.containsKey(output)) {
+ streams.put(output, outputStream);
+ } else {
+ throw new RuntimeException("set stream for duplicated output " + output);
+ }
+ }
+
+ public Map<TupleTag<?>, PValue> getInputs() {
+ return getCurrentTransform().getInputs();
+ }
+
+ public PValue getInput() {
+ return Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform()));
+ }
+
+ public Map<TupleTag<?>, PValue> getOutputs() {
+ return getCurrentTransform().getOutputs();
+ }
+
+ public PValue getOutput() {
+ return Iterables.getOnlyElement(getOutputs().values());
+ }
+
+ private AppliedPTransform<?, ?, ?> getCurrentTransform() {
+ checkArgument(
+ currentTransform != null,
+ "current transform not set");
+ return currentTransform;
+ }
+
+ public <T> JavaStream<T> getSourceStream(DataSource dataSource) {
+ return streamApp.source(dataSource, pipelineOptions.getParallelism(),
+ UserConfig.empty(), "source");
+ }
+
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java
new file mode 100644
index 0000000..d144b95
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslator.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import com.google.common.collect.Iterables;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.dsl.javaapi.functions.FlatMapFunction;
+import org.joda.time.Instant;
+
+/**
+ * {@link Window.Assign} is translated to Gearpump flatMap function.
+ */
+@SuppressWarnings("unchecked")
+public class WindowAssignTranslator<T> implements TransformTranslator<Window.Assign<T>> {
+
+ private static final long serialVersionUID = -964887482120489061L;
+
+ @Override
+ public void translate(Window.Assign<T> transform, TranslationContext context) {
+ PCollection<T> input = (PCollection<T>) context.getInput();
+ PCollection<T> output = (PCollection<T>) context.getOutput();
+ JavaStream<WindowedValue<T>> inputStream = context.getInputStream(input);
+ WindowingStrategy<?, ?> outputStrategy = output.getWindowingStrategy();
+ WindowFn<T, BoundedWindow> windowFn = (WindowFn<T, BoundedWindow>) outputStrategy.getWindowFn();
+ JavaStream<WindowedValue<T>> outputStream =
+ inputStream
+ .flatMap(new AssignWindows(windowFn), "assign_windows");
+
+ context.setOutputStream(output, outputStream);
+ }
+
+ /**
+ * A Function used internally by Gearpump to wrap the actual Beam's WindowFn.
+ */
+ protected static class AssignWindows<T> extends
+ FlatMapFunction<WindowedValue<T>, WindowedValue<T>> {
+
+ private static final long serialVersionUID = 7284565861938681360L;
+ private final WindowFn<T, BoundedWindow> windowFn;
+
+ AssignWindows(WindowFn<T, BoundedWindow> windowFn) {
+ this.windowFn = windowFn;
+ }
+
+ @Override
+ public Iterator<WindowedValue<T>> flatMap(final WindowedValue<T> value) {
+ try {
+ Collection<BoundedWindow> windows = windowFn.assignWindows(windowFn.new AssignContext() {
+ @Override
+ public T element() {
+ return value.getValue();
+ }
+
+ @Override
+ public Instant timestamp() {
+ return value.getTimestamp();
+ }
+
+ @Override
+ public BoundedWindow window() {
+ return Iterables.getOnlyElement(value.getWindows());
+ }
+ });
+ List<WindowedValue<T>> values = new ArrayList<>(windows.size());
+ for (BoundedWindow win: windows) {
+ values.add(
+ WindowedValue.of(value.getValue(), value.getTimestamp(), win, value.getPane()));
+ }
+ return values.iterator();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java
new file mode 100644
index 0000000..fde265a
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.functions;
+
+import com.google.common.collect.Iterables;
+
+import com.google.common.collect.Lists;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.PushbackSideInputDoFnRunner;
+import org.apache.beam.runners.core.SideInputHandler;
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.runners.gearpump.translators.utils.DoFnRunnerFactory;
+import org.apache.beam.runners.gearpump.translators.utils.NoOpStepContext;
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils;
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils.RawUnionValue;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.gearpump.streaming.dsl.javaapi.functions.FlatMapFunction;
+
+/**
+ * Gearpump {@link FlatMapFunction} wrapper over Beam {@link DoFn}.
+ */
+@SuppressWarnings("unchecked")
+public class DoFnFunction<InputT, OutputT> extends
+ FlatMapFunction<List<RawUnionValue>, RawUnionValue> {
+
+ private static final long serialVersionUID = -5701440128544343353L;
+ private final DoFnRunnerFactory<InputT, OutputT> doFnRunnerFactory;
+ private final DoFn<InputT, OutputT> doFn;
+ private transient DoFnInvoker<InputT, OutputT> doFnInvoker;
+ private transient PushbackSideInputDoFnRunner<InputT, OutputT> doFnRunner;
+ private transient SideInputHandler sideInputReader;
+ private transient List<WindowedValue<InputT>> pushedBackValues;
+ private final Collection<PCollectionView<?>> sideInputs;
+ private final Map<String, PCollectionView<?>> tagsToSideInputs;
+ private final TupleTag<OutputT> mainOutput;
+ private final List<TupleTag<?>> sideOutputs;
+ private final DoFnOutputManager outputManager;
+
+ public DoFnFunction(
+ GearpumpPipelineOptions pipelineOptions,
+ DoFn<InputT, OutputT> doFn,
+ WindowingStrategy<?, ?> windowingStrategy,
+ Collection<PCollectionView<?>> sideInputs,
+ Map<String, PCollectionView<?>> sideInputTagMapping,
+ TupleTag<OutputT> mainOutput,
+ List<TupleTag<?>> sideOutputs) {
+ this.doFn = doFn;
+ this.outputManager = new DoFnOutputManager();
+ this.doFnRunnerFactory = new DoFnRunnerFactory<>(
+ pipelineOptions,
+ doFn,
+ sideInputs,
+ outputManager,
+ mainOutput,
+ sideOutputs,
+ new NoOpStepContext(),
+ windowingStrategy
+ );
+ this.sideInputs = sideInputs;
+ this.tagsToSideInputs = sideInputTagMapping;
+ this.mainOutput = mainOutput;
+ this.sideOutputs = sideOutputs;
+ }
+
+ @Override
+ public void setup() {
+ sideInputReader = new SideInputHandler(sideInputs,
+ InMemoryStateInternals.<Void>forKey(null));
+ doFnInvoker = DoFnInvokers.invokerFor(doFn);
+ doFnInvoker.invokeSetup();
+
+ doFnRunner = doFnRunnerFactory.createRunner(sideInputReader);
+
+ pushedBackValues = new LinkedList<>();
+ outputManager.setup(mainOutput, sideOutputs);
+ }
+
+ @Override
+ public void teardown() {
+ doFnInvoker.invokeTeardown();
+ }
+
+ @Override
+ public Iterator<TranslatorUtils.RawUnionValue> flatMap(List<RawUnionValue> inputs) {
+ outputManager.clear();
+
+ doFnRunner.startBundle();
+
+ for (RawUnionValue unionValue: inputs) {
+ final String tag = unionValue.getUnionTag();
+ if (tag.equals("0")) {
+ // main input
+ pushedBackValues.add((WindowedValue<InputT>) unionValue.getValue());
+ } else {
+ // side input
+ PCollectionView<?> sideInput = tagsToSideInputs.get(unionValue.getUnionTag());
+ WindowedValue<Iterable<?>> sideInputValue =
+ (WindowedValue<Iterable<?>>) unionValue.getValue();
+ sideInputReader.addSideInputValue(sideInput, sideInputValue);
+ }
+ }
+
+ for (PCollectionView<?> sideInput: sideInputs) {
+ for (WindowedValue<InputT> value : pushedBackValues) {
+ for (BoundedWindow win: value.getWindows()) {
+ BoundedWindow sideInputWindow =
+ sideInput.getWindowMappingFn().getSideInputWindow(win);
+ if (!sideInputReader.isReady(sideInput, sideInputWindow)) {
+ Object emptyValue = WindowedValue.of(
+ Lists.newArrayList(), value.getTimestamp(), sideInputWindow, value.getPane());
+ sideInputReader.addSideInputValue(sideInput, (WindowedValue<Iterable<?>>) emptyValue);
+ }
+ }
+ }
+ }
+
+ List<WindowedValue<InputT>> nextPushedBackValues = new LinkedList<>();
+ for (WindowedValue<InputT> value : pushedBackValues) {
+ Iterable<WindowedValue<InputT>> values = doFnRunner.processElementInReadyWindows(value);
+ Iterables.addAll(nextPushedBackValues, values);
+ }
+ pushedBackValues.clear();
+ Iterables.addAll(pushedBackValues, nextPushedBackValues);
+
+ doFnRunner.finishBundle();
+
+ return outputManager.getOutputs();
+ }
+
+ private static class DoFnOutputManager implements DoFnRunners.OutputManager, Serializable {
+
+ private static final long serialVersionUID = 4967375172737408160L;
+ private transient List<RawUnionValue> outputs;
+ private transient Set<TupleTag<?>> outputTags;
+
+ @Override
+ public <T> void output(TupleTag<T> outputTag, WindowedValue<T> output) {
+ if (outputTags.contains(outputTag)) {
+ outputs.add(new RawUnionValue(outputTag.getId(), output));
+ }
+ }
+
+ void setup(TupleTag<?> mainOutput, List<TupleTag<?>> sideOutputs) {
+ outputs = new LinkedList<>();
+ outputTags = new HashSet<>();
+ outputTags.add(mainOutput);
+ outputTags.addAll(sideOutputs);
+ }
+
+ void clear() {
+ outputs.clear();
+ }
+
+ Iterator<RawUnionValue> getOutputs() {
+ return outputs.iterator();
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/package-info.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/package-info.java
new file mode 100644
index 0000000..cba2363
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Gearpump specific wrappers for Beam DoFn.
+ */
+package org.apache.beam.runners.gearpump.translators.functions;
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java
new file mode 100644
index 0000000..2c18735
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/BoundedSourceWrapper.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.io;
+
+import java.io.IOException;
+
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.Source;
+import org.apache.beam.sdk.options.PipelineOptions;
+
+/**
+ * wrapper over BoundedSource for Gearpump DataSource API.
+ */
+public class BoundedSourceWrapper<T> extends GearpumpSource<T> {
+
+ private static final long serialVersionUID = 8199570485738786123L;
+ private final BoundedSource<T> source;
+
+ public BoundedSourceWrapper(BoundedSource<T> source, PipelineOptions options) {
+ super(options);
+ this.source = source;
+ }
+
+
+ @Override
+ protected Source.Reader<T> createReader(PipelineOptions options) throws IOException {
+ return source.createReader(options);
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java
new file mode 100644
index 0000000..2f53139
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSource.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.io;
+
+import java.io.IOException;
+import java.time.Instant;
+
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils;
+import org.apache.beam.sdk.io.Source;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+
+import org.apache.gearpump.DefaultMessage;
+import org.apache.gearpump.Message;
+import org.apache.gearpump.streaming.source.DataSource;
+import org.apache.gearpump.streaming.source.Watermark;
+import org.apache.gearpump.streaming.task.TaskContext;
+
+/**
+ * common methods for {@link BoundedSourceWrapper} and {@link UnboundedSourceWrapper}.
+ */
+public abstract class GearpumpSource<T> implements DataSource {
+
+ private final byte[] serializedOptions;
+
+ private Source.Reader<T> reader;
+ private boolean available = false;
+
+ GearpumpSource(PipelineOptions options) {
+ this.serializedOptions = TranslatorUtils.serializePipelineOptions(options);
+ }
+
+ protected abstract Source.Reader<T> createReader(PipelineOptions options) throws IOException;
+
+ @Override
+ public void open(TaskContext context, Instant startTime) {
+ try {
+ PipelineOptions options = TranslatorUtils.deserializePipelineOptions(serializedOptions);
+ this.reader = createReader(options);
+ this.available = reader.start();
+ } catch (Exception e) {
+ close();
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public Message read() {
+ Message message = null;
+ try {
+ if (available) {
+ T data = reader.getCurrent();
+ org.joda.time.Instant timestamp = reader.getCurrentTimestamp();
+ message = new DefaultMessage(
+ WindowedValue.timestampedValueInGlobalWindow(data, timestamp),
+ timestamp.getMillis());
+ }
+ available = reader.advance();
+ } catch (Exception e) {
+ close();
+ throw new RuntimeException(e);
+ }
+ return message;
+ }
+
+ @Override
+ public void close() {
+ try {
+ if (reader != null) {
+ reader.close();
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public Instant getWatermark() {
+ if (reader instanceof UnboundedSource.UnboundedReader) {
+ org.joda.time.Instant watermark =
+ ((UnboundedSource.UnboundedReader) reader).getWatermark();
+ if (watermark == BoundedWindow.TIMESTAMP_MAX_VALUE) {
+ return Watermark.MAX();
+ } else {
+ return TranslatorUtils.jodaTimeToJava8Time(watermark);
+ }
+ } else {
+ if (available) {
+ return Watermark.MIN();
+ } else {
+ return Watermark.MAX();
+ }
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java
new file mode 100644
index 0000000..cb912c1
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/UnboundedSourceWrapper.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.io;
+
+import java.io.IOException;
+
+import org.apache.beam.sdk.io.Source;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+
+/**
+ * wrapper over UnboundedSource for Gearpump DataSource API.
+ */
+public class UnboundedSourceWrapper<OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark>
+ extends GearpumpSource<OutputT> {
+
+ private static final long serialVersionUID = -2453956849834747150L;
+ private final UnboundedSource<OutputT, CheckpointMarkT> source;
+
+ public UnboundedSourceWrapper(UnboundedSource<OutputT, CheckpointMarkT> source,
+ PipelineOptions options) {
+ super(options);
+ this.source = source;
+ }
+
+ @Override
+ protected Source.Reader<OutputT> createReader(PipelineOptions options) throws IOException {
+ return source.createReader(options, null);
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/ValuesSource.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/ValuesSource.java
new file mode 100644
index 0000000..b62da19
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/ValuesSource.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.io;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import javax.annotation.Nullable;
+
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Instant;
+
+/**
+ * unbounded source that reads from a Java {@link Iterable}.
+ */
+public class ValuesSource<T> extends UnboundedSource<T, UnboundedSource.CheckpointMark> {
+
+ private static final long serialVersionUID = 9113026175795235710L;
+ private final byte[] values;
+ private final IterableCoder<T> iterableCoder;
+
+ public ValuesSource(Iterable<T> values, Coder<T> coder) {
+ this.iterableCoder = IterableCoder.of(coder);
+ this.values = encode(values, iterableCoder);
+ }
+
+ private byte[] encode(Iterable<T> values, IterableCoder<T> coder) {
+ try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
+ coder.encode(values, stream, Coder.Context.OUTER);
+ return stream.toByteArray();
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ private Iterable<T> decode(byte[] bytes) throws IOException{
+ try (ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes)) {
+ return iterableCoder.decode(inputStream, Coder.Context.OUTER);
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ @Override
+ public java.util.List<? extends UnboundedSource<T, CheckpointMark>> split(
+ int desiredNumSplits, PipelineOptions options) throws Exception {
+ return Collections.singletonList(this);
+ }
+
+ @Override
+ public UnboundedReader<T> createReader(PipelineOptions options,
+ @Nullable CheckpointMark checkpointMark) {
+ try {
+ return new ValuesReader<>(decode(values), this);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Nullable
+ @Override
+ public Coder<CheckpointMark> getCheckpointMarkCoder() {
+ return null;
+ }
+
+ @Override
+ public void validate() {
+ }
+
+ @Override
+ public Coder<T> getDefaultOutputCoder() {
+ return iterableCoder.getElemCoder();
+ }
+
+ private static class ValuesReader<T> extends UnboundedReader<T> {
+ private final UnboundedSource<T, CheckpointMark> source;
+ private final Iterable<T> values;
+ private transient Iterator<T> iterator;
+ private T current;
+
+ ValuesReader(Iterable<T> values,
+ UnboundedSource<T, CheckpointMark> source) {
+ this.values = values;
+ this.source = source;
+ }
+
+ @Override
+ public boolean start() throws IOException {
+ if (null == iterator) {
+ iterator = values.iterator();
+ }
+ return advance();
+ }
+
+ @Override
+ public boolean advance() throws IOException {
+ if (iterator.hasNext()) {
+ current = iterator.next();
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public T getCurrent() throws NoSuchElementException {
+ return current;
+ }
+
+ @Override
+ public Instant getCurrentTimestamp() throws NoSuchElementException {
+ return getTimestamp(current);
+ }
+
+ @Override
+ public void close() throws IOException {
+ }
+
+ @Override
+ public Instant getWatermark() {
+ if (iterator.hasNext()) {
+ return getTimestamp(current);
+ } else {
+ return BoundedWindow.TIMESTAMP_MAX_VALUE;
+ }
+ }
+
+ @Override
+ public CheckpointMark getCheckpointMark() {
+ return null;
+ }
+
+ @Override
+ public UnboundedSource<T, ?> getCurrentSource() {
+ return source;
+ }
+
+ private Instant getTimestamp(Object value) {
+ if (value instanceof TimestampedValue) {
+ return ((TimestampedValue) value).getTimestamp();
+ } else {
+ return Instant.now();
+ }
+ }
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/package-info.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/package-info.java
new file mode 100644
index 0000000..dfdf51a
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/io/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Gearpump specific wrappers for Beam I/O.
+ */
+package org.apache.beam.runners.gearpump.translators.io;
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/package-info.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/package-info.java
new file mode 100644
index 0000000..612096a
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Gearpump specific translators.
+ */
+package org.apache.beam.runners.gearpump.translators;
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java
new file mode 100644
index 0000000..375b696
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.utils;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.PushbackSideInputDoFnRunner;
+import org.apache.beam.runners.core.ReadyCheckingSideInputReader;
+import org.apache.beam.runners.core.SimpleDoFnRunner;
+import org.apache.beam.runners.core.SimplePushbackSideInputDoFnRunner;
+import org.apache.beam.runners.core.StepContext;
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+
+/**
+ * a serializable {@link SimpleDoFnRunner}.
+ */
+public class DoFnRunnerFactory<InputT, OutputT> implements Serializable {
+
+ private static final long serialVersionUID = -4109539010014189725L;
+ private final DoFn<InputT, OutputT> fn;
+ private final byte[] serializedOptions;
+ private final Collection<PCollectionView<?>> sideInputs;
+ private final DoFnRunners.OutputManager outputManager;
+ private final TupleTag<OutputT> mainOutputTag;
+ private final List<TupleTag<?>> sideOutputTags;
+ private final StepContext stepContext;
+ private final WindowingStrategy<?, ?> windowingStrategy;
+
+ public DoFnRunnerFactory(
+ GearpumpPipelineOptions pipelineOptions,
+ DoFn<InputT, OutputT> doFn,
+ Collection<PCollectionView<?>> sideInputs,
+ DoFnRunners.OutputManager outputManager,
+ TupleTag<OutputT> mainOutputTag,
+ List<TupleTag<?>> sideOutputTags,
+ StepContext stepContext,
+ WindowingStrategy<?, ?> windowingStrategy) {
+ this.fn = doFn;
+ this.serializedOptions = TranslatorUtils.serializePipelineOptions(pipelineOptions);
+ this.sideInputs = sideInputs;
+ this.outputManager = outputManager;
+ this.mainOutputTag = mainOutputTag;
+ this.sideOutputTags = sideOutputTags;
+ this.stepContext = stepContext;
+ this.windowingStrategy = windowingStrategy;
+ }
+
+ public PushbackSideInputDoFnRunner<InputT, OutputT> createRunner(
+ ReadyCheckingSideInputReader sideInputReader) {
+ PipelineOptions options = TranslatorUtils.deserializePipelineOptions(serializedOptions);
+ DoFnRunner<InputT, OutputT> underlying = DoFnRunners.simpleRunner(
+ options, fn, sideInputReader, outputManager, mainOutputTag,
+ sideOutputTags, stepContext, windowingStrategy);
+ return SimplePushbackSideInputDoFnRunner.create(underlying, sideInputs, sideInputReader);
+ }
+
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java
new file mode 100644
index 0000000..b795ed9
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/NoOpStepContext.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.utils;
+
+import java.io.Serializable;
+
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StepContext;
+import org.apache.beam.runners.core.TimerInternals;
+
+/**
+ * serializable {@link StepContext} that basically does nothing.
+ */
+public class NoOpStepContext implements StepContext, Serializable {
+
+ @Override
+ public StateInternals stateInternals() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public TimerInternals timerInternals() {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java
new file mode 100644
index 0000000..c14298f
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtils.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.utils;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.Lists;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.beam.runners.gearpump.translators.TranslationContext;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollectionView;
+
+import org.apache.gearpump.streaming.dsl.api.functions.FoldFunction;
+import org.apache.gearpump.streaming.dsl.api.functions.MapFunction;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.dsl.window.impl.Window;
+
+/**
+ * Utility methods for translators.
+ */
+public class TranslatorUtils {
+
+ public static Instant jodaTimeToJava8Time(org.joda.time.Instant time) {
+ return Instant.ofEpochMilli(time.getMillis());
+ }
+
+ public static org.joda.time.Instant java8TimeToJodaTime(Instant time) {
+ return new org.joda.time.Instant(time.toEpochMilli());
+ }
+
+ public static Window boundedWindowToGearpumpWindow(BoundedWindow window) {
+ // Gearpump window upper bound is exclusive
+ Instant end = TranslatorUtils.jodaTimeToJava8Time(window.maxTimestamp().plus(1L));
+ if (window instanceof IntervalWindow) {
+ IntervalWindow intervalWindow = (IntervalWindow) window;
+ Instant start = TranslatorUtils.jodaTimeToJava8Time(intervalWindow.start());
+ return new Window(start, end);
+ } else if (window instanceof GlobalWindow) {
+ return new Window(TranslatorUtils.jodaTimeToJava8Time(BoundedWindow.TIMESTAMP_MIN_VALUE),
+ end);
+ } else {
+ throw new RuntimeException("unknown window " + window.getClass().getName());
+ }
+ }
+
+ public static <InputT> JavaStream<RawUnionValue> withSideInputStream(
+ TranslationContext context,
+ JavaStream<WindowedValue<InputT>> inputStream,
+ Map<String, PCollectionView<?>> tagsToSideInputs) {
+ JavaStream<RawUnionValue> mainStream =
+ inputStream.map(new ToRawUnionValue<>("0"), "map_to_RawUnionValue");
+
+ for (Map.Entry<String, PCollectionView<?>> tagToSideInput: tagsToSideInputs.entrySet()) {
+ JavaStream<WindowedValue<List<?>>> sideInputStream = context.getInputStream(
+ tagToSideInput.getValue());
+ mainStream = mainStream.merge(sideInputStream.map(new ToRawUnionValue<>(
+ tagToSideInput.getKey()), "map_to_RawUnionValue"), 1, "merge_to_MainStream");
+ }
+ return mainStream;
+ }
+
+ public static Map<String, PCollectionView<?>> getTagsToSideInputs(
+ Collection<PCollectionView<?>> sideInputs) {
+ Map<String, PCollectionView<?>> tagsToSideInputs = new HashMap<>();
+ // tag 0 is reserved for main input
+ int tag = 1;
+ for (PCollectionView<?> sideInput: sideInputs) {
+ tagsToSideInputs.put(tag + "", sideInput);
+ tag++;
+ }
+ return tagsToSideInputs;
+ }
+
+ public static JavaStream<List<RawUnionValue>> toList(JavaStream<RawUnionValue> stream) {
+ return stream.fold(new FoldFunction<RawUnionValue, List<RawUnionValue>>() {
+
+ @Override
+ public List<RawUnionValue> init() {
+ return Lists.newArrayList();
+ }
+
+ @Override
+ public List<RawUnionValue> fold(List<RawUnionValue> accumulator,
+ RawUnionValue rawUnionValue) {
+ accumulator.add(rawUnionValue);
+ return accumulator;
+ }
+ }, "fold_to_iterable");
+ }
+
+ /**
+ * Converts @link{RawUnionValue} to @link{WindowedValue}.
+ */
+ public static class FromRawUnionValue<OutputT> extends
+ MapFunction<RawUnionValue, WindowedValue<OutputT>> {
+
+ private static final long serialVersionUID = -4764968219713478955L;
+
+ @Override
+ public WindowedValue<OutputT> map(RawUnionValue value) {
+ return (WindowedValue<OutputT>) value.getValue();
+ }
+ }
+
+ private static class ToRawUnionValue<T> extends
+ MapFunction<WindowedValue<T>, RawUnionValue> {
+
+ private static final long serialVersionUID = 8648852871014813583L;
+ private final String tag;
+
+ ToRawUnionValue(String tag) {
+ this.tag = tag;
+ }
+
+ @Override
+ public RawUnionValue map(WindowedValue<T> windowedValue) {
+ return new RawUnionValue(tag, windowedValue);
+ }
+ }
+
+ public static byte[] serializePipelineOptions(PipelineOptions options) {
+ try {
+ return new ObjectMapper().writeValueAsBytes(options);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static PipelineOptions deserializePipelineOptions(byte[] serializedOptions) {
+ try {
+ return new ObjectMapper().readValue(serializedOptions, PipelineOptions.class);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * This is copied from org.apache.beam.sdk.transforms.join.RawUnionValue.
+ */
+ public static class RawUnionValue {
+ private final String unionTag;
+ private final Object value;
+
+ /**
+ * Constructs a partial union from the given union tag and value.
+ */
+ public RawUnionValue(String unionTag, Object value) {
+ this.unionTag = unionTag;
+ this.value = value;
+ }
+
+ public String getUnionTag() {
+ return unionTag;
+ }
+
+ public Object getValue() {
+ return value;
+ }
+
+ @Override
+ public String toString() {
+ return unionTag + ":" + value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ RawUnionValue that = (RawUnionValue) o;
+
+ if (unionTag != that.unionTag) {
+ return false;
+ }
+ return value != null ? value.equals(that.value) : that.value == null;
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = unionTag.hashCode();
+ result = 31 * result + value.hashCode();
+ return result;
+ }
+ }
+
+}
diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/package-info.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/package-info.java
new file mode 100644
index 0000000..ab2a6ea
--- /dev/null
+++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Utilities for translators.
+ */
+package org.apache.beam.runners.gearpump.translators.utils;
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/GearpumpRunnerRegistrarTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/GearpumpRunnerRegistrarTest.java
new file mode 100644
index 0000000..9a01d20
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/GearpumpRunnerRegistrarTest.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GearpumpRunnerRegistrar}.
+ */
+public class GearpumpRunnerRegistrarTest {
+
+ @Test
+ public void testFullName() {
+ String[] args =
+ new String[] {String.format("--runner=%s", GearpumpRunner.class.getName())};
+ PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create();
+ assertEquals(opts.getRunner(), GearpumpRunner.class);
+ }
+
+ @Test
+ public void testClassName() {
+ String[] args =
+ new String[] {String.format("--runner=%s", GearpumpRunner.class.getSimpleName())};
+ PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create();
+ assertEquals(opts.getRunner(), GearpumpRunner.class);
+ }
+
+ @Test
+ public void testOptions() {
+ assertEquals(
+ ImmutableList.of(GearpumpPipelineOptions.class),
+ new GearpumpRunnerRegistrar.Options().getPipelineOptions());
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java
new file mode 100644
index 0000000..994856b
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/PipelineOptionsTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.gearpump;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.Maps;
+import com.typesafe.config.Config;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.gearpump.cluster.ClusterConfig;
+import org.apache.gearpump.cluster.embedded.EmbeddedCluster;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GearpumpPipelineOptions}.
+ */
+public class PipelineOptionsTest {
+
+ @Test
+ public void testIgnoredFieldSerialization() throws IOException {
+ String appName = "forTest";
+ Map<String, String> serializers = Maps.newHashMap();
+ serializers.put("classA", "SerializerA");
+ GearpumpPipelineOptions options = PipelineOptionsFactory.create()
+ .as(GearpumpPipelineOptions.class);
+ Config config = ClusterConfig.master(null);
+ EmbeddedCluster cluster = new EmbeddedCluster(config);
+ options.setSerializers(serializers);
+ options.setApplicationName(appName);
+ options.setEmbeddedCluster(cluster);
+ options.setParallelism(10);
+
+ byte[] serializedOptions = serialize(options);
+ GearpumpPipelineOptions deserializedOptions = new ObjectMapper()
+ .readValue(serializedOptions, PipelineOptions.class).as(GearpumpPipelineOptions.class);
+
+ assertNull(deserializedOptions.getEmbeddedCluster());
+ assertNull(deserializedOptions.getSerializers());
+ assertEquals(10, deserializedOptions.getParallelism());
+ assertEquals(appName, deserializedOptions.getApplicationName());
+ }
+
+ private byte[] serialize(Object obj) {
+ try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
+ new ObjectMapper().writeValue(baos, obj);
+ return baos.toByteArray();
+ } catch (Exception e) {
+ throw new RuntimeException("Couldn't serialize PipelineOptions.", e);
+ }
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslatorTest.java
new file mode 100644
index 0000000..511eed1
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/CreateGearpumpPCollectionViewTranslatorTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.junit.Test;
+
+/** Tests for {@link CreateGearpumpPCollectionViewTranslator}. */
+public class CreateGearpumpPCollectionViewTranslatorTest {
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testTranslate() {
+ CreateGearpumpPCollectionViewTranslator translator =
+ new CreateGearpumpPCollectionViewTranslator();
+
+ CreateStreamingGearpumpView.CreateGearpumpPCollectionView pCollectionView =
+ mock(CreateStreamingGearpumpView.CreateGearpumpPCollectionView.class);
+
+ JavaStream javaStream = mock(JavaStream.class);
+ TranslationContext translationContext = mock(TranslationContext.class);
+
+ PValue mockInput = mock(PValue.class);
+ when(translationContext.getInput()).thenReturn(mockInput);
+ when(translationContext.getInputStream(mockInput)).thenReturn(javaStream);
+
+ PCollectionView view = mock(PCollectionView.class);
+ when(pCollectionView.getView()).thenReturn(view);
+
+ translator.translate(pCollectionView, translationContext);
+ verify(translationContext, times(1)).setOutputStream(view, javaStream);
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java
new file mode 100644
index 0000000..1262177
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/FlattenPCollectionsTranslatorTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static org.mockito.Matchers.argThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.beam.runners.gearpump.translators.io.UnboundedSourceWrapper;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.gearpump.streaming.dsl.api.functions.MapFunction;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.source.DataSource;
+import org.junit.Test;
+import org.mockito.ArgumentMatcher;
+
+/** Tests for {@link FlattenPCollectionsTranslator}. */
+public class FlattenPCollectionsTranslatorTest {
+
+ private FlattenPCollectionsTranslator translator = new FlattenPCollectionsTranslator();
+ private Flatten.PCollections transform = mock(Flatten.PCollections.class);
+
+ class UnboundedSourceWrapperMatcher extends ArgumentMatcher<DataSource> {
+ @Override
+ public boolean matches(Object o) {
+ return o instanceof UnboundedSourceWrapper;
+ }
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testTranslateWithEmptyCollection() {
+ PCollection mockOutput = mock(PCollection.class);
+ TranslationContext translationContext = mock(TranslationContext.class);
+
+ when(translationContext.getInputs()).thenReturn(Collections.EMPTY_MAP);
+ when(translationContext.getOutput()).thenReturn(mockOutput);
+
+ translator.translate(transform, translationContext);
+ verify(translationContext).getSourceStream(argThat(new UnboundedSourceWrapperMatcher()));
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testTranslateWithOneCollection() {
+ JavaStream javaStream = mock(JavaStream.class);
+ TranslationContext translationContext = mock(TranslationContext.class);
+
+ Map<TupleTag<?>, PValue> inputs = new HashMap<>();
+ TupleTag tag = mock(TupleTag.class);
+ PCollection mockCollection = mock(PCollection.class);
+ inputs.put(tag, mockCollection);
+
+ when(translationContext.getInputs()).thenReturn(inputs);
+ when(translationContext.getInputStream(mockCollection)).thenReturn(javaStream);
+
+ PValue mockOutput = mock(PValue.class);
+ when(translationContext.getOutput()).thenReturn(mockOutput);
+
+ translator.translate(transform, translationContext);
+ verify(translationContext, times(1)).setOutputStream(mockOutput, javaStream);
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testWithMoreThanOneCollections() {
+ String transformName = "transform";
+ when(transform.getName()).thenReturn(transformName);
+
+ JavaStream javaStream1 = mock(JavaStream.class);
+ JavaStream javaStream2 = mock(JavaStream.class);
+ JavaStream mergedStream = mock(JavaStream.class);
+ TranslationContext translationContext = mock(TranslationContext.class);
+
+ Map<TupleTag<?>, PValue> inputs = new HashMap<>();
+ TupleTag tag1 = mock(TupleTag.class);
+ PCollection mockCollection1 = mock(PCollection.class);
+ inputs.put(tag1, mockCollection1);
+
+ TupleTag tag2 = mock(TupleTag.class);
+ PCollection mockCollection2 = mock(PCollection.class);
+ inputs.put(tag2, mockCollection2);
+
+ PCollection output = mock(PCollection.class);
+
+ when(translationContext.getInputs()).thenReturn(inputs);
+ when(translationContext.getInputStream(mockCollection1)).thenReturn(javaStream1);
+ when(translationContext.getInputStream(mockCollection2)).thenReturn(javaStream2);
+ when(javaStream1.merge(javaStream2, 1, transformName)).thenReturn(mergedStream);
+ when(javaStream2.merge(javaStream1, 1, transformName)).thenReturn(mergedStream);
+
+ when(translationContext.getOutput()).thenReturn(output);
+
+ translator.translate(transform, translationContext);
+ verify(translationContext).setOutputStream(output, mergedStream);
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testWithDuplicatedCollections() {
+ String transformName = "transform";
+ when(transform.getName()).thenReturn(transformName);
+
+ JavaStream javaStream1 = mock(JavaStream.class);
+ TranslationContext translationContext = mock(TranslationContext.class);
+
+ Map<TupleTag<?>, PValue> inputs = new HashMap<>();
+ TupleTag tag1 = mock(TupleTag.class);
+ PCollection mockCollection1 = mock(PCollection.class);
+ inputs.put(tag1, mockCollection1);
+
+ TupleTag tag2 = mock(TupleTag.class);
+ inputs.put(tag2, mockCollection1);
+
+ when(translationContext.getInputs()).thenReturn(inputs);
+ when(translationContext.getInputStream(mockCollection1)).thenReturn(javaStream1);
+
+ translator.translate(transform, translationContext);
+ verify(javaStream1).map(any(MapFunction.class), eq("dummy"));
+ verify(javaStream1).merge(any(JavaStream.class), eq(1), eq(transformName));
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java
new file mode 100644
index 0000000..d5b931b
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/GroupByKeyTranslatorTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+
+import java.time.Instant;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.beam.runners.gearpump.translators.GroupByKeyTranslator.GearpumpWindowFn;
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.gearpump.streaming.dsl.window.api.WindowFunction;
+import org.apache.gearpump.streaming.dsl.window.impl.Window;
+import org.joda.time.Duration;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/** Tests for {@link GroupByKeyTranslator}. */
+@RunWith(Parameterized.class)
+public class GroupByKeyTranslatorTest {
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testGearpumpWindowFn() {
+ GearpumpWindowFn windowFn = new GearpumpWindowFn(true);
+ List<BoundedWindow> windows =
+ Lists.newArrayList(
+ new IntervalWindow(new org.joda.time.Instant(0), new org.joda.time.Instant(10)),
+ new IntervalWindow(new org.joda.time.Instant(5), new org.joda.time.Instant(15)));
+
+ WindowFunction.Context<WindowedValue<String>> context =
+ new WindowFunction.Context<WindowedValue<String>>() {
+ @Override
+ public Instant timestamp() {
+ return Instant.EPOCH;
+ }
+
+ @Override
+ public WindowedValue<String> element() {
+ return WindowedValue.of(
+ "v1", new org.joda.time.Instant(6), windows, PaneInfo.NO_FIRING);
+ }
+ };
+
+ Window[] result = windowFn.apply(context);
+ List<Window> expected = Lists.newArrayList();
+ for (BoundedWindow w : windows) {
+ expected.add(TranslatorUtils.boundedWindowToGearpumpWindow(w));
+ }
+ assertThat(result, equalTo(expected.toArray()));
+ }
+
+ @Parameterized.Parameters(name = "{index}: {0}")
+ public static Iterable<TimestampCombiner> data() {
+ return ImmutableList.of(
+ TimestampCombiner.EARLIEST,
+ TimestampCombiner.LATEST,
+ TimestampCombiner.END_OF_WINDOW);
+ }
+
+ @Parameterized.Parameter(0)
+ public TimestampCombiner timestampCombiner;
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testKeyedByTimestamp() {
+ WindowFn slidingWindows = Sessions.withGapDuration(Duration.millis(10));
+ BoundedWindow window =
+ new IntervalWindow(new org.joda.time.Instant(0), new org.joda.time.Instant(10));
+ GroupByKeyTranslator.KeyedByTimestamp keyedByTimestamp =
+ new GroupByKeyTranslator.KeyedByTimestamp(slidingWindows, timestampCombiner);
+ WindowedValue<KV<String, String>> value =
+ WindowedValue.of(
+ KV.of("key", "val"), org.joda.time.Instant.now(), window, PaneInfo.NO_FIRING);
+ KV<org.joda.time.Instant, WindowedValue<KV<String, String>>> result =
+ keyedByTimestamp.map(value);
+ org.joda.time.Instant time =
+ timestampCombiner.assign(window,
+ slidingWindows.getOutputTime(value.getTimestamp(), window));
+ assertThat(result, equalTo(KV.of(time, value)));
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testMerge() {
+ WindowFn slidingWindows = Sessions.withGapDuration(Duration.millis(10));
+ GroupByKeyTranslator.Merge merge = new GroupByKeyTranslator.Merge(slidingWindows,
+ timestampCombiner);
+ org.joda.time.Instant key1 = new org.joda.time.Instant(5);
+ WindowedValue<KV<String, String>> value1 =
+ WindowedValue.of(
+ KV.of("key1", "value1"),
+ key1,
+ new IntervalWindow(new org.joda.time.Instant(5), new org.joda.time.Instant(10)),
+ PaneInfo.NO_FIRING);
+
+ org.joda.time.Instant key2 = new org.joda.time.Instant(10);
+ WindowedValue<KV<String, String>> value2 =
+ WindowedValue.of(
+ KV.of("key2", "value2"),
+ key2,
+ new IntervalWindow(new org.joda.time.Instant(9), new org.joda.time.Instant(14)),
+ PaneInfo.NO_FIRING);
+
+ KV<org.joda.time.Instant, WindowedValue<KV<String, List<String>>>> result1 =
+ merge.fold(KV.<org.joda.time.Instant, WindowedValue<KV<String, List<String>>>>of(
+ null, null), KV.of(key1, value1));
+ assertThat(result1.getKey(), equalTo(key1));
+ assertThat(result1.getValue().getValue().getValue(), equalTo(Lists.newArrayList("value1")));
+
+ KV<org.joda.time.Instant, WindowedValue<KV<String, List<String>>>> result2 =
+ merge.fold(result1, KV.of(key2, value2));
+ assertThat(result2.getKey(), equalTo(timestampCombiner.combine(key1, key2)));
+ Collection<? extends BoundedWindow> resultWindows = result2.getValue().getWindows();
+ assertThat(resultWindows.size(), equalTo(1));
+ IntervalWindow expectedWindow =
+ new IntervalWindow(new org.joda.time.Instant(5), new org.joda.time.Instant(14));
+ assertThat(resultWindows.toArray()[0], equalTo(expectedWindow));
+ assertThat(
+ result2.getValue().getValue().getValue(), equalTo(Lists.newArrayList("value1", "value2")));
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/ReadBoundedTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/ReadBoundedTranslatorTest.java
new file mode 100644
index 0000000..20ee1a2
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/ReadBoundedTranslatorTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.runners.gearpump.translators.io.BoundedSourceWrapper;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.source.DataSource;
+import org.junit.Test;
+import org.mockito.ArgumentMatcher;
+
+/** Tests for {@link ReadBoundedTranslator}. */
+public class ReadBoundedTranslatorTest {
+
+ class BoundedSourceWrapperMatcher extends ArgumentMatcher<DataSource> {
+ @Override
+ public boolean matches(Object o) {
+ return o instanceof BoundedSourceWrapper;
+ }
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testTranslate() {
+ ReadBoundedTranslator translator = new ReadBoundedTranslator();
+ GearpumpPipelineOptions options =
+ PipelineOptionsFactory.create().as(GearpumpPipelineOptions.class);
+ Read.Bounded transform = mock(Read.Bounded.class);
+ BoundedSource source = mock(BoundedSource.class);
+ when(transform.getSource()).thenReturn(source);
+
+ TranslationContext translationContext = mock(TranslationContext.class);
+ when(translationContext.getPipelineOptions()).thenReturn(options);
+
+ JavaStream stream = mock(JavaStream.class);
+ PValue mockOutput = mock(PValue.class);
+ when(translationContext.getOutput()).thenReturn(mockOutput);
+ when(translationContext.getSourceStream(any(DataSource.class))).thenReturn(stream);
+
+ translator.translate(transform, translationContext);
+ verify(translationContext).getSourceStream(argThat(new BoundedSourceWrapperMatcher()));
+ verify(translationContext).setOutputStream(mockOutput, stream);
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslatorTest.java
new file mode 100644
index 0000000..f27b568
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/ReadUnboundedTranslatorTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.runners.gearpump.translators.io.UnboundedSourceWrapper;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.gearpump.streaming.dsl.javaapi.JavaStream;
+import org.apache.gearpump.streaming.source.DataSource;
+import org.junit.Test;
+import org.mockito.ArgumentMatcher;
+
+/** Tests for {@link ReadUnboundedTranslator}. */
+public class ReadUnboundedTranslatorTest {
+
+ class UnboundedSourceWrapperMatcher extends ArgumentMatcher<DataSource> {
+ @Override
+ public boolean matches(Object o) {
+ return o instanceof UnboundedSourceWrapper;
+ }
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testTranslate() {
+ ReadUnboundedTranslator translator = new ReadUnboundedTranslator();
+ GearpumpPipelineOptions options =
+ PipelineOptionsFactory.create().as(GearpumpPipelineOptions.class);
+ Read.Unbounded transform = mock(Read.Unbounded.class);
+ UnboundedSource source = mock(UnboundedSource.class);
+ when(transform.getSource()).thenReturn(source);
+
+ TranslationContext translationContext = mock(TranslationContext.class);
+ when(translationContext.getPipelineOptions()).thenReturn(options);
+
+ JavaStream stream = mock(JavaStream.class);
+ PValue mockOutput = mock(PValue.class);
+ when(translationContext.getOutput()).thenReturn(mockOutput);
+ when(translationContext.getSourceStream(any(DataSource.class))).thenReturn(stream);
+
+ translator.translate(transform, translationContext);
+ verify(translationContext).getSourceStream(argThat(new UnboundedSourceWrapperMatcher()));
+ verify(translationContext).setOutputStream(mockOutput, stream);
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslatorTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslatorTest.java
new file mode 100644
index 0000000..06ccaaf
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/WindowAssignTranslatorTest.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.Lists;
+import java.util.ArrayList;
+import java.util.Iterator;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+
+/** Tests for {@link WindowAssignTranslator}. */
+public class WindowAssignTranslatorTest {
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testAssignWindowsWithSlidingWindow() {
+ WindowFn slidingWindows = SlidingWindows.of(Duration.millis(10)).every(Duration.millis(5));
+ WindowAssignTranslator.AssignWindows<String> assignWindows =
+ new WindowAssignTranslator.AssignWindows(slidingWindows);
+
+ String value = "v1";
+ Instant timestamp = new Instant(1);
+ WindowedValue<String> windowedValue =
+ WindowedValue.timestampedValueInGlobalWindow(value, timestamp);
+ ArrayList<WindowedValue<String>> expected = new ArrayList<>();
+ expected.add(
+ WindowedValue.of(
+ value,
+ timestamp,
+ new IntervalWindow(new Instant(0), new Instant(10)),
+ PaneInfo.NO_FIRING));
+ expected.add(
+ WindowedValue.of(
+ value,
+ timestamp,
+ new IntervalWindow(new Instant(-5), new Instant(5)),
+ PaneInfo.NO_FIRING));
+
+ Iterator<WindowedValue<String>> result = assignWindows.flatMap(windowedValue);
+ assertThat(expected, equalTo(Lists.newArrayList(result)));
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testAssignWindowsWithSessions() {
+ WindowFn slidingWindows = Sessions.withGapDuration(Duration.millis(10));
+ WindowAssignTranslator.AssignWindows<String> assignWindows =
+ new WindowAssignTranslator.AssignWindows(slidingWindows);
+
+ String value = "v1";
+ Instant timestamp = new Instant(1);
+ WindowedValue<String> windowedValue =
+ WindowedValue.timestampedValueInGlobalWindow(value, timestamp);
+ ArrayList<WindowedValue<String>> expected = new ArrayList<>();
+ expected.add(
+ WindowedValue.of(
+ value,
+ timestamp,
+ new IntervalWindow(new Instant(1), new Instant(11)),
+ PaneInfo.NO_FIRING));
+
+ Iterator<WindowedValue<String>> result = assignWindows.flatMap(windowedValue);
+ assertThat(expected, equalTo(Lists.newArrayList(result)));
+ }
+
+ @Test
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public void testAssignWindowsGlobal() {
+ WindowFn slidingWindows = new GlobalWindows();
+ WindowAssignTranslator.AssignWindows<String> assignWindows =
+ new WindowAssignTranslator.AssignWindows(slidingWindows);
+
+ String value = "v1";
+ Instant timestamp = new Instant(1);
+ WindowedValue<String> windowedValue =
+ WindowedValue.timestampedValueInGlobalWindow(value, timestamp);
+ ArrayList<WindowedValue<String>> expected = new ArrayList<>();
+ expected.add(WindowedValue.timestampedValueInGlobalWindow(value, timestamp));
+
+ Iterator<WindowedValue<String>> result = assignWindows.flatMap(windowedValue);
+ assertThat(expected, equalTo(Lists.newArrayList(result)));
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java
new file mode 100644
index 0000000..cc4284f
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/GearpumpSourceTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.io;
+
+import com.google.common.collect.Lists;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.List;
+
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.runners.gearpump.translators.utils.TranslatorUtils;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.Source;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.apache.gearpump.DefaultMessage;
+import org.apache.gearpump.Message;
+import org.apache.gearpump.streaming.source.Watermark;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** Tests for {@link GearpumpSource}. */
+public class GearpumpSourceTest {
+ private static final List<TimestampedValue<String>> TEST_VALUES =
+ Lists.newArrayList(
+ TimestampedValue.of("a", BoundedWindow.TIMESTAMP_MIN_VALUE),
+ TimestampedValue.of("b", new org.joda.time.Instant(0)),
+ TimestampedValue.of("c", new org.joda.time.Instant(53)),
+ TimestampedValue.of("d", BoundedWindow.TIMESTAMP_MAX_VALUE)
+ );
+
+ private static class SourceForTest<T> extends GearpumpSource<T> {
+ private ValuesSource<T> valuesSource;
+
+ SourceForTest(PipelineOptions options, ValuesSource<T> valuesSource) {
+ super(options);
+ this.valuesSource = valuesSource;
+ }
+
+ @Override
+ protected Source.Reader<T> createReader(PipelineOptions options) throws IOException {
+ return this.valuesSource.createReader(options, null);
+ }
+ }
+
+ @Test
+ public void testGearpumpSource() {
+ GearpumpPipelineOptions options =
+ PipelineOptionsFactory.create().as(GearpumpPipelineOptions.class);
+ ValuesSource<TimestampedValue<String>> valuesSource =
+ new ValuesSource<>(
+ TEST_VALUES, TimestampedValue.TimestampedValueCoder.of(StringUtf8Coder.of()));
+ SourceForTest<TimestampedValue<String>> sourceForTest =
+ new SourceForTest<>(options, valuesSource);
+ sourceForTest.open(null, Instant.EPOCH);
+
+ for (int i = 0; i < TEST_VALUES.size(); i++) {
+ TimestampedValue<String> value = TEST_VALUES.get(i);
+
+ // Check the watermark first since the Source will advance when it's opened
+ if (i < TEST_VALUES.size() - 1) {
+ Instant expectedWaterMark = TranslatorUtils.jodaTimeToJava8Time(value.getTimestamp());
+ Assert.assertEquals(expectedWaterMark, sourceForTest.getWatermark());
+ } else {
+ Assert.assertEquals(Watermark.MAX(), sourceForTest.getWatermark());
+ }
+
+ Message expectedMsg =
+ new DefaultMessage(
+ WindowedValue.timestampedValueInGlobalWindow(value, value.getTimestamp()),
+ value.getTimestamp().getMillis());
+ Message message = sourceForTest.read();
+ Assert.assertEquals(expectedMsg, message);
+ }
+
+ Assert.assertNull(sourceForTest.read());
+ Assert.assertEquals(Watermark.MAX(), sourceForTest.getWatermark());
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java
new file mode 100644
index 0000000..439e1b1
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/io/ValueSoureTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.gearpump.translators.io;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import com.typesafe.config.Config;
+import com.typesafe.config.ConfigValueFactory;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.beam.runners.gearpump.GearpumpPipelineOptions;
+import org.apache.beam.runners.gearpump.GearpumpRunner;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.gearpump.cluster.ClusterConfig;
+import org.apache.gearpump.cluster.embedded.EmbeddedCluster;
+import org.apache.gearpump.util.Constants;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** Tests for {@link ValuesSource}. */
+public class ValueSoureTest {
+
+ @Test
+ public void testValueSource() {
+ GearpumpPipelineOptions options =
+ PipelineOptionsFactory.create().as(GearpumpPipelineOptions.class);
+ Config config = ClusterConfig.master(null);
+ config =
+ config.withValue(Constants.APPLICATION_TOTAL_RETRIES(), ConfigValueFactory.fromAnyRef(0));
+ EmbeddedCluster cluster = new EmbeddedCluster(config);
+ cluster.start();
+
+ options.setEmbeddedCluster(cluster);
+ options.setRunner(GearpumpRunner.class);
+ options.setParallelism(1);
+ Pipeline p = Pipeline.create(options);
+ List<String> values = Lists.newArrayList("1", "2", "3", "4", "5");
+ ValuesSource<String> source = new ValuesSource<>(values, StringUtf8Coder.of());
+ p.apply(Read.from(source)).apply(ParDo.of(new ResultCollector()));
+
+ p.run().waitUntilFinish();
+ cluster.stop();
+
+ Assert.assertEquals(Sets.newHashSet(values), ResultCollector.RESULTS);
+ }
+
+ private static class ResultCollector extends DoFn<Object, Void> {
+ private static final Set<Object> RESULTS = Collections.synchronizedSet(new HashSet<>());
+
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ RESULTS.add(c.element());
+ }
+ }
+}
diff --git a/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java
new file mode 100644
index 0000000..6ebe59b
--- /dev/null
+++ b/runners/gearpump/src/test/java/org/apache/beam/runners/gearpump/translators/utils/TranslatorUtilsTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.gearpump.translators.utils;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.Lists;
+
+import java.time.Instant;
+import java.util.List;
+
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.gearpump.streaming.dsl.window.impl.Window;
+import org.junit.Test;
+
+/**
+ * Tests for {@link TranslatorUtils}.
+ */
+public class TranslatorUtilsTest {
+
+ private static final List<KV<org.joda.time.Instant, Instant>> TEST_VALUES = Lists.newArrayList(
+ KV.of(new org.joda.time.Instant(0), Instant.EPOCH),
+ KV.of(new org.joda.time.Instant(42), Instant.ofEpochMilli(42)),
+ KV.of(new org.joda.time.Instant(Long.MIN_VALUE), Instant.ofEpochMilli(Long.MIN_VALUE)),
+ KV.of(new org.joda.time.Instant(Long.MAX_VALUE), Instant.ofEpochMilli(Long.MAX_VALUE)));
+
+ @Test
+ public void testJodaTimeAndJava8TimeConversion() {
+ for (KV<org.joda.time.Instant, Instant> kv: TEST_VALUES) {
+ assertThat(TranslatorUtils.jodaTimeToJava8Time(kv.getKey()),
+ equalTo(kv.getValue()));
+ assertThat(TranslatorUtils.java8TimeToJodaTime(kv.getValue()),
+ equalTo(kv.getKey()));
+ }
+ }
+
+ @Test
+ public void testBoundedWindowToGearpumpWindow() {
+ assertThat(TranslatorUtils.boundedWindowToGearpumpWindow(
+ new IntervalWindow(new org.joda.time.Instant(0),
+ new org.joda.time.Instant(Long.MAX_VALUE))),
+ equalTo(Window.apply(Instant.EPOCH, Instant.ofEpochMilli(Long.MAX_VALUE))));
+ assertThat(TranslatorUtils.boundedWindowToGearpumpWindow(
+ new IntervalWindow(new org.joda.time.Instant(Long.MIN_VALUE),
+ new org.joda.time.Instant(Long.MAX_VALUE))),
+ equalTo(Window.apply(Instant.ofEpochMilli(Long.MIN_VALUE),
+ Instant.ofEpochMilli(Long.MAX_VALUE))));
+ BoundedWindow globalWindow = GlobalWindow.INSTANCE;
+ assertThat(TranslatorUtils.boundedWindowToGearpumpWindow(globalWindow),
+ equalTo(Window.apply(Instant.ofEpochMilli(BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis()),
+ Instant.ofEpochMilli(globalWindow.maxTimestamp().getMillis() + 1))));
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml
index 1181b79..46352fb 100644
--- a/runners/google-cloud-dataflow-java/pom.xml
+++ b/runners/google-cloud-dataflow-java/pom.xml
@@ -245,6 +245,7 @@
-->
<excludedGroups>
org.apache.beam.sdk.testing.LargeKeys$Above10MB,
+ org.apache.beam.sdk.testing.UsesAttemptedMetrics,
org.apache.beam.sdk.testing.UsesDistributionMetrics,
org.apache.beam.sdk.testing.UsesGaugeMetrics,
org.apache.beam.sdk.testing.UsesSetState,
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/AssignWindows.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/AssignWindows.java
index 572b005..7d1dadb 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/AssignWindows.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/AssignWindows.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.dataflow;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
@@ -59,8 +58,8 @@
transform.getOutputStrategyInternal(input.getWindowingStrategy());
if (transform.getWindowFn() != null) {
// If the windowFn changed, we create a primitive, and run the AssignWindows operation here.
- return PCollection.<T>createPrimitiveOutputInternal(
- input.getPipeline(), outputStrategy, input.isBounded());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), outputStrategy, input.isBounded(), input.getCoder());
} else {
// If the windowFn didn't change, we just run a pass-through transform and then set the
// new windowing strategy.
@@ -69,7 +68,7 @@
public void processElement(ProcessContext c) throws Exception {
c.output(c.element());
}
- })).setWindowingStrategyInternal(outputStrategy);
+ })).setWindowingStrategyInternal(outputStrategy).setCoder(input.getCoder());
}
}
@@ -79,11 +78,6 @@
}
@Override
- protected Coder<?> getDefaultOutputCoder(PCollection<T> input) {
- return input.getCoder();
- }
-
- @Override
protected String getKindString() {
return "Window.Into()";
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
index ad3faed1..9a77b4b 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
@@ -1258,18 +1258,13 @@
@Override
public PCollection<KV<K1, Iterable<KV<K2, V>>>> expand(PCollection<KV<K1, KV<K2, V>>> input) {
- PCollection<KV<K1, Iterable<KV<K2, V>>>> rval =
- PCollection.<KV<K1, Iterable<KV<K2, V>>>>createPrimitiveOutputInternal(
- input.getPipeline(),
- WindowingStrategy.globalDefault(),
- IsBounded.BOUNDED);
-
- @SuppressWarnings({"unchecked", "rawtypes"})
- KvCoder<K1, KV<K2, V>> inputCoder = (KvCoder) input.getCoder();
- rval.setCoder(
- KvCoder.of(inputCoder.getKeyCoder(),
- IterableCoder.of(inputCoder.getValueCoder())));
- return rval;
+ @SuppressWarnings("unchecked")
+ KvCoder<K1, KV<K2, V>> inputCoder = (KvCoder<K1, KV<K2, V>>) input.getCoder();
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ IsBounded.BOUNDED,
+ KvCoder.of(inputCoder.getKeyCoder(), IterableCoder.of(inputCoder.getValueCoder())));
}
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java
index caad7f8..3b01d69 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/CreateDataflowView.java
@@ -37,9 +37,8 @@
@Override
public PCollection<ElemT> expand(PCollection<ElemT> input) {
- return PCollection.<ElemT>createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
- .setCoder(input.getCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder());
}
public PCollectionView<ViewT> getView() {
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 762ac9f..6999616 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -92,6 +92,7 @@
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.extensions.gcp.storage.PathValidator;
import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.FileBasedSink;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.io.UnboundedSource;
@@ -320,7 +321,7 @@
overridesBuilder.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(PubsubUnboundedSource.class),
- new ReflectiveRootOverrideFactory(StreamingPubsubIORead.class, this)));
+ new StreamingPubsubIOReadOverrideFactory()));
}
if (!hasExperiment(options, "enable_custom_pubsub_sink")) {
overridesBuilder.add(
@@ -358,11 +359,11 @@
// must precede it
PTransformOverride.of(
PTransformMatchers.classEqualTo(Read.Bounded.class),
- new ReflectiveRootOverrideFactory(StreamingBoundedRead.class, this)))
+ new StreamingBoundedReadOverrideFactory()))
.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(Read.Unbounded.class),
- new ReflectiveRootOverrideFactory(StreamingUnboundedRead.class, this)))
+ new StreamingUnboundedReadOverrideFactory()))
.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(View.CreatePCollectionView.class),
@@ -447,38 +448,6 @@
}
}
- private static class ReflectiveRootOverrideFactory<T>
- implements PTransformOverrideFactory<
- PBegin, PCollection<T>, PTransform<PInput, PCollection<T>>> {
- private final Class<PTransform<PBegin, PCollection<T>>> replacement;
- private final DataflowRunner runner;
-
- private ReflectiveRootOverrideFactory(
- Class<PTransform<PBegin, PCollection<T>>> replacement, DataflowRunner runner) {
- this.replacement = replacement;
- this.runner = runner;
- }
-
- @Override
- public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
- AppliedPTransform<PBegin, PCollection<T>, PTransform<PInput, PCollection<T>>> transform) {
- PTransform<PInput, PCollection<T>> original = transform.getTransform();
- return PTransformReplacement.of(
- transform.getPipeline().begin(),
- InstanceBuilder.ofType(replacement)
- .withArg(DataflowRunner.class, runner)
- .withArg(
- (Class<? super PTransform<PInput, PCollection<T>>>) original.getClass(), original)
- .build());
- }
-
- @Override
- public Map<PValue, ReplacementOutput> mapOutputs(
- Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) {
- return ReplacementOutputs.singleton(outputs, newOutput);
- }
- }
-
private String debuggerMessage(String projectId, String uniquifier) {
return String.format("To debug your job, visit Google Cloud Debugger at: "
+ "https://console.developers.google.com/debug?project=%s&dbgee=%s",
@@ -837,6 +806,24 @@
// PubsubIO translations
// ================================================================================
+ private static class StreamingPubsubIOReadOverrideFactory
+ implements PTransformOverrideFactory<
+ PBegin, PCollection<PubsubMessage>, PubsubUnboundedSource> {
+ @Override
+ public PTransformReplacement<PBegin, PCollection<PubsubMessage>> getReplacementTransform(
+ AppliedPTransform<PBegin, PCollection<PubsubMessage>, PubsubUnboundedSource> transform) {
+ return PTransformReplacement.of(
+ transform.getPipeline().begin(), new StreamingPubsubIORead(transform.getTransform()));
+ }
+
+ @Override
+ public Map<PValue, ReplacementOutput> mapOutputs(
+ Map<TupleTag<?>, PValue> outputs, PCollection<PubsubMessage> newOutput) {
+ return ReplacementOutputs.singleton(outputs, newOutput);
+ }
+ }
+
+
/**
* Suppress application of {@link PubsubUnboundedSource#expand} in streaming mode so that we can
* instead defer to Windmill's implementation.
@@ -845,9 +832,7 @@
extends PTransform<PBegin, PCollection<PubsubMessage>> {
private final PubsubUnboundedSource transform;
- /** Builds an instance of this class from the overridden transform. */
- @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply()
- public StreamingPubsubIORead(DataflowRunner runner, PubsubUnboundedSource transform) {
+ public StreamingPubsubIORead(PubsubUnboundedSource transform) {
this.transform = transform;
}
@@ -857,9 +842,11 @@
@Override
public PCollection<PubsubMessage> expand(PBegin input) {
- return PCollection.<PubsubMessage>createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
- .setCoder(new PubsubMessageWithAttributesCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ IsBounded.UNBOUNDED,
+ new PubsubMessageWithAttributesCoder());
}
@Override
@@ -1128,12 +1115,7 @@
@Override
public PCollection<byte[]> expand(PBegin input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), isBounded);
- }
-
- @Override
- protected Coder<?> getDefaultOutputCoder() {
- return ByteArrayCoder.of();
+ input.getPipeline(), WindowingStrategy.globalDefault(), isBounded, ByteArrayCoder.of());
}
private static class Translator implements TransformTranslator<Impulse> {
@@ -1156,6 +1138,22 @@
}
}
+ private static class StreamingUnboundedReadOverrideFactory<T>
+ implements PTransformOverrideFactory<PBegin, PCollection<T>, Read.Unbounded<T>> {
+ @Override
+ public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+ AppliedPTransform<PBegin, PCollection<T>, Read.Unbounded<T>> transform) {
+ return PTransformReplacement.of(
+ transform.getPipeline().begin(), new StreamingUnboundedRead<>(transform.getTransform()));
+ }
+
+ @Override
+ public Map<PValue, ReplacementOutput> mapOutputs(
+ Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) {
+ return ReplacementOutputs.singleton(outputs, newOutput);
+ }
+ }
+
/**
* Specialized implementation for
* {@link org.apache.beam.sdk.io.Read.Unbounded Read.Unbounded} for the
@@ -1167,18 +1165,11 @@
private static class StreamingUnboundedRead<T> extends PTransform<PBegin, PCollection<T>> {
private final UnboundedSource<T, ?> source;
- /** Builds an instance of this class from the overridden transform. */
- @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply()
- public StreamingUnboundedRead(DataflowRunner runner, Read.Unbounded<T> transform) {
+ public StreamingUnboundedRead(Read.Unbounded<T> transform) {
this.source = transform.getSource();
}
@Override
- protected Coder<T> getDefaultOutputCoder() {
- return source.getDefaultOutputCoder();
- }
-
- @Override
public final PCollection<T> expand(PBegin input) {
source.validate();
@@ -1205,13 +1196,9 @@
@Override
public final PCollection<ValueWithRecordId<T>> expand(PInput input) {
- return PCollection.<ValueWithRecordId<T>>createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED);
- }
-
- @Override
- protected Coder<ValueWithRecordId<T>> getDefaultOutputCoder() {
- return ValueWithRecordId.ValueWithRecordIdCoder.of(source.getDefaultOutputCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED,
+ ValueWithRecordId.ValueWithRecordIdCoder.of(source.getOutputCoder()));
}
@Override
@@ -1275,6 +1262,22 @@
}
}
+ private static class StreamingBoundedReadOverrideFactory<T>
+ implements PTransformOverrideFactory<PBegin, PCollection<T>, Read.Bounded<T>> {
+ @Override
+ public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+ AppliedPTransform<PBegin, PCollection<T>, Read.Bounded<T>> transform) {
+ return PTransformReplacement.of(
+ transform.getPipeline().begin(), new StreamingBoundedRead<>(transform.getTransform()));
+ }
+
+ @Override
+ public Map<PValue, ReplacementOutput> mapOutputs(
+ Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) {
+ return ReplacementOutputs.singleton(outputs, newOutput);
+ }
+ }
+
/**
* Specialized implementation for {@link org.apache.beam.sdk.io.Read.Bounded Read.Bounded} for the
* Dataflow runner in streaming mode.
@@ -1282,18 +1285,11 @@
private static class StreamingBoundedRead<T> extends PTransform<PBegin, PCollection<T>> {
private final BoundedSource<T> source;
- /** Builds an instance of this class from the overridden transform. */
- @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply()
- public StreamingBoundedRead(DataflowRunner runner, Read.Bounded<T> transform) {
+ public StreamingBoundedRead(Read.Bounded<T> transform) {
this.source = transform.getSource();
}
@Override
- protected Coder<T> getDefaultOutputCoder() {
- return source.getDefaultOutputCoder();
- }
-
- @Override
public final PCollection<T> expand(PBegin input) {
source.validate();
@@ -1403,15 +1399,19 @@
static class CombineGroupedValues<K, InputT, OutputT>
extends PTransform<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>> {
private final Combine.GroupedValues<K, InputT, OutputT> original;
+ private final Coder<KV<K, OutputT>> outputCoder;
- CombineGroupedValues(GroupedValues<K, InputT, OutputT> original) {
+ CombineGroupedValues(
+ GroupedValues<K, InputT, OutputT> original, Coder<KV<K, OutputT>> outputCoder) {
this.original = original;
+ this.outputCoder = outputCoder;
}
@Override
public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, Iterable<InputT>>> input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded());
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(),
+ outputCoder);
}
public Combine.GroupedValues<K, InputT, OutputT> getOriginalCombine() {
@@ -1432,7 +1432,9 @@
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
- new CombineGroupedValues<>(transform.getTransform()));
+ new CombineGroupedValues<>(
+ transform.getTransform(),
+ PTransformReplacements.getSingletonMainOutput(transform).getCoder()));
}
@Override
@@ -1501,10 +1503,11 @@
}
try {
+ List<PCollectionView<?>> sideInputs =
+ WriteFilesTranslation.getDynamicDestinationSideInputs(transform);
+ FileBasedSink sink = WriteFilesTranslation.getSink(transform);
WriteFiles<UserT, DestinationT, OutputT> replacement =
- WriteFiles.<UserT, DestinationT, OutputT>to(
- WriteFilesTranslation.<UserT, DestinationT, OutputT>getSink(transform),
- WriteFilesTranslation.<UserT, OutputT>getFormatFunction(transform));
+ WriteFiles.to(sink).withSideInputs(sideInputs);
if (WriteFilesTranslation.isWindowedWrites(transform)) {
replacement = replacement.withWindowedWrites();
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
index 8611d3c..9252c64 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
@@ -22,6 +22,7 @@
import org.apache.beam.runners.core.construction.ForwardingPTransform;
import org.apache.beam.runners.core.construction.PTransformReplacements;
import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi.DisplayData;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
@@ -49,7 +50,9 @@
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
- new ParDoSingle<>(transform.getTransform()));
+ new ParDoSingle<>(
+ transform.getTransform(),
+ PTransformReplacements.getSingletonMainOutput(transform).getCoder()));
}
/**
@@ -58,15 +61,18 @@
public static class ParDoSingle<InputT, OutputT>
extends ForwardingPTransform<PCollection<? extends InputT>, PCollection<OutputT>> {
private final ParDo.SingleOutput<InputT, OutputT> original;
+ private final Coder<OutputT> outputCoder;
- private ParDoSingle(ParDo.SingleOutput<InputT, OutputT> original) {
+ private ParDoSingle(SingleOutput<InputT, OutputT> original, Coder<OutputT> outputCoder) {
this.original = original;
+ this.outputCoder = outputCoder;
}
@Override
public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded());
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(),
+ outputCoder);
}
public DoFn<InputT, OutputT> getFn() {
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java
index fc010f8..7b65950 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/SplittableParDoOverrides.java
@@ -64,7 +64,7 @@
appliedTransform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(appliedTransform),
- SplittableParDo.forJavaParDo(appliedTransform.getTransform()));
+ SplittableParDo.forAppliedParDo(appliedTransform));
}
@Override
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index 9a0bdf8..f756065 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -606,11 +606,6 @@
// Return a value unrelated to the input.
return input.getPipeline().apply(Create.of(1, 2, 3, 4));
}
-
- @Override
- protected Coder<?> getDefaultOutputCoder() {
- return VarIntCoder.of();
- }
}
/**
@@ -626,11 +621,6 @@
return PDone.in(input.getPipeline());
}
-
- @Override
- protected Coder<?> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
}
/**
@@ -650,10 +640,13 @@
// Fails here when attempting to construct a tuple with an unbound object.
return PCollectionTuple.of(sumTag, sum)
- .and(doneTag, PCollection.<Void>createPrimitiveOutputInternal(
- input.getPipeline(),
- WindowingStrategy.globalDefault(),
- input.isBounded()));
+ .and(
+ doneTag,
+ PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ input.isBounded(),
+ VoidCoder.of()));
}
}
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index 94985f8..55264a1 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -64,6 +64,7 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Pattern;
+import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.DataflowRunner.StreamingShardedWriteFactory;
import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions;
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
@@ -71,7 +72,6 @@
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.extensions.gcp.auth.NoopCredentialFactory;
import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
@@ -81,6 +81,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.io.WriteFiles;
+import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptions.CheckEnabled;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -102,6 +103,8 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.GcsUtil;
@@ -949,15 +952,11 @@
@Override
public PCollection<Integer> expand(PCollection<Integer> input) {
- return PCollection.<Integer>createPrimitiveOutputInternal(
+ return PCollection.createPrimitiveOutputInternal(
input.getPipeline(),
WindowingStrategy.globalDefault(),
- input.isBounded());
- }
-
- @Override
- protected Coder<?> getDefaultOutputCoder(PCollection<Integer> input) {
- return input.getCoder();
+ input.isBounded(),
+ input.getCoder());
}
}
@@ -1267,8 +1266,7 @@
StreamingShardedWriteFactory<Object, Void, Object> factory =
new StreamingShardedWriteFactory<>(p.getOptions());
- WriteFiles<Object, Void, Object> original =
- WriteFiles.to(new TestSink(tmpFolder.toString()), SerializableFunctions.identity());
+ WriteFiles<Object, Void, Object> original = WriteFiles.to(new TestSink(tmpFolder.toString()));
PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of()));
AppliedPTransform<PCollection<Object>, PDone, WriteFiles<Object, Void, Object>>
originalApplication =
@@ -1286,18 +1284,37 @@
assertThat(replacement.getNumShards().get(), equalTo(expectedNumShards));
}
- private static class TestSink extends FileBasedSink<Object, Void> {
+ private static class TestSink extends FileBasedSink<Object, Void, Object> {
@Override
public void validate(PipelineOptions options) {}
TestSink(String tmpFolder) {
super(
StaticValueProvider.of(FileSystems.matchNewResource(tmpFolder, true)),
- DynamicFileDestinations.constant(null));
+ DynamicFileDestinations.constant(
+ new FilenamePolicy() {
+ @Override
+ public ResourceId windowedFilename(
+ int shardNumber,
+ int numShards,
+ BoundedWindow window,
+ PaneInfo paneInfo,
+ OutputFileHints outputFileHints) {
+ throw new UnsupportedOperationException("should not be called");
+ }
+
+ @Nullable
+ @Override
+ public ResourceId unwindowedFilename(
+ int shardNumber, int numShards, OutputFileHints outputFileHints) {
+ throw new UnsupportedOperationException("should not be called");
+ }
+ },
+ SerializableFunctions.identity()));
}
@Override
- public WriteOperation<Object, Void> createWriteOperation() {
+ public WriteOperation<Void, Object> createWriteOperation() {
throw new IllegalArgumentException("Should not be used");
}
}
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowGroupByKeyTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowGroupByKeyTest.java
index 737b408..c198ebf 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowGroupByKeyTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowGroupByKeyTest.java
@@ -26,6 +26,7 @@
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.gcp.storage.NoopPathValidator;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
@@ -36,7 +37,6 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Duration;
import org.junit.Before;
@@ -105,11 +105,11 @@
new PTransform<PBegin, PCollection<KV<String, Integer>>>() {
@Override
public PCollection<KV<String, Integer>> expand(PBegin input) {
- return PCollection.<KV<String, Integer>>createPrimitiveOutputInternal(
- input.getPipeline(),
- WindowingStrategy.globalDefault(),
- PCollection.IsBounded.UNBOUNDED)
- .setTypeDescriptor(new TypeDescriptor<KV<String, Integer>>() {});
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.UNBOUNDED,
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
}
});
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowViewTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowViewTest.java
index dea96b9..e2e42a6 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowViewTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/transforms/DataflowViewTest.java
@@ -21,6 +21,9 @@
import org.apache.beam.runners.dataflow.DataflowRunner;
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.gcp.storage.NoopPathValidator;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
@@ -33,7 +36,6 @@
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
@@ -94,11 +96,11 @@
new PTransform<PBegin, PCollection<KV<String, Integer>>>() {
@Override
public PCollection<KV<String, Integer>> expand(PBegin input) {
- return PCollection.<KV<String, Integer>>createPrimitiveOutputInternal(
- input.getPipeline(),
- WindowingStrategy.globalDefault(),
- PCollection.IsBounded.UNBOUNDED)
- .setTypeDescriptor(new TypeDescriptor<KV<String, Integer>>() {});
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.UNBOUNDED,
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
}
})
.apply(view);
diff --git a/runners/pom.xml b/runners/pom.xml
index 4cba41a..39a9811 100644
--- a/runners/pom.xml
+++ b/runners/pom.xml
@@ -55,6 +55,15 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>java8</id>
+ <activation>
+ <jdk>[1.8,)</jdk>
+ </activation>
+ <modules>
+ <module>gearpump</module>
+ </modules>
+ </profile>
</profiles>
<build>
diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
index 7f70204..b2e7fe4 100644
--- a/runners/spark/pom.xml
+++ b/runners/spark/pom.xml
@@ -35,7 +35,6 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<kafka.version>0.9.0.1</kafka.version>
- <jackson.version>2.4.4</jackson.version>
<dropwizard.metrics.version>3.1.2</dropwizard.metrics.version>
</properties>
@@ -77,7 +76,8 @@
<excludedGroups>
org.apache.beam.sdk.testing.UsesSplittableParDo,
org.apache.beam.sdk.testing.UsesCommittedMetrics,
- org.apache.beam.sdk.testing.UsesTestStream
+ org.apache.beam.sdk.testing.UsesTestStream,
+ org.apache.beam.sdk.testing.UsesCustomWindowMerging
</excludedGroups>
<parallel>none</parallel>
<forkCount>1</forkCount>
@@ -183,22 +183,12 @@
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-core</artifactId>
- <version>${jackson.version}</version>
- </dependency>
- <dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
- <version>${jackson.version}</version>
- </dependency>
- <dependency>
- <groupId>com.fasterxml.jackson.core</groupId>
- <artifactId>jackson-databind</artifactId>
- <version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>org.apache.avro</groupId>
<artifactId>avro</artifactId>
+ <scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java
index 27f2ec8..a9f2c445 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/NamedAggregators.java
@@ -19,18 +19,11 @@
package org.apache.beam.runners.spark.aggregators;
import com.google.common.base.Function;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Map;
import java.util.TreeMap;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.Combine;
/**
@@ -52,17 +45,6 @@
}
/**
- * Constructs a new named aggregators instance that contains a mapping from the specified
- * `named` to the associated initial state.
- *
- * @param name Name of aggregator.
- * @param state Associated State.
- */
- public NamedAggregators(String name, State<?, ?, ?> state) {
- this.mNamedAggregators.put(name, state);
- }
-
- /**
* @param name Name of aggregator to retrieve.
* @param typeClass Type class to cast the value to.
* @param <T> Type to be returned.
@@ -152,79 +134,4 @@
Combine.CombineFn<InputT, InterT, OutputT> getCombineFn();
}
- /**
- * @param <InputT> Input data type
- * @param <InterT> Intermediate data type (useful for averages)
- * @param <OutputT> Output data type
- */
- public static class CombineFunctionState<InputT, InterT, OutputT>
- implements State<InputT, InterT, OutputT> {
-
- private Combine.CombineFn<InputT, InterT, OutputT> combineFn;
- private Coder<InputT> inCoder;
- private SparkRuntimeContext ctxt;
- private transient InterT state;
-
- public CombineFunctionState(
- Combine.CombineFn<InputT, InterT, OutputT> combineFn,
- Coder<InputT> inCoder,
- SparkRuntimeContext ctxt) {
- this.combineFn = combineFn;
- this.inCoder = inCoder;
- this.ctxt = ctxt;
- this.state = combineFn.createAccumulator();
- }
-
- @Override
- public void update(InputT element) {
- combineFn.addInput(state, element);
- }
-
- @Override
- public State<InputT, InterT, OutputT> merge(State<InputT, InterT, OutputT> other) {
- this.state = combineFn.mergeAccumulators(ImmutableList.of(current(), other.current()));
- return this;
- }
-
- @Override
- public InterT current() {
- return state;
- }
-
- @Override
- public OutputT render() {
- return combineFn.extractOutput(state);
- }
-
- @Override
- public Combine.CombineFn<InputT, InterT, OutputT> getCombineFn() {
- return combineFn;
- }
-
- private void writeObject(ObjectOutputStream oos) throws IOException {
- oos.writeObject(ctxt);
- oos.writeObject(combineFn);
- oos.writeObject(inCoder);
- try {
- combineFn.getAccumulatorCoder(ctxt.getCoderRegistry(), inCoder)
- .encode(state, oos);
- } catch (CannotProvideCoderException e) {
- throw new IllegalStateException("Could not determine coder for accumulator", e);
- }
- }
-
- @SuppressWarnings("unchecked")
- private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException {
- ctxt = (SparkRuntimeContext) ois.readObject();
- combineFn = (Combine.CombineFn<InputT, InterT, OutputT>) ois.readObject();
- inCoder = (Coder<InputT>) ois.readObject();
- try {
- state = combineFn.getAccumulatorCoder(ctxt.getCoderRegistry(), inCoder)
- .decode(ois);
- } catch (CannotProvideCoderException e) {
- throw new IllegalStateException("Could not determine coder for accumulator", e);
- }
- }
- }
-
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java
index fdcea99..d485d25 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/CreateStream.java
@@ -27,7 +27,6 @@
import java.util.List;
import java.util.Queue;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -203,11 +202,9 @@
@Override
public PCollection<T> expand(PBegin input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), PCollection.IsBounded.UNBOUNDED);
- }
-
- @Override
- protected Coder<T> getDefaultOutputCoder() throws CannotProvideCoderException {
- return coder;
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.UNBOUNDED,
+ coder);
}
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java
index 3b48caf..ae873a3 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java
@@ -140,8 +140,8 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
- return source.getDefaultOutputCoder();
+ public Coder<T> getOutputCoder() {
+ return source.getOutputCoder();
}
public Coder<CheckpointMarkT> getCheckpointMarkCoder() {
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
index 20aca5f..b7000b4 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
@@ -20,8 +20,8 @@
import static com.google.common.base.Preconditions.checkArgument;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.spark.api.java.JavaSparkContext$;
@@ -58,7 +58,7 @@
private static final Logger LOG = LoggerFactory.getLogger(SourceDStream.class);
private final UnboundedSource<T, CheckpointMarkT> unboundedSource;
- private final SparkRuntimeContext runtimeContext;
+ private final SerializablePipelineOptions options;
private final Duration boundReadDuration;
// Reader cache interval to expire readers if they haven't been accessed in the last microbatch.
// The reason we expire readers is that upon executor death/addition source split ownership can be
@@ -81,20 +81,20 @@
SourceDStream(
StreamingContext ssc,
UnboundedSource<T, CheckpointMarkT> unboundedSource,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
Long boundMaxRecords) {
super(ssc, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag());
this.unboundedSource = unboundedSource;
- this.runtimeContext = runtimeContext;
+ this.options = options;
- SparkPipelineOptions options = runtimeContext.getPipelineOptions().as(
+ SparkPipelineOptions sparkOptions = options.get().as(
SparkPipelineOptions.class);
// Reader cache expiration interval. 50% of batch interval is added to accommodate latency.
- this.readerCacheInterval = 1.5 * options.getBatchIntervalMillis();
+ this.readerCacheInterval = 1.5 * sparkOptions.getBatchIntervalMillis();
- this.boundReadDuration = boundReadDuration(options.getReadTimePercentage(),
- options.getMinReadTimeMillis());
+ this.boundReadDuration = boundReadDuration(sparkOptions.getReadTimePercentage(),
+ sparkOptions.getMinReadTimeMillis());
// set initial parallelism once.
this.initialParallelism = ssc().sparkContext().defaultParallelism();
checkArgument(this.initialParallelism > 0, "Number of partitions must be greater than zero.");
@@ -104,7 +104,7 @@
try {
this.numPartitions =
createMicrobatchSource()
- .split(options)
+ .split(sparkOptions)
.size();
} catch (Exception e) {
throw new RuntimeException(e);
@@ -116,7 +116,7 @@
RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd =
new SourceRDD.Unbounded<>(
ssc().sparkContext(),
- runtimeContext,
+ options,
createMicrobatchSource(),
numPartitions);
return scala.Option.apply(rdd);
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
index 01cc176..a225e0f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
@@ -28,9 +28,9 @@
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
@@ -66,7 +66,7 @@
private static final Logger LOG = LoggerFactory.getLogger(SourceRDD.Bounded.class);
private final BoundedSource<T> source;
- private final SparkRuntimeContext runtimeContext;
+ private final SerializablePipelineOptions options;
private final int numPartitions;
private final String stepName;
private final Accumulator<MetricsContainerStepMap> metricsAccum;
@@ -79,11 +79,11 @@
public Bounded(
SparkContext sc,
BoundedSource<T> source,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
String stepName) {
super(sc, NIL, JavaSparkContext$.MODULE$.<WindowedValue<T>>fakeClassTag());
this.source = source;
- this.runtimeContext = runtimeContext;
+ this.options = options;
// the input parallelism is determined by Spark's scheduler backend.
// when running on YARN/SparkDeploy it's the result of max(totalCores, 2).
// when running on Mesos it's 8.
@@ -103,14 +103,14 @@
long desiredSizeBytes = DEFAULT_BUNDLE_SIZE;
try {
desiredSizeBytes = source.getEstimatedSizeBytes(
- runtimeContext.getPipelineOptions()) / numPartitions;
+ options.get()) / numPartitions;
} catch (Exception e) {
LOG.warn("Failed to get estimated bundle size for source {}, using default bundle "
+ "size of {} bytes.", source, DEFAULT_BUNDLE_SIZE);
}
try {
List<? extends Source<T>> partitionedSources = source.split(desiredSizeBytes,
- runtimeContext.getPipelineOptions());
+ options.get());
Partition[] partitions = new SourcePartition[partitionedSources.size()];
for (int i = 0; i < partitionedSources.size(); i++) {
partitions[i] = new SourcePartition<>(id(), i, partitionedSources.get(i));
@@ -125,7 +125,7 @@
private BoundedSource.BoundedReader<T> createReader(SourcePartition<T> partition) {
try {
return ((BoundedSource<T>) partition.source).createReader(
- runtimeContext.getPipelineOptions());
+ options.get());
} catch (IOException e) {
throw new RuntimeException("Failed to create reader from a BoundedSource.", e);
}
@@ -293,7 +293,7 @@
UnboundedSource.CheckpointMark> extends RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> {
private final MicrobatchSource<T, CheckpointMarkT> microbatchSource;
- private final SparkRuntimeContext runtimeContext;
+ private final SerializablePipelineOptions options;
private final Partitioner partitioner;
// to satisfy Scala API.
@@ -302,12 +302,12 @@
.asScalaBuffer(Collections.<Dependency<?>>emptyList()).toList();
public Unbounded(SparkContext sc,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
MicrobatchSource<T, CheckpointMarkT> microbatchSource,
int initialNumPartitions) {
super(sc, NIL,
JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag());
- this.runtimeContext = runtimeContext;
+ this.options = options;
this.microbatchSource = microbatchSource;
this.partitioner = new HashPartitioner(initialNumPartitions);
}
@@ -316,7 +316,7 @@
public Partition[] getPartitions() {
try {
final List<? extends Source<T>> partitionedSources =
- microbatchSource.split(runtimeContext.getPipelineOptions());
+ microbatchSource.split(options.get());
final Partition[] partitions = new CheckpointableSourcePartition[partitionedSources.size()];
for (int i = 0; i < partitionedSources.size(); i++) {
partitions[i] =
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
index 7106c73..26af0c0 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
@@ -22,12 +22,12 @@
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.stateful.StateSpecFunctions;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.runners.spark.translation.streaming.UnboundedDataset;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks;
@@ -80,11 +80,11 @@
public static <T, CheckpointMarkT extends CheckpointMark> UnboundedDataset<T> read(
JavaStreamingContext jssc,
- SparkRuntimeContext rc,
+ SerializablePipelineOptions rc,
UnboundedSource<T, CheckpointMarkT> source,
String stepName) {
- SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class);
+ SparkPipelineOptions options = rc.get().as(SparkPipelineOptions.class);
Long maxRecordsPerBatch = options.getMaxRecordsPerBatch();
SourceDStream<T, CheckpointMarkT> sourceDStream =
new SourceDStream<>(jssc.ssc(), source, rc, maxRecordsPerBatch);
@@ -116,7 +116,7 @@
// output the actual (deserialized) stream.
WindowedValue.FullWindowedValueCoder<T> coder =
WindowedValue.FullWindowedValueCoder.of(
- source.getDefaultOutputCoder(),
+ source.getOutputCoder(),
GlobalWindow.Coder.INSTANCE);
JavaDStream<WindowedValue<T>> readUnboundedStream =
mapWithStateDStream
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index 1385e07..1263618 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -31,6 +31,7 @@
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.UnsupportedSideInputReader;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TriggerTranslation;
import org.apache.beam.runners.core.metrics.CounterCell;
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
@@ -38,7 +39,6 @@
import org.apache.beam.runners.core.triggers.TriggerStateMachines;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.runners.spark.translation.WindowingHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
@@ -108,11 +108,11 @@
final Coder<K> keyCoder,
final Coder<WindowedValue<InputT>> wvCoder,
final WindowingStrategy<?, W> windowingStrategy,
- final SparkRuntimeContext runtimeContext,
+ final SerializablePipelineOptions options,
final List<Integer> sourceIds) {
final long batchDurationMillis =
- runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis();
+ options.get().as(SparkPipelineOptions.class).getBatchIntervalMillis();
final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder);
final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder();
final Coder<? extends BoundedWindow> wCoder =
@@ -123,7 +123,7 @@
TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
long checkpointDurationMillis =
- runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class)
+ options.get().as(SparkPipelineOptions.class)
.getCheckpointDurationMillis();
// we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819.
@@ -268,7 +268,7 @@
outputHolder,
new UnsupportedSideInputReader("GroupAlsoByWindow"),
reduceFn,
- runtimeContext.getPipelineOptions());
+ options.get());
outputHolder.clear(); // clear before potential use.
if (!seq.isEmpty()) {
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
index 549bd30..ca54715 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
@@ -27,12 +27,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.io.EmptyCheckpointMark;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.io.SparkUnboundedSource.Metadata;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
@@ -91,7 +91,7 @@
*
* <p>See also <a href="https://issues.apache.org/jira/browse/SPARK-4819">SPARK-4819</a>.</p>
*
- * @param runtimeContext A serializable {@link SparkRuntimeContext}.
+ * @param options A serializable {@link SerializablePipelineOptions}.
* @param <T> The type of the input stream elements.
* @param <CheckpointMarkT> The type of the {@link UnboundedSource.CheckpointMark}.
* @return The appropriate {@link org.apache.spark.streaming.StateSpec} function.
@@ -99,7 +99,7 @@
public static <T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
scala.Function3<Source<T>, scala.Option<CheckpointMarkT>, State<Tuple2<byte[], Instant>>,
Tuple2<Iterable<byte[]>, Metadata>> mapSourceFunction(
- final SparkRuntimeContext runtimeContext, final String stepName) {
+ final SerializablePipelineOptions options, final String stepName) {
return new SerializableFunction3<Source<T>, Option<CheckpointMarkT>,
State<Tuple2<byte[], Instant>>, Tuple2<Iterable<byte[]>, Metadata>>() {
@@ -151,7 +151,7 @@
try {
microbatchReader =
(MicrobatchSource.Reader)
- microbatchSource.getOrCreateReader(runtimeContext.getPipelineOptions(),
+ microbatchSource.getOrCreateReader(options.get(),
checkpointMark);
} catch (IOException e) {
throw new RuntimeException(e);
@@ -161,7 +161,7 @@
final List<byte[]> readValues = new ArrayList<>();
WindowedValue.FullWindowedValueCoder<T> coder =
WindowedValue.FullWindowedValueCoder.of(
- source.getDefaultOutputCoder(),
+ source.getOutputCoder(),
GlobalWindow.Coder.INSTANCE);
try {
// measure how long a read takes per-partition.
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index 0c6c4d1..463e507 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -26,6 +26,7 @@
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -50,7 +51,6 @@
public class EvaluationContext {
private final JavaSparkContext jsc;
private JavaStreamingContext jssc;
- private final SparkRuntimeContext runtime;
private final Pipeline pipeline;
private final Map<PValue, Dataset> datasets = new LinkedHashMap<>();
private final Map<PValue, Dataset> pcollections = new LinkedHashMap<>();
@@ -60,12 +60,13 @@
private final SparkPCollectionView pviews = new SparkPCollectionView();
private final Map<PCollection, Long> cacheCandidates = new HashMap<>();
private final PipelineOptions options;
+ private final SerializablePipelineOptions serializableOptions;
public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, PipelineOptions options) {
this.jsc = jsc;
this.pipeline = pipeline;
this.options = options;
- this.runtime = new SparkRuntimeContext(pipeline, options);
+ this.serializableOptions = new SerializablePipelineOptions(options);
}
public EvaluationContext(
@@ -90,8 +91,8 @@
return options;
}
- public SparkRuntimeContext getRuntimeContext() {
- return runtime;
+ public SerializablePipelineOptions getSerializableOptions() {
+ return serializableOptions;
}
public void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) {
@@ -254,7 +255,7 @@
}
private String storageLevel() {
- return runtime.getPipelineOptions().as(SparkPipelineOptions.class).getStorageLevel();
+ return serializableOptions.get().as(SparkPipelineOptions.class).getStorageLevel();
}
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 23d5b32..7299583 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -34,8 +34,8 @@
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StepContext;
import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
-import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.runners.spark.util.SparkSideInputReader;
import org.apache.beam.sdk.transforms.DoFn;
@@ -59,11 +59,10 @@
public class MultiDoFnFunction<InputT, OutputT>
implements PairFlatMapFunction<Iterator<WindowedValue<InputT>>, TupleTag<?>, WindowedValue<?>> {
- private final Accumulator<NamedAggregators> aggAccum;
private final Accumulator<MetricsContainerStepMap> metricsAccum;
private final String stepName;
private final DoFn<InputT, OutputT> doFn;
- private final SparkRuntimeContext runtimeContext;
+ private final SerializablePipelineOptions options;
private final TupleTag<OutputT> mainOutputTag;
private final List<TupleTag<?>> additionalOutputTags;
private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs;
@@ -71,10 +70,9 @@
private final boolean stateful;
/**
- * @param aggAccum The Spark {@link Accumulator} that backs the Beam Aggregators.
* @param metricsAccum The Spark {@link Accumulator} that backs the Beam metrics.
* @param doFn The {@link DoFn} to be wrapped.
- * @param runtimeContext The {@link SparkRuntimeContext}.
+ * @param options The {@link SerializablePipelineOptions}.
* @param mainOutputTag The main output {@link TupleTag}.
* @param additionalOutputTags Additional {@link TupleTag output tags}.
* @param sideInputs Side inputs used in this {@link DoFn}.
@@ -82,21 +80,19 @@
* @param stateful Stateful {@link DoFn}.
*/
public MultiDoFnFunction(
- Accumulator<NamedAggregators> aggAccum,
Accumulator<MetricsContainerStepMap> metricsAccum,
String stepName,
DoFn<InputT, OutputT> doFn,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy,
boolean stateful) {
- this.aggAccum = aggAccum;
this.metricsAccum = metricsAccum;
this.stepName = stepName;
this.doFn = doFn;
- this.runtimeContext = runtimeContext;
+ this.options = options;
this.mainOutputTag = mainOutputTag;
this.additionalOutputTags = additionalOutputTags;
this.sideInputs = sideInputs;
@@ -140,7 +136,7 @@
final DoFnRunner<InputT, OutputT> doFnRunner =
DoFnRunners.simpleRunner(
- runtimeContext.getPipelineOptions(),
+ options.get(),
doFn,
new SparkSideInputReader(sideInputs),
outputManager,
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAbstractCombineFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAbstractCombineFn.java
index 315f7fb..d8d71ff 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAbstractCombineFn.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAbstractCombineFn.java
@@ -30,6 +30,7 @@
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.runners.spark.util.SparkSideInputReader;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -48,16 +49,16 @@
* {@link org.apache.beam.sdk.transforms.Combine.CombineFn}.
*/
public class SparkAbstractCombineFn implements Serializable {
- protected final SparkRuntimeContext runtimeContext;
+ protected final SerializablePipelineOptions options;
protected final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs;
protected final WindowingStrategy<?, BoundedWindow> windowingStrategy;
public SparkAbstractCombineFn(
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy) {
- this.runtimeContext = runtimeContext;
+ this.options = options;
this.sideInputs = sideInputs;
this.windowingStrategy = (WindowingStrategy<?, BoundedWindow>) windowingStrategy;
}
@@ -71,7 +72,7 @@
private transient SparkCombineContext combineContext;
protected SparkCombineContext ctxtForInput(WindowedValue<?> input) {
if (combineContext == null) {
- combineContext = new SparkCombineContext(runtimeContext.getPipelineOptions(),
+ combineContext = new SparkCombineContext(options.get(),
new SparkSideInputReader(sideInputs));
}
return combineContext.forInput(input);
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGlobalCombineFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGlobalCombineFn.java
index d0e9038..81416a3 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGlobalCombineFn.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGlobalCombineFn.java
@@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -49,10 +50,10 @@
public SparkGlobalCombineFn(
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy) {
- super(runtimeContext, sideInputs, windowingStrategy);
+ super(options, sideInputs, windowingStrategy);
this.combineFn = combineFn;
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowViaOutputBufferFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowViaOutputBufferFn.java
index d2a3424..fcf438c 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowViaOutputBufferFn.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowViaOutputBufferFn.java
@@ -30,6 +30,7 @@
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.UnsupportedSideInputReader;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TriggerTranslation;
import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
import org.apache.beam.runners.core.triggers.TriggerStateMachines;
@@ -55,18 +56,18 @@
private final WindowingStrategy<?, W> windowingStrategy;
private final StateInternalsFactory<K> stateInternalsFactory;
private final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn;
- private final SparkRuntimeContext runtimeContext;
+ private final SerializablePipelineOptions options;
public SparkGroupAlsoByWindowViaOutputBufferFn(
WindowingStrategy<?, W> windowingStrategy,
StateInternalsFactory<K> stateInternalsFactory,
SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
Accumulator<NamedAggregators> accumulator) {
this.windowingStrategy = windowingStrategy;
this.stateInternalsFactory = stateInternalsFactory;
this.reduceFn = reduceFn;
- this.runtimeContext = runtimeContext;
+ this.options = options;
}
@Override
@@ -98,7 +99,7 @@
outputter,
new UnsupportedSideInputReader("GroupAlsoByWindow"),
reduceFn,
- runtimeContext.getPipelineOptions());
+ options.get());
// Process the grouped values.
reduceFnRunner.processElements(values);
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java
index 7ac8e7d..55392e9 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java
@@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -49,10 +50,10 @@
public SparkKeyedCombineFn(
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn,
- SparkRuntimeContext runtimeContext,
+ SerializablePipelineOptions options,
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
WindowingStrategy<?, ?> windowingStrategy) {
- super(runtimeContext, sideInputs, windowingStrategy);
+ super(options, sideInputs, windowingStrategy);
this.combineFn = combineFn;
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
deleted file mode 100644
index f3fe99c..0000000
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.translation;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.Module;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import java.io.IOException;
-import java.io.Serializable;
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.io.FileSystems;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.util.common.ReflectHelpers;
-
-/**
- * The SparkRuntimeContext allows us to define useful features on the client side before our
- * data flow program is launched.
- */
-public class SparkRuntimeContext implements Serializable {
- private final String serializedPipelineOptions;
- private transient CoderRegistry coderRegistry;
-
- SparkRuntimeContext(Pipeline pipeline, PipelineOptions options) {
- this.serializedPipelineOptions = serializePipelineOptions(options);
- }
-
- /**
- * Use an {@link ObjectMapper} configured with any {@link Module}s in the class path allowing
- * for user specified configuration injection into the ObjectMapper. This supports user custom
- * types on {@link PipelineOptions}.
- */
- private static ObjectMapper createMapper() {
- return new ObjectMapper().registerModules(
- ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
- }
-
- private String serializePipelineOptions(PipelineOptions pipelineOptions) {
- try {
- return createMapper().writeValueAsString(pipelineOptions);
- } catch (JsonProcessingException e) {
- throw new IllegalStateException("Failed to serialize the pipeline options.", e);
- }
- }
-
- private static PipelineOptions deserializePipelineOptions(String serializedPipelineOptions) {
- try {
- return createMapper().readValue(serializedPipelineOptions, PipelineOptions.class);
- } catch (IOException e) {
- throw new IllegalStateException("Failed to deserialize the pipeline options.", e);
- }
- }
-
- public PipelineOptions getPipelineOptions() {
- return PipelineOptionsHolder.getOrInit(serializedPipelineOptions);
- }
-
- public CoderRegistry getCoderRegistry() {
- if (coderRegistry == null) {
- coderRegistry = CoderRegistry.createDefault();
- }
- return coderRegistry;
- }
-
- private static class PipelineOptionsHolder {
- // on executors, this should deserialize once.
- private static transient volatile PipelineOptions pipelineOptions = null;
-
- static PipelineOptions getOrInit(String serializedPipelineOptions) {
- if (pipelineOptions == null) {
- synchronized (PipelineOptionsHolder.class) {
- if (pipelineOptions == null) {
- pipelineOptions = deserializePipelineOptions(serializedPipelineOptions);
- }
- }
- // Register standard FileSystems.
- FileSystems.setDefaultPipelineOptions(pipelineOptions);
- }
- return pipelineOptions;
- }
- }
-}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
index 0ecfa75..b236ce7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.spark.translation;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
@@ -32,12 +31,7 @@
public PCollection<String> expand(PCollection<?> input) {
return PCollection.createPrimitiveOutputInternal(input.getPipeline(),
WindowingStrategy.globalDefault(),
- PCollection.IsBounded.BOUNDED);
+ PCollection.IsBounded.BOUNDED,
+ StringUtf8Coder.of());
}
-
- @Override
- public Coder getDefaultOutputCoder() {
- return StringUtf8Coder.of();
- }
-
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index ac5e0cd..e060e1d 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -146,7 +146,7 @@
windowingStrategy,
new TranslationUtils.InMemoryStateInternalsFactory<K>(),
SystemReduceFn.<K, V, W>buffering(coder.getValueCoder()),
- context.getRuntimeContext(),
+ context.getSerializableOptions(),
accum));
context.putDataset(transform, new BoundedDataset<>(groupedAlsoByWindow));
@@ -171,7 +171,7 @@
(CombineWithContext.CombineFnWithContext<InputT, ?, OutputT>)
CombineFnUtil.toFnWithContext(transform.getFn());
final SparkKeyedCombineFn<K, InputT, ?, OutputT> sparkCombineFn =
- new SparkKeyedCombineFn<>(combineFn, context.getRuntimeContext(),
+ new SparkKeyedCombineFn<>(combineFn, context.getSerializableOptions(),
TranslationUtils.getSideInputs(transform.getSideInputs(), context),
context.getInput(transform).getWindowingStrategy());
@@ -222,18 +222,18 @@
final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder =
WindowedValue.FullWindowedValueCoder.of(oCoder,
windowingStrategy.getWindowFn().windowCoder());
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final boolean hasDefault = transform.isInsertDefault();
final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn =
new SparkGlobalCombineFn<>(
combineFn,
- runtimeContext,
+ context.getSerializableOptions(),
TranslationUtils.getSideInputs(transform.getSideInputs(), context),
windowingStrategy);
final Coder<AccumT> aCoder;
try {
- aCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder);
+ aCoder = combineFn.getAccumulatorCoder(
+ context.getPipeline().getCoderRegistry(), iCoder);
} catch (CannotProvideCoderException e) {
throw new IllegalStateException("Could not determine coder for accumulator", e);
}
@@ -295,16 +295,16 @@
(CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>)
CombineFnUtil.toFnWithContext(transform.getFn());
final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
TranslationUtils.getSideInputs(transform.getSideInputs(), context);
final SparkKeyedCombineFn<K, InputT, AccumT, OutputT> sparkCombineFn =
- new SparkKeyedCombineFn<>(combineFn, runtimeContext, sideInputs, windowingStrategy);
+ new SparkKeyedCombineFn<>(
+ combineFn, context.getSerializableOptions(), sideInputs, windowingStrategy);
final Coder<AccumT> vaCoder;
try {
vaCoder =
combineFn.getAccumulatorCoder(
- runtimeContext.getCoderRegistry(), inputCoder.getValueCoder());
+ context.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
} catch (CannotProvideCoderException e) {
throw new IllegalStateException("Could not determine coder for accumulator", e);
}
@@ -360,7 +360,6 @@
((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
WindowingStrategy<?, ?> windowingStrategy =
context.getInput(transform).getWindowingStrategy();
- Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance();
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all;
@@ -370,11 +369,10 @@
|| signature.timerDeclarations().size() > 0;
MultiDoFnFunction<InputT, OutputT> multiDoFnFunction = new MultiDoFnFunction<>(
- aggAccum,
metricsAccum,
stepName,
doFn,
- context.getRuntimeContext(),
+ context.getSerializableOptions(),
transform.getMainOutputTag(),
transform.getAdditionalOutputTags().getAll(),
TranslationUtils.getSideInputs(transform.getSideInputs(), context),
@@ -452,10 +450,11 @@
public void evaluate(Read.Bounded<T> transform, EvaluationContext context) {
String stepName = context.getCurrentTransform().getFullName();
final JavaSparkContext jsc = context.getSparkContext();
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
// create an RDD from a BoundedSource.
- JavaRDD<WindowedValue<T>> input = new SourceRDD.Bounded<>(
- jsc.sc(), transform.getSource(), runtimeContext, stepName).toJavaRDD();
+ JavaRDD<WindowedValue<T>> input =
+ new SourceRDD.Bounded<>(
+ jsc.sc(), transform.getSource(), context.getSerializableOptions(), stepName)
+ .toJavaRDD();
// cache to avoid re-evaluation of the source by Spark's lazy DAG evaluation.
context.putDataset(transform, new BoundedDataset<>(input.cache()));
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index cd5bb3e..38d6119 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -32,9 +32,8 @@
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import javax.annotation.Nonnull;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
-import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
-import org.apache.beam.runners.spark.aggregators.NamedAggregators;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.io.ConsoleIO;
import org.apache.beam.runners.spark.io.CreateStream;
@@ -50,7 +49,6 @@
import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn;
import org.apache.beam.runners.spark.translation.SparkPCollectionView;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
-import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.runners.spark.translation.TransformEvaluator;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.runners.spark.translation.WindowingHelpers;
@@ -125,7 +123,7 @@
transform,
SparkUnboundedSource.read(
context.getStreamingContext(),
- context.getRuntimeContext(),
+ context.getSerializableOptions(),
transform.getSource(),
stepName));
}
@@ -273,7 +271,6 @@
JavaDStream<WindowedValue<KV<K, V>>> dStream = inputDataset.getDStream();
@SuppressWarnings("unchecked")
final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder();
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
@SuppressWarnings("unchecked")
final WindowingStrategy<?, W> windowingStrategy =
(WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy();
@@ -303,7 +300,7 @@
coder.getKeyCoder(),
wvCoder,
windowingStrategy,
- runtimeContext,
+ context.getSerializableOptions(),
streamSources);
context.putDataset(transform, new UnboundedDataset<>(outStream, streamSources));
@@ -336,7 +333,7 @@
((UnboundedDataset<KV<K, Iterable<InputT>>>) context.borrowDataset(transform));
JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> dStream = unboundedDataset.getDStream();
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
+ final SerializablePipelineOptions options = context.getSerializableOptions();
final SparkPCollectionView pviews = context.getPViews();
JavaDStream<WindowedValue<KV<K, OutputT>>> outStream = dStream.transform(
@@ -347,7 +344,7 @@
call(JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> rdd)
throws Exception {
SparkKeyedCombineFn<K, InputT, ?, OutputT> combineFnWithContext =
- new SparkKeyedCombineFn<>(fn, runtimeContext,
+ new SparkKeyedCombineFn<>(fn, options,
TranslationUtils.getSideInputs(transform.getSideInputs(),
new JavaSparkContext(rdd.context()), pviews),
windowingStrategy);
@@ -374,7 +371,7 @@
final DoFn<InputT, OutputT> doFn = transform.getFn();
rejectSplittable(doFn);
rejectStateAndTimers(doFn);
- final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
+ final SerializablePipelineOptions options = context.getSerializableOptions();
final SparkPCollectionView pviews = context.getPViews();
final WindowingStrategy<?, ?> windowingStrategy =
context.getInput(transform).getWindowingStrategy();
@@ -393,8 +390,6 @@
@Override
public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
- final Accumulator<NamedAggregators> aggAccum =
- AggregatorsAccumulator.getInstance();
final Accumulator<MetricsContainerStepMap> metricsAccum =
MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
@@ -405,11 +400,10 @@
pviews);
return rdd.mapPartitionsToPair(
new MultiDoFnFunction<>(
- aggAccum,
metricsAccum,
stepName,
doFn,
- runtimeContext,
+ options,
transform.getMainOutputTag(),
transform.getAdditionalOutputTags().getAll(),
sideInputs,
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SinglePrimitiveOutputPTransform.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SinglePrimitiveOutputPTransform.java
deleted file mode 100644
index 299f5ba..0000000
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SinglePrimitiveOutputPTransform.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.util;
-
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollection.IsBounded;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.WindowingStrategy;
-
-/**
- * A {@link PTransform} wrapping another transform.
- */
-public class SinglePrimitiveOutputPTransform<T> extends PTransform<PInput, PCollection<T>> {
- private PTransform<PInput, PCollection<T>> transform;
-
- public SinglePrimitiveOutputPTransform(PTransform<PInput, PCollection<T>> transform) {
- this.transform = transform;
- }
-
- @Override
- public PCollection<T> expand(PInput input) {
- try {
- PCollection<T> collection = PCollection.<T>createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
- collection.setCoder(transform.getDefaultOutputCoder(input, collection));
- return collection;
- } catch (CannotProvideCoderException e) {
- throw new IllegalArgumentException(
- "Unable to infer a coder and no Coder was specified. "
- + "Please set a coder by invoking Create.withCoder() explicitly.",
- e);
- }
- }
-}
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkRuntimeContextTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkRuntimeContextTest.java
deleted file mode 100644
index e8f578a..0000000
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkRuntimeContextTest.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.translation;
-
-import static org.junit.Assert.assertEquals;
-
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.core.JsonParser;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.DeserializationContext;
-import com.fasterxml.jackson.databind.JsonDeserializer;
-import com.fasterxml.jackson.databind.JsonSerializer;
-import com.fasterxml.jackson.databind.Module;
-import com.fasterxml.jackson.databind.SerializerProvider;
-import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
-import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-import com.fasterxml.jackson.databind.module.SimpleModule;
-import com.google.auto.service.AutoService;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.testing.CrashingRunner;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/**
- * Tests for {@link SparkRuntimeContext}.
- */
-@RunWith(JUnit4.class)
-public class SparkRuntimeContextTest {
- /** PipelineOptions used to test auto registration of Jackson modules. */
- public interface JacksonIncompatibleOptions extends PipelineOptions {
- JacksonIncompatible getJacksonIncompatible();
- void setJacksonIncompatible(JacksonIncompatible value);
- }
-
- /** A Jackson {@link Module} to test auto-registration of modules. */
- @AutoService(Module.class)
- public static class RegisteredTestModule extends SimpleModule {
- public RegisteredTestModule() {
- super("RegisteredTestModule");
- setMixInAnnotation(JacksonIncompatible.class, JacksonIncompatibleMixin.class);
- }
- }
-
- /** A class which Jackson does not know how to serialize/deserialize. */
- public static class JacksonIncompatible {
- private final String value;
- public JacksonIncompatible(String value) {
- this.value = value;
- }
- }
-
- /** A Jackson mixin used to add annotations to other classes. */
- @JsonDeserialize(using = JacksonIncompatibleDeserializer.class)
- @JsonSerialize(using = JacksonIncompatibleSerializer.class)
- public static final class JacksonIncompatibleMixin {}
-
- /** A Jackson deserializer for {@link JacksonIncompatible}. */
- public static class JacksonIncompatibleDeserializer extends
- JsonDeserializer<JacksonIncompatible> {
-
- @Override
- public JacksonIncompatible deserialize(JsonParser jsonParser,
- DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
- return new JacksonIncompatible(jsonParser.readValueAs(String.class));
- }
- }
-
- /** A Jackson serializer for {@link JacksonIncompatible}. */
- public static class JacksonIncompatibleSerializer extends JsonSerializer<JacksonIncompatible> {
-
- @Override
- public void serialize(JacksonIncompatible jacksonIncompatible, JsonGenerator jsonGenerator,
- SerializerProvider serializerProvider) throws IOException, JsonProcessingException {
- jsonGenerator.writeString(jacksonIncompatible.value);
- }
- }
-
- @Test
- public void testSerializingPipelineOptionsWithCustomUserType() throws Exception {
- PipelineOptions options = PipelineOptionsFactory.fromArgs("--jacksonIncompatible=\"testValue\"")
- .as(JacksonIncompatibleOptions.class);
- options.setRunner(CrashingRunner.class);
- Pipeline p = Pipeline.create(options);
- SparkRuntimeContext context = new SparkRuntimeContext(p, options);
-
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- try (ObjectOutputStream outputStream = new ObjectOutputStream(baos)) {
- outputStream.writeObject(context);
- }
- try (ObjectInputStream inputStream =
- new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()))) {
- SparkRuntimeContext copy = (SparkRuntimeContext) inputStream.readObject();
- assertEquals("testValue",
- copy.getPipelineOptions().as(JacksonIncompatibleOptions.class)
- .getJacksonIncompatible().value);
- }
- }
-}
diff --git a/sdks/common/runner-api/pom.xml b/sdks/common/runner-api/pom.xml
index 8bc4123..e138ca8 100644
--- a/sdks/common/runner-api/pom.xml
+++ b/sdks/common/runner-api/pom.xml
@@ -65,11 +65,14 @@
<artifactId>protobuf-maven-plugin</artifactId>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
+ <pluginId>grpc-java</pluginId>
+ <pluginArtifact>io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
</configuration>
<executions>
<execution>
<goals>
<goal>compile</goal>
+ <goal>compile-custom</goal>
</goals>
</execution>
</executions>
@@ -82,5 +85,25 @@
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</dependency>
+
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-core</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-protobuf</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-stub</artifactId>
+ </dependency>
</dependencies>
</project>
diff --git a/sdks/common/runner-api/src/main/proto/beam_job_api.proto b/sdks/common/runner-api/src/main/proto/beam_job_api.proto
new file mode 100644
index 0000000..7be14cc
--- /dev/null
+++ b/sdks/common/runner-api/src/main/proto/beam_job_api.proto
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Protocol Buffers describing the Job API, api for communicating with a runner
+ * for job submission over GRPC.
+ */
+
+syntax = "proto3";
+
+package org.apache.beam.runner_api.v1;
+
+option java_package = "org.apache.beam.sdk.common.runner.v1";
+option java_outer_classname = "JobApi";
+
+import "beam_runner_api.proto";
+import "google/protobuf/struct.proto";
+
+
+// Job Service for running RunnerAPI pipelines
+service JobService {
+ // Submit the job for execution
+ rpc run (SubmitJobRequest) returns (SubmitJobResponse) {}
+
+ // Get the current state of the job
+ rpc getState (GetJobStateRequest) returns (GetJobStateResponse) {}
+
+ // Cancel the job
+ rpc cancel (CancelJobRequest) returns (CancelJobResponse) {}
+
+ // Subscribe to a stream of state changes of the job, will immediately return the current state of the job as the first response.
+ rpc getStateStream (GetJobStateRequest) returns (stream GetJobStateResponse) {}
+
+ // Subscribe to a stream of state changes and messages from the job
+ rpc getMessageStream (JobMessagesRequest) returns (stream JobMessagesResponse) {}
+}
+
+
+// Submit is a synchronus request that returns a jobId back
+// Throws error GRPC_STATUS_UNAVAILABLE if server is down
+// Throws error ALREADY_EXISTS if the jobName is reused as runners are permitted to deduplicate based on the name of the job.
+// Throws error UNKNOWN for all other issues
+message SubmitJobRequest {
+ org.apache.beam.runner_api.v1.Pipeline pipeline = 1; // (required)
+ google.protobuf.Struct pipelineOptions = 2; // (required)
+ string jobName = 3; // (required)
+}
+
+message SubmitJobResponse {
+ // JobId is used as an identifier for the job in all future calls.
+ string jobId = 1; // (required)
+}
+
+
+// Cancel is a synchronus request that returns a jobState back
+// Throws error GRPC_STATUS_UNAVAILABLE if server is down
+// Throws error NOT_FOUND if the jobId is not found
+message CancelJobRequest {
+ string jobId = 1; // (required)
+
+}
+
+// Valid responses include any terminal state or CANCELLING
+message CancelJobResponse {
+ JobState.JobStateType state = 1; // (required)
+}
+
+
+// GetState is a synchronus request that returns a jobState back
+// Throws error GRPC_STATUS_UNAVAILABLE if server is down
+// Throws error NOT_FOUND if the jobId is not found
+message GetJobStateRequest {
+ string jobId = 1; // (required)
+
+}
+
+message GetJobStateResponse {
+ JobState.JobStateType state = 1; // (required)
+}
+
+
+// GetJobMessages is a streaming api for streaming job messages from the service
+// One request will connect you to the job and you'll get a stream of job state
+// and job messages back; one is used for logging and the other for detecting
+// the job ended.
+message JobMessagesRequest {
+ string jobId = 1; // (required)
+
+}
+
+message JobMessage {
+ string messageId = 1;
+ string time = 2;
+ MessageImportance importance = 3;
+ string messageText = 4;
+
+ enum MessageImportance {
+ JOB_MESSAGE_DEBUG = 0;
+ JOB_MESSAGE_DETAILED = 1;
+ JOB_MESSAGE_BASIC = 2;
+ JOB_MESSAGE_WARNING = 3;
+ JOB_MESSAGE_ERROR = 4;
+ }
+}
+
+message JobMessagesResponse {
+ oneof response {
+ JobMessage messageResponse = 1;
+ GetJobStateResponse stateResponse = 2;
+ }
+}
+
+message JobState {
+ // Enumeration of all JobStates
+ enum JobStateType {
+ UNKNOWN = 0;
+ STOPPED = 1;
+ RUNNING = 2;
+ DONE = 3;
+ FAILED = 4;
+ CANCELLED = 5;
+ UPDATED = 6;
+ DRAINING = 7;
+ DRAINED = 8;
+ STARTING = 9;
+ CANCELLING = 10;
+ }
+}
diff --git a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto
index 711da2a..9afb565 100644
--- a/sdks/common/runner-api/src/main/proto/beam_runner_api.proto
+++ b/sdks/common/runner-api/src/main/proto/beam_runner_api.proto
@@ -92,7 +92,9 @@
// this pipeline.
Components components = 1;
- // (Required) The ids of all PTransforms that are not contained within another PTransform
+ // (Required) The ids of all PTransforms that are not contained within another PTransform.
+ // These must be in shallow topological order, so that traversing them recursively
+ // in this order yields a recursively topological traversal.
repeated string root_transform_ids = 2;
// (Optional) Static display data for the pipeline. If there is none,
@@ -286,8 +288,8 @@
}
enum IsBounded {
- BOUNDED = 0;
- UNBOUNDED = 1;
+ UNBOUNDED = 0;
+ BOUNDED = 1;
}
// The payload for the primitive Read transform.
@@ -373,6 +375,8 @@
bool windowed_writes = 3;
bool runner_determined_sharding = 4;
+
+ map<string, SideInput> side_inputs = 5;
}
// A coder, the binary format for serialization and deserialization of data in
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index bdf8a12..760efb3 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -180,6 +180,12 @@
return begin().apply(name, root);
}
+ @Internal
+ public static Pipeline forTransformHierarchy(
+ TransformHierarchy transforms, PipelineOptions options) {
+ return new Pipeline(transforms, options);
+ }
+
/**
* <b><i>For internal use only; no backwards-compatibility guarantees.</i></b>
*
@@ -476,16 +482,21 @@
/////////////////////////////////////////////////////////////////////////////
// Below here are internal operations, never called by users.
- private final TransformHierarchy transforms = new TransformHierarchy();
+ private final TransformHierarchy transforms;
private Set<String> usedFullNames = new HashSet<>();
private CoderRegistry coderRegistry;
private final List<String> unstableNames = new ArrayList<>();
private final PipelineOptions defaultOptions;
- protected Pipeline(PipelineOptions options) {
+ private Pipeline(TransformHierarchy transforms, PipelineOptions options) {
+ this.transforms = transforms;
this.defaultOptions = options;
}
+ protected Pipeline(PipelineOptions options) {
+ this(new TransformHierarchy(), options);
+ }
+
@Override
public String toString() {
return "Pipeline#" + hashCode();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/annotations/Experimental.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/annotations/Experimental.java
index 8224ebb..80c4613 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/annotations/Experimental.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/annotations/Experimental.java
@@ -72,8 +72,9 @@
OUTPUT_TIME,
/**
- * <a href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a>.
- * Do not use: API is unstable and runner support is incomplete.
+ * <a href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a>. See <a
+ * href="https://beam.apache.org/documentation/runners/capability-matrix/">capability matrix</a>
+ * for runner support.
*/
SPLITTABLE_DO_FN,
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BooleanCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BooleanCoder.java
new file mode 100644
index 0000000..e7f7543
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BooleanCoder.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.coders;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/** A {@link Coder} for {@link Boolean}. */
+public class BooleanCoder extends AtomicCoder<Boolean> {
+ private static final ByteCoder BYTE_CODER = ByteCoder.of();
+
+ private static final BooleanCoder INSTANCE = new BooleanCoder();
+
+ /** Returns the singleton instance of {@link BooleanCoder}. */
+ public static BooleanCoder of() {
+ return INSTANCE;
+ }
+
+ @Override
+ public void encode(Boolean value, OutputStream os) throws IOException {
+ BYTE_CODER.encode(value ? (byte) 1 : 0, os);
+ }
+
+ @Override
+ public Boolean decode(InputStream is) throws IOException {
+ return BYTE_CODER.decode(is) == 1;
+ }
+
+ @Override
+ public boolean consistentWithEquals() {
+ return true;
+ }
+
+ @Override
+ public boolean isRegisterByteSizeObserverCheap(Boolean value) {
+ return true;
+ }
+
+ @Override
+ protected long getEncodedElementByteSize(Boolean value) throws Exception {
+ return 1;
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java
index 53cb6d3..c335bda 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java
@@ -43,6 +43,10 @@
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.coders.CannotProvideCoderException.ReasonCode;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
+import org.apache.beam.sdk.io.fs.MetadataCoder;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.io.fs.ResourceIdCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.util.CoderUtils;
@@ -89,6 +93,8 @@
private CommonTypes() {
ImmutableMap.Builder<Class<?>, CoderProvider> builder = ImmutableMap.builder();
+ builder.put(Boolean.class,
+ CoderProviders.fromStaticMethods(Boolean.class, BooleanCoder.class));
builder.put(Byte.class,
CoderProviders.fromStaticMethods(Byte.class, ByteCoder.class));
builder.put(BitSet.class,
@@ -109,6 +115,10 @@
CoderProviders.fromStaticMethods(Long.class, VarLongCoder.class));
builder.put(Map.class,
CoderProviders.fromStaticMethods(Map.class, MapCoder.class));
+ builder.put(Metadata.class,
+ CoderProviders.fromStaticMethods(Metadata.class, MetadataCoder.class));
+ builder.put(ResourceId.class,
+ CoderProviders.fromStaticMethods(ResourceId.class, ResourceIdCoder.class));
builder.put(Set.class,
CoderProviders.fromStaticMethods(Set.class, SetCoder.class));
builder.put(String.class,
@@ -147,9 +157,13 @@
Set<CoderProviderRegistrar> registrars = Sets.newTreeSet(ObjectsClassComparator.INSTANCE);
registrars.addAll(Lists.newArrayList(
ServiceLoader.load(CoderProviderRegistrar.class, ReflectHelpers.findClassLoader())));
+
+ // DefaultCoder should have the highest precedence and SerializableCoder the lowest
+ codersToRegister.addAll(new DefaultCoder.DefaultCoderProviderRegistrar().getCoderProviders());
for (CoderProviderRegistrar registrar : registrars) {
codersToRegister.addAll(registrar.getCoderProviders());
}
+ codersToRegister.add(SerializableCoder.getCoderProvider());
REGISTERED_CODER_FACTORIES = ImmutableList.copyOf(codersToRegister);
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/DefaultCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/DefaultCoder.java
index 6eff9e9..7eb2ecb 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/DefaultCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/DefaultCoder.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.sdk.coders;
-import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableList;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
@@ -57,7 +56,6 @@
* the {@code @DefaultCoder} annotation to provide {@link CoderProvider coder providers} that
* creates {@link Coder}s.
*/
- @AutoService(CoderProviderRegistrar.class)
class DefaultCoderProviderRegistrar implements CoderProviderRegistrar {
@Override
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SerializableCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SerializableCoder.java
index 6691876..9204942 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SerializableCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/SerializableCoder.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.sdk.coders;
-import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.io.InputStream;
@@ -80,7 +79,6 @@
* A {@link CoderProviderRegistrar} which registers a {@link CoderProvider} which can handle
* serializable types.
*/
- @AutoService(CoderProviderRegistrar.class)
public static class SerializableCoderProviderRegistrar implements CoderProviderRegistrar {
@Override
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java
index 89cadbd..653b806 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroIO.java
@@ -18,11 +18,13 @@
package org.apache.beam.sdk.io;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
import com.google.auto.value.AutoValue;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
-import com.google.common.io.BaseEncoding;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.avro.Schema;
@@ -32,36 +34,48 @@
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.AvroCoder;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy;
-import org.apache.beam.sdk.io.Read.Bounded;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.display.DisplayData;
-import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
/**
* {@link PTransform}s for reading and writing Avro files.
*
- * <p>To read a {@link PCollection} from one or more Avro files, use {@code AvroIO.read()}, using
- * {@link AvroIO.Read#from} to specify the filename or filepattern to read from. See {@link
- * FileSystems} for information on supported file systems and filepatterns.
+ * <p>To read a {@link PCollection} from one or more Avro files with the same schema known at
+ * pipeline construction time, use {@code AvroIO.read()}, using {@link AvroIO.Read#from} to specify
+ * the filename or filepattern to read from. Alternatively, if the filepatterns to be read are
+ * themselves in a {@link PCollection}, apply {@link #readAll}.
+ *
+ * <p>See {@link FileSystems} for information on supported file systems and filepatterns.
*
* <p>To read specific records, such as Avro-generated classes, use {@link #read(Class)}. To read
* {@link GenericRecord GenericRecords}, use {@link #readGenericRecords(Schema)} which takes a
* {@link Schema} object, or {@link #readGenericRecords(String)} which takes an Avro schema in a
* JSON-encoded string form. An exception will be thrown if a record doesn't match the specified
- * schema.
+ * schema. Likewise, to read a {@link PCollection} of filepatterns, apply {@link
+ * #readAllGenericRecords}.
+ *
+ * <p>To read records from files whose schema is unknown at pipeline construction time or differs
+ * between files, use {@link #parseGenericRecords} - in this case, you will need to specify a
+ * parsing function for converting each {@link GenericRecord} into a value of your custom type.
+ * Likewise, to read a {@link PCollection} of filepatterns with unknown schema, use {@link
+ * #parseAllGenericRecords}.
*
* <p>For example:
*
@@ -77,6 +91,33 @@
* PCollection<GenericRecord> records =
* p.apply(AvroIO.readGenericRecords(schema)
* .from("gs://my_bucket/path/to/records-*.avro"));
+ *
+ * PCollection<Foo> records =
+ * p.apply(AvroIO.parseGenericRecords(new SerializableFunction<GenericRecord, Foo>() {
+ * public Foo apply(GenericRecord record) {
+ * // If needed, access the schema of the record using record.getSchema()
+ * return ...;
+ * }
+ * }));
+ * }</pre>
+ *
+ * <p>If it is known that the filepattern will match a very large number of files (e.g. tens of
+ * thousands or more), use {@link Read#withHintMatchesManyFiles} or {@link
+ * Parse#withHintMatchesManyFiles} for better performance and scalability. Note that it may decrease
+ * performance if the filepattern matches only a small number of files.
+ *
+ * <p>Reading from a {@link PCollection} of filepatterns:
+ *
+ * <pre>{@code
+ * Pipeline p = ...;
+ *
+ * PCollection<String> filepatterns = p.apply(...);
+ * PCollection<AvroAutoGenClass> records =
+ * filepatterns.apply(AvroIO.read(AvroAutoGenClass.class));
+ * PCollection<GenericRecord> genericRecords =
+ * filepatterns.apply(AvroIO.readGenericRecords(schema));
+ * PCollection<Foo> records =
+ * filepatterns.apply(AvroIO.parseAllGenericRecords(new SerializableFunction...);
* }</pre>
*
* <p>To write a {@link PCollection} to one or more Avro files, use {@link AvroIO.Write}, using
@@ -116,6 +157,51 @@
* .withSuffix(".avro"));
* }</pre>
*
+ * <p>The following shows a more-complex example of AvroIO.Write usage, generating dynamic file
+ * destinations as well as a dynamic Avro schema per file. In this example, a PCollection of user
+ * events (e.g. actions on a website) is written out to Avro files. Each event contains the user id
+ * as an integer field. We want events for each user to go into a specific directory for that user,
+ * and each user's data should be written with a specific schema for that user; a side input is
+ * used, so the schema can be calculated in a different stage.
+ *
+ * <pre>{@code
+ * // This is the user class that controls dynamic destinations for this avro write. The input to
+ * // AvroIO.Write will be UserEvent, and we will be writing GenericRecords to the file (in order
+ * // to have dynamic schemas). Everything is per userid, so we define a dynamic destination type
+ * // of Integer.
+ * class UserDynamicAvroDestinations
+ * extends DynamicAvroDestinations<UserEvent, Integer, GenericRecord> {
+ * private final PCollectionView<Map<Integer, String>> userToSchemaMap;
+ * public UserDynamicAvroDestinations( PCollectionView<Map<Integer, String>> userToSchemaMap) {
+ * this.userToSchemaMap = userToSchemaMap;
+ * }
+ * public GenericRecord formatRecord(UserEvent record) {
+ * return formatUserRecord(record, getSchema(record.getUserId()));
+ * }
+ * public Schema getSchema(Integer userId) {
+ * return new Schema.Parser().parse(sideInput(userToSchemaMap).get(userId));
+ * }
+ * public Integer getDestination(UserEvent record) {
+ * return record.getUserId();
+ * }
+ * public Integer getDefaultDestination() {
+ * return 0;
+ * }
+ * public FilenamePolicy getFilenamePolicy(Integer userId) {
+ * return DefaultFilenamePolicy.fromParams(new Params().withBaseFilename(baseDir + "/user-"
+ * + userId + "/events"));
+ * }
+ * public List<PCollectionView<?>> getSideInputs() {
+ * return ImmutableList.<PCollectionView<?>>of(userToSchemaMap);
+ * }
+ * }
+ * PCollection<UserEvents> events = ...;
+ * PCollectionView<Integer, String> schemaMap = events.apply(
+ * "ComputeSchemas", new ComputePerUserSchemas());
+ * events.apply("WriteAvros", AvroIO.<Integer>writeCustomTypeToGenericRecords()
+ * .to(new UserDynamicAvros()));
+ * }</pre>
+ *
* <p>By default, {@link AvroIO.Write} produces output files that are compressed using the {@link
* org.apache.avro.file.Codec CodecFactory.deflateCodec(6)}. This default can be changed or
* overridden using {@link AvroIO.Write#withCodec}.
@@ -130,6 +216,19 @@
return new AutoValue_AvroIO_Read.Builder<T>()
.setRecordClass(recordClass)
.setSchema(ReflectData.get().getSchema(recordClass))
+ .setHintMatchesManyFiles(false)
+ .build();
+ }
+
+ /** Like {@link #read}, but reads each filepattern in the input {@link PCollection}. */
+ public static <T> ReadAll<T> readAll(Class<T> recordClass) {
+ return new AutoValue_AvroIO_ReadAll.Builder<T>()
+ .setRecordClass(recordClass)
+ .setSchema(ReflectData.get().getSchema(recordClass))
+ // 64MB is a reasonable value that allows to amortize the cost of opening files,
+ // but is not so large as to exhaust a typical runner's maximum amount of output per
+ // ProcessElement call.
+ .setDesiredBundleSizeBytes(64 * 1024 * 1024L)
.build();
}
@@ -138,6 +237,19 @@
return new AutoValue_AvroIO_Read.Builder<GenericRecord>()
.setRecordClass(GenericRecord.class)
.setSchema(schema)
+ .setHintMatchesManyFiles(false)
+ .build();
+ }
+
+ /**
+ * Like {@link #readGenericRecords(Schema)}, but reads each filepattern in the input {@link
+ * PCollection}.
+ */
+ public static ReadAll<GenericRecord> readAllGenericRecords(Schema schema) {
+ return new AutoValue_AvroIO_ReadAll.Builder<GenericRecord>()
+ .setRecordClass(GenericRecord.class)
+ .setSchema(schema)
+ .setDesiredBundleSizeBytes(64 * 1024 * 1024L)
.build();
}
@@ -150,22 +262,88 @@
}
/**
+ * Like {@link #readGenericRecords(String)}, but reads each filepattern in the input {@link
+ * PCollection}.
+ */
+ public static ReadAll<GenericRecord> readAllGenericRecords(String schema) {
+ return readAllGenericRecords(new Schema.Parser().parse(schema));
+ }
+
+ /**
+ * Reads Avro file(s) containing records of an unspecified schema and converting each record to a
+ * custom type.
+ */
+ public static <T> Parse<T> parseGenericRecords(SerializableFunction<GenericRecord, T> parseFn) {
+ return new AutoValue_AvroIO_Parse.Builder<T>()
+ .setParseFn(parseFn)
+ .setHintMatchesManyFiles(false)
+ .build();
+ }
+
+ /**
+ * Like {@link #parseGenericRecords(SerializableFunction)}, but reads each filepattern in the
+ * input {@link PCollection}.
+ */
+ public static <T> ParseAll<T> parseAllGenericRecords(
+ SerializableFunction<GenericRecord, T> parseFn) {
+ return new AutoValue_AvroIO_ParseAll.Builder<T>()
+ .setParseFn(parseFn)
+ .setDesiredBundleSizeBytes(64 * 1024 * 1024L)
+ .build();
+ }
+
+ /**
* Writes a {@link PCollection} to an Avro file (or multiple Avro files matching a sharding
* pattern).
*/
public static <T> Write<T> write(Class<T> recordClass) {
- return AvroIO.<T>defaultWriteBuilder()
- .setRecordClass(recordClass)
- .setSchema(ReflectData.get().getSchema(recordClass))
- .build();
+ return new Write<>(
+ AvroIO.<T, T>defaultWriteBuilder()
+ .setGenericRecords(false)
+ .setSchema(ReflectData.get().getSchema(recordClass))
+ .build());
}
/** Writes Avro records of the specified schema. */
public static Write<GenericRecord> writeGenericRecords(Schema schema) {
- return AvroIO.<GenericRecord>defaultWriteBuilder()
- .setRecordClass(GenericRecord.class)
- .setSchema(schema)
- .build();
+ return new Write<>(
+ AvroIO.<GenericRecord, GenericRecord>defaultWriteBuilder()
+ .setGenericRecords(true)
+ .setSchema(schema)
+ .build());
+ }
+
+ /**
+ * A {@link PTransform} that writes a {@link PCollection} to an avro file (or multiple avro files
+ * matching a sharding pattern), with each element of the input collection encoded into its own
+ * record of type OutputT.
+ *
+ * <p>This version allows you to apply {@link AvroIO} writes to a PCollection of a custom type
+ * {@link UserT}. A format mechanism that converts the input type {@link UserT} to the output type
+ * that will be written to the file must be specified. If using a custom {@link
+ * DynamicAvroDestinations} object this is done using {@link
+ * DynamicAvroDestinations#formatRecord}, otherwise the {@link
+ * AvroIO.TypedWrite#withFormatFunction} can be used to specify a format function.
+ *
+ * <p>The advantage of using a custom type is that is it allows a user-provided {@link
+ * DynamicAvroDestinations} object, set via {@link AvroIO.Write#to(DynamicAvroDestinations)} to
+ * examine the custom type when choosing a destination.
+ *
+ * <p>If the output type is {@link GenericRecord} use {@link #writeCustomTypeToGenericRecords()}
+ * instead.
+ */
+ public static <UserT, OutputT> TypedWrite<UserT, OutputT> writeCustomType() {
+ return AvroIO.<UserT, OutputT>defaultWriteBuilder().setGenericRecords(false).build();
+ }
+
+ /**
+ * Similar to {@link #writeCustomType()}, but specialized for the case where the output type is
+ * {@link GenericRecord}. A schema must be specified either in {@link
+ * DynamicAvroDestinations#getSchema} or if not using dynamic destinations, by using {@link
+ * TypedWrite#withSchema(Schema)}.
+ */
+ public static <UserT> TypedWrite<UserT, GenericRecord> writeCustomTypeToGenericRecords() {
+ return AvroIO.<UserT, GenericRecord>defaultWriteBuilder().setGenericRecords(true).build();
}
/**
@@ -175,86 +353,315 @@
return writeGenericRecords(new Schema.Parser().parse(schema));
}
- private static <T> Write.Builder<T> defaultWriteBuilder() {
- return new AutoValue_AvroIO_Write.Builder<T>()
+ private static <UserT, OutputT> TypedWrite.Builder<UserT, OutputT> defaultWriteBuilder() {
+ return new AutoValue_AvroIO_TypedWrite.Builder<UserT, OutputT>()
.setFilenameSuffix(null)
.setShardTemplate(null)
.setNumShards(0)
- .setCodec(Write.DEFAULT_CODEC)
+ .setCodec(TypedWrite.DEFAULT_SERIALIZABLE_CODEC)
.setMetadata(ImmutableMap.<String, Object>of())
.setWindowedWrites(false);
}
- /** Implementation of {@link #read}. */
+ /** Implementation of {@link #read} and {@link #readGenericRecords}. */
@AutoValue
public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> {
- @Nullable abstract String getFilepattern();
+ @Nullable abstract ValueProvider<String> getFilepattern();
@Nullable abstract Class<T> getRecordClass();
@Nullable abstract Schema getSchema();
+ abstract boolean getHintMatchesManyFiles();
abstract Builder<T> toBuilder();
@AutoValue.Builder
abstract static class Builder<T> {
- abstract Builder<T> setFilepattern(String filepattern);
+ abstract Builder<T> setFilepattern(ValueProvider<String> filepattern);
abstract Builder<T> setRecordClass(Class<T> recordClass);
abstract Builder<T> setSchema(Schema schema);
+ abstract Builder<T> setHintMatchesManyFiles(boolean hintManyFiles);
abstract Read<T> build();
}
- /** Reads from the given filename or filepattern. */
- public Read<T> from(String filepattern) {
+ /**
+ * Reads from the given filename or filepattern.
+ *
+ * <p>If it is known that the filepattern will match a very large number of files (at least tens
+ * of thousands), use {@link #withHintMatchesManyFiles} for better performance and scalability.
+ */
+ public Read<T> from(ValueProvider<String> filepattern) {
return toBuilder().setFilepattern(filepattern).build();
}
+ /** Like {@link #from(ValueProvider)}. */
+ public Read<T> from(String filepattern) {
+ return from(StaticValueProvider.of(filepattern));
+ }
+
+ /**
+ * Hints that the filepattern specified in {@link #from(String)} matches a very large number of
+ * files.
+ *
+ * <p>This hint may cause a runner to execute the transform differently, in a way that improves
+ * performance for this case, but it may worsen performance if the filepattern matches only a
+ * small number of files (e.g., in a runner that supports dynamic work rebalancing, it will
+ * happen less efficiently within individual files).
+ */
+ public Read<T> withHintMatchesManyFiles() {
+ return toBuilder().setHintMatchesManyFiles(true).build();
+ }
+
@Override
public PCollection<T> expand(PBegin input) {
- if (getFilepattern() == null) {
- throw new IllegalStateException(
- "need to set the filepattern of an AvroIO.Read transform");
+ checkNotNull(getFilepattern(), "filepattern");
+ checkNotNull(getSchema(), "schema");
+ if (getHintMatchesManyFiles()) {
+ ReadAll<T> readAll =
+ (getRecordClass() == GenericRecord.class)
+ ? (ReadAll<T>) readAllGenericRecords(getSchema())
+ : readAll(getRecordClass());
+ return input
+ .apply(Create.ofProvider(getFilepattern(), StringUtf8Coder.of()))
+ .apply(readAll);
+ } else {
+ return input
+ .getPipeline()
+ .apply(
+ "Read",
+ org.apache.beam.sdk.io.Read.from(
+ createSource(getFilepattern(), getRecordClass(), getSchema())));
}
- if (getSchema() == null) {
- throw new IllegalStateException("need to set the schema of an AvroIO.Read transform");
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+ builder.addIfNotNull(
+ DisplayData.item("filePattern", getFilepattern()).withLabel("Input File Pattern"));
+ }
+
+ @SuppressWarnings("unchecked")
+ private static <T> AvroSource<T> createSource(
+ ValueProvider<String> filepattern, Class<T> recordClass, Schema schema) {
+ return recordClass == GenericRecord.class
+ ? (AvroSource<T>) AvroSource.from(filepattern).withSchema(schema)
+ : AvroSource.from(filepattern).withSchema(recordClass);
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /** Implementation of {@link #readAll}. */
+ @AutoValue
+ public abstract static class ReadAll<T> extends PTransform<PCollection<String>, PCollection<T>> {
+ @Nullable abstract Class<T> getRecordClass();
+ @Nullable abstract Schema getSchema();
+ abstract long getDesiredBundleSizeBytes();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setRecordClass(Class<T> recordClass);
+ abstract Builder<T> setSchema(Schema schema);
+ abstract Builder<T> setDesiredBundleSizeBytes(long desiredBundleSizeBytes);
+
+ abstract ReadAll<T> build();
+ }
+
+ @VisibleForTesting
+ ReadAll<T> withDesiredBundleSizeBytes(long desiredBundleSizeBytes) {
+ return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build();
+ }
+
+ @Override
+ public PCollection<T> expand(PCollection<String> input) {
+ checkNotNull(getSchema(), "schema");
+ return input
+ .apply(Match.filepatterns())
+ .apply(
+ "Read all via FileBasedSource",
+ new ReadAllViaFileBasedSource<>(
+ SerializableFunctions.<String, Boolean>constant(true) /* isSplittable */,
+ getDesiredBundleSizeBytes(),
+ new CreateSourceFn<>(getRecordClass(), getSchema().toString())))
+ .setCoder(AvroCoder.of(getRecordClass(), getSchema()));
+ }
+ }
+
+ private static class CreateSourceFn<T>
+ implements SerializableFunction<String, FileBasedSource<T>> {
+ private final Class<T> recordClass;
+ private final Supplier<Schema> schemaSupplier;
+
+ public CreateSourceFn(Class<T> recordClass, String jsonSchema) {
+ this.recordClass = recordClass;
+ this.schemaSupplier = AvroUtils.serializableSchemaSupplier(jsonSchema);
+ }
+
+ @Override
+ public FileBasedSource<T> apply(String input) {
+ return Read.createSource(
+ StaticValueProvider.of(input), recordClass, schemaSupplier.get());
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /** Implementation of {@link #parseGenericRecords}. */
+ @AutoValue
+ public abstract static class Parse<T> extends PTransform<PBegin, PCollection<T>> {
+ @Nullable abstract ValueProvider<String> getFilepattern();
+ abstract SerializableFunction<GenericRecord, T> getParseFn();
+ @Nullable abstract Coder<T> getCoder();
+ abstract boolean getHintMatchesManyFiles();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setFilepattern(ValueProvider<String> filepattern);
+ abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T> parseFn);
+ abstract Builder<T> setCoder(Coder<T> coder);
+ abstract Builder<T> setHintMatchesManyFiles(boolean hintMatchesManyFiles);
+
+ abstract Parse<T> build();
+ }
+
+ /** Reads from the given filename or filepattern. */
+ public Parse<T> from(String filepattern) {
+ return from(StaticValueProvider.of(filepattern));
+ }
+
+ /** Like {@link #from(String)}. */
+ public Parse<T> from(ValueProvider<String> filepattern) {
+ return toBuilder().setFilepattern(filepattern).build();
+ }
+
+ /** Sets a coder for the result of the parse function. */
+ public Parse<T> withCoder(Coder<T> coder) {
+ return toBuilder().setCoder(coder).build();
+ }
+
+ /** Like {@link Read#withHintMatchesManyFiles()}. */
+ public Parse<T> withHintMatchesManyFiles() {
+ return toBuilder().setHintMatchesManyFiles(true).build();
+ }
+
+ @Override
+ public PCollection<T> expand(PBegin input) {
+ checkNotNull(getFilepattern(), "filepattern");
+ Coder<T> coder = inferCoder(getCoder(), getParseFn(), input.getPipeline().getCoderRegistry());
+ if (getHintMatchesManyFiles()) {
+ return input
+ .apply(Create.ofProvider(getFilepattern(), StringUtf8Coder.of()))
+ .apply(parseAllGenericRecords(getParseFn()).withCoder(getCoder()));
}
+ return input.apply(
+ org.apache.beam.sdk.io.Read.from(
+ AvroSource.from(getFilepattern()).withParseFn(getParseFn(), coder)));
+ }
- @SuppressWarnings("unchecked")
- Bounded<T> read =
- getRecordClass() == GenericRecord.class
- ? (Bounded<T>) org.apache.beam.sdk.io.Read.from(
- AvroSource.from(getFilepattern()).withSchema(getSchema()))
- : org.apache.beam.sdk.io.Read.from(
- AvroSource.from(getFilepattern()).withSchema(getRecordClass()));
-
- PCollection<T> pcol = input.getPipeline().apply("Read", read);
- // Honor the default output coder that would have been used by this PTransform.
- pcol.setCoder(getDefaultOutputCoder());
- return pcol;
+ private static <T> Coder<T> inferCoder(
+ @Nullable Coder<T> explicitCoder,
+ SerializableFunction<GenericRecord, T> parseFn,
+ CoderRegistry coderRegistry) {
+ if (explicitCoder != null) {
+ return explicitCoder;
+ }
+ // If a coder was not specified explicitly, infer it from parse fn.
+ TypeDescriptor<T> descriptor = TypeDescriptors.outputOf(parseFn);
+ String message =
+ "Unable to infer coder for output of parseFn. Specify it explicitly using withCoder().";
+ checkArgument(descriptor != null, message);
+ try {
+ return coderRegistry.getCoder(descriptor);
+ } catch (CannotProvideCoderException e) {
+ throw new IllegalArgumentException(message, e);
+ }
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder
- .addIfNotNull(DisplayData.item("filePattern", getFilepattern())
- .withLabel("Input File Pattern"));
- }
-
- @Override
- protected Coder<T> getDefaultOutputCoder() {
- return AvroCoder.of(getRecordClass(), getSchema());
+ .addIfNotNull(
+ DisplayData.item("filePattern", getFilepattern()).withLabel("Input File Pattern"))
+ .add(DisplayData.item("parseFn", getParseFn().getClass()).withLabel("Parse function"));
}
}
/////////////////////////////////////////////////////////////////////////////
+ /** Implementation of {@link #parseAllGenericRecords}. */
+ @AutoValue
+ public abstract static class ParseAll<T> extends PTransform<PCollection<String>, PCollection<T>> {
+ abstract SerializableFunction<GenericRecord, T> getParseFn();
+ @Nullable abstract Coder<T> getCoder();
+ abstract long getDesiredBundleSizeBytes();
+
+ abstract Builder<T> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+ abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T> parseFn);
+ abstract Builder<T> setCoder(Coder<T> coder);
+ abstract Builder<T> setDesiredBundleSizeBytes(long desiredBundleSizeBytes);
+
+ abstract ParseAll<T> build();
+ }
+
+ /** Specifies the coder for the result of the {@code parseFn}. */
+ public ParseAll<T> withCoder(Coder<T> coder) {
+ return toBuilder().setCoder(coder).build();
+ }
+
+ @VisibleForTesting
+ ParseAll<T> withDesiredBundleSizeBytes(long desiredBundleSizeBytes) {
+ return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build();
+ }
+
+ @Override
+ public PCollection<T> expand(PCollection<String> input) {
+ final Coder<T> coder =
+ Parse.inferCoder(getCoder(), getParseFn(), input.getPipeline().getCoderRegistry());
+ SerializableFunction<String, FileBasedSource<T>> createSource =
+ new SerializableFunction<String, FileBasedSource<T>>() {
+ @Override
+ public FileBasedSource<T> apply(String input) {
+ return AvroSource.from(input).withParseFn(getParseFn(), coder);
+ }
+ };
+ return input
+ .apply(Match.filepatterns())
+ .apply(
+ "Parse all via FileBasedSource",
+ new ReadAllViaFileBasedSource<>(
+ SerializableFunctions.<String, Boolean>constant(true) /* isSplittable */,
+ getDesiredBundleSizeBytes(),
+ createSource))
+ .setCoder(coder);
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+ builder.add(DisplayData.item("parseFn", getParseFn().getClass()).withLabel("Parse function"));
+ }
+ }
+
+ // ///////////////////////////////////////////////////////////////////////////
+
/** Implementation of {@link #write}. */
@AutoValue
- public abstract static class Write<T> extends PTransform<PCollection<T>, PDone> {
- private static final SerializableAvroCodecFactory DEFAULT_CODEC =
- new SerializableAvroCodecFactory(CodecFactory.deflateCodec(6));
- // This should be a multiple of 4 to not get a partial encoded byte.
- private static final int METADATA_BYTES_MAX_LENGTH = 40;
+ public abstract static class TypedWrite<UserT, OutputT>
+ extends PTransform<PCollection<UserT>, PDone> {
+ static final CodecFactory DEFAULT_CODEC = CodecFactory.deflateCodec(6);
+ static final SerializableAvroCodecFactory DEFAULT_SERIALIZABLE_CODEC =
+ new SerializableAvroCodecFactory(DEFAULT_CODEC);
+
+ @Nullable
+ abstract SerializableFunction<UserT, OutputT> getFormatFunction();
@Nullable abstract ValueProvider<ResourceId> getFilenamePrefix();
@Nullable abstract String getShardTemplate();
@@ -264,11 +671,16 @@
abstract ValueProvider<ResourceId> getTempDirectory();
abstract int getNumShards();
- @Nullable abstract Class<T> getRecordClass();
+
+ abstract boolean getGenericRecords();
+
@Nullable abstract Schema getSchema();
abstract boolean getWindowedWrites();
@Nullable abstract FilenamePolicy getFilenamePolicy();
+ @Nullable
+ abstract DynamicAvroDestinations<UserT, ?, OutputT> getDynamicDestinations();
+
/**
* The codec used to encode the blocks in the Avro file. String value drawn from those in
* https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html
@@ -277,25 +689,39 @@
/** Avro file metadata. */
abstract ImmutableMap<String, Object> getMetadata();
- abstract Builder<T> toBuilder();
+ abstract Builder<UserT, OutputT> toBuilder();
@AutoValue.Builder
- abstract static class Builder<T> {
- abstract Builder<T> setFilenamePrefix(ValueProvider<ResourceId> filenamePrefix);
- abstract Builder<T> setFilenameSuffix(String filenameSuffix);
+ abstract static class Builder<UserT, OutputT> {
+ abstract Builder<UserT, OutputT> setFormatFunction(
+ SerializableFunction<UserT, OutputT> formatFunction);
- abstract Builder<T> setTempDirectory(ValueProvider<ResourceId> tempDirectory);
+ abstract Builder<UserT, OutputT> setFilenamePrefix(ValueProvider<ResourceId> filenamePrefix);
- abstract Builder<T> setNumShards(int numShards);
- abstract Builder<T> setShardTemplate(String shardTemplate);
- abstract Builder<T> setRecordClass(Class<T> recordClass);
- abstract Builder<T> setSchema(Schema schema);
- abstract Builder<T> setWindowedWrites(boolean windowedWrites);
- abstract Builder<T> setFilenamePolicy(FilenamePolicy filenamePolicy);
- abstract Builder<T> setCodec(SerializableAvroCodecFactory codec);
- abstract Builder<T> setMetadata(ImmutableMap<String, Object> metadata);
+ abstract Builder<UserT, OutputT> setFilenameSuffix(String filenameSuffix);
- abstract Write<T> build();
+ abstract Builder<UserT, OutputT> setTempDirectory(ValueProvider<ResourceId> tempDirectory);
+
+ abstract Builder<UserT, OutputT> setNumShards(int numShards);
+
+ abstract Builder<UserT, OutputT> setShardTemplate(String shardTemplate);
+
+ abstract Builder<UserT, OutputT> setGenericRecords(boolean genericRecords);
+
+ abstract Builder<UserT, OutputT> setSchema(Schema schema);
+
+ abstract Builder<UserT, OutputT> setWindowedWrites(boolean windowedWrites);
+
+ abstract Builder<UserT, OutputT> setFilenamePolicy(FilenamePolicy filenamePolicy);
+
+ abstract Builder<UserT, OutputT> setCodec(SerializableAvroCodecFactory codec);
+
+ abstract Builder<UserT, OutputT> setMetadata(ImmutableMap<String, Object> metadata);
+
+ abstract Builder<UserT, OutputT> setDynamicDestinations(
+ DynamicAvroDestinations<UserT, ?, OutputT> dynamicDestinations);
+
+ abstract TypedWrite<UserT, OutputT> build();
}
/**
@@ -309,7 +735,7 @@
* common suffix (if supplied using {@link #withSuffix(String)}). This default can be overridden
* using {@link #to(FilenamePolicy)}.
*/
- public Write<T> to(String outputPrefix) {
+ public TypedWrite<UserT, OutputT> to(String outputPrefix) {
return to(FileBasedSink.convertToFileResourceIfPossible(outputPrefix));
}
@@ -332,14 +758,12 @@
* infer a directory for temporary files.
*/
@Experimental(Kind.FILESYSTEM)
- public Write<T> to(ResourceId outputPrefix) {
+ public TypedWrite<UserT, OutputT> to(ResourceId outputPrefix) {
return toResource(StaticValueProvider.of(outputPrefix));
}
- /**
- * Like {@link #to(String)}.
- */
- public Write<T> to(ValueProvider<String> outputPrefix) {
+ /** Like {@link #to(String)}. */
+ public TypedWrite<UserT, OutputT> to(ValueProvider<String> outputPrefix) {
return toResource(NestedValueProvider.of(outputPrefix,
new SerializableFunction<String, ResourceId>() {
@Override
@@ -349,11 +773,9 @@
}));
}
- /**
- * Like {@link #to(ResourceId)}.
- */
+ /** Like {@link #to(ResourceId)}. */
@Experimental(Kind.FILESYSTEM)
- public Write<T> toResource(ValueProvider<ResourceId> outputPrefix) {
+ public TypedWrite<UserT, OutputT> toResource(ValueProvider<ResourceId> outputPrefix) {
return toBuilder().setFilenamePrefix(outputPrefix).build();
}
@@ -361,16 +783,52 @@
* Writes to files named according to the given {@link FileBasedSink.FilenamePolicy}. A
* directory for temporary files must be specified using {@link #withTempDirectory}.
*/
- public Write<T> to(FilenamePolicy filenamePolicy) {
+ @Experimental(Kind.FILESYSTEM)
+ public TypedWrite<UserT, OutputT> to(FilenamePolicy filenamePolicy) {
return toBuilder().setFilenamePolicy(filenamePolicy).build();
}
+ /**
+ * Use a {@link DynamicAvroDestinations} object to vend {@link FilenamePolicy} objects. These
+ * objects can examine the input record when creating a {@link FilenamePolicy}. A directory for
+ * temporary files must be specified using {@link #withTempDirectory}.
+ */
+ @Experimental(Kind.FILESYSTEM)
+ public TypedWrite<UserT, OutputT> to(
+ DynamicAvroDestinations<UserT, ?, OutputT> dynamicDestinations) {
+ return toBuilder().setDynamicDestinations(dynamicDestinations).build();
+ }
+
+ /**
+ * Sets the the output schema. Can only be used when the output type is {@link GenericRecord}
+ * and when not using {@link #to(DynamicAvroDestinations)}.
+ */
+ public TypedWrite<UserT, OutputT> withSchema(Schema schema) {
+ return toBuilder().setSchema(schema).build();
+ }
+
+ /**
+ * Specifies a format function to convert {@link UserT} to the output type. If {@link
+ * #to(DynamicAvroDestinations)} is used, {@link DynamicAvroDestinations#formatRecord} must be
+ * used instead.
+ */
+ public TypedWrite<UserT, OutputT> withFormatFunction(
+ SerializableFunction<UserT, OutputT> formatFunction) {
+ return toBuilder().setFormatFunction(formatFunction).build();
+ }
+
/** Set the base directory used to generate temporary files. */
@Experimental(Kind.FILESYSTEM)
- public Write<T> withTempDirectory(ValueProvider<ResourceId> tempDirectory) {
+ public TypedWrite<UserT, OutputT> withTempDirectory(ValueProvider<ResourceId> tempDirectory) {
return toBuilder().setTempDirectory(tempDirectory).build();
}
+ /** Set the base directory used to generate temporary files. */
+ @Experimental(Kind.FILESYSTEM)
+ public TypedWrite<UserT, OutputT> withTempDirectory(ResourceId tempDirectory) {
+ return withTempDirectory(StaticValueProvider.of(tempDirectory));
+ }
+
/**
* Uses the given {@link ShardNameTemplate} for naming output files. This option may only be
* used when using one of the default filename-prefix to() overrides.
@@ -378,7 +836,7 @@
* <p>See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are
* used.
*/
- public Write<T> withShardNameTemplate(String shardTemplate) {
+ public TypedWrite<UserT, OutputT> withShardNameTemplate(String shardTemplate) {
return toBuilder().setShardTemplate(shardTemplate).build();
}
@@ -389,7 +847,7 @@
* <p>See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are
* used.
*/
- public Write<T> withSuffix(String filenameSuffix) {
+ public TypedWrite<UserT, OutputT> withSuffix(String filenameSuffix) {
return toBuilder().setFilenameSuffix(filenameSuffix).build();
}
@@ -403,7 +861,7 @@
*
* @param numShards the number of shards to use, or 0 to let the system decide.
*/
- public Write<T> withNumShards(int numShards) {
+ public TypedWrite<UserT, OutputT> withNumShards(int numShards) {
checkArgument(numShards >= 0);
return toBuilder().setNumShards(numShards).build();
}
@@ -418,7 +876,7 @@
*
* <p>This is equivalent to {@code .withNumShards(1).withShardNameTemplate("")}
*/
- public Write<T> withoutSharding() {
+ public TypedWrite<UserT, OutputT> withoutSharding() {
return withNumShards(1).withShardNameTemplate("");
}
@@ -428,12 +886,12 @@
* <p>If using {@link #to(FileBasedSink.FilenamePolicy)}. Filenames will be generated using
* {@link FilenamePolicy#windowedFilename}. See also {@link WriteFiles#withWindowedWrites()}.
*/
- public Write<T> withWindowedWrites() {
+ public TypedWrite<UserT, OutputT> withWindowedWrites() {
return toBuilder().setWindowedWrites(true).build();
}
/** Writes to Avro file(s) compressed using specified codec. */
- public Write<T> withCodec(CodecFactory codec) {
+ public TypedWrite<UserT, OutputT> withCodec(CodecFactory codec) {
return toBuilder().setCodec(new SerializableAvroCodecFactory(codec)).build();
}
@@ -442,7 +900,7 @@
*
* <p>Supported value types are String, Long, and byte[].
*/
- public Write<T> withMetadata(Map<String, Object> metadata) {
+ public TypedWrite<UserT, OutputT> withMetadata(Map<String, Object> metadata) {
Map<String, String> badKeys = Maps.newLinkedHashMap();
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
Object v = entry.getValue();
@@ -457,18 +915,31 @@
return toBuilder().setMetadata(ImmutableMap.copyOf(metadata)).build();
}
- DynamicDestinations<T, Void> resolveDynamicDestinations() {
- FilenamePolicy usedFilenamePolicy = getFilenamePolicy();
- if (usedFilenamePolicy == null) {
- usedFilenamePolicy =
- DefaultFilenamePolicy.fromStandardParameters(
- getFilenamePrefix(), getShardTemplate(), getFilenameSuffix(), getWindowedWrites());
+ DynamicAvroDestinations<UserT, ?, OutputT> resolveDynamicDestinations() {
+ DynamicAvroDestinations<UserT, ?, OutputT> dynamicDestinations = getDynamicDestinations();
+ if (dynamicDestinations == null) {
+ FilenamePolicy usedFilenamePolicy = getFilenamePolicy();
+ if (usedFilenamePolicy == null) {
+ usedFilenamePolicy =
+ DefaultFilenamePolicy.fromStandardParameters(
+ getFilenamePrefix(),
+ getShardTemplate(),
+ getFilenameSuffix(),
+ getWindowedWrites());
+ }
+ dynamicDestinations =
+ constantDestinations(
+ usedFilenamePolicy,
+ getSchema(),
+ getMetadata(),
+ getCodec().getCodec(),
+ getFormatFunction());
}
- return DynamicFileDestinations.constant(usedFilenamePolicy);
+ return dynamicDestinations;
}
@Override
- public PDone expand(PCollection<T> input) {
+ public PDone expand(PCollection<UserT> input) {
checkArgument(
getFilenamePrefix() != null || getTempDirectory() != null,
"Need to set either the filename prefix or the tempDirectory of a AvroIO.Write "
@@ -479,24 +950,25 @@
"shardTemplate and filenameSuffix should only be used with the default "
+ "filename policy");
}
+ if (getDynamicDestinations() != null) {
+ checkArgument(
+ getFormatFunction() == null,
+ "A format function should not be specified "
+ + "with DynamicDestinations. Use DynamicDestinations.formatRecord instead");
+ }
+
return expandTyped(input, resolveDynamicDestinations());
}
public <DestinationT> PDone expandTyped(
- PCollection<T> input, DynamicDestinations<T, DestinationT> dynamicDestinations) {
+ PCollection<UserT> input,
+ DynamicAvroDestinations<UserT, DestinationT, OutputT> dynamicDestinations) {
ValueProvider<ResourceId> tempDirectory = getTempDirectory();
if (tempDirectory == null) {
tempDirectory = getFilenamePrefix();
}
- WriteFiles<T, DestinationT, T> write =
- WriteFiles.to(
- new AvroSink<>(
- tempDirectory,
- dynamicDestinations,
- AvroCoder.of(getRecordClass(), getSchema()),
- getCodec(),
- getMetadata()),
- SerializableFunctions.<T>identity());
+ WriteFiles<UserT, DestinationT, OutputT> write =
+ WriteFiles.to(new AvroSink<>(tempDirectory, dynamicDestinations, getGenericRecords()));
if (getNumShards() > 0) {
write = write.withNumShards(getNumShards());
}
@@ -519,41 +991,139 @@
: getTempDirectory().toString();
}
builder
- .add(DisplayData.item("schema", getRecordClass()).withLabel("Record Schema"))
.addIfNotDefault(
DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards"), 0)
- .addIfNotDefault(
- DisplayData.item("codec", getCodec().toString()).withLabel("Avro Compression Codec"),
- DEFAULT_CODEC.toString())
.addIfNotNull(
DisplayData.item("tempDirectory", tempDirectory)
.withLabel("Directory for temporary files"));
- builder.include("Metadata", new Metadata());
- }
-
- private class Metadata implements HasDisplayData {
- @Override
- public void populateDisplayData(DisplayData.Builder builder) {
- for (Map.Entry<String, Object> entry : getMetadata().entrySet()) {
- DisplayData.Type type = DisplayData.inferType(entry.getValue());
- if (type != null) {
- builder.add(DisplayData.item(entry.getKey(), type, entry.getValue()));
- } else {
- String base64 = BaseEncoding.base64().encode((byte[]) entry.getValue());
- String repr = base64.length() <= METADATA_BYTES_MAX_LENGTH
- ? base64 : base64.substring(0, METADATA_BYTES_MAX_LENGTH) + "...";
- builder.add(DisplayData.item(entry.getKey(), repr));
- }
- }
- }
- }
-
- @Override
- protected Coder<Void> getDefaultOutputCoder() {
- return VoidCoder.of();
}
}
+ /**
+ * This class is used as the default return value of {@link AvroIO#write}
+ *
+ * <p>All methods in this class delegate to the appropriate method of {@link AvroIO.TypedWrite}.
+ * This class exists for backwards compatibility, and will be removed in Beam 3.0.
+ */
+ public static class Write<T> extends PTransform<PCollection<T>, PDone> {
+ @VisibleForTesting TypedWrite<T, T> inner;
+
+ Write(TypedWrite<T, T> inner) {
+ this.inner = inner;
+ }
+
+ /** See {@link TypedWrite#to(String)}. */
+ public Write<T> to(String outputPrefix) {
+ return new Write<>(
+ inner
+ .to(FileBasedSink.convertToFileResourceIfPossible(outputPrefix))
+ .withFormatFunction(SerializableFunctions.<T>identity()));
+ }
+
+ /** See {@link TypedWrite#to(ResourceId)} . */
+ @Experimental(Kind.FILESYSTEM)
+ public Write<T> to(ResourceId outputPrefix) {
+ return new Write<T>(
+ inner.to(outputPrefix).withFormatFunction(SerializableFunctions.<T>identity()));
+ }
+
+ /** See {@link TypedWrite#to(ValueProvider)}. */
+ public Write<T> to(ValueProvider<String> outputPrefix) {
+ return new Write<>(
+ inner.to(outputPrefix).withFormatFunction(SerializableFunctions.<T>identity()));
+ }
+
+ /** See {@link TypedWrite#to(ResourceId)}. */
+ @Experimental(Kind.FILESYSTEM)
+ public Write<T> toResource(ValueProvider<ResourceId> outputPrefix) {
+ return new Write<>(
+ inner.toResource(outputPrefix).withFormatFunction(SerializableFunctions.<T>identity()));
+ }
+
+ /** See {@link TypedWrite#to(FilenamePolicy)}. */
+ public Write<T> to(FilenamePolicy filenamePolicy) {
+ return new Write<>(
+ inner.to(filenamePolicy).withFormatFunction(SerializableFunctions.<T>identity()));
+ }
+
+ /** See {@link TypedWrite#to(DynamicAvroDestinations)}. */
+ public Write to(DynamicAvroDestinations<T, ?, T> dynamicDestinations) {
+ return new Write<>(inner.to(dynamicDestinations).withFormatFunction(null));
+ }
+
+ /** See {@link TypedWrite#withSchema}. */
+ public Write withSchema(Schema schema) {
+ return new Write<>(inner.withSchema(schema));
+ }
+ /** See {@link TypedWrite#withTempDirectory(ValueProvider)}. */
+ @Experimental(Kind.FILESYSTEM)
+ public Write<T> withTempDirectory(ValueProvider<ResourceId> tempDirectory) {
+ return new Write<>(inner.withTempDirectory(tempDirectory));
+ }
+
+ /** See {@link TypedWrite#withTempDirectory(ResourceId)}. */
+ public Write<T> withTempDirectory(ResourceId tempDirectory) {
+ return new Write<>(inner.withTempDirectory(tempDirectory));
+ }
+
+ /** See {@link TypedWrite#withShardNameTemplate}. */
+ public Write<T> withShardNameTemplate(String shardTemplate) {
+ return new Write<>(inner.withShardNameTemplate(shardTemplate));
+ }
+
+ /** See {@link TypedWrite#withSuffix}. */
+ public Write<T> withSuffix(String filenameSuffix) {
+ return new Write<>(inner.withSuffix(filenameSuffix));
+ }
+
+ /** See {@link TypedWrite#withNumShards}. */
+ public Write<T> withNumShards(int numShards) {
+ return new Write<>(inner.withNumShards(numShards));
+ }
+
+ /** See {@link TypedWrite#withoutSharding}. */
+ public Write<T> withoutSharding() {
+ return new Write<>(inner.withoutSharding());
+ }
+
+ /** See {@link TypedWrite#withWindowedWrites}. */
+ public Write withWindowedWrites() {
+ return new Write<T>(inner.withWindowedWrites());
+ }
+
+ /** See {@link TypedWrite#withCodec}. */
+ public Write<T> withCodec(CodecFactory codec) {
+ return new Write<>(inner.withCodec(codec));
+ }
+
+ /** See {@link TypedWrite#withMetadata} . */
+ public Write withMetadata(Map<String, Object> metadata) {
+ return new Write<>(inner.withMetadata(metadata));
+ }
+
+ @Override
+ public PDone expand(PCollection<T> input) {
+ return inner.expand(input);
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ inner.populateDisplayData(builder);
+ }
+ }
+
+ /**
+ * Returns a {@link DynamicAvroDestinations} that always returns the same {@link FilenamePolicy},
+ * schema, metadata, and codec.
+ */
+ public static <UserT, OutputT> DynamicAvroDestinations<UserT, Void, OutputT> constantDestinations(
+ FilenamePolicy filenamePolicy,
+ Schema schema,
+ Map<String, Object> metadata,
+ CodecFactory codec,
+ SerializableFunction<UserT, OutputT> formatFunction) {
+ return new ConstantAvroDestination<>(filenamePolicy, schema, metadata, codec, formatFunction);
+ }
/////////////////////////////////////////////////////////////////////////////
/** Disallow construction of utility class. */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java
index c78870b..acd3ea6 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSink.java
@@ -17,93 +17,90 @@
*/
package org.apache.beam.sdk.io;
-import com.google.common.collect.ImmutableMap;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.Map;
+import org.apache.avro.Schema;
+import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericDatumWriter;
-import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.reflect.ReflectDatumWriter;
-import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.util.MimeTypes;
/** A {@link FileBasedSink} for Avro files. */
-class AvroSink<T, DestinationT> extends FileBasedSink<T, DestinationT> {
- private final AvroCoder<T> coder;
- private final SerializableAvroCodecFactory codec;
- private final ImmutableMap<String, Object> metadata;
+class AvroSink<UserT, DestinationT, OutputT> extends FileBasedSink<UserT, DestinationT, OutputT> {
+ private final DynamicAvroDestinations<UserT, DestinationT, OutputT> dynamicDestinations;
+ private final boolean genericRecords;
AvroSink(
ValueProvider<ResourceId> outputPrefix,
- DynamicDestinations<T, DestinationT> dynamicDestinations,
- AvroCoder<T> coder,
- SerializableAvroCodecFactory codec,
- ImmutableMap<String, Object> metadata) {
+ DynamicAvroDestinations<UserT, DestinationT, OutputT> dynamicDestinations,
+ boolean genericRecords) {
// Avro handle compression internally using the codec.
super(outputPrefix, dynamicDestinations, CompressionType.UNCOMPRESSED);
- this.coder = coder;
- this.codec = codec;
- this.metadata = metadata;
+ this.dynamicDestinations = dynamicDestinations;
+ this.genericRecords = genericRecords;
}
@Override
- public WriteOperation<T, DestinationT> createWriteOperation() {
- return new AvroWriteOperation<>(this, coder, codec, metadata);
+ public DynamicAvroDestinations<UserT, DestinationT, OutputT> getDynamicDestinations() {
+ return (DynamicAvroDestinations<UserT, DestinationT, OutputT>) super.getDynamicDestinations();
+ }
+
+ @Override
+ public WriteOperation<DestinationT, OutputT> createWriteOperation() {
+ return new AvroWriteOperation<>(this, genericRecords);
}
/** A {@link WriteOperation WriteOperation} for Avro files. */
- private static class AvroWriteOperation<T, DestinationT> extends WriteOperation<T, DestinationT> {
- private final AvroCoder<T> coder;
- private final SerializableAvroCodecFactory codec;
- private final ImmutableMap<String, Object> metadata;
+ private static class AvroWriteOperation<DestinationT, OutputT>
+ extends WriteOperation<DestinationT, OutputT> {
+ private final DynamicAvroDestinations<?, DestinationT, ?> dynamicDestinations;
+ private final boolean genericRecords;
- private AvroWriteOperation(
- AvroSink<T, DestinationT> sink,
- AvroCoder<T> coder,
- SerializableAvroCodecFactory codec,
- ImmutableMap<String, Object> metadata) {
+ private AvroWriteOperation(AvroSink<?, DestinationT, OutputT> sink, boolean genericRecords) {
super(sink);
- this.coder = coder;
- this.codec = codec;
- this.metadata = metadata;
+ this.dynamicDestinations = sink.getDynamicDestinations();
+ this.genericRecords = genericRecords;
}
@Override
- public Writer<T, DestinationT> createWriter() throws Exception {
- return new AvroWriter<>(this, coder, codec, metadata);
+ public Writer<DestinationT, OutputT> createWriter() throws Exception {
+ return new AvroWriter<>(this, dynamicDestinations, genericRecords);
}
}
/** A {@link Writer Writer} for Avro files. */
- private static class AvroWriter<T, DestinationT> extends Writer<T, DestinationT> {
- private final AvroCoder<T> coder;
- private DataFileWriter<T> dataFileWriter;
- private SerializableAvroCodecFactory codec;
- private final ImmutableMap<String, Object> metadata;
+ private static class AvroWriter<DestinationT, OutputT> extends Writer<DestinationT, OutputT> {
+ private DataFileWriter<OutputT> dataFileWriter;
+ private final DynamicAvroDestinations<?, DestinationT, ?> dynamicDestinations;
+ private final boolean genericRecords;
public AvroWriter(
- WriteOperation<T, DestinationT> writeOperation,
- AvroCoder<T> coder,
- SerializableAvroCodecFactory codec,
- ImmutableMap<String, Object> metadata) {
+ WriteOperation<DestinationT, OutputT> writeOperation,
+ DynamicAvroDestinations<?, DestinationT, ?> dynamicDestinations,
+ boolean genericRecords) {
super(writeOperation, MimeTypes.BINARY);
- this.coder = coder;
- this.codec = codec;
- this.metadata = metadata;
+ this.dynamicDestinations = dynamicDestinations;
+ this.genericRecords = genericRecords;
}
@SuppressWarnings("deprecation") // uses internal test functionality.
@Override
protected void prepareWrite(WritableByteChannel channel) throws Exception {
- DatumWriter<T> datumWriter = coder.getType().equals(GenericRecord.class)
- ? new GenericDatumWriter<T>(coder.getSchema())
- : new ReflectDatumWriter<T>(coder.getSchema());
+ DestinationT destination = getDestination();
+ CodecFactory codec = dynamicDestinations.getCodec(destination);
+ Schema schema = dynamicDestinations.getSchema(destination);
+ Map<String, Object> metadata = dynamicDestinations.getMetadata(destination);
- dataFileWriter = new DataFileWriter<>(datumWriter).setCodec(codec.getCodec());
+ DatumWriter<OutputT> datumWriter =
+ genericRecords
+ ? new GenericDatumWriter<OutputT>(schema)
+ : new ReflectDatumWriter<OutputT>(schema);
+ dataFileWriter = new DataFileWriter<>(datumWriter).setCodec(codec);
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
Object v = entry.getValue();
if (v instanceof String) {
@@ -118,11 +115,11 @@
+ v.getClass().getSimpleName());
}
}
- dataFileWriter.create(coder.getSchema(), Channels.newOutputStream(channel));
+ dataFileWriter.create(schema, Channels.newOutputStream(channel));
}
@Override
- public void write(T value) throws Exception {
+ public void write(OutputT value) throws Exception {
dataFileWriter.append(value);
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java
index 7cd97a8..8dd3125 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroSource.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io;
+import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
@@ -27,8 +28,10 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidObjectException;
+import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.PushbackInputStream;
+import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
@@ -53,9 +56,12 @@
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.AvroCoder;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
import org.apache.commons.compress.compressors.snappy.SnappyCompressorInputStream;
@@ -129,40 +135,127 @@
// The default sync interval is 64k.
private static final long DEFAULT_MIN_BUNDLE_SIZE = 2 * DataFileConstants.DEFAULT_SYNC_INTERVAL;
- // The type of the records contained in the file.
- private final Class<T> type;
+ // Use cases of AvroSource are:
+ // 1) AvroSource<GenericRecord> Reading GenericRecord records with a specified schema.
+ // 2) AvroSource<Foo> Reading records of a generated Avro class Foo.
+ // 3) AvroSource<T> Reading GenericRecord records with an unspecified schema
+ // and converting them to type T.
+ // | Case 1 | Case 2 | Case 3 |
+ // type | GenericRecord | Foo | GenericRecord |
+ // readerSchemaString | non-null | non-null | null |
+ // parseFn | null | null | non-null |
+ // outputCoder | null | null | non-null |
+ private static class Mode<T> implements Serializable {
+ private final Class<?> type;
- // The JSON schema used to decode records.
- @Nullable
- private final String readerSchemaString;
+ // The JSON schema used to decode records.
+ @Nullable
+ private String readerSchemaString;
+
+ @Nullable
+ private final SerializableFunction<GenericRecord, T> parseFn;
+
+ @Nullable
+ private final Coder<T> outputCoder;
+
+ private Mode(
+ Class<?> type,
+ @Nullable String readerSchemaString,
+ @Nullable SerializableFunction<GenericRecord, T> parseFn,
+ @Nullable Coder<T> outputCoder) {
+ this.type = type;
+ this.readerSchemaString = internSchemaString(readerSchemaString);
+ this.parseFn = parseFn;
+ this.outputCoder = outputCoder;
+ }
+
+ private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException {
+ is.defaultReadObject();
+ readerSchemaString = internSchemaString(readerSchemaString);
+ }
+
+ private Coder<T> getOutputCoder() {
+ if (parseFn == null) {
+ return AvroCoder.of((Class<T>) type, internOrParseSchemaString(readerSchemaString));
+ } else {
+ return outputCoder;
+ }
+ }
+
+ private void validate() {
+ if (parseFn == null) {
+ checkArgument(
+ readerSchemaString != null,
+ "schema must be specified using withSchema() when not using a parse fn");
+ }
+ }
+ }
+
+ private static Mode<GenericRecord> readGenericRecordsWithSchema(String schema) {
+ return new Mode<>(GenericRecord.class, schema, null, null);
+ }
+ private static <T> Mode<T> readGeneratedClasses(Class<T> clazz) {
+ return new Mode<>(clazz, ReflectData.get().getSchema(clazz).toString(), null, null);
+ }
+ private static <T> Mode<T> parseGenericRecords(
+ SerializableFunction<GenericRecord, T> parseFn, Coder<T> outputCoder) {
+ return new Mode<>(GenericRecord.class, null, parseFn, outputCoder);
+ }
+
+ private final Mode<T> mode;
/**
- * Reads from the given file name or pattern ("glob"). The returned source can be further
+ * Reads from the given file name or pattern ("glob"). The returned source needs to be further
* configured by calling {@link #withSchema} to return a type other than {@link GenericRecord}.
*/
+ public static AvroSource<GenericRecord> from(ValueProvider<String> fileNameOrPattern) {
+ return new AvroSource<>(
+ fileNameOrPattern,
+ DEFAULT_MIN_BUNDLE_SIZE,
+ readGenericRecordsWithSchema(null /* will need to be specified in withSchema */));
+ }
+
+ /** Like {@link #from(ValueProvider)}. */
public static AvroSource<GenericRecord> from(String fileNameOrPattern) {
- return new AvroSource<>(fileNameOrPattern, DEFAULT_MIN_BUNDLE_SIZE, null, GenericRecord.class);
+ return from(ValueProvider.StaticValueProvider.of(fileNameOrPattern));
}
/** Reads files containing records that conform to the given schema. */
public AvroSource<GenericRecord> withSchema(String schema) {
+ checkNotNull(schema, "schema");
return new AvroSource<>(
- getFileOrPatternSpec(), getMinBundleSize(), schema, GenericRecord.class);
+ getFileOrPatternSpecProvider(),
+ getMinBundleSize(),
+ readGenericRecordsWithSchema(schema));
}
/** Like {@link #withSchema(String)}. */
public AvroSource<GenericRecord> withSchema(Schema schema) {
- return new AvroSource<>(
- getFileOrPatternSpec(), getMinBundleSize(), schema.toString(), GenericRecord.class);
+ checkNotNull(schema, "schema");
+ return withSchema(schema.toString());
}
/** Reads files containing records of the given class. */
public <X> AvroSource<X> withSchema(Class<X> clazz) {
+ checkNotNull(clazz, "clazz");
return new AvroSource<>(
- getFileOrPatternSpec(),
+ getFileOrPatternSpecProvider(),
getMinBundleSize(),
- ReflectData.get().getSchema(clazz).toString(),
- clazz);
+ readGeneratedClasses(clazz));
+ }
+
+ /**
+ * Reads {@link GenericRecord} of unspecified schema and maps them to instances of a custom type
+ * using the given {@code parseFn} and encoded using the given coder.
+ */
+ public <X> AvroSource<X> withParseFn(
+ SerializableFunction<GenericRecord, X> parseFn, Coder<X> coder) {
+ checkNotNull(parseFn, "parseFn");
+ checkNotNull(parseFn, "coder");
+ return new AvroSource<>(
+ getFileOrPatternSpecProvider(),
+ getMinBundleSize(),
+ parseGenericRecords(parseFn, coder));
}
/**
@@ -170,15 +263,16 @@
* minBundleSize} and its use.
*/
public AvroSource<T> withMinBundleSize(long minBundleSize) {
- return new AvroSource<>(getFileOrPatternSpec(), minBundleSize, readerSchemaString, type);
+ return new AvroSource<>(getFileOrPatternSpecProvider(), minBundleSize, mode);
}
/** Constructor for FILEPATTERN mode. */
private AvroSource(
- String fileNameOrPattern, long minBundleSize, String readerSchemaString, Class<T> type) {
+ ValueProvider<String> fileNameOrPattern,
+ long minBundleSize,
+ Mode<T> mode) {
super(fileNameOrPattern, minBundleSize);
- this.readerSchemaString = internSchemaString(readerSchemaString);
- this.type = type;
+ this.mode = mode;
}
/** Constructor for SINGLE_FILE_OR_SUBRANGE mode. */
@@ -187,18 +281,15 @@
long minBundleSize,
long startOffset,
long endOffset,
- String readerSchemaString,
- Class<T> type) {
+ Mode<T> mode) {
super(metadata, minBundleSize, startOffset, endOffset);
- this.readerSchemaString = internSchemaString(readerSchemaString);
- this.type = type;
+ this.mode = mode;
}
@Override
public void validate() {
- // AvroSource objects do not need to be configured with more than a file pattern. Overridden to
- // make this explicit.
super.validate();
+ mode.validate();
}
/**
@@ -215,7 +306,7 @@
@Override
public BlockBasedSource<T> createForSubrangeOfFile(Metadata fileMetadata, long start, long end) {
- return new AvroSource<>(fileMetadata, getMinBundleSize(), start, end, readerSchemaString, type);
+ return new AvroSource<>(fileMetadata, getMinBundleSize(), start, end, mode);
}
@Override
@@ -224,14 +315,14 @@
}
@Override
- public AvroCoder<T> getDefaultOutputCoder() {
- return AvroCoder.of(type, internOrParseSchemaString(readerSchemaString));
+ public Coder<T> getOutputCoder() {
+ return mode.getOutputCoder();
}
@VisibleForTesting
@Nullable
String getReaderSchemaString() {
- return readerSchemaString;
+ return mode.readerSchemaString;
}
/** Avro file metadata. */
@@ -370,15 +461,9 @@
switch (getMode()) {
case SINGLE_FILE_OR_SUBRANGE:
return new AvroSource<>(
- getSingleFileMetadata(),
- getMinBundleSize(),
- getStartOffset(),
- getEndOffset(),
- readerSchemaString,
- type);
+ getSingleFileMetadata(), getMinBundleSize(), getStartOffset(), getEndOffset(), mode);
case FILEPATTERN:
- return new AvroSource<>(
- getFileOrPatternSpec(), getMinBundleSize(), readerSchemaString, type);
+ return new AvroSource<>(getFileOrPatternSpecProvider(), getMinBundleSize(), mode);
default:
throw new InvalidObjectException(
String.format("Unknown mode %s for AvroSource %s", getMode(), this));
@@ -392,6 +477,8 @@
*/
@Experimental(Experimental.Kind.SOURCE_SINK)
static class AvroBlock<T> extends Block<T> {
+ private final Mode<T> mode;
+
// The number of records in the block.
private final long numRecords;
@@ -402,7 +489,7 @@
private long currentRecordIndex = 0;
// A DatumReader to read records from the block.
- private final DatumReader<T> reader;
+ private final DatumReader<?> reader;
// A BinaryDecoder used by the reader to decode records.
private final BinaryDecoder decoder;
@@ -445,19 +532,19 @@
AvroBlock(
byte[] data,
long numRecords,
- Class<? extends T> type,
- String readerSchemaString,
+ Mode<T> mode,
String writerSchemaString,
String codec)
throws IOException {
+ this.mode = mode;
this.numRecords = numRecords;
checkNotNull(writerSchemaString, "writerSchemaString");
Schema writerSchema = internOrParseSchemaString(writerSchemaString);
Schema readerSchema =
internOrParseSchemaString(
- MoreObjects.firstNonNull(readerSchemaString, writerSchemaString));
+ MoreObjects.firstNonNull(mode.readerSchemaString, writerSchemaString));
this.reader =
- (type == GenericRecord.class)
+ (mode.type == GenericRecord.class)
? new GenericDatumReader<T>(writerSchema, readerSchema)
: new ReflectDatumReader<T>(writerSchema, readerSchema);
this.decoder = DecoderFactory.get().binaryDecoder(decodeAsInputStream(data, codec), null);
@@ -473,7 +560,9 @@
if (currentRecordIndex >= numRecords) {
return false;
}
- currentRecord = reader.read(null, decoder);
+ Object record = reader.read(null, decoder);
+ currentRecord =
+ (mode.parseFn == null) ? ((T) record) : mode.parseFn.apply((GenericRecord) record);
currentRecordIndex++;
return true;
}
@@ -575,8 +664,7 @@
new AvroBlock<>(
data,
numRecords,
- getCurrentSource().type,
- getCurrentSource().readerSchemaString,
+ getCurrentSource().mode,
metadata.getSchemaString(),
metadata.getCodec());
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroUtils.java
new file mode 100644
index 0000000..65c5bf1
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/AvroUtils.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import com.google.common.base.Function;
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
+import java.io.Serializable;
+import org.apache.avro.Schema;
+
+/** Helpers for working with Avro. */
+class AvroUtils {
+ /** Helper to get around the fact that {@link Schema} itself is not serializable. */
+ public static Supplier<Schema> serializableSchemaSupplier(String jsonSchema) {
+ return Suppliers.memoize(
+ Suppliers.compose(new JsonToSchema(), Suppliers.ofInstance(jsonSchema)));
+ }
+
+ private static class JsonToSchema implements Function<String, Schema>, Serializable {
+ @Override
+ public Schema apply(String input) {
+ return new Schema.Parser().parse(input);
+ }
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java
index cf6671e..25e8483 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java
@@ -23,6 +23,7 @@
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
/**
@@ -69,6 +70,11 @@
super(StaticValueProvider.of(fileOrPatternSpec), minBundleSize);
}
+ /** Like {@link #BlockBasedSource(String, long)}. */
+ public BlockBasedSource(ValueProvider<String> fileOrPatternSpec, long minBundleSize) {
+ super(fileOrPatternSpec, minBundleSize);
+ }
+
/**
* Creates a {@code BlockBasedSource} for a single file. Subclasses must call this constructor
* when implementing {@link BlockBasedSource#createForSubrangeOfFile}. See documentation in
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedReadFromUnboundedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedReadFromUnboundedSource.java
index c882447..80a03eb 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedReadFromUnboundedSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BoundedReadFromUnboundedSource.java
@@ -114,12 +114,8 @@
}
}));
}
- return read.apply("StripIds", ParDo.of(new ValueWithRecordId.StripIdsDoFn<T>()));
- }
-
- @Override
- protected Coder<T> getDefaultOutputCoder() {
- return source.getDefaultOutputCoder();
+ return read.apply("StripIds", ParDo.of(new ValueWithRecordId.StripIdsDoFn<T>()))
+ .setCoder(source.getOutputCoder());
}
@Override
@@ -211,8 +207,8 @@
}
@Override
- public Coder<ValueWithRecordId<T>> getDefaultOutputCoder() {
- return ValueWithRecordId.ValueWithRecordIdCoder.of(getSource().getDefaultOutputCoder());
+ public Coder<ValueWithRecordId<T>> getOutputCoder() {
+ return ValueWithRecordId.ValueWithRecordIdCoder.of(getSource().getOutputCoder());
}
@Override
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
index 4baac36..6943a02 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
@@ -146,7 +146,7 @@
public ReadableByteChannel createDecompressingChannel(ReadableByteChannel channel)
throws IOException {
return Channels.newChannel(
- new BZip2CompressorInputStream(Channels.newInputStream(channel)));
+ new BZip2CompressorInputStream(Channels.newInputStream(channel), true));
}
},
@@ -404,11 +404,11 @@
}
/**
- * Returns the delegate source's default output coder.
+ * Returns the delegate source's output coder.
*/
@Override
- public final Coder<T> getDefaultOutputCoder() {
- return sourceDelegate.getDefaultOutputCoder();
+ public final Coder<T> getOutputCoder() {
+ return sourceDelegate.getOutputCoder();
}
public final DecompressingChannelFactory getChannelFactory() {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java
new file mode 100644
index 0000000..b006e26
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ConstantAvroDestination.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.io;
+
+import com.google.common.base.Function;
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
+import com.google.common.io.BaseEncoding;
+import java.io.Serializable;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.avro.Schema;
+import org.apache.avro.file.CodecFactory;
+import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.display.HasDisplayData;
+
+/** Always returns a constant {@link FilenamePolicy}, {@link Schema}, metadata, and codec. */
+class ConstantAvroDestination<UserT, OutputT>
+ extends DynamicAvroDestinations<UserT, Void, OutputT> {
+ private static class SchemaFunction implements Serializable, Function<String, Schema> {
+ @Nullable
+ @Override
+ public Schema apply(@Nullable String input) {
+ return new Schema.Parser().parse(input);
+ }
+ }
+
+ // This should be a multiple of 4 to not get a partial encoded byte.
+ private static final int METADATA_BYTES_MAX_LENGTH = 40;
+ private final FilenamePolicy filenamePolicy;
+ private final Supplier<Schema> schema;
+ private final Map<String, Object> metadata;
+ private final SerializableAvroCodecFactory codec;
+ private final SerializableFunction<UserT, OutputT> formatFunction;
+
+ private class Metadata implements HasDisplayData {
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ for (Map.Entry<String, Object> entry : metadata.entrySet()) {
+ DisplayData.Type type = DisplayData.inferType(entry.getValue());
+ if (type != null) {
+ builder.add(DisplayData.item(entry.getKey(), type, entry.getValue()));
+ } else {
+ String base64 = BaseEncoding.base64().encode((byte[]) entry.getValue());
+ String repr =
+ base64.length() <= METADATA_BYTES_MAX_LENGTH
+ ? base64
+ : base64.substring(0, METADATA_BYTES_MAX_LENGTH) + "...";
+ builder.add(DisplayData.item(entry.getKey(), repr));
+ }
+ }
+ }
+ }
+
+ public ConstantAvroDestination(
+ FilenamePolicy filenamePolicy,
+ Schema schema,
+ Map<String, Object> metadata,
+ CodecFactory codec,
+ SerializableFunction<UserT, OutputT> formatFunction) {
+ this.filenamePolicy = filenamePolicy;
+ this.schema = Suppliers.compose(new SchemaFunction(), Suppliers.ofInstance(schema.toString()));
+ this.metadata = metadata;
+ this.codec = new SerializableAvroCodecFactory(codec);
+ this.formatFunction = formatFunction;
+ }
+
+ @Override
+ public OutputT formatRecord(UserT record) {
+ return formatFunction.apply(record);
+ }
+
+ @Override
+ public Void getDestination(UserT element) {
+ return (Void) null;
+ }
+
+ @Override
+ public Void getDefaultDestination() {
+ return (Void) null;
+ }
+
+ @Override
+ public FilenamePolicy getFilenamePolicy(Void destination) {
+ return filenamePolicy;
+ }
+
+ @Override
+ public Schema getSchema(Void destination) {
+ return schema.get();
+ }
+
+ @Override
+ public Map<String, Object> getMetadata(Void destination) {
+ return metadata;
+ }
+
+ @Override
+ public CodecFactory getCodec(Void destination) {
+ return codec.getCodec();
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ filenamePolicy.populateDisplayData(builder);
+ builder.add(DisplayData.item("schema", schema.get().toString()).withLabel("Record Schema"));
+ builder.addIfNotDefault(
+ DisplayData.item("codec", codec.getCodec().toString()).withLabel("Avro Compression Codec"),
+ AvroIO.TypedWrite.DEFAULT_SERIALIZABLE_CODEC.toString());
+ builder.include("Metadata", new Metadata());
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
index 6202c2b..b47edc7 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
@@ -188,7 +188,7 @@
}
@Override
- public Coder<Long> getDefaultOutputCoder() {
+ public Coder<Long> getOutputCoder() {
return VarLongCoder.of();
}
@@ -364,7 +364,7 @@
public void validate() {}
@Override
- public Coder<Long> getDefaultOutputCoder() {
+ public Coder<Long> getOutputCoder() {
return VarLongCoder.of();
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java
index 4021609..1f438d5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java
@@ -157,7 +157,6 @@
&& shardTemplate.equals(other.shardTemplate)
&& suffix.equals(other.suffix);
}
-
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java
new file mode 100644
index 0000000..f4e8ee6
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicAvroDestinations.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.io;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.Map;
+import org.apache.avro.Schema;
+import org.apache.avro.file.CodecFactory;
+import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations;
+
+/**
+ * A specialization of {@link DynamicDestinations} for {@link AvroIO}. In addition to dynamic file
+ * destinations, this allows specifying other AVRO properties (schema, metadata, codec) per
+ * destination.
+ */
+public abstract class DynamicAvroDestinations<UserT, DestinationT, OutputT>
+ extends DynamicDestinations<UserT, DestinationT, OutputT> {
+ /** Return an AVRO schema for a given destination. */
+ public abstract Schema getSchema(DestinationT destination);
+
+ /** Return AVRO file metadata for a given destination. */
+ public Map<String, Object> getMetadata(DestinationT destination) {
+ return ImmutableMap.<String, Object>of();
+ }
+
+ /** Return an AVRO codec for a given destination. */
+ public CodecFactory getCodec(DestinationT destination) {
+ return AvroIO.TypedWrite.DEFAULT_CODEC;
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java
index e7ef0f6..b087bc5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DynamicFileDestinations.java
@@ -18,6 +18,8 @@
package org.apache.beam.sdk.io;
+import static com.google.common.base.Preconditions.checkState;
+
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params;
@@ -25,20 +27,30 @@
import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations;
import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy;
import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.display.DisplayData;
/** Some helper classes that derive from {@link FileBasedSink.DynamicDestinations}. */
public class DynamicFileDestinations {
/** Always returns a constant {@link FilenamePolicy}. */
- private static class ConstantFilenamePolicy<T> extends DynamicDestinations<T, Void> {
+ private static class ConstantFilenamePolicy<UserT, OutputT>
+ extends DynamicDestinations<UserT, Void, OutputT> {
private final FilenamePolicy filenamePolicy;
+ private final SerializableFunction<UserT, OutputT> formatFunction;
- public ConstantFilenamePolicy(FilenamePolicy filenamePolicy) {
+ public ConstantFilenamePolicy(
+ FilenamePolicy filenamePolicy, SerializableFunction<UserT, OutputT> formatFunction) {
this.filenamePolicy = filenamePolicy;
+ this.formatFunction = formatFunction;
}
@Override
- public Void getDestination(T element) {
+ public OutputT formatRecord(UserT record) {
+ return formatFunction.apply(record);
+ }
+
+ @Override
+ public Void getDestination(UserT element) {
return (Void) null;
}
@@ -59,6 +71,7 @@
@Override
public void populateDisplayData(DisplayData.Builder builder) {
+ checkState(filenamePolicy != null);
filenamePolicy.populateDisplayData(builder);
}
}
@@ -67,14 +80,24 @@
* A base class for a {@link DynamicDestinations} object that returns differently-configured
* instances of {@link DefaultFilenamePolicy}.
*/
- private static class DefaultPolicyDestinations<UserT> extends DynamicDestinations<UserT, Params> {
- SerializableFunction<UserT, Params> destinationFunction;
- Params emptyDestination;
+ private static class DefaultPolicyDestinations<UserT, OutputT>
+ extends DynamicDestinations<UserT, Params, OutputT> {
+ private final SerializableFunction<UserT, Params> destinationFunction;
+ private final Params emptyDestination;
+ private final SerializableFunction<UserT, OutputT> formatFunction;
public DefaultPolicyDestinations(
- SerializableFunction<UserT, Params> destinationFunction, Params emptyDestination) {
+ SerializableFunction<UserT, Params> destinationFunction,
+ Params emptyDestination,
+ SerializableFunction<UserT, OutputT> formatFunction) {
this.destinationFunction = destinationFunction;
this.emptyDestination = emptyDestination;
+ this.formatFunction = formatFunction;
+ }
+
+ @Override
+ public OutputT formatRecord(UserT record) {
+ return formatFunction.apply(record);
}
@Override
@@ -100,16 +123,28 @@
}
/** Returns a {@link DynamicDestinations} that always returns the same {@link FilenamePolicy}. */
- public static <T> DynamicDestinations<T, Void> constant(FilenamePolicy filenamePolicy) {
- return new ConstantFilenamePolicy<>(filenamePolicy);
+ public static <UserT, OutputT> DynamicDestinations<UserT, Void, OutputT> constant(
+ FilenamePolicy filenamePolicy, SerializableFunction<UserT, OutputT> formatFunction) {
+ return new ConstantFilenamePolicy<>(filenamePolicy, formatFunction);
+ }
+
+ /**
+ * A specialization of {@link #constant(FilenamePolicy, SerializableFunction)} for the case where
+ * UserT and OutputT are the same type and the format function is the identity.
+ */
+ public static <UserT> DynamicDestinations<UserT, Void, UserT> constant(
+ FilenamePolicy filenamePolicy) {
+ return new ConstantFilenamePolicy<>(filenamePolicy, SerializableFunctions.<UserT>identity());
}
/**
* Returns a {@link DynamicDestinations} that returns instances of {@link DefaultFilenamePolicy}
* configured with the given {@link Params}.
*/
- public static <UserT> DynamicDestinations<UserT, Params> toDefaultPolicies(
- SerializableFunction<UserT, Params> destinationFunction, Params emptyDestination) {
- return new DefaultPolicyDestinations<>(destinationFunction, emptyDestination);
+ public static <UserT, OutputT> DynamicDestinations<UserT, Params, OutputT> toDefaultPolicies(
+ SerializableFunction<UserT, Params> destinationFunction,
+ Params emptyDestination,
+ SerializableFunction<UserT, OutputT> formatFunction) {
+ return new DefaultPolicyDestinations<>(destinationFunction, emptyDestination, formatFunction);
}
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java
index 9953975..4e2b61c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java
@@ -23,24 +23,25 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verifyNotNull;
import static org.apache.beam.sdk.io.WriteFiles.UNKNOWN_SHARDNUM;
+import static org.apache.beam.sdk.values.TypeDescriptors.extractFromTypeParameters;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
-import java.lang.reflect.TypeVariable;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
-import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -67,6 +68,7 @@
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
@@ -74,7 +76,9 @@
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors.TypeVariableExtractor;
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorOutputStream;
import org.apache.commons.compress.compressors.deflate.DeflateCompressorOutputStream;
import org.joda.time.Instant;
@@ -94,9 +98,9 @@
* <p>The process of writing to file-based sink is as follows:
*
* <ol>
- * <li>An optional subclass-defined initialization,
- * <li>a parallel write of bundles to temporary files, and finally,
- * <li>these temporary files are renamed with final output filenames.
+ * <li>An optional subclass-defined initialization,
+ * <li>a parallel write of bundles to temporary files, and finally,
+ * <li>these temporary files are renamed with final output filenames.
* </ol>
*
* <p>In order to ensure fault-tolerance, a bundle may be executed multiple times (e.g., in the
@@ -120,7 +124,8 @@
* @param <OutputT> the type of values written to the sink.
*/
@Experimental(Kind.FILESYSTEM)
-public abstract class FileBasedSink<OutputT, DestinationT> implements Serializable, HasDisplayData {
+public abstract class FileBasedSink<UserT, DestinationT, OutputT>
+ implements Serializable, HasDisplayData {
private static final Logger LOG = LoggerFactory.getLogger(FileBasedSink.class);
/** Directly supported file output compression types. */
@@ -198,7 +203,7 @@
}
}
- private final DynamicDestinations<?, DestinationT> dynamicDestinations;
+ private final DynamicDestinations<?, DestinationT, OutputT> dynamicDestinations;
/**
* The {@link WritableByteChannelFactory} that is used to wrap the raw data output to the
@@ -214,8 +219,54 @@
* destination type into an instance of {@link FilenamePolicy}.
*/
@Experimental(Kind.FILESYSTEM)
- public abstract static class DynamicDestinations<UserT, DestinationT>
+ public abstract static class DynamicDestinations<UserT, DestinationT, OutputT>
implements HasDisplayData, Serializable {
+ interface SideInputAccessor {
+ <SideInputT> SideInputT sideInput(PCollectionView<SideInputT> view);
+ }
+
+ private SideInputAccessor sideInputAccessor;
+
+ static class SideInputAccessorViaProcessContext implements SideInputAccessor {
+ private DoFn<?, ?>.ProcessContext processContext;
+
+ SideInputAccessorViaProcessContext(DoFn<?, ?>.ProcessContext processContext) {
+ this.processContext = processContext;
+ }
+
+ @Override
+ public <SideInputT> SideInputT sideInput(PCollectionView<SideInputT> view) {
+ return processContext.sideInput(view);
+ }
+ }
+
+ /**
+ * Override to specify that this object needs access to one or more side inputs. This side
+ * inputs must be globally windowed, as they will be accessed from the global window.
+ */
+ public List<PCollectionView<?>> getSideInputs() {
+ return ImmutableList.of();
+ }
+
+ /**
+ * Returns the value of a given side input. The view must be present in {@link
+ * #getSideInputs()}.
+ */
+ protected final <SideInputT> SideInputT sideInput(PCollectionView<SideInputT> view) {
+ return sideInputAccessor.sideInput(view);
+ }
+
+ final void setSideInputAccessor(SideInputAccessor sideInputAccessor) {
+ this.sideInputAccessor = sideInputAccessor;
+ }
+
+ final void setSideInputAccessorFromProcessContext(DoFn<?, ?>.ProcessContext context) {
+ this.sideInputAccessor = new SideInputAccessorViaProcessContext(context);
+ }
+
+ /** Convert an input record type into the output type. */
+ public abstract OutputT formatRecord(UserT record);
+
/**
* Returns an object that represents at a high level the destination being written to. May not
* return null. A destination must have deterministic hash and equality methods defined.
@@ -255,17 +306,17 @@
return destinationCoder;
}
// If dynamicDestinations doesn't provide a coder, try to find it in the coder registry.
- // We must first use reflection to figure out what the type parameter is.
- TypeDescriptor<?> superDescriptor =
- TypeDescriptor.of(getClass()).getSupertype(DynamicDestinations.class);
- if (!superDescriptor.getRawType().equals(DynamicDestinations.class)) {
- throw new AssertionError(
- "Couldn't find the DynamicDestinations superclass of " + this.getClass());
- }
- TypeVariable typeVariable = superDescriptor.getTypeParameter("DestinationT");
- @SuppressWarnings("unchecked")
+ @Nullable
TypeDescriptor<DestinationT> descriptor =
- (TypeDescriptor<DestinationT>) superDescriptor.resolveType(typeVariable);
+ extractFromTypeParameters(
+ this,
+ DynamicDestinations.class,
+ new TypeVariableExtractor<
+ DynamicDestinations<UserT, DestinationT, OutputT>, DestinationT>() {});
+ checkArgument(
+ descriptor != null,
+ "Unable to infer a coder for DestinationT, "
+ + "please specify it explicitly by overriding getDestinationCoder()");
return registry.getCoder(descriptor);
}
}
@@ -323,7 +374,7 @@
@Experimental(Kind.FILESYSTEM)
public FileBasedSink(
ValueProvider<ResourceId> tempDirectoryProvider,
- DynamicDestinations<?, DestinationT> dynamicDestinations) {
+ DynamicDestinations<?, DestinationT, OutputT> dynamicDestinations) {
this(tempDirectoryProvider, dynamicDestinations, CompressionType.UNCOMPRESSED);
}
@@ -331,7 +382,7 @@
@Experimental(Kind.FILESYSTEM)
public FileBasedSink(
ValueProvider<ResourceId> tempDirectoryProvider,
- DynamicDestinations<?, DestinationT> dynamicDestinations,
+ DynamicDestinations<?, DestinationT, OutputT> dynamicDestinations,
WritableByteChannelFactory writableByteChannelFactory) {
this.tempDirectoryProvider =
NestedValueProvider.of(tempDirectoryProvider, new ExtractDirectory());
@@ -341,8 +392,8 @@
/** Return the {@link DynamicDestinations} used. */
@SuppressWarnings("unchecked")
- public <UserT> DynamicDestinations<UserT, DestinationT> getDynamicDestinations() {
- return (DynamicDestinations<UserT, DestinationT>) dynamicDestinations;
+ public DynamicDestinations<UserT, DestinationT, OutputT> getDynamicDestinations() {
+ return (DynamicDestinations<UserT, DestinationT, OutputT>) dynamicDestinations;
}
/**
@@ -357,7 +408,7 @@
public void validate(PipelineOptions options) {}
/** Return a subclass of {@link WriteOperation} that will manage the write to the sink. */
- public abstract WriteOperation<OutputT, DestinationT> createWriteOperation();
+ public abstract WriteOperation<DestinationT, OutputT> createWriteOperation();
public void populateDisplayData(DisplayData.Builder builder) {
getDynamicDestinations().populateDisplayData(builder);
@@ -371,11 +422,11 @@
* written,
*
* <ol>
- * <li>{@link WriteOperation#finalize} is given a list of the temporary files containing the
- * output bundles.
- * <li>During finalize, these temporary files are copied to final output locations and named
- * according to a file naming template.
- * <li>Finally, any temporary files that were created during the write are removed.
+ * <li>{@link WriteOperation#finalize} is given a list of the temporary files containing the
+ * output bundles.
+ * <li>During finalize, these temporary files are copied to final output locations and named
+ * according to a file naming template.
+ * <li>Finally, any temporary files that were created during the write are removed.
* </ol>
*
* <p>Subclass implementations of WriteOperation must implement {@link
@@ -400,9 +451,9 @@
*
* @param <OutputT> the type of values written to the sink.
*/
- public abstract static class WriteOperation<OutputT, DestinationT> implements Serializable {
+ public abstract static class WriteOperation<DestinationT, OutputT> implements Serializable {
/** The Sink that this WriteOperation will write to. */
- protected final FileBasedSink<OutputT, DestinationT> sink;
+ protected final FileBasedSink<?, DestinationT, OutputT> sink;
/** Directory for temporary output files. */
protected final ValueProvider<ResourceId> tempDirectory;
@@ -428,7 +479,7 @@
*
* @param sink the FileBasedSink that will be used to configure this write operation.
*/
- public WriteOperation(FileBasedSink<OutputT, DestinationT> sink) {
+ public WriteOperation(FileBasedSink<?, DestinationT, OutputT> sink) {
this(
sink,
NestedValueProvider.of(sink.getTempDirectoryProvider(), new TemporaryDirectoryBuilder()));
@@ -463,12 +514,12 @@
* @param tempDirectory the base directory to be used for temporary output files.
*/
@Experimental(Kind.FILESYSTEM)
- public WriteOperation(FileBasedSink<OutputT, DestinationT> sink, ResourceId tempDirectory) {
+ public WriteOperation(FileBasedSink<?, DestinationT, OutputT> sink, ResourceId tempDirectory) {
this(sink, StaticValueProvider.of(tempDirectory));
}
private WriteOperation(
- FileBasedSink<OutputT, DestinationT> sink, ValueProvider<ResourceId> tempDirectory) {
+ FileBasedSink<?, DestinationT, OutputT> sink, ValueProvider<ResourceId> tempDirectory) {
this.sink = sink;
this.tempDirectory = tempDirectory;
this.windowedWrites = false;
@@ -478,7 +529,7 @@
* Clients must implement to return a subclass of {@link Writer}. This method must not mutate
* the state of the object.
*/
- public abstract Writer<OutputT, DestinationT> createWriter() throws Exception;
+ public abstract Writer<DestinationT, OutputT> createWriter() throws Exception;
/** Indicates that the operation will be performing windowed writes. */
public void setWindowedWrites(boolean windowedWrites) {
@@ -533,7 +584,7 @@
protected final Map<ResourceId, ResourceId> buildOutputFilenames(
Iterable<FileResult<DestinationT>> writerResults) {
int numShards = Iterables.size(writerResults);
- Map<ResourceId, ResourceId> outputFilenames = new HashMap<>();
+ Map<ResourceId, ResourceId> outputFilenames = Maps.newHashMap();
// Either all results have a shard number set (if the sink is configured with a fixed
// number of shards), or they all don't (otherwise).
@@ -597,7 +648,6 @@
"Only generated %s distinct file names for %s files.",
numDistinctShards,
outputFilenames.size());
-
return outputFilenames;
}
@@ -691,7 +741,7 @@
}
/** Returns the FileBasedSink for this write operation. */
- public FileBasedSink<OutputT, DestinationT> getSink() {
+ public FileBasedSink<?, DestinationT, OutputT> getSink() {
return sink;
}
@@ -727,10 +777,10 @@
*
* @param <OutputT> the type of values to write.
*/
- public abstract static class Writer<OutputT, DestinationT> {
+ public abstract static class Writer<DestinationT, OutputT> {
private static final Logger LOG = LoggerFactory.getLogger(Writer.class);
- private final WriteOperation<OutputT, DestinationT> writeOperation;
+ private final WriteOperation<DestinationT, OutputT> writeOperation;
/** Unique id for this output bundle. */
private String id;
@@ -757,7 +807,7 @@
private final String mimeType;
/** Construct a new {@link Writer} that will produce files of the given MIME type. */
- public Writer(WriteOperation<OutputT, DestinationT> writeOperation, String mimeType) {
+ public Writer(WriteOperation<DestinationT, OutputT> writeOperation, String mimeType) {
checkNotNull(writeOperation);
this.writeOperation = writeOperation;
this.mimeType = mimeType;
@@ -930,9 +980,14 @@
}
/** Return the WriteOperation that this Writer belongs to. */
- public WriteOperation<OutputT, DestinationT> getWriteOperation() {
+ public WriteOperation<DestinationT, OutputT> getWriteOperation() {
return writeOperation;
}
+
+ /** Return the user destination object for this writer. */
+ public DestinationT getDestination() {
+ return destination;
+ }
}
/**
@@ -987,7 +1042,7 @@
@Experimental(Kind.FILESYSTEM)
public ResourceId getDestinationFile(
- DynamicDestinations<?, DestinationT> dynamicDestinations,
+ DynamicDestinations<?, DestinationT, ?> dynamicDestinations,
int numShards,
OutputFileHints outputFileHints) {
checkArgument(getShard() != UNKNOWN_SHARDNUM);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java
index d4413c9..7f865de 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java
@@ -23,19 +23,17 @@
import static com.google.common.base.Verify.verify;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
import java.io.IOException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.ListIterator;
import java.util.NoSuchElementException;
import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
-import org.apache.beam.sdk.io.fs.MatchResult.Status;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
@@ -68,6 +66,7 @@
private static final Logger LOG = LoggerFactory.getLogger(FileBasedSource.class);
private final ValueProvider<String> fileOrPatternSpec;
+ private final EmptyMatchTreatment emptyMatchTreatment;
@Nullable private MatchResult.Metadata singleFileMetadata;
private final Mode mode;
@@ -80,12 +79,25 @@
}
/**
- * Create a {@code FileBaseSource} based on a file or a file pattern specification.
+ * Create a {@code FileBaseSource} based on a file or a file pattern specification, with the given
+ * strategy for treating filepatterns that do not match any files.
+ */
+ protected FileBasedSource(
+ ValueProvider<String> fileOrPatternSpec,
+ EmptyMatchTreatment emptyMatchTreatment,
+ long minBundleSize) {
+ super(0, Long.MAX_VALUE, minBundleSize);
+ this.mode = Mode.FILEPATTERN;
+ this.emptyMatchTreatment = emptyMatchTreatment;
+ this.fileOrPatternSpec = fileOrPatternSpec;
+ }
+
+ /**
+ * Like {@link #FileBasedSource(ValueProvider, EmptyMatchTreatment, long)}, but uses the default
+ * value of {@link EmptyMatchTreatment#DISALLOW}.
*/
protected FileBasedSource(ValueProvider<String> fileOrPatternSpec, long minBundleSize) {
- super(0, Long.MAX_VALUE, minBundleSize);
- mode = Mode.FILEPATTERN;
- this.fileOrPatternSpec = fileOrPatternSpec;
+ this(fileOrPatternSpec, EmptyMatchTreatment.DISALLOW, minBundleSize);
}
/**
@@ -110,6 +122,9 @@
mode = Mode.SINGLE_FILE_OR_SUBRANGE;
this.singleFileMetadata = checkNotNull(fileMetadata, "fileMetadata");
this.fileOrPatternSpec = StaticValueProvider.of(fileMetadata.resourceId().toString());
+
+ // This field will be unused in this mode.
+ this.emptyMatchTreatment = null;
}
/**
@@ -204,14 +219,7 @@
if (mode == Mode.FILEPATTERN) {
long totalSize = 0;
- List<MatchResult> inputs = FileSystems.match(Collections.singletonList(fileOrPattern));
- MatchResult result = Iterables.getOnlyElement(inputs);
- checkArgument(
- result.status() == Status.OK,
- "Error matching the pattern or glob %s: status %s",
- fileOrPattern,
- result.status());
- List<Metadata> allMatches = result.metadata();
+ List<Metadata> allMatches = FileSystems.match(fileOrPattern, emptyMatchTreatment).metadata();
for (Metadata metadata : allMatches) {
totalSize += metadata.sizeBytes();
}
@@ -254,9 +262,8 @@
if (mode == Mode.FILEPATTERN) {
long startTime = System.currentTimeMillis();
- List<Metadata> expandedFiles = FileBasedSource.expandFilePattern(fileOrPattern);
- checkArgument(!expandedFiles.isEmpty(),
- "Unable to find any files matching %s", fileOrPattern);
+ List<Metadata> expandedFiles =
+ FileSystems.match(fileOrPattern, emptyMatchTreatment).metadata();
List<FileBasedSource<T>> splitResults = new ArrayList<>(expandedFiles.size());
for (Metadata metadata : expandedFiles) {
FileBasedSource<T> split = createForSubrangeOfFile(metadata, 0, metadata.sizeBytes());
@@ -327,7 +334,9 @@
if (mode == Mode.FILEPATTERN) {
long startTime = System.currentTimeMillis();
- List<Metadata> fileMetadata = FileBasedSource.expandFilePattern(fileOrPattern);
+ List<Metadata> fileMetadata =
+ FileSystems.match(fileOrPattern, emptyMatchTreatment).metadata();
+ LOG.info("Matched {} files for pattern {}", fileMetadata.size(), fileOrPattern);
List<FileBasedReader<T>> fileReaders = new ArrayList<>();
for (Metadata metadata : fileMetadata) {
long endOffset = metadata.sizeBytes();
@@ -389,13 +398,6 @@
return metadata.sizeBytes();
}
- private static List<Metadata> expandFilePattern(String fileOrPatternSpec) throws IOException {
- MatchResult matches =
- Iterables.getOnlyElement(FileSystems.match(Collections.singletonList(fileOrPatternSpec)));
- LOG.info("Matched {} files for pattern {}", matches.metadata().size(), fileOrPatternSpec);
- return ImmutableList.copyOf(matches.metadata());
- }
-
/**
* A {@link Source.Reader reader} that implements code common to readers of
* {@code FileBasedSource}s.
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java
index 2ed29e3..96394b8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java
@@ -54,6 +54,7 @@
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.io.fs.CreateOptions;
import org.apache.beam.sdk.io.fs.CreateOptions.StandardCreateOptions;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.io.fs.MatchResult.Status;
@@ -69,16 +70,23 @@
@Experimental(Kind.FILESYSTEM)
public class FileSystems {
- public static final String DEFAULT_SCHEME = "default";
+ public static final String DEFAULT_SCHEME = "file";
private static final Pattern FILE_SCHEME_PATTERN =
Pattern.compile("(?<scheme>[a-zA-Z][-a-zA-Z0-9+.]*):.*");
+ private static final Pattern GLOB_PATTERN =
+ Pattern.compile("[*?{}]");
private static final AtomicReference<Map<String, FileSystem>> SCHEME_TO_FILESYSTEM =
new AtomicReference<Map<String, FileSystem>>(
- ImmutableMap.<String, FileSystem>of("file", new LocalFileSystem()));
+ ImmutableMap.<String, FileSystem>of(DEFAULT_SCHEME, new LocalFileSystem()));
/********************************** METHODS FOR CLIENT **********************************/
+ /** Checks whether the given spec contains a glob wildcard character. */
+ public static boolean hasGlobWildcard(String spec) {
+ return GLOB_PATTERN.matcher(spec).find();
+ }
+
/**
* This is the entry point to convert user-provided specs to {@link ResourceId ResourceIds}.
* Callers should use {@link #match} to resolve users specs ambiguities before
@@ -99,6 +107,12 @@
* component of {@link ResourceId}. This allows SDK libraries to construct file system agnostic
* spec. {@link FileSystem FileSystems} can support additional patterns for user-provided specs.
*
+ * <p>In case the spec schemes don't match any known {@link FileSystem} implementations,
+ * FileSystems will attempt to use {@link LocalFileSystem} to resolve a path.
+ *
+ * <p>Specs that do not match any resources are treated according to
+ * {@link EmptyMatchTreatment#DISALLOW}.
+ *
* @return {@code List<MatchResult>} in the same order of the input specs.
*
* @throws IllegalArgumentException if specs are invalid -- empty or have different schemes.
@@ -111,6 +125,17 @@
return getFileSystemInternal(getOnlyScheme(specs)).match(specs);
}
+ /** Like {@link #match(List)}, but with a configurable {@link EmptyMatchTreatment}. */
+ public static List<MatchResult> match(List<String> specs, EmptyMatchTreatment emptyMatchTreatment)
+ throws IOException {
+ List<MatchResult> matches = getFileSystemInternal(getOnlyScheme(specs)).match(specs);
+ List<MatchResult> res = Lists.newArrayListWithExpectedSize(matches.size());
+ for (int i = 0; i < matches.size(); i++) {
+ res.add(maybeAdjustEmptyMatchResult(specs.get(i), matches.get(i), emptyMatchTreatment));
+ }
+ return res;
+ }
+
/**
* Like {@link #match(List)}, but for a single resource specification.
@@ -127,6 +152,30 @@
matches);
return matches.get(0);
}
+
+ /** Like {@link #match(String)}, but with a configurable {@link EmptyMatchTreatment}. */
+ public static MatchResult match(String spec, EmptyMatchTreatment emptyMatchTreatment)
+ throws IOException {
+ MatchResult res = match(spec);
+ return maybeAdjustEmptyMatchResult(spec, res, emptyMatchTreatment);
+ }
+
+ private static MatchResult maybeAdjustEmptyMatchResult(
+ String spec, MatchResult res, EmptyMatchTreatment emptyMatchTreatment)
+ throws IOException {
+ if (res.status() != Status.NOT_FOUND) {
+ return res;
+ }
+ boolean notFoundAllowed =
+ emptyMatchTreatment == EmptyMatchTreatment.ALLOW
+ || (FileSystems.hasGlobWildcard(spec)
+ && emptyMatchTreatment == EmptyMatchTreatment.ALLOW_IF_WILDCARD);
+ if (notFoundAllowed) {
+ return MatchResult.create(Status.OK, Collections.<Metadata>emptyList());
+ }
+ return res;
+ }
+
/**
* Returns the {@link Metadata} for a single file resource. Expects a resource specification
* {@code spec} that matches a single result.
@@ -176,7 +225,7 @@
.transform(new Function<ResourceId, String>() {
@Override
public String apply(@Nonnull ResourceId resourceId) {
- return resourceId.toString();
+ return resourceId.toString();
}})
.toList());
}
@@ -423,7 +472,7 @@
Matcher matcher = FILE_SCHEME_PATTERN.matcher(spec);
if (!matcher.matches()) {
- return "file";
+ return DEFAULT_SCHEME;
} else {
return matcher.group("scheme").toLowerCase();
}
@@ -440,11 +489,7 @@
if (rval != null) {
return rval;
}
- rval = schemeToFileSystem.get(DEFAULT_SCHEME);
- if (rval != null) {
- return rval;
- }
- throw new IllegalStateException("Unable to find registrar for " + scheme);
+ return schemeToFileSystem.get(DEFAULT_SCHEME);
}
/********************************** METHODS FOR REGISTRATION **********************************/
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/LocalFileSystem.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/LocalFileSystem.java
index b732bee..5fe894d 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/LocalFileSystem.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/LocalFileSystem.java
@@ -38,6 +38,7 @@
import java.nio.file.PathMatcher;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@@ -46,11 +47,32 @@
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.io.fs.MatchResult.Status;
+import org.apache.commons.lang3.SystemUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link FileSystem} implementation for local files.
+ *
+ * {@link #match} should interpret {@code spec} and resolve paths correctly according to OS being
+ * used. In order to do that specs should be defined in one of the below formats:
+ *
+ * <p>Linux/Mac:
+ * <ul>
+ * <li>pom.xml</li>
+ * <li>/Users/beam/Documents/pom.xml</li>
+ * <li>file:/Users/beam/Documents/pom.xml</li>
+ * <li>file:///Users/beam/Documents/pom.xml</li>
+ * </ul>
+ *
+ * <p>Windows OS:
+ * <ul>
+ * <li>pom.xml</li>
+ * <li>C:/Users/beam/Documents/pom.xml</li>
+ * <li>C:\\Users\\beam\\Documents\\pom.xml</li>
+ * <li>file:/C:/Users/beam/Documents/pom.xml</li>
+ * <li>file:///C:/Users/beam/Documents/pom.xml</li>
+ * </ul>
*/
class LocalFileSystem extends FileSystem<LocalResourceId> {
@@ -176,8 +198,20 @@
}
private MatchResult matchOne(String spec) throws IOException {
- File file = Paths.get(spec).toFile();
+ if (spec.toLowerCase().startsWith("file:")) {
+ spec = spec.substring("file:".length());
+ }
+ if (SystemUtils.IS_OS_WINDOWS) {
+ List<String> prefixes = Arrays.asList("///", "/");
+ for (String prefix : prefixes) {
+ if (spec.toLowerCase().startsWith(prefix)) {
+ spec = spec.substring(prefix.length());
+ }
+ }
+ }
+
+ File file = Paths.get(spec).toFile();
if (file.exists()) {
return MatchResult.create(Status.OK, ImmutableList.of(toMetadata(file)));
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Match.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Match.java
new file mode 100644
index 0000000..bb44fac
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Match.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import com.google.auto.value.AutoValue;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
+import org.apache.beam.sdk.io.fs.MatchResult;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.Watch;
+import org.apache.beam.sdk.transforms.Watch.Growth.PollResult;
+import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition;
+import org.apache.beam.sdk.values.PCollection;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Matches each filepattern in a collection of filepatterns using {@link FileSystems#match}, and
+ * produces a collection of matched resources (both files and directories) as {@link Metadata}.
+ * Resources are not deduplicated between filepatterns, i.e. if the same resource matches multiple
+ * filepatterns, it will be produced multiple times.
+ *
+ * <p>By default, this transform matches each filepattern once and produces a bounded {@link
+ * PCollection}. To continuously watch each filepattern for new matches, use {@link
+ * Filepatterns#continuously(Duration, TerminationCondition)} - this will produce an unbounded
+ * {@link PCollection}.
+ *
+ * <p>By default, filepatterns matching no resources are treated according to {@link
+ * EmptyMatchTreatment#ALLOW_IF_WILDCARD}. To configure this behavior, use {@link
+ * Filepatterns#withEmptyMatchTreatment}.
+ */
+public class Match {
+ private static final Logger LOG = LoggerFactory.getLogger(Match.class);
+
+ /** See {@link Match}. */
+ public static Filepatterns filepatterns() {
+ return new AutoValue_Match_Filepatterns.Builder()
+ .setEmptyMatchTreatment(EmptyMatchTreatment.ALLOW_IF_WILDCARD)
+ .build();
+ }
+
+ /** Implementation of {@link #filepatterns}. */
+ @AutoValue
+ public abstract static class Filepatterns
+ extends PTransform<PCollection<String>, PCollection<Metadata>> {
+ abstract EmptyMatchTreatment getEmptyMatchTreatment();
+
+ @Nullable
+ abstract Duration getWatchInterval();
+
+ @Nullable
+ abstract TerminationCondition<String, ?> getWatchTerminationCondition();
+
+ abstract Builder toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder setEmptyMatchTreatment(EmptyMatchTreatment treatment);
+
+ abstract Builder setWatchInterval(Duration watchInterval);
+
+ abstract Builder setWatchTerminationCondition(TerminationCondition<String, ?> condition);
+
+ abstract Filepatterns build();
+ }
+
+ /**
+ * Sets whether or not filepatterns matching no files are allowed. When using {@link
+ * #continuously}, they are always allowed, and this parameter is ignored.
+ */
+ public Filepatterns withEmptyMatchTreatment(EmptyMatchTreatment treatment) {
+ return toBuilder().setEmptyMatchTreatment(treatment).build();
+ }
+
+ /**
+ * Continuously watches for new resources matching the filepattern, repeatedly matching it at
+ * the given interval, until the given termination condition is reached. The returned {@link
+ * PCollection} is unbounded.
+ *
+ * <p>This works only in runners supporting {@link Experimental.Kind#SPLITTABLE_DO_FN}.
+ *
+ * @see TerminationCondition
+ */
+ @Experimental(Experimental.Kind.SPLITTABLE_DO_FN)
+ public Filepatterns continuously(
+ Duration pollInterval, TerminationCondition<String, ?> terminationCondition) {
+ return toBuilder()
+ .setWatchInterval(pollInterval)
+ .setWatchTerminationCondition(terminationCondition)
+ .build();
+ }
+
+ @Override
+ public PCollection<Metadata> expand(PCollection<String> input) {
+ if (getWatchInterval() == null) {
+ return input.apply("Match filepatterns", ParDo.of(new MatchFn(getEmptyMatchTreatment())));
+ } else {
+ return input
+ .apply(
+ "Continuously match filepatterns",
+ Watch.growthOf(new MatchPollFn())
+ .withPollInterval(getWatchInterval())
+ .withTerminationPerInput(getWatchTerminationCondition()))
+ .apply(Values.<Metadata>create());
+ }
+ }
+
+ private static class MatchFn extends DoFn<String, Metadata> {
+ private final EmptyMatchTreatment emptyMatchTreatment;
+
+ public MatchFn(EmptyMatchTreatment emptyMatchTreatment) {
+ this.emptyMatchTreatment = emptyMatchTreatment;
+ }
+
+ @ProcessElement
+ public void process(ProcessContext c) throws Exception {
+ String filepattern = c.element();
+ MatchResult match = FileSystems.match(filepattern, emptyMatchTreatment);
+ LOG.info("Matched {} files for pattern {}", match.metadata().size(), filepattern);
+ for (Metadata metadata : match.metadata()) {
+ c.output(metadata);
+ }
+ }
+ }
+
+ private static class MatchPollFn implements Watch.Growth.PollFn<String, Metadata> {
+ @Override
+ public PollResult<Metadata> apply(String input, Instant timestamp) throws Exception {
+ return PollResult.incomplete(
+ Instant.now(), FileSystems.match(input, EmptyMatchTreatment.ALLOW).metadata());
+ }
+ }
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
index a07fca8..9b273f8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
@@ -18,7 +18,6 @@
package org.apache.beam.sdk.io;
import javax.annotation.Nullable;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.NameUtils;
@@ -95,17 +94,14 @@
}
@Override
- protected Coder<T> getDefaultOutputCoder() {
- return source.getDefaultOutputCoder();
- }
-
- @Override
public final PCollection<T> expand(PBegin input) {
source.validate();
- return PCollection.<T>createPrimitiveOutputInternal(input.getPipeline(),
- WindowingStrategy.globalDefault(), IsBounded.BOUNDED)
- .setCoder(getDefaultOutputCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ IsBounded.BOUNDED,
+ source.getOutputCoder());
}
/**
@@ -163,16 +159,13 @@
}
@Override
- protected Coder<T> getDefaultOutputCoder() {
- return source.getDefaultOutputCoder();
- }
-
- @Override
public final PCollection<T> expand(PBegin input) {
source.validate();
-
- return PCollection.<T>createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED);
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ IsBounded.UNBOUNDED,
+ source.getOutputCoder());
}
/**
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java
new file mode 100644
index 0000000..990f508
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.io.IOException;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+
+/**
+ * Reads each file in the input {@link PCollection} of {@link Metadata} using given parameters for
+ * splitting files into offset ranges and for creating a {@link FileBasedSource} for a file. The
+ * input {@link PCollection} must not contain {@link ResourceId#isDirectory directories}.
+ *
+ * <p>To obtain the collection of {@link Metadata} from a filepattern, use {@link
+ * Match#filepatterns()}.
+ */
+class ReadAllViaFileBasedSource<T> extends PTransform<PCollection<Metadata>, PCollection<T>> {
+ private final SerializableFunction<String, Boolean> isSplittable;
+ private final long desiredBundleSizeBytes;
+ private final SerializableFunction<String, FileBasedSource<T>> createSource;
+
+ public ReadAllViaFileBasedSource(
+ SerializableFunction<String, Boolean> isSplittable,
+ long desiredBundleSizeBytes,
+ SerializableFunction<String, FileBasedSource<T>> createSource) {
+ this.isSplittable = isSplittable;
+ this.desiredBundleSizeBytes = desiredBundleSizeBytes;
+ this.createSource = createSource;
+ }
+
+ @Override
+ public PCollection<T> expand(PCollection<Metadata> input) {
+ return input
+ .apply(
+ "Split into ranges",
+ ParDo.of(new SplitIntoRangesFn(isSplittable, desiredBundleSizeBytes)))
+ .apply("Reshuffle", new ReshuffleWithUniqueKey<KV<Metadata, OffsetRange>>())
+ .apply("Read ranges", ParDo.of(new ReadFileRangesFn<T>(createSource)));
+ }
+
+ private static class ReshuffleWithUniqueKey<T>
+ extends PTransform<PCollection<T>, PCollection<T>> {
+ @Override
+ public PCollection<T> expand(PCollection<T> input) {
+ return input
+ .apply("Unique key", ParDo.of(new AssignUniqueKeyFn<T>()))
+ .apply("Reshuffle", Reshuffle.<Integer, T>of())
+ .apply("Values", Values.<T>create());
+ }
+ }
+
+ private static class AssignUniqueKeyFn<T> extends DoFn<T, KV<Integer, T>> {
+ private int index;
+
+ @Setup
+ public void setup() {
+ this.index = ThreadLocalRandom.current().nextInt();
+ }
+
+ @ProcessElement
+ public void process(ProcessContext c) {
+ c.output(KV.of(++index, c.element()));
+ }
+ }
+
+ private static class SplitIntoRangesFn extends DoFn<Metadata, KV<Metadata, OffsetRange>> {
+ private final SerializableFunction<String, Boolean> isSplittable;
+ private final long desiredBundleSizeBytes;
+
+ private SplitIntoRangesFn(
+ SerializableFunction<String, Boolean> isSplittable, long desiredBundleSizeBytes) {
+ this.isSplittable = isSplittable;
+ this.desiredBundleSizeBytes = desiredBundleSizeBytes;
+ }
+
+ @ProcessElement
+ public void process(ProcessContext c) {
+ Metadata metadata = c.element();
+ checkArgument(
+ !metadata.resourceId().isDirectory(),
+ "Resource %s is a directory",
+ metadata.resourceId());
+ if (!metadata.isReadSeekEfficient()
+ || !isSplittable.apply(metadata.resourceId().toString())) {
+ c.output(KV.of(metadata, new OffsetRange(0, metadata.sizeBytes())));
+ return;
+ }
+ for (OffsetRange range :
+ new OffsetRange(0, metadata.sizeBytes()).split(desiredBundleSizeBytes, 0)) {
+ c.output(KV.of(metadata, range));
+ }
+ }
+ }
+
+ private static class ReadFileRangesFn<T> extends DoFn<KV<Metadata, OffsetRange>, T> {
+ private final SerializableFunction<String, FileBasedSource<T>> createSource;
+
+ private ReadFileRangesFn(SerializableFunction<String, FileBasedSource<T>> createSource) {
+ this.createSource = createSource;
+ }
+
+ @ProcessElement
+ public void process(ProcessContext c) throws IOException {
+ Metadata metadata = c.element().getKey();
+ OffsetRange range = c.element().getValue();
+ FileBasedSource<T> source = createSource.apply(metadata.toString());
+ try (BoundedSource.BoundedReader<T> reader =
+ source
+ .createForSubrangeOfFile(metadata, range.getFrom(), range.getTo())
+ .createReader(c.getPipelineOptions())) {
+ for (boolean more = reader.start(); more; more = reader.advance()) {
+ c.output(reader.getCurrent());
+ }
+ }
+ }
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Source.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Source.java
index 542d91c..872c135 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Source.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Source.java
@@ -61,10 +61,28 @@
*/
public abstract void validate();
- /**
- * Returns the default {@code Coder} to use for the data read from this source.
- */
- public abstract Coder<T> getDefaultOutputCoder();
+ /** @deprecated Override {@link #getOutputCoder()} instead. */
+ @Deprecated
+ public Coder<T> getDefaultOutputCoder() {
+ // If the subclass doesn't override getDefaultOutputCoder(), hopefully it overrides the proper
+ // version - getOutputCoder(). Check that it does, before calling the method (if subclass
+ // doesn't override it, we'll call the default implementation and get infinite recursion).
+ try {
+ if (getClass().getMethod("getOutputCoder").getDeclaringClass().equals(Source.class)) {
+ throw new UnsupportedOperationException(
+ getClass() + " needs to override getOutputCoder().");
+ }
+ } catch (NoSuchMethodException e) {
+ throw new RuntimeException(e);
+ }
+ return getOutputCoder();
+ }
+
+ /** Returns the {@code Coder} to use for the data read from this source. */
+ public Coder<T> getOutputCoder() {
+ // Call the old method for compatibility.
+ return getDefaultOutputCoder();
+ }
/**
* {@inheritDoc}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
index 6e7b243..c75051f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -35,8 +35,6 @@
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.io.Read.Bounded;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.io.fs.ResourceId;
@@ -45,7 +43,6 @@
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.values.PBegin;
@@ -172,11 +169,7 @@
}
}
- final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
- PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
- // Honor the default output coder that would have been used by this PTransform.
- pcol.setCoder(getDefaultOutputCoder());
- return pcol;
+ return input.apply("Read", org.apache.beam.sdk.io.Read.from(getSource()));
}
// Helper to create a source specific to the requested compression type.
@@ -213,11 +206,6 @@
.addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
.withLabel("File Pattern"));
}
-
- @Override
- protected Coder<byte[]> getDefaultOutputCoder() {
- return ByteArrayCoder.of();
- }
}
/////////////////////////////////////////////////////////////////////////////
@@ -357,10 +345,12 @@
checkState(getOutputPrefix() != null,
"need to set the output prefix of a TFRecordIO.Write transform");
WriteFiles<byte[], Void, byte[]> write =
- WriteFiles.<byte[], Void, byte[]>to(
+ WriteFiles.to(
new TFRecordSink(
- getOutputPrefix(), getShardTemplate(), getFilenameSuffix(), getCompressionType()),
- SerializableFunctions.<byte[]>identity());
+ getOutputPrefix(),
+ getShardTemplate(),
+ getFilenameSuffix(),
+ getCompressionType()));
if (getNumShards() > 0) {
write = write.withNumShards(getNumShards());
}
@@ -390,11 +380,6 @@
.add(DisplayData.item("compressionType", getCompressionType().toString())
.withLabel("Compression Type"));
}
-
- @Override
- protected Coder<Void> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
}
/**
@@ -473,7 +458,7 @@
}
@Override
- public Coder<byte[]> getDefaultOutputCoder() {
+ public Coder<byte[]> getOutputCoder() {
return DEFAULT_BYTE_ARRAY_CODER;
}
@@ -548,7 +533,7 @@
/** A {@link FileBasedSink} for TFRecord files. Produces TFRecord files. */
@VisibleForTesting
- static class TFRecordSink extends FileBasedSink<byte[], Void> {
+ static class TFRecordSink extends FileBasedSink<byte[], Void, byte[]> {
@VisibleForTesting
TFRecordSink(
ValueProvider<ResourceId> outputPrefix,
@@ -557,7 +542,7 @@
TFRecordIO.CompressionType compressionType) {
super(
outputPrefix,
- DynamicFileDestinations.constant(
+ DynamicFileDestinations.<byte[]>constant(
DefaultFilenamePolicy.fromStandardParameters(
outputPrefix, shardTemplate, suffix, false)),
writableByteChannelFactory(compressionType));
@@ -571,7 +556,7 @@
}
@Override
- public WriteOperation<byte[], Void> createWriteOperation() {
+ public WriteOperation<Void, byte[]> createWriteOperation() {
return new TFRecordWriteOperation(this);
}
@@ -591,23 +576,23 @@
}
/** A {@link WriteOperation WriteOperation} for TFRecord files. */
- private static class TFRecordWriteOperation extends WriteOperation<byte[], Void> {
+ private static class TFRecordWriteOperation extends WriteOperation<Void, byte[]> {
private TFRecordWriteOperation(TFRecordSink sink) {
super(sink);
}
@Override
- public Writer<byte[], Void> createWriter() throws Exception {
+ public Writer<Void, byte[]> createWriter() throws Exception {
return new TFRecordWriter(this);
}
}
/** A {@link Writer Writer} for TFRecord files. */
- private static class TFRecordWriter extends Writer<byte[], Void> {
+ private static class TFRecordWriter extends Writer<Void, byte[]> {
private WritableByteChannel outChannel;
private TFRecordCodec codec;
- private TFRecordWriter(WriteOperation<byte[], Void> writeOperation) {
+ private TFRecordWriter(WriteOperation<Void, byte[]> writeOperation) {
super(writeOperation, MimeTypes.BINARY);
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
index 7b4c483..612f5c5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
@@ -23,52 +23,53 @@
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
-import java.io.IOException;
-import java.util.concurrent.ThreadLocalRandom;
+import com.google.common.base.Predicates;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import java.util.List;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.CompressedSource.CompressionMode;
import org.apache.beam.sdk.io.DefaultFilenamePolicy.Params;
import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations;
import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy;
import org.apache.beam.sdk.io.FileBasedSink.WritableByteChannelFactory;
-import org.apache.beam.sdk.io.Read.Bounded;
-import org.apache.beam.sdk.io.fs.MatchResult;
-import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
-import org.apache.beam.sdk.io.fs.MatchResult.Status;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.ResourceId;
-import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
-import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
-import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition;
import org.apache.beam.sdk.transforms.display.DisplayData;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
+import org.joda.time.Duration;
/**
* {@link PTransform}s for reading and writing text files.
*
* <p>To read a {@link PCollection} from one or more text files, use {@code TextIO.read()} to
* instantiate a transform and use {@link TextIO.Read#from(String)} to specify the path of the
- * file(s) to be read. Alternatively, if the filenames to be read are themselves in a
- * {@link PCollection}, apply {@link TextIO#readAll()}.
+ * file(s) to be read. Alternatively, if the filenames to be read are themselves in a {@link
+ * PCollection}, apply {@link TextIO#readAll()}.
*
- * <p>{@link TextIO.Read} returns a {@link PCollection} of {@link String Strings}, each
- * corresponding to one line of an input UTF-8 text file (split into lines delimited by '\n', '\r',
- * or '\r\n').
+ * <p>{@link #read} returns a {@link PCollection} of {@link String Strings}, each corresponding to
+ * one line of an input UTF-8 text file (split into lines delimited by '\n', '\r', or '\r\n').
+ *
+ * <p>By default, the filepatterns are expanded only once. {@link Read#watchForNewFiles} and {@link
+ * ReadAll#watchForNewFiles} allow streaming of new files matching the filepattern(s).
+ *
+ * <p>By default, {@link #read} prohibits filepatterns that match no files, and {@link #readAll}
+ * allows them in case the filepattern contains a glob wildcard character. Use {@link
+ * TextIO.Read#withEmptyMatchTreatment} and {@link TextIO.ReadAll#withEmptyMatchTreatment} to
+ * configure this behavior.
*
* <p>Example 1: reading a file or filepattern.
*
@@ -79,6 +80,11 @@
* PCollection<String> lines = p.apply(TextIO.read().from("/local/path/to/file.txt"));
* }</pre>
*
+ * <p>If it is known that the filepattern will match a very large number of files (e.g. tens of
+ * thousands or more), use {@link Read#withHintMatchesManyFiles} for better performance and
+ * scalability. Note that it may decrease performance if the filepattern matches only a small number
+ * of files.
+ *
* <p>Example 2: reading a PCollection of filenames.
*
* <pre>{@code
@@ -92,6 +98,20 @@
* PCollection<String> lines = filenames.apply(TextIO.readAll());
* }</pre>
*
+ * <p>Example 3: streaming new files matching a filepattern.
+ *
+ * <pre>{@code
+ * Pipeline p = ...;
+ *
+ * PCollection<String> lines = p.apply(TextIO.read()
+ * .from("/local/path/to/files/*")
+ * .watchForNewFiles(
+ * // Check for new files every minute
+ * Duration.standardMinutes(1),
+ * // Stop watching the filepattern if no new files appear within an hour
+ * afterTimeSinceNewOutput(Duration.standardHours(1))));
+ * }</pre>
+ *
* <p>To write a {@link PCollection} to one or more text files, use {@code TextIO.write()}, using
* {@link TextIO.Write#to(String)} to specify the output prefix of the files to write.
*
@@ -127,9 +147,9 @@
* allows you to convert any input value into a custom destination object, and map that destination
* object to a {@link FilenamePolicy}. This allows using different filename policies (or more
* commonly, differently-configured instances of the same policy) based on the input record. Often
- * this is used in conjunction with {@link TextIO#writeCustomType(SerializableFunction)}, which
- * allows your {@link DynamicDestinations} object to examine the input type and takes a format
- * function to convert that type to a string for writing.
+ * this is used in conjunction with {@link TextIO#writeCustomType}, which allows your {@link
+ * DynamicDestinations} object to examine the input type and takes a format function to convert that
+ * type to a string for writing.
*
* <p>A convenience shortcut is provided for the case where the default naming policy is used, but
* different configurations of this policy are wanted based on the input record. Default naming
@@ -154,7 +174,11 @@
* {@link PCollection} containing one element for each line of the input files.
*/
public static Read read() {
- return new AutoValue_TextIO_Read.Builder().setCompressionType(CompressionType.AUTO).build();
+ return new AutoValue_TextIO_Read.Builder()
+ .setCompressionType(CompressionType.AUTO)
+ .setHintMatchesManyFiles(false)
+ .setEmptyMatchTreatment(EmptyMatchTreatment.DISALLOW)
+ .build();
}
/**
@@ -174,6 +198,7 @@
// but is not so large as to exhaust a typical runner's maximum amount of output per
// ProcessElement call.
.setDesiredBundleSizeBytes(64 * 1024 * 1024L)
+ .setEmptyMatchTreatment(EmptyMatchTreatment.ALLOW_IF_WILDCARD)
.build();
}
@@ -192,20 +217,23 @@
* line.
*
* <p>This version allows you to apply {@link TextIO} writes to a PCollection of a custom type
- * {@link T}, along with a format function that converts the input type {@link T} to the String
- * that will be written to the file. The advantage of this is it allows a user-provided {@link
+ * {@link UserT}. A format mechanism that converts the input type {@link UserT} to the String that
+ * will be written to the file must be specified. If using a custom {@link DynamicDestinations}
+ * object this is done using {@link DynamicDestinations#formatRecord}, otherwise the {@link
+ * TypedWrite#withFormatFunction} can be used to specify a format function.
+ *
+ * <p>The advantage of using a custom type is that is it allows a user-provided {@link
* DynamicDestinations} object, set via {@link Write#to(DynamicDestinations)} to examine the
- * user's custom type when choosing a destination.
+ * custom type when choosing a destination.
*/
- public static <T> TypedWrite<T> writeCustomType(SerializableFunction<T, String> formatFunction) {
- return new AutoValue_TextIO_TypedWrite.Builder<T>()
+ public static <UserT> TypedWrite<UserT> writeCustomType() {
+ return new AutoValue_TextIO_TypedWrite.Builder<UserT>()
.setFilenamePrefix(null)
.setTempDirectory(null)
.setShardTemplate(null)
.setFilenameSuffix(null)
.setFilenamePolicy(null)
.setDynamicDestinations(null)
- .setFormatFunction(formatFunction)
.setWritableByteChannelFactory(FileBasedSink.CompressionType.UNCOMPRESSED)
.setWindowedWrites(false)
.setNumShards(0)
@@ -218,12 +246,25 @@
@Nullable abstract ValueProvider<String> getFilepattern();
abstract CompressionType getCompressionType();
+ @Nullable
+ abstract Duration getWatchForNewFilesInterval();
+
+ @Nullable
+ abstract TerminationCondition getWatchForNewFilesTerminationCondition();
+
+ abstract boolean getHintMatchesManyFiles();
+ abstract EmptyMatchTreatment getEmptyMatchTreatment();
+
abstract Builder toBuilder();
@AutoValue.Builder
abstract static class Builder {
abstract Builder setFilepattern(ValueProvider<String> filepattern);
abstract Builder setCompressionType(CompressionType compressionType);
+ abstract Builder setWatchForNewFilesInterval(Duration watchForNewFilesInterval);
+ abstract Builder setWatchForNewFilesTerminationCondition(TerminationCondition condition);
+ abstract Builder setHintMatchesManyFiles(boolean hintManyFiles);
+ abstract Builder setEmptyMatchTreatment(EmptyMatchTreatment treatment);
abstract Read build();
}
@@ -237,6 +278,9 @@
*
* <p>Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html" >Java
* Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
+ *
+ * <p>If it is known that the filepattern will match a very large number of files (at least tens
+ * of thousands), use {@link #withHintMatchesManyFiles} for better performance and scalability.
*/
public Read from(String filepattern) {
checkNotNull(filepattern, "Filepattern cannot be empty.");
@@ -250,8 +294,7 @@
}
/**
- * Returns a new transform for reading from text files that's like this one but
- * reads from input sources using the specified compression type.
+ * Reads from input sources using the specified compression type.
*
* <p>If no compression type is specified, the default is {@link TextIO.CompressionType#AUTO}.
*/
@@ -259,22 +302,70 @@
return toBuilder().setCompressionType(compressionType).build();
}
+ /**
+ * Continuously watches for new files matching the filepattern, polling it at the given
+ * interval, until the given termination condition is reached. The returned {@link PCollection}
+ * is unbounded.
+ *
+ * <p>This works only in runners supporting {@link Kind#SPLITTABLE_DO_FN}.
+ *
+ * @see TerminationCondition
+ */
+ @Experimental(Kind.SPLITTABLE_DO_FN)
+ public Read watchForNewFiles(Duration pollInterval, TerminationCondition terminationCondition) {
+ return toBuilder()
+ .setWatchForNewFilesInterval(pollInterval)
+ .setWatchForNewFilesTerminationCondition(terminationCondition)
+ .build();
+ }
+
+ /**
+ * Hints that the filepattern specified in {@link #from(String)} matches a very large number of
+ * files.
+ *
+ * <p>This hint may cause a runner to execute the transform differently, in a way that improves
+ * performance for this case, but it may worsen performance if the filepattern matches only
+ * a small number of files (e.g., in a runner that supports dynamic work rebalancing, it will
+ * happen less efficiently within individual files).
+ */
+ public Read withHintMatchesManyFiles() {
+ return toBuilder().setHintMatchesManyFiles(true).build();
+ }
+
+ /**
+ * Configures whether or not a filepattern matching no files is allowed. When using {@link
+ * #watchForNewFiles}, it is always allowed and this parameter is ignored.
+ */
+ public Read withEmptyMatchTreatment(EmptyMatchTreatment treatment) {
+ return toBuilder().setEmptyMatchTreatment(treatment).build();
+ }
+
@Override
public PCollection<String> expand(PBegin input) {
- if (getFilepattern() == null) {
- throw new IllegalStateException("need to set the filepattern of a TextIO.Read transform");
+ checkNotNull(getFilepattern(), "need to set the filepattern of a TextIO.Read transform");
+ if (getWatchForNewFilesInterval() == null && !getHintMatchesManyFiles()) {
+ return input.apply("Read", org.apache.beam.sdk.io.Read.from(getSource()));
}
-
- final Bounded<String> read = org.apache.beam.sdk.io.Read.from(getSource());
- PCollection<String> pcol = input.getPipeline().apply("Read", read);
- // Honor the default output coder that would have been used by this PTransform.
- pcol.setCoder(getDefaultOutputCoder());
- return pcol;
+ // All other cases go through ReadAll.
+ ReadAll readAll =
+ readAll()
+ .withCompressionType(getCompressionType())
+ .withEmptyMatchTreatment(getEmptyMatchTreatment());
+ if (getWatchForNewFilesInterval() != null) {
+ readAll =
+ readAll.watchForNewFiles(
+ getWatchForNewFilesInterval(), getWatchForNewFilesTerminationCondition());
+ }
+ return input
+ .apply("Create filepattern", Create.ofProvider(getFilepattern(), StringUtf8Coder.of()))
+ .apply("Via ReadAll", readAll);
}
// Helper to create a source specific to the requested compression type.
protected FileBasedSource<String> getSource() {
- return wrapWithCompression(new TextSource(getFilepattern()), getCompressionType());
+ return wrapWithCompression(
+ new TextSource(getFilepattern(), getEmptyMatchTreatment()),
+ getCompressionType());
}
private static FileBasedSource<String> wrapWithCompression(
@@ -312,15 +403,17 @@
String filepatternDisplay = getFilepattern().isAccessible()
? getFilepattern().get() : getFilepattern().toString();
builder
- .add(DisplayData.item("compressionType", getCompressionType().toString())
- .withLabel("Compression Type"))
- .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
- .withLabel("File Pattern"));
- }
-
- @Override
- protected Coder<String> getDefaultOutputCoder() {
- return StringUtf8Coder.of();
+ .add(
+ DisplayData.item("compressionType", getCompressionType().toString())
+ .withLabel("Compression Type"))
+ .addIfNotNull(
+ DisplayData.item("filePattern", filepatternDisplay).withLabel("File Pattern"))
+ .add(
+ DisplayData.item("emptyMatchTreatment", getEmptyMatchTreatment().toString())
+ .withLabel("Treatment of filepatterns that match no files"))
+ .addIfNotNull(
+ DisplayData.item("watchForNewFilesInterval", getWatchForNewFilesInterval())
+ .withLabel("Interval to watch for new files"));
}
}
@@ -331,6 +424,14 @@
public abstract static class ReadAll
extends PTransform<PCollection<String>, PCollection<String>> {
abstract CompressionType getCompressionType();
+
+ @Nullable
+ abstract Duration getWatchForNewFilesInterval();
+
+ @Nullable
+ abstract TerminationCondition<String, ?> getWatchForNewFilesTerminationCondition();
+
+ abstract EmptyMatchTreatment getEmptyMatchTreatment();
abstract long getDesiredBundleSizeBytes();
abstract Builder toBuilder();
@@ -338,6 +439,10 @@
@AutoValue.Builder
abstract static class Builder {
abstract Builder setCompressionType(CompressionType compressionType);
+ abstract Builder setWatchForNewFilesInterval(Duration watchForNewFilesInterval);
+ abstract Builder setWatchForNewFilesTerminationCondition(
+ TerminationCondition<String, ?> condition);
+ abstract Builder setEmptyMatchTreatment(EmptyMatchTreatment treatment);
abstract Builder setDesiredBundleSizeBytes(long desiredBundleSizeBytes);
abstract ReadAll build();
@@ -348,6 +453,21 @@
return toBuilder().setCompressionType(compressionType).build();
}
+ /** Same as {@link Read#withEmptyMatchTreatment}. */
+ public ReadAll withEmptyMatchTreatment(EmptyMatchTreatment treatment) {
+ return toBuilder().setEmptyMatchTreatment(treatment).build();
+ }
+
+ /** Same as {@link Read#watchForNewFiles(Duration, TerminationCondition)}. */
+ @Experimental(Kind.SPLITTABLE_DO_FN)
+ public ReadAll watchForNewFiles(
+ Duration pollInterval, TerminationCondition<String, ?> terminationCondition) {
+ return toBuilder()
+ .setWatchForNewFilesInterval(pollInterval)
+ .setWatchForNewFilesTerminationCondition(terminationCondition)
+ .build();
+ }
+
@VisibleForTesting
ReadAll withDesiredBundleSizeBytes(long desiredBundleSizeBytes) {
return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build();
@@ -355,130 +475,71 @@
@Override
public PCollection<String> expand(PCollection<String> input) {
+ Match.Filepatterns matchFilepatterns =
+ Match.filepatterns().withEmptyMatchTreatment(getEmptyMatchTreatment());
+ if (getWatchForNewFilesInterval() != null) {
+ matchFilepatterns =
+ matchFilepatterns.continuously(
+ getWatchForNewFilesInterval(), getWatchForNewFilesTerminationCondition());
+ }
return input
- .apply("Expand glob", ParDo.of(new ExpandGlobFn()))
+ .apply(matchFilepatterns)
.apply(
- "Split into ranges",
- ParDo.of(new SplitIntoRangesFn(getCompressionType(), getDesiredBundleSizeBytes())))
- .apply("Reshuffle", new ReshuffleWithUniqueKey<KV<Metadata, OffsetRange>>())
- .apply("Read", ParDo.of(new ReadTextFn(this)));
+ "Read all via FileBasedSource",
+ new ReadAllViaFileBasedSource<>(
+ new IsSplittableFn(getCompressionType()),
+ getDesiredBundleSizeBytes(),
+ new CreateTextSourceFn(getCompressionType(), getEmptyMatchTreatment())))
+ .setCoder(StringUtf8Coder.of());
}
- private static class ReshuffleWithUniqueKey<T>
- extends PTransform<PCollection<T>, PCollection<T>> {
- @Override
- public PCollection<T> expand(PCollection<T> input) {
- return input
- .apply("Unique key", ParDo.of(new AssignUniqueKeyFn<T>()))
- .apply("Reshuffle", Reshuffle.<Integer, T>of())
- .apply("Values", Values.<T>create());
- }
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+
+ builder.add(
+ DisplayData.item("compressionType", getCompressionType().toString())
+ .withLabel("Compression Type"));
}
- private static class AssignUniqueKeyFn<T> extends DoFn<T, KV<Integer, T>> {
- private int index;
-
- @Setup
- public void setup() {
- this.index = ThreadLocalRandom.current().nextInt();
- }
-
- @ProcessElement
- public void process(ProcessContext c) {
- c.output(KV.of(++index, c.element()));
- }
- }
-
- private static class ExpandGlobFn extends DoFn<String, Metadata> {
- @ProcessElement
- public void process(ProcessContext c) throws Exception {
- MatchResult match = FileSystems.match(c.element());
- checkArgument(
- match.status().equals(Status.OK),
- "Failed to match filepattern %s: %s",
- c.element(),
- match.status());
- for (Metadata metadata : match.metadata()) {
- c.output(metadata);
- }
- }
- }
-
- private static class SplitIntoRangesFn extends DoFn<Metadata, KV<Metadata, OffsetRange>> {
+ private static class CreateTextSourceFn
+ implements SerializableFunction<String, FileBasedSource<String>> {
private final CompressionType compressionType;
- private final long desiredBundleSize;
+ private final EmptyMatchTreatment emptyMatchTreatment;
- private SplitIntoRangesFn(CompressionType compressionType, long desiredBundleSize) {
+ private CreateTextSourceFn(
+ CompressionType compressionType, EmptyMatchTreatment emptyMatchTreatment) {
this.compressionType = compressionType;
- this.desiredBundleSize = desiredBundleSize;
+ this.emptyMatchTreatment = emptyMatchTreatment;
}
- @ProcessElement
- public void process(ProcessContext c) {
- Metadata metadata = c.element();
- final boolean isSplittable = isSplittable(metadata, compressionType);
- if (!isSplittable) {
- c.output(KV.of(metadata, new OffsetRange(0, metadata.sizeBytes())));
- return;
- }
- for (OffsetRange range :
- new OffsetRange(0, metadata.sizeBytes()).split(desiredBundleSize, 0)) {
- c.output(KV.of(metadata, range));
- }
- }
-
- static boolean isSplittable(Metadata metadata, CompressionType compressionType) {
- if (!metadata.isReadSeekEfficient()) {
- return false;
- }
- switch (compressionType) {
- case AUTO:
- return !CompressionMode.isCompressed(metadata.resourceId().toString());
- case UNCOMPRESSED:
- return true;
- case GZIP:
- case BZIP2:
- case ZIP:
- case DEFLATE:
- return false;
- default:
- throw new UnsupportedOperationException("Unknown compression type: " + compressionType);
- }
+ @Override
+ public FileBasedSource<String> apply(String input) {
+ return Read.wrapWithCompression(
+ new TextSource(StaticValueProvider.of(input), emptyMatchTreatment), compressionType);
}
}
- private static class ReadTextFn extends DoFn<KV<Metadata, OffsetRange>, String> {
- private final TextIO.ReadAll spec;
+ private static class IsSplittableFn implements SerializableFunction<String, Boolean> {
+ private final CompressionType compressionType;
- private ReadTextFn(ReadAll spec) {
- this.spec = spec;
+ private IsSplittableFn(CompressionType compressionType) {
+ this.compressionType = compressionType;
}
- @ProcessElement
- public void process(ProcessContext c) throws IOException {
- Metadata metadata = c.element().getKey();
- OffsetRange range = c.element().getValue();
- FileBasedSource<String> source =
- TextIO.Read.wrapWithCompression(
- new TextSource(StaticValueProvider.of(metadata.toString())),
- spec.getCompressionType());
- try (BoundedSource.BoundedReader<String> reader =
- source
- .createForSubrangeOfFile(metadata, range.getFrom(), range.getTo())
- .createReader(c.getPipelineOptions())) {
- for (boolean more = reader.start(); more; more = reader.advance()) {
- c.output(reader.getCurrent());
- }
- }
+ @Override
+ public Boolean apply(String filename) {
+ return compressionType == CompressionType.UNCOMPRESSED
+ || (compressionType == CompressionType.AUTO && !CompressionMode.isCompressed(filename));
}
}
}
- /////////////////////////////////////////////////////////////////////////////
+ // ///////////////////////////////////////////////////////////////////////////
/** Implementation of {@link #write}. */
@AutoValue
- public abstract static class TypedWrite<T> extends PTransform<PCollection<T>, PDone> {
+ public abstract static class TypedWrite<UserT> extends PTransform<PCollection<UserT>, PDone> {
/** The prefix of each file written, combined with suffix and shardTemplate. */
@Nullable abstract ValueProvider<ResourceId> getFilenamePrefix();
@@ -506,10 +567,19 @@
/** Allows for value-dependent {@link DynamicDestinations} to be vended. */
@Nullable
- abstract DynamicDestinations<T, ?> getDynamicDestinations();
+ abstract DynamicDestinations<UserT, ?, String> getDynamicDestinations();
- /** A function that converts T to a String, for writing to the file. */
- abstract SerializableFunction<T, String> getFormatFunction();
+ @Nullable
+ /** A destination function for using {@link DefaultFilenamePolicy} */
+ abstract SerializableFunction<UserT, Params> getDestinationFunction();
+
+ @Nullable
+ /** A default destination for empty PCollections. */
+ abstract Params getEmptyDestination();
+
+ /** A function that converts UserT to a String, for writing to the file. */
+ @Nullable
+ abstract SerializableFunction<UserT, String> getFormatFunction();
/** Whether to write windowed output files. */
abstract boolean getWindowedWrites();
@@ -520,37 +590,42 @@
*/
abstract WritableByteChannelFactory getWritableByteChannelFactory();
- abstract Builder<T> toBuilder();
+ abstract Builder<UserT> toBuilder();
@AutoValue.Builder
- abstract static class Builder<T> {
- abstract Builder<T> setFilenamePrefix(ValueProvider<ResourceId> filenamePrefix);
+ abstract static class Builder<UserT> {
+ abstract Builder<UserT> setFilenamePrefix(ValueProvider<ResourceId> filenamePrefix);
- abstract Builder<T> setTempDirectory(ValueProvider<ResourceId> tempDirectory);
+ abstract Builder<UserT> setTempDirectory(ValueProvider<ResourceId> tempDirectory);
- abstract Builder<T> setShardTemplate(@Nullable String shardTemplate);
+ abstract Builder<UserT> setShardTemplate(@Nullable String shardTemplate);
- abstract Builder<T> setFilenameSuffix(@Nullable String filenameSuffix);
+ abstract Builder<UserT> setFilenameSuffix(@Nullable String filenameSuffix);
- abstract Builder<T> setHeader(@Nullable String header);
+ abstract Builder<UserT> setHeader(@Nullable String header);
- abstract Builder<T> setFooter(@Nullable String footer);
+ abstract Builder<UserT> setFooter(@Nullable String footer);
- abstract Builder<T> setFilenamePolicy(@Nullable FilenamePolicy filenamePolicy);
+ abstract Builder<UserT> setFilenamePolicy(@Nullable FilenamePolicy filenamePolicy);
- abstract Builder<T> setDynamicDestinations(
- @Nullable DynamicDestinations<T, ?> dynamicDestinations);
+ abstract Builder<UserT> setDynamicDestinations(
+ @Nullable DynamicDestinations<UserT, ?, String> dynamicDestinations);
- abstract Builder<T> setFormatFunction(SerializableFunction<T, String> formatFunction);
+ abstract Builder<UserT> setDestinationFunction(
+ @Nullable SerializableFunction<UserT, Params> destinationFunction);
- abstract Builder<T> setNumShards(int numShards);
+ abstract Builder<UserT> setEmptyDestination(Params emptyDestination);
- abstract Builder<T> setWindowedWrites(boolean windowedWrites);
+ abstract Builder<UserT> setFormatFunction(SerializableFunction<UserT, String> formatFunction);
- abstract Builder<T> setWritableByteChannelFactory(
+ abstract Builder<UserT> setNumShards(int numShards);
+
+ abstract Builder<UserT> setWindowedWrites(boolean windowedWrites);
+
+ abstract Builder<UserT> setWritableByteChannelFactory(
WritableByteChannelFactory writableByteChannelFactory);
- abstract TypedWrite<T> build();
+ abstract TypedWrite<UserT> build();
}
/**
@@ -570,18 +645,18 @@
* <p>If {@link #withTempDirectory} has not been called, this filename prefix will be used to
* infer a directory for temporary files.
*/
- public TypedWrite<T> to(String filenamePrefix) {
+ public TypedWrite<UserT> to(String filenamePrefix) {
return to(FileBasedSink.convertToFileResourceIfPossible(filenamePrefix));
}
/** Like {@link #to(String)}. */
@Experimental(Kind.FILESYSTEM)
- public TypedWrite<T> to(ResourceId filenamePrefix) {
+ public TypedWrite<UserT> to(ResourceId filenamePrefix) {
return toResource(StaticValueProvider.of(filenamePrefix));
}
/** Like {@link #to(String)}. */
- public TypedWrite<T> to(ValueProvider<String> outputPrefix) {
+ public TypedWrite<UserT> to(ValueProvider<String> outputPrefix) {
return toResource(NestedValueProvider.of(outputPrefix,
new SerializableFunction<String, ResourceId>() {
@Override
@@ -595,7 +670,7 @@
* Writes to files named according to the given {@link FileBasedSink.FilenamePolicy}. A
* directory for temporary files must be specified using {@link #withTempDirectory}.
*/
- public TypedWrite<T> to(FilenamePolicy filenamePolicy) {
+ public TypedWrite<UserT> to(FilenamePolicy filenamePolicy) {
return toBuilder().setFilenamePolicy(filenamePolicy).build();
}
@@ -604,7 +679,7 @@
* objects can examine the input record when creating a {@link FilenamePolicy}. A directory for
* temporary files must be specified using {@link #withTempDirectory}.
*/
- public TypedWrite<T> to(DynamicDestinations<T, ?> dynamicDestinations) {
+ public TypedWrite<UserT> to(DynamicDestinations<UserT, ?, String> dynamicDestinations) {
return toBuilder().setDynamicDestinations(dynamicDestinations).build();
}
@@ -615,26 +690,39 @@
* emptyDestination parameter specified where empty files should be written for when the written
* {@link PCollection} is empty.
*/
- public TypedWrite<T> to(
- SerializableFunction<T, Params> destinationFunction, Params emptyDestination) {
- return to(DynamicFileDestinations.toDefaultPolicies(destinationFunction, emptyDestination));
+ public TypedWrite<UserT> to(
+ SerializableFunction<UserT, Params> destinationFunction, Params emptyDestination) {
+ return toBuilder()
+ .setDestinationFunction(destinationFunction)
+ .setEmptyDestination(emptyDestination)
+ .build();
}
/** Like {@link #to(ResourceId)}. */
@Experimental(Kind.FILESYSTEM)
- public TypedWrite<T> toResource(ValueProvider<ResourceId> filenamePrefix) {
+ public TypedWrite<UserT> toResource(ValueProvider<ResourceId> filenamePrefix) {
return toBuilder().setFilenamePrefix(filenamePrefix).build();
}
+ /**
+ * Specifies a format function to convert {@link UserT} to the output type. If {@link
+ * #to(DynamicDestinations)} is used, {@link DynamicDestinations#formatRecord(Object)} must be
+ * used instead.
+ */
+ public TypedWrite<UserT> withFormatFunction(
+ SerializableFunction<UserT, String> formatFunction) {
+ return toBuilder().setFormatFunction(formatFunction).build();
+ }
+
/** Set the base directory used to generate temporary files. */
@Experimental(Kind.FILESYSTEM)
- public TypedWrite<T> withTempDirectory(ValueProvider<ResourceId> tempDirectory) {
+ public TypedWrite<UserT> withTempDirectory(ValueProvider<ResourceId> tempDirectory) {
return toBuilder().setTempDirectory(tempDirectory).build();
}
/** Set the base directory used to generate temporary files. */
@Experimental(Kind.FILESYSTEM)
- public TypedWrite<T> withTempDirectory(ResourceId tempDirectory) {
+ public TypedWrite<UserT> withTempDirectory(ResourceId tempDirectory) {
return withTempDirectory(StaticValueProvider.of(tempDirectory));
}
@@ -646,7 +734,7 @@
* <p>See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are
* used.
*/
- public TypedWrite<T> withShardNameTemplate(String shardTemplate) {
+ public TypedWrite<UserT> withShardNameTemplate(String shardTemplate) {
return toBuilder().setShardTemplate(shardTemplate).build();
}
@@ -658,7 +746,7 @@
* <p>See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are
* used.
*/
- public TypedWrite<T> withSuffix(String filenameSuffix) {
+ public TypedWrite<UserT> withSuffix(String filenameSuffix) {
return toBuilder().setFilenameSuffix(filenameSuffix).build();
}
@@ -672,7 +760,7 @@
*
* @param numShards the number of shards to use, or 0 to let the system decide.
*/
- public TypedWrite<T> withNumShards(int numShards) {
+ public TypedWrite<UserT> withNumShards(int numShards) {
checkArgument(numShards >= 0);
return toBuilder().setNumShards(numShards).build();
}
@@ -686,7 +774,7 @@
*
* <p>This is equivalent to {@code .withNumShards(1).withShardNameTemplate("")}
*/
- public TypedWrite<T> withoutSharding() {
+ public TypedWrite<UserT> withoutSharding() {
return withNumShards(1).withShardNameTemplate("");
}
@@ -695,7 +783,7 @@
*
* <p>A {@code null} value will clear any previously configured header.
*/
- public TypedWrite<T> withHeader(@Nullable String header) {
+ public TypedWrite<UserT> withHeader(@Nullable String header) {
return toBuilder().setHeader(header).build();
}
@@ -704,7 +792,7 @@
*
* <p>A {@code null} value will clear any previously configured footer.
*/
- public TypedWrite<T> withFooter(@Nullable String footer) {
+ public TypedWrite<UserT> withFooter(@Nullable String footer) {
return toBuilder().setFooter(footer).build();
}
@@ -715,7 +803,7 @@
*
* <p>A {@code null} value will reset the value to the default value mentioned above.
*/
- public TypedWrite<T> withWritableByteChannelFactory(
+ public TypedWrite<UserT> withWritableByteChannelFactory(
WritableByteChannelFactory writableByteChannelFactory) {
return toBuilder().setWritableByteChannelFactory(writableByteChannelFactory).build();
}
@@ -726,36 +814,58 @@
* <p>If using {@link #to(FileBasedSink.FilenamePolicy)}. Filenames will be generated using
* {@link FilenamePolicy#windowedFilename}. See also {@link WriteFiles#withWindowedWrites()}.
*/
- public TypedWrite<T> withWindowedWrites() {
+ public TypedWrite<UserT> withWindowedWrites() {
return toBuilder().setWindowedWrites(true).build();
}
- private DynamicDestinations<T, ?> resolveDynamicDestinations() {
- DynamicDestinations<T, ?> dynamicDestinations = getDynamicDestinations();
+ private DynamicDestinations<UserT, ?, String> resolveDynamicDestinations() {
+ DynamicDestinations<UserT, ?, String> dynamicDestinations = getDynamicDestinations();
if (dynamicDestinations == null) {
- FilenamePolicy usedFilenamePolicy = getFilenamePolicy();
- if (usedFilenamePolicy == null) {
- usedFilenamePolicy =
- DefaultFilenamePolicy.fromStandardParameters(
- getFilenamePrefix(),
- getShardTemplate(),
- getFilenameSuffix(),
- getWindowedWrites());
+ if (getDestinationFunction() != null) {
+ dynamicDestinations =
+ DynamicFileDestinations.toDefaultPolicies(
+ getDestinationFunction(), getEmptyDestination(), getFormatFunction());
+ } else {
+ FilenamePolicy usedFilenamePolicy = getFilenamePolicy();
+ if (usedFilenamePolicy == null) {
+ usedFilenamePolicy =
+ DefaultFilenamePolicy.fromStandardParameters(
+ getFilenamePrefix(),
+ getShardTemplate(),
+ getFilenameSuffix(),
+ getWindowedWrites());
+ }
+ dynamicDestinations =
+ DynamicFileDestinations.constant(usedFilenamePolicy, getFormatFunction());
}
- dynamicDestinations = DynamicFileDestinations.constant(usedFilenamePolicy);
}
return dynamicDestinations;
}
@Override
- public PDone expand(PCollection<T> input) {
+ public PDone expand(PCollection<UserT> input) {
checkState(
getFilenamePrefix() != null || getTempDirectory() != null,
"Need to set either the filename prefix or the tempDirectory of a TextIO.Write "
+ "transform.");
- checkState(
- getFilenamePolicy() == null || getDynamicDestinations() == null,
- "Cannot specify both a filename policy and dynamic destinations");
+
+ List<?> allToArgs =
+ Lists.newArrayList(
+ getFilenamePolicy(),
+ getDynamicDestinations(),
+ getFilenamePrefix(),
+ getDestinationFunction());
+ checkArgument(
+ 1 == Iterables.size(Iterables.filter(allToArgs, Predicates.notNull())),
+ "Exactly one of filename policy, dynamic destinations, filename prefix, or destination "
+ + "function must be set");
+
+ if (getDynamicDestinations() != null) {
+ checkArgument(
+ getFormatFunction() == null,
+ "A format function should not be specified "
+ + "with DynamicDestinations. Use DynamicDestinations.formatRecord instead");
+ }
if (getFilenamePolicy() != null || getDynamicDestinations() != null) {
checkState(
getShardTemplate() == null && getFilenameSuffix() == null,
@@ -766,20 +876,20 @@
}
public <DestinationT> PDone expandTyped(
- PCollection<T> input, DynamicDestinations<T, DestinationT> dynamicDestinations) {
+ PCollection<UserT> input,
+ DynamicDestinations<UserT, DestinationT, String> dynamicDestinations) {
ValueProvider<ResourceId> tempDirectory = getTempDirectory();
if (tempDirectory == null) {
tempDirectory = getFilenamePrefix();
}
- WriteFiles<T, DestinationT, String> write =
+ WriteFiles<UserT, DestinationT, String> write =
WriteFiles.to(
new TextSink<>(
tempDirectory,
dynamicDestinations,
getHeader(),
getFooter(),
- getWritableByteChannelFactory()),
- getFormatFunction());
+ getWritableByteChannelFactory()));
if (getNumShards() > 0) {
write = write.withNumShards(getNumShards());
}
@@ -814,11 +924,6 @@
"writableByteChannelFactory", getWritableByteChannelFactory().toString())
.withLabel("Compression/Transformation Type"));
}
-
- @Override
- protected Coder<Void> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
}
/**
@@ -831,7 +936,7 @@
@VisibleForTesting TypedWrite<String> inner;
Write() {
- this(TextIO.writeCustomType(SerializableFunctions.<String>identity()));
+ this(TextIO.<String>writeCustomType());
}
Write(TypedWrite<String> inner) {
@@ -840,43 +945,53 @@
/** See {@link TypedWrite#to(String)}. */
public Write to(String filenamePrefix) {
- return new Write(inner.to(filenamePrefix));
+ return new Write(
+ inner.to(filenamePrefix).withFormatFunction(SerializableFunctions.<String>identity()));
}
/** See {@link TypedWrite#to(ResourceId)}. */
@Experimental(Kind.FILESYSTEM)
public Write to(ResourceId filenamePrefix) {
- return new Write(inner.to(filenamePrefix));
+ return new Write(
+ inner.to(filenamePrefix).withFormatFunction(SerializableFunctions.<String>identity()));
}
/** See {@link TypedWrite#to(ValueProvider)}. */
public Write to(ValueProvider<String> outputPrefix) {
- return new Write(inner.to(outputPrefix));
+ return new Write(
+ inner.to(outputPrefix).withFormatFunction(SerializableFunctions.<String>identity()));
}
/** See {@link TypedWrite#toResource(ValueProvider)}. */
@Experimental(Kind.FILESYSTEM)
public Write toResource(ValueProvider<ResourceId> filenamePrefix) {
- return new Write(inner.toResource(filenamePrefix));
+ return new Write(
+ inner
+ .toResource(filenamePrefix)
+ .withFormatFunction(SerializableFunctions.<String>identity()));
}
/** See {@link TypedWrite#to(FilenamePolicy)}. */
@Experimental(Kind.FILESYSTEM)
public Write to(FilenamePolicy filenamePolicy) {
- return new Write(inner.to(filenamePolicy));
+ return new Write(
+ inner.to(filenamePolicy).withFormatFunction(SerializableFunctions.<String>identity()));
}
/** See {@link TypedWrite#to(DynamicDestinations)}. */
@Experimental(Kind.FILESYSTEM)
- public Write to(DynamicDestinations<String, ?> dynamicDestinations) {
- return new Write(inner.to(dynamicDestinations));
+ public Write to(DynamicDestinations<String, ?, String> dynamicDestinations) {
+ return new Write(inner.to(dynamicDestinations).withFormatFunction(null));
}
/** See {@link TypedWrite#to(SerializableFunction, Params)}. */
@Experimental(Kind.FILESYSTEM)
public Write to(
SerializableFunction<String, Params> destinationFunction, Params emptyDestination) {
- return new Write(inner.to(destinationFunction, emptyDestination));
+ return new Write(
+ inner
+ .to(destinationFunction, emptyDestination)
+ .withFormatFunction(SerializableFunctions.<String>identity()));
}
/** See {@link TypedWrite#withTempDirectory(ValueProvider)}. */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java
index b57b28c..387e0ac 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSink.java
@@ -34,13 +34,13 @@
* '\n'} represented in {@code UTF-8} format as the record separator. Each record (including the
* last) is terminated.
*/
-class TextSink<UserT, DestinationT> extends FileBasedSink<String, DestinationT> {
+class TextSink<UserT, DestinationT> extends FileBasedSink<UserT, DestinationT, String> {
@Nullable private final String header;
@Nullable private final String footer;
TextSink(
ValueProvider<ResourceId> baseOutputFilename,
- DynamicDestinations<UserT, DestinationT> dynamicDestinations,
+ DynamicDestinations<UserT, DestinationT, String> dynamicDestinations,
@Nullable String header,
@Nullable String footer,
WritableByteChannelFactory writableByteChannelFactory) {
@@ -50,13 +50,13 @@
}
@Override
- public WriteOperation<String, DestinationT> createWriteOperation() {
+ public WriteOperation<DestinationT, String> createWriteOperation() {
return new TextWriteOperation<>(this, header, footer);
}
/** A {@link WriteOperation WriteOperation} for text files. */
private static class TextWriteOperation<DestinationT>
- extends WriteOperation<String, DestinationT> {
+ extends WriteOperation<DestinationT, String> {
@Nullable private final String header;
@Nullable private final String footer;
@@ -67,20 +67,20 @@
}
@Override
- public Writer<String, DestinationT> createWriter() throws Exception {
+ public Writer<DestinationT, String> createWriter() throws Exception {
return new TextWriter<>(this, header, footer);
}
}
/** A {@link Writer Writer} for text files. */
- private static class TextWriter<DestinationT> extends Writer<String, DestinationT> {
+ private static class TextWriter<DestinationT> extends Writer<DestinationT, String> {
private static final String NEWLINE = "\n";
@Nullable private final String header;
@Nullable private final String footer;
private OutputStreamWriter out;
public TextWriter(
- WriteOperation<String, DestinationT> writeOperation,
+ WriteOperation<DestinationT, String> writeOperation,
@Nullable String header,
@Nullable String footer) {
super(writeOperation, MimeTypes.TEXT);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
index 4d9fa77..29188dc 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
@@ -28,6 +28,7 @@
import java.util.NoSuchElementException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.ValueProvider;
@@ -48,7 +49,11 @@
@VisibleForTesting
class TextSource extends FileBasedSource<String> {
TextSource(ValueProvider<String> fileSpec) {
- super(fileSpec, 1L);
+ this(fileSpec, EmptyMatchTreatment.DISALLOW);
+ }
+
+ TextSource(ValueProvider<String> fileSpec, EmptyMatchTreatment emptyMatchTreatment) {
+ super(fileSpec, emptyMatchTreatment, 1L);
}
private TextSource(MatchResult.Metadata metadata, long start, long end) {
@@ -69,7 +74,7 @@
}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java
index d8d7478..85c5652 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java
@@ -60,7 +60,6 @@
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.display.DisplayData;
@@ -76,7 +75,9 @@
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PCollectionViews;
import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.ShardedKey;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
@@ -121,9 +122,8 @@
private static final int SPILLED_RECORD_SHARDING_FACTOR = 10;
static final int UNKNOWN_SHARDNUM = -1;
- private FileBasedSink<OutputT, DestinationT> sink;
- private SerializableFunction<UserT, OutputT> formatFunction;
- private WriteOperation<OutputT, DestinationT> writeOperation;
+ private FileBasedSink<UserT, DestinationT, OutputT> sink;
+ private WriteOperation<DestinationT, OutputT> writeOperation;
// This allows the number of shards to be dynamically computed based on the input
// PCollection.
@Nullable private final PTransform<PCollection<UserT>, PCollectionView<Integer>> computeNumShards;
@@ -133,37 +133,44 @@
private final ValueProvider<Integer> numShardsProvider;
private final boolean windowedWrites;
private int maxNumWritersPerBundle;
+ // This is the set of side inputs used by this transform. This is usually populated by the users's
+ // DynamicDestinations object.
+ private final List<PCollectionView<?>> sideInputs;
/**
* Creates a {@link WriteFiles} transform that writes to the given {@link FileBasedSink}, letting
* the runner control how many different shards are produced.
*/
public static <UserT, DestinationT, OutputT> WriteFiles<UserT, DestinationT, OutputT> to(
- FileBasedSink<OutputT, DestinationT> sink,
- SerializableFunction<UserT, OutputT> formatFunction) {
+ FileBasedSink<UserT, DestinationT, OutputT> sink) {
checkNotNull(sink, "sink");
return new WriteFiles<>(
sink,
- formatFunction,
null /* runner-determined sharding */,
null,
false,
- DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE);
+ DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE,
+ sink.getDynamicDestinations().getSideInputs());
}
private WriteFiles(
- FileBasedSink<OutputT, DestinationT> sink,
- SerializableFunction<UserT, OutputT> formatFunction,
+ FileBasedSink<UserT, DestinationT, OutputT> sink,
@Nullable PTransform<PCollection<UserT>, PCollectionView<Integer>> computeNumShards,
@Nullable ValueProvider<Integer> numShardsProvider,
boolean windowedWrites,
- int maxNumWritersPerBundle) {
+ int maxNumWritersPerBundle,
+ List<PCollectionView<?>> sideInputs) {
this.sink = sink;
- this.formatFunction = checkNotNull(formatFunction);
this.computeNumShards = computeNumShards;
this.numShardsProvider = numShardsProvider;
this.windowedWrites = windowedWrites;
this.maxNumWritersPerBundle = maxNumWritersPerBundle;
+ this.sideInputs = sideInputs;
+ }
+
+ @Override
+ public Map<TupleTag<?>, PValue> getAdditionalInputs() {
+ return PCollectionViews.toAdditionalInputs(sideInputs);
}
@Override
@@ -207,15 +214,10 @@
}
/** Returns the {@link FileBasedSink} associated with this PTransform. */
- public FileBasedSink<OutputT, DestinationT> getSink() {
+ public FileBasedSink<UserT, DestinationT, OutputT> getSink() {
return sink;
}
- /** Returns the the format function that maps the user type to the record written to files. */
- public SerializableFunction<UserT, OutputT> getFormatFunction() {
- return formatFunction;
- }
-
/**
* Returns whether or not to perform windowed writes.
*/
@@ -266,11 +268,11 @@
ValueProvider<Integer> numShardsProvider) {
return new WriteFiles<>(
sink,
- formatFunction,
computeNumShards,
numShardsProvider,
windowedWrites,
- maxNumWritersPerBundle);
+ maxNumWritersPerBundle,
+ sideInputs);
}
/** Set the maximum number of writers created in a bundle before spilling to shuffle. */
@@ -278,11 +280,22 @@
int maxNumWritersPerBundle) {
return new WriteFiles<>(
sink,
- formatFunction,
computeNumShards,
numShardsProvider,
windowedWrites,
- maxNumWritersPerBundle);
+ maxNumWritersPerBundle,
+ sideInputs);
+ }
+
+ public WriteFiles<UserT, DestinationT, OutputT> withSideInputs(
+ List<PCollectionView<?>> sideInputs) {
+ return new WriteFiles<>(
+ sink,
+ computeNumShards,
+ numShardsProvider,
+ windowedWrites,
+ maxNumWritersPerBundle,
+ sideInputs);
}
/**
@@ -297,7 +310,7 @@
checkNotNull(
sharding, "Cannot provide null sharding. Use withRunnerDeterminedSharding() instead");
return new WriteFiles<>(
- sink, formatFunction, sharding, null, windowedWrites, maxNumWritersPerBundle);
+ sink, sharding, null, windowedWrites, maxNumWritersPerBundle, sideInputs);
}
/**
@@ -305,8 +318,7 @@
* runner-determined sharding.
*/
public WriteFiles<UserT, DestinationT, OutputT> withRunnerDeterminedSharding() {
- return new WriteFiles<>(
- sink, formatFunction, null, null, windowedWrites, maxNumWritersPerBundle);
+ return new WriteFiles<>(sink, null, null, windowedWrites, maxNumWritersPerBundle, sideInputs);
}
/**
@@ -323,7 +335,7 @@
*/
public WriteFiles<UserT, DestinationT, OutputT> withWindowedWrites() {
return new WriteFiles<>(
- sink, formatFunction, computeNumShards, numShardsProvider, true, maxNumWritersPerBundle);
+ sink, computeNumShards, numShardsProvider, true, maxNumWritersPerBundle, sideInputs);
}
private static class WriterKey<DestinationT> {
@@ -374,7 +386,7 @@
private final Coder<DestinationT> destinationCoder;
private final boolean windowedWrites;
- private Map<WriterKey<DestinationT>, Writer<OutputT, DestinationT>> writers;
+ private Map<WriterKey<DestinationT>, Writer<DestinationT, OutputT>> writers;
private int spilledShardNum = UNKNOWN_SHARDNUM;
WriteBundles(
@@ -394,6 +406,7 @@
@ProcessElement
public void processElement(ProcessContext c, BoundedWindow window) throws Exception {
+ sink.getDynamicDestinations().setSideInputAccessorFromProcessContext(c);
PaneInfo paneInfo = c.pane();
// If we are doing windowed writes, we need to ensure that we have separate files for
// data in different windows/panes. Similar for dynamic writes, make sure that different
@@ -402,7 +415,7 @@
// the map will only have a single element.
DestinationT destination = sink.getDynamicDestinations().getDestination(c.element());
WriterKey<DestinationT> key = new WriterKey<>(window, c.pane(), destination);
- Writer<OutputT, DestinationT> writer = writers.get(key);
+ Writer<DestinationT, OutputT> writer = writers.get(key);
if (writer == null) {
if (writers.size() <= maxNumWritersPerBundle) {
String uuid = UUID.randomUUID().toString();
@@ -436,14 +449,14 @@
return;
}
}
- writeOrClose(writer, formatFunction.apply(c.element()));
+ writeOrClose(writer, getSink().getDynamicDestinations().formatRecord(c.element()));
}
@FinishBundle
public void finishBundle(FinishBundleContext c) throws Exception {
- for (Map.Entry<WriterKey<DestinationT>, Writer<OutputT, DestinationT>> entry :
+ for (Map.Entry<WriterKey<DestinationT>, Writer<DestinationT, OutputT>> entry :
writers.entrySet()) {
- Writer<OutputT, DestinationT> writer = entry.getValue();
+ Writer<DestinationT, OutputT> writer = entry.getValue();
FileResult<DestinationT> result;
try {
result = writer.close();
@@ -478,13 +491,14 @@
@ProcessElement
public void processElement(ProcessContext c, BoundedWindow window) throws Exception {
+ sink.getDynamicDestinations().setSideInputAccessorFromProcessContext(c);
// Since we key by a 32-bit hash of the destination, there might be multiple destinations
// in this iterable. The number of destinations is generally very small (1000s or less), so
// there will rarely be hash collisions.
- Map<DestinationT, Writer<OutputT, DestinationT>> writers = Maps.newHashMap();
+ Map<DestinationT, Writer<DestinationT, OutputT>> writers = Maps.newHashMap();
for (UserT input : c.element().getValue()) {
DestinationT destination = sink.getDynamicDestinations().getDestination(input);
- Writer<OutputT, DestinationT> writer = writers.get(destination);
+ Writer<DestinationT, OutputT> writer = writers.get(destination);
if (writer == null) {
LOG.debug("Opening writer for write operation {}", writeOperation);
writer = writeOperation.createWriter();
@@ -501,12 +515,12 @@
LOG.debug("Done opening writer");
writers.put(destination, writer);
}
- writeOrClose(writer, formatFunction.apply(input));
- }
+ writeOrClose(writer, getSink().getDynamicDestinations().formatRecord(input));
+ }
// Close all writers.
- for (Map.Entry<DestinationT, Writer<OutputT, DestinationT>> entry : writers.entrySet()) {
- Writer<OutputT, DestinationT> writer = entry.getValue();
+ for (Map.Entry<DestinationT, Writer<DestinationT, OutputT>> entry : writers.entrySet()) {
+ Writer<DestinationT, OutputT> writer = entry.getValue();
FileResult<DestinationT> result;
try {
// Close the writer; if this throws let the error propagate.
@@ -526,8 +540,8 @@
}
}
- private static <OutputT, DestinationT> void writeOrClose(
- Writer<OutputT, DestinationT> writer, OutputT t) throws Exception {
+ private static <DestinationT, OutputT> void writeOrClose(
+ Writer<DestinationT, OutputT> writer, OutputT t) throws Exception {
try {
writer.write(t);
} catch (Exception e) {
@@ -678,6 +692,7 @@
input.apply(
writeName,
ParDo.of(new WriteBundles(windowedWrites, unwrittedRecordsTag, destinationCoder))
+ .withSideInputs(sideInputs)
.withOutputTags(writtenRecordsTag, TupleTagList.of(unwrittedRecordsTag)));
PCollection<FileResult<DestinationT>> writtenBundleFiles =
writeTuple
@@ -692,17 +707,18 @@
.apply("GroupUnwritten", GroupByKey.<ShardedKey<Integer>, UserT>create())
.apply(
"WriteUnwritten",
- ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_IN_FINALIZE)))
+ ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_IN_FINALIZE))
+ .withSideInputs(sideInputs))
.setCoder(FileResultCoder.of(shardedWindowCoder, destinationCoder));
results =
PCollectionList.of(writtenBundleFiles)
.and(writtenGroupedFiles)
.apply(Flatten.<FileResult<DestinationT>>pCollections());
} else {
- List<PCollectionView<?>> sideInputs = Lists.newArrayList();
+ List<PCollectionView<?>> shardingSideInputs = Lists.newArrayList();
if (computeNumShards != null) {
numShardsView = input.apply(computeNumShards);
- sideInputs.add(numShardsView);
+ shardingSideInputs.add(numShardsView);
} else {
numShardsView = null;
}
@@ -715,7 +731,7 @@
numShardsView,
(numShardsView != null) ? null : numShardsProvider,
destinationCoder))
- .withSideInputs(sideInputs))
+ .withSideInputs(shardingSideInputs))
.setCoder(KvCoder.of(ShardedKeyCoder.of(VarIntCoder.of()), input.getCoder()))
.apply("GroupIntoShards", GroupByKey.<ShardedKey<Integer>, UserT>create());
shardedWindowCoder =
@@ -728,7 +744,8 @@
results =
sharded.apply(
"WriteShardedBundles",
- ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_WHEN_WRITING)));
+ ParDo.of(new WriteShardedBundles(ShardAssignment.ASSIGN_WHEN_WRITING))
+ .withSideInputs(sideInputs));
}
results.setCoder(FileResultCoder.of(shardedWindowCoder, destinationCoder));
@@ -773,11 +790,12 @@
} else {
final PCollectionView<Iterable<FileResult<DestinationT>>> resultsView =
results.apply(View.<FileResult<DestinationT>>asIterable());
- ImmutableList.Builder<PCollectionView<?>> sideInputs =
+ ImmutableList.Builder<PCollectionView<?>> finalizeSideInputs =
ImmutableList.<PCollectionView<?>>builder().add(resultsView);
if (numShardsView != null) {
- sideInputs.add(numShardsView);
+ finalizeSideInputs.add(numShardsView);
}
+ finalizeSideInputs.addAll(sideInputs);
// Finalize the write in another do-once ParDo on the singleton collection containing the
// Writer. The results from the per-bundle writes are given as an Iterable side input.
@@ -794,7 +812,7 @@
new DoFn<Void, Integer>() {
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
- LOG.info("Finalizing write operation {}.", writeOperation);
+ sink.getDynamicDestinations().setSideInputAccessorFromProcessContext(c);
// We must always output at least 1 shard, and honor user-specified numShards
// if
// set.
@@ -827,7 +845,7 @@
writeOperation.removeTemporaryFiles(tempFiles);
}
})
- .withSideInputs(sideInputs.build()));
+ .withSideInputs(finalizeSideInputs.build()));
}
return PDone.in(input.getPipeline());
}
@@ -857,7 +875,7 @@
minShardsNeeded,
destination);
for (int i = 0; i < extraShardsNeeded; ++i) {
- Writer<OutputT, DestinationT> writer = writeOperation.createWriter();
+ Writer<DestinationT, OutputT> writer = writeOperation.createWriter();
// Currently this code path is only called in the unwindowed case.
writer.openUnwindowed(UUID.randomUUID().toString(), UNKNOWN_SHARDNUM, destination);
FileResult<DestinationT> emptyWrite = writer.close();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/EmptyMatchTreatment.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/EmptyMatchTreatment.java
new file mode 100644
index 0000000..8e12993
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/EmptyMatchTreatment.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.fs;
+
+import org.apache.beam.sdk.io.fs.MatchResult.Status;
+
+/**
+ * Options for allowing or disallowing filepatterns that match no resources in {@link
+ * org.apache.beam.sdk.io.FileSystems#match}.
+ */
+public enum EmptyMatchTreatment {
+ /**
+ * Filepatterns matching no resources are allowed. For such a filepattern, {@link
+ * MatchResult#status} will be {@link Status#OK} and {@link MatchResult#metadata} will return an
+ * empty list.
+ */
+ ALLOW,
+
+ /**
+ * Filepatterns matching no resources are disallowed. For such a filepattern, {@link
+ * MatchResult#status} will be {@link Status#NOT_FOUND} and {@link MatchResult#metadata} will
+ * throw a {@link java.io.FileNotFoundException}.
+ */
+ DISALLOW,
+
+ /**
+ * Filepatterns matching no resources are allowed if the filepattern contains a glob wildcard
+ * character, and disallowed otherwise (i.e. if the filepattern specifies a single file).
+ */
+ ALLOW_IF_WILDCARD
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MatchResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MatchResult.java
index 642c049..aa80b96 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MatchResult.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MatchResult.java
@@ -21,6 +21,7 @@
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
+import org.apache.beam.sdk.io.FileSystems;
/**
* The result of {@link org.apache.beam.sdk.io.FileSystem#match}.
@@ -78,7 +79,9 @@
public abstract Status status();
/**
- * {@link Metadata} of matched files.
+ * {@link Metadata} of matched files. Note that if {@link #status()} is {@link Status#NOT_FOUND},
+ * this may either throw a {@link java.io.FileNotFoundException} or return an empty list,
+ * depending on the {@link EmptyMatchTreatment} used in the {@link FileSystems#match} call.
*/
public abstract List<Metadata> metadata() throws IOException;
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MetadataCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MetadataCoder.java
new file mode 100644
index 0000000..5c9c4d7
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/MetadataCoder.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.fs;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
+
+/** A {@link Coder} for {@link Metadata}. */
+public class MetadataCoder extends AtomicCoder<Metadata> {
+ private static final ResourceIdCoder RESOURCE_ID_CODER = ResourceIdCoder.of();
+ private static final VarIntCoder INT_CODER = VarIntCoder.of();
+ private static final VarLongCoder LONG_CODER = VarLongCoder.of();
+
+ /** Creates a {@link MetadataCoder}. */
+ public static MetadataCoder of() {
+ return new MetadataCoder();
+ }
+
+ @Override
+ public void encode(Metadata value, OutputStream os) throws IOException {
+ RESOURCE_ID_CODER.encode(value.resourceId(), os);
+ INT_CODER.encode(value.isReadSeekEfficient() ? 1 : 0, os);
+ LONG_CODER.encode(value.sizeBytes(), os);
+ }
+
+ @Override
+ public Metadata decode(InputStream is) throws IOException {
+ ResourceId resourceId = RESOURCE_ID_CODER.decode(is);
+ boolean isReadSeekEfficient = INT_CODER.decode(is) == 1;
+ long sizeBytes = LONG_CODER.decode(is);
+ return Metadata.builder()
+ .setResourceId(resourceId)
+ .setIsReadSeekEfficient(isReadSeekEfficient)
+ .setSizeBytes(sizeBytes)
+ .build();
+ }
+
+ @Override
+ public boolean consistentWithEquals() {
+ return true;
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/ResourceIdCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/ResourceIdCoder.java
new file mode 100644
index 0000000..d7649c0
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/fs/ResourceIdCoder.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.fs;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.BooleanCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.FileSystems;
+
+/** A {@link Coder} for {@link ResourceId}. */
+public class ResourceIdCoder extends AtomicCoder<ResourceId> {
+ private static final Coder<String> STRING_CODER = StringUtf8Coder.of();
+ private static final Coder<Boolean> BOOL_CODER = BooleanCoder.of();
+
+ /** Creates a {@link ResourceIdCoder}. */
+ public static ResourceIdCoder of() {
+ return new ResourceIdCoder();
+ }
+
+ @Override
+ public void encode(ResourceId value, OutputStream os) throws IOException {
+ STRING_CODER.encode(value.toString(), os);
+ BOOL_CODER.encode(value.isDirectory(), os);
+ }
+
+ @Override
+ public ResourceId decode(InputStream is) throws IOException {
+ String spec = STRING_CODER.decode(is);
+ boolean isDirectory = BOOL_CODER.decode(is);
+ return FileSystems.matchNewResource(spec, isDirectory);
+ }
+
+ @Override
+ public boolean consistentWithEquals() {
+ return true;
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java
index 9a4d25a..5cc0b3f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptions.java
@@ -176,7 +176,12 @@
*
* <h2>Serialization Of PipelineOptions</h2>
*
- * {@link PipelineRunner}s require support for options to be serialized. Each property
+ * {@link PipelineOptions} is intentionally <i>not</i> marked {@link java.io.Serializable}, in order
+ * to discourage pipeline authors from capturing {@link PipelineOptions} at pipeline construction
+ * time, because a pipeline may be saved as a template and run with a different set of options
+ * than the ones it was constructed with. See {@link Pipeline#run(PipelineOptions)}.
+ *
+ * <p>However, {@link PipelineRunner}s require support for options to be serialized. Each property
* within {@link PipelineOptions} must be able to be serialized using Jackson's
* {@link ObjectMapper} or the getter method for the property annotated with
* {@link JsonIgnore @JsonIgnore}.
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProvider.java
index c7f1e09..94187a9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProvider.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProvider.java
@@ -41,13 +41,19 @@
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.values.PCollection;
/**
* A {@link ValueProvider} abstracts the notion of fetching a value that may or may not be currently
* available.
*
* <p>This can be used to parameterize transforms that only read values in at runtime, for example.
+ *
+ * <p>A common task is to create a {@link PCollection} containing the value of this
+ * {@link ValueProvider} regardless of whether it's accessible at construction time or not.
+ * For that, use {@link Create#ofProvider}.
*/
@JsonSerialize(using = ValueProvider.Serializer.class)
@JsonDeserialize(using = ValueProvider.Deserializer.class)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProviders.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProviders.java
index 1cc46fe..2fffffa 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProviders.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ValueProviders.java
@@ -19,16 +19,14 @@
import static com.google.common.base.Preconditions.checkNotNull;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.util.Map;
-import org.apache.beam.sdk.util.common.ReflectHelpers;
/**
* Utilities for working with the {@link ValueProvider} interface.
*/
-class ValueProviders {
+public class ValueProviders {
private ValueProviders() {}
/**
@@ -37,11 +35,9 @@
*/
public static String updateSerializedOptions(
String serializedOptions, Map<String, String> runtimeValues) {
- ObjectMapper mapper = new ObjectMapper().registerModules(
- ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
ObjectNode root, options;
try {
- root = mapper.readValue(serializedOptions, ObjectNode.class);
+ root = PipelineOptionsFactory.MAPPER.readValue(serializedOptions, ObjectNode.class);
options = (ObjectNode) root.get("options");
checkNotNull(options, "Unable to locate 'options' in %s", serializedOptions);
} catch (IOException e) {
@@ -53,7 +49,7 @@
options.put(entry.getKey(), entry.getValue());
}
try {
- return mapper.writeValueAsString(root);
+ return PipelineOptionsFactory.MAPPER.writeValueAsString(root);
} catch (IOException e) {
throw new RuntimeException("Unable to parse re-serialize options", e);
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
index d8ff59e..c2d5771 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
@@ -34,6 +34,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior;
@@ -98,6 +99,48 @@
return current;
}
+ @Internal
+ public Node pushFinalizedNode(
+ String name,
+ Map<TupleTag<?>, PValue> inputs,
+ PTransform<?, ?> transform,
+ Map<TupleTag<?>, PValue> outputs) {
+ checkNotNull(
+ transform, "A %s must be provided for all Nodes", PTransform.class.getSimpleName());
+ checkNotNull(
+ name, "A name must be provided for all %s Nodes", PTransform.class.getSimpleName());
+ checkNotNull(
+ inputs, "An input must be provided for all %s Nodes", PTransform.class.getSimpleName());
+ Node node = new Node(current, transform, name, inputs, outputs);
+ node.finishedSpecifying = true;
+ current.addComposite(node);
+ current = node;
+ return current;
+ }
+
+ @Internal
+ public Node addFinalizedPrimitiveNode(
+ String name,
+ Map<TupleTag<?>, PValue> inputs,
+ PTransform<?, ?> transform,
+ Map<TupleTag<?>, PValue> outputs) {
+ checkNotNull(
+ transform, "A %s must be provided for all Nodes", PTransform.class.getSimpleName());
+ checkNotNull(
+ name, "A name must be provided for all %s Nodes", PTransform.class.getSimpleName());
+ checkNotNull(
+ inputs, "Inputs must be provided for all %s Nodes", PTransform.class.getSimpleName());
+ checkNotNull(
+ outputs, "Outputs must be provided for all %s Nodes", PTransform.class.getSimpleName());
+ Node node = new Node(current, transform, name, inputs, outputs);
+ node.finishedSpecifying = true;
+ for (PValue output : outputs.values()) {
+ producers.put(output, node);
+ }
+ current.addComposite(node);
+ return node;
+ }
+
public Node replaceNode(Node existing, PInput input, PTransform<?, ?> transform) {
checkNotNull(existing);
checkNotNull(input);
@@ -321,6 +364,32 @@
}
/**
+ * Creates a new {@link Node} with the given parent and transform, where inputs and outputs
+ * are already known.
+ *
+ * <p>EnclosingNode and transform may both be null for a root-level node, which holds all other
+ * nodes.
+ *
+ * @param enclosingNode the composite node containing this node
+ * @param transform the PTransform tracked by this node
+ * @param fullName the fully qualified name of the transform
+ * @param inputs the expanded inputs to the transform
+ * @param outputs the expanded outputs of the transform
+ */
+ private Node(
+ @Nullable Node enclosingNode,
+ @Nullable PTransform<?, ?> transform,
+ String fullName,
+ @Nullable Map<TupleTag<?>, PValue> inputs,
+ @Nullable Map<TupleTag<?>, PValue> outputs) {
+ this.enclosingNode = enclosingNode;
+ this.transform = transform;
+ this.fullName = fullName;
+ this.inputs = inputs == null ? Collections.<TupleTag<?>, PValue>emptyMap() : inputs;
+ this.outputs = outputs == null ? Collections.<TupleTag<?>, PValue>emptyMap() : outputs;
+ }
+
+ /**
* Returns the transform associated with this transform node.
*/
public PTransform<?, ?> getTransform() {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CombineFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CombineFnTester.java
new file mode 100644
index 0000000..efd2af3
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/CombineFnTester.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.hamcrest.Matcher;
+
+/**
+ * Utilities for testing {@link CombineFn CombineFns}. Ensures that the {@link CombineFn} gives
+ * correct results across various permutations and shardings of the input.
+ */
+public class CombineFnTester {
+ /**
+ * Tests that the the {@link CombineFn}, when applied to the provided input, produces the provided
+ * output. Tests a variety of permutations of the input.
+ */
+ public static <InputT, AccumT, OutputT> void testCombineFn(
+ CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, final OutputT expected) {
+ testCombineFn(fn, input, is(expected));
+ Collections.shuffle(input);
+ testCombineFn(fn, input, is(expected));
+ }
+
+ public static <InputT, AccumT, OutputT> void testCombineFn(
+ CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, Matcher<? super OutputT> matcher) {
+ int size = input.size();
+ checkCombineFnShardsMultipleOrders(fn, Collections.singletonList(input), matcher);
+ checkCombineFnShardsMultipleOrders(fn, shardEvenly(input, 2), matcher);
+ if (size > 4) {
+ checkCombineFnShardsMultipleOrders(fn, shardEvenly(input, size / 2), matcher);
+ checkCombineFnShardsMultipleOrders(
+ fn, shardEvenly(input, (int) (size / Math.sqrt(size))), matcher);
+ }
+ checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, 1.4), matcher);
+ checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, 2), matcher);
+ checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, Math.E), matcher);
+ }
+
+ private static <InputT, AccumT, OutputT> void checkCombineFnShardsMultipleOrders(
+ CombineFn<InputT, AccumT, OutputT> fn,
+ List<? extends Iterable<InputT>> shards,
+ Matcher<? super OutputT> matcher) {
+ checkCombineFnShardsSingleMerge(fn, shards, matcher);
+ checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);
+ checkCombineFnShardsIncrementalMerging(fn, shards, matcher);
+ Collections.shuffle(shards);
+ checkCombineFnShardsSingleMerge(fn, shards, matcher);
+ checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);
+ checkCombineFnShardsIncrementalMerging(fn, shards, matcher);
+ }
+
+ private static <InputT, AccumT, OutputT> void checkCombineFnShardsSingleMerge(
+ CombineFn<InputT, AccumT, OutputT> fn,
+ Iterable<? extends Iterable<InputT>> shards,
+ Matcher<? super OutputT> matcher) {
+ List<AccumT> accumulators = combineInputs(fn, shards);
+ AccumT merged = fn.mergeAccumulators(accumulators);
+ assertThat(fn.extractOutput(merged), matcher);
+ }
+
+ private static <InputT, AccumT, OutputT> void checkCombineFnShardsWithEmptyAccumulators(
+ CombineFn<InputT, AccumT, OutputT> fn,
+ Iterable<? extends Iterable<InputT>> shards,
+ Matcher<? super OutputT> matcher) {
+ List<AccumT> accumulators = combineInputs(fn, shards);
+ accumulators.add(0, fn.createAccumulator());
+ accumulators.add(fn.createAccumulator());
+ AccumT merged = fn.mergeAccumulators(accumulators);
+ assertThat(fn.extractOutput(merged), matcher);
+ }
+
+ private static <InputT, AccumT, OutputT> void checkCombineFnShardsIncrementalMerging(
+ CombineFn<InputT, AccumT, OutputT> fn,
+ List<? extends Iterable<InputT>> shards,
+ Matcher<? super OutputT> matcher) {
+ AccumT accumulator = null;
+ for (AccumT inputAccum : combineInputs(fn, shards)) {
+ if (accumulator == null) {
+ accumulator = inputAccum;
+ } else {
+ accumulator = fn.mergeAccumulators(Arrays.asList(accumulator, inputAccum));
+ }
+ }
+ assertThat(fn.extractOutput(accumulator), matcher);
+ }
+
+ private static <InputT, AccumT, OutputT> List<AccumT> combineInputs(
+ CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards) {
+ List<AccumT> accumulators = new ArrayList<>();
+ int maybeCompact = 0;
+ for (Iterable<InputT> shard : shards) {
+ AccumT accumulator = fn.createAccumulator();
+ for (InputT elem : shard) {
+ accumulator = fn.addInput(accumulator, elem);
+ }
+ if (maybeCompact++ % 2 == 0) {
+ accumulator = fn.compact(accumulator);
+ }
+ accumulators.add(accumulator);
+ }
+ return accumulators;
+ }
+
+ private static <T> List<List<T>> shardEvenly(List<T> input, int numShards) {
+ List<List<T>> shards = new ArrayList<>(numShards);
+ for (int i = 0; i < numShards; i++) {
+ shards.add(input.subList(i * input.size() / numShards,
+ (i + 1) * input.size() / numShards));
+ }
+ return shards;
+ }
+
+ private static <T> List<List<T>> shardExponentially(
+ List<T> input, double base) {
+ assert base > 1.0;
+ List<List<T>> shards = new ArrayList<>();
+ int end = input.size();
+ while (end > 0) {
+ int start = (int) (end / base);
+ shards.add(input.subList(start, end));
+ end = start;
+ }
+ return shards;
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SourceTestUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SourceTestUtils.java
index cde0b94..e147221 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SourceTestUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SourceTestUtils.java
@@ -212,7 +212,7 @@
List<? extends BoundedSource<T>> sources,
PipelineOptions options)
throws Exception {
- Coder<T> coder = referenceSource.getDefaultOutputCoder();
+ Coder<T> coder = referenceSource.getOutputCoder();
List<T> referenceRecords = readFromSource(referenceSource, options);
List<T> bundleRecords = new ArrayList<>();
for (BoundedSource<T> source : sources) {
@@ -221,7 +221,7 @@
+ source
+ " is not compatible with Coder type for referenceSource "
+ referenceSource,
- source.getDefaultOutputCoder(),
+ source.getOutputCoder(),
equalTo(coder));
List<T> elems = readFromSource(source, options);
bundleRecords.addAll(elems);
@@ -239,7 +239,7 @@
*/
public static <T> void assertUnstartedReaderReadsSameAsItsSource(
BoundedSource.BoundedReader<T> reader, PipelineOptions options) throws Exception {
- Coder<T> coder = reader.getCurrentSource().getDefaultOutputCoder();
+ Coder<T> coder = reader.getCurrentSource().getOutputCoder();
List<T> expected = readFromUnstartedReader(reader);
List<T> actual = readFromSource(reader.getCurrentSource(), options);
List<ReadableStructuralValue<T>> expectedStructural = createStructuralValues(coder, expected);
@@ -415,7 +415,7 @@
source,
primary,
residual);
- Coder<T> coder = primary.getDefaultOutputCoder();
+ Coder<T> coder = primary.getOutputCoder();
List<ReadableStructuralValue<T>> primaryValues =
createStructuralValues(coder, primaryItems);
List<ReadableStructuralValue<T>> currentValues =
@@ -728,8 +728,8 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
- return boundedSource.getDefaultOutputCoder();
+ public Coder<T> getOutputCoder() {
+ return boundedSource.getOutputCoder();
}
private static class UnsplittableReader<T> extends BoundedReader<T> {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
index 34f1c834..b67b14f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
@@ -328,6 +328,11 @@
* testing.
*/
public PipelineResult run() {
+ return run(getOptions());
+ }
+
+ /** Like {@link #run} but with the given potentially modified options. */
+ public PipelineResult run(PipelineOptions options) {
checkState(
enforcement.isPresent(),
"Is your TestPipeline declaration missing a @Rule annotation? Usage: "
@@ -336,7 +341,7 @@
final PipelineResult pipelineResult;
try {
enforcement.get().beforePipelineExecution();
- pipelineResult = super.run();
+ pipelineResult = super.run(options);
verifyPAssertsSucceeded(this, pipelineResult);
} catch (RuntimeException exc) {
Throwable cause = exc.getCause();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java
index d13fcf1..45f4413 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java
@@ -253,9 +253,8 @@
@Override
public PCollection<T> expand(PBegin input) {
- return PCollection.<T>createPrimitiveOutputInternal(
- input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
- .setCoder(coder);
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED, coder);
}
public Coder<T> getValueCoder() {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCustomWindowMerging.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCustomWindowMerging.java
new file mode 100644
index 0000000..fc40e02
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesCustomWindowMerging.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Category tag for validation tests which utilize custom window merging.
+ */
+public interface UsesCustomWindowMerging {}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
index c195352..fab98f8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
@@ -2156,8 +2156,13 @@
}).withSideInputs(sideInputs));
try {
- Coder<KV<K, OutputT>> outputCoder = getDefaultOutputCoder(input);
- output.setCoder(outputCoder);
+ KvCoder<K, InputT> kvCoder = getKvCoder(input.getCoder());
+ @SuppressWarnings("unchecked")
+ Coder<OutputT> outputValueCoder =
+ ((GlobalCombineFn<InputT, ?, OutputT>) fn)
+ .getDefaultOutputCoder(
+ input.getPipeline().getCoderRegistry(), kvCoder.getValueCoder());
+ output.setCoder(KvCoder.of(kvCoder.getKeyCoder(), outputValueCoder));
} catch (CannotProvideCoderException exc) {
// let coder inference happen later, if it can
}
@@ -2200,19 +2205,6 @@
}
@Override
- public Coder<KV<K, OutputT>> getDefaultOutputCoder(
- PCollection<? extends KV<K, ? extends Iterable<InputT>>> input)
- throws CannotProvideCoderException {
- KvCoder<K, InputT> kvCoder = getKvCoder(input.getCoder());
- @SuppressWarnings("unchecked")
- Coder<OutputT> outputValueCoder =
- ((GlobalCombineFn<InputT, ?, OutputT>) fn)
- .getDefaultOutputCoder(
- input.getPipeline().getCoderRegistry(), kvCoder.getValueCoder());
- return KvCoder.of(kvCoder.getKeyCoder(), outputValueCoder);
- }
-
- @Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
Combine.populateDisplayData(builder, fn, fnDisplayData);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
index 7af8fb8..2635bc8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.transforms;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
@@ -52,6 +53,7 @@
import org.apache.beam.sdk.io.OffsetBasedSource.OffsetBasedReader;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
@@ -200,6 +202,14 @@
}
/**
+ * Returns an {@link OfValueProvider} transform that produces a {@link PCollection}
+ * of a single element provided by the given {@link ValueProvider}.
+ */
+ public static <T> OfValueProvider<T> ofProvider(ValueProvider<T> provider, Coder<T> coder) {
+ return new OfValueProvider<>(provider, coder);
+ }
+
+ /**
* Returns a new {@link Create.TimestampedValues} transform that produces a
* {@link PCollection} containing the elements of the provided {@code Iterable}
* with the specified timestamps.
@@ -305,29 +315,25 @@
@Override
public PCollection<T> expand(PBegin input) {
+ Coder<T> coder;
try {
- Coder<T> coder = getDefaultOutputCoder(input);
- try {
- CreateSource<T> source = CreateSource.fromIterable(elems, coder);
- return input.getPipeline().apply(Read.from(source));
- } catch (IOException e) {
- throw new RuntimeException(
- String.format("Unable to apply Create %s using Coder %s.", this, coder), e);
- }
+ CoderRegistry registry = input.getPipeline().getCoderRegistry();
+ coder =
+ this.coder.isPresent()
+ ? this.coder.get()
+ : typeDescriptor.isPresent()
+ ? registry.getCoder(typeDescriptor.get())
+ : getDefaultCreateCoder(registry, elems);
} catch (CannotProvideCoderException e) {
throw new IllegalArgumentException("Unable to infer a coder and no Coder was specified. "
+ "Please set a coder by invoking Create.withCoder() explicitly.", e);
}
- }
-
- @Override
- public Coder<T> getDefaultOutputCoder(PBegin input) throws CannotProvideCoderException {
- if (coder.isPresent()) {
- return coder.get();
- } else if (typeDescriptor.isPresent()) {
- return input.getPipeline().getCoderRegistry().getCoder(typeDescriptor.get());
- } else {
- return getDefaultCreateCoder(input.getPipeline().getCoderRegistry(), elems);
+ try {
+ CreateSource<T> source = CreateSource.fromIterable(elems, coder);
+ return input.getPipeline().apply(Read.from(source));
+ } catch (IOException e) {
+ throw new RuntimeException(
+ String.format("Unable to apply Create %s using Coder %s.", this, coder), e);
}
}
@@ -401,7 +407,7 @@
public void validate() {}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return coder;
}
@@ -485,6 +491,38 @@
/////////////////////////////////////////////////////////////////////////////
+ /** Implementation of {@link #ofProvider}. */
+ public static class OfValueProvider<T> extends PTransform<PBegin, PCollection<T>> {
+ private final ValueProvider<T> provider;
+ private final Coder<T> coder;
+
+ private OfValueProvider(ValueProvider<T> provider, Coder<T> coder) {
+ this.provider = checkNotNull(provider, "provider");
+ this.coder = checkNotNull(coder, "coder");
+ }
+
+ @Override
+ public PCollection<T> expand(PBegin input) {
+ if (provider.isAccessible()) {
+ Values<T> values = Create.of(provider.get());
+ return input.apply(values.withCoder(coder));
+ }
+ return input
+ .apply(Create.of((Void) null))
+ .apply(
+ MapElements.via(
+ new SimpleFunction<Void, T>() {
+ @Override
+ public T apply(Void input) {
+ return provider.get();
+ }
+ }))
+ .setCoder(coder);
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
/**
* A {@code PTransform} that creates a {@code PCollection} whose elements have
* associated timestamps.
@@ -528,7 +566,23 @@
@Override
public PCollection<T> expand(PBegin input) {
try {
- Coder<T> coder = getDefaultOutputCoder(input);
+ Coder<T> coder;
+ if (elementCoder.isPresent()) {
+ coder = elementCoder.get();
+ } else if (typeDescriptor.isPresent()) {
+ coder = input.getPipeline().getCoderRegistry().getCoder(typeDescriptor.get());
+ } else {
+ Iterable<T> rawElements =
+ Iterables.transform(
+ timestampedElements,
+ new Function<TimestampedValue<T>, T>() {
+ @Override
+ public T apply(TimestampedValue<T> timestampedValue) {
+ return timestampedValue.getValue();
+ }
+ });
+ coder = getDefaultCreateCoder(input.getPipeline().getCoderRegistry(), rawElements);
+ }
PCollection<TimestampedValue<T>> intermediate = Pipeline.applyTransform(input,
Create.of(timestampedElements).withCoder(TimestampedValueCoder.of(coder)));
@@ -568,26 +622,6 @@
c.outputWithTimestamp(c.element().getValue(), c.element().getTimestamp());
}
}
-
- @Override
- public Coder<T> getDefaultOutputCoder(PBegin input) throws CannotProvideCoderException {
- if (elementCoder.isPresent()) {
- return elementCoder.get();
- } else if (typeDescriptor.isPresent()) {
- return input.getPipeline().getCoderRegistry().getCoder(typeDescriptor.get());
- } else {
- Iterable<T> rawElements =
- Iterables.transform(
- timestampedElements,
- new Function<TimestampedValue<T>, T>() {
- @Override
- public T apply(TimestampedValue<T> input) {
- return input.getValue();
- }
- });
- return getDefaultCreateCoder(input.getPipeline().getCoderRegistry(), rawElements);
- }
- }
}
private static <T> Coder<T> getDefaultCreateCoder(CoderRegistry registry, Iterable<T> elems)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index 37c6263..3e023db 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -524,12 +524,15 @@
* <li>It must return {@code void}.
* </ul>
*
- * <h2>Splittable DoFn's (WARNING: work in progress, do not use)</h2>
+ * <h2>Splittable DoFn's</h2>
*
* <p>A {@link DoFn} is <i>splittable</i> if its {@link ProcessElement} method has a parameter
* whose type is a subtype of {@link RestrictionTracker}. This is an advanced feature and an
- * overwhelming majority of users will never need to write a splittable {@link DoFn}. Right now
- * the implementation of this feature is in progress and it's not ready for any use.
+ * overwhelming majority of users will never need to write a splittable {@link DoFn}.
+ *
+ * <p>Not all runners support Splittable DoFn. See the
+ * <a href="https://beam.apache.org/documentation/runners/capability-matrix/">capability
+ * matrix</a>.
*
* <p>See <a href="https://s.apache.org/splittable-do-fn">the proposal</a> for an overview of the
* involved concepts (<i>splittable DoFn</i>, <i>restriction</i>, <i>restriction tracker</i>).
@@ -558,8 +561,6 @@
* </ul>
*
* <p>A non-splittable {@link DoFn} <i>must not</i> define any of these methods.
- *
- * <p>More documentation will be added when the feature becomes ready for general usage.
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java
index d0314eb..2fd12de 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.sdk.transforms;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
@@ -229,19 +228,18 @@
@Override
public PCollection<T> expand(PCollection<T> input) {
- return input.apply(ParDo.of(new DoFn<T, T>() {
- @ProcessElement
- public void processElement(ProcessContext c) {
- if (predicate.apply(c.element())) {
- c.output(c.element());
- }
- }
- }));
- }
-
- @Override
- protected Coder<T> getDefaultOutputCoder(PCollection<T> input) {
- return input.getCoder();
+ return input
+ .apply(
+ ParDo.of(
+ new DoFn<T, T>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ if (predicate.apply(c.element())) {
+ c.output(c.element());
+ }
+ }
+ }))
+ .setCoder(input.getCoder());
}
@Override
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java
index 25d9c05..8247a58 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.sdk.transforms;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableLikeCoder;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
@@ -129,25 +128,12 @@
windowingStrategy = WindowingStrategy.globalDefault();
}
- return PCollection.<T>createPrimitiveOutputInternal(
+ return PCollection.createPrimitiveOutputInternal(
inputs.getPipeline(),
windowingStrategy,
- isBounded);
- }
-
- @Override
- protected Coder<?> getDefaultOutputCoder(PCollectionList<T> input)
- throws CannotProvideCoderException {
-
- // Take coder from first collection
- for (PCollection<T> pCollection : input.getAll()) {
- return pCollection.getCoder();
- }
-
- // No inputs
- throw new CannotProvideCoderException(
- this.getClass().getSimpleName() + " cannot provide a Coder for"
- + " empty " + PCollectionList.class.getSimpleName());
+ isBounded,
+ // Take coder from first collection. If there are none, will be left unspecified.
+ inputs.getAll().isEmpty() ? null : inputs.get(0).getCoder());
}
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
index 7516b25..3cb0d23 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
@@ -217,13 +217,11 @@
// merging windows as needed, using the windows assigned to the
// key/value input elements and the window merge operation of the
// window function associated with the input PCollection.
- return PCollection.createPrimitiveOutputInternal(input.getPipeline(),
- updateWindowingStrategy(input.getWindowingStrategy()), input.isBounded());
- }
-
- @Override
- protected Coder<KV<K, Iterable<V>>> getDefaultOutputCoder(PCollection<KV<K, V>> input) {
- return getOutputKvCoder(input.getCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ updateWindowingStrategy(input.getWindowingStrategy()),
+ input.isBounded(),
+ getOutputKvCoder(input.getCoder()));
}
/**
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
index 58051df..f5e7830 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
@@ -277,13 +277,16 @@
}
/**
- * Returns the default {@code Coder} to use for the output of this
- * single-output {@code PTransform}.
+ * Returns the default {@code Coder} to use for the output of this single-output {@code
+ * PTransform}.
*
* <p>By default, always throws
*
* @throws CannotProvideCoderException if no coder can be inferred
+ * @deprecated Instead, the PTransform should explicitly call {@link PCollection#setCoder} on the
+ * returned PCollection.
*/
+ @Deprecated
protected Coder<?> getDefaultOutputCoder() throws CannotProvideCoderException {
throw new CannotProvideCoderException("PTransform.getOutputCoder called.");
}
@@ -295,7 +298,10 @@
* <p>By default, always throws.
*
* @throws CannotProvideCoderException if none can be inferred.
+ * @deprecated Instead, the PTransform should explicitly call {@link PCollection#setCoder} on the
+ * returned PCollection.
*/
+ @Deprecated
protected Coder<?> getDefaultOutputCoder(@SuppressWarnings("unused") InputT input)
throws CannotProvideCoderException {
return getDefaultOutputCoder();
@@ -308,7 +314,10 @@
* <p>By default, always throws.
*
* @throws CannotProvideCoderException if none can be inferred.
+ * @deprecated Instead, the PTransform should explicitly call {@link PCollection#setCoder} on the
+ * returned PCollection.
*/
+ @Deprecated
public <T> Coder<T> getDefaultOutputCoder(
InputT input, @SuppressWarnings("unused") PCollection<T> output)
throws CannotProvideCoderException {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 0d03835..a0e1eb2 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -636,19 +636,21 @@
@Override
public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
- finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), input.getCoder());
+ CoderRegistry registry = input.getPipeline().getCoderRegistry();
+ finishSpecifyingStateSpecs(fn, registry, input.getCoder());
TupleTag<OutputT> mainOutput = new TupleTag<>();
- return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput);
- }
-
- @Override
- @SuppressWarnings("unchecked")
- protected Coder<OutputT> getDefaultOutputCoder(PCollection<? extends InputT> input)
- throws CannotProvideCoderException {
- return input.getPipeline().getCoderRegistry().getCoder(
- getFn().getOutputTypeDescriptor(),
- getFn().getInputTypeDescriptor(),
- ((PCollection<InputT>) input).getCoder());
+ PCollection<OutputT> res =
+ input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput);
+ try {
+ res.setCoder(
+ registry.getCoder(
+ getFn().getOutputTypeDescriptor(),
+ getFn().getInputTypeDescriptor(),
+ ((PCollection<InputT>) input).getCoder()));
+ } catch (CannotProvideCoderException e) {
+ // Ignore and leave coder unset.
+ }
+ return res;
}
@Override
@@ -757,7 +759,8 @@
validateWindowType(input, fn);
// Use coder registry to determine coders for all StateSpec defined in the fn signature.
- finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), input.getCoder());
+ CoderRegistry registry = input.getPipeline().getCoderRegistry();
+ finishSpecifyingStateSpecs(fn, registry, input.getCoder());
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
if (signature.usesState() || signature.usesTimers()) {
@@ -767,8 +770,22 @@
PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
input.getPipeline(),
TupleTagList.of(mainOutputTag).and(additionalOutputTags.getAll()),
+ // TODO
+ Collections.<TupleTag<?>, Coder<?>>emptyMap(),
input.getWindowingStrategy(),
input.isBounded());
+ @SuppressWarnings("unchecked")
+ Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
+ for (PCollection<?> out : outputs.getAll().values()) {
+ try {
+ out.setCoder(
+ (Coder)
+ registry.getCoder(
+ out.getTypeDescriptor(), getFn().getInputTypeDescriptor(), inputCoder));
+ } catch (CannotProvideCoderException e) {
+ // Ignore and let coder inference happen later.
+ }
+ }
// The fn will likely be an instance of an anonymous subclass
// such as DoFn<Integer, String> { }, thus will have a high-fidelity
@@ -779,24 +796,6 @@
}
@Override
- protected Coder<OutputT> getDefaultOutputCoder() {
- throw new RuntimeException(
- "internal error: shouldn't be calling this on a multi-output ParDo");
- }
-
- @Override
- public <T> Coder<T> getDefaultOutputCoder(
- PCollection<? extends InputT> input, PCollection<T> output)
- throws CannotProvideCoderException {
- @SuppressWarnings("unchecked")
- Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
- return input.getPipeline().getCoderRegistry().getCoder(
- output.getTypeDescriptor(),
- getFn().getInputTypeDescriptor(),
- inputCoder);
- }
-
- @Override
protected String getKindString() {
return String.format("ParMultiDo(%s)", NameUtils.approximateSimpleName(getFn()));
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java
index c94fad6..f6f3af5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java
@@ -229,7 +229,7 @@
* <pre>
* {@code
* PCollection<KV<K, V>> input = ... // maybe more than one occurrence of a some keys
- * PCollectionView<Map<K, V>> output = input.apply(View.<K, V>asMultimap());
+ * PCollectionView<Map<K, Iterable<V>>> output = input.apply(View.<K, V>asMultimap());
* }</pre>
*
* <p>Currently, the resulting map is required to fit into memory.
@@ -509,9 +509,8 @@
@Override
public PCollection<ElemT> expand(PCollection<ElemT> input) {
- return PCollection.<ElemT>createPrimitiveOutputInternal(
- input.getPipeline(), input.getWindowingStrategy(), input.isBounded())
- .setCoder(input.getCoder());
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded(), input.getCoder());
}
}
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
new file mode 100644
index 0000000..9da2408
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
@@ -0,0 +1,1010 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Ordering;
+import com.google.common.hash.Funnel;
+import com.google.common.hash.Funnels;
+import com.google.common.hash.HashCode;
+import com.google.common.hash.Hashing;
+import com.google.common.hash.PrimitiveSink;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.BooleanCoder;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.DurationCoder;
+import org.apache.beam.sdk.coders.InstantCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.MapCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
+import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.sdk.values.TypeDescriptors.TypeVariableExtractor;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.joda.time.ReadableDuration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Given a "poll function" that produces a potentially growing set of outputs for an input, this
+ * transform simultaneously continuously watches the growth of output sets of all inputs, until a
+ * per-input termination condition is reached.
+ *
+ * <p>The output is returned as an unbounded {@link PCollection} of {@code KV<InputT, OutputT>},
+ * where each {@code OutputT} is associated with the {@code InputT} that produced it, and is
+ * assigned with the timestamp that the poll function returned when this output was detected for the
+ * first time.
+ *
+ * <p>Hypothetical usage example for watching new files in a collection of directories, where for
+ * each directory we assume that new files will not appear if the directory contains a file named
+ * ".complete":
+ *
+ * <pre>{@code
+ * PCollection<String> directories = ...; // E.g. Create.of(single directory)
+ * PCollection<KV<String, String>> matches = filepatterns.apply(Watch.<String, String>growthOf(
+ * new PollFn<String, String>() {
+ * public PollResult<String> apply(TimestampedValue<String> input) {
+ * String directory = input.getValue();
+ * List<TimestampedValue<String>> outputs = new ArrayList<>();
+ * ... List the directory and get creation times of all files ...
+ * boolean isComplete = ... does a file ".complete" exist in the directory ...
+ * return isComplete ? PollResult.complete(outputs) : PollResult.incomplete(outputs);
+ * }
+ * })
+ * // Poll each directory every 5 seconds
+ * .withPollInterval(Duration.standardSeconds(5))
+ * // Stop watching each directory 12 hours after it's seen even if it's incomplete
+ * .withTerminationPerInput(afterTotalOf(Duration.standardHours(12)));
+ * }</pre>
+ *
+ * <p>By default, the watermark for a particular input is computed from a poll result as "earliest
+ * timestamp of new elements in this poll result". It can also be set explicitly via {@link
+ * Growth.PollResult#withWatermark} if the {@link Growth.PollFn} can provide a more optimistic
+ * estimate.
+ *
+ * <p>Note: This transform works only in runners supporting Splittable DoFn: see <a
+ * href="https://beam.apache.org/documentation/runners/capability-matrix/">capability matrix</a>.
+ */
+@Experimental(Experimental.Kind.SPLITTABLE_DO_FN)
+public class Watch {
+ private static final Logger LOG = LoggerFactory.getLogger(Watch.class);
+
+ /** Watches the growth of the given poll function. See class documentation for more details. */
+ public static <InputT, OutputT> Growth<InputT, OutputT> growthOf(
+ Growth.PollFn<InputT, OutputT> pollFn) {
+ return new AutoValue_Watch_Growth.Builder<InputT, OutputT>()
+ .setTerminationPerInput(Watch.Growth.<InputT>never())
+ .setPollFn(pollFn)
+ .build();
+ }
+
+ /** Implementation of {@link #growthOf}. */
+ @AutoValue
+ public abstract static class Growth<InputT, OutputT>
+ extends PTransform<PCollection<InputT>, PCollection<KV<InputT, OutputT>>> {
+ /** The result of a single invocation of a {@link PollFn}. */
+ public static final class PollResult<OutputT> {
+ private final List<TimestampedValue<OutputT>> outputs;
+ // null means unspecified (infer automatically).
+ @Nullable private final Instant watermark;
+
+ private PollResult(List<TimestampedValue<OutputT>> outputs, @Nullable Instant watermark) {
+ this.outputs = outputs;
+ this.watermark = watermark;
+ }
+
+ List<TimestampedValue<OutputT>> getOutputs() {
+ return outputs;
+ }
+
+ @Nullable
+ Instant getWatermark() {
+ return watermark;
+ }
+
+ /**
+ * Sets the watermark - an approximate lower bound on timestamps of future new outputs from
+ * this {@link PollFn}.
+ */
+ public PollResult<OutputT> withWatermark(Instant watermark) {
+ checkNotNull(watermark, "watermark");
+ return new PollResult<>(outputs, watermark);
+ }
+
+ /**
+ * Constructs a {@link PollResult} with the given outputs and declares that there will be no
+ * new outputs for the current input. The {@link PollFn} will not be called again for this
+ * input.
+ */
+ public static <OutputT> PollResult<OutputT> complete(
+ List<TimestampedValue<OutputT>> outputs) {
+ return new PollResult<>(outputs, BoundedWindow.TIMESTAMP_MAX_VALUE);
+ }
+
+ /** Like {@link #complete(List)}, but assigns the same timestamp to all new outputs. */
+ public static <OutputT> PollResult<OutputT> complete(
+ Instant timestamp, List<OutputT> outputs) {
+ return new PollResult<>(
+ addTimestamp(timestamp, outputs), BoundedWindow.TIMESTAMP_MAX_VALUE);
+ }
+
+ /**
+ * Constructs a {@link PollResult} with the given outputs and declares that new outputs might
+ * appear for the current input. By default, {@link Watch} will estimate the watermark for
+ * future new outputs as equal to the earliest of the new outputs from this {@link
+ * PollResult}. To specify a more exact watermark, use {@link #withWatermark(Instant)}.
+ */
+ public static <OutputT> PollResult<OutputT> incomplete(
+ List<TimestampedValue<OutputT>> outputs) {
+ return new PollResult<>(outputs, null);
+ }
+
+ /** Like {@link #incomplete(List)}, but assigns the same timestamp to all new outputs. */
+ public static <OutputT> PollResult<OutputT> incomplete(
+ Instant timestamp, List<OutputT> outputs) {
+ return new PollResult<>(addTimestamp(timestamp, outputs), null);
+ }
+
+ private static <OutputT> List<TimestampedValue<OutputT>> addTimestamp(
+ Instant timestamp, List<OutputT> outputs) {
+ List<TimestampedValue<OutputT>> res = Lists.newArrayListWithExpectedSize(outputs.size());
+ for (OutputT output : outputs) {
+ res.add(TimestampedValue.of(output, timestamp));
+ }
+ return res;
+ }
+ }
+
+ /**
+ * A function that computes the current set of outputs for the given input (given as a {@link
+ * TimestampedValue}), in the form of a {@link PollResult}.
+ */
+ public interface PollFn<InputT, OutputT> extends Serializable {
+ PollResult<OutputT> apply(InputT input, Instant timestamp) throws Exception;
+ }
+
+ /**
+ * A strategy for determining whether it is time to stop polling the current input regardless of
+ * whether its output is complete or not.
+ *
+ * <p>Some built-in termination conditions are {@link #never}, {@link #afterTotalOf} and {@link
+ * #afterTimeSinceNewOutput}. Conditions can be combined using {@link #eitherOf} and {@link
+ * #allOf}. Users can also develop custom termination conditions, for example, one might imagine
+ * a condition that terminates after a given time after the first output appears for the input
+ * (unlike {@link #afterTotalOf} which operates relative to when the input itself arrives).
+ *
+ * <p>A {@link TerminationCondition} is provided to {@link
+ * Growth#withTerminationPerInput(TerminationCondition)} and is used to maintain an independent
+ * state of the termination condition for every input, represented as {@code StateT} which must
+ * be immutable, non-null, and encodable via {@link #getStateCoder()}.
+ *
+ * <p>All functions take the wall-clock timestamp as {@link Instant} for convenience of
+ * unit-testing custom termination conditions.
+ */
+ public interface TerminationCondition<InputT, StateT> extends Serializable {
+ /** Used to encode the state of this {@link TerminationCondition}. */
+ Coder<StateT> getStateCoder();
+
+ /**
+ * Called by the {@link Watch} transform to create a new independent termination state for a
+ * newly arrived {@code InputT}.
+ */
+ StateT forNewInput(Instant now, InputT input);
+
+ /**
+ * Called by the {@link Watch} transform to compute a new termination state, in case after
+ * calling the {@link PollFn} for the current input, the {@link PollResult} included a
+ * previously unseen {@code OutputT}.
+ */
+ StateT onSeenNewOutput(Instant now, StateT state);
+
+ /**
+ * Called by the {@link Watch} transform to determine whether the given termination state
+ * signals that {@link Watch} should stop calling {@link PollFn} for the current input,
+ * regardless of whether the last {@link PollResult} was complete or incomplete.
+ */
+ boolean canStopPolling(Instant now, StateT state);
+
+ /** Creates a human-readable representation of the given state of this condition. */
+ String toString(StateT state);
+ }
+
+ /**
+ * Returns a {@link TerminationCondition} that never holds (i.e., poll each input until its
+ * output is complete).
+ */
+ public static <InputT> Never<InputT> never() {
+ return new Never<>();
+ }
+
+ /**
+ * Returns a {@link TerminationCondition} that holds after the given time has elapsed after the
+ * current input was seen.
+ */
+ public static <InputT> AfterTotalOf<InputT> afterTotalOf(ReadableDuration timeSinceInput) {
+ return afterTotalOf(SerializableFunctions.<InputT, ReadableDuration>constant(timeSinceInput));
+ }
+
+ /** Like {@link #afterTotalOf(ReadableDuration)}, but the duration is input-dependent. */
+ public static <InputT> AfterTotalOf<InputT> afterTotalOf(
+ SerializableFunction<InputT, ReadableDuration> timeSinceInput) {
+ return new AfterTotalOf<>(timeSinceInput);
+ }
+
+ /**
+ * Returns a {@link TerminationCondition} that holds after the given time has elapsed after the
+ * last time the {@link PollResult} for the current input contained a previously unseen output.
+ */
+ public static <InputT> AfterTimeSinceNewOutput<InputT> afterTimeSinceNewOutput(
+ ReadableDuration timeSinceNewOutput) {
+ return afterTimeSinceNewOutput(
+ SerializableFunctions.<InputT, ReadableDuration>constant(timeSinceNewOutput));
+ }
+
+ /**
+ * Like {@link #afterTimeSinceNewOutput(ReadableDuration)}, but the duration is input-dependent.
+ */
+ public static <InputT> AfterTimeSinceNewOutput<InputT> afterTimeSinceNewOutput(
+ SerializableFunction<InputT, ReadableDuration> timeSinceNewOutput) {
+ return new AfterTimeSinceNewOutput<>(timeSinceNewOutput);
+ }
+
+ /**
+ * Returns a {@link TerminationCondition} that holds when at least one of the given two
+ * conditions holds.
+ */
+ public static <InputT, FirstStateT, SecondStateT>
+ BinaryCombined<InputT, FirstStateT, SecondStateT> eitherOf(
+ TerminationCondition<InputT, FirstStateT> first,
+ TerminationCondition<InputT, SecondStateT> second) {
+ return new BinaryCombined<>(BinaryCombined.Operation.OR, first, second);
+ }
+
+ /**
+ * Returns a {@link TerminationCondition} that holds when both of the given two conditions hold.
+ */
+ public static <InputT, FirstStateT, SecondStateT>
+ BinaryCombined<InputT, FirstStateT, SecondStateT> allOf(
+ TerminationCondition<InputT, FirstStateT> first,
+ TerminationCondition<InputT, SecondStateT> second) {
+ return new BinaryCombined<>(BinaryCombined.Operation.AND, first, second);
+ }
+
+ // Uses Integer rather than Void for state, because termination state must be non-null.
+ static class Never<InputT> implements TerminationCondition<InputT, Integer> {
+ @Override
+ public Coder<Integer> getStateCoder() {
+ return VarIntCoder.of();
+ }
+
+ @Override
+ public Integer forNewInput(Instant now, InputT input) {
+ return 0;
+ }
+
+ @Override
+ public Integer onSeenNewOutput(Instant now, Integer state) {
+ return state;
+ }
+
+ @Override
+ public boolean canStopPolling(Instant now, Integer state) {
+ return false;
+ }
+
+ @Override
+ public String toString(Integer state) {
+ return "Never";
+ }
+ }
+
+ static class AfterTotalOf<InputT>
+ implements TerminationCondition<
+ InputT, KV<Instant /* timeStarted */, ReadableDuration /* maxTimeSinceInput */>> {
+ private final SerializableFunction<InputT, ReadableDuration> maxTimeSinceInput;
+
+ private AfterTotalOf(SerializableFunction<InputT, ReadableDuration> maxTimeSinceInput) {
+ this.maxTimeSinceInput = maxTimeSinceInput;
+ }
+
+ @Override
+ public Coder<KV<Instant, ReadableDuration>> getStateCoder() {
+ return KvCoder.of(InstantCoder.of(), DurationCoder.of());
+ }
+
+ @Override
+ public KV<Instant, ReadableDuration> forNewInput(Instant now, InputT input) {
+ return KV.of(now, maxTimeSinceInput.apply(input));
+ }
+
+ @Override
+ public KV<Instant, ReadableDuration> onSeenNewOutput(
+ Instant now, KV<Instant, ReadableDuration> state) {
+ return state;
+ }
+
+ @Override
+ public boolean canStopPolling(Instant now, KV<Instant, ReadableDuration> state) {
+ return new Duration(state.getKey(), now).isLongerThan(state.getValue());
+ }
+
+ @Override
+ public String toString(KV<Instant, ReadableDuration> state) {
+ return "AfterTotalOf{"
+ + "timeStarted="
+ + state.getKey()
+ + ", maxTimeSinceInput="
+ + state.getValue()
+ + '}';
+ }
+ }
+
+ static class AfterTimeSinceNewOutput<InputT>
+ implements TerminationCondition<
+ InputT,
+ KV<Instant /* timeOfLastNewOutput */, ReadableDuration /* maxTimeSinceNewOutput */>> {
+ private final SerializableFunction<InputT, ReadableDuration> maxTimeSinceNewOutput;
+
+ private AfterTimeSinceNewOutput(
+ SerializableFunction<InputT, ReadableDuration> maxTimeSinceNewOutput) {
+ this.maxTimeSinceNewOutput = maxTimeSinceNewOutput;
+ }
+
+ @Override
+ public Coder<KV<Instant, ReadableDuration>> getStateCoder() {
+ return KvCoder.of(NullableCoder.of(InstantCoder.of()), DurationCoder.of());
+ }
+
+ @Override
+ public KV<Instant, ReadableDuration> forNewInput(Instant now, InputT input) {
+ return KV.of(null, maxTimeSinceNewOutput.apply(input));
+ }
+
+ @Override
+ public KV<Instant, ReadableDuration> onSeenNewOutput(
+ Instant now, KV<Instant, ReadableDuration> state) {
+ return KV.of(now, state.getValue());
+ }
+
+ @Override
+ public boolean canStopPolling(Instant now, KV<Instant, ReadableDuration> state) {
+ Instant timeOfLastNewOutput = state.getKey();
+ ReadableDuration maxTimeSinceNewOutput = state.getValue();
+ return timeOfLastNewOutput != null
+ && new Duration(timeOfLastNewOutput, now).isLongerThan(maxTimeSinceNewOutput);
+ }
+
+ @Override
+ public String toString(KV<Instant, ReadableDuration> state) {
+ return "AfterTimeSinceNewOutput{"
+ + "timeOfLastNewOutput="
+ + state.getKey()
+ + ", maxTimeSinceNewOutput="
+ + state.getValue()
+ + '}';
+ }
+ }
+
+ static class BinaryCombined<InputT, FirstStateT, SecondStateT>
+ implements TerminationCondition<InputT, KV<FirstStateT, SecondStateT>> {
+ private enum Operation {
+ OR,
+ AND
+ }
+
+ private final Operation operation;
+ private final TerminationCondition<InputT, FirstStateT> first;
+ private final TerminationCondition<InputT, SecondStateT> second;
+
+ public BinaryCombined(
+ Operation operation,
+ TerminationCondition<InputT, FirstStateT> first,
+ TerminationCondition<InputT, SecondStateT> second) {
+ this.operation = operation;
+ this.first = first;
+ this.second = second;
+ }
+
+ @Override
+ public Coder<KV<FirstStateT, SecondStateT>> getStateCoder() {
+ return KvCoder.of(first.getStateCoder(), second.getStateCoder());
+ }
+
+ @Override
+ public KV<FirstStateT, SecondStateT> forNewInput(Instant now, InputT input) {
+ return KV.of(first.forNewInput(now, input), second.forNewInput(now, input));
+ }
+
+ @Override
+ public KV<FirstStateT, SecondStateT> onSeenNewOutput(
+ Instant now, KV<FirstStateT, SecondStateT> state) {
+ return KV.of(
+ first.onSeenNewOutput(now, state.getKey()),
+ second.onSeenNewOutput(now, state.getValue()));
+ }
+
+ @Override
+ public boolean canStopPolling(Instant now, KV<FirstStateT, SecondStateT> state) {
+ switch (operation) {
+ case OR:
+ return first.canStopPolling(now, state.getKey())
+ || second.canStopPolling(now, state.getValue());
+ case AND:
+ return first.canStopPolling(now, state.getKey())
+ && second.canStopPolling(now, state.getValue());
+ default:
+ throw new UnsupportedOperationException("Unexpected operation " + operation);
+ }
+ }
+
+ @Override
+ public String toString(KV<FirstStateT, SecondStateT> state) {
+ return operation
+ + "{first="
+ + first.toString(state.getKey())
+ + ", second="
+ + second.toString(state.getValue())
+ + '}';
+ }
+ }
+
+ abstract PollFn<InputT, OutputT> getPollFn();
+
+ @Nullable
+ abstract Duration getPollInterval();
+
+ @Nullable
+ abstract TerminationCondition<InputT, ?> getTerminationPerInput();
+
+ @Nullable
+ abstract Coder<OutputT> getOutputCoder();
+
+ abstract Builder<InputT, OutputT> toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder<InputT, OutputT> {
+ abstract Builder<InputT, OutputT> setPollFn(PollFn<InputT, OutputT> pollFn);
+
+ abstract Builder<InputT, OutputT> setTerminationPerInput(
+ TerminationCondition<InputT, ?> terminationPerInput);
+
+ abstract Builder<InputT, OutputT> setPollInterval(Duration pollInterval);
+
+ abstract Builder<InputT, OutputT> setOutputCoder(Coder<OutputT> outputCoder);
+
+ abstract Growth<InputT, OutputT> build();
+ }
+
+ /** Specifies a {@link TerminationCondition} that will be independently used for every input. */
+ public Growth<InputT, OutputT> withTerminationPerInput(
+ TerminationCondition<InputT, ?> terminationPerInput) {
+ return toBuilder().setTerminationPerInput(terminationPerInput).build();
+ }
+
+ /**
+ * Specifies how long to wait after a call to {@link PollFn} before calling it again (if at all
+ * - according to {@link PollResult} and the {@link TerminationCondition}).
+ */
+ public Growth<InputT, OutputT> withPollInterval(Duration pollInterval) {
+ return toBuilder().setPollInterval(pollInterval).build();
+ }
+
+ /**
+ * Specifies a {@link Coder} to use for the outputs. If unspecified, it will be inferred from
+ * the output type of {@link PollFn} whenever possible.
+ *
+ * <p>The coder must be deterministic, because the transform will compare encoded outputs for
+ * deduplication between polling rounds.
+ */
+ public Growth<InputT, OutputT> withOutputCoder(Coder<OutputT> outputCoder) {
+ return toBuilder().setOutputCoder(outputCoder).build();
+ }
+
+ @Override
+ public PCollection<KV<InputT, OutputT>> expand(PCollection<InputT> input) {
+ checkNotNull(getPollInterval(), "pollInterval");
+ checkNotNull(getTerminationPerInput(), "terminationPerInput");
+
+ Coder<OutputT> outputCoder = getOutputCoder();
+ if (outputCoder == null) {
+ // If a coder was not specified explicitly, infer it from the OutputT type parameter
+ // of the PollFn.
+ TypeDescriptor<OutputT> outputT =
+ TypeDescriptors.extractFromTypeParameters(
+ getPollFn(),
+ PollFn.class,
+ new TypeVariableExtractor<PollFn<InputT, OutputT>, OutputT>() {});
+ try {
+ outputCoder = input.getPipeline().getCoderRegistry().getCoder(outputT);
+ } catch (CannotProvideCoderException e) {
+ throw new RuntimeException(
+ "Unable to infer coder for OutputT. Specify it explicitly using withOutputCoder().");
+ }
+ }
+ try {
+ outputCoder.verifyDeterministic();
+ } catch (Coder.NonDeterministicException e) {
+ throw new IllegalArgumentException(
+ "Output coder " + outputCoder + " must be deterministic");
+ }
+
+ return input
+ .apply(ParDo.of(new WatchGrowthFn<>(this, outputCoder)))
+ .setCoder(KvCoder.of(input.getCoder(), outputCoder));
+ }
+ }
+
+ private static class WatchGrowthFn<InputT, OutputT, TerminationStateT>
+ extends DoFn<InputT, KV<InputT, OutputT>> {
+ private final Watch.Growth<InputT, OutputT> spec;
+ private final Coder<OutputT> outputCoder;
+
+ private WatchGrowthFn(Growth<InputT, OutputT> spec, Coder<OutputT> outputCoder) {
+ this.spec = spec;
+ this.outputCoder = outputCoder;
+ }
+
+ @ProcessElement
+ public ProcessContinuation process(
+ ProcessContext c, final GrowthTracker<OutputT, TerminationStateT> tracker)
+ throws Exception {
+ if (!tracker.hasPending() && !tracker.currentRestriction().isOutputComplete) {
+ LOG.debug("{} - polling input", c.element());
+ Growth.PollResult<OutputT> res = spec.getPollFn().apply(c.element(), c.timestamp());
+ // TODO (https://issues.apache.org/jira/browse/BEAM-2680):
+ // Consider truncating the pending outputs if there are too many, to avoid blowing
+ // up the state. In that case, we'd rely on the next poll cycle to provide more outputs.
+ // All outputs would still have to be stored in state.completed, but it is more compact
+ // because it stores hashes and because it could potentially be garbage-collected.
+ int numPending = tracker.addNewAsPending(res);
+ if (numPending > 0) {
+ LOG.info(
+ "{} - polling returned {} results, of which {} were new. The output is {}.",
+ c.element(),
+ res.getOutputs().size(),
+ numPending,
+ BoundedWindow.TIMESTAMP_MAX_VALUE.equals(res.getWatermark())
+ ? "complete"
+ : "incomplete");
+ }
+ }
+ while (tracker.hasPending()) {
+ c.updateWatermark(tracker.getWatermark());
+
+ TimestampedValue<OutputT> nextPending = tracker.tryClaimNextPending();
+ if (nextPending == null) {
+ return stop();
+ }
+ c.outputWithTimestamp(
+ KV.of(c.element(), nextPending.getValue()), nextPending.getTimestamp());
+ }
+ Instant watermark = tracker.getWatermark();
+ if (watermark != null) {
+ // Null means the poll result did not provide a watermark and there were no new elements,
+ // so we have no information to update the watermark and should keep it as-is.
+ c.updateWatermark(watermark);
+ }
+ // No more pending outputs - future output will come from more polling,
+ // unless output is complete or termination condition is reached.
+ if (tracker.shouldPollMore()) {
+ return resume().withResumeDelay(spec.getPollInterval());
+ }
+ return stop();
+ }
+
+ private Growth.TerminationCondition<InputT, TerminationStateT> getTerminationCondition() {
+ return ((Growth.TerminationCondition<InputT, TerminationStateT>)
+ spec.getTerminationPerInput());
+ }
+
+ @GetInitialRestriction
+ public GrowthState<OutputT, TerminationStateT> getInitialRestriction(InputT element) {
+ return new GrowthState<>(getTerminationCondition().forNewInput(Instant.now(), element));
+ }
+
+ @NewTracker
+ public GrowthTracker<OutputT, TerminationStateT> newTracker(
+ GrowthState<OutputT, TerminationStateT> restriction) {
+ return new GrowthTracker<>(outputCoder, restriction, getTerminationCondition());
+ }
+
+ @GetRestrictionCoder
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public Coder<GrowthState<OutputT, TerminationStateT>> getRestrictionCoder() {
+ return GrowthStateCoder.of(
+ outputCoder, (Coder) spec.getTerminationPerInput().getStateCoder());
+ }
+ }
+
+ @VisibleForTesting
+ static class GrowthState<OutputT, TerminationStateT> {
+ // Hashes and timestamps of outputs that have already been output and should be omitted
+ // from future polls. Timestamps are preserved to allow garbage-collecting this state
+ // in the future, e.g. dropping elements from "completed" and from addNewAsPending() if their
+ // timestamp is more than X behind the watermark.
+ // As of writing, we don't do this, but preserve the information for forward compatibility
+ // in case of pipeline update. TODO: do this.
+ private final Map<HashCode, Instant> completed;
+ // Outputs that are known to be present in a poll result, but have not yet been returned
+ // from a ProcessElement call, sorted by timestamp to help smooth watermark progress.
+ private final List<TimestampedValue<OutputT>> pending;
+ // If true, processing of this restriction should only output "pending". Otherwise, it should
+ // also continue polling.
+ private final boolean isOutputComplete;
+ // Can be null only if isOutputComplete is true.
+ @Nullable private final TerminationStateT terminationState;
+ // A lower bound on timestamps of future outputs from PollFn, excluding completed and pending.
+ @Nullable private final Instant pollWatermark;
+
+ GrowthState(TerminationStateT terminationState) {
+ this.completed = Collections.emptyMap();
+ this.pending = Collections.emptyList();
+ this.isOutputComplete = false;
+ this.terminationState = checkNotNull(terminationState);
+ this.pollWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
+ }
+
+ GrowthState(
+ Map<HashCode, Instant> completed,
+ List<TimestampedValue<OutputT>> pending,
+ boolean isOutputComplete,
+ @Nullable TerminationStateT terminationState,
+ @Nullable Instant pollWatermark) {
+ if (!isOutputComplete) {
+ checkNotNull(terminationState);
+ }
+ this.completed = Collections.unmodifiableMap(completed);
+ this.pending = Collections.unmodifiableList(pending);
+ this.isOutputComplete = isOutputComplete;
+ this.terminationState = terminationState;
+ this.pollWatermark = pollWatermark;
+ }
+
+ public String toString(Growth.TerminationCondition<?, TerminationStateT> terminationCondition) {
+ return "GrowthState{"
+ + "completed=<"
+ + completed.size()
+ + " elements>, pending=<"
+ + pending.size()
+ + " elements"
+ + (pending.isEmpty() ? "" : (", earliest " + pending.get(0)))
+ + ">, isOutputComplete="
+ + isOutputComplete
+ + ", terminationState="
+ + terminationCondition.toString(terminationState)
+ + ", pollWatermark="
+ + pollWatermark
+ + '}';
+ }
+ }
+
+ @VisibleForTesting
+ static class GrowthTracker<OutputT, TerminationStateT>
+ implements RestrictionTracker<GrowthState<OutputT, TerminationStateT>> {
+ private final Funnel<OutputT> coderFunnel;
+ private final Growth.TerminationCondition<?, TerminationStateT> terminationCondition;
+
+ // The restriction describing the entire work to be done by the current ProcessElement call.
+ // Changes only in checkpoint().
+ private GrowthState<OutputT, TerminationStateT> state;
+
+ // Mutable state changed by the ProcessElement call itself, and used to compute the primary
+ // and residual restrictions in checkpoint().
+
+ // Remaining pending outputs; initialized from state.pending (if non-empty) or in
+ // addNewAsPending(); drained via tryClaimNextPending().
+ private LinkedList<TimestampedValue<OutputT>> pending;
+ // Outputs that have been claimed in the current ProcessElement call. A prefix of "pending".
+ private List<TimestampedValue<OutputT>> claimed = Lists.newArrayList();
+ private boolean isOutputComplete;
+ private TerminationStateT terminationState;
+ @Nullable private Instant pollWatermark;
+ private boolean shouldStop = false;
+
+ GrowthTracker(final Coder<OutputT> outputCoder, GrowthState<OutputT, TerminationStateT> state,
+ Growth.TerminationCondition<?, TerminationStateT> terminationCondition) {
+ this.coderFunnel =
+ new Funnel<OutputT>() {
+ @Override
+ public void funnel(OutputT from, PrimitiveSink into) {
+ try {
+ outputCoder.encode(from, Funnels.asOutputStream(into));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+ this.terminationCondition = terminationCondition;
+ this.state = state;
+ this.isOutputComplete = state.isOutputComplete;
+ this.pollWatermark = state.pollWatermark;
+ this.terminationState = state.terminationState;
+ this.pending = Lists.newLinkedList(state.pending);
+ }
+
+ @Override
+ public synchronized GrowthState<OutputT, TerminationStateT> currentRestriction() {
+ return state;
+ }
+
+ @Override
+ public synchronized GrowthState<OutputT, TerminationStateT> checkpoint() {
+ // primary should contain exactly the work claimed in the current ProcessElement call - i.e.
+ // claimed outputs become pending, and it shouldn't poll again.
+ GrowthState<OutputT, TerminationStateT> primary =
+ new GrowthState<>(
+ state.completed /* completed */,
+ claimed /* pending */,
+ true /* isOutputComplete */,
+ null /* terminationState */,
+ BoundedWindow.TIMESTAMP_MAX_VALUE /* pollWatermark */);
+
+ // residual should contain exactly the work *not* claimed in the current ProcessElement call -
+ // unclaimed pending outputs plus future polling outputs.
+ Map<HashCode, Instant> newCompleted = Maps.newHashMap(state.completed);
+ for (TimestampedValue<OutputT> claimedOutput : claimed) {
+ newCompleted.put(hash128(claimedOutput.getValue()), claimedOutput.getTimestamp());
+ }
+ GrowthState<OutputT, TerminationStateT> residual =
+ new GrowthState<>(
+ newCompleted /* completed */,
+ pending /* pending */,
+ isOutputComplete /* isOutputComplete */,
+ terminationState,
+ pollWatermark);
+
+ // Morph ourselves into primary, except for "pending" - the current call has already claimed
+ // everything from it.
+ this.state = primary;
+ this.isOutputComplete = primary.isOutputComplete;
+ this.pollWatermark = primary.pollWatermark;
+ this.terminationState = null;
+ this.pending = Lists.newLinkedList();
+
+ this.shouldStop = true;
+ return residual;
+ }
+
+ private HashCode hash128(OutputT value) {
+ return Hashing.murmur3_128().hashObject(value, coderFunnel);
+ }
+
+ @Override
+ public synchronized void checkDone() throws IllegalStateException {
+ if (shouldStop) {
+ return;
+ }
+ checkState(!shouldPollMore(), "Polling is still allowed to continue");
+ checkState(pending.isEmpty(), "There are %s unclaimed pending outputs", pending.size());
+ }
+
+ @VisibleForTesting
+ synchronized boolean hasPending() {
+ return !pending.isEmpty();
+ }
+
+ @VisibleForTesting
+ @Nullable
+ synchronized TimestampedValue<OutputT> tryClaimNextPending() {
+ if (shouldStop) {
+ return null;
+ }
+ checkState(!pending.isEmpty(), "No more unclaimed pending outputs");
+ TimestampedValue<OutputT> value = pending.removeFirst();
+ claimed.add(value);
+ return value;
+ }
+
+ @VisibleForTesting
+ synchronized boolean shouldPollMore() {
+ return !isOutputComplete
+ && !terminationCondition.canStopPolling(Instant.now(), terminationState);
+ }
+
+ @VisibleForTesting
+ synchronized int addNewAsPending(Growth.PollResult<OutputT> pollResult) {
+ checkState(
+ state.pending.isEmpty(),
+ "Should have drained all old pending outputs before adding new, "
+ + "but there are %s old pending outputs",
+ state.pending.size());
+ List<TimestampedValue<OutputT>> newPending = Lists.newArrayList();
+ for (TimestampedValue<OutputT> output : pollResult.getOutputs()) {
+ OutputT value = output.getValue();
+ if (state.completed.containsKey(hash128(value))) {
+ continue;
+ }
+ // TODO (https://issues.apache.org/jira/browse/BEAM-2680):
+ // Consider adding only at most N pending elements and ignoring others,
+ // instead relying on future poll rounds to provide them, in order to avoid
+ // blowing up the state. Combined with garbage collection of GrowthState.completed,
+ // this would make the transform scalable to very large poll results.
+ newPending.add(TimestampedValue.of(value, output.getTimestamp()));
+ }
+ if (!newPending.isEmpty()) {
+ terminationState = terminationCondition.onSeenNewOutput(Instant.now(), terminationState);
+ }
+ this.pending =
+ Lists.newLinkedList(
+ Ordering.natural()
+ .onResultOf(
+ new Function<TimestampedValue<OutputT>, Instant>() {
+ @Override
+ public Instant apply(TimestampedValue<OutputT> output) {
+ return output.getTimestamp();
+ }
+ })
+ .sortedCopy(newPending));
+ // If poll result doesn't provide a watermark, assume that future new outputs may
+ // arrive with about the same timestamps as the current new outputs.
+ if (pollResult.getWatermark() != null) {
+ this.pollWatermark = pollResult.getWatermark();
+ } else if (!pending.isEmpty()) {
+ this.pollWatermark = pending.getFirst().getTimestamp();
+ }
+ if (BoundedWindow.TIMESTAMP_MAX_VALUE.equals(pollWatermark)) {
+ isOutputComplete = true;
+ }
+ return pending.size();
+ }
+
+ @VisibleForTesting
+ synchronized Instant getWatermark() {
+ // Future elements that can be claimed in this restriction come either from
+ // "pending" or from future polls, so the total watermark is
+ // min(watermark for future polling, earliest remaining pending element)
+ return Ordering.natural()
+ .nullsLast()
+ .min(pollWatermark, pending.isEmpty() ? null : pending.getFirst().getTimestamp());
+ }
+
+ @Override
+ public synchronized String toString() {
+ return "GrowthTracker{"
+ + "state="
+ + state.toString(terminationCondition)
+ + ", pending=<"
+ + pending.size()
+ + " elements"
+ + (pending.isEmpty() ? "" : (", earliest " + pending.get(0)))
+ + ">, claimed=<"
+ + claimed.size()
+ + " elements>, isOutputComplete="
+ + isOutputComplete
+ + ", terminationState="
+ + terminationState
+ + ", pollWatermark="
+ + pollWatermark
+ + ", shouldStop="
+ + shouldStop
+ + '}';
+ }
+ }
+
+ private static class HashCode128Coder extends AtomicCoder<HashCode> {
+ private static final HashCode128Coder INSTANCE = new HashCode128Coder();
+
+ public static HashCode128Coder of() {
+ return INSTANCE;
+ }
+
+ @Override
+ public void encode(HashCode value, OutputStream os) throws IOException {
+ checkArgument(
+ value.bits() == 128, "Expected a 128-bit hash code, but got %s bits", value.bits());
+ byte[] res = new byte[16];
+ value.writeBytesTo(res, 0, 16);
+ os.write(res);
+ }
+
+ @Override
+ public HashCode decode(InputStream is) throws IOException {
+ byte[] res = new byte[16];
+ int numRead = is.read(res, 0, 16);
+ checkArgument(numRead == 16, "Expected to read 16 bytes, but read %s", numRead);
+ return HashCode.fromBytes(res);
+ }
+ }
+
+ private static class GrowthStateCoder<OutputT, TerminationStateT>
+ extends StructuredCoder<GrowthState<OutputT, TerminationStateT>> {
+ public static <OutputT, TerminationStateT> GrowthStateCoder<OutputT, TerminationStateT> of(
+ Coder<OutputT> outputCoder, Coder<TerminationStateT> terminationStateCoder) {
+ return new GrowthStateCoder<>(outputCoder, terminationStateCoder);
+ }
+
+ private static final Coder<Boolean> BOOLEAN_CODER = BooleanCoder.of();
+ private static final Coder<Instant> INSTANT_CODER = NullableCoder.of(InstantCoder.of());
+ private static final Coder<HashCode> HASH_CODE_CODER = HashCode128Coder.of();
+
+ private final Coder<OutputT> outputCoder;
+ private final Coder<Map<HashCode, Instant>> completedCoder;
+ private final Coder<List<TimestampedValue<OutputT>>> pendingCoder;
+ private final Coder<TerminationStateT> terminationStateCoder;
+
+ private GrowthStateCoder(
+ Coder<OutputT> outputCoder, Coder<TerminationStateT> terminationStateCoder) {
+ this.outputCoder = outputCoder;
+ this.terminationStateCoder = terminationStateCoder;
+ this.completedCoder = MapCoder.of(HASH_CODE_CODER, INSTANT_CODER);
+ this.pendingCoder = ListCoder.of(TimestampedValue.TimestampedValueCoder.of(outputCoder));
+ }
+
+ @Override
+ public void encode(GrowthState<OutputT, TerminationStateT> value, OutputStream os)
+ throws IOException {
+ completedCoder.encode(value.completed, os);
+ pendingCoder.encode(value.pending, os);
+ BOOLEAN_CODER.encode(value.isOutputComplete, os);
+ terminationStateCoder.encode(value.terminationState, os);
+ INSTANT_CODER.encode(value.pollWatermark, os);
+ }
+
+ @Override
+ public GrowthState<OutputT, TerminationStateT> decode(InputStream is) throws IOException {
+ Map<HashCode, Instant> completed = completedCoder.decode(is);
+ List<TimestampedValue<OutputT>> pending = pendingCoder.decode(is);
+ boolean isOutputComplete = BOOLEAN_CODER.decode(is);
+ TerminationStateT terminationState = terminationStateCoder.decode(is);
+ Instant pollWatermark = INSTANT_CODER.decode(is);
+ return new GrowthState<>(
+ completed, pending, isOutputComplete, terminationState, pollWatermark);
+ }
+
+ @Override
+ public List<? extends Coder<?>> getCoderArguments() {
+ return Arrays.asList(outputCoder, terminationStateCoder);
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ outputCoder.verifyDeterministic();
+ }
+ }
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java
index a12be6d..2337798 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java
@@ -23,7 +23,6 @@
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
@@ -453,11 +452,6 @@
}
@Override
- protected Coder<?> getDefaultOutputCoder(PCollection<T> input) {
- return input.getCoder();
- }
-
- @Override
protected String getKindString() {
return "Window.Into()";
}
@@ -484,7 +478,7 @@
@Override
public PCollection<T> expand(PCollection<T> input) {
return PCollection.createPrimitiveOutputInternal(
- input.getPipeline(), updatedStrategy, input.isBounded());
+ input.getPipeline(), updatedStrategy, input.isBounded(), input.getCoder());
}
@Override
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java
index 4063d11..e8bf9b8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java
@@ -366,10 +366,15 @@
public static <T> PCollection<T> createPrimitiveOutputInternal(
Pipeline pipeline,
WindowingStrategy<?, ?> windowingStrategy,
- IsBounded isBounded) {
- return new PCollection<T>(pipeline)
+ IsBounded isBounded,
+ @Nullable Coder<T> coder) {
+ PCollection<T> res = new PCollection<T>(pipeline)
.setWindowingStrategyInternal(windowingStrategy)
.setIsBoundedInternal(isBounded);
+ if (coder != null) {
+ res.setCoder(coder);
+ }
+ return res;
}
private static class CoderOrFailure<T> {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
index 793994f..9799d0e 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java
@@ -24,6 +24,7 @@
import java.util.Objects;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection.IsBounded;
@@ -201,6 +202,7 @@
public static PCollectionTuple ofPrimitiveOutputsInternal(
Pipeline pipeline,
TupleTagList outputTags,
+ Map<TupleTag<?>, Coder<?>> coders,
WindowingStrategy<?, ?> windowingStrategy,
IsBounded isBounded) {
Map<TupleTag<?>, PCollection<?>> pcollectionMap = new LinkedHashMap<>();
@@ -217,10 +219,10 @@
// erasure as the correct type. When a transform adds
// elements to `outputCollection` they will be of type T.
@SuppressWarnings("unchecked")
- TypeDescriptor<Object> token = (TypeDescriptor<Object>) outputTag.getTypeDescriptor();
- PCollection<Object> outputCollection = PCollection
- .createPrimitiveOutputInternal(pipeline, windowingStrategy, isBounded)
- .setTypeDescriptor(token);
+ PCollection outputCollection =
+ PCollection.createPrimitiveOutputInternal(
+ pipeline, windowingStrategy, isBounded, coders.get(outputTag))
+ .setTypeDescriptor((TypeDescriptor) outputTag.getTypeDescriptor());
pcollectionMap.put(outputTag, outputCollection);
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java
index 14f2cb8..dd6a0fd 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java
@@ -328,30 +328,64 @@
}
/**
- * Returns a new {@code TypeDescriptor} where type variables represented by
- * {@code typeParameter} are substituted by {@code typeDescriptor}. For example, it can be used to
- * construct {@code Map<K, V>} for any {@code K} and {@code V} type: <pre> {@code
- * static <K, V> TypeDescriptor<Map<K, V>> mapOf(
- * TypeDescriptor<K> keyType, TypeDescriptor<V> valueType) {
- * return new TypeDescriptor<Map<K, V>>() {}
- * .where(new TypeParameter<K>() {}, keyType)
- * .where(new TypeParameter<V>() {}, valueType);
- * }}</pre>
+ * Returns a new {@code TypeDescriptor} where the type variable represented by {@code
+ * typeParameter} are substituted by {@code typeDescriptor}. For example, it can be used to
+ * construct {@code Map<K, V>} for any {@code K} and {@code V} type:
+ *
+ * <pre>{@code
+ * static <K, V> TypeDescriptor<Map<K, V>> mapOf(
+ * TypeDescriptor<K> keyType, TypeDescriptor<V> valueType) {
+ * return new TypeDescriptor<Map<K, V>>() {}
+ * .where(new TypeParameter<K>() {}, keyType)
+ * .where(new TypeParameter<V>() {}, valueType);
+ * }
+ * }</pre>
*
* @param <X> The parameter type
* @param typeParameter the parameter type variable
* @param typeDescriptor the actual type to substitute
*/
@SuppressWarnings("unchecked")
- public <X> TypeDescriptor<T> where(TypeParameter<X> typeParameter,
- TypeDescriptor<X> typeDescriptor) {
- TypeResolver resolver =
- new TypeResolver()
- .where(
- typeParameter.typeVariable, typeDescriptor.getType());
+ public <X> TypeDescriptor<T> where(
+ TypeParameter<X> typeParameter, TypeDescriptor<X> typeDescriptor) {
+ return where(typeParameter.typeVariable, typeDescriptor.getType());
+ }
+
+ /**
+ * A more general form of {@link #where(TypeParameter, TypeDescriptor)} that returns a new {@code
+ * TypeDescriptor} by matching {@code formal} against {@code actual} to resolve type variables in
+ * the current {@link TypeDescriptor}.
+ */
+ @SuppressWarnings("unchecked")
+ public TypeDescriptor<T> where(Type formal, Type actual) {
+ TypeResolver resolver = new TypeResolver().where(formal, actual);
return (TypeDescriptor<T>) TypeDescriptor.of(resolver.resolveType(token.getType()));
}
+ /**
+ * Returns whether this {@link TypeDescriptor} has any unresolved type parameters, as opposed to
+ * being a concrete type.
+ *
+ * <p>For example:
+ * <pre>{@code
+ * TypeDescriptor.of(new ArrayList<String>() {}.getClass()).hasUnresolvedTypeParameters()
+ * => false, because the anonymous class is instantiated with a concrete type
+ *
+ * class TestUtils {
+ * <T> ArrayList<T> createTypeErasedList() {
+ * return new ArrayList<T>() {};
+ * }
+ * }
+ *
+ * TypeDescriptor.of(TestUtils.<String>createTypeErasedList().getClass())
+ * => true, because the type variable T got type-erased and the anonymous ArrayList class
+ * is instantiated with an unresolved type variable T.
+ * }</pre>
+ */
+ public boolean hasUnresolvedParameters() {
+ return hasUnresolvedParameters(getType());
+ }
+
@Override
public String toString() {
return token.toString();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java
index a4626c9..8207f06 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptors.java
@@ -17,16 +17,20 @@
*/
package org.apache.beam.sdk.values;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;
import java.util.Set;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.transforms.SerializableFunction;
/**
- * A utility class containing the Java primitives for
- * {@link TypeDescriptor} equivalents. Also, has methods
- * for classes that wrap Java primitives like {@link KV},
- * {@link Set}, {@link List}, and {@link Iterable}.
+ * A utility class for creating {@link TypeDescriptor} objects for different types, such as Java
+ * primitive types, containers and {@link KV KVs} of other {@link TypeDescriptor} objects, and
+ * extracting type variables of parameterized types (e.g. extracting the {@code OutputT} type
+ * variable of a {@code DoFn<InputT, OutputT>}).
*/
public class TypeDescriptors {
/**
@@ -286,4 +290,110 @@
return typeDescriptor;
}
+
+ /**
+ * A helper interface for use with {@link #extractFromTypeParameters(Object, Class,
+ * TypeVariableExtractor)}.
+ */
+ public interface TypeVariableExtractor<InputT, OutputT> {}
+
+ /**
+ * Extracts a type from the actual type parameters of a parameterized class, subject to Java type
+ * erasure. The type to extract is specified in a way that is safe w.r.t. changing the type
+ * signature of the parameterized class, as opposed to specifying the name or index of a type
+ * variable.
+ *
+ * <p>Example of use:
+ * <pre>{@code
+ * class Foo<BarT> {
+ * private SerializableFunction<BarT, String> fn;
+ *
+ * TypeDescriptor<BarT> inferBarTypeDescriptorFromFn() {
+ * return TypeDescriptors.extractFromTypeParameters(
+ * fn,
+ * SerializableFunction.class,
+ * // The actual type of "fn" is matched against the input type of the extractor,
+ * // and the obtained values of type variables of the superclass are substituted
+ * // into the output type of the extractor.
+ * new TypeVariableExtractor<SerializableFunction<BarT, String>, BarT>() {});
+ * }
+ * }
+ * }</pre>
+ *
+ * @param instance The object being analyzed
+ * @param supertype Parameterized superclass of interest
+ * @param extractor A class for specifying the type to extract from the supertype
+ *
+ * @return A {@link TypeDescriptor} for the actual value of the result type of the extractor,
+ * or {@code null} if the type was erased.
+ */
+ @SuppressWarnings("unchecked")
+ @Nullable
+ public static <T, V> TypeDescriptor<V> extractFromTypeParameters(
+ T instance, Class<? super T> supertype, TypeVariableExtractor<T, V> extractor) {
+ return extractFromTypeParameters(
+ (TypeDescriptor<T>) TypeDescriptor.of(instance.getClass()), supertype, extractor);
+ }
+
+ /**
+ * Like {@link #extractFromTypeParameters(Object, Class, TypeVariableExtractor)}, but takes a
+ * {@link TypeDescriptor} of the instance being analyzed rather than the instance itself.
+ */
+ @SuppressWarnings("unchecked")
+ @Nullable
+ public static <T, V> TypeDescriptor<V> extractFromTypeParameters(
+ TypeDescriptor<T> type, Class<? super T> supertype, TypeVariableExtractor<T, V> extractor) {
+ // Get the type signature of the extractor, e.g.
+ // TypeVariableExtractor<SerializableFunction<BarT, String>, BarT>
+ TypeDescriptor<TypeVariableExtractor<T, V>> extractorSupertype =
+ (TypeDescriptor<TypeVariableExtractor<T, V>>)
+ TypeDescriptor.of(extractor.getClass()).getSupertype(TypeVariableExtractor.class);
+
+ // Get the actual type argument, e.g. SerializableFunction<BarT, String>
+ Type inputT = ((ParameterizedType) extractorSupertype.getType()).getActualTypeArguments()[0];
+
+ // Get the actual supertype of the type being analyzed, hopefully with all type parameters
+ // resolved, e.g. SerializableFunction<Integer, String>
+ TypeDescriptor supertypeDescriptor = type.getSupertype(supertype);
+
+ // Substitute actual supertype into the extractor, e.g.
+ // TypeVariableExtractor<SerializableFunction<Integer, String>, Integer>
+ TypeDescriptor<TypeVariableExtractor<T, V>> extractorT =
+ extractorSupertype.where(inputT, supertypeDescriptor.getType());
+
+ // Get output of the extractor.
+ Type outputT = ((ParameterizedType) extractorT.getType()).getActualTypeArguments()[1];
+ TypeDescriptor<?> res = TypeDescriptor.of(outputT);
+ if (res.hasUnresolvedParameters()) {
+ return null;
+ } else {
+ return (TypeDescriptor<V>) res;
+ }
+ }
+
+ /**
+ * Returns a type descriptor for the input of the given {@link SerializableFunction}, subject to
+ * Java type erasure: returns {@code null} if the type was erased.
+ */
+ @Nullable
+ public static <InputT, OutputT> TypeDescriptor<InputT> inputOf(
+ SerializableFunction<InputT, OutputT> fn) {
+ return extractFromTypeParameters(
+ fn,
+ SerializableFunction.class,
+ new TypeVariableExtractor<SerializableFunction<InputT, OutputT>, InputT>() {});
+ }
+
+ /**
+ * Returns a type descriptor for the output of the given {@link SerializableFunction}, subject to
+ * Java type erasure: returns {@code null} if the type was erased.
+ */
+ @Nullable
+ public static <InputT, OutputT> TypeDescriptor<OutputT> outputOf(
+ SerializableFunction<InputT, OutputT> fn) {
+ return extractFromTypeParameters(
+ fn,
+ SerializableFunction.class,
+ new TypeVariableExtractor<SerializableFunction<InputT, OutputT>, OutputT>() {});
+ }
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/TestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/TestUtils.java
index 1224f10..5ccc1ac 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/TestUtils.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/TestUtils.java
@@ -17,15 +17,9 @@
*/
package org.apache.beam.sdk;
-import static org.junit.Assert.assertThat;
-
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
import java.util.List;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.values.KV;
-import org.hamcrest.CoreMatchers;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
@@ -127,86 +121,4 @@
.appendText(")");
}
}
-
- ////////////////////////////////////////////////////////////////////////////
- // Utilities for testing CombineFns, ensuring they give correct results
- // across various permutations and shardings of the input.
-
- public static <InputT, AccumT, OutputT> void checkCombineFn(
- CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, final OutputT expected) {
- checkCombineFn(fn, input, CoreMatchers.is(expected));
- }
-
- public static <InputT, AccumT, OutputT> void checkCombineFn(
- CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, Matcher<? super OutputT> matcher) {
- checkCombineFnInternal(fn, input, matcher);
- Collections.shuffle(input);
- checkCombineFnInternal(fn, input, matcher);
- }
-
- private static <InputT, AccumT, OutputT> void checkCombineFnInternal(
- CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, Matcher<? super OutputT> matcher) {
- int size = input.size();
- checkCombineFnShards(fn, Collections.singletonList(input), matcher);
- checkCombineFnShards(fn, shardEvenly(input, 2), matcher);
- if (size > 4) {
- checkCombineFnShards(fn, shardEvenly(input, size / 2), matcher);
- checkCombineFnShards(
- fn, shardEvenly(input, (int) (size / Math.sqrt(size))), matcher);
- }
- checkCombineFnShards(fn, shardExponentially(input, 1.4), matcher);
- checkCombineFnShards(fn, shardExponentially(input, 2), matcher);
- checkCombineFnShards(fn, shardExponentially(input, Math.E), matcher);
- }
-
- public static <InputT, AccumT, OutputT> void checkCombineFnShards(
- CombineFn<InputT, AccumT, OutputT> fn,
- List<? extends Iterable<InputT>> shards,
- Matcher<? super OutputT> matcher) {
- checkCombineFnShardsInternal(fn, shards, matcher);
- Collections.shuffle(shards);
- checkCombineFnShardsInternal(fn, shards, matcher);
- }
-
- private static <InputT, AccumT, OutputT> void checkCombineFnShardsInternal(
- CombineFn<InputT, AccumT, OutputT> fn,
- Iterable<? extends Iterable<InputT>> shards,
- Matcher<? super OutputT> matcher) {
- List<AccumT> accumulators = new ArrayList<>();
- int maybeCompact = 0;
- for (Iterable<InputT> shard : shards) {
- AccumT accumulator = fn.createAccumulator();
- for (InputT elem : shard) {
- accumulator = fn.addInput(accumulator, elem);
- }
- if (maybeCompact++ % 2 == 0) {
- accumulator = fn.compact(accumulator);
- }
- accumulators.add(accumulator);
- }
- AccumT merged = fn.mergeAccumulators(accumulators);
- assertThat(fn.extractOutput(merged), matcher);
- }
-
- private static <T> List<List<T>> shardEvenly(List<T> input, int numShards) {
- List<List<T>> shards = new ArrayList<>(numShards);
- for (int i = 0; i < numShards; i++) {
- shards.add(input.subList(i * input.size() / numShards,
- (i + 1) * input.size() / numShards));
- }
- return shards;
- }
-
- private static <T> List<List<T>> shardExponentially(
- List<T> input, double base) {
- assert base > 1.0;
- List<List<T>> shards = new ArrayList<>();
- int end = input.size();
- while (end > 0) {
- int start = (int) (end / base);
- shards.add(input.subList(start, end));
- end = start;
- }
- return shards;
- }
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/CoderRegistryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/CoderRegistryTest.java
index d1113f7..b6430e5 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/CoderRegistryTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/CoderRegistryTest.java
@@ -467,4 +467,73 @@
AutoRegistrationClassCoder.INSTANCE));
}
}
+
+ @Test
+ public void testCoderPrecedence() throws Exception {
+ CoderRegistry registry = CoderRegistry.createDefault();
+
+ // DefaultCoder precedes CoderProviderRegistrar
+ assertEquals(AvroCoder.of(MyValueA.class), registry.getCoder(MyValueA.class));
+
+ // CoderProviderRegistrar precedes SerializableCoder
+ assertEquals(MyValueBCoder.INSTANCE, registry.getCoder(MyValueB.class));
+
+ // fallbacks to SerializableCoder at last
+ assertEquals(SerializableCoder.of(MyValueC.class), registry.getCoder(MyValueC.class));
+ }
+
+ @DefaultCoder(AvroCoder.class)
+ private static class MyValueA implements Serializable {}
+
+ private static class MyValueB implements Serializable {}
+
+ private static class MyValueC implements Serializable {}
+
+ private static class MyValueACoder extends CustomCoder<MyValueA> {
+ private static final MyValueACoder INSTANCE = new MyValueACoder();
+
+ @Override
+ public void encode(MyValueA value, OutputStream outStream) throws CoderException, IOException {}
+
+ @Override
+ public MyValueA decode(InputStream inStream) throws CoderException, IOException {
+ return null;
+ }
+ }
+
+ /**
+ * A {@link CoderProviderRegistrar} to demonstrate default {@link Coder} registration.
+ */
+ @AutoService(CoderProviderRegistrar.class)
+ public static class MyValueACoderProviderRegistrar implements CoderProviderRegistrar {
+ @Override
+ public List<CoderProvider> getCoderProviders() {
+ return ImmutableList.of(
+ CoderProviders.forCoder(TypeDescriptor.of(MyValueA.class), MyValueACoder.INSTANCE));
+ }
+ }
+
+ private static class MyValueBCoder extends CustomCoder<MyValueB> {
+ private static final MyValueBCoder INSTANCE = new MyValueBCoder();
+
+ @Override
+ public void encode(MyValueB value, OutputStream outStream) throws CoderException, IOException {}
+
+ @Override
+ public MyValueB decode(InputStream inStream) throws CoderException, IOException {
+ return null;
+ }
+ }
+
+ /**
+ * A {@link CoderProviderRegistrar} to demonstrate default {@link Coder} registration.
+ */
+ @AutoService(CoderProviderRegistrar.class)
+ public static class MyValueBCoderProviderRegistrar implements CoderProviderRegistrar {
+ @Override
+ public List<CoderProvider> getCoderProviders() {
+ return ImmutableList.of(
+ CoderProviders.forCoder(TypeDescriptor.of(MyValueB.class), MyValueBCoder.INSTANCE));
+ }
+ }
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java
index 4380c57..a96b6be 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroIOTest.java
@@ -30,8 +30,11 @@
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
@@ -40,6 +43,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
@@ -47,6 +51,7 @@
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileReader;
import org.apache.avro.file.DataFileStream;
+import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.reflect.Nullable;
@@ -54,6 +59,7 @@
import org.apache.avro.reflect.ReflectDatumReader;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.coders.DefaultCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy;
import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints;
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions;
@@ -66,6 +72,8 @@
import org.apache.beam.sdk.testing.UsesTestStream;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayDataEvaluator;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -75,6 +83,7 @@
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TimestampedValue;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -113,9 +122,9 @@
public GenericClass() {}
- public GenericClass(int intValue, String stringValue) {
- this.intField = intValue;
- this.stringField = stringValue;
+ public GenericClass(int intField, String stringField) {
+ this.intField = intField;
+ this.stringField = stringField;
}
@Override
@@ -141,9 +150,18 @@
}
}
+ private static class ParseGenericClass
+ implements SerializableFunction<GenericRecord, GenericClass> {
+ @Override
+ public GenericClass apply(GenericRecord input) {
+ return new GenericClass(
+ (int) input.get("intField"), input.get("stringField").toString());
+ }
+ }
+
@Test
@Category(NeedsRunner.class)
- public void testAvroIOWriteAndReadASingleFile() throws Throwable {
+ public void testAvroIOWriteAndReadAndParseASingleFile() throws Throwable {
List<GenericClass> values =
ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar"));
File outputFile = tmpFolder.newFile("output.avro");
@@ -152,10 +170,106 @@
.apply(AvroIO.write(GenericClass.class).to(outputFile.getAbsolutePath()).withoutSharding());
writePipeline.run().waitUntilFinish();
- PCollection<GenericClass> input =
- readPipeline.apply(AvroIO.read(GenericClass.class).from(outputFile.getAbsolutePath()));
+ // Test the same data using all versions of read().
+ PCollection<String> path =
+ readPipeline.apply("Create path", Create.of(outputFile.getAbsolutePath()));
+ PAssert.that(
+ readPipeline.apply(
+ "Read", AvroIO.read(GenericClass.class).from(outputFile.getAbsolutePath())))
+ .containsInAnyOrder(values);
+ PAssert.that(
+ readPipeline.apply(
+ "Read withHintMatchesManyFiles",
+ AvroIO.read(GenericClass.class)
+ .from(outputFile.getAbsolutePath())
+ .withHintMatchesManyFiles()))
+ .containsInAnyOrder(values);
+ PAssert.that(
+ path.apply(
+ "ReadAll", AvroIO.readAll(GenericClass.class).withDesiredBundleSizeBytes(10)))
+ .containsInAnyOrder(values);
+ PAssert.that(
+ readPipeline.apply(
+ "Parse",
+ AvroIO.parseGenericRecords(new ParseGenericClass())
+ .from(outputFile.getAbsolutePath())
+ .withCoder(AvroCoder.of(GenericClass.class))))
+ .containsInAnyOrder(values);
+ PAssert.that(
+ readPipeline.apply(
+ "Parse withHintMatchesManyFiles",
+ AvroIO.parseGenericRecords(new ParseGenericClass())
+ .from(outputFile.getAbsolutePath())
+ .withCoder(AvroCoder.of(GenericClass.class))
+ .withHintMatchesManyFiles()))
+ .containsInAnyOrder(values);
+ PAssert.that(
+ path.apply(
+ "ParseAll",
+ AvroIO.parseAllGenericRecords(new ParseGenericClass())
+ .withCoder(AvroCoder.of(GenericClass.class))
+ .withDesiredBundleSizeBytes(10)))
+ .containsInAnyOrder(values);
- PAssert.that(input).containsInAnyOrder(values);
+ readPipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testAvroIOWriteAndReadMultipleFilepatterns() throws Throwable {
+ List<GenericClass> firstValues = Lists.newArrayList();
+ List<GenericClass> secondValues = Lists.newArrayList();
+ for (int i = 0; i < 10; ++i) {
+ firstValues.add(new GenericClass(i, "a" + i));
+ secondValues.add(new GenericClass(i, "b" + i));
+ }
+ writePipeline
+ .apply("Create first", Create.of(firstValues))
+ .apply(
+ "Write first",
+ AvroIO.write(GenericClass.class)
+ .to(tmpFolder.getRoot().getAbsolutePath() + "/first")
+ .withNumShards(2));
+ writePipeline
+ .apply("Create second", Create.of(secondValues))
+ .apply(
+ "Write second",
+ AvroIO.write(GenericClass.class)
+ .to(tmpFolder.getRoot().getAbsolutePath() + "/second")
+ .withNumShards(3));
+ writePipeline.run().waitUntilFinish();
+
+ // Test read(), readAll(), and parseAllGenericRecords().
+ PAssert.that(
+ readPipeline.apply(
+ "Read first",
+ AvroIO.read(GenericClass.class)
+ .from(tmpFolder.getRoot().getAbsolutePath() + "/first*")))
+ .containsInAnyOrder(firstValues);
+ PAssert.that(
+ readPipeline.apply(
+ "Read second",
+ AvroIO.read(GenericClass.class)
+ .from(tmpFolder.getRoot().getAbsolutePath() + "/second*")))
+ .containsInAnyOrder(secondValues);
+ PCollection<String> paths =
+ readPipeline.apply(
+ "Create paths",
+ Create.of(
+ tmpFolder.getRoot().getAbsolutePath() + "/first*",
+ tmpFolder.getRoot().getAbsolutePath() + "/second*"));
+ PAssert.that(
+ paths.apply(
+ "Read all", AvroIO.readAll(GenericClass.class).withDesiredBundleSizeBytes(10)))
+ .containsInAnyOrder(Iterables.concat(firstValues, secondValues));
+ PAssert.that(
+ paths.apply(
+ "Parse all",
+ AvroIO.parseAllGenericRecords(new ParseGenericClass())
+ .withCoder(AvroCoder.of(GenericClass.class))
+ .withDesiredBundleSizeBytes(10)))
+ .containsInAnyOrder(Iterables.concat(firstValues, secondValues));
+
readPipeline.run();
}
@@ -428,17 +542,147 @@
assertThat(actualElements, containsInAnyOrder(allElements.toArray()));
}
+ private static final String SCHEMA_TEMPLATE_STRING =
+ "{\"namespace\": \"example.avro\",\n"
+ + " \"type\": \"record\",\n"
+ + " \"name\": \"TestTemplateSchema$$\",\n"
+ + " \"fields\": [\n"
+ + " {\"name\": \"$$full\", \"type\": \"string\"},\n"
+ + " {\"name\": \"$$suffix\", \"type\": [\"string\", \"null\"]}\n"
+ + " ]\n"
+ + "}";
+
+ private static String schemaFromPrefix(String prefix) {
+ return SCHEMA_TEMPLATE_STRING.replace("$$", prefix);
+ }
+
+ private static GenericRecord createRecord(String record, String prefix, Schema schema) {
+ GenericRecord genericRecord = new GenericData.Record(schema);
+ genericRecord.put(prefix + "full", record);
+ genericRecord.put(prefix + "suffix", record.substring(1));
+ return genericRecord;
+ }
+
+ private static class TestDynamicDestinations
+ extends DynamicAvroDestinations<String, String, GenericRecord> {
+ ResourceId baseDir;
+ PCollectionView<Map<String, String>> schemaView;
+
+ TestDynamicDestinations(ResourceId baseDir, PCollectionView<Map<String, String>> schemaView) {
+ this.baseDir = baseDir;
+ this.schemaView = schemaView;
+ }
+
+ @Override
+ public Schema getSchema(String destination) {
+ // Return a per-destination schema.
+ String schema = sideInput(schemaView).get(destination);
+ return new Schema.Parser().parse(schema);
+ }
+
+ @Override
+ public List<PCollectionView<?>> getSideInputs() {
+ return ImmutableList.<PCollectionView<?>>of(schemaView);
+ }
+
+ @Override
+ public GenericRecord formatRecord(String record) {
+ String prefix = record.substring(0, 1);
+ return createRecord(record, prefix, getSchema(prefix));
+ }
+
+ @Override
+ public String getDestination(String element) {
+ // Destination is based on first character of string.
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public String getDefaultDestination() {
+ return "";
+ }
+
+ @Override
+ public FilenamePolicy getFilenamePolicy(String destination) {
+ return DefaultFilenamePolicy.fromStandardParameters(
+ StaticValueProvider.of(
+ baseDir.resolve("file_" + destination + ".txt", StandardResolveOptions.RESOLVE_FILE)),
+ null,
+ null,
+ false);
+ }
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testDynamicDestinations() throws Exception {
+ ResourceId baseDir =
+ FileSystems.matchNewResource(
+ Files.createTempDirectory(tmpFolder.getRoot().toPath(), "testDynamicDestinations")
+ .toString(),
+ true);
+
+ List<String> elements = Lists.newArrayList("aaaa", "aaab", "baaa", "baab", "caaa", "caab");
+ List<GenericRecord> expectedElements = Lists.newArrayListWithExpectedSize(elements.size());
+ Map<String, String> schemaMap = Maps.newHashMap();
+ for (String element : elements) {
+ String prefix = element.substring(0, 1);
+ String jsonSchema = schemaFromPrefix(prefix);
+ schemaMap.put(prefix, jsonSchema);
+ expectedElements.add(createRecord(element, prefix, new Schema.Parser().parse(jsonSchema)));
+ }
+ PCollectionView<Map<String, String>> schemaView =
+ writePipeline
+ .apply("createSchemaView", Create.of(schemaMap))
+ .apply(View.<String, String>asMap());
+
+ PCollection<String> input =
+ writePipeline.apply("createInput", Create.of(elements).withCoder(StringUtf8Coder.of()));
+ input.apply(
+ AvroIO.<String>writeCustomTypeToGenericRecords()
+ .to(new TestDynamicDestinations(baseDir, schemaView))
+ .withoutSharding()
+ .withTempDirectory(baseDir));
+ writePipeline.run();
+
+ // Validate that the data written matches the expected elements in the expected order.
+
+ List<String> prefixes = Lists.newArrayList();
+ for (String element : elements) {
+ prefixes.add(element.substring(0, 1));
+ }
+ prefixes = ImmutableSet.copyOf(prefixes).asList();
+
+ List<GenericRecord> actualElements = new ArrayList<>();
+ for (String prefix : prefixes) {
+ File expectedFile =
+ new File(
+ baseDir
+ .resolve(
+ "file_" + prefix + ".txt-00000-of-00001", StandardResolveOptions.RESOLVE_FILE)
+ .toString());
+ assertTrue("Expected output file " + expectedFile.getAbsolutePath(), expectedFile.exists());
+ Schema schema = new Schema.Parser().parse(schemaFromPrefix(prefix));
+ try (DataFileReader<GenericRecord> reader =
+ new DataFileReader<>(expectedFile, new GenericDatumReader<GenericRecord>(schema))) {
+ Iterators.addAll(actualElements, reader);
+ }
+ expectedFile.delete();
+ }
+ assertThat(actualElements, containsInAnyOrder(expectedElements.toArray()));
+ }
+
@Test
public void testWriteWithDefaultCodec() throws Exception {
AvroIO.Write<String> write = AvroIO.write(String.class).to("/tmp/foo/baz");
- assertEquals(CodecFactory.deflateCodec(6).toString(), write.getCodec().toString());
+ assertEquals(CodecFactory.deflateCodec(6).toString(), write.inner.getCodec().toString());
}
@Test
public void testWriteWithCustomCodec() throws Exception {
AvroIO.Write<String> write =
AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.snappyCodec());
- assertEquals(SNAPPY_CODEC, write.getCodec().toString());
+ assertEquals(SNAPPY_CODEC, write.inner.getCodec().toString());
}
@Test
@@ -449,7 +693,7 @@
assertEquals(
CodecFactory.deflateCodec(9).toString(),
- SerializableUtils.clone(write.getCodec()).getCodec().toString());
+ SerializableUtils.clone(write.inner.getCodec()).getCodec().toString());
}
@Test
@@ -460,7 +704,7 @@
assertEquals(
CodecFactory.xzCodec(9).toString(),
- SerializableUtils.clone(write.getCodec()).getCodec().toString());
+ SerializableUtils.clone(write.inner.getCodec()).getCodec().toString());
}
@Test
@@ -511,7 +755,8 @@
String shardNameTemplate =
firstNonNull(
- write.getShardTemplate(), DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE);
+ write.inner.getShardTemplate(),
+ DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE);
assertTestOutputs(expectedElements, numShards, outputFilePrefix, shardNameTemplate);
}
@@ -603,7 +848,13 @@
assertThat(displayData, hasDisplayItem("filePrefix", "/foo"));
assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-"));
assertThat(displayData, hasDisplayItem("fileSuffix", "bar"));
- assertThat(displayData, hasDisplayItem("schema", GenericClass.class));
+ assertThat(
+ displayData,
+ hasDisplayItem(
+ "schema",
+ "{\"type\":\"record\",\"name\":\"GenericClass\",\"namespace\":\"org.apache.beam.sdk.io"
+ + ".AvroIOTest$\",\"fields\":[{\"name\":\"intField\",\"type\":\"int\"},"
+ + "{\"name\":\"stringField\",\"type\":\"string\"}]}"));
assertThat(displayData, hasDisplayItem("numShards", 100));
assertThat(displayData, hasDisplayItem("codec", CodecFactory.snappyCodec().toString()));
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java
index bf2ac95..714e029 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/AvroSourceTest.java
@@ -59,6 +59,7 @@
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.SourceTestUtils;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.SerializableUtils;
import org.hamcrest.Matchers;
@@ -407,11 +408,6 @@
source = AvroSource.from(filename).withSchema(schemaString);
records = SourceTestUtils.readFromSource(source, null);
assertEqualsWithGeneric(expected, records);
-
- // Create a source with no schema
- source = AvroSource.from(filename);
- records = SourceTestUtils.readFromSource(source, null);
- assertEqualsWithGeneric(expected, records);
}
@Test
@@ -449,6 +445,30 @@
assertSame(sourceA.getReaderSchemaString(), sourceC.getReaderSchemaString());
}
+ @Test
+ public void testParseFn() throws Exception {
+ List<Bird> expected = createRandomRecords(100);
+ String filename = generateTestFile("tmp.avro", expected, SyncBehavior.SYNC_DEFAULT, 0,
+ AvroCoder.of(Bird.class), DataFileConstants.NULL_CODEC);
+
+ AvroSource<Bird> source =
+ AvroSource.from(filename)
+ .withParseFn(
+ new SerializableFunction<GenericRecord, Bird>() {
+ @Override
+ public Bird apply(GenericRecord input) {
+ return new Bird(
+ (long) input.get("number"),
+ input.get("species").toString(),
+ input.get("quality").toString(),
+ (long) input.get("quantity"));
+ }
+ },
+ AvroCoder.of(Bird.class));
+ List<Bird> actual = SourceTestUtils.readFromSource(source, null);
+ assertThat(actual, containsInAnyOrder(expected.toArray()));
+ }
+
private void assertEqualsWithGeneric(List<Bird> expected, List<GenericRecord> actual) {
assertEquals(expected.size(), actual.size());
for (int i = 0; i < expected.size(); i++) {
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java
index 3fff319..fe6f01f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/CompressedSourceTest.java
@@ -253,6 +253,39 @@
}
/**
+ * Test a bzip2 file containing multiple streams is correctly decompressed.
+ *
+ * <p>A bzip2 file may contain multiple streams and should decompress as the concatenation of
+ * those streams.
+ */
+ @Test
+ @Category(NeedsRunner.class)
+ public void testReadMultiStreamBzip2() throws IOException {
+ CompressionMode mode = CompressionMode.BZIP2;
+ byte[] input1 = generateInput(5, 587973);
+ byte[] input2 = generateInput(5, 387374);
+
+ ByteArrayOutputStream stream1 = new ByteArrayOutputStream();
+ try (OutputStream os = getOutputStreamForMode(mode, stream1)) {
+ os.write(input1);
+ }
+
+ ByteArrayOutputStream stream2 = new ByteArrayOutputStream();
+ try (OutputStream os = getOutputStreamForMode(mode, stream2)) {
+ os.write(input2);
+ }
+
+ File tmpFile = tmpFolder.newFile();
+ try (OutputStream os = new FileOutputStream(tmpFile)) {
+ os.write(stream1.toByteArray());
+ os.write(stream2.toByteArray());
+ }
+
+ byte[] output = Bytes.concat(input1, input2);
+ verifyReadContents(output, tmpFile, mode);
+ }
+
+ /**
* Test reading empty input with bzip2.
*/
@Test
@@ -470,7 +503,16 @@
*/
private byte[] generateInput(int size) {
// Arbitrary but fixed seed
- Random random = new Random(285930);
+ return generateInput(size, 285930);
+ }
+
+
+ /**
+ * Generate byte array of given size.
+ */
+ private byte[] generateInput(int size, int seed) {
+ // Arbitrary but fixed seed
+ Random random = new Random(seed);
byte[] buff = new byte[size];
random.nextBytes(buff);
return buff;
@@ -596,7 +638,7 @@
}
@Override
- public Coder<Byte> getDefaultOutputCoder() {
+ public Coder<Byte> getOutputCoder() {
return SerializableCoder.of(Byte.class);
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java
index a6ad746..ff30e33 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSinkTest.java
@@ -231,7 +231,7 @@
SimpleSink.makeSimpleSink(
getBaseOutputDirectory(), prefix, "", "", CompressionType.UNCOMPRESSED);
- WriteOperation<String, Void> writeOp =
+ WriteOperation<Void, String> writeOp =
new SimpleSink.SimpleWriteOperation<>(sink, tempDirectory);
List<File> temporaryFiles = new ArrayList<>();
@@ -482,11 +482,11 @@
public void testFileBasedWriterWithWritableByteChannelFactory() throws Exception {
final String testUid = "testId";
ResourceId root = getBaseOutputDirectory();
- WriteOperation<String, Void> writeOp =
+ WriteOperation<Void, String> writeOp =
SimpleSink.makeSimpleSink(
root, "file", "-SS-of-NN", "txt", new DrunkWritableByteChannelFactory())
.createWriteOperation();
- final Writer<String, Void> writer = writeOp.createWriter();
+ final Writer<Void, String> writer = writeOp.createWriter();
final ResourceId expectedFile =
writeOp.tempDirectory.get().resolve(testUid, StandardResolveOptions.RESOLVE_FILE);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java
index c15e667..ea9e06b 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileBasedSourceTest.java
@@ -47,6 +47,7 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.FileBasedSource.FileBasedReader;
import org.apache.beam.sdk.io.Source.Reader;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -94,6 +95,15 @@
}
public TestFileBasedSource(
+ String fileOrPattern,
+ EmptyMatchTreatment emptyMatchTreatment,
+ long minBundleSize,
+ String splitHeader) {
+ super(StaticValueProvider.of(fileOrPattern), emptyMatchTreatment, minBundleSize);
+ this.splitHeader = splitHeader;
+ }
+
+ public TestFileBasedSource(
Metadata fileOrPattern,
long minBundleSize,
long startOffset,
@@ -107,7 +117,7 @@
public void validate() {}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
@@ -371,6 +381,47 @@
}
@Test
+ public void testEmptyFilepatternTreatmentDefaultDisallow() throws IOException {
+ PipelineOptions options = PipelineOptionsFactory.create();
+ TestFileBasedSource source =
+ new TestFileBasedSource(new File(tempFolder.getRoot(), "doesNotExist").getPath(), 64, null);
+ thrown.expect(FileNotFoundException.class);
+ readFromSource(source, options);
+ }
+
+ @Test
+ public void testEmptyFilepatternTreatmentAllow() throws IOException {
+ PipelineOptions options = PipelineOptionsFactory.create();
+ TestFileBasedSource source =
+ new TestFileBasedSource(
+ new File(tempFolder.getRoot(), "doesNotExist").getPath(),
+ EmptyMatchTreatment.ALLOW,
+ 64,
+ null);
+ TestFileBasedSource sourceWithWildcard =
+ new TestFileBasedSource(
+ new File(tempFolder.getRoot(), "doesNotExist*").getPath(),
+ EmptyMatchTreatment.ALLOW_IF_WILDCARD,
+ 64,
+ null);
+ assertEquals(0, readFromSource(source, options).size());
+ assertEquals(0, readFromSource(sourceWithWildcard, options).size());
+ }
+
+ @Test
+ public void testEmptyFilepatternTreatmentAllowIfWildcard() throws IOException {
+ PipelineOptions options = PipelineOptionsFactory.create();
+ TestFileBasedSource source =
+ new TestFileBasedSource(
+ new File(tempFolder.getRoot(), "doesNotExist").getPath(),
+ EmptyMatchTreatment.ALLOW_IF_WILDCARD,
+ 64,
+ null);
+ thrown.expect(FileNotFoundException.class);
+ readFromSource(source, options);
+ }
+
+ @Test
public void testCloseUnstartedFilePatternReader() throws IOException {
PipelineOptions options = PipelineOptionsFactory.create();
List<String> data1 = createStringDataset(3, 50);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/LocalFileSystemTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/LocalFileSystemTest.java
index 048908f..aaaeb83 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/LocalFileSystemTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/LocalFileSystemTest.java
@@ -45,7 +45,9 @@
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions;
+import org.apache.beam.sdk.testing.RestoreSystemProperties;
import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.commons.lang3.SystemUtils;
import org.hamcrest.Matchers;
import org.junit.Rule;
import org.junit.Test;
@@ -61,6 +63,7 @@
public class LocalFileSystemTest {
@Rule public ExpectedException thrown = ExpectedException.none();
@Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
+ @Rule public RestoreSystemProperties restoreSystemProperties = new RestoreSystemProperties();
private LocalFileSystem localFileSystem = new LocalFileSystem();
@Test
@@ -242,6 +245,52 @@
}
@Test
+ public void testMatchInDirectory() throws Exception {
+ List<String> expected = ImmutableList.of(temporaryFolder.newFile("a").toString());
+ temporaryFolder.newFile("aa");
+ temporaryFolder.newFile("ab");
+
+ String expectedFile = expected.get(0);
+ int slashIndex = expectedFile.lastIndexOf('/');
+ if (SystemUtils.IS_OS_WINDOWS) {
+ slashIndex = expectedFile.lastIndexOf('\\');
+ }
+ String directory = expectedFile.substring(0, slashIndex);
+ String relative = expectedFile.substring(slashIndex + 1);
+ System.setProperty("user.dir", directory);
+ List<MatchResult> results = localFileSystem.match(ImmutableList.of(relative));
+ assertThat(
+ toFilenames(results),
+ containsInAnyOrder(expected.toArray(new String[expected.size()])));
+ }
+
+ @Test
+ public void testMatchWithFileSlashPrefix() throws Exception {
+ List<String> expected = ImmutableList.of(temporaryFolder.newFile("a").toString());
+ temporaryFolder.newFile("aa");
+ temporaryFolder.newFile("ab");
+
+ String file = "file:/" + temporaryFolder.getRoot().toPath().resolve("a").toString();
+ List<MatchResult> results = localFileSystem.match(ImmutableList.of(file));
+ assertThat(
+ toFilenames(results),
+ containsInAnyOrder(expected.toArray(new String[expected.size()])));
+ }
+
+ @Test
+ public void testMatchWithFileThreeSlashesPrefix() throws Exception {
+ List<String> expected = ImmutableList.of(temporaryFolder.newFile("a").toString());
+ temporaryFolder.newFile("aa");
+ temporaryFolder.newFile("ab");
+
+ String file = "file:///" + temporaryFolder.getRoot().toPath().resolve("a").toString();
+ List<MatchResult> results = localFileSystem.match(ImmutableList.of(file));
+ assertThat(
+ toFilenames(results),
+ containsInAnyOrder(expected.toArray(new String[expected.size()])));
+ }
+
+ @Test
public void testMatchMultipleWithoutSubdirectoryExpansion() throws Exception {
File unmatchedSubDir = temporaryFolder.newFolder("aaa");
File unmatchedSubDirFile = File.createTempFile("sub-dir-file", "", unmatchedSubDir);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java
index 25168a3..feda355 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/OffsetBasedSourceTest.java
@@ -65,7 +65,7 @@
public void validate() {}
@Override
- public Coder<Integer> getDefaultOutputCoder() {
+ public Coder<Integer> getOutputCoder() {
return BigEndianIntegerCoder.of();
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
index 74acf18..4277dc3 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
@@ -171,7 +171,7 @@
public void validate() {}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
}
@@ -207,7 +207,7 @@
public void validate() {}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java
index 9196178..382898d 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/SimpleSink.java
@@ -28,10 +28,10 @@
/**
* A simple {@link FileBasedSink} that writes {@link String} values as lines with header and footer.
*/
-class SimpleSink<DestinationT> extends FileBasedSink<String, DestinationT> {
+class SimpleSink<DestinationT> extends FileBasedSink<String, DestinationT, String> {
public SimpleSink(
ResourceId tempDirectory,
- DynamicDestinations<String, DestinationT> dynamicDestinations,
+ DynamicDestinations<String, DestinationT, String> dynamicDestinations,
WritableByteChannelFactory writableByteChannelFactory) {
super(StaticValueProvider.of(tempDirectory), dynamicDestinations, writableByteChannelFactory);
}
@@ -50,7 +50,7 @@
String shardTemplate,
String suffix,
WritableByteChannelFactory writableByteChannelFactory) {
- DynamicDestinations<String, Void> dynamicDestinations =
+ DynamicDestinations<String, Void, String> dynamicDestinations =
DynamicFileDestinations.constant(
DefaultFilenamePolicy.fromParams(
new Params()
@@ -67,7 +67,7 @@
}
static final class SimpleWriteOperation<DestinationT>
- extends WriteOperation<String, DestinationT> {
+ extends WriteOperation<DestinationT, String> {
public SimpleWriteOperation(SimpleSink sink, ResourceId tempOutputDirectory) {
super(sink, tempOutputDirectory);
}
@@ -82,7 +82,7 @@
}
}
- static final class SimpleWriter<DestinationT> extends Writer<String, DestinationT> {
+ static final class SimpleWriter<DestinationT> extends Writer<DestinationT, String> {
static final String HEADER = "header";
static final String FOOTER = "footer";
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java
index e733010..aa6090d 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java
@@ -25,6 +25,7 @@
import static org.apache.beam.sdk.io.TextIO.CompressionType.GZIP;
import static org.apache.beam.sdk.io.TextIO.CompressionType.UNCOMPRESSED;
import static org.apache.beam.sdk.io.TextIO.CompressionType.ZIP;
+import static org.apache.beam.sdk.transforms.Watch.Growth.afterTimeSinceNewOutput;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasValue;
import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -63,6 +64,7 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
import org.apache.beam.sdk.io.TextIO.CompressionType;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.ValueProvider;
@@ -70,6 +72,7 @@
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.SourceTestUtils;
import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.UsesSplittableParDo;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.display.DisplayData;
@@ -78,6 +81,7 @@
import org.apache.beam.sdk.values.PCollection;
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorOutputStream;
import org.apache.commons.compress.compressors.deflate.DeflateCompressorOutputStream;
+import org.joda.time.Duration;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
@@ -290,9 +294,16 @@
}
/**
- * Helper method that runs TextIO.read().from(filename).withCompressionType(compressionType) and
- * TextIO.readAll().withCompressionType(compressionType) applied to the single filename,
- * and asserts that the results match the given expected output.
+ * Helper method that runs a variety of ways to read a single file using TextIO
+ * and checks that they all match the given expected output.
+ *
+ * <p>The transforms being verified are:
+ * <ul>
+ * <li>TextIO.read().from(filename).withCompressionType(compressionType)
+ * <li>TextIO.read().from(filename).withCompressionType(compressionType)
+ * .withHintMatchesManyFiles()
+ * <li>TextIO.readAll().withCompressionType(compressionType)
+ * </ul> and
*/
private void assertReadingCompressedFileMatchesExpected(
File file, CompressionType compressionType, List<String> expected) {
@@ -300,10 +311,17 @@
int thisUniquifier = ++uniquifier;
TextIO.Read read = TextIO.read().from(file.getPath()).withCompressionType(compressionType);
+
PAssert.that(
p.apply("Read_" + file + "_" + compressionType.toString() + "_" + thisUniquifier, read))
.containsInAnyOrder(expected);
+ PAssert.that(
+ p.apply(
+ "Read_" + file + "_" + compressionType.toString() + "_many" + "_" + thisUniquifier,
+ read.withHintMatchesManyFiles()))
+ .containsInAnyOrder(expected);
+
TextIO.ReadAll readAll =
TextIO.readAll().withCompressionType(compressionType).withDesiredBundleSizeBytes(10);
PAssert.that(
@@ -773,7 +791,8 @@
private TextSource prepareSource(byte[] data) throws IOException {
Path path = Files.createTempFile(tempFolder, "tempfile", "ext");
Files.write(path, data);
- return new TextSource(ValueProvider.StaticValueProvider.of(path.toString()));
+ return new TextSource(
+ ValueProvider.StaticValueProvider.of(path.toString()), EmptyMatchTreatment.DISALLOW);
}
@Test
@@ -858,4 +877,51 @@
PAssert.that(lines).containsInAnyOrder(Iterables.concat(TINY, TINY, LARGE, LARGE));
p.run();
}
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testReadWatchForNewFiles() throws IOException, InterruptedException {
+ final Path basePath = tempFolder.resolve("readWatch");
+ basePath.toFile().mkdir();
+ PCollection<String> lines =
+ p.apply(
+ TextIO.read()
+ .from(basePath.resolve("*").toString())
+ // Make sure that compression type propagates into readAll()
+ .withCompressionType(ZIP)
+ .watchForNewFiles(
+ Duration.millis(100), afterTimeSinceNewOutput(Duration.standardSeconds(3))));
+
+ Thread writer =
+ new Thread() {
+ @Override
+ public void run() {
+ try {
+ Thread.sleep(1000);
+ writeToFile(
+ Arrays.asList("a.1", "a.2"),
+ basePath.resolve("fileA").toString(),
+ CompressionType.ZIP);
+ Thread.sleep(300);
+ writeToFile(
+ Arrays.asList("b.1", "b.2"),
+ basePath.resolve("fileB").toString(),
+ CompressionType.ZIP);
+ Thread.sleep(300);
+ writeToFile(
+ Arrays.asList("c.1", "c.2"),
+ basePath.resolve("fileC").toString(),
+ CompressionType.ZIP);
+ } catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+ writer.start();
+
+ PAssert.that(lines).containsInAnyOrder("a.1", "a.2", "b.1", "b.2", "c.1", "c.2");
+ p.run();
+
+ writer.join();
+ }
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
index a73ed7d..7f80c26 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
@@ -110,7 +110,8 @@
});
}
- static class TestDynamicDestinations extends FileBasedSink.DynamicDestinations<String, String> {
+ static class TestDynamicDestinations
+ extends FileBasedSink.DynamicDestinations<String, String, String> {
ResourceId baseDir;
TestDynamicDestinations(ResourceId baseDir) {
@@ -118,6 +119,11 @@
}
@Override
+ public String formatRecord(String record) {
+ return record;
+ }
+
+ @Override
public String getDestination(String element) {
// Destination is based on first character of string.
return element.substring(0, 1);
@@ -169,10 +175,7 @@
List<String> elements = Lists.newArrayList("aaaa", "aaab", "baaa", "baab", "caaa", "caab");
PCollection<String> input = p.apply(Create.of(elements).withCoder(StringUtf8Coder.of()));
- input.apply(
- TextIO.write()
- .to(new TestDynamicDestinations(baseDir))
- .withTempDirectory(FileSystems.matchNewResource(baseDir.toString(), true)));
+ input.apply(TextIO.write().to(new TestDynamicDestinations(baseDir)).withTempDirectory(baseDir));
p.run();
assertOutputFiles(
@@ -268,8 +271,14 @@
new UserWriteType("caab", "sixth"));
PCollection<UserWriteType> input = p.apply(Create.of(elements));
input.apply(
- TextIO.writeCustomType(new SerializeUserWrite())
- .to(new UserWriteDestination(baseDir), new DefaultFilenamePolicy.Params())
+ TextIO.<UserWriteType>writeCustomType()
+ .to(
+ new UserWriteDestination(baseDir),
+ new DefaultFilenamePolicy.Params()
+ .withBaseFilename(
+ baseDir.resolve(
+ "empty", ResolveOptions.StandardResolveOptions.RESOLVE_FILE)))
+ .withFormatFunction(new SerializeUserWrite())
.withTempDirectory(FileSystems.matchNewResource(baseDir.toString(), true)));
p.run();
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java
index 60088de..1d4ce08 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java
@@ -68,8 +68,6 @@
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.transforms.View;
@@ -178,11 +176,7 @@
"Intimidating pigeon",
"Pedantic gull",
"Frisky finch");
- runWrite(
- inputs,
- IDENTITY_MAP,
- getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity()));
+ runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink()));
}
/** Test that WriteFiles with an empty input still produces one shard. */
@@ -193,7 +187,7 @@
Collections.<String>emptyList(),
IDENTITY_MAP,
getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity()));
+ WriteFiles.to(makeSimpleSink()));
checkFileContents(getBaseOutputFilename(), Collections.<String>emptyList(), Optional.of(1));
}
@@ -208,7 +202,7 @@
Arrays.asList("one", "two", "three", "four", "five", "six"),
IDENTITY_MAP,
getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity()).withNumShards(1));
+ WriteFiles.to(makeSimpleSink()));
}
private ResourceId getBaseOutputDirectory() {
@@ -241,9 +235,7 @@
}
SimpleSink<Void> sink = makeSimpleSink();
- WriteFiles<String, ?, String> write =
- WriteFiles.to(sink, SerializableFunctions.<String>identity())
- .withSharding(new LargestInt());
+ WriteFiles<String, ?, String> write = WriteFiles.to(sink).withSharding(new LargestInt());
p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of()))
.apply(IDENTITY_MAP)
.apply(write);
@@ -264,8 +256,7 @@
Arrays.asList("one", "two", "three", "four", "five", "six"),
IDENTITY_MAP,
getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity())
- .withNumShards(20));
+ WriteFiles.to(makeSimpleSink()).withNumShards(20));
}
/** Test a WriteFiles transform with an empty PCollection. */
@@ -273,11 +264,7 @@
@Category(NeedsRunner.class)
public void testWriteWithEmptyPCollection() throws IOException {
List<String> inputs = new ArrayList<>();
- runWrite(
- inputs,
- IDENTITY_MAP,
- getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity()));
+ runWrite(inputs, IDENTITY_MAP, getBaseOutputFilename(), WriteFiles.to(makeSimpleSink()));
}
/** Test a WriteFiles with a windowed PCollection. */
@@ -295,7 +282,7 @@
inputs,
new WindowAndReshuffle<>(Window.<String>into(FixedWindows.of(Duration.millis(2)))),
getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity()));
+ WriteFiles.to(makeSimpleSink()));
}
/** Test a WriteFiles with sessions. */
@@ -314,7 +301,7 @@
inputs,
new WindowAndReshuffle<>(Window.<String>into(Sessions.withGapDuration(Duration.millis(1)))),
getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity()));
+ WriteFiles.to(makeSimpleSink()));
}
@Test
@@ -328,15 +315,12 @@
inputs,
Window.<String>into(FixedWindows.of(Duration.millis(2))),
getBaseOutputFilename(),
- WriteFiles.to(makeSimpleSink(), SerializableFunctions.<String>identity())
- .withMaxNumWritersPerBundle(2)
- .withWindowedWrites());
+ WriteFiles.to(makeSimpleSink()).withMaxNumWritersPerBundle(2).withWindowedWrites());
}
public void testBuildWrite() {
SimpleSink<Void> sink = makeSimpleSink();
- WriteFiles<String, ?, String> write =
- WriteFiles.to(sink, SerializableFunctions.<String>identity()).withNumShards(3);
+ WriteFiles<String, ?, String> write = WriteFiles.to(sink).withNumShards(3);
assertThat((SimpleSink<Void>) write.getSink(), is(sink));
PTransform<PCollection<String>, PCollectionView<Integer>> originalSharding =
write.getSharding();
@@ -358,7 +342,7 @@
@Test
public void testDisplayData() {
- DynamicDestinations<String, Void> dynamicDestinations =
+ DynamicDestinations<String, Void, String> dynamicDestinations =
DynamicFileDestinations.constant(
DefaultFilenamePolicy.fromParams(
new Params()
@@ -374,8 +358,7 @@
builder.add(DisplayData.item("foo", "bar"));
}
};
- WriteFiles<String, ?, String> write =
- WriteFiles.to(sink, SerializableFunctions.<String>identity());
+ WriteFiles<String, ?, String> write = WriteFiles.to(sink);
DisplayData displayData = DisplayData.from(write);
@@ -391,9 +374,7 @@
"Must use windowed writes when applying WriteFiles to an unbounded PCollection");
SimpleSink<Void> sink = makeSimpleSink();
- p.apply(Create.of("foo"))
- .setIsBoundedInternal(IsBounded.UNBOUNDED)
- .apply(WriteFiles.to(sink, SerializableFunctions.<String>identity()));
+ p.apply(Create.of("foo")).setIsBoundedInternal(IsBounded.UNBOUNDED).apply(WriteFiles.to(sink));
p.run();
}
@@ -408,13 +389,13 @@
SimpleSink<Void> sink = makeSimpleSink();
p.apply(Create.of("foo"))
.setIsBoundedInternal(IsBounded.UNBOUNDED)
- .apply(WriteFiles.to(sink, SerializableFunctions.<String>identity()).withWindowedWrites());
+ .apply(WriteFiles.to(sink).withWindowedWrites());
p.run();
}
// Test DynamicDestinations class. Expects user values to be string-encoded integers.
// Stores the integer mod 5 as the destination, and uses that in the file prefix.
- static class TestDestinations extends DynamicDestinations<String, Integer> {
+ static class TestDestinations extends DynamicDestinations<String, Integer, String> {
private ResourceId baseOutputDirectory;
TestDestinations(ResourceId baseOutputDirectory) {
@@ -422,6 +403,11 @@
}
@Override
+ public String formatRecord(String record) {
+ return "record_" + record;
+ }
+
+ @Override
public Integer getDestination(String element) {
return Integer.valueOf(element) % 5;
}
@@ -444,14 +430,6 @@
}
}
- // Test format function. Prepend a string to each record before writing.
- static class TestDynamicFormatFunction implements SerializableFunction<String, String> {
- @Override
- public String apply(String input) {
- return "record_" + input;
- }
- }
-
@Test
@Category(NeedsRunner.class)
public void testDynamicDestinationsBounded() throws Exception {
@@ -495,8 +473,7 @@
// If emptyShards==true make numShards larger than the number of elements per destination.
// This will force every destination to generate some empty shards.
int numShards = emptyShards ? 2 * numInputs / 5 : 2;
- WriteFiles<String, Integer, String> writeFiles =
- WriteFiles.to(sink, new TestDynamicFormatFunction()).withNumShards(numShards);
+ WriteFiles<String, Integer, String> writeFiles = WriteFiles.to(sink).withNumShards(numShards);
PCollection<String> input = p.apply(Create.timestamped(inputs, timestamps));
if (!bounded) {
@@ -521,7 +498,7 @@
@Test
public void testShardedDisplayData() {
- DynamicDestinations<String, Void> dynamicDestinations =
+ DynamicDestinations<String, Void, String> dynamicDestinations =
DynamicFileDestinations.constant(
DefaultFilenamePolicy.fromParams(
new Params()
@@ -537,8 +514,7 @@
builder.add(DisplayData.item("foo", "bar"));
}
};
- WriteFiles<String, ?, String> write =
- WriteFiles.to(sink, SerializableFunctions.<String>identity()).withNumShards(1);
+ WriteFiles<String, ?, String> write = WriteFiles.to(sink).withNumShards(1);
DisplayData displayData = DisplayData.from(write);
assertThat(displayData, hasDisplayItem("sink", sink.getClass()));
assertThat(displayData, includesDisplayDataFor("sink", sink));
@@ -547,7 +523,7 @@
@Test
public void testCustomShardStrategyDisplayData() {
- DynamicDestinations<String, Void> dynamicDestinations =
+ DynamicDestinations<String, Void, String> dynamicDestinations =
DynamicFileDestinations.constant(
DefaultFilenamePolicy.fromParams(
new Params()
@@ -564,7 +540,7 @@
}
};
WriteFiles<String, ?, String> write =
- WriteFiles.to(sink, SerializableFunctions.<String>identity())
+ WriteFiles.to(sink)
.withSharding(
new PTransform<PCollection<String>, PCollectionView<Integer>>() {
@Override
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java
index 93650dd..12fe633 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java
@@ -35,6 +35,7 @@
import org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.io.CountingSource;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.Read;
@@ -110,7 +111,7 @@
public void emptyCompositeSucceeds() {
PCollection<Long> created =
PCollection.createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarLongCoder.of());
TransformHierarchy.Node node = hierarchy.pushNode("Create", PBegin.in(pipeline), Create.of(1));
hierarchy.setOutput(created);
hierarchy.popNode();
@@ -139,7 +140,7 @@
public void producingOwnAndOthersOutputsFails() {
PCollection<Long> created =
PCollection.createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarLongCoder.of());
hierarchy.pushNode("Create", PBegin.in(pipeline), Create.of(1));
hierarchy.setOutput(created);
hierarchy.popNode();
@@ -147,8 +148,11 @@
final PCollectionList<Long> appended =
pcList.and(
- PCollection.<Long>createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)
+ PCollection.createPrimitiveOutputInternal(
+ pipeline,
+ WindowingStrategy.globalDefault(),
+ IsBounded.BOUNDED,
+ VarLongCoder.of())
.setName("prim"));
hierarchy.pushNode(
"AddPc",
@@ -171,7 +175,7 @@
public void producingOwnOutputWithCompositeFails() {
final PCollection<Long> comp =
PCollection.createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarLongCoder.of());
PTransform<PBegin, PCollection<Long>> root =
new PTransform<PBegin, PCollection<Long>>() {
@Override
@@ -327,7 +331,7 @@
PCollection<Long> created =
PCollection.createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarLongCoder.of());
SingleOutput<Long, Long> pardo =
ParDo.of(
@@ -340,7 +344,7 @@
PCollection<Long> mapped =
PCollection.createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarLongCoder.of());
TransformHierarchy.Node compositeNode = hierarchy.pushNode("Create", begin, create);
hierarchy.finishSpecifyingInput();
@@ -499,13 +503,11 @@
@Test
public void visitIsTopologicallyOrdered() {
PCollection<String> one =
- PCollection.<String>createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)
- .setCoder(StringUtf8Coder.of());
+ PCollection.createPrimitiveOutputInternal(
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of());
final PCollection<Integer> two =
- PCollection.<Integer>createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
- .setCoder(VarIntCoder.of());
+ PCollection.createPrimitiveOutputInternal(
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED, VarIntCoder.of());
final PDone done = PDone.in(pipeline);
final TupleTag<String> oneTag = new TupleTag<String>() {};
final TupleTag<Integer> twoTag = new TupleTag<Integer>() {};
@@ -617,13 +619,14 @@
@Test
public void visitDoesNotVisitSkippedNodes() {
PCollection<String> one =
- PCollection.<String>createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED)
- .setCoder(StringUtf8Coder.of());
+ PCollection.createPrimitiveOutputInternal(
+ pipeline,
+ WindowingStrategy.globalDefault(),
+ IsBounded.BOUNDED,
+ StringUtf8Coder.of());
final PCollection<Integer> two =
- PCollection.<Integer>createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED)
- .setCoder(VarIntCoder.of());
+ PCollection.createPrimitiveOutputInternal(
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED, VarIntCoder.of());
final PDone done = PDone.in(pipeline);
final TupleTag<String> oneTag = new TupleTag<String>() {};
final TupleTag<Integer> twoTag = new TupleTag<Integer>() {};
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
index e7b680a..9956d5c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
@@ -27,8 +27,7 @@
import java.util.Arrays;
import java.util.EnumSet;
import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.io.WriteFiles;
@@ -85,10 +84,13 @@
// Issue below: PCollection.createPrimitiveOutput should not be used
// from within a composite transform.
return PCollectionList.of(
- Arrays.asList(result, PCollection.<String>createPrimitiveOutputInternal(
- b.getPipeline(),
- WindowingStrategy.globalDefault(),
- result.isBounded())));
+ Arrays.asList(
+ result,
+ PCollection.createPrimitiveOutputInternal(
+ b.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ result.isBounded(),
+ StringUtf8Coder.of())));
}
}
@@ -105,11 +107,6 @@
return PDone.in(input.getPipeline());
}
-
- @Override
- protected Coder<?> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
}
@Test
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/dataflow/TestCountingSource.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/dataflow/TestCountingSource.java
index 9fcc3c5..338ea38 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/dataflow/TestCountingSource.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/dataflow/TestCountingSource.java
@@ -248,7 +248,7 @@
public void validate() {}
@Override
- public Coder<KV<Integer, Integer>> getDefaultOutputCoder() {
+ public Coder<KV<Integer, Integer>> getOutputCoder() {
return KvCoder.of(VarIntCoder.of(), VarIntCoder.of());
}
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CombineFnTesterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CombineFnTesterTest.java
new file mode 100644
index 0000000..15198b2
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/CombineFnTesterTest.java
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.testing;
+
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Sum;
+import org.hamcrest.Description;
+import org.hamcrest.Matcher;
+import org.hamcrest.Matchers;
+import org.hamcrest.TypeSafeMatcher;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link CombineFnTester}.
+ */
+@RunWith(JUnit4.class)
+public class CombineFnTesterTest {
+ @Test
+ public void checksMergeWithEmptyAccumulators() {
+ final AtomicBoolean sawEmpty = new AtomicBoolean(false);
+ CombineFn<Integer, Integer, Integer> combineFn =
+ new CombineFn<Integer, Integer, Integer>() {
+ @Override
+ public Integer createAccumulator() {
+ return 0;
+ }
+
+ @Override
+ public Integer addInput(Integer accumulator, Integer input) {
+ return accumulator + input;
+ }
+
+ @Override
+ public Integer mergeAccumulators(Iterable<Integer> accumulators) {
+ int result = 0;
+ for (int accum : accumulators) {
+ if (accum == 0) {
+ sawEmpty.set(true);
+ }
+ result += accum;
+ }
+ return result;
+ }
+
+ @Override
+ public Integer extractOutput(Integer accumulator) {
+ return accumulator;
+ }
+ };
+
+ CombineFnTester.testCombineFn(combineFn, Arrays.asList(1, 2, 3, 4, 5), 15);
+ assertThat(sawEmpty.get(), is(true));
+ }
+
+ @Test
+ public void checksWithSingleShard() {
+ final AtomicBoolean sawSingleShard = new AtomicBoolean();
+ CombineFn<Integer, Integer, Integer> combineFn =
+ new CombineFn<Integer, Integer, Integer>() {
+ int accumCount = 0;
+
+ @Override
+ public Integer createAccumulator() {
+ accumCount++;
+ return 0;
+ }
+
+ @Override
+ public Integer addInput(Integer accumulator, Integer input) {
+ return accumulator + input;
+ }
+
+ @Override
+ public Integer mergeAccumulators(Iterable<Integer> accumulators) {
+ int result = 0;
+ for (int accum : accumulators) {
+ result += accum;
+ }
+ return result;
+ }
+
+ @Override
+ public Integer extractOutput(Integer accumulator) {
+ if (accumCount == 1) {
+ sawSingleShard.set(true);
+ }
+ accumCount = 0;
+ return accumulator;
+ }
+ };
+
+ CombineFnTester.testCombineFn(combineFn, Arrays.asList(1, 2, 3, 4, 5), 15);
+ assertThat(sawSingleShard.get(), is(true));
+ }
+
+ @Test
+ public void checksWithShards() {
+ final AtomicBoolean sawManyShards = new AtomicBoolean();
+ CombineFn<Integer, Integer, Integer> combineFn =
+ new CombineFn<Integer, Integer, Integer>() {
+
+ @Override
+ public Integer createAccumulator() {
+ return 0;
+ }
+
+ @Override
+ public Integer addInput(Integer accumulator, Integer input) {
+ return accumulator + input;
+ }
+
+ @Override
+ public Integer mergeAccumulators(Iterable<Integer> accumulators) {
+ if (Iterables.size(accumulators) > 2) {
+ sawManyShards.set(true);
+ }
+ int result = 0;
+ for (int accum : accumulators) {
+ result += accum;
+ }
+ return result;
+ }
+
+ @Override
+ public Integer extractOutput(Integer accumulator) {
+ return accumulator;
+ }
+ };
+
+ CombineFnTester.testCombineFn(
+ combineFn, Arrays.asList(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3), 30);
+ assertThat(sawManyShards.get(), is(true));
+ }
+
+ @Test
+ public void checksWithMultipleMerges() {
+ final AtomicBoolean sawMultipleMerges = new AtomicBoolean();
+ CombineFn<Integer, Integer, Integer> combineFn =
+ new CombineFn<Integer, Integer, Integer>() {
+ int mergeCalls = 0;
+
+ @Override
+ public Integer createAccumulator() {
+ return 0;
+ }
+
+ @Override
+ public Integer addInput(Integer accumulator, Integer input) {
+ return accumulator + input;
+ }
+
+ @Override
+ public Integer mergeAccumulators(Iterable<Integer> accumulators) {
+ mergeCalls++;
+ int result = 0;
+ for (int accum : accumulators) {
+ result += accum;
+ }
+ return result;
+ }
+
+ @Override
+ public Integer extractOutput(Integer accumulator) {
+ if (mergeCalls > 1) {
+ sawMultipleMerges.set(true);
+ }
+ mergeCalls = 0;
+ return accumulator;
+ }
+ };
+
+ CombineFnTester.testCombineFn(combineFn, Arrays.asList(1, 1, 2, 2, 3, 3, 4, 4, 5, 5), 30);
+ assertThat(sawMultipleMerges.get(), is(true));
+ }
+
+ @Test
+ public void checksAlternateOrder() {
+ final AtomicBoolean sawOutOfOrder = new AtomicBoolean();
+ CombineFn<Integer, List<Integer>, Integer> combineFn =
+ new CombineFn<Integer, List<Integer>, Integer>() {
+ @Override
+ public List<Integer> createAccumulator() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public List<Integer> addInput(List<Integer> accumulator, Integer input) {
+ // If the input is being added to an empty accumulator, it's not known to be
+ // out of order, and it cannot be compared to the previous element. If the elements
+ // are out of order (relative to the input) a greater element will be added before
+ // a smaller one.
+ if (!accumulator.isEmpty() && accumulator.get(accumulator.size() - 1) > input) {
+ sawOutOfOrder.set(true);
+ }
+ accumulator.add(input);
+ return accumulator;
+ }
+
+ @Override
+ public List<Integer> mergeAccumulators(Iterable<List<Integer>> accumulators) {
+ List<Integer> result = new ArrayList<>();
+ for (List<Integer> accum : accumulators) {
+ result.addAll(accum);
+ }
+ return result;
+ }
+
+ @Override
+ public Integer extractOutput(List<Integer> accumulator) {
+ int value = 0;
+ for (int i : accumulator) {
+ value += i;
+ }
+ return value;
+ }
+ };
+
+ CombineFnTester.testCombineFn(
+ combineFn, Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), 105);
+ assertThat(sawOutOfOrder.get(), is(true));
+ }
+
+ @Test
+ public void usesMatcher() {
+ final AtomicBoolean matcherUsed = new AtomicBoolean();
+ Matcher<Integer> matcher =
+ new TypeSafeMatcher<Integer>() {
+ @Override
+ public void describeTo(Description description) {}
+
+ @Override
+ protected boolean matchesSafely(Integer item) {
+ matcherUsed.set(true);
+ return item == 30;
+ }
+ };
+ CombineFnTester.testCombineFn(
+ Sum.ofIntegers(), Arrays.asList(1, 1, 2, 2, 3, 3, 4, 4, 5, 5), matcher);
+ assertThat(matcherUsed.get(), is(true));
+ try {
+ CombineFnTester.testCombineFn(
+ Sum.ofIntegers(), Arrays.asList(1, 2, 3, 4, 5), Matchers.not(Matchers.equalTo(15)));
+ } catch (AssertionError ignored) {
+ // Success! Return to avoid the call to fail();
+ return;
+ }
+ fail("The matcher should have failed, throwing an error");
+ }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ApproximateQuantilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ApproximateQuantilesTest.java
index 9e0b3cc..e180833 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ApproximateQuantilesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ApproximateQuantilesTest.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.transforms;
-import static org.apache.beam.sdk.TestUtils.checkCombineFn;
+import static org.apache.beam.sdk.testing.CombineFnTester.testCombineFn;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
@@ -129,7 +129,7 @@
@Test
public void testSingleton() {
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(5),
Arrays.asList(389),
Arrays.asList(389, 389, 389, 389, 389));
@@ -137,7 +137,7 @@
@Test
public void testSimpleQuantiles() {
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(5),
intRange(101),
Arrays.asList(0, 25, 50, 75, 100));
@@ -145,7 +145,7 @@
@Test
public void testUnevenQuantiles() {
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(37),
intRange(5000),
quantileMatcher(5000, 37, 20 /* tolerance */));
@@ -153,7 +153,7 @@
@Test
public void testLargerQuantiles() {
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(50),
intRange(10001),
quantileMatcher(10001, 50, 20 /* tolerance */));
@@ -161,7 +161,7 @@
@Test
public void testTightEpsilon() {
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(10).withEpsilon(0.01),
intRange(10001),
quantileMatcher(10001, 10, 5 /* tolerance */));
@@ -174,7 +174,7 @@
for (int i = 0; i < 10; i++) {
all.addAll(intRange(size));
}
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(5),
all,
Arrays.asList(0, 25, 50, 75, 100));
@@ -190,7 +190,7 @@
for (int i = 300; i < 1000; i++) {
all.add(3);
}
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(5),
all,
Arrays.asList(1, 2, 3, 3, 3));
@@ -202,7 +202,7 @@
for (int i = 1; i < 1000; i++) {
all.add((int) Math.log(i));
}
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(5),
all,
Arrays.asList(0, 5, 6, 6, 6));
@@ -214,7 +214,7 @@
for (int i = 1; i < 1000; i++) {
all.add(1000 / i);
}
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<Integer>create(5),
all,
Arrays.asList(1, 1, 2, 4, 1000));
@@ -224,11 +224,11 @@
public void testAlternateComparator() {
List<String> inputs = Arrays.asList(
"aa", "aaa", "aaaa", "b", "ccccc", "dddd", "zz");
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.<String>create(3),
inputs,
Arrays.asList("aa", "b", "zz"));
- checkCombineFn(
+ testCombineFn(
ApproximateQuantilesCombineFn.create(3, new OrderByLength()),
inputs,
Arrays.asList("b", "aaa", "ccccc"));
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java
index b24d82d..52fedc6 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CombineTest.java
@@ -19,7 +19,7 @@
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
-import static org.apache.beam.sdk.TestUtils.checkCombineFn;
+import static org.apache.beam.sdk.testing.CombineFnTester.testCombineFn;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasNamespace;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor;
@@ -695,11 +695,11 @@
@Test
public void testBinaryCombineFnWithNulls() {
- checkCombineFn(new NullCombiner(), Arrays.asList(3, 3, 5), 45);
- checkCombineFn(new NullCombiner(), Arrays.asList(null, 3, 5), 30);
- checkCombineFn(new NullCombiner(), Arrays.asList(3, 3, null), 18);
- checkCombineFn(new NullCombiner(), Arrays.asList(null, 3, null), 12);
- checkCombineFn(new NullCombiner(), Arrays.<Integer>asList(null, null, null), 8);
+ testCombineFn(new NullCombiner(), Arrays.asList(3, 3, 5), 45);
+ testCombineFn(new NullCombiner(), Arrays.asList(null, 3, 5), 30);
+ testCombineFn(new NullCombiner(), Arrays.asList(3, 3, null), 18);
+ testCombineFn(new NullCombiner(), Arrays.asList(null, 3, null), 12);
+ testCombineFn(new NullCombiner(), Arrays.<Integer>asList(null, null, null), 8);
}
private static final class TestProdInt extends Combine.BinaryCombineIntegerFn {
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
index a05d31c..81ad947 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
@@ -25,7 +25,9 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.io.InputStream;
@@ -47,6 +49,10 @@
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
+import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
+import org.apache.beam.sdk.options.ValueProviders;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.SourceTestUtils;
@@ -54,8 +60,8 @@
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create.Values.CreateSource;
import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -307,28 +313,24 @@
@Test
public void testCreateTimestampedDefaultOutputCoderUsingCoder() throws Exception {
Coder<Record> coder = new RecordCoder();
- PBegin pBegin = PBegin.in(p);
Create.TimestampedValues<Record> values =
Create.timestamped(
TimestampedValue.of(new Record(), new Instant(0)),
TimestampedValue.<Record>of(new Record2(), new Instant(0)))
.withCoder(coder);
- Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
- assertThat(defaultCoder, equalTo(coder));
+ assertThat(p.apply(values).getCoder(), equalTo(coder));
}
@Test
public void testCreateTimestampedDefaultOutputCoderUsingTypeDescriptor() throws Exception {
Coder<Record> coder = new RecordCoder();
p.getCoderRegistry().registerCoderForClass(Record.class, coder);
- PBegin pBegin = PBegin.in(p);
Create.TimestampedValues<Record> values =
Create.timestamped(
TimestampedValue.of(new Record(), new Instant(0)),
TimestampedValue.<Record>of(new Record2(), new Instant(0)))
.withType(new TypeDescriptor<Record>() {});
- Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
- assertThat(defaultCoder, equalTo(coder));
+ assertThat(p.apply(values).getCoder(), equalTo(coder));
}
@Test
@@ -353,6 +355,52 @@
p.run();
}
+ private static final ObjectMapper MAPPER = new ObjectMapper().registerModules(
+ ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
+
+ /** Testing options for {@link #testCreateOfProvider()}. */
+ public interface CreateOfProviderOptions extends PipelineOptions {
+ ValueProvider<String> getFoo();
+ void setFoo(ValueProvider<String> value);
+ }
+
+ @Test
+ @Category(ValidatesRunner.class)
+ public void testCreateOfProvider() throws Exception {
+ PAssert.that(
+ p.apply(
+ "Static", Create.ofProvider(StaticValueProvider.of("foo"), StringUtf8Coder.of())))
+ .containsInAnyOrder("foo");
+ PAssert.that(
+ p.apply(
+ "Static nested",
+ Create.ofProvider(
+ NestedValueProvider.of(
+ StaticValueProvider.of("foo"),
+ new SerializableFunction<String, String>() {
+ @Override
+ public String apply(String input) {
+ return input + "bar";
+ }
+ }),
+ StringUtf8Coder.of())))
+ .containsInAnyOrder("foobar");
+ CreateOfProviderOptions submitOptions =
+ p.getOptions().as(CreateOfProviderOptions.class);
+ PAssert.that(
+ p.apply("Runtime", Create.ofProvider(submitOptions.getFoo(), StringUtf8Coder.of())))
+ .containsInAnyOrder("runtime foo");
+
+ String serializedOptions = MAPPER.writeValueAsString(p.getOptions());
+ String runnerString = ValueProviders.updateSerializedOptions(
+ serializedOptions, ImmutableMap.of("foo", "runtime foo"));
+ CreateOfProviderOptions runtimeOptions =
+ MAPPER.readValue(runnerString, PipelineOptions.class).as(CreateOfProviderOptions.class);
+
+ p.run(runtimeOptions);
+ }
+
+
@Test
public void testCreateGetName() {
assertEquals("Create.Values", Create.of(1, 2, 3).getName());
@@ -364,31 +412,25 @@
public void testCreateDefaultOutputCoderUsingInference() throws Exception {
Coder<Record> coder = new RecordCoder();
p.getCoderRegistry().registerCoderForClass(Record.class, coder);
- PBegin pBegin = PBegin.in(p);
- Create.Values<Record> values = Create.of(new Record(), new Record(), new Record());
- Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
- assertThat(defaultCoder, equalTo(coder));
+ assertThat(
+ p.apply(Create.of(new Record(), new Record(), new Record())).getCoder(), equalTo(coder));
}
@Test
public void testCreateDefaultOutputCoderUsingCoder() throws Exception {
Coder<Record> coder = new RecordCoder();
- PBegin pBegin = PBegin.in(p);
- Create.Values<Record> values =
- Create.of(new Record(), new Record2()).withCoder(coder);
- Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
- assertThat(defaultCoder, equalTo(coder));
+ assertThat(
+ p.apply(Create.of(new Record(), new Record2()).withCoder(coder)).getCoder(),
+ equalTo(coder));
}
@Test
public void testCreateDefaultOutputCoderUsingTypeDescriptor() throws Exception {
Coder<Record> coder = new RecordCoder();
p.getCoderRegistry().registerCoderForClass(Record.class, coder);
- PBegin pBegin = PBegin.in(p);
Create.Values<Record> values =
Create.of(new Record(), new Record2()).withType(new TypeDescriptor<Record>() {});
- Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
- assertThat(defaultCoder, equalTo(coder));
+ assertThat(p.apply(values).getCoder(), equalTo(coder));
}
@Test
@@ -434,12 +476,12 @@
}
@Test
- public void testSourceGetDefaultOutputCoderReturnsConstructorCoder() throws Exception {
+ public void testSourceGetOutputCoderReturnsConstructorCoder() throws Exception {
Coder<Integer> coder = VarIntCoder.of();
CreateSource<Integer> source =
CreateSource.fromIterable(ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8), coder);
- Coder<Integer> defaultCoder = source.getDefaultOutputCoder();
+ Coder<Integer> defaultCoder = source.getOutputCoder();
assertThat(defaultCoder, equalTo(coder));
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java
index a8cb843..5dbe176 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java
@@ -228,7 +228,7 @@
public void testFlattenNoListsNoCoder() {
// not ValidatesRunner because it should fail at pipeline construction time anyhow.
thrown.expect(IllegalStateException.class);
- thrown.expectMessage("cannot provide a Coder for empty");
+ thrown.expectMessage("Unable to return a default Coder");
PCollectionList.<ClassWithoutCoder>empty(p)
.apply(Flatten.<ClassWithoutCoder>pCollections());
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
index 8fcb4c0..a76714c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
@@ -70,7 +70,6 @@
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
-import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.hamcrest.Matcher;
import org.joda.time.Duration;
@@ -423,11 +422,11 @@
new PTransform<PBegin, PCollection<KV<String, Integer>>>() {
@Override
public PCollection<KV<String, Integer>> expand(PBegin input) {
- return PCollection.<KV<String, Integer>>createPrimitiveOutputInternal(
- input.getPipeline(),
- WindowingStrategy.globalDefault(),
- PCollection.IsBounded.UNBOUNDED)
- .setTypeDescriptor(new TypeDescriptor<KV<String, Integer>>() {});
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.UNBOUNDED,
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
}
});
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MaxTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MaxTest.java
index 52043e1..a298a5e 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MaxTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MaxTest.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.transforms;
-import static org.apache.beam.sdk.TestUtils.checkCombineFn;
+import static org.apache.beam.sdk.testing.CombineFnTester.testCombineFn;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
@@ -45,7 +45,7 @@
@Test
public void testMaxIntegerFn() {
- checkCombineFn(
+ testCombineFn(
Max.ofIntegers(),
Lists.newArrayList(1, 2, 3, 4),
4);
@@ -53,7 +53,7 @@
@Test
public void testMaxLongFn() {
- checkCombineFn(
+ testCombineFn(
Max.ofLongs(),
Lists.newArrayList(1L, 2L, 3L, 4L),
4L);
@@ -61,7 +61,7 @@
@Test
public void testMaxDoubleFn() {
- checkCombineFn(
+ testCombineFn(
Max.ofDoubles(),
Lists.newArrayList(1.0, 2.0, 3.0, 4.0),
4.0);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MeanTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MeanTest.java
index 79ebc25..e138135 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MeanTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MeanTest.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.transforms;
-import static org.apache.beam.sdk.TestUtils.checkCombineFn;
+import static org.apache.beam.sdk.testing.CombineFnTester.testCombineFn;
import static org.junit.Assert.assertEquals;
import com.google.common.collect.Lists;
@@ -64,7 +64,7 @@
@Test
public void testMeanFn() throws Exception {
- checkCombineFn(
+ testCombineFn(
Mean.<Integer>of(),
Lists.newArrayList(1, 2, 3, 4),
2.5);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MinTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MinTest.java
index 1ece09b..a515b63 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MinTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MinTest.java
@@ -18,7 +18,7 @@
package org.apache.beam.sdk.transforms;
-import static org.apache.beam.sdk.TestUtils.checkCombineFn;
+import static org.apache.beam.sdk.testing.CombineFnTester.testCombineFn;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
@@ -45,7 +45,7 @@
}
@Test
public void testMinIntegerFn() {
- checkCombineFn(
+ testCombineFn(
Min.ofIntegers(),
Lists.newArrayList(1, 2, 3, 4),
1);
@@ -53,7 +53,7 @@
@Test
public void testMinLongFn() {
- checkCombineFn(
+ testCombineFn(
Min.ofLongs(),
Lists.newArrayList(1L, 2L, 3L, 4L),
1L);
@@ -61,7 +61,7 @@
@Test
public void testMinDoubleFn() {
- checkCombineFn(
+ testCombineFn(
Min.ofDoubles(),
Lists.newArrayList(1.0, 2.0, 3.0, 4.0),
1.0);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SumTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SumTest.java
index 9d2c6f6..e5bf904 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SumTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SumTest.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.transforms;
-import static org.apache.beam.sdk.TestUtils.checkCombineFn;
+import static org.apache.beam.sdk.testing.CombineFnTester.testCombineFn;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
@@ -51,7 +51,7 @@
@Test
public void testSumIntegerFn() {
- checkCombineFn(
+ testCombineFn(
Sum.ofIntegers(),
Lists.newArrayList(1, 2, 3, 4),
10);
@@ -59,7 +59,7 @@
@Test
public void testSumLongFn() {
- checkCombineFn(
+ testCombineFn(
Sum.ofLongs(),
Lists.newArrayList(1L, 2L, 3L, 4L),
10L);
@@ -67,7 +67,7 @@
@Test
public void testSumDoubleFn() {
- checkCombineFn(
+ testCombineFn(
Sum.ofDoubles(),
Lists.newArrayList(1.0, 2.0, 3.0, 4.0),
10.0);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ViewTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ViewTest.java
index cdd03d9..bfb8b5a 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ViewTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ViewTest.java
@@ -60,7 +60,6 @@
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TimestampedValue;
-import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
@@ -1340,11 +1339,11 @@
new PTransform<PBegin, PCollection<KV<String, Integer>>>() {
@Override
public PCollection<KV<String, Integer>> expand(PBegin input) {
- return PCollection.<KV<String, Integer>>createPrimitiveOutputInternal(
- input.getPipeline(),
- WindowingStrategy.globalDefault(),
- PCollection.IsBounded.UNBOUNDED)
- .setTypeDescriptor(new TypeDescriptor<KV<String, Integer>>() {});
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(),
+ WindowingStrategy.globalDefault(),
+ PCollection.IsBounded.UNBOUNDED,
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()));
}
})
.apply(view);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java
new file mode 100644
index 0000000..132a1ff
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java
@@ -0,0 +1,763 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import static org.apache.beam.sdk.transforms.Watch.Growth.afterTimeSinceNewOutput;
+import static org.apache.beam.sdk.transforms.Watch.Growth.afterTotalOf;
+import static org.apache.beam.sdk.transforms.Watch.Growth.allOf;
+import static org.apache.beam.sdk.transforms.Watch.Growth.eitherOf;
+import static org.apache.beam.sdk.transforms.Watch.Growth.never;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.joda.time.Duration.standardSeconds;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Ordering;
+import com.google.common.collect.Sets;
+import com.google.common.hash.HashCode;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.UUID;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.UsesSplittableParDo;
+import org.apache.beam.sdk.transforms.Watch.Growth.PollFn;
+import org.apache.beam.sdk.transforms.Watch.Growth.PollResult;
+import org.apache.beam.sdk.transforms.Watch.GrowthState;
+import org.apache.beam.sdk.transforms.Watch.GrowthTracker;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.joda.time.ReadableDuration;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link Watch}. */
+@RunWith(JUnit4.class)
+public class WatchTest implements Serializable {
+ @Rule public transient TestPipeline p = TestPipeline.create();
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testSinglePollMultipleInputs() {
+ PCollection<KV<String, String>> res =
+ p.apply(Create.of("a", "b"))
+ .apply(
+ Watch.growthOf(
+ new PollFn<String, String>() {
+ @Override
+ public PollResult<String> apply(String input, Instant time) {
+ return PollResult.complete(
+ time, Arrays.asList(input + ".foo", input + ".bar"));
+ }
+ })
+ .withPollInterval(Duration.ZERO));
+
+ PAssert.that(res)
+ .containsInAnyOrder(
+ Arrays.asList(
+ KV.of("a", "a.foo"), KV.of("a", "a.bar"),
+ KV.of("b", "b.foo"), KV.of("b", "b.bar")));
+
+ p.run();
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testMultiplePollsWithTerminationBecauseOutputIsFinal() {
+ testMultiplePolls(false);
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testMultiplePollsWithTerminationDueToTerminationCondition() {
+ testMultiplePolls(true);
+ }
+
+ private void testMultiplePolls(boolean terminationConditionElapsesBeforeOutputIsFinal) {
+ List<Integer> all = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
+
+ PCollection<Integer> res =
+ p.apply(Create.of("a"))
+ .apply(
+ Watch.growthOf(
+ new TimedPollFn<String, Integer>(
+ all,
+ standardSeconds(1) /* timeToOutputEverything */,
+ standardSeconds(3) /* timeToDeclareOutputFinal */,
+ standardSeconds(30) /* timeToFail */))
+ .withTerminationPerInput(
+ Watch.Growth.<String>afterTotalOf(
+ standardSeconds(
+ // At 2 seconds, all output has been yielded, but not yet
+ // declared final - so polling should terminate per termination
+ // condition.
+ // At 3 seconds, all output has been yielded (and declared final),
+ // so polling should terminate because of that without waiting for
+ // 100 seconds.
+ terminationConditionElapsesBeforeOutputIsFinal ? 2 : 100)))
+ .withPollInterval(Duration.millis(300))
+ .withOutputCoder(VarIntCoder.of()))
+ .apply("Drop input", Values.<Integer>create());
+
+ PAssert.that(res).containsInAnyOrder(all);
+
+ p.run();
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testMultiplePollsStopAfterTimeSinceNewOutput() {
+ List<Integer> all = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
+
+ PCollection<Integer> res =
+ p.apply(Create.of("a"))
+ .apply(
+ Watch.growthOf(
+ new TimedPollFn<String, Integer>(
+ all,
+ standardSeconds(1) /* timeToOutputEverything */,
+ // Never declare output final
+ standardSeconds(1000) /* timeToDeclareOutputFinal */,
+ standardSeconds(30) /* timeToFail */))
+ // Should terminate after 4 seconds - earlier than timeToFail
+ .withTerminationPerInput(
+ Watch.Growth.<String>afterTimeSinceNewOutput(standardSeconds(3)))
+ .withPollInterval(Duration.millis(300))
+ .withOutputCoder(VarIntCoder.of()))
+ .apply("Drop input", Values.<Integer>create());
+
+ PAssert.that(res).containsInAnyOrder(all);
+
+ p.run();
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testSinglePollWithManyResults() {
+ // More than the default 100 elements per checkpoint for direct runner.
+ final long numResults = 3000;
+ PCollection<KV<String, Integer>> res =
+ p.apply(Create.of("a"))
+ .apply(
+ Watch.growthOf(
+ new PollFn<String, KV<String, Integer>>() {
+ @Override
+ public PollResult<KV<String, Integer>> apply(String input, Instant time) {
+ String pollId = UUID.randomUUID().toString();
+ List<KV<String, Integer>> output = Lists.newArrayList();
+ for (int i = 0; i < numResults; ++i) {
+ output.add(KV.of(pollId, i));
+ }
+ return PollResult.complete(time, output);
+ }
+ })
+ .withTerminationPerInput(Watch.Growth.<String>afterTotalOf(standardSeconds(1)))
+ .withPollInterval(Duration.millis(1))
+ .withOutputCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())))
+ .apply("Drop input", Values.<KV<String, Integer>>create());
+
+ PAssert.that("Poll called only once", res.apply(Keys.<String>create()))
+ .satisfies(
+ new SerializableFunction<Iterable<String>, Void>() {
+ @Override
+ public Void apply(Iterable<String> pollIds) {
+ assertEquals(1, Sets.newHashSet(pollIds).size());
+ return null;
+ }
+ });
+ PAssert.that("Yields all expected results", res.apply("Drop poll id", Values.<Integer>create()))
+ .satisfies(
+ new SerializableFunction<Iterable<Integer>, Void>() {
+ @Override
+ public Void apply(Iterable<Integer> input) {
+ assertEquals(
+ "Total number of results mismatches",
+ numResults,
+ Lists.newArrayList(input).size());
+ assertEquals("Results are not unique", numResults, Sets.newHashSet(input).size());
+ return null;
+ }
+ });
+
+ p.run();
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesSplittableParDo.class})
+ public void testMultiplePollsWithManyResults() {
+ final long numResults = 3000;
+ List<Integer> all = Lists.newArrayList();
+ for (int i = 0; i < numResults; ++i) {
+ all.add(i);
+ }
+
+ PCollection<TimestampedValue<Integer>> res =
+ p.apply(Create.of("a"))
+ .apply(
+ Watch.growthOf(
+ new TimedPollFn<String, Integer>(
+ all,
+ standardSeconds(1) /* timeToOutputEverything */,
+ standardSeconds(3) /* timeToDeclareOutputFinal */,
+ standardSeconds(30) /* timeToFail */))
+ .withPollInterval(Duration.millis(500))
+ .withOutputCoder(VarIntCoder.of()))
+ .apply(ReifyTimestamps.<String, Integer>inValues())
+ .apply("Drop timestamped input", Values.<TimestampedValue<Integer>>create());
+
+ PAssert.that(res)
+ .satisfies(
+ new SerializableFunction<Iterable<TimestampedValue<Integer>>, Void>() {
+ @Override
+ public Void apply(Iterable<TimestampedValue<Integer>> outputs) {
+ Function<TimestampedValue<Integer>, Integer> extractValueFn =
+ new Function<TimestampedValue<Integer>, Integer>() {
+ @Nullable
+ @Override
+ public Integer apply(@Nullable TimestampedValue<Integer> input) {
+ return input.getValue();
+ }
+ };
+ Function<TimestampedValue<Integer>, Instant> extractTimestampFn =
+ new Function<TimestampedValue<Integer>, Instant>() {
+ @Nullable
+ @Override
+ public Instant apply(@Nullable TimestampedValue<Integer> input) {
+ return input.getTimestamp();
+ }
+ };
+
+ Ordering<TimestampedValue<Integer>> byValue =
+ Ordering.natural().onResultOf(extractValueFn);
+ Ordering<TimestampedValue<Integer>> byTimestamp =
+ Ordering.natural().onResultOf(extractTimestampFn);
+ // New outputs appear in timestamp order because each output's assigned timestamp
+ // is Instant.now() at the time of poll.
+ assertTrue(
+ "Outputs must be in timestamp order",
+ byTimestamp.isOrdered(byValue.sortedCopy(outputs)));
+ assertEquals(
+ "Yields all expected values",
+ Sets.newHashSet(Iterables.transform(outputs, extractValueFn)).size(),
+ numResults);
+ assertThat(
+ "Poll called more than once",
+ Sets.newHashSet(Iterables.transform(outputs, extractTimestampFn)).size(),
+ greaterThan(1));
+ return null;
+ }
+ });
+
+ p.run();
+ }
+
+ /**
+ * Gradually emits all items from the given list, pairing each one with a UUID that identifies the
+ * round of polling, so a client can check how many rounds of polling there were.
+ */
+ private static class TimedPollFn<InputT, OutputT> implements PollFn<InputT, OutputT> {
+ private final Instant baseTime;
+ private final List<OutputT> outputs;
+ private final Duration timeToOutputEverything;
+ private final Duration timeToDeclareOutputFinal;
+ private final Duration timeToFail;
+
+ public TimedPollFn(
+ List<OutputT> outputs,
+ Duration timeToOutputEverything,
+ Duration timeToDeclareOutputFinal,
+ Duration timeToFail) {
+ this.baseTime = Instant.now();
+ this.outputs = outputs;
+ this.timeToOutputEverything = timeToOutputEverything;
+ this.timeToDeclareOutputFinal = timeToDeclareOutputFinal;
+ this.timeToFail = timeToFail;
+ }
+
+ @Override
+ public PollResult<OutputT> apply(InputT input, Instant time) {
+ Instant now = Instant.now();
+ Duration elapsed = new Duration(baseTime, Instant.now());
+ if (elapsed.isLongerThan(timeToFail)) {
+ fail(
+ String.format(
+ "Poll called %s after base time, which is longer than the threshold of %s",
+ elapsed, timeToFail));
+ }
+
+ double fractionElapsed = 1.0 * elapsed.getMillis() / timeToOutputEverything.getMillis();
+ int numToEmit = (int) Math.min(outputs.size(), fractionElapsed * outputs.size());
+ List<TimestampedValue<OutputT>> toEmit = Lists.newArrayList();
+ for (int i = 0; i < numToEmit; ++i) {
+ toEmit.add(TimestampedValue.of(outputs.get(i), now));
+ }
+ return elapsed.isLongerThan(timeToDeclareOutputFinal)
+ ? PollResult.complete(toEmit)
+ : PollResult.incomplete(toEmit).withWatermark(now);
+ }
+ }
+
+ @Test
+ public void testTerminationConditionsNever() {
+ Watch.Growth.Never<Object> c = never();
+ Integer state = c.forNewInput(Instant.now(), null);
+ assertFalse(c.canStopPolling(Instant.now(), state));
+ }
+
+ @Test
+ public void testTerminationConditionsAfterTotalOf() {
+ Instant now = Instant.now();
+ Watch.Growth.AfterTotalOf<Object> c = afterTotalOf(standardSeconds(5));
+ KV<Instant, ReadableDuration> state = c.forNewInput(now, null);
+ assertFalse(c.canStopPolling(now, state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(3)), state));
+ assertTrue(c.canStopPolling(now.plus(standardSeconds(6)), state));
+ }
+
+ @Test
+ public void testTerminationConditionsAfterTimeSinceNewOutput() {
+ Instant now = Instant.now();
+ Watch.Growth.AfterTimeSinceNewOutput<Object> c = afterTimeSinceNewOutput(standardSeconds(5));
+ KV<Instant, ReadableDuration> state = c.forNewInput(now, null);
+ assertFalse(c.canStopPolling(now, state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(3)), state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(6)), state));
+
+ state = c.onSeenNewOutput(now.plus(standardSeconds(3)), state);
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(3)), state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(6)), state));
+ assertTrue(c.canStopPolling(now.plus(standardSeconds(9)), state));
+
+ state = c.onSeenNewOutput(now.plus(standardSeconds(5)), state);
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(3)), state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(6)), state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(9)), state));
+ assertTrue(c.canStopPolling(now.plus(standardSeconds(11)), state));
+ }
+
+ @Test
+ public void testTerminationConditionsEitherOf() {
+ Instant now = Instant.now();
+ Watch.Growth.AfterTotalOf<Object> a = afterTotalOf(standardSeconds(5));
+ Watch.Growth.AfterTotalOf<Object> b = afterTotalOf(standardSeconds(10));
+
+ Watch.Growth.BinaryCombined<
+ Object, KV<Instant, ReadableDuration>, KV<Instant, ReadableDuration>>
+ c = eitherOf(a, b);
+ KV<KV<Instant, ReadableDuration>, KV<Instant, ReadableDuration>> state =
+ c.forNewInput(now, null);
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(3)), state));
+ assertTrue(c.canStopPolling(now.plus(standardSeconds(7)), state));
+ assertTrue(c.canStopPolling(now.plus(standardSeconds(12)), state));
+ }
+
+ @Test
+ public void testTerminationConditionsAllOf() {
+ Instant now = Instant.now();
+ Watch.Growth.AfterTotalOf<Object> a = afterTotalOf(standardSeconds(5));
+ Watch.Growth.AfterTotalOf<Object> b = afterTotalOf(standardSeconds(10));
+
+ Watch.Growth.BinaryCombined<
+ Object, KV<Instant, ReadableDuration>, KV<Instant, ReadableDuration>>
+ c = allOf(a, b);
+ KV<KV<Instant, ReadableDuration>, KV<Instant, ReadableDuration>> state =
+ c.forNewInput(now, null);
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(3)), state));
+ assertFalse(c.canStopPolling(now.plus(standardSeconds(7)), state));
+ assertTrue(c.canStopPolling(now.plus(standardSeconds(12)), state));
+ }
+
+ private static GrowthTracker<String, Integer> newTracker(GrowthState<String, Integer> state) {
+ return new GrowthTracker<>(StringUtf8Coder.of(), state, never());
+ }
+
+ private static GrowthTracker<String, Integer> newTracker() {
+ return newTracker(new GrowthState<String, Integer>(never().forNewInput(Instant.now(), null)));
+ }
+
+ @Test
+ public void testGrowthTrackerCheckpointEmpty() {
+ // Checkpoint an empty tracker.
+ GrowthTracker<String, Integer> tracker = newTracker();
+ GrowthState<String, Integer> residual = tracker.checkpoint();
+ GrowthState<String, Integer> primary = tracker.currentRestriction();
+ Watch.Growth.Never<String> condition = never();
+ assertEquals(
+ primary.toString(condition),
+ new GrowthState<>(
+ Collections.<HashCode, Instant>emptyMap() /* completed */,
+ Collections.<TimestampedValue<String>>emptyList() /* pending */,
+ true /* isOutputFinal */,
+ (Integer) null /* terminationState */,
+ BoundedWindow.TIMESTAMP_MAX_VALUE /* pollWatermark */)
+ .toString(condition));
+ assertEquals(
+ residual.toString(condition),
+ new GrowthState<>(
+ Collections.<HashCode, Instant>emptyMap() /* completed */,
+ Collections.<TimestampedValue<String>>emptyList() /* pending */,
+ false /* isOutputFinal */,
+ 0 /* terminationState */,
+ BoundedWindow.TIMESTAMP_MIN_VALUE /* pollWatermark */)
+ .toString(condition));
+ }
+
+ @Test
+ public void testGrowthTrackerCheckpointNonEmpty() {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2)))))
+ .withWatermark(now.plus(standardSeconds(7))));
+
+ assertEquals(now.plus(standardSeconds(1)), tracker.getWatermark());
+ assertTrue(tracker.hasPending());
+ assertEquals("a", tracker.tryClaimNextPending().getValue());
+ assertTrue(tracker.hasPending());
+ assertEquals("b", tracker.tryClaimNextPending().getValue());
+ assertTrue(tracker.hasPending());
+ assertEquals(now.plus(standardSeconds(3)), tracker.getWatermark());
+
+ GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+ GrowthTracker<String, Integer> primaryTracker = newTracker(tracker.currentRestriction());
+
+ // Verify primary: should contain what the current tracker claimed, and nothing else.
+ assertEquals(now.plus(standardSeconds(1)), primaryTracker.getWatermark());
+ assertTrue(primaryTracker.hasPending());
+ assertEquals("a", primaryTracker.tryClaimNextPending().getValue());
+ assertTrue(primaryTracker.hasPending());
+ assertEquals("b", primaryTracker.tryClaimNextPending().getValue());
+ assertFalse(primaryTracker.hasPending());
+ assertFalse(primaryTracker.shouldPollMore());
+ // No more pending elements in primary restriction, and no polling.
+ primaryTracker.checkDone();
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, primaryTracker.getWatermark());
+
+ // Verify residual: should contain what the current tracker didn't claim.
+ assertEquals(now.plus(standardSeconds(3)), residualTracker.getWatermark());
+ assertTrue(residualTracker.hasPending());
+ assertEquals("c", residualTracker.tryClaimNextPending().getValue());
+ assertTrue(residualTracker.hasPending());
+ assertEquals("d", residualTracker.tryClaimNextPending().getValue());
+ assertFalse(residualTracker.hasPending());
+ assertTrue(residualTracker.shouldPollMore());
+ // No more pending elements in residual restriction, but poll watermark still holds.
+ assertEquals(now.plus(standardSeconds(7)), residualTracker.getWatermark());
+
+ // Verify current tracker: it was checkpointed, so should contain nothing else.
+ assertNull(tracker.tryClaimNextPending());
+ tracker.checkDone();
+ assertFalse(tracker.hasPending());
+ assertFalse(tracker.shouldPollMore());
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark());
+ }
+
+ @Test
+ public void testGrowthTrackerOutputFullyBeforeCheckpointIncomplete() {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2)))))
+ .withWatermark(now.plus(standardSeconds(7))));
+
+ assertEquals("a", tracker.tryClaimNextPending().getValue());
+ assertEquals("b", tracker.tryClaimNextPending().getValue());
+ assertEquals("c", tracker.tryClaimNextPending().getValue());
+ assertEquals("d", tracker.tryClaimNextPending().getValue());
+ assertFalse(tracker.hasPending());
+ assertEquals(now.plus(standardSeconds(7)), tracker.getWatermark());
+
+ GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+ GrowthTracker<String, Integer> primaryTracker = newTracker(tracker.currentRestriction());
+
+ // Verify primary: should contain what the current tracker claimed, and nothing else.
+ assertEquals(now.plus(standardSeconds(1)), primaryTracker.getWatermark());
+ assertTrue(primaryTracker.hasPending());
+ assertEquals("a", primaryTracker.tryClaimNextPending().getValue());
+ assertTrue(primaryTracker.hasPending());
+ assertEquals("b", primaryTracker.tryClaimNextPending().getValue());
+ assertTrue(primaryTracker.hasPending());
+ assertEquals("c", primaryTracker.tryClaimNextPending().getValue());
+ assertTrue(primaryTracker.hasPending());
+ assertEquals("d", primaryTracker.tryClaimNextPending().getValue());
+ assertFalse(primaryTracker.hasPending());
+ assertFalse(primaryTracker.shouldPollMore());
+ // No more pending elements in primary restriction, and no polling.
+ primaryTracker.checkDone();
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, primaryTracker.getWatermark());
+
+ // Verify residual: should contain what the current tracker didn't claim.
+ assertFalse(residualTracker.hasPending());
+ assertTrue(residualTracker.shouldPollMore());
+ // No more pending elements in residual restriction, but poll watermark still holds.
+ assertEquals(now.plus(standardSeconds(7)), residualTracker.getWatermark());
+
+ // Verify current tracker: it was checkpointed, so should contain nothing else.
+ tracker.checkDone();
+ assertFalse(tracker.hasPending());
+ assertFalse(tracker.shouldPollMore());
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark());
+ }
+
+ @Test
+ public void testGrowthTrackerPollAfterCheckpointIncompleteWithNewOutputs() {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2)))))
+ .withWatermark(now.plus(standardSeconds(7))));
+
+ assertEquals("a", tracker.tryClaimNextPending().getValue());
+ assertEquals("b", tracker.tryClaimNextPending().getValue());
+ assertEquals("c", tracker.tryClaimNextPending().getValue());
+ assertEquals("d", tracker.tryClaimNextPending().getValue());
+
+ GrowthState<String, Integer> checkpoint = tracker.checkpoint();
+ // Simulate resuming from the checkpoint and adding more elements.
+ {
+ GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+ residualTracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("e", now.plus(standardSeconds(5))),
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2))),
+ TimestampedValue.of("f", now.plus(standardSeconds(8)))))
+ .withWatermark(now.plus(standardSeconds(12))));
+
+ assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark());
+ assertEquals("e", residualTracker.tryClaimNextPending().getValue());
+ assertEquals(now.plus(standardSeconds(8)), residualTracker.getWatermark());
+ assertEquals("f", residualTracker.tryClaimNextPending().getValue());
+
+ assertFalse(residualTracker.hasPending());
+ assertTrue(residualTracker.shouldPollMore());
+ assertEquals(now.plus(standardSeconds(12)), residualTracker.getWatermark());
+ }
+ // Try same without an explicitly specified watermark.
+ {
+ GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+ residualTracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("e", now.plus(standardSeconds(5))),
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2))),
+ TimestampedValue.of("f", now.plus(standardSeconds(8))))));
+
+ assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark());
+ assertEquals("e", residualTracker.tryClaimNextPending().getValue());
+ assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark());
+ assertEquals("f", residualTracker.tryClaimNextPending().getValue());
+
+ assertFalse(residualTracker.hasPending());
+ assertTrue(residualTracker.shouldPollMore());
+ assertEquals(now.plus(standardSeconds(5)), residualTracker.getWatermark());
+ }
+ }
+
+ @Test
+ public void testGrowthTrackerPollAfterCheckpointWithoutNewOutputs() {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2)))))
+ .withWatermark(now.plus(standardSeconds(7))));
+
+ assertEquals("a", tracker.tryClaimNextPending().getValue());
+ assertEquals("b", tracker.tryClaimNextPending().getValue());
+ assertEquals("c", tracker.tryClaimNextPending().getValue());
+ assertEquals("d", tracker.tryClaimNextPending().getValue());
+
+ // Simulate resuming from the checkpoint but there are no new elements.
+ GrowthState<String, Integer> checkpoint = tracker.checkpoint();
+ {
+ GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+ residualTracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2)))))
+ .withWatermark(now.plus(standardSeconds(12))));
+
+ assertFalse(residualTracker.hasPending());
+ assertTrue(residualTracker.shouldPollMore());
+ assertEquals(now.plus(standardSeconds(12)), residualTracker.getWatermark());
+ }
+ // Try the same without an explicitly specified watermark
+ {
+ GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+ residualTracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2))))));
+ // No new elements and no explicit watermark supplied - should reuse old watermark.
+ assertEquals(now.plus(standardSeconds(7)), residualTracker.getWatermark());
+ }
+ }
+
+ @Test
+ public void testGrowthTrackerPollAfterCheckpointWithoutNewOutputsNoWatermark() {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2))))));
+ assertEquals("a", tracker.tryClaimNextPending().getValue());
+ assertEquals("b", tracker.tryClaimNextPending().getValue());
+ assertEquals("c", tracker.tryClaimNextPending().getValue());
+ assertEquals("d", tracker.tryClaimNextPending().getValue());
+ assertEquals(now.plus(standardSeconds(1)), tracker.getWatermark());
+
+ // Simulate resuming from the checkpoint but there are no new elements.
+ GrowthState<String, Integer> checkpoint = tracker.checkpoint();
+ GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+ residualTracker.addNewAsPending(
+ PollResult.incomplete(
+ Arrays.asList(
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2))))));
+ // No new elements and no explicit watermark supplied - should keep old watermark.
+ assertEquals(now.plus(standardSeconds(1)), residualTracker.getWatermark());
+ }
+
+ @Test
+ public void testGrowthTrackerRepeatedEmptyPollWatermark() {
+ // Empty poll result with no watermark
+ {
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList()));
+ assertEquals(BoundedWindow.TIMESTAMP_MIN_VALUE, tracker.getWatermark());
+
+ // Simulate resuming from the checkpoint but there are still no new elements.
+ GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+ tracker.addNewAsPending(
+ PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList()));
+ // No new elements and no explicit watermark supplied - still no watermark.
+ assertEquals(BoundedWindow.TIMESTAMP_MIN_VALUE, residualTracker.getWatermark());
+ }
+ // Empty poll result with watermark
+ {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList())
+ .withWatermark(now));
+ assertEquals(now, tracker.getWatermark());
+
+ // Simulate resuming from the checkpoint but there are still no new elements.
+ GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+ tracker.addNewAsPending(
+ PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList()));
+ // No new elements and no explicit watermark supplied - should keep old watermark.
+ assertEquals(now, residualTracker.getWatermark());
+ }
+ }
+
+ @Test
+ public void testGrowthTrackerOutputFullyBeforeCheckpointComplete() {
+ Instant now = Instant.now();
+ GrowthTracker<String, Integer> tracker = newTracker();
+ tracker.addNewAsPending(
+ PollResult.complete(
+ Arrays.asList(
+ TimestampedValue.of("d", now.plus(standardSeconds(4))),
+ TimestampedValue.of("c", now.plus(standardSeconds(3))),
+ TimestampedValue.of("a", now.plus(standardSeconds(1))),
+ TimestampedValue.of("b", now.plus(standardSeconds(2))))));
+
+ assertEquals("a", tracker.tryClaimNextPending().getValue());
+ assertEquals("b", tracker.tryClaimNextPending().getValue());
+ assertEquals("c", tracker.tryClaimNextPending().getValue());
+ assertEquals("d", tracker.tryClaimNextPending().getValue());
+ assertFalse(tracker.hasPending());
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark());
+
+ GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+
+ // Verify residual: should be empty, since output was final.
+ residualTracker.checkDone();
+ assertFalse(residualTracker.hasPending());
+ assertFalse(residualTracker.shouldPollMore());
+ // No more pending elements in residual restriction, but poll watermark still holds.
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, residualTracker.getWatermark());
+
+ // Verify current tracker: it was checkpointed, so should contain nothing else.
+ tracker.checkDone();
+ assertFalse(tracker.hasPending());
+ assertFalse(tracker.shouldPollMore());
+ assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark());
+ }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowTest.java
index 65af7a1..5b6d046 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/WindowTest.java
@@ -31,19 +31,30 @@
import static org.mockito.Mockito.when;
import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import java.io.Serializable;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
+import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.UsesCustomWindowMerging;
import org.apache.beam.sdk.testing.ValidatesRunner;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
@@ -570,4 +581,177 @@
assertThat(data, not(hasDisplayItem("trigger")));
assertThat(data, not(hasDisplayItem("allowedLateness")));
}
+ @Test
+ @Category({ValidatesRunner.class, UsesCustomWindowMerging.class})
+ public void testMergingCustomWindows() {
+ Instant startInstant = new Instant(0L);
+ List<TimestampedValue<String>> input = new ArrayList<>();
+ input.add(TimestampedValue.of("big", startInstant.plus(Duration.standardSeconds(10))));
+ input.add(TimestampedValue.of("small1", startInstant.plus(Duration.standardSeconds(20))));
+ // This one will be outside of bigWindow thus not merged
+ input.add(TimestampedValue.of("small2", startInstant.plus(Duration.standardSeconds(39))));
+ PCollection<String> inputCollection = pipeline.apply(Create.timestamped(input));
+ PCollection<String> windowedCollection = inputCollection
+ .apply(Window.into(new CustomWindowFn<String>()));
+ PCollection<Long> count = windowedCollection
+ .apply(Combine.globally(Count.<String>combineFn()).withoutDefaults());
+ // "small1" and "big" elements merged into bigWindow "small2" not merged
+ // because timestamp is not in bigWindow
+ PAssert.that("Wrong number of elements in output collection", count).containsInAnyOrder(2L, 1L);
+ pipeline.run();
+ }
+
+ // This test is usefull because some runners have a special merge implementation
+ // for keyed collections
+ @Test
+ @Category({ValidatesRunner.class, UsesCustomWindowMerging.class})
+ public void testMergingCustomWindowsKeyedCollection() {
+ Instant startInstant = new Instant(0L);
+ List<TimestampedValue<KV<Integer, String>>> input = new ArrayList<>();
+ input
+ .add(TimestampedValue.of(KV.of(0, "big"), startInstant.plus(Duration.standardSeconds(10))));
+ input.add(
+ TimestampedValue.of(KV.of(1, "small1"), startInstant.plus(Duration.standardSeconds(20))));
+ // This one will be outside of bigWindow thus not merged
+ input.add(
+ TimestampedValue.of(KV.of(2, "small2"), startInstant.plus(Duration.standardSeconds(39))));
+ PCollection<KV<Integer, String>> inputCollection = pipeline.apply(Create.timestamped(input));
+ PCollection<KV<Integer, String>> windowedCollection = inputCollection
+ .apply(Window.into(new CustomWindowFn<KV<Integer, String>>()));
+ PCollection<Long> count = windowedCollection
+ .apply(Combine.globally(Count.<KV<Integer, String>>combineFn()).withoutDefaults());
+ // "small1" and "big" elements merged into bigWindow "small2" not merged
+ // because timestamp is not in bigWindow
+ PAssert.that("Wrong number of elements in output collection", count).containsInAnyOrder(2L, 1L);
+ pipeline.run();
+ }
+
+ private static class CustomWindow extends IntervalWindow {
+
+ private boolean isBig;
+
+
+ CustomWindow(Instant start, Instant end) {
+ super(start, end);
+ this.isBig = false;
+ }
+
+ CustomWindow(Instant start, Instant end, boolean isBig) {
+ super(start, end);
+ this.isBig = isBig;
+ }
+
+ @Override public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ CustomWindow that = (CustomWindow) o;
+ return isBig == that.isBig;
+ }
+
+ @Override public int hashCode() {
+ return Objects.hash(super.hashCode(), isBig);
+ }
+ }
+
+ private static class CustomWindowCoder extends
+ CustomCoder<CustomWindow> {
+
+ private static final CustomWindowCoder INSTANCE = new CustomWindowCoder();
+ private static final Coder<IntervalWindow> INTERVAL_WINDOW_CODER = IntervalWindow.getCoder();
+ private static final VarIntCoder VAR_INT_CODER = VarIntCoder.of();
+
+ public static CustomWindowCoder of() {
+ return INSTANCE;
+ }
+
+ @Override
+ public void encode(CustomWindow window, OutputStream outStream)
+ throws IOException {
+ INTERVAL_WINDOW_CODER.encode(window, outStream);
+ VAR_INT_CODER.encode(window.isBig ? 1 : 0, outStream);
+ }
+
+ @Override
+ public CustomWindow decode(InputStream inStream) throws IOException {
+ IntervalWindow superWindow = INTERVAL_WINDOW_CODER.decode(inStream);
+ boolean isBig = VAR_INT_CODER.decode(inStream) != 0;
+ return new CustomWindow(superWindow.start(), superWindow.end(), isBig);
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ INTERVAL_WINDOW_CODER.verifyDeterministic();
+ VAR_INT_CODER.verifyDeterministic();
+ }
+ }
+
+ private static class CustomWindowFn<T> extends WindowFn<T, CustomWindow> {
+
+ @Override public Collection<CustomWindow> assignWindows(AssignContext c) throws Exception {
+ String element;
+ // It loses genericity of type T but this is not a big deal for a test.
+ // And it allows to avoid duplicating CustomWindowFn to support PCollection<KV>
+ if (c.element() instanceof KV){
+ element = ((KV<Integer, String>) c.element()).getValue();
+ } else {
+ element = (String) c.element();
+ }
+ // put big elements in windows of 30s and small ones in windows of 5s
+ if ("big".equals(element)) {
+ return Collections.singletonList(
+ new CustomWindow(c.timestamp(), c.timestamp().plus(Duration.standardSeconds(30)),
+ true));
+ } else {
+ return Collections.singletonList(
+ new CustomWindow(c.timestamp(), c.timestamp().plus(Duration.standardSeconds(5)),
+ false));
+ }
+ }
+
+ @Override
+ public void mergeWindows(MergeContext c) throws Exception {
+ List<CustomWindow> toBeMerged = new ArrayList<>();
+ CustomWindow bigWindow = null;
+ for (CustomWindow customWindow : c.windows()) {
+ if (customWindow.isBig) {
+ bigWindow = customWindow;
+ toBeMerged.add(customWindow);
+ } else if (bigWindow != null
+ && customWindow.start().isAfter(bigWindow.start())
+ && customWindow.end().isBefore(bigWindow.end())) {
+ toBeMerged.add(customWindow);
+ }
+ }
+ // in case bigWindow has not been seen yet
+ if (bigWindow != null) {
+ // merge small windows into big windows
+ c.merge(toBeMerged, bigWindow);
+ }
+ }
+
+ @Override
+ public boolean isCompatible(WindowFn<?, ?> other) {
+ return other instanceof CustomWindowFn;
+ }
+
+ @Override
+ public Coder<CustomWindow> windowCoder() {
+ return CustomWindowCoder.of();
+ }
+
+ @Override
+ public WindowMappingFn<CustomWindow> getDefaultWindowMappingFn() {
+ throw new UnsupportedOperationException("side inputs not supported");
+ }
+
+
+ }
+
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java
index 58e2bbd..33503b6 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java
@@ -31,6 +31,7 @@
import java.util.Map;
import java.util.Map.Entry;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
@@ -59,9 +60,9 @@
@Test
public void testOfThenHas() {
- PCollection<Object> pCollection = PCollection.createPrimitiveOutputInternal(
- pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED);
- TupleTag<Object> tag = new TupleTag<>();
+ PCollection<Integer> pCollection = PCollection.createPrimitiveOutputInternal(
+ pipeline, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
+ TupleTag<Integer> tag = new TupleTag<>();
assertTrue(PCollectionTuple.of(tag, pCollection).has(tag));
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/TypeDescriptorsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/TypeDescriptorsTest.java
index 1bf0fc9..a4f58da 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/TypeDescriptorsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/TypeDescriptorsTest.java
@@ -25,6 +25,7 @@
import static org.apache.beam.sdk.values.TypeDescriptors.strings;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
import java.util.List;
import java.util.Set;
@@ -70,4 +71,52 @@
assertNotEquals(descriptor, new TypeDescriptor<List<String>>() {});
assertNotEquals(descriptor, new TypeDescriptor<List<Boolean>>() {});
}
+
+ private interface Generic<FooT, BarT> {}
+
+ private static <ActualFooT> Generic<ActualFooT, String> typeErasedGeneric() {
+ return new Generic<ActualFooT, String>() {};
+ }
+
+ private static <ActualFooT, ActualBarT> TypeDescriptor<ActualFooT> extractFooT(
+ Generic<ActualFooT, ActualBarT> instance) {
+ return TypeDescriptors.extractFromTypeParameters(
+ instance,
+ Generic.class,
+ new TypeDescriptors.TypeVariableExtractor<
+ Generic<ActualFooT, ActualBarT>, ActualFooT>() {});
+ }
+
+ private static <ActualFooT, ActualBarT> TypeDescriptor<ActualBarT> extractBarT(
+ Generic<ActualFooT, ActualBarT> instance) {
+ return TypeDescriptors.extractFromTypeParameters(
+ instance,
+ Generic.class,
+ new TypeDescriptors.TypeVariableExtractor<
+ Generic<ActualFooT, ActualBarT>, ActualBarT>() {});
+ }
+
+ private static <ActualFooT, ActualBarT> TypeDescriptor<KV<ActualFooT, ActualBarT>> extractKV(
+ Generic<ActualFooT, ActualBarT> instance) {
+ return TypeDescriptors.extractFromTypeParameters(
+ instance,
+ Generic.class,
+ new TypeDescriptors.TypeVariableExtractor<
+ Generic<ActualFooT, ActualBarT>, KV<ActualFooT, ActualBarT>>() {});
+ }
+
+ @Test
+ public void testTypeDescriptorsTypeParameterOf() throws Exception {
+ assertEquals(strings(), extractFooT(new Generic<String, Integer>() {}));
+ assertEquals(integers(), extractBarT(new Generic<String, Integer>() {}));
+ assertEquals(kvs(strings(), integers()), extractKV(new Generic<String, Integer>() {}));
+ }
+
+ @Test
+ public void testTypeDescriptorsTypeParameterOfErased() throws Exception {
+ Generic<Integer, String> instance = TypeDescriptorsTest.typeErasedGeneric();
+ assertNull(extractFooT(instance));
+ assertEquals(strings(), extractBarT(instance));
+ assertNull(extractKV(instance));
+ }
}
diff --git a/sdks/java/extensions/sorter/src/main/java/org/apache/beam/sdk/extensions/sorter/SortValues.java b/sdks/java/extensions/sorter/src/main/java/org/apache/beam/sdk/extensions/sorter/SortValues.java
index d1b4d07..cb9d984 100644
--- a/sdks/java/extensions/sorter/src/main/java/org/apache/beam/sdk/extensions/sorter/SortValues.java
+++ b/sdks/java/extensions/sorter/src/main/java/org/apache/beam/sdk/extensions/sorter/SortValues.java
@@ -76,18 +76,14 @@
@Override
public PCollection<KV<PrimaryKeyT, Iterable<KV<SecondaryKeyT, ValueT>>>> expand(
PCollection<KV<PrimaryKeyT, Iterable<KV<SecondaryKeyT, ValueT>>>> input) {
- return input.apply(
- ParDo.of(
- new SortValuesDoFn<PrimaryKeyT, SecondaryKeyT, ValueT>(
- sorterOptions,
- getSecondaryKeyCoder(input.getCoder()),
- getValueCoder(input.getCoder()))));
- }
-
- @Override
- protected Coder<KV<PrimaryKeyT, Iterable<KV<SecondaryKeyT, ValueT>>>> getDefaultOutputCoder(
- PCollection<KV<PrimaryKeyT, Iterable<KV<SecondaryKeyT, ValueT>>>> input) {
- return input.getCoder();
+ return input
+ .apply(
+ ParDo.of(
+ new SortValuesDoFn<PrimaryKeyT, SecondaryKeyT, ValueT>(
+ sorterOptions,
+ getSecondaryKeyCoder(input.getCoder()),
+ getValueCoder(input.getCoder()))))
+ .setCoder(input.getCoder());
}
/** Retrieves the {@link Coder} for the secondary key-value pairs. */
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index e2c17b0..1e611db 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -24,7 +24,6 @@
import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimap;
-import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
@@ -35,8 +34,8 @@
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.dataflow.util.CloudObject;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.runners.core.construction.CoderTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -91,8 +90,9 @@
.setPrimitiveTransformReference(pTransformId)
.setName(getOnlyElement(pTransform.getOutputsMap().keySet()))
.build();
- RunnerApi.Coder coderSpec = coders.get(pCollections.get(
- getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
+ RunnerApi.Coder coderSpec =
+ coders.get(
+ pCollections.get(getOnlyElement(pTransform.getOutputsMap().values())).getCoderId());
Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers =
(Collection) pCollectionIdsToConsumers.get(
getOnlyElement(pTransform.getOutputsMap().values()));
@@ -102,6 +102,7 @@
processBundleInstructionId,
target,
coderSpec,
+ coders,
beamFnDataClient,
consumers);
addStartFunction.accept(runner::registerInputLocation);
@@ -124,6 +125,7 @@
Supplier<String> processBundleInstructionIdSupplier,
BeamFnApi.Target inputTarget,
RunnerApi.Coder coderSpec,
+ Map<String, RunnerApi.Coder> coders,
BeamFnDataClient beamFnDataClientFactory,
Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers)
throws IOException {
@@ -137,17 +139,10 @@
@SuppressWarnings("unchecked")
Coder<WindowedValue<OutputT>> coder =
(Coder<WindowedValue<OutputT>>)
- CloudObjects.coderFromCloudObject(
- CloudObject.fromSpec(
- OBJECT_MAPPER.readValue(
- coderSpec
- .getSpec()
- .getSpec()
- .getParameter()
- .unpack(BytesValue.class)
- .getValue()
- .newInput(),
- Map.class)));
+ CoderTranslation.fromProto(
+ coderSpec,
+ RehydratedComponents.forComponents(
+ RunnerApi.Components.newBuilder().putAllCoders(coders).build()));
this.coder = coder;
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
index eec4dfd..bbed753 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java
@@ -24,7 +24,6 @@
import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimap;
-import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.Map;
import java.util.function.Consumer;
@@ -34,8 +33,8 @@
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.dataflow.util.CloudObject;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.runners.core.construction.CoderTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -93,6 +92,7 @@
processBundleInstructionId,
target,
coderSpec,
+ coders,
beamFnDataClient);
addStartFunction.accept(runner::registerForOutput);
pCollectionIdsToConsumers.put(
@@ -117,6 +117,7 @@
Supplier<String> processBundleInstructionIdSupplier,
BeamFnApi.Target outputTarget,
RunnerApi.Coder coderSpec,
+ Map<String, RunnerApi.Coder> coders,
BeamFnDataClient beamFnDataClientFactory)
throws IOException {
this.apiServiceDescriptor = functionSpec.getParameter().unpack(BeamFnApi.RemoteGrpcPort.class)
@@ -128,17 +129,10 @@
@SuppressWarnings("unchecked")
Coder<WindowedValue<InputT>> coder =
(Coder<WindowedValue<InputT>>)
- CloudObjects.coderFromCloudObject(
- CloudObject.fromSpec(
- OBJECT_MAPPER.readValue(
- coderSpec
- .getSpec()
- .getSpec()
- .getParameter()
- .unpack(BytesValue.class)
- .getValue()
- .newInput(),
- Map.class)));
+ CoderTranslation.fromProto(
+ coderSpec,
+ RehydratedComponents.forComponents(
+ RunnerApi.Components.newBuilder().putAllCoders(coders).build()));
this.coder = coder;
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/stream/DataStreams.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/stream/DataStreams.java
new file mode 100644
index 0000000..d23d784
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/stream/DataStreams.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.stream;
+
+import com.google.common.io.ByteStreams;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import java.util.concurrent.BlockingQueue;
+import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer;
+
+/**
+ * {@link #inbound(Iterator)} treats multiple {@link ByteString}s as a single input stream and
+ * {@link #outbound(CloseableThrowingConsumer)} treats a single {@link OutputStream} as mulitple
+ * {@link ByteString}s.
+ */
+public class DataStreams {
+ /**
+ * Converts multiple {@link ByteString}s into a single {@link InputStream}.
+ *
+ * <p>The iterator is accessed lazily. The supplied {@link Iterator} should block until
+ * either it knows that no more values will be provided or it has the next {@link ByteString}.
+ */
+ public static InputStream inbound(Iterator<ByteString> bytes) {
+ return new Inbound(bytes);
+ }
+
+ /**
+ * Converts a single {@link OutputStream} into multiple {@link ByteString}s.
+ */
+ public static OutputStream outbound(CloseableThrowingConsumer<ByteString> consumer) {
+ // TODO: Migrate logic from BeamFnDataBufferingOutboundObserver
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * An input stream which concatenates multiple {@link ByteString}s. Lazily accesses the
+ * first {@link Iterator} on first access of this input stream.
+ *
+ * <p>Closing this input stream has no effect.
+ */
+ private static class Inbound<T> extends InputStream {
+ private static final InputStream EMPTY_STREAM = new InputStream() {
+ @Override
+ public int read() throws IOException {
+ return -1;
+ }
+ };
+
+ private final Iterator<ByteString> bytes;
+ private InputStream currentStream;
+
+ public Inbound(Iterator<ByteString> bytes) {
+ this.currentStream = EMPTY_STREAM;
+ this.bytes = bytes;
+ }
+
+ @Override
+ public int read() throws IOException {
+ int rval = -1;
+ // Move on to the next stream if we have read nothing
+ while ((rval = currentStream.read()) == -1 && bytes.hasNext()) {
+ currentStream = bytes.next().newInput();
+ }
+ return rval;
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ int remainingLen = len;
+ while ((remainingLen -= ByteStreams.read(
+ currentStream, b, off + len - remainingLen, remainingLen)) > 0) {
+ if (bytes.hasNext()) {
+ currentStream = bytes.next().newInput();
+ } else {
+ int bytesRead = len - remainingLen;
+ return bytesRead > 0 ? bytesRead : -1;
+ }
+ }
+ return len - remainingLen;
+ }
+ }
+
+ /**
+ * Allows for one or more writing threads to append values to this iterator while one reading
+ * thread reads values. {@link #hasNext()} and {@link #next()} will block until a value is
+ * available or this has been closed.
+ *
+ * <p>External synchronization must be provided if multiple readers would like to access the
+ * {@link Iterator#hasNext()} and {@link Iterator#next()} methods.
+ *
+ * <p>The order or values which are appended to this iterator is nondeterministic when multiple
+ * threads call {@link #accept(Object)}.
+ */
+ public static class BlockingQueueIterator<T> implements
+ CloseableThrowingConsumer<T>, Iterator<T> {
+ private static final Object POISION_PILL = new Object();
+ private final BlockingQueue<T> queue;
+
+ /** Only accessed by {@link Iterator#hasNext()} and {@link Iterator#next()} methods. */
+ private T currentElement;
+
+ public BlockingQueueIterator(BlockingQueue<T> queue) {
+ this.queue = queue;
+ }
+
+ @Override
+ public void close() throws Exception {
+ queue.put((T) POISION_PILL);
+ }
+
+ @Override
+ public void accept(T t) throws Exception {
+ queue.put(t);
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (currentElement == null) {
+ try {
+ currentElement = queue.take();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IllegalStateException(e);
+ }
+ }
+ return currentElement != POISION_PILL;
+ }
+
+ @Override
+ public T next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ T rval = currentElement;
+ currentElement = null;
+ return rval;
+ }
+ }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index a7c6666..d712f5f 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -30,7 +30,6 @@
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
@@ -39,8 +38,6 @@
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.Any;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -56,10 +53,11 @@
import org.apache.beam.fn.harness.test.TestExecutors;
import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.MessageWithComponents;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
@@ -79,7 +77,6 @@
@RunWith(JUnit4.class)
public class BeamFnDataReadRunnerTest {
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
.setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
@@ -88,19 +85,19 @@
WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
private static final String CODER_SPEC_ID = "string-coder-id";
private static final RunnerApi.Coder CODER_SPEC;
+ private static final RunnerApi.Components COMPONENTS;
private static final String URN = "urn:org.apache.beam:source:runner:0.1";
static {
try {
- CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
- RunnerApi.SdkFunctionSpec.newBuilder().setSpec(
- RunnerApi.FunctionSpec.newBuilder().setParameter(
- Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
- OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER))))
- .build()))
- .build())
- .build())
- .build();
+ MessageWithComponents coderAndComponents = CoderTranslation.toProto(CODER);
+ CODER_SPEC = coderAndComponents.getCoder();
+ COMPONENTS =
+ coderAndComponents
+ .getComponents()
+ .toBuilder()
+ .putCoders(CODER_SPEC_ID, CODER_SPEC)
+ .build();
} catch (IOException e) {
throw new ExceptionInInitializerError(e);
}
@@ -150,7 +147,7 @@
Suppliers.ofInstance(bundleId)::get,
ImmutableMap.of("outputPC",
RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()),
- ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC),
+ COMPONENTS.getCodersMap(),
consumers,
startFunctions::add,
finishFunctions::add);
@@ -200,6 +197,7 @@
bundleId::get,
INPUT_TARGET,
CODER_SPEC,
+ COMPONENTS.getCodersMap(),
mockBeamFnDataClient,
ImmutableList.of(valuesA::add, valuesB::add));
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
index 28838b1..0caf19e 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java
@@ -32,15 +32,12 @@
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.protobuf.Any;
-import com.google.protobuf.ByteString;
-import com.google.protobuf.BytesValue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -53,10 +50,11 @@
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.fn.v1.BeamFnApi;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi.MessageWithComponents;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
@@ -74,7 +72,6 @@
@RunWith(JUnit4.class)
public class BeamFnDataWriteRunnerTest {
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
.setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
@@ -83,19 +80,15 @@
private static final Coder<WindowedValue<String>> CODER =
WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
private static final RunnerApi.Coder CODER_SPEC;
+ private static final RunnerApi.Components COMPONENTS;
private static final String URN = "urn:org.apache.beam:sink:runner:0.1";
static {
try {
- CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
- RunnerApi.SdkFunctionSpec.newBuilder().setSpec(
- RunnerApi.FunctionSpec.newBuilder().setParameter(
- Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
- OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER))))
- .build()))
- .build())
- .build())
- .build();
+ MessageWithComponents coderAndComponents = CoderTranslation.toProto(CODER);
+ CODER_SPEC = coderAndComponents.getCoder();
+ COMPONENTS =
+ coderAndComponents.getComponents().toBuilder().putCoders(CODER_ID, CODER_SPEC).build();
} catch (IOException e) {
throw new ExceptionInInitializerError(e);
}
@@ -140,7 +133,7 @@
Suppliers.ofInstance(bundleId)::get,
ImmutableMap.of("inputPC",
RunnerApi.PCollection.newBuilder().setCoderId(CODER_ID).build()),
- ImmutableMap.of(CODER_ID, CODER_SPEC),
+ COMPONENTS.getCodersMap(),
consumers,
startFunctions::add,
finishFunctions::add);
@@ -201,6 +194,7 @@
bundleId::get,
OUTPUT_TARGET,
CODER_SPEC,
+ COMPONENTS.getCodersMap(),
mockBeamFnDataClient);
// Process for bundle id 0
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 98362a2..e269bcc 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -25,7 +25,6 @@
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
@@ -35,19 +34,14 @@
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
-import com.google.protobuf.Message;
-import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
import java.util.ServiceLoader;
import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
import org.apache.beam.fn.harness.fn.ThrowingConsumer;
import org.apache.beam.fn.harness.fn.ThrowingRunnable;
import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.runners.dataflow.util.DoFnInfo;
-import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.common.runner.v1.RunnerApi;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -66,28 +60,6 @@
/** Tests for {@link FnApiDoFnRunner}. */
@RunWith(JUnit4.class)
public class FnApiDoFnRunnerTest {
-
- private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
- private static final Coder<WindowedValue<String>> STRING_CODER =
- WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
- private static final String STRING_CODER_SPEC_ID = "999L";
- private static final RunnerApi.Coder STRING_CODER_SPEC;
-
- static {
- try {
- STRING_CODER_SPEC = RunnerApi.Coder.newBuilder()
- .setSpec(RunnerApi.SdkFunctionSpec.newBuilder()
- .setSpec(RunnerApi.FunctionSpec.newBuilder()
- .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
- OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER))))
- .build())))
- .build())
- .build();
- } catch (IOException e) {
- throw new ExceptionInInitializerError(e);
- }
- }
-
private static class TestDoFn extends DoFn<String, String> {
private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
@@ -117,7 +89,6 @@
*/
@Test
public void testCreatingAndProcessingDoFn() throws Exception {
- Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC);
String pTransformId = "pTransformId";
String mainOutputId = "101";
String additionalOutputId = "102";
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DataStreamsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DataStreamsTest.java
new file mode 100644
index 0000000..d141570
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DataStreamsTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.stream;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.common.collect.Iterators;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.SynchronousQueue;
+import org.apache.beam.fn.harness.stream.DataStreams.BlockingQueueIterator;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link DataStreams}. */
+@RunWith(JUnit4.class)
+public class DataStreamsTest {
+ private static final ByteString BYTES_A = ByteString.copyFromUtf8("TestData");
+ private static final ByteString BYTES_B = ByteString.copyFromUtf8("SomeOtherTestData");
+
+ @Test
+ public void testEmptyRead() throws Exception {
+ assertEquals(ByteString.EMPTY, read());
+ assertEquals(ByteString.EMPTY, read(ByteString.EMPTY));
+ assertEquals(ByteString.EMPTY, read(ByteString.EMPTY, ByteString.EMPTY));
+ }
+
+ @Test
+ public void testRead() throws Exception {
+ assertEquals(BYTES_A.concat(BYTES_B), read(BYTES_A, BYTES_B));
+ assertEquals(BYTES_A.concat(BYTES_B), read(BYTES_A, ByteString.EMPTY, BYTES_B));
+ assertEquals(BYTES_A.concat(BYTES_B), read(BYTES_A, BYTES_B, ByteString.EMPTY));
+ }
+
+ @Test(timeout = 10_000)
+ public void testBlockingQueueIteratorWithoutBlocking() throws Exception {
+ BlockingQueueIterator<String> iterator =
+ new BlockingQueueIterator<>(new ArrayBlockingQueue<>(3));
+
+ iterator.accept("A");
+ iterator.accept("B");
+ iterator.close();
+
+ assertEquals(Arrays.asList("A", "B"),
+ Arrays.asList(Iterators.toArray(iterator, String.class)));
+ }
+
+ @Test(timeout = 10_000)
+ public void testBlockingQueueIteratorWithBlocking() throws Exception {
+ // The synchronous queue only allows for one element to transfer at a time and blocks
+ // the sending/receiving parties until both parties are there.
+ final BlockingQueueIterator<String> iterator =
+ new BlockingQueueIterator<>(new SynchronousQueue<>());
+ final CompletableFuture<List<String>> valuesFuture = new CompletableFuture<>();
+ Thread appender = new Thread() {
+ @Override
+ public void run() {
+ valuesFuture.complete(Arrays.asList(Iterators.toArray(iterator, String.class)));
+ }
+ };
+ appender.start();
+ iterator.accept("A");
+ iterator.accept("B");
+ iterator.close();
+ assertEquals(Arrays.asList("A", "B"), valuesFuture.get());
+ appender.join();
+ }
+
+ private static ByteString read(ByteString... bytes) throws IOException {
+ return ByteString.readFrom(DataStreams.inbound(Arrays.asList(bytes).iterator()));
+ }
+}
diff --git a/sdks/java/io/amqp/pom.xml b/sdks/java/io/amqp/pom.xml
index 8da9448..4369bb8 100644
--- a/sdks/java/io/amqp/pom.xml
+++ b/sdks/java/io/amqp/pom.xml
@@ -39,6 +39,7 @@
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
+ <scope>test</scope>
</dependency>
<dependency>
diff --git a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java
index 1f307b2..508373f 100644
--- a/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java
+++ b/sdks/java/io/amqp/src/main/java/org/apache/beam/sdk/io/amqp/AmqpIO.java
@@ -246,7 +246,7 @@
}
@Override
- public Coder<Message> getDefaultOutputCoder() {
+ public Coder<Message> getOutputCoder() {
return new AmqpMessageCoder();
}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
index 32905b7..eacc3e4 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
@@ -289,7 +289,7 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return spec.coder();
}
diff --git a/sdks/java/io/elasticsearch/pom.xml b/sdks/java/io/elasticsearch/pom.xml
index e0a7f21..a021420 100644
--- a/sdks/java/io/elasticsearch/pom.xml
+++ b/sdks/java/io/elasticsearch/pom.xml
@@ -37,11 +37,6 @@
</dependency>
<dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-api</artifactId>
- </dependency>
-
- <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
@@ -86,11 +81,6 @@
<version>4.5.2</version>
</dependency>
- <dependency>
- <groupId>joda-time</groupId>
- <artifactId>joda-time</artifactId>
- </dependency>
-
<!-- compile dependencies -->
<dependency>
<groupId>com.google.auto.value</groupId>
@@ -133,6 +123,12 @@
<dependency>
<groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<scope>test</scope>
</dependency>
diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java
index 4d76887..5046888 100644
--- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java
+++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java
@@ -183,10 +183,8 @@
* @param index the index toward which the requests will be issued
* @param type the document type toward which the requests will be issued
* @return the connection configuration object
- * @throws IOException when it fails to connect to Elasticsearch
*/
- public static ConnectionConfiguration create(String[] addresses, String index, String type)
- throws IOException {
+ public static ConnectionConfiguration create(String[] addresses, String index, String type){
checkArgument(
addresses != null,
"ConnectionConfiguration.create(addresses, index, type) called with null address");
@@ -206,25 +204,9 @@
.setIndex(index)
.setType(type)
.build();
- checkVersion(connectionConfiguration);
return connectionConfiguration;
}
- private static void checkVersion(ConnectionConfiguration connectionConfiguration)
- throws IOException {
- RestClient restClient = connectionConfiguration.createClient();
- Response response = restClient.performRequest("GET", "", new BasicHeader("", ""));
- JsonNode jsonNode = parseResponse(response);
- String version = jsonNode.path("version").path("number").asText();
- boolean version2x = version.startsWith("2.");
- restClient.close();
- checkArgument(
- version2x,
- "ConnectionConfiguration.create(addresses, index, type): "
- + "the Elasticsearch version to connect to is different of 2.x. "
- + "This version of the ElasticsearchIO is only compatible with Elasticsearch v2.x");
- }
-
/**
* If Elasticsearch authentication is enabled, provide the username.
*
@@ -398,16 +380,20 @@
@Override
public void validate(PipelineOptions options) {
+ ConnectionConfiguration connectionConfiguration = getConnectionConfiguration();
checkState(
- getConnectionConfiguration() != null,
+ connectionConfiguration != null,
"ElasticsearchIO.read() requires a connection configuration"
+ " to be set via withConnectionConfiguration(configuration)");
+ checkVersion(connectionConfiguration);
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder.addIfNotNull(DisplayData.item("query", getQuery()));
+ builder.addIfNotNull(DisplayData.item("batchSize", getBatchSize()));
+ builder.addIfNotNull(DisplayData.item("scrollKeepalive", getScrollKeepalive()));
getConnectionConfiguration().populateDisplayData(builder);
}
}
@@ -498,7 +484,7 @@
}
@Override
- public Coder<String> getDefaultOutputCoder() {
+ public Coder<String> getOutputCoder() {
return StringUtf8Coder.of();
}
@@ -715,10 +701,12 @@
@Override
public void validate(PipelineOptions options) {
+ ConnectionConfiguration connectionConfiguration = getConnectionConfiguration();
checkState(
- getConnectionConfiguration() != null,
+ connectionConfiguration != null,
"ElasticsearchIO.write() requires a connection configuration"
+ " to be set via withConnectionConfiguration(configuration)");
+ checkVersion(connectionConfiguration);
}
@Override
@@ -828,4 +816,16 @@
}
}
}
+ private static void checkVersion(ConnectionConfiguration connectionConfiguration){
+ try (RestClient restClient = connectionConfiguration.createClient()) {
+ Response response = restClient.performRequest("GET", "", new BasicHeader("", ""));
+ JsonNode jsonNode = parseResponse(response);
+ String version = jsonNode.path("version").path("number").asText();
+ boolean version2x = version.startsWith("2.");
+ checkArgument(version2x, "The Elasticsearch version to connect to is different of 2.x. "
+ + "This version of the ElasticsearchIO is only compatible with Elasticsearch v2.x");
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Cannot check Elasticsearch version");
+ }
+ }
}
diff --git a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java
index 2a2dbe9..a6e1cc0 100644
--- a/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java
+++ b/sdks/java/io/elasticsearch/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchTestDataSet.java
@@ -18,7 +18,6 @@
package org.apache.beam.sdk.io.elasticsearch;
-import java.io.IOException;
import org.apache.beam.sdk.io.common.IOTestPipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.elasticsearch.client.RestClient;
@@ -37,7 +36,7 @@
public static final long NUM_DOCS = 60000;
public static final int AVERAGE_DOC_SIZE = 25;
public static final int MAX_DOC_SIZE = 35;
- private static String writeIndex = ES_INDEX + org.joda.time.Instant.now().getMillis();
+ private static final String writeIndex = ES_INDEX + System.currentTimeMillis();
/**
* Use this to create the index for reading before IT read tests.
@@ -63,17 +62,15 @@
}
private static void createAndPopulateReadIndex(IOTestPipelineOptions options) throws Exception {
- RestClient restClient = getConnectionConfiguration(options, ReadOrWrite.READ).createClient();
// automatically creates the index and insert docs
- try {
+ try (RestClient restClient = getConnectionConfiguration(options, ReadOrWrite.READ)
+ .createClient()) {
ElasticSearchIOTestUtils.insertTestDocuments(ES_INDEX, ES_TYPE, NUM_DOCS, restClient);
- } finally {
- restClient.close();
}
}
static ElasticsearchIO.ConnectionConfiguration getConnectionConfiguration(
- IOTestPipelineOptions options, ReadOrWrite rOw) throws IOException {
+ IOTestPipelineOptions options, ReadOrWrite rOw){
ElasticsearchIO.ConnectionConfiguration connectionConfiguration =
ElasticsearchIO.ConnectionConfiguration.create(
new String[] {
diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml
index adb7e32..7842bcd 100644
--- a/sdks/java/io/google-cloud-platform/pom.xml
+++ b/sdks/java/io/google-cloud-platform/pom.xml
@@ -46,18 +46,6 @@
</plugin>
</plugins>
</pluginManagement>
-
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <systemPropertyVariables>
- <beamUseDummyRunner>false</beamUseDummyRunner>
- </systemPropertyVariables>
- </configuration>
- </plugin>
- </plugins>
</build>
<profiles>
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index cc288e1..6edbd06 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -50,7 +50,6 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.FileSystems;
@@ -185,13 +184,16 @@
* quotes.apply(Window.<TableRow>into(CalendarWindows.days(1)))
* .apply(BigQueryIO.writeTableRows()
* .withSchema(schema)
- * .to(new SerializableFunction<ValueInSingleWindow, String>() {
- * public String apply(ValueInSingleWindow value) {
+ * .to(new SerializableFunction<ValueInSingleWindow<TableRow>, TableDestination>() {
+ * public TableDestination apply(ValueInSingleWindow<TableRow> value) {
* // The cast below is safe because CalendarWindows.days(1) produces IntervalWindows.
* String dayString = DateTimeFormat.forPattern("yyyy_MM_dd")
* .withZone(DateTimeZone.UTC)
* .print(((IntervalWindow) value.getWindow()).start());
- * return "my-project:output.output_table_" + dayString;
+ * return new TableDestination(
+ * "my-project:output.output_table_" + dayString, // Table spec
+ * "Output for day " + dayString // Table description
+ * );
* }
* }));
* }</pre>
@@ -540,9 +542,7 @@
p.apply("TriggerIdCreation", Create.of(staticJobUuid))
.apply("ViewId", View.<String>asSingleton());
// Apply the traditional Source model.
- rows =
- p.apply(org.apache.beam.sdk.io.Read.from(createSource(staticJobUuid)))
- .setCoder(getDefaultOutputCoder());
+ rows = p.apply(org.apache.beam.sdk.io.Read.from(createSource(staticJobUuid)));
} else {
// Create a singleton job ID token at execution time.
jobIdTokenCollection =
@@ -622,7 +622,8 @@
}
}
})
- .withSideInputs(schemaView, jobIdTokenView));
+ .withSideInputs(schemaView, jobIdTokenView))
+ .setCoder(TableRowJsonCoder.of());
}
PassThroughThenCleanup.CleanupOperation cleanupOperation =
new PassThroughThenCleanup.CleanupOperation() {
@@ -655,11 +656,6 @@
}
@Override
- protected Coder<TableRow> getDefaultOutputCoder() {
- return TableRowJsonCoder.of();
- }
-
- @Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder
@@ -1138,11 +1134,6 @@
}
@Override
- protected Coder<Void> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
-
- @Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java
index 2b1eafe..abe559c 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java
@@ -133,7 +133,7 @@
}
@Override
- public Coder<TableRow> getDefaultOutputCoder() {
+ public Coder<TableRow> getOutputCoder() {
return TableRowJsonCoder.of();
}
@@ -183,8 +183,8 @@
List<BoundedSource<TableRow>> avroSources = Lists.newArrayList();
for (ResourceId file : files) {
- avroSources.add(new TransformingSource<>(
- AvroSource.from(file.toString()), function, getDefaultOutputCoder()));
+ avroSources.add(
+ AvroSource.from(file.toString()).withParseFn(function, getOutputCoder()));
}
return ImmutableList.copyOf(avroSources);
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java
index c5c2462..ea4fc4e 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/DynamicDestinations.java
@@ -19,11 +19,11 @@
package org.apache.beam.sdk.io.gcp.bigquery;
import static com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.sdk.values.TypeDescriptors.extractFromTypeParameters;
import com.google.api.services.bigquery.model.TableSchema;
import com.google.common.collect.Lists;
import java.io.Serializable;
-import java.lang.reflect.TypeVariable;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
@@ -32,6 +32,7 @@
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.values.ValueInSingleWindow;
/**
@@ -157,17 +158,16 @@
return destinationCoder;
}
// If dynamicDestinations doesn't provide a coder, try to find it in the coder registry.
- // We must first use reflection to figure out what the type parameter is.
- TypeDescriptor<?> superDescriptor =
- TypeDescriptor.of(getClass()).getSupertype(DynamicDestinations.class);
- if (!superDescriptor.getRawType().equals(DynamicDestinations.class)) {
- throw new AssertionError(
- "Couldn't find the DynamicDestinations superclass of " + this.getClass());
- }
- TypeVariable typeVariable = superDescriptor.getTypeParameter("DestinationT");
- @SuppressWarnings("unchecked")
TypeDescriptor<DestinationT> descriptor =
- (TypeDescriptor<DestinationT>) superDescriptor.resolveType(typeVariable);
+ extractFromTypeParameters(
+ this,
+ DynamicDestinations.class,
+ new TypeDescriptors.TypeVariableExtractor<
+ DynamicDestinations<T, DestinationT>, DestinationT>() {});
+ checkArgument(
+ descriptor != null,
+ "Unable to infer a coder for DestinationT, "
+ + "please specify it explicitly by overriding getDestinationCoder()");
return registry.getCoder(descriptor);
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/PassThroughThenCleanup.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/PassThroughThenCleanup.java
index de26c8d..2f7da08 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/PassThroughThenCleanup.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/PassThroughThenCleanup.java
@@ -74,7 +74,7 @@
})
.withSideInputs(jobIdSideInput, cleanupSignalView));
- return outputs.get(mainOutput);
+ return outputs.get(mainOutput).setCoder(input.getCoder());
}
private static class IdentityFn<T> extends DoFn<T, T> {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingInserts.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingInserts.java
index ba09cb3..747f2b0 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingInserts.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingInserts.java
@@ -19,8 +19,6 @@
package org.apache.beam.sdk.io.gcp.bigquery;
import com.google.api.services.bigquery.model.TableRow;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.KV;
@@ -67,12 +65,6 @@
return new StreamingInserts<>(
createDisposition, dynamicDestinations, bigQueryServices, retryPolicy); }
-
- @Override
- protected Coder<Void> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
-
@Override
public WriteResult expand(PCollection<KV<DestinationT, TableRow>> input) {
PCollection<KV<TableDestination, TableRow>> writes =
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TransformingSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TransformingSource.java
deleted file mode 100644
index b8e6b39..0000000
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TransformingSource.java
+++ /dev/null
@@ -1,136 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.sdk.io.gcp.bigquery;
-
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Function;
-import com.google.common.collect.Lists;
-import java.io.IOException;
-import java.util.List;
-import java.util.NoSuchElementException;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.joda.time.Instant;
-
-/**
- * A {@link BoundedSource} that reads from {@code BoundedSource<T>}
- * and transforms elements to type {@code V}.
-*/
-@VisibleForTesting
-class TransformingSource<T, V> extends BoundedSource<V> {
- private final BoundedSource<T> boundedSource;
- private final SerializableFunction<T, V> function;
- private final Coder<V> outputCoder;
-
- TransformingSource(
- BoundedSource<T> boundedSource,
- SerializableFunction<T, V> function,
- Coder<V> outputCoder) {
- this.boundedSource = checkNotNull(boundedSource, "boundedSource");
- this.function = checkNotNull(function, "function");
- this.outputCoder = checkNotNull(outputCoder, "outputCoder");
- }
-
- @Override
- public List<? extends BoundedSource<V>> split(
- long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
- return Lists.transform(
- boundedSource.split(desiredBundleSizeBytes, options),
- new Function<BoundedSource<T>, BoundedSource<V>>() {
- @Override
- public BoundedSource<V> apply(BoundedSource<T> input) {
- return new TransformingSource<>(input, function, outputCoder);
- }
- });
- }
-
- @Override
- public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
- return boundedSource.getEstimatedSizeBytes(options);
- }
-
- @Override
- public BoundedReader<V> createReader(PipelineOptions options) throws IOException {
- return new TransformingReader(boundedSource.createReader(options));
- }
-
- @Override
- public void validate() {
- boundedSource.validate();
- }
-
- @Override
- public Coder<V> getDefaultOutputCoder() {
- return outputCoder;
- }
-
- private class TransformingReader extends BoundedReader<V> {
- private final BoundedReader<T> boundedReader;
-
- private TransformingReader(BoundedReader<T> boundedReader) {
- this.boundedReader = checkNotNull(boundedReader, "boundedReader");
- }
-
- @Override
- public synchronized BoundedSource<V> getCurrentSource() {
- return new TransformingSource<>(boundedReader.getCurrentSource(), function, outputCoder);
- }
-
- @Override
- public boolean start() throws IOException {
- return boundedReader.start();
- }
-
- @Override
- public boolean advance() throws IOException {
- return boundedReader.advance();
- }
-
- @Override
- public V getCurrent() throws NoSuchElementException {
- T current = boundedReader.getCurrent();
- return function.apply(current);
- }
-
- @Override
- public void close() throws IOException {
- boundedReader.close();
- }
-
- @Override
- public synchronized BoundedSource<V> splitAtFraction(double fraction) {
- BoundedSource<T> split = boundedReader.splitAtFraction(fraction);
- return split == null ? null : new TransformingSource<>(split, function, outputCoder);
- }
-
- @Override
- public Double getFractionConsumed() {
- return boundedReader.getFractionConsumed();
- }
-
- @Override
- public Instant getCurrentTimestamp() throws NoSuchElementException {
- return boundedReader.getCurrentTimestamp();
- }
- }
-}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java
index 0a90dde..c5b0fbf 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java
@@ -893,7 +893,7 @@
}
@Override
- public Coder<Row> getDefaultOutputCoder() {
+ public Coder<Row> getOutputCoder() {
return ProtoCoder.of(Row.class);
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java
index 1ed6430..7e40db4 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java
@@ -71,7 +71,7 @@
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
-import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
@@ -99,7 +99,6 @@
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
-import org.apache.beam.sdk.values.TypeDescriptor;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -611,10 +610,10 @@
if (getQuery() != null) {
inputQuery = input.apply(Create.of(getQuery()));
} else {
- inputQuery = input
- .apply(Create.of(getLiteralGqlQuery())
- .withCoder(SerializableCoder.of(new TypeDescriptor<ValueProvider<String>>() {})))
- .apply(ParDo.of(new GqlQueryTranslateFn(v1Options)));
+ inputQuery =
+ input
+ .apply(Create.ofProvider(getLiteralGqlQuery(), StringUtf8Coder.of()))
+ .apply(ParDo.of(new GqlQueryTranslateFn(v1Options)));
}
PCollection<KV<Integer, Query>> splitQueries = inputQuery
@@ -730,7 +729,7 @@
/**
* A DoFn that translates a Cloud Datastore gql query string to {@code Query}.
*/
- static class GqlQueryTranslateFn extends DoFn<ValueProvider<String>, Query> {
+ static class GqlQueryTranslateFn extends DoFn<String, Query> {
private final V1Options v1Options;
private transient Datastore datastore;
private final V1DatastoreFactory datastoreFactory;
@@ -751,9 +750,9 @@
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
- ValueProvider<String> gqlQuery = c.element();
- LOG.info("User query: '{}'", gqlQuery.get());
- Query query = translateGqlQueryWithLimitCheck(gqlQuery.get(), datastore,
+ String gqlQuery = c.element();
+ LOG.info("User query: '{}'", gqlQuery);
+ Query query = translateGqlQueryWithLimitCheck(gqlQuery, datastore,
v1Options.getNamespace());
LOG.info("User gql query translated to Query({})", query);
c.output(query);
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
index 4f33d61..46c2df4 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
@@ -36,7 +36,6 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.extensions.protobuf.ProtoCoder;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.OutgoingMessage;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.ProjectPath;
@@ -727,7 +726,7 @@
getTimestampAttribute(),
getIdAttribute(),
getNeedsAttributes());
- return input.apply(source).apply(MapElements.via(getParseFn()));
+ return input.apply(source).apply(MapElements.via(getParseFn())).setCoder(getCoder());
}
@Override
@@ -743,11 +742,6 @@
.withLabel("Pubsub Subscription"));
}
}
-
- @Override
- protected Coder<T> getDefaultOutputCoder() {
- return getCoder();
- }
}
/////////////////////////////////////////////////////////////////////////////
@@ -870,11 +864,6 @@
builder, getTimestampAttribute(), getIdAttribute(), getTopicProvider());
}
- @Override
- protected Coder<Void> getDefaultOutputCoder() {
- return VoidCoder.of();
- }
-
/**
* Writer to Pubsub which batches messages from bounded collections.
*
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java
index b7df804..8da6ff4 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java
@@ -1164,7 +1164,7 @@
}
@Override
- public Coder<PubsubMessage> getDefaultOutputCoder() {
+ public Coder<PubsubMessage> getOutputCoder() {
return outer.getNeedsAttributes()
? PubsubMessageWithAttributesCoder.of()
: PubsubMessagePayloadOnlyCoder.of();
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
index 00008f1..50efdea 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/AbstractSpannerFn.java
@@ -22,12 +22,16 @@
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.util.ReleaseInfo;
/**
* Abstract {@link DoFn} that manages {@link Spanner} lifecycle. Use {@link
* AbstractSpannerFn#databaseClient} to access the Cloud Spanner database client.
*/
abstract class AbstractSpannerFn<InputT, OutputT> extends DoFn<InputT, OutputT> {
+ // A common user agent token that indicates that this request was originated from Apache Beam.
+ private static final String USER_AGENT_PREFIX = "Apache_Beam_Java";
+
private transient Spanner spanner;
private transient DatabaseClient databaseClient;
@@ -36,7 +40,16 @@
@Setup
public void setup() throws Exception {
SpannerConfig spannerConfig = getSpannerConfig();
- SpannerOptions options = spannerConfig.buildSpannerOptions();
+ SpannerOptions.Builder builder = SpannerOptions.newBuilder();
+ if (spannerConfig.getProjectId() != null) {
+ builder.setProjectId(spannerConfig.getProjectId().get());
+ }
+ if (spannerConfig.getServiceFactory() != null) {
+ builder.setServiceFactory(spannerConfig.getServiceFactory());
+ }
+ ReleaseInfo releaseInfo = ReleaseInfo.getReleaseInfo();
+ builder.setUserAgentPrefix(USER_AGENT_PREFIX + "/" + releaseInfo.getVersion());
+ SpannerOptions options = builder.build();
spanner = options.getService();
databaseClient = spanner.getDatabaseClient(DatabaseId
.of(options.getProjectId(), spannerConfig.getInstanceId().get(),
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
index 02716fb..034c38a 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
@@ -49,17 +49,6 @@
abstract Builder toBuilder();
- SpannerOptions buildSpannerOptions() {
- SpannerOptions.Builder builder = SpannerOptions.newBuilder();
- if (getProjectId() != null) {
- builder.setProjectId(getProjectId().get());
- }
- if (getServiceFactory() != null) {
- builder.setServiceFactory(getServiceFactory());
- }
- return builder.build();
- }
-
public static SpannerConfig create() {
return builder().build();
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java
index 3465b4e..8db4e94 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java
@@ -86,7 +86,6 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.io.CountingSource;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.fs.ResourceId;
@@ -1510,8 +1509,6 @@
// Simulate a repeated call to split(), like a Dataflow worker will sometimes do.
sources = bqSource.split(200, options);
assertEquals(2, sources.size());
- BoundedSource<TableRow> actual = sources.get(0);
- assertThat(actual, CoreMatchers.instanceOf(TransformingSource.class));
// A repeated call to split() should not have caused a duplicate extract job.
assertEquals(1, fakeJobService.getNumExtractJobCalls());
@@ -1594,8 +1591,6 @@
List<? extends BoundedSource<TableRow>> sources = bqSource.split(100, options);
assertEquals(2, sources.size());
- BoundedSource<TableRow> actual = sources.get(0);
- assertThat(actual, CoreMatchers.instanceOf(TransformingSource.class));
}
@Test
@@ -1673,69 +1668,6 @@
List<? extends BoundedSource<TableRow>> sources = bqSource.split(100, options);
assertEquals(2, sources.size());
- BoundedSource<TableRow> actual = sources.get(0);
- assertThat(actual, CoreMatchers.instanceOf(TransformingSource.class));
- }
-
- @Test
- public void testTransformingSource() throws Exception {
- int numElements = 10000;
- @SuppressWarnings("deprecation")
- BoundedSource<Long> longSource = CountingSource.upTo(numElements);
- SerializableFunction<Long, String> toStringFn =
- new SerializableFunction<Long, String>() {
- @Override
- public String apply(Long input) {
- return input.toString();
- }};
- BoundedSource<String> stringSource = new TransformingSource<>(
- longSource, toStringFn, StringUtf8Coder.of());
-
- List<String> expected = Lists.newArrayList();
- for (int i = 0; i < numElements; i++) {
- expected.add(String.valueOf(i));
- }
-
- PipelineOptions options = PipelineOptionsFactory.create();
- Assert.assertThat(
- SourceTestUtils.readFromSource(stringSource, options),
- CoreMatchers.is(expected));
- SourceTestUtils.assertSplitAtFractionBehavior(
- stringSource, 100, 0.3, ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT, options);
-
- SourceTestUtils.assertSourcesEqualReferenceSource(
- stringSource, stringSource.split(100, options), options);
- }
-
- @Test
- public void testTransformingSourceUnsplittable() throws Exception {
- int numElements = 10000;
- @SuppressWarnings("deprecation")
- BoundedSource<Long> longSource =
- SourceTestUtils.toUnsplittableSource(CountingSource.upTo(numElements));
- SerializableFunction<Long, String> toStringFn =
- new SerializableFunction<Long, String>() {
- @Override
- public String apply(Long input) {
- return input.toString();
- }
- };
- BoundedSource<String> stringSource =
- new TransformingSource<>(longSource, toStringFn, StringUtf8Coder.of());
-
- List<String> expected = Lists.newArrayList();
- for (int i = 0; i < numElements; i++) {
- expected.add(String.valueOf(i));
- }
-
- PipelineOptions options = PipelineOptionsFactory.create();
- Assert.assertThat(
- SourceTestUtils.readFromSource(stringSource, options), CoreMatchers.is(expected));
- SourceTestUtils.assertSplitAtFractionBehavior(
- stringSource, 100, 0.3, ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS, options);
-
- SourceTestUtils.assertSourcesEqualReferenceSource(
- stringSource, stringSource.split(100, options), options);
}
@Test
diff --git a/sdks/java/io/hadoop-file-system/pom.xml b/sdks/java/io/hadoop-file-system/pom.xml
index a9c2e57..3cc7e00 100644
--- a/sdks/java/io/hadoop-file-system/pom.xml
+++ b/sdks/java/io/hadoop-file-system/pom.xml
@@ -30,20 +30,6 @@
<name>Apache Beam :: SDKs :: Java :: IO :: Hadoop File System</name>
<description>Library to read and write Hadoop/HDFS file formats from Beam.</description>
- <build>
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <systemPropertyVariables>
- <beamUseDummyRunner>false</beamUseDummyRunner>
- </systemPropertyVariables>
- </configuration>
- </plugin>
- </plugins>
- </build>
-
<dependencies>
<dependency>
<groupId>org.apache.beam</groupId>
diff --git a/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java b/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java
index 0b4c23f..20ca50a 100644
--- a/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java
+++ b/sdks/java/io/hadoop/input-format/src/main/java/org/apache/beam/sdk/io/hadoop/inputformat/HadoopInputFormatIO.java
@@ -552,7 +552,7 @@
}
@Override
- public Coder<KV<K, V>> getDefaultOutputCoder() {
+ public Coder<KV<K, V>> getOutputCoder() {
return KvCoder.of(keyCoder, valueCoder);
}
diff --git a/sdks/java/io/hadoop/jdk1.8-tests/pom.xml b/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
index 12944f4..8df2552 100644
--- a/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
+++ b/sdks/java/io/hadoop/jdk1.8-tests/pom.xml
@@ -39,7 +39,6 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
- <version>1.4.1</version>
<executions>
<execution>
<id>enforce</id>
diff --git a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java
index 90ede4c..2ba6826 100644
--- a/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java
+++ b/sdks/java/io/hbase/src/main/java/org/apache/beam/sdk/io/hbase/HBaseIO.java
@@ -457,7 +457,7 @@
}
@Override
- public Coder<Result> getDefaultOutputCoder() {
+ public Coder<Result> getOutputCoder() {
return HBaseResultCoder.of();
}
}
diff --git a/sdks/java/io/hcatalog/pom.xml b/sdks/java/io/hcatalog/pom.xml
index 2aa661e..34e60da 100644
--- a/sdks/java/io/hcatalog/pom.xml
+++ b/sdks/java/io/hcatalog/pom.xml
@@ -61,6 +61,7 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
+ <scope>provided</scope>
<exclusions>
<!-- Fix build on JDK-9 -->
<exclusion>
@@ -71,12 +72,6 @@
</dependency>
<dependency>
- <groupId>commons-io</groupId>
- <artifactId>commons-io</artifactId>
- <version>${apache.commons.version}</version>
- </dependency>
-
- <dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
@@ -95,6 +90,7 @@
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>${hive.version}</version>
+ <scope>provided</scope>
</dependency>
<dependency>
@@ -107,6 +103,7 @@
<groupId>org.apache.hive.hcatalog</groupId>
<artifactId>hive-hcatalog-core</artifactId>
<version>${hive.version}</version>
+ <scope>provided</scope>
<exclusions>
<exclusion>
<groupId>org.apache.hive</groupId>
@@ -133,6 +130,13 @@
</dependency>
<dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ <version>${apache.commons.version}</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
diff --git a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java
index 4199b805..d8e462b 100644
--- a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java
+++ b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatalogIO.java
@@ -210,7 +210,7 @@
@Override
@SuppressWarnings({"unchecked", "rawtypes"})
- public Coder<HCatRecord> getDefaultOutputCoder() {
+ public Coder<HCatRecord> getOutputCoder() {
return (Coder) WritableCoder.of(DefaultHCatRecord.class);
}
diff --git a/sdks/java/io/jdbc/pom.xml b/sdks/java/io/jdbc/pom.xml
index 357ddc0..c559ad4 100644
--- a/sdks/java/io/jdbc/pom.xml
+++ b/sdks/java/io/jdbc/pom.xml
@@ -257,11 +257,6 @@
</dependency>
<dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-api</artifactId>
- </dependency>
-
- <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
@@ -320,6 +315,11 @@
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<scope>test</scope>
</dependency>
diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index bf73dbe..51f34ae 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -31,7 +31,9 @@
import javax.sql.DataSource;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
@@ -272,7 +274,7 @@
@AutoValue
public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> {
@Nullable abstract DataSourceConfiguration getDataSourceConfiguration();
- @Nullable abstract String getQuery();
+ @Nullable abstract ValueProvider<String> getQuery();
@Nullable abstract StatementPreparator getStatementPreparator();
@Nullable abstract RowMapper<T> getRowMapper();
@Nullable abstract Coder<T> getCoder();
@@ -282,7 +284,7 @@
@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setDataSourceConfiguration(DataSourceConfiguration config);
- abstract Builder<T> setQuery(String query);
+ abstract Builder<T> setQuery(ValueProvider<String> query);
abstract Builder<T> setStatementPreparator(StatementPreparator statementPreparator);
abstract Builder<T> setRowMapper(RowMapper<T> rowMapper);
abstract Builder<T> setCoder(Coder<T> coder);
@@ -297,6 +299,11 @@
public Read<T> withQuery(String query) {
checkArgument(query != null, "JdbcIO.read().withQuery(query) called with null query");
+ return withQuery(ValueProvider.StaticValueProvider.of(query));
+ }
+
+ public Read<T> withQuery(ValueProvider<String> query) {
+ checkArgument(query != null, "JdbcIO.read().withQuery(query) called with null query");
return toBuilder().setQuery(query).build();
}
@@ -321,7 +328,7 @@
@Override
public PCollection<T> expand(PBegin input) {
return input
- .apply(Create.of(getQuery()))
+ .apply(Create.ofProvider(getQuery(), StringUtf8Coder.of()))
.apply(ParDo.of(new ReadFn<>(this))).setCoder(getCoder())
.apply(ParDo.of(new DoFn<T, KV<Integer, T>>() {
private Random random;
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
index f8cba5e..2af0ce9 100644
--- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
@@ -373,7 +373,7 @@
}
@Override
- public Coder<JmsRecord> getDefaultOutputCoder() {
+ public Coder<JmsRecord> getOutputCoder() {
return SerializableCoder.of(JmsRecord.class);
}
diff --git a/sdks/java/io/kafka/pom.xml b/sdks/java/io/kafka/pom.xml
index 1256c46..3902510 100644
--- a/sdks/java/io/kafka/pom.xml
+++ b/sdks/java/io/kafka/pom.xml
@@ -46,18 +46,6 @@
</plugin>
</plugins>
</pluginManagement>
-
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <systemPropertyVariables>
- <beamUseDummyRunner>false</beamUseDummyRunner>
- </systemPropertyVariables>
- </configuration>
- </plugin>
- </plugins>
</build>
<dependencies>
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 026313a..7fb4260 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -844,7 +844,7 @@
}
@Override
- public Coder<KafkaRecord<K, V>> getDefaultOutputCoder() {
+ public Coder<KafkaRecord<K, V>> getOutputCoder() {
return KafkaRecordCoder.of(spec.getKeyCoder(), spec.getValueCoder());
}
}
diff --git a/sdks/java/io/kinesis/pom.xml b/sdks/java/io/kinesis/pom.xml
index 46d5e26..872c590 100644
--- a/sdks/java/io/kinesis/pom.xml
+++ b/sdks/java/io/kinesis/pom.xml
@@ -31,16 +31,6 @@
<build>
<plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <configuration>
- <systemPropertyVariables>
- <beamUseDummyRunner>false</beamUseDummyRunner>
- </systemPropertyVariables>
- </configuration>
- </plugin>
-
<!-- Integration Tests -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java
index 362792b..144bd80 100644
--- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java
@@ -107,7 +107,7 @@
}
@Override
- public Coder<KinesisRecord> getDefaultOutputCoder() {
+ public Coder<KinesisRecord> getOutputCoder() {
return KinesisRecordCoder.of();
}
}
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java
index 5b5412c..c612d52 100644
--- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbGridFSIO.java
@@ -440,7 +440,7 @@
}
@Override
- public Coder<ObjectId> getDefaultOutputCoder() {
+ public Coder<ObjectId> getOutputCoder() {
return SerializableCoder.of(ObjectId.class);
}
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
index 3b14182..087123a 100644
--- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
@@ -277,7 +277,7 @@
}
@Override
- public Coder<Document> getDefaultOutputCoder() {
+ public Coder<Document> getOutputCoder() {
return SerializableCoder.of(Document.class);
}
diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
index add5cb5..5aadb80 100644
--- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
@@ -387,7 +387,7 @@
}
@Override
- public Coder<byte[]> getDefaultOutputCoder() {
+ public Coder<byte[]> getOutputCoder() {
return ByteArrayCoder.of();
}
}
diff --git a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java
index 442fba5..7255a94 100644
--- a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java
+++ b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlIO.java
@@ -36,7 +36,6 @@
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
@@ -522,8 +521,7 @@
@Override
public PDone expand(PCollection<T> input) {
- return input.apply(
- org.apache.beam.sdk.io.WriteFiles.to(createSink(), SerializableFunctions.<T>identity()));
+ return input.apply(org.apache.beam.sdk.io.WriteFiles.to(createSink()));
}
@VisibleForTesting
diff --git a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java
index 74e0bda..b663544 100644
--- a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java
+++ b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSink.java
@@ -35,7 +35,7 @@
import org.apache.beam.sdk.util.MimeTypes;
/** Implementation of {@link XmlIO#write}. */
-class XmlSink<T> extends FileBasedSink<T, Void> {
+class XmlSink<T> extends FileBasedSink<T, Void, T> {
private static final String XML_EXTENSION = ".xml";
private final XmlIO.Write<T> spec;
@@ -46,7 +46,7 @@
}
XmlSink(XmlIO.Write<T> spec) {
- super(spec.getFilenamePrefix(), DynamicFileDestinations.constant(makeFilenamePolicy(spec)));
+ super(spec.getFilenamePrefix(), DynamicFileDestinations.<T>constant(makeFilenamePolicy(spec)));
this.spec = spec;
}
@@ -77,7 +77,7 @@
}
/** {@link WriteOperation} for XML {@link FileBasedSink}s. */
- protected static final class XmlWriteOperation<T> extends WriteOperation<T, Void> {
+ protected static final class XmlWriteOperation<T> extends WriteOperation<Void, T> {
public XmlWriteOperation(XmlSink<T> sink) {
super(sink);
}
@@ -112,7 +112,7 @@
}
/** A {@link Writer} that can write objects as XML elements. */
- protected static final class XmlWriter<T> extends Writer<T, Void> {
+ protected static final class XmlWriter<T> extends Writer<Void, T> {
final Marshaller marshaller;
private OutputStream os = null;
diff --git a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSource.java b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSource.java
index 7aa42c5..b893d43 100644
--- a/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSource.java
+++ b/sdks/java/io/xml/src/main/java/org/apache/beam/sdk/io/xml/XmlSource.java
@@ -85,7 +85,7 @@
}
@Override
- public Coder<T> getDefaultOutputCoder() {
+ public Coder<T> getOutputCoder() {
return JAXBCoder.of(spec.getRecordClass());
}
diff --git a/sdks/java/javadoc/pom.xml b/sdks/java/javadoc/pom.xml
index 51109fb..e1adb79 100644
--- a/sdks/java/javadoc/pom.xml
+++ b/sdks/java/javadoc/pom.xml
@@ -79,6 +79,11 @@
<dependency>
<groupId>org.apache.beam</groupId>
+ <artifactId>beam-runners-gearpump</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.beam</groupId>
<artifactId>beam-sdks-java-core</artifactId>
</dependency>
diff --git a/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml b/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml
index 6056fb0..91da6eb 100644
--- a/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml
+++ b/sdks/java/maven-archetypes/starter/src/test/resources/projects/basic/reference/pom.xml
@@ -27,9 +27,9 @@
<properties>
<beam.version>@project.version@</beam.version>
- <maven-compiler-plugin.version>3.6.1</maven-compiler-plugin.version>
- <maven-exec-plugin.version>1.6.0</maven-exec-plugin.version>
- <slf4j.version>1.7.14</slf4j.version>
+ <maven-compiler-plugin.version>@maven-compiler-plugin.version@</maven-compiler-plugin.version>
+ <maven-exec-plugin.version>@maven-exec-plugin.version@</maven-exec-plugin.version>
+ <slf4j.version>@slf4j.version@</slf4j.version>
</properties>
<repositories>
diff --git a/sdks/python/apache_beam/coders/stream.pxd b/sdks/python/apache_beam/coders/stream.pxd
index 4e01a89..ade9b72 100644
--- a/sdks/python/apache_beam/coders/stream.pxd
+++ b/sdks/python/apache_beam/coders/stream.pxd
@@ -53,7 +53,7 @@
cdef bytes all
cdef char* allc
- cpdef size_t size(self) except? -1
+ cpdef ssize_t size(self) except? -1
cpdef bytes read(self, size_t len)
cpdef long read_byte(self) except? -1
cpdef libc.stdint.int64_t read_var_int64(self) except? -1
diff --git a/sdks/python/apache_beam/coders/stream.pyx b/sdks/python/apache_beam/coders/stream.pyx
index 8d97681..7c9521a 100644
--- a/sdks/python/apache_beam/coders/stream.pyx
+++ b/sdks/python/apache_beam/coders/stream.pyx
@@ -167,7 +167,7 @@
# unsigned char here.
return <long>(<unsigned char> self.allc[self.pos - 1])
- cpdef size_t size(self) except? -1:
+ cpdef ssize_t size(self) except? -1:
return len(self.all) - self.pos
cpdef bytes read_all(self, bint nested=False):
diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_it_test.py
index 5d2ee7c..05ee3c5 100644
--- a/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_it_test.py
+++ b/sdks/python/apache_beam/examples/cookbook/bigquery_tornadoes_it_test.py
@@ -26,6 +26,7 @@
from apache_beam.examples.cookbook import bigquery_tornadoes
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryMatcher
+from apache_beam.io.gcp.tests import utils
from apache_beam.testing.pipeline_verifiers import PipelineStateMatcher
from apache_beam.testing.test_pipeline import TestPipeline
@@ -44,17 +45,24 @@
test_pipeline = TestPipeline(is_integration_test=True)
# Set extra options to the pipeline for test purpose
- output_table = ('BigQueryTornadoesIT'
- '.monthly_tornadoes_%s' % int(round(time.time() * 1000)))
+ project = test_pipeline.get_option('project')
+
+ dataset = 'BigQueryTornadoesIT'
+ table = 'monthly_tornadoes_%s' % int(round(time.time() * 1000))
+ output_table = '.'.join([dataset, table])
query = 'SELECT month, tornado_count FROM [%s]' % output_table
+
pipeline_verifiers = [PipelineStateMatcher(),
BigqueryMatcher(
- project=test_pipeline.get_option('project'),
+ project=project,
query=query,
checksum=self.DEFAULT_CHECKSUM)]
extra_opts = {'output': output_table,
'on_success_matcher': all_of(*pipeline_verifiers)}
+ # Register cleanup before pipeline execution.
+ self.addCleanup(utils.delete_bq_table, project, dataset, table)
+
# Get pipeline options from command argument: --test-pipeline-options,
# and start pipeline job by calling pipeline main function.
bigquery_tornadoes.run(
diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py
index 31f71b3..9183d0d 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets_test.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py
@@ -589,22 +589,6 @@
snippets.model_textio_compressed(
{'read': gzip_file_name}, ['aa', 'bb', 'cc'])
- def test_model_textio_gzip_concatenated(self):
- temp_path_1 = self.create_temp_file('a\nb\nc\n')
- temp_path_2 = self.create_temp_file('p\nq\nr\n')
- temp_path_3 = self.create_temp_file('x\ny\nz')
- gzip_file_name = temp_path_1 + '.gz'
- with open(temp_path_1) as src, gzip.open(gzip_file_name, 'wb') as dst:
- dst.writelines(src)
- with open(temp_path_2) as src, gzip.open(gzip_file_name, 'ab') as dst:
- dst.writelines(src)
- with open(temp_path_3) as src, gzip.open(gzip_file_name, 'ab') as dst:
- dst.writelines(src)
- # Add the temporary gzip file to be cleaned up as well.
- self.temp_files.append(gzip_file_name)
- snippets.model_textio_compressed(
- {'read': gzip_file_name}, ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z'])
-
@unittest.skipIf(datastore_pb2 is None, 'GCP dependencies are not installed')
def test_model_datastoreio(self):
# We cannot test datastoreio functionality in unit tests therefore we limit
diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py
index 1f65d0a..ef3040c 100644
--- a/sdks/python/apache_beam/io/filesystem.py
+++ b/sdks/python/apache_beam/io/filesystem.py
@@ -187,29 +187,26 @@
del buf # Free up some possibly large and no-longer-needed memory.
self._read_buffer.write(decompressed)
else:
- # EOF reached.
- # Verify completeness and no corruption and flush (if needed by
- # the underlying algorithm).
- if self._compression_type == CompressionTypes.BZIP2:
- # Having unused_data past end of stream would imply file corruption.
- assert not self._decompressor.unused_data, 'Possible file corruption.'
- try:
- # EOF implies that the underlying BZIP2 stream must also have
- # reached EOF. We expect this to raise an EOFError and we catch it
- # below. Any other kind of error though would be problematic.
- self._decompressor.decompress('dummy')
- assert False, 'Possible file corruption.'
- except EOFError:
- pass # All is as expected!
- elif self._compression_type == CompressionTypes.GZIP:
- # If Gzip file check if there is unused data generated by gzip concat
+ # EOF of current stream reached.
+ #
+ # Any uncompressed data at the end of the stream of a gzip or bzip2
+ # file that is not corrupted points to a concatenated compressed
+ # file. We read concatenated files by recursively creating decompressor
+ # objects for the unused compressed data.
+ if (self._compression_type == CompressionTypes.BZIP2 or
+ self._compression_type == CompressionTypes.GZIP):
if self._decompressor.unused_data != '':
buf = self._decompressor.unused_data
- self._decompressor = zlib.decompressobj(self._gzip_mask)
+ self._decompressor = (
+ bz2.BZ2Decompressor()
+ if self._compression_type == CompressionTypes.BZIP2
+ else zlib.decompressobj(self._gzip_mask))
decompressed = self._decompressor.decompress(buf)
self._read_buffer.write(decompressed)
continue
else:
+ # Gzip and bzip2 formats do not require flushing remaining data in the
+ # decompressor into the read buffer when fully decompressing files.
self._read_buffer.write(self._decompressor.flush())
# Record that we have hit the end of file, so we won't unnecessarily
diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py
index 23fd310..db6715a 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -1002,12 +1002,23 @@
if found_table and write_disposition != BigQueryDisposition.WRITE_TRUNCATE:
return found_table
else:
+ created_table = self._create_table(project_id=project_id,
+ dataset_id=dataset_id,
+ table_id=table_id,
+ schema=schema or found_table.schema)
# if write_disposition == BigQueryDisposition.WRITE_TRUNCATE we delete
# the table before this point.
- return self._create_table(project_id=project_id,
- dataset_id=dataset_id,
- table_id=table_id,
- schema=schema or found_table.schema)
+ if write_disposition == BigQueryDisposition.WRITE_TRUNCATE:
+ # BigQuery can route data to the old table for 2 mins max so wait
+ # that much time before creating the table and writing it
+ logging.warning('Sleeping for 150 seconds before the write as ' +
+ 'BigQuery inserts can be routed to deleted table ' +
+ 'for 2 mins after the delete and create.')
+ # TODO(BEAM-2673): Remove this sleep by migrating to load api
+ time.sleep(150)
+ return created_table
+ else:
+ return created_table
def run_query(self, project_id, query, use_legacy_sql, flatten_results,
dry_run=False):
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index 14247ba..bfd06ac 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -650,7 +650,8 @@
self.assertFalse(client.tables.Delete.called)
self.assertFalse(client.tables.Insert.called)
- def test_table_with_write_disposition_truncate(self):
+ @mock.patch('time.sleep', return_value=None)
+ def test_table_with_write_disposition_truncate(self, _patched_sleep):
client = mock.Mock()
table = bigquery.Table(
tableReference=bigquery.TableReference(
diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py
index 32d388a..7d1f355 100644
--- a/sdks/python/apache_beam/io/gcp/pubsub.py
+++ b/sdks/python/apache_beam/io/gcp/pubsub.py
@@ -183,6 +183,9 @@
raise NotImplementedError(
'PubSubPayloadSource is not supported in local execution.')
+ def is_bounded(self):
+ return False
+
class _PubSubPayloadSink(dataflow_io.NativeSink):
"""Sink for the payload of a message as bytes to a Cloud Pub/Sub topic."""
diff --git a/sdks/python/apache_beam/io/gcp/tests/utils.py b/sdks/python/apache_beam/io/gcp/tests/utils.py
new file mode 100644
index 0000000..40eb975
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/tests/utils.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+"""Utility methods for testing on GCP."""
+
+import logging
+
+from apache_beam.utils import retry
+
+# Protect against environments where bigquery library is not available.
+try:
+ from google.cloud import bigquery
+except ImportError:
+ bigquery = None
+
+
+class GcpTestIOError(retry.PermanentException):
+ """Basic GCP IO error for testing. Function that raises this error should
+ not be retried."""
+ pass
+
+
+@retry.with_exponential_backoff(
+ num_retries=3,
+ retry_filter=retry.retry_on_server_errors_filter)
+def delete_bq_table(project, dataset, table):
+ """Delete a Biqquery table.
+
+ Args:
+ project: Name of the project.
+ dataset: Name of the dataset where table is.
+ table: Name of the table.
+ """
+ logging.info('Clean up a Bigquery table with project: %s, dataset: %s, '
+ 'table: %s.', project, dataset, table)
+ bq_dataset = bigquery.Client(project=project).dataset(dataset)
+ if not bq_dataset.exists():
+ raise GcpTestIOError('Failed to cleanup. Bigquery dataset %s doesn\'t'
+ 'exist in project %s.' % dataset, project)
+ bq_table = bq_dataset.table(table)
+ if not bq_table.exists():
+ raise GcpTestIOError('Failed to cleanup. Biqeury table %s doesn\'t '
+ 'exist in project %s, dataset %s.' %
+ table, project, dataset)
+ bq_table.delete()
+ if bq_table.exists():
+ raise RuntimeError('Failed to cleanup. Bigquery table %s still exists '
+ 'after cleanup.' % table)
diff --git a/sdks/python/apache_beam/io/gcp/tests/utils_test.py b/sdks/python/apache_beam/io/gcp/tests/utils_test.py
new file mode 100644
index 0000000..270750a
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/tests/utils_test.py
@@ -0,0 +1,70 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unittest for GCP testing utils."""
+
+import logging
+import unittest
+from mock import Mock, patch
+
+from apache_beam.io.gcp.tests import utils
+from apache_beam.testing.test_utils import patch_retry
+
+# Protect against environments where bigquery library is not available.
+try:
+ from google.cloud import bigquery
+except ImportError:
+ bigquery = None
+
+
+@unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.')
+class UtilsTest(unittest.TestCase):
+
+ def setUp(self):
+ self._mock_result = Mock()
+ patch_retry(self, utils)
+
+ @patch('google.cloud.bigquery.Table.delete')
+ @patch('google.cloud.bigquery.Table.exists', side_effect=[True, False])
+ @patch('google.cloud.bigquery.Dataset.exists', return_value=True)
+ def test_delete_bq_table_succeeds(self, *_):
+ utils.delete_bq_table('unused_project',
+ 'unused_dataset',
+ 'unused_table')
+
+ @patch('google.cloud.bigquery.Table.delete', side_effect=Exception)
+ @patch('google.cloud.bigquery.Table.exists', return_value=True)
+ @patch('google.cloud.bigquery.Dataset.exists', return_vaue=True)
+ def test_delete_bq_table_fails_with_server_error(self, *_):
+ with self.assertRaises(Exception):
+ utils.delete_bq_table('unused_project',
+ 'unused_dataset',
+ 'unused_table')
+
+ @patch('google.cloud.bigquery.Table.delete')
+ @patch('google.cloud.bigquery.Table.exists', return_value=[True, True])
+ @patch('google.cloud.bigquery.Dataset.exists', return_vaue=True)
+ def test_delete_bq_table_fails_with_delete_error(self, *_):
+ with self.assertRaises(RuntimeError):
+ utils.delete_bq_table('unused_project',
+ 'unused_dataset',
+ 'unused_table')
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 7e40d83..db75fe3 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -37,6 +37,7 @@
from apache_beam import pvalue
from apache_beam import coders
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsIter
from apache_beam.pvalue import AsSingleton
from apache_beam.transforms import core
@@ -44,6 +45,7 @@
from apache_beam.transforms import window
from apache_beam.transforms.display import HasDisplayData
from apache_beam.transforms.display import DisplayDataItem
+from apache_beam.utils import urns
from apache_beam.utils.windowed_value import WindowedValue
__all__ = ['BoundedSource', 'RangeTracker', 'Read', 'Sink', 'Write', 'Writer']
@@ -70,7 +72,13 @@
'weight source start_position stop_position')
-class BoundedSource(HasDisplayData):
+class SourceBase(HasDisplayData, urns.RunnerApiFn):
+ """Base class for all sources that can be passed to beam.io.Read(...).
+ """
+ urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_SOURCE)
+
+
+class BoundedSource(SourceBase):
"""A source that reads a finite amount of input records.
This class defines following operations which can be used to read the source
@@ -189,6 +197,9 @@
"""
return coders.registry.get_coder(object)
+ def is_bounded(self):
+ return True
+
class RangeTracker(object):
"""A thread safe object used by Dataflow source framework.
@@ -820,6 +831,24 @@
label='Read Source'),
'source_dd': self.source}
+ def to_runner_api_parameter(self, context):
+ return (urns.READ_TRANSFORM,
+ beam_runner_api_pb2.ReadPayload(
+ source=self.source.to_runner_api(context),
+ is_bounded=beam_runner_api_pb2.BOUNDED
+ if self.source.is_bounded()
+ else beam_runner_api_pb2.UNBOUNDED))
+
+ @staticmethod
+ def from_runner_api_parameter(parameter, context):
+ return Read(SourceBase.from_runner_api(parameter.source, context))
+
+
+ptransform.PTransform.register_urn(
+ urns.READ_TRANSFORM,
+ beam_runner_api_pb2.ReadPayload,
+ Read.from_runner_api_parameter)
+
class Write(ptransform.PTransform):
"""A ``PTransform`` that writes to a sink.
diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py
index 9a4ec47..8bd7116 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -401,6 +401,64 @@
assert_that(pcoll, equal_to(lines))
pipeline.run()
+ def test_read_corrupted_bzip2_fails(self):
+ _, lines = write_data(15)
+ file_name = self._create_temp_file()
+ with bz2.BZ2File(file_name, 'wb') as f:
+ f.write('\n'.join(lines))
+
+ with open(file_name, 'wb') as f:
+ f.write('corrupt')
+
+ pipeline = TestPipeline()
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ compression_type=CompressionTypes.BZIP2)
+ assert_that(pcoll, equal_to(lines))
+ with self.assertRaises(Exception):
+ pipeline.run()
+
+ def test_read_bzip2_concat(self):
+ bzip2_file_name1 = self._create_temp_file()
+ lines = ['a', 'b', 'c']
+ with bz2.BZ2File(bzip2_file_name1, 'wb') as dst:
+ data = '\n'.join(lines) + '\n'
+ dst.write(data)
+
+ bzip2_file_name2 = self._create_temp_file()
+ lines = ['p', 'q', 'r']
+ with bz2.BZ2File(bzip2_file_name2, 'wb') as dst:
+ data = '\n'.join(lines) + '\n'
+ dst.write(data)
+
+ bzip2_file_name3 = self._create_temp_file()
+ lines = ['x', 'y', 'z']
+ with bz2.BZ2File(bzip2_file_name3, 'wb') as dst:
+ data = '\n'.join(lines) + '\n'
+ dst.write(data)
+
+ final_bzip2_file = self._create_temp_file()
+ with open(bzip2_file_name1, 'rb') as src, open(
+ final_bzip2_file, 'wb') as dst:
+ dst.writelines(src.readlines())
+
+ with open(bzip2_file_name2, 'rb') as src, open(
+ final_bzip2_file, 'ab') as dst:
+ dst.writelines(src.readlines())
+
+ with open(bzip2_file_name3, 'rb') as src, open(
+ final_bzip2_file, 'ab') as dst:
+ dst.writelines(src.readlines())
+
+ pipeline = TestPipeline()
+ lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
+ final_bzip2_file,
+ compression_type=beam.io.filesystem.CompressionTypes.BZIP2)
+
+ expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
+ assert_that(lines, equal_to(expected))
+ pipeline.run()
+
def test_read_gzip(self):
_, lines = write_data(15)
file_name = self._create_temp_file()
@@ -415,6 +473,63 @@
assert_that(pcoll, equal_to(lines))
pipeline.run()
+ def test_read_corrupted_gzip_fails(self):
+ _, lines = write_data(15)
+ file_name = self._create_temp_file()
+ with gzip.GzipFile(file_name, 'wb') as f:
+ f.write('\n'.join(lines))
+
+ with open(file_name, 'wb') as f:
+ f.write('corrupt')
+
+ pipeline = TestPipeline()
+ pcoll = pipeline | 'Read' >> ReadFromText(
+ file_name,
+ 0, CompressionTypes.GZIP,
+ True, coders.StrUtf8Coder())
+ assert_that(pcoll, equal_to(lines))
+
+ with self.assertRaises(Exception):
+ pipeline.run()
+
+ def test_read_gzip_concat(self):
+ gzip_file_name1 = self._create_temp_file()
+ lines = ['a', 'b', 'c']
+ with gzip.open(gzip_file_name1, 'wb') as dst:
+ data = '\n'.join(lines) + '\n'
+ dst.write(data)
+
+ gzip_file_name2 = self._create_temp_file()
+ lines = ['p', 'q', 'r']
+ with gzip.open(gzip_file_name2, 'wb') as dst:
+ data = '\n'.join(lines) + '\n'
+ dst.write(data)
+
+ gzip_file_name3 = self._create_temp_file()
+ lines = ['x', 'y', 'z']
+ with gzip.open(gzip_file_name3, 'wb') as dst:
+ data = '\n'.join(lines) + '\n'
+ dst.write(data)
+
+ final_gzip_file = self._create_temp_file()
+ with open(gzip_file_name1, 'rb') as src, open(final_gzip_file, 'wb') as dst:
+ dst.writelines(src.readlines())
+
+ with open(gzip_file_name2, 'rb') as src, open(final_gzip_file, 'ab') as dst:
+ dst.writelines(src.readlines())
+
+ with open(gzip_file_name3, 'rb') as src, open(final_gzip_file, 'ab') as dst:
+ dst.writelines(src.readlines())
+
+ pipeline = TestPipeline()
+ lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
+ final_gzip_file,
+ compression_type=beam.io.filesystem.CompressionTypes.GZIP)
+
+ expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
+ assert_that(lines, equal_to(expected))
+ pipeline.run()
+
def test_read_gzip_large(self):
_, lines = write_data(10000)
file_name = self._create_temp_file()
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index fe36d85..e7c2322 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -66,6 +66,7 @@
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator
from apache_beam.utils.annotations import deprecated
+from apache_beam.utils import urns
__all__ = ['Pipeline']
@@ -474,6 +475,12 @@
return str, ('Pickled pipeline stub.',)
def _verify_runner_api_compatible(self):
+ if self._options.view_as(TypeOptions).runtime_type_check:
+ # This option is incompatible with the runner API as it requires
+ # the runner to inspect non-serialized hints on the transform
+ # itself.
+ return False
+
class Visitor(PipelineVisitor): # pylint: disable=used-before-assignment
ok = True # Really a nonlocal.
@@ -723,7 +730,8 @@
return beam_runner_api_pb2.PTransform(
unique_name=self.full_label,
spec=transform_to_runner_api(self.transform, context),
- subtransforms=[context.transforms.get_id(part) for part in self.parts],
+ subtransforms=[context.transforms.get_id(part, label=part.full_label)
+ for part in self.parts],
# TODO(BEAM-115): Side inputs.
inputs={tag: context.pcollections.get_id(pc)
for tag, pc in self.named_inputs().items()},
@@ -745,6 +753,10 @@
result.outputs = {
None if tag == 'None' else tag: context.pcollections.get_by_id(id)
for tag, id in proto.outputs.items()}
+ # This annotation is expected by some runners.
+ if proto.spec.urn == urns.PARDO_TRANSFORM:
+ result.transform.output_tags = set(proto.outputs.keys()).difference(
+ {'None'})
if not result.parts:
for tag, pc in result.outputs.items():
if pc not in result.inputs:
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index aec7d00..880901e 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -188,18 +188,31 @@
if not input_type:
input_type = typehints.Any
- if not isinstance(input_type, typehints.TupleHint.TupleConstraint):
- if isinstance(input_type, typehints.AnyTypeConstraint):
+ def coerce_to_kv_type(element_type):
+ if isinstance(element_type, typehints.TupleHint.TupleConstraint):
+ if len(element_type.tuple_types) == 2:
+ return element_type
+ else:
+ raise ValueError(
+ "Tuple input to GroupByKey must be have two components. "
+ "Found %s for %s" % (element_type, pcoll))
+ elif isinstance(input_type, typehints.AnyTypeConstraint):
# `Any` type needs to be replaced with a KV[Any, Any] to
# force a KV coder as the main output coder for the pcollection
# preceding a GroupByKey.
- pcoll.element_type = typehints.KV[typehints.Any, typehints.Any]
+ return typehints.KV[typehints.Any, typehints.Any]
+ elif isinstance(element_type, typehints.UnionConstraint):
+ union_types = [
+ coerce_to_kv_type(t) for t in element_type.union_types]
+ return typehints.KV[
+ typehints.Union[tuple(t.tuple_types[0] for t in union_types)],
+ typehints.Union[tuple(t.tuple_types[1] for t in union_types)]]
else:
- # TODO: Handle other valid types,
- # e.g. Union[KV[str, int], KV[str, float]]
+ # TODO: Possibly handle other valid types.
raise ValueError(
"Input to GroupByKey must be of Tuple or Any type. "
- "Found %s for %s" % (input_type, pcoll))
+ "Found %s for %s" % (element_type, pcoll))
+ pcoll.element_type = coerce_to_kv_type(input_type)
return GroupByKeyInputVisitor()
@@ -517,10 +530,12 @@
si_labels[side_pval] = si_label
# Now create the step for the ParDo transform being handled.
+ transform_name = transform_node.full_label.rsplit('/', 1)[-1]
step = self._add_step(
TransformNames.DO,
transform_node.full_label + (
- '/Do' if transform_node.side_inputs else ''),
+ '/{}'.format(transform_name)
+ if transform_node.side_inputs else ''),
transform_node,
transform_node.transform.output_tags)
fn_data = self._pardo_fn_data(transform_node, lookup_label)
@@ -591,8 +606,8 @@
PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})
# Note that the accumulator must not have a WindowedValue encoding, while
# the output of this step does in fact have a WindowedValue encoding.
- accumulator_encoding = self._get_encoded_output_coder(transform_node,
- window_value=False)
+ accumulator_encoding = self._get_cloud_encoding(
+ transform_node.transform.fn.get_accumulator_coder())
output_encoding = self._get_encoded_output_coder(transform_node)
step.encoding = output_encoding
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index a9b8fdb..80414d6 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -256,6 +256,28 @@
for _ in range(num_inputs):
self.assertEqual(inputs[0].element_type, output_type)
+ def test_gbk_then_flatten_input_visitor(self):
+ p = TestPipeline(
+ runner=DataflowRunner(),
+ options=PipelineOptions(self.default_properties))
+ none_str_pc = p | 'c1' >> beam.Create({None: 'a'})
+ none_int_pc = p | 'c2' >> beam.Create({None: 3})
+ flat = (none_str_pc, none_int_pc) | beam.Flatten()
+ _ = flat | beam.GroupByKey()
+
+ # This may change if type inference changes, but we assert it here
+ # to make sure the check below is not vacuous.
+ self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint)
+
+ p.visit(DataflowRunner.group_by_key_input_visitor())
+ p.visit(DataflowRunner.flatten_input_visitor())
+
+ # The dataflow runner requires gbk input to be tuples *and* flatten
+ # inputs to be equal to their outputs. Assert both hold.
+ self.assertIsInstance(flat.element_type, typehints.TupleConstraint)
+ self.assertEqual(flat.element_type, none_str_pc.element_type)
+ self.assertEqual(flat.element_type, none_int_pc.element_type)
+
def test_serialize_windowing_strategy(self):
# This just tests the basic path; more complete tests
# are in window_test.py.
diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
index c1f4238..2f2316f 100644
--- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
+++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
@@ -23,6 +23,7 @@
import logging
from apache_beam import pvalue
+from apache_beam.io import iobase
from apache_beam.transforms import ptransform
from apache_beam.transforms.display import HasDisplayData
@@ -42,7 +43,7 @@
'compression_type']
-class NativeSource(HasDisplayData):
+class NativeSource(iobase.SourceBase):
"""A source implemented by Dataflow service.
This class is to be only inherited by sources natively implemented by Cloud
@@ -55,6 +56,9 @@
"""Returns a NativeSourceReader instance associated with this source."""
raise NotImplementedError
+ def is_bounded(self):
+ return True
+
def __repr__(self):
return '<{name} {vals}>'.format(
name=self.__class__.__name__,
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 1a94b3d..7a88d0e 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -26,6 +26,8 @@
import collections
import logging
+from google.protobuf import wrappers_pb2
+
import apache_beam as beam
from apache_beam import typehints
from apache_beam.metrics.execution import MetricsEnvironment
@@ -35,6 +37,7 @@
from apache_beam.runners.runner import PipelineRunner
from apache_beam.runners.runner import PipelineState
from apache_beam.runners.runner import PValueCache
+from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.core import _GroupAlsoByWindow
from apache_beam.transforms.core import _GroupByKeyOnly
from apache_beam.options.pipeline_options import DirectOptions
@@ -54,14 +57,34 @@
@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
class _StreamingGroupByKeyOnly(_GroupByKeyOnly):
"""Streaming GroupByKeyOnly placeholder for overriding in DirectRunner."""
- pass
+ urn = "direct_runner:streaming_gbko:v0.1"
+
+ # These are needed due to apply overloads.
+ def to_runner_api_parameter(self, unused_context):
+ return _StreamingGroupByKeyOnly.urn, None
+
+ @PTransform.register_urn(urn, None)
+ def from_runner_api_parameter(unused_payload, unused_context):
+ return _StreamingGroupByKeyOnly()
@typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]])
@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow):
"""Streaming GroupAlsoByWindow placeholder for overriding in DirectRunner."""
- pass
+ urn = "direct_runner:streaming_gabw:v0.1"
+
+ # These are needed due to apply overloads.
+ def to_runner_api_parameter(self, context):
+ return (
+ _StreamingGroupAlsoByWindow.urn,
+ wrappers_pb2.BytesValue(value=context.windowing_strategies.get_id(
+ self.windowing)))
+
+ @PTransform.register_urn(urn, wrappers_pb2.BytesValue)
+ def from_runner_api_parameter(payload, context):
+ return _StreamingGroupAlsoByWindow(
+ context.windowing_strategies.get_by_id(payload.value))
class DirectRunner(PipelineRunner):
diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py
index a40069b..42d7f5d 100644
--- a/sdks/python/apache_beam/runners/pipeline_context.py
+++ b/sdks/python/apache_beam/runners/pipeline_context.py
@@ -40,21 +40,21 @@
self._obj_type = obj_type
self._obj_to_id = {}
self._id_to_obj = {}
- self._id_to_proto = proto_map if proto_map else {}
+ self._id_to_proto = dict(proto_map) if proto_map else {}
self._counter = 0
- def _unique_ref(self, obj=None):
+ def _unique_ref(self, obj=None, label=None):
self._counter += 1
return "ref_%s_%s_%s" % (
- self._obj_type.__name__, type(obj).__name__, self._counter)
+ self._obj_type.__name__, label or type(obj).__name__, self._counter)
def populate_map(self, proto_map):
for id, proto in self._id_to_proto.items():
proto_map[id].CopyFrom(proto)
- def get_id(self, obj):
+ def get_id(self, obj, label=None):
if obj not in self._obj_to_id:
- id = self._unique_ref(obj)
+ id = self._unique_ref(obj, label)
self._id_to_obj[id] = obj
self._obj_to_id[obj] = id
self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context)
@@ -66,6 +66,12 @@
self._id_to_proto[id], self._pipeline_context)
return self._id_to_obj[id]
+ def __getitem__(self, id):
+ return self.get_by_id(id)
+
+ def __contains__(self, id):
+ return id in self._id_to_proto
+
class PipelineContext(object):
"""For internal use only; no backwards-compatibility guarantees.
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index f88fe53..3222bcb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -19,6 +19,7 @@
"""
import base64
import collections
+import copy
import logging
import Queue as queue
import threading
@@ -28,21 +29,26 @@
import grpc
import apache_beam as beam # pylint: disable=ungrouped-imports
+from apache_beam.coders import registry
from apache_beam.coders import WindowedValueCoder
from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.internal import pickler
from apache_beam.io import iobase
-from apache_beam.transforms.window import GlobalWindows
+from apache_beam.metrics.execution import MetricsEnvironment
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.runner import PipelineState
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import sdk_worker
+from apache_beam.transforms.window import GlobalWindows
from apache_beam.utils import proto_utils
+from apache_beam.utils import urns
+
# This module is experimental. No backwards-compatibility guarantees.
@@ -113,6 +119,30 @@
beam.transforms.core.Windowing(GlobalWindows())))
+class _GroupingBuffer(object):
+ """Used to accumulate groupded (shuffled) results."""
+ def __init__(self, pre_grouped_coder, post_grouped_coder):
+ self._key_coder = pre_grouped_coder.value_coder().key_coder()
+ self._pre_grouped_coder = pre_grouped_coder
+ self._post_grouped_coder = post_grouped_coder
+ self._table = collections.defaultdict(list)
+
+ def append(self, elements_data):
+ input_stream = create_InputStream(elements_data)
+ while input_stream.size() > 0:
+ key, value = self._pre_grouped_coder.get_impl().decode_from_stream(
+ input_stream, True).value
+ self._table[self._key_coder.encode(key)].append(value)
+
+ def __iter__(self):
+ output_stream = create_OutputStream()
+ for encoded_key, values in self._table.items():
+ key = self._key_coder.decode(encoded_key)
+ self._post_grouped_coder.get_impl().encode_to_stream(
+ GlobalWindows.windowed_value((key, values)), output_stream, True)
+ return iter([output_stream.get()])
+
+
class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
def __init__(self):
@@ -126,6 +156,520 @@
self._last_uid += 1
return str(self._last_uid)
+ def run(self, pipeline):
+ MetricsEnvironment.set_metrics_supported(self.has_metrics_support())
+ if pipeline._verify_runner_api_compatible():
+ return self.run_via_runner_api(pipeline.to_runner_api())
+ else:
+ return super(FnApiRunner, self).run(pipeline)
+
+ def run_via_runner_api(self, pipeline_proto):
+ return self.run_stages(*self.create_stages(pipeline_proto))
+
+ def create_stages(self, pipeline_proto):
+
+ # First define a couple of helpers.
+
+ def union(a, b):
+ # Minimize the number of distinct sets.
+ if not a or a == b:
+ return b
+ elif not b:
+ return a
+ else:
+ return frozenset.union(a, b)
+
+ class Stage(object):
+ """A set of Transforms that can be sent to the worker for processing."""
+ def __init__(self, name, transforms,
+ downstream_side_inputs=None, must_follow=frozenset()):
+ self.name = name
+ self.transforms = transforms
+ self.downstream_side_inputs = downstream_side_inputs
+ self.must_follow = must_follow
+
+ def __repr__(self):
+ must_follow = ', '.join(prev.name for prev in self.must_follow)
+ return "%s\n %s\n must follow: %s" % (
+ self.name,
+ '\n'.join(["%s:%s" % (transform.unique_name, transform.spec.urn)
+ for transform in self.transforms]),
+ must_follow)
+
+ def can_fuse(self, consumer):
+ def no_overlap(a, b):
+ return not a.intersection(b)
+ return (
+ not self in consumer.must_follow
+ and not self.is_flatten() and not consumer.is_flatten()
+ and no_overlap(self.downstream_side_inputs, consumer.side_inputs()))
+
+ def fuse(self, other):
+ return Stage(
+ "(%s)+(%s)" % (self.name, other.name),
+ self.transforms + other.transforms,
+ union(self.downstream_side_inputs, other.downstream_side_inputs),
+ union(self.must_follow, other.must_follow))
+
+ def is_flatten(self):
+ return any(transform.spec.urn == urns.FLATTEN_TRANSFORM
+ for transform in self.transforms)
+
+ def side_inputs(self):
+ for transform in self.transforms:
+ if transform.spec.urn == urns.PARDO_TRANSFORM:
+ payload = proto_utils.unpack_Any(
+ transform.spec.parameter, beam_runner_api_pb2.ParDoPayload)
+ for side_input in payload.side_inputs:
+ yield transform.inputs[side_input]
+
+ def has_as_main_input(self, pcoll):
+ for transform in self.transforms:
+ if transform.spec.urn == urns.PARDO_TRANSFORM:
+ payload = proto_utils.unpack_Any(
+ transform.spec.parameter, beam_runner_api_pb2.ParDoPayload)
+ local_side_inputs = payload.side_inputs
+ else:
+ local_side_inputs = {}
+ for local_id, pipeline_id in transform.inputs.items():
+ if pcoll == pipeline_id and local_id not in local_side_inputs:
+ return True
+
+ def deduplicate_read(self):
+ seen_pcolls = set()
+ new_transforms = []
+ for transform in self.transforms:
+ if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+ pcoll = only_element(transform.outputs.items())[1]
+ if pcoll in seen_pcolls:
+ continue
+ seen_pcolls.add(pcoll)
+ new_transforms.append(transform)
+ self.transforms = new_transforms
+
+ # Now define the "optimization" phases.
+
+ def expand_gbk(stages):
+ """Transforms each GBK into a write followed by a read.
+ """
+ for stage in stages:
+ assert len(stage.transforms) == 1
+ transform = stage.transforms[0]
+ if transform.spec.urn == urns.GROUP_BY_KEY_ONLY_TRANSFORM:
+ # This is used later to correlate the read and write.
+ param = proto_utils.pack_Any(
+ wrappers_pb2.BytesValue(
+ value=str("group:%s" % stage.name)))
+ gbk_write = Stage(
+ transform.unique_name + '/Write',
+ [beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Write',
+ inputs=transform.inputs,
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_OUTPUT_URN,
+ parameter=param))],
+ downstream_side_inputs=frozenset(),
+ must_follow=stage.must_follow)
+ yield gbk_write
+
+ yield Stage(
+ transform.unique_name + '/Read',
+ [beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Read',
+ outputs=transform.outputs,
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_INPUT_URN,
+ parameter=param))],
+ downstream_side_inputs=frozenset(),
+ must_follow=union(frozenset([gbk_write]), stage.must_follow))
+ else:
+ yield stage
+
+ def sink_flattens(stages):
+ """Sink flattens and remove them from the graph.
+
+ A flatten that cannot be sunk/fused away becomes multiple writes (to the
+ same logical sink) followed by a read.
+ """
+ # TODO(robertwb): Actually attempt to sink rather than always materialize.
+ # TODO(robertwb): Possibly fuse this into one of the stages.
+ pcollections = pipeline_components.pcollections
+ for stage in stages:
+ assert len(stage.transforms) == 1
+ transform = stage.transforms[0]
+ if transform.spec.urn == urns.FLATTEN_TRANSFORM:
+ # This is used later to correlate the read and writes.
+ param = proto_utils.pack_Any(
+ wrappers_pb2.BytesValue(
+ value=str("materialize:%s" % transform.unique_name)))
+ output_pcoll_id, = transform.outputs.values()
+ output_coder_id = pcollections[output_pcoll_id].coder_id
+ flatten_writes = []
+ for local_in, pcoll_in in transform.inputs.items():
+
+ if pcollections[pcoll_in].coder_id != output_coder_id:
+ # Flatten inputs must all be written with the same coder as is
+ # used to read them.
+ pcollections[pcoll_in].coder_id = output_coder_id
+ transcoded_pcollection = (
+ transform.unique_name + '/Transcode/' + local_in + '/out')
+ yield Stage(
+ transform.unique_name + '/Transcode/' + local_in,
+ [beam_runner_api_pb2.PTransform(
+ unique_name=
+ transform.unique_name + '/Transcode/' + local_in,
+ inputs={local_in: pcoll_in},
+ outputs={'out': transcoded_pcollection},
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.IDENTITY_DOFN_URN))],
+ downstream_side_inputs=frozenset(),
+ must_follow=stage.must_follow)
+ pcollections[transcoded_pcollection].CopyFrom(
+ pcollections[pcoll_in])
+ pcollections[transcoded_pcollection].coder_id = output_coder_id
+ else:
+ transcoded_pcollection = pcoll_in
+
+ flatten_write = Stage(
+ transform.unique_name + '/Write/' + local_in,
+ [beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Write/' + local_in,
+ inputs={local_in: transcoded_pcollection},
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_OUTPUT_URN,
+ parameter=param))],
+ downstream_side_inputs=frozenset(),
+ must_follow=stage.must_follow)
+ flatten_writes.append(flatten_write)
+ yield flatten_write
+
+ yield Stage(
+ transform.unique_name + '/Read',
+ [beam_runner_api_pb2.PTransform(
+ unique_name=transform.unique_name + '/Read',
+ outputs=transform.outputs,
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_INPUT_URN,
+ parameter=param))],
+ downstream_side_inputs=frozenset(),
+ must_follow=union(frozenset(flatten_writes), stage.must_follow))
+
+ else:
+ yield stage
+
+ def annotate_downstream_side_inputs(stages):
+ """Annotate each stage with fusion-prohibiting information.
+
+ Each stage is annotated with the (transitive) set of pcollections that
+ depend on this stage that are also used later in the pipeline as a
+ side input.
+
+ While theoretically this could result in O(n^2) annotations, the size of
+ each set is bounded by the number of side inputs (typically much smaller
+ than the number of total nodes) and the number of *distinct* side-input
+ sets is also generally small (and shared due to the use of union
+ defined above).
+
+ This representation is also amenable to simple recomputation on fusion.
+ """
+ consumers = collections.defaultdict(list)
+ all_side_inputs = set()
+ for stage in stages:
+ for transform in stage.transforms:
+ for input in transform.inputs.values():
+ consumers[input].append(stage)
+ for si in stage.side_inputs():
+ all_side_inputs.add(si)
+ all_side_inputs = frozenset(all_side_inputs)
+
+ downstream_side_inputs_by_stage = {}
+
+ def compute_downstream_side_inputs(stage):
+ if stage not in downstream_side_inputs_by_stage:
+ downstream_side_inputs = frozenset()
+ for transform in stage.transforms:
+ for output in transform.outputs.values():
+ if output in all_side_inputs:
+ downstream_side_inputs = union(downstream_side_inputs, output)
+ for consumer in consumers[output]:
+ downstream_side_inputs = union(
+ downstream_side_inputs,
+ compute_downstream_side_inputs(consumer))
+ downstream_side_inputs_by_stage[stage] = downstream_side_inputs
+ return downstream_side_inputs_by_stage[stage]
+
+ for stage in stages:
+ stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
+ return stages
+
+ def greedily_fuse(stages):
+ """Places transforms sharing an edge in the same stage, whenever possible.
+ """
+ producers_by_pcoll = {}
+ consumers_by_pcoll = collections.defaultdict(list)
+
+ # Used to always reference the correct stage as the producer and
+ # consumer maps are not updated when stages are fused away.
+ replacements = {}
+
+ def replacement(s):
+ old_ss = []
+ while s in replacements:
+ old_ss.append(s)
+ s = replacements[s]
+ for old_s in old_ss[:-1]:
+ replacements[old_s] = s
+ return s
+
+ def fuse(producer, consumer):
+ fused = producer.fuse(consumer)
+ replacements[producer] = fused
+ replacements[consumer] = fused
+
+ # First record the producers and consumers of each PCollection.
+ for stage in stages:
+ for transform in stage.transforms:
+ for input in transform.inputs.values():
+ consumers_by_pcoll[input].append(stage)
+ for output in transform.outputs.values():
+ producers_by_pcoll[output] = stage
+
+ logging.debug('consumers\n%s', consumers_by_pcoll)
+ logging.debug('producers\n%s', producers_by_pcoll)
+
+ # Now try to fuse away all pcollections.
+ for pcoll, producer in producers_by_pcoll.items():
+ pcoll_as_param = proto_utils.pack_Any(
+ wrappers_pb2.BytesValue(
+ value=str("materialize:%s" % pcoll)))
+ write_pcoll = None
+ for consumer in consumers_by_pcoll[pcoll]:
+ producer = replacement(producer)
+ consumer = replacement(consumer)
+ # Update consumer.must_follow set, as it's used in can_fuse.
+ consumer.must_follow = set(
+ replacement(s) for s in consumer.must_follow)
+ if producer.can_fuse(consumer):
+ fuse(producer, consumer)
+ else:
+ # If we can't fuse, do a read + write.
+ if write_pcoll is None:
+ write_pcoll = Stage(
+ pcoll + '/Write',
+ [beam_runner_api_pb2.PTransform(
+ unique_name=pcoll + '/Write',
+ inputs={'in': pcoll},
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_OUTPUT_URN,
+ parameter=pcoll_as_param))])
+ fuse(producer, write_pcoll)
+ if consumer.has_as_main_input(pcoll):
+ read_pcoll = Stage(
+ pcoll + '/Read',
+ [beam_runner_api_pb2.PTransform(
+ unique_name=pcoll + '/Read',
+ outputs={'out': pcoll},
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_INPUT_URN,
+ parameter=pcoll_as_param))],
+ must_follow={write_pcoll})
+ fuse(read_pcoll, consumer)
+
+ # Everything that was originally a stage or a replacement, but wasn't
+ # replaced, should be in the final graph.
+ final_stages = frozenset(stages).union(replacements.values()).difference(
+ replacements.keys())
+
+ for stage in final_stages:
+ # Update all references to their final values before throwing
+ # the replacement data away.
+ stage.must_follow = frozenset(replacement(s) for s in stage.must_follow)
+ # Two reads of the same stage may have been fused. This is unneeded.
+ stage.deduplicate_read()
+ return final_stages
+
+ def sort_stages(stages):
+ """Order stages suitable for sequential execution.
+ """
+ seen = set()
+ ordered = []
+
+ def process(stage):
+ if stage not in seen:
+ seen.add(stage)
+ for prev in stage.must_follow:
+ process(prev)
+ ordered.append(stage)
+ for stage in stages:
+ process(stage)
+ return ordered
+
+ # Now actually apply the operations.
+
+ pipeline_components = copy.deepcopy(pipeline_proto.components)
+
+ # Reify coders.
+ # TODO(BEAM-2717): Remove once Coders are already in proto.
+ coders = pipeline_context.PipelineContext(pipeline_components).coders
+ for pcoll in pipeline_components.pcollections.values():
+ if pcoll.coder_id not in coders:
+ window_coder = coders[
+ pipeline_components.windowing_strategies[
+ pcoll.windowing_strategy_id].window_coder_id]
+ coder = WindowedValueCoder(
+ registry.get_coder(pickler.loads(pcoll.coder_id)),
+ window_coder=window_coder)
+ pcoll.coder_id = coders.get_id(coder)
+ coders.populate_map(pipeline_components.coders)
+
+ # Initial set of stages are singleton transforms.
+ stages = [
+ Stage(name, [transform])
+ for name, transform in pipeline_proto.components.transforms.items()
+ if not transform.subtransforms]
+
+ # Apply each phase in order.
+ for phase in [
+ annotate_downstream_side_inputs, expand_gbk, sink_flattens,
+ greedily_fuse, sort_stages]:
+ logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
+ stages = list(phase(stages))
+ logging.debug('Stages: %s', [str(s) for s in stages])
+
+ # Return the (possibly mutated) context and ordered set of stages.
+ return pipeline_components, stages
+
+ def run_stages(self, pipeline_components, stages, direct=True):
+
+ if direct:
+ controller = FnApiRunner.DirectController()
+ else:
+ controller = FnApiRunner.GrpcController()
+
+ try:
+ pcoll_buffers = collections.defaultdict(list)
+ for stage in stages:
+ self.run_stage(controller, pipeline_components, stage, pcoll_buffers)
+ finally:
+ controller.close()
+
+ return maptask_executor_runner.WorkerRunnerResult(PipelineState.DONE)
+
+ def run_stage(self, controller, pipeline_components, stage, pcoll_buffers):
+
+ coders = pipeline_context.PipelineContext(pipeline_components).coders
+ data_operation_spec = controller.data_operation_spec()
+
+ def extract_endpoints(stage):
+ # Returns maps of transform names to PCollection identifiers.
+ # Also mutates IO stages to point to the data data_operation_spec.
+ data_input = {}
+ data_side_input = {}
+ data_output = {}
+ for transform in stage.transforms:
+ pcoll_id = proto_utils.unpack_Any(
+ transform.spec.parameter, wrappers_pb2.BytesValue).value
+ if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
+ bundle_processor.DATA_OUTPUT_URN):
+ if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+ target = transform.unique_name, only_element(transform.outputs)
+ data_input[target] = pcoll_id
+ elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
+ target = transform.unique_name, only_element(transform.inputs)
+ data_output[target] = pcoll_id
+ else:
+ raise NotImplementedError
+ if data_operation_spec:
+ transform.spec.parameter.CopyFrom(data_operation_spec)
+ else:
+ transform.spec.parameter.Clear()
+ return data_input, data_side_input, data_output
+
+ logging.info('Running %s', stage.name)
+ logging.debug(' %s', stage)
+ data_input, data_side_input, data_output = extract_endpoints(stage)
+ if data_side_input:
+ raise NotImplementedError('Side inputs.')
+
+ process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
+ id=self._next_uid(),
+ transforms={transform.unique_name: transform
+ for transform in stage.transforms},
+ pcollections=dict(pipeline_components.pcollections.items()),
+ coders=dict(pipeline_components.coders.items()),
+ windowing_strategies=dict(
+ pipeline_components.windowing_strategies.items()),
+ environments=dict(pipeline_components.environments.items()))
+
+ process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
+ instruction_id=self._next_uid(),
+ register=beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=[process_bundle_descriptor]))
+
+ process_bundle = beam_fn_api_pb2.InstructionRequest(
+ instruction_id=self._next_uid(),
+ process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+ process_bundle_descriptor_reference=
+ process_bundle_descriptor.id))
+
+ # Write all the input data to the channel.
+ for (transform_id, name), pcoll_id in data_input.items():
+ data_out = controller.data_plane_handler.output_stream(
+ process_bundle.instruction_id, beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_id, name=name))
+ for element_data in pcoll_buffers[pcoll_id]:
+ data_out.write(element_data)
+ data_out.close()
+
+ # Register and start running the bundle.
+ controller.control_handler.push(process_bundle_registration)
+ controller.control_handler.push(process_bundle)
+
+ # Wait for the bundle to finish.
+ while True:
+ result = controller.control_handler.pull()
+ if result.instruction_id == process_bundle.instruction_id:
+ if result.error:
+ raise RuntimeError(result.error)
+ break
+
+ # Gather all output data.
+ expected_targets = [
+ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+ name=output_name)
+ for (transform_id, output_name), _ in data_output.items()]
+ for output in controller.data_plane_handler.input_elements(
+ process_bundle.instruction_id, expected_targets):
+ target_tuple = (
+ output.target.primitive_transform_reference, output.target.name)
+ if target_tuple in data_output:
+ pcoll_id = data_output[target_tuple]
+ if pcoll_id.startswith('materialize:'):
+ # Just store the data chunks for replay.
+ pcoll_buffers[pcoll_id].append(output.data)
+ elif pcoll_id.startswith('group:'):
+ # This is a grouping write, create a grouping buffer if needed.
+ if pcoll_id not in pcoll_buffers:
+ original_gbk_transform = pcoll_id.split(':', 1)[1]
+ transform_proto = pipeline_components.transforms[
+ original_gbk_transform]
+ input_pcoll = only_element(transform_proto.inputs.values())
+ output_pcoll = only_element(transform_proto.outputs.values())
+ pre_gbk_coder = coders[
+ pipeline_components.pcollections[input_pcoll].coder_id]
+ post_gbk_coder = coders[
+ pipeline_components.pcollections[output_pcoll].coder_id]
+ pcoll_buffers[pcoll_id] = _GroupingBuffer(
+ pre_gbk_coder, post_gbk_coder)
+ pcoll_buffers[pcoll_id].append(output.data)
+ else:
+ # These should be the only two identifiers we produce for now,
+ # but special side input writes may go here.
+ raise NotImplementedError(pcoll_id)
+
+ # This is the "old" way of executing pipelines.
+ # TODO(robertwb): Remove once runner API supports side inputs.
+
def _map_task_registration(self, map_task, state_handler,
data_operation_spec):
input_data, side_input_data, runner_sinks, process_bundle_descriptor = (
@@ -175,10 +719,6 @@
return {tag: pcollection_id(op_ix, out_ix)
for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))}
- def only_element(iterable):
- element, = iterable
- return element
-
for op_ix, (stage_name, operation) in enumerate(map_task):
transform_id = uniquify(stage_name)
@@ -332,6 +872,15 @@
finally:
controller.close()
+ @staticmethod
+ def _reencode_elements(elements, element_coder):
+ output_stream = create_OutputStream()
+ for element in elements:
+ element_coder.get_impl().encode_to_stream(element, output_stream, True)
+ return output_stream.get()
+
+ # These classes are used to interact with the worker.
+
class SimpleState(object): # TODO(robertwb): Inherit from GRPC servicer.
def __init__(self):
@@ -429,9 +978,7 @@
self.control_server.stop(5).wait()
self.data_server.stop(5).wait()
- @staticmethod
- def _reencode_elements(elements, element_coder):
- output_stream = create_OutputStream()
- for element in elements:
- element_coder.get_impl().encode_to_stream(element, output_stream, True)
- return output_stream.get()
+
+def only_element(iterable):
+ element, = iterable
+ return element
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 163e980..ba21954 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -51,7 +51,7 @@
def test_assert_that(self):
# TODO: figure out a way for fn_api_runner to parse and raise the
# underlying exception.
- with self.assertRaisesRegexp(RuntimeError, 'BeamAssertException'):
+ with self.assertRaisesRegexp(Exception, 'Failed assert'):
with self.create_pipeline() as p:
assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 2669bfc..9474eda 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -28,17 +28,20 @@
from google.protobuf import wrappers_pb2
+import apache_beam as beam
from apache_beam.coders import coder_impl
from apache_beam.coders import WindowedValueCoder
from apache_beam.internal import pickler
from apache_beam.io import iobase
from apache_beam.portability.api import beam_fn_api_pb2
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.dataflow.native_io import iobase as native_iobase
from apache_beam.runners import pipeline_context
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import operations
from apache_beam.utils import counters
from apache_beam.utils import proto_utils
+from apache_beam.utils import urns
# This module is experimental. No backwards-compatibility guarantees.
@@ -374,6 +377,24 @@
consumers)
+@BeamTransformFactory.register_urn(
+ urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ # The Dataflow runner harness strips the base64 encoding.
+ source = iobase.SourceBase.from_runner_api(parameter.source, factory.context)
+ spec = operation_specs.WorkerRead(
+ iobase.SourceBundle(1.0, source, None, None),
+ [WindowedValueCoder(source.default_output_coder())])
+ return factory.augment_oldstyle_op(
+ operations.ReadOperation(
+ transform_proto.unique_name,
+ spec,
+ factory.counter_factory,
+ factory.state_sampler),
+ transform_proto.unique_name,
+ consumers)
+
+
@BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue)
def create(factory, transform_id, transform_proto, parameter, consumers):
dofn_data = pickler.loads(parameter.value)
@@ -383,7 +404,32 @@
else:
# No side input data.
serialized_fn, side_input_data = parameter.value, []
+ return _create_pardo_operation(
+ factory, transform_id, transform_proto, consumers,
+ serialized_fn, side_input_data)
+
+@BeamTransformFactory.register_urn(
+ urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ assert parameter.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
+ serialized_fn = proto_utils.unpack_Any(
+ parameter.do_fn.spec.parameter, wrappers_pb2.BytesValue).value
+ dofn_data = pickler.loads(serialized_fn)
+ if len(dofn_data) == 2:
+ # Has side input data.
+ serialized_fn, side_input_data = dofn_data
+ else:
+ # No side input data.
+ side_input_data = []
+ return _create_pardo_operation(
+ factory, transform_id, transform_proto, consumers,
+ serialized_fn, side_input_data)
+
+
+def _create_pardo_operation(
+ factory, transform_id, transform_proto, consumers,
+ serialized_fn, side_input_data):
def create_side_input(tag, coder):
# TODO(robertwb): Extract windows (and keys) out of element data.
# TODO(robertwb): Extract state key from ParDoPayload.
@@ -395,10 +441,27 @@
key=side_input_tag(transform_id, tag)),
coder=coder))
output_tags = list(transform_proto.outputs.keys())
+
+ # Hack to match out prefix injected by dataflow runner.
+ def mutate_tag(tag):
+ if 'None' in output_tags:
+ if tag == 'None':
+ return 'out'
+ else:
+ return 'out_' + tag
+ else:
+ return tag
+ dofn_data = pickler.loads(serialized_fn)
+ if not dofn_data[-1]:
+ # Windowing not set.
+ pcoll_id, = transform_proto.inputs.values()
+ windowing = factory.context.windowing_strategies.get_by_id(
+ factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
+ serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
output_coders = factory.get_output_coders(transform_proto)
spec = operation_specs.WorkerDoFn(
serialized_fn=serialized_fn,
- output_tags=output_tags,
+ output_tags=[mutate_tag(tag) for tag in output_tags],
input=None,
side_inputs=[
create_side_input(tag, coder) for tag, coder in side_input_data],
@@ -414,12 +477,52 @@
output_tags)
+def _create_simple_pardo_operation(
+ factory, transform_id, transform_proto, consumers, dofn):
+ serialized_fn = pickler.dumps((dofn, (), {}, [], None))
+ side_input_data = []
+ return _create_pardo_operation(
+ factory, transform_id, transform_proto, consumers,
+ serialized_fn, side_input_data)
+
+
+@BeamTransformFactory.register_urn(
+ urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ # Perhaps this hack can go away once all apply overloads are gone.
+ from apache_beam.transforms.core import _GroupAlsoByWindowDoFn
+ return _create_simple_pardo_operation(
+ factory, transform_id, transform_proto, consumers,
+ _GroupAlsoByWindowDoFn(
+ factory.context.windowing_strategies.get_by_id(parameter.value)))
+
+
+@BeamTransformFactory.register_urn(
+ urns.WINDOW_INTO_TRANSFORM, beam_runner_api_pb2.WindowingStrategy)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ class WindowIntoDoFn(beam.DoFn):
+ def __init__(self, windowing):
+ self.windowing = windowing
+
+ def process(self, element, timestamp=beam.DoFn.TimestampParam):
+ new_windows = self.windowing.windowfn.assign(
+ WindowFn.AssignContext(timestamp, element=element))
+ yield WindowedValue(element, timestamp, new_windows)
+ from apache_beam.transforms.core import Windowing
+ from apache_beam.transforms.window import WindowFn, WindowedValue
+ windowing = Windowing.from_runner_api(parameter, factory.context)
+ return _create_simple_pardo_operation(
+ factory, transform_id, transform_proto, consumers,
+ WindowIntoDoFn(windowing))
+
+
@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
def create(factory, transform_id, transform_proto, unused_parameter, consumers):
return factory.augment_oldstyle_op(
operations.FlattenOperation(
transform_proto.unique_name,
- None,
+ operation_specs.WorkerFlatten(
+ None, [factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 92b8737..3f92ce9 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -23,9 +23,13 @@
import inspect
import types
+from google.protobuf import wrappers_pb2
+
from apache_beam import pvalue
from apache_beam import typehints
+from apache_beam import coders
from apache_beam.coders import typecoders
+from apache_beam.internal import pickler
from apache_beam.internal import util
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import ptransform
@@ -49,6 +53,7 @@
from apache_beam.typehints.decorators import WithTypeHints
from apache_beam.typehints.trivial_inference import element_type
from apache_beam.typehints.typehints import is_consistent_with
+from apache_beam.utils import proto_utils
from apache_beam.utils import urns
from apache_beam.options.pipeline_options import TypeOptions
@@ -135,7 +140,7 @@
self.windows = windowed_value.windows
-class DoFn(WithTypeHints, HasDisplayData):
+class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
"""A function object used by a transform with custom processing.
The ParDo transform is such a transform. The ParDo.apply
@@ -235,6 +240,8 @@
return False # Method is a classmethod
return True
+ urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_DO_FN)
+
def _fn_takes_side_inputs(fn):
try:
@@ -311,7 +318,7 @@
return getattr(self._fn, '_argspec_fn', self._fn)
-class CombineFn(WithTypeHints, HasDisplayData):
+class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
"""A function object used by a Combine transform with custom processing.
A CombineFn specifies how multiple values in all or part of a PCollection can
@@ -430,6 +437,11 @@
def maybe_from_callable(fn):
return fn if isinstance(fn, CombineFn) else CallableWrapperCombineFn(fn)
+ def get_accumulator_coder(self):
+ return coders.registry.get_coder(object)
+
+ urns.RunnerApiFn.register_pickle_urn(urns.PICKLED_COMBINE_FN)
+
class CallableWrapperCombineFn(CombineFn):
"""For internal use only; no backwards-compatibility guarantees.
@@ -680,6 +692,37 @@
raise ValueError('Unexpected keyword arguments: %s' % main_kw.keys())
return _MultiParDo(self, tags, main_tag)
+ def _pardo_fn_data(self):
+ si_tags_and_types = []
+ windowing = None
+ return self.fn, self.args, self.kwargs, si_tags_and_types, windowing
+
+ def to_runner_api_parameter(self, context):
+ assert self.__class__ is ParDo
+ return (
+ urns.PARDO_TRANSFORM,
+ beam_runner_api_pb2.ParDoPayload(
+ do_fn=beam_runner_api_pb2.SdkFunctionSpec(
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=urns.PICKLED_DO_FN_INFO,
+ parameter=proto_utils.pack_Any(
+ wrappers_pb2.BytesValue(
+ value=pickler.dumps(
+ self._pardo_fn_data())))))))
+
+ @PTransform.register_urn(
+ urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload)
+ def from_runner_api_parameter(pardo_payload, context):
+ assert pardo_payload.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
+ fn, args, kwargs, si_tags_and_types, windowing = pickler.loads(
+ proto_utils.unpack_Any(
+ pardo_payload.do_fn.spec.parameter, wrappers_pb2.BytesValue).value)
+ if si_tags_and_types:
+ raise NotImplementedError('deferred side inputs')
+ elif windowing:
+ raise NotImplementedError('explicit windowing')
+ return ParDo(fn, *args, **kwargs)
+
class _MultiParDo(PTransform):
@@ -816,6 +859,13 @@
return pardo
+def _combine_payload(combine_fn, context):
+ return beam_runner_api_pb2.CombinePayload(
+ combine_fn=combine_fn.to_runner_api(context),
+ accumulator_coder_id=context.coders.get_id(
+ combine_fn.get_accumulator_coder()))
+
+
class CombineGlobally(PTransform):
"""A CombineGlobally transform.
@@ -973,6 +1023,17 @@
return pcoll | GroupByKey() | 'Combine' >> CombineValues(
self.fn, *args, **kwargs)
+ def to_runner_api_parameter(self, context):
+ return (
+ urns.COMBINE_PER_KEY_TRANSFORM,
+ _combine_payload(self.fn, context))
+
+ @PTransform.register_urn(
+ urns.COMBINE_PER_KEY_TRANSFORM, beam_runner_api_pb2.CombinePayload)
+ def from_runner_api_parameter(combine_payload, context):
+ return CombinePerKey(
+ CombineFn.from_runner_api(combine_payload.combine_fn, context))
+
# TODO(robertwb): Rename to CombineGroupedValues?
class CombineValues(PTransformWithSideInputs):
@@ -995,6 +1056,17 @@
CombineValuesDoFn(key_type, self.fn, runtime_type_check),
*args, **kwargs)
+ def to_runner_api_parameter(self, context):
+ return (
+ urns.COMBINE_GROUPED_VALUES_TRANSFORM,
+ _combine_payload(self.fn, context))
+
+ @PTransform.register_urn(
+ urns.COMBINE_GROUPED_VALUES_TRANSFORM, beam_runner_api_pb2.CombinePayload)
+ def from_runner_api_parameter(combine_payload, context):
+ return CombineValues(
+ CombineFn.from_runner_api(combine_payload.combine_fn, context))
+
class CombineValuesDoFn(DoFn):
"""DoFn for performing per-key Combine transforms."""
@@ -1112,6 +1184,13 @@
| 'GroupByKey' >> _GroupByKeyOnly()
| 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing))
+ def to_runner_api_parameter(self, unused_context):
+ return urns.GROUP_BY_KEY_TRANSFORM, None
+
+ @PTransform.register_urn(urns.GROUP_BY_KEY_TRANSFORM, None)
+ def from_runner_api_parameter(unused_payload, unused_context):
+ return GroupByKey()
+
@typehints.with_input_types(typehints.KV[K, V])
@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
@@ -1125,6 +1204,13 @@
self._check_pcollection(pcoll)
return pvalue.PCollection(pcoll.pipeline)
+ def to_runner_api_parameter(self, unused_context):
+ return urns.GROUP_BY_KEY_ONLY_TRANSFORM, None
+
+ @PTransform.register_urn(urns.GROUP_BY_KEY_ONLY_TRANSFORM, None)
+ def from_runner_api_parameter(unused_payload, unused_context):
+ return _GroupByKeyOnly()
+
@typehints.with_input_types(typehints.KV[K, typehints.Iterable[V]])
@typehints.with_output_types(typehints.KV[K, typehints.Iterable[V]])
@@ -1139,6 +1225,18 @@
self._check_pcollection(pcoll)
return pvalue.PCollection(pcoll.pipeline)
+ def to_runner_api_parameter(self, context):
+ return (
+ urns.GROUP_ALSO_BY_WINDOW_TRANSFORM,
+ wrappers_pb2.BytesValue(value=context.windowing_strategies.get_id(
+ self.windowing)))
+
+ @PTransform.register_urn(
+ urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue)
+ def from_runner_api_parameter(payload, context):
+ return _GroupAlsoByWindow(
+ context.windowing_strategies.get_by_id(payload.value))
+
class _GroupAlsoByWindowDoFn(DoFn):
# TODO(robertwb): Support combiner lifting.
@@ -1363,6 +1461,7 @@
# (Right now only WindowFn is used, but we need this to reconstitute the
# WindowInto transform, and in the future will need it at runtime to
# support meta-data driven triggers.)
+ # TODO(robertwb): Use a reference rather than embedding?
beam_runner_api_pb2.WindowingStrategy,
WindowInto.from_runner_api_parameter)
@@ -1402,7 +1501,10 @@
def expand(self, pcolls):
for pcoll in pcolls:
self._check_pcollection(pcoll)
- return pvalue.PCollection(self.pipeline)
+ result = pvalue.PCollection(self.pipeline)
+ result.element_type = typehints.Union[
+ tuple(pcoll.element_type for pcoll in pcolls)]
+ return result
def get_windowing(self, inputs):
if not inputs:
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index cd84122..da113e0 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -426,8 +426,16 @@
_known_urns = {}
@classmethod
- def register_urn(cls, urn, parameter_type, constructor):
- cls._known_urns[urn] = parameter_type, constructor
+ def register_urn(cls, urn, parameter_type, constructor=None):
+ def register(constructor):
+ cls._known_urns[urn] = parameter_type, constructor
+ return staticmethod(constructor)
+ if constructor:
+ # Used as a statement.
+ register(constructor)
+ else:
+ # Used as a decorator.
+ return register
def to_runner_api(self, context):
from apache_beam.portability.api import beam_runner_api_pb2
diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py
index e553eea..0013cb3 100644
--- a/sdks/python/apache_beam/utils/urns.py
+++ b/sdks/python/apache_beam/utils/urns.py
@@ -32,12 +32,24 @@
SLIDING_WINDOWS_FN = "beam:windowfn:sliding_windows:v0.1"
SESSION_WINDOWS_FN = "beam:windowfn:session_windows:v0.1"
+PICKLED_DO_FN = "beam:dofn:pickled_python:v0.1"
+PICKLED_DO_FN_INFO = "beam:dofn:pickled_python_info:v0.1"
+PICKLED_COMBINE_FN = "beam:combinefn:pickled_python:v0.1"
PICKLED_CODER = "beam:coder:pickled_python:v0.1"
PICKLED_TRANSFORM = "beam:ptransform:pickled_python:v0.1"
+PARDO_TRANSFORM = "beam:ptransform:pardo:v0.1"
+GROUP_BY_KEY_TRANSFORM = "beam:ptransform:group_by_key:v0.1"
+GROUP_BY_KEY_ONLY_TRANSFORM = "beam:ptransform:group_by_key_only:v0.1"
+GROUP_ALSO_BY_WINDOW_TRANSFORM = "beam:ptransform:group_also_by_window:v0.1"
+COMBINE_PER_KEY_TRANSFORM = "beam:ptransform:combine_per_key:v0.1"
+COMBINE_GROUPED_VALUES_TRANSFORM = "beam:ptransform:combine_grouped_values:v0.1"
FLATTEN_TRANSFORM = "beam:ptransform:flatten:v0.1"
+READ_TRANSFORM = "beam:ptransform:read:v0.1"
WINDOW_INTO_TRANSFORM = "beam:ptransform:window_into:v0.1"
+PICKLED_SOURCE = "beam:source:pickled_python:v0.1"
+
class RunnerApiFn(object):
"""Abstract base class that provides urn registration utilities.
@@ -50,7 +62,8 @@
to register serialization via pickling.
"""
- __metaclass__ = abc.ABCMeta
+ # TODO(BEAM-2685): Issue with dill + local classes + abc metaclass
+ # __metaclass__ = abc.ABCMeta
_known_urns = {}