| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.core.construction; |
| |
| import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; |
| import static org.junit.Assert.assertEquals; |
| |
| import java.io.IOException; |
| import java.util.concurrent.atomic.AtomicReference; |
| import org.apache.beam.model.pipeline.v1.RunnerApi; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload; |
| import org.apache.beam.sdk.Pipeline.PipelineVisitor; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.CoderRegistry; |
| import org.apache.beam.sdk.coders.VoidCoder; |
| import org.apache.beam.sdk.runners.AppliedPTransform; |
| import org.apache.beam.sdk.runners.TransformHierarchy.Node; |
| import org.apache.beam.sdk.testing.TestPipeline; |
| import org.apache.beam.sdk.transforms.Combine; |
| import org.apache.beam.sdk.transforms.Combine.BinaryCombineIntegerFn; |
| import org.apache.beam.sdk.transforms.Combine.CombineFn; |
| import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; |
| import org.apache.beam.sdk.transforms.CombineWithContext.Context; |
| import org.apache.beam.sdk.transforms.Count; |
| import org.apache.beam.sdk.transforms.Create; |
| import org.apache.beam.sdk.transforms.Sum; |
| import org.apache.beam.sdk.transforms.View; |
| import org.apache.beam.sdk.util.SerializableUtils; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionView; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.rules.ExpectedException; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| import org.junit.runners.Parameterized; |
| import org.junit.runners.Parameterized.Parameter; |
| import org.junit.runners.Parameterized.Parameters; |
| |
| /** Tests for {@link CombineTranslation}. */ |
| public class CombineTranslationTest { |
| |
| /** Tests that simple {@link CombineFn CombineFns} can be translated to and from proto. */ |
| @RunWith(Parameterized.class) |
| public static class TranslateSimpleCombinesTest { |
| @Parameters(name = "{index}: {0}") |
| public static Iterable<Combine.CombineFn<Integer, ?, ?>> params() { |
| BinaryCombineIntegerFn sum = Sum.ofIntegers(); |
| CombineFn<Integer, ?, Long> count = Count.combineFn(); |
| TestCombineFn test = new TestCombineFn(); |
| return ImmutableList.<CombineFn<Integer, ?, ?>>builder() |
| .add(sum) |
| .add(count) |
| .add(test) |
| .build(); |
| } |
| |
| @Rule public TestPipeline pipeline = TestPipeline.create(); |
| |
| @Parameter(0) |
| public Combine.CombineFn<Integer, ?, ?> combineFn; |
| |
| @Test |
| public void testToProto() throws Exception { |
| PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3)); |
| input.apply(Combine.globally(combineFn)); |
| final AtomicReference<AppliedPTransform<?, ?, Combine.Globally<?, ?>>> combine = |
| new AtomicReference<>(); |
| pipeline.traverseTopologically( |
| new PipelineVisitor.Defaults() { |
| @Override |
| public void leaveCompositeTransform(Node node) { |
| if (node.getTransform() instanceof Combine.Globally) { |
| checkState(combine.get() == null); |
| combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline())); |
| } |
| } |
| }); |
| checkState(combine.get() != null); |
| assertEquals(combineFn, combine.get().getTransform().getFn()); |
| |
| SdkComponents sdkComponents = SdkComponents.create(); |
| sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); |
| CombinePayload combineProto = |
| CombineTranslation.CombineGloballyPayloadTranslator.payloadForCombineGlobally( |
| (AppliedPTransform) combine.get(), sdkComponents); |
| RunnerApi.Components componentsProto = sdkComponents.toComponents(); |
| |
| assertEquals( |
| combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()), |
| getAccumulatorCoder(combineProto, RehydratedComponents.forComponents(componentsProto))); |
| assertEquals( |
| combineFn, |
| SerializableUtils.deserializeFromByteArray( |
| combineProto.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn")); |
| } |
| } |
| |
| /** Tests that a {@link CombineFnWithContext} can be translated. */ |
| @RunWith(JUnit4.class) |
| public static class ValidateCombineWithContextTest { |
| @Rule public TestPipeline pipeline = TestPipeline.create(); |
| @Rule public ExpectedException exception = ExpectedException.none(); |
| |
| @Test |
| public void testToProtoWithoutSideInputs() throws Exception { |
| PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3)); |
| CombineFnWithContext<Integer, int[], Integer> combineFn = new TestCombineFnWithContext(); |
| input.apply(Combine.globally(combineFn).withoutDefaults()); |
| final AtomicReference<AppliedPTransform<?, ?, Combine.Globally<?, ?>>> combine = |
| new AtomicReference<>(); |
| pipeline.traverseTopologically( |
| new PipelineVisitor.Defaults() { |
| @Override |
| public void leaveCompositeTransform(Node node) { |
| if (node.getTransform() instanceof Combine.Globally) { |
| checkState(combine.get() == null); |
| combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline())); |
| } |
| } |
| }); |
| checkState(combine.get() != null); |
| assertEquals(combineFn, combine.get().getTransform().getFn()); |
| |
| SdkComponents sdkComponents = SdkComponents.create(); |
| sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); |
| CombinePayload combineProto = |
| CombineTranslation.CombineGloballyPayloadTranslator.payloadForCombineGlobally( |
| (AppliedPTransform) combine.get(), sdkComponents); |
| RunnerApi.Components componentsProto = sdkComponents.toComponents(); |
| |
| assertEquals( |
| combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()), |
| getAccumulatorCoder(combineProto, RehydratedComponents.forComponents(componentsProto))); |
| assertEquals( |
| combineFn, |
| SerializableUtils.deserializeFromByteArray( |
| combineProto.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn")); |
| } |
| |
| @Test |
| public void testToProtoWithSideInputsFails() throws Exception { |
| exception.expect(IllegalArgumentException.class); |
| |
| PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3)); |
| final PCollectionView<Iterable<String>> sideInputs = |
| pipeline.apply(Create.of("foo")).apply(View.asIterable()); |
| |
| CombineFnWithContext<Integer, int[], Integer> combineFn = |
| new TestCombineFnWithContext() { |
| @Override |
| public Integer extractOutput(int[] accumulator, Context c) { |
| Iterable<String> sideInput = c.sideInput(sideInputs); |
| return accumulator[0]; |
| } |
| }; |
| |
| input.apply(Combine.globally(combineFn).withSideInputs(sideInputs).withoutDefaults()); |
| final AtomicReference<AppliedPTransform<?, ?, Combine.Globally<?, ?>>> combine = |
| new AtomicReference<>(); |
| pipeline.traverseTopologically( |
| new PipelineVisitor.Defaults() { |
| @Override |
| public void leaveCompositeTransform(Node node) { |
| if (node.getTransform() instanceof Combine.Globally) { |
| checkState(combine.get() == null); |
| combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline())); |
| } |
| } |
| }); |
| |
| SdkComponents sdkComponents = SdkComponents.create(); |
| sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java")); |
| CombinePayload payload = |
| CombineTranslation.CombineGloballyPayloadTranslator.payloadForCombineGlobally( |
| (AppliedPTransform) combine.get(), sdkComponents); |
| } |
| } |
| |
| private static Coder<?> getAccumulatorCoder( |
| CombinePayload payload, RehydratedComponents components) throws IOException { |
| String id = payload.getAccumulatorCoderId(); |
| return components.getCoder(id); |
| } |
| |
| private static class TestCombineFn extends Combine.CombineFn<Integer, Void, Void> { |
| @Override |
| public Void createAccumulator() { |
| return null; |
| } |
| |
| @Override |
| public Coder<Void> getAccumulatorCoder(CoderRegistry registry, Coder<Integer> inputCoder) { |
| return (Coder) VoidCoder.of(); |
| } |
| |
| @Override |
| public Void extractOutput(Void accumulator) { |
| return accumulator; |
| } |
| |
| @Override |
| public Void mergeAccumulators(Iterable<Void> accumulators) { |
| return null; |
| } |
| |
| @Override |
| public Void addInput(Void accumulator, Integer input) { |
| return accumulator; |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return other != null && other.getClass().equals(TestCombineFn.class); |
| } |
| |
| @Override |
| public int hashCode() { |
| return TestCombineFn.class.hashCode(); |
| } |
| } |
| |
| private static class TestCombineFnWithContext |
| extends CombineFnWithContext<Integer, int[], Integer> { |
| |
| @Override |
| public int[] createAccumulator(Context c) { |
| return new int[1]; |
| } |
| |
| @Override |
| public int[] addInput(int[] accumulator, Integer input, Context c) { |
| accumulator[0] += input; |
| return accumulator; |
| } |
| |
| @Override |
| public int[] mergeAccumulators(Iterable<int[]> accumulators, Context c) { |
| int[] res = new int[1]; |
| for (int[] accum : accumulators) { |
| res[0] += accum[0]; |
| } |
| return res; |
| } |
| |
| @Override |
| public Integer extractOutput(int[] accumulator, Context c) { |
| return accumulator[0]; |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return other instanceof TestCombineFnWithContext; |
| } |
| |
| @Override |
| public int hashCode() { |
| return TestCombineFnWithContext.class.hashCode(); |
| } |
| } |
| } |