[NEMO-460] Setting coders in CombinePerKey transformation (#303)
JIRA: [NEMO-460: Setting coders in CombinePerKey transformation](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-460)
**Major changes:**
- Added the additional parameter "inputCoder" for GBKTransform constructor.
- Fixed the input coder and the output coder for the partial combine transform and the final combine transform.
**Minor changes to note:**
- Fixed the main output TupleTags for the partial combine transform and the final combine transform.
**Tests for the changes:**
- Current tests suffice.
**Other comments:**
- This needs to be merged after merging #302
Closes #303
diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
index 21ded1c..cd9d7ad 100644
--- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
+++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java
@@ -406,10 +406,11 @@
KvCoder.of(inputCoder.getKeyCoder(),
accumulatorCoder),
null, mainInput.getWindowingStrategy()));
+ final TupleTag<?> partialMainOutputTag = new TupleTag<>();
final GBKTransform partialCombineStreamTransform =
- new GBKTransform(
- getOutputCoders(pTransform),
- new TupleTag<>(),
+ new GBKTransform(inputCoder,
+ Collections.singletonMap(partialMainOutputTag, KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)),
+ partialMainOutputTag,
mainInput.getWindowingStrategy(),
ctx.getPipelineOptions(),
partialSystemReduceFn,
@@ -418,9 +419,9 @@
true);
final GBKTransform finalCombineStreamTransform =
- new GBKTransform(
+ new GBKTransform(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder),
getOutputCoders(pTransform),
- new TupleTag<>(),
+ Iterables.getOnlyElement(beamNode.getOutputs().keySet()),
mainInput.getWindowingStrategy(),
ctx.getPipelineOptions(),
finalSystemReduceFn,
@@ -556,7 +557,7 @@
final AppliedPTransform<?, ?, ?> pTransform = beamNode.toAppliedPTransform(ctx.getPipeline());
final PCollection<?> mainInput = (PCollection<?>)
Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(pTransform));
- final TupleTag mainOutputTag = new TupleTag<>();
+ final TupleTag mainOutputTag = Iterables.getOnlyElement(beamNode.getOutputs().keySet());
if (isGlobalWindow(beamNode, ctx.getPipeline())) {
// GroupByKey Transform when using a global windowing strategy.
@@ -564,6 +565,7 @@
} else {
// GroupByKey Transform when using a non-global windowing strategy.
return new GBKTransform<>(
+ (KvCoder) mainInput.getCoder(),
getOutputCoders(pTransform),
mainOutputTag,
mainInput.getWindowingStrategy(),
diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java
index 9dd2e5a..1bf6cb8 100644
--- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java
+++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java
@@ -58,7 +58,8 @@
private transient OutputCollector originOc;
private final boolean isPartialCombining;
- public GBKTransform(final Map<TupleTag<?>, Coder<?>> outputCoders,
+ public GBKTransform(final Coder<KV<K, InputT>> inputCoder,
+ final Map<TupleTag<?>, Coder<?>> outputCoders,
final TupleTag<KV<K, OutputT>> mainOutputTag,
final WindowingStrategy<?, ?> windowingStrategy,
final PipelineOptions options,
@@ -67,7 +68,7 @@
final DisplayData displayData,
final boolean isPartialCombining) {
super(null,
- null,
+ inputCoder,
outputCoders,
mainOutputTag,
Collections.emptyList(), /* no additional outputs */
@@ -278,7 +279,7 @@
/** Emit output. If {@param output} is emitted on-time, save its timestamp in the output watermark map. */
@Override
- public void emit(final WindowedValue<KV<K, OutputT>> output) {
+ public final void emit(final WindowedValue<KV<K, OutputT>> output) {
// The watermark advances only in ON_TIME
if (output.getPane().getTiming().equals(PaneInfo.Timing.ON_TIME)) {
KV<K, OutputT> value = output.getValue();
@@ -296,13 +297,13 @@
/** Emit watermark. */
@Override
- public void emitWatermark(final Watermark watermark) {
+ public final void emitWatermark(final Watermark watermark) {
oc.emitWatermark(watermark);
}
/** Emit output value to {@param dstVertexId}. */
@Override
- public <T> void emit(final String dstVertexId, final T output) {
+ public final <T> void emit(final String dstVertexId, final T output) {
oc.emit(dstVertexId, output);
}
}
diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java
index 45933b0..3c08c50 100644
--- a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java
+++ b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java
@@ -18,6 +18,7 @@
*/
package org.apache.nemo.compiler.frontend.beam.transform;
+import com.google.common.collect.Iterables;
import junit.framework.TestCase;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.sdk.coders.*;
@@ -41,15 +42,12 @@
import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.*;
import static org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES;
-import static org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode.DISCARDING_FIRED_PANES;
-import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
public class GBKTransformTest extends TestCase {
private static final Logger LOG = LoggerFactory.getLogger(GBKTransformTest.class.getName());
private final static Coder STRING_CODER = StringUtf8Coder.of();
private final static Coder INTEGER_CODER = BigEndianIntegerCoder.of();
- private final static Map<TupleTag<?>, Coder<?>> NULL_OUTPUT_CODERS = null;
private void checkOutput(final KV<String, Integer> expected, final KV<String, Integer> result) {
// check key
@@ -155,7 +153,8 @@
final GBKTransform<String, Integer, Integer> combine_transform =
new GBKTransform(
- NULL_OUTPUT_CODERS,
+ KvCoder.of(STRING_CODER, INTEGER_CODER),
+ Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)),
outputTag,
WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES),
PipelineOptionsFactory.as(NemoPipelineOptions.class),
@@ -283,7 +282,8 @@
final GBKTransform<String, Integer, Integer> combine_transform =
new GBKTransform(
- NULL_OUTPUT_CODERS,
+ KvCoder.of(STRING_CODER, INTEGER_CODER),
+ Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)),
outputTag,
WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES).withAllowedLateness(lateness),
PipelineOptionsFactory.as(NemoPipelineOptions.class),
@@ -377,7 +377,8 @@
final GBKTransform<String, String, Iterable<String>> doFnTransform =
new GBKTransform(
- NULL_OUTPUT_CODERS,
+ KvCoder.of(STRING_CODER, STRING_CODER),
+ Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))),
outputTag,
WindowingStrategy.of(slidingWindows),
PipelineOptionsFactory.as(NemoPipelineOptions.class),
@@ -562,7 +563,8 @@
final GBKTransform<String, String, Iterable<String>> doFnTransform =
new GBKTransform(
- NULL_OUTPUT_CODERS,
+ KvCoder.of(STRING_CODER, STRING_CODER),
+ Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))),
outputTag,
WindowingStrategy.of(window).withTrigger(trigger)
.withMode(ACCUMULATING_FIRED_PANES)