| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.dataflow.util; |
| |
| import static org.hamcrest.MatcherAssert.assertThat; |
| import static org.hamcrest.Matchers.emptyIterable; |
| import static org.hamcrest.Matchers.hasItem; |
| import static org.hamcrest.Matchers.instanceOf; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.assertTrue; |
| |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.io.Serializable; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Objects; |
| import java.util.Set; |
| import org.apache.beam.runners.core.construction.SdkComponents; |
| import org.apache.beam.sdk.coders.AvroCoder; |
| import org.apache.beam.sdk.coders.ByteArrayCoder; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.CoderException; |
| import org.apache.beam.sdk.coders.CollectionCoder; |
| import org.apache.beam.sdk.coders.CustomCoder; |
| import org.apache.beam.sdk.coders.IterableCoder; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.coders.LengthPrefixCoder; |
| import org.apache.beam.sdk.coders.ListCoder; |
| import org.apache.beam.sdk.coders.MapCoder; |
| import org.apache.beam.sdk.coders.NullableCoder; |
| import org.apache.beam.sdk.coders.SerializableCoder; |
| import org.apache.beam.sdk.coders.SetCoder; |
| import org.apache.beam.sdk.coders.StructuredCoder; |
| import org.apache.beam.sdk.coders.VarLongCoder; |
| import org.apache.beam.sdk.schemas.Schema; |
| import org.apache.beam.sdk.schemas.Schema.FieldType; |
| import org.apache.beam.sdk.schemas.SchemaCoder; |
| import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; |
| import org.apache.beam.sdk.transforms.SerializableFunction; |
| import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder; |
| import org.apache.beam.sdk.transforms.join.CoGbkResultSchema; |
| import org.apache.beam.sdk.transforms.join.UnionCoder; |
| import org.apache.beam.sdk.transforms.windowing.GlobalWindow; |
| import org.apache.beam.sdk.transforms.windowing.IntervalWindow; |
| import org.apache.beam.sdk.util.InstanceBuilder; |
| import org.apache.beam.sdk.util.WindowedValue; |
| import org.apache.beam.sdk.values.Row; |
| import org.apache.beam.sdk.values.TupleTag; |
| import org.apache.beam.sdk.values.TypeDescriptors; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList.Builder; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; |
| import org.junit.Test; |
| import org.junit.experimental.runners.Enclosed; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| import org.junit.runners.Parameterized; |
| import org.junit.runners.Parameterized.Parameter; |
| import org.junit.runners.Parameterized.Parameters; |
| |
| /** Tests for {@link CloudObjects}. */ |
| @RunWith(Enclosed.class) |
| public class CloudObjectsTest { |
| private static final Schema TEST_SCHEMA = |
| Schema.builder() |
| .addBooleanField("bool") |
| .addByteField("int8") |
| .addInt16Field("int16") |
| .addInt32Field("int32") |
| .addInt64Field("int64") |
| .addFloatField("float") |
| .addDoubleField("double") |
| .addStringField("string") |
| .addArrayField("list_int32", FieldType.INT32) |
| .addLogicalTypeField("fixed_bytes", FixedBytes.of(4)) |
| .build(); |
| |
| /** Tests that all of the Default Coders are tested. */ |
| @RunWith(JUnit4.class) |
| public static class DefaultsPresentTest { |
| @Test |
| public void defaultCodersAllTested() { |
| Set<Class<? extends Coder>> defaultCoderTranslators = |
| new DefaultCoderCloudObjectTranslatorRegistrar().classesToTranslators().keySet(); |
| Set<Class<? extends Coder>> testedClasses = new HashSet<>(); |
| for (Coder<?> tested : DefaultCoders.data()) { |
| if (tested instanceof ObjectCoder || tested instanceof ArbitraryCoder) { |
| testedClasses.add(CustomCoder.class); |
| assertThat(defaultCoderTranslators, hasItem(CustomCoder.class)); |
| } else { |
| testedClasses.add(tested.getClass()); |
| assertThat(defaultCoderTranslators, hasItem(tested.getClass())); |
| } |
| } |
| Set<Class<? extends Coder>> missing = new HashSet<>(); |
| missing.addAll(defaultCoderTranslators); |
| missing.removeAll(testedClasses); |
| assertThat("Coders with custom serializers should all be tested", missing, emptyIterable()); |
| } |
| |
| @Test |
| public void defaultCodersIncludesCustomCoder() { |
| Set<Class<? extends Coder>> defaultCoders = |
| new DefaultCoderCloudObjectTranslatorRegistrar().classesToTranslators().keySet(); |
| assertThat(defaultCoders, hasItem(CustomCoder.class)); |
| } |
| } |
| |
| /** |
| * Tests that all of the registered coders in {@link DefaultCoderCloudObjectTranslatorRegistrar} |
| * can be serialized and deserialized with {@link CloudObjects}. |
| */ |
| @RunWith(Parameterized.class) |
| public static class DefaultCoders { |
| @Parameters(name = "{index}: {0}") |
| public static Iterable<Coder<?>> data() { |
| Builder<Coder<?>> dataBuilder = |
| ImmutableList.<Coder<?>>builder() |
| .add(new ArbitraryCoder()) |
| .add(new ObjectCoder()) |
| .add(GlobalWindow.Coder.INSTANCE) |
| .add(IntervalWindow.getCoder()) |
| .add(LengthPrefixCoder.of(VarLongCoder.of())) |
| .add(IterableCoder.of(VarLongCoder.of())) |
| .add(KvCoder.of(VarLongCoder.of(), ByteArrayCoder.of())) |
| .add( |
| WindowedValue.getFullCoder( |
| KvCoder.of(VarLongCoder.of(), ByteArrayCoder.of()), |
| IntervalWindow.getCoder())) |
| .add(ByteArrayCoder.of()) |
| .add(VarLongCoder.of()) |
| .add(SerializableCoder.of(Record.class)) |
| .add(AvroCoder.of(Record.class)) |
| .add(CollectionCoder.of(VarLongCoder.of())) |
| .add(ListCoder.of(VarLongCoder.of())) |
| .add(SetCoder.of(VarLongCoder.of())) |
| .add(MapCoder.of(VarLongCoder.of(), ByteArrayCoder.of())) |
| .add(NullableCoder.of(IntervalWindow.getCoder())) |
| .add( |
| UnionCoder.of( |
| ImmutableList.of( |
| VarLongCoder.of(), |
| ByteArrayCoder.of(), |
| KvCoder.of(VarLongCoder.of(), ByteArrayCoder.of())))) |
| .add( |
| CoGbkResultCoder.of( |
| CoGbkResultSchema.of( |
| ImmutableList.of(new TupleTag<Long>(), new TupleTag<byte[]>())), |
| UnionCoder.of(ImmutableList.of(VarLongCoder.of(), ByteArrayCoder.of())))) |
| .add( |
| SchemaCoder.of( |
| Schema.builder().build(), |
| TypeDescriptors.rows(), |
| new RowIdentity(), |
| new RowIdentity())) |
| .add( |
| SchemaCoder.of( |
| TEST_SCHEMA, TypeDescriptors.rows(), new RowIdentity(), new RowIdentity())); |
| for (Class<? extends Coder> atomicCoder : |
| DefaultCoderCloudObjectTranslatorRegistrar.KNOWN_ATOMIC_CODERS) { |
| dataBuilder.add(InstanceBuilder.ofType(atomicCoder).fromFactoryMethod("of").build()); |
| } |
| return dataBuilder.build(); |
| } |
| |
| @Parameter(0) |
| public Coder<?> coder; |
| |
| @Test |
| public void toAndFromCloudObject() throws Exception { |
| CloudObject cloudObject = CloudObjects.asCloudObject(coder, /*sdkComponents=*/ null); |
| Coder<?> fromCloudObject = CloudObjects.coderFromCloudObject(cloudObject); |
| |
| assertEquals(coder.getClass(), fromCloudObject.getClass()); |
| assertEquals(coder, fromCloudObject); |
| } |
| |
| @Test |
| public void toAndFromCloudObjectWithSdkComponents() throws Exception { |
| SdkComponents sdkComponents = SdkComponents.create(); |
| CloudObject cloudObject = CloudObjects.asCloudObject(coder, sdkComponents); |
| Coder<?> fromCloudObject = CloudObjects.coderFromCloudObject(cloudObject); |
| |
| assertEquals(coder.getClass(), fromCloudObject.getClass()); |
| assertEquals(coder, fromCloudObject); |
| |
| checkPipelineProtoCoderIds(coder, cloudObject, sdkComponents); |
| } |
| |
| private static void checkPipelineProtoCoderIds( |
| Coder<?> coder, CloudObject cloudObject, SdkComponents sdkComponents) throws Exception { |
| if (CloudObjects.DATAFLOW_KNOWN_CODERS.contains(coder.getClass())) { |
| assertFalse(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID)); |
| } else { |
| assertTrue(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID)); |
| assertEquals( |
| sdkComponents.registerCoder(coder), |
| ((CloudObject) cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID)) |
| .get(PropertyNames.VALUE)); |
| } |
| List<? extends Coder<?>> expectedComponents; |
| if (coder instanceof StructuredCoder) { |
| expectedComponents = ((StructuredCoder) coder).getComponents(); |
| } else { |
| expectedComponents = coder.getCoderArguments(); |
| } |
| Object cloudComponentsObject = cloudObject.get(PropertyNames.COMPONENT_ENCODINGS); |
| List<CloudObject> cloudComponents; |
| if (cloudComponentsObject == null) { |
| cloudComponents = Lists.newArrayList(); |
| } else { |
| assertThat(cloudComponentsObject, instanceOf(List.class)); |
| cloudComponents = (List<CloudObject>) cloudComponentsObject; |
| } |
| assertEquals(expectedComponents.size(), cloudComponents.size()); |
| for (int i = 0; i < expectedComponents.size(); i++) { |
| checkPipelineProtoCoderIds( |
| expectedComponents.get(i), cloudComponents.get(i), sdkComponents); |
| } |
| } |
| } |
| |
| private static class Record implements Serializable {} |
| |
| private static class ObjectCoder extends CustomCoder<Object> { |
| @Override |
| public void encode(Object value, OutputStream outStream) throws CoderException, IOException {} |
| |
| @Override |
| public Object decode(InputStream inStream) throws CoderException, IOException { |
| return new Object(); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| return other != null && getClass().equals(other.getClass()); |
| } |
| |
| @Override |
| public int hashCode() { |
| return getClass().hashCode(); |
| } |
| } |
| |
| /** A non-custom coder with no registered translator. */ |
| private static class ArbitraryCoder extends StructuredCoder<Record> { |
| @Override |
| public void encode(Record value, OutputStream outStream) throws CoderException, IOException {} |
| |
| @Override |
| public Record decode(InputStream inStream) throws CoderException, IOException { |
| return new Record(); |
| } |
| |
| @Override |
| public List<? extends Coder<?>> getCoderArguments() { |
| return Collections.emptyList(); |
| } |
| |
| @Override |
| public void verifyDeterministic() throws NonDeterministicException {} |
| } |
| |
| /** Hack to satisfy SchemaCoder.equals until BEAM-8146 is fixed. */ |
| private static class RowIdentity implements SerializableFunction<Row, Row> { |
| @Override |
| public Row apply(Row input) { |
| return input; |
| } |
| |
| @Override |
| public int hashCode() { |
| return Objects.hash(getClass()); |
| } |
| |
| @Override |
| public boolean equals(Object o) { |
| if (this == o) { |
| return true; |
| } |
| return o != null && getClass() == o.getClass(); |
| } |
| } |
| } |