blob: 2de2fe83ee66000b98c3d428dff841a14b05dadf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects.firstNonNull;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList.toImmutableList;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi.StandardCoders;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.ByteCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.Context;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CharStreams;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
/** Tests that Java SDK coders standardized by the Fn API meet the common spec. */
@RunWith(Parameterized.class)
public class CommonCoderTest {
private static final String STANDARD_CODERS_YAML_PATH =
"/org/apache/beam/model/fnexecution/v1/standard_coders.yaml";
private static final Map<String, Class<?>> coders =
ImmutableMap.<String, Class<?>>builder()
.put(getUrn(StandardCoders.Enum.BYTES), ByteCoder.class)
.put(getUrn(StandardCoders.Enum.BOOL), BooleanCoder.class)
.put(getUrn(StandardCoders.Enum.STRING_UTF8), StringUtf8Coder.class)
.put(getUrn(StandardCoders.Enum.KV), KvCoder.class)
.put(getUrn(StandardCoders.Enum.VARINT), VarLongCoder.class)
.put(getUrn(StandardCoders.Enum.INTERVAL_WINDOW), IntervalWindowCoder.class)
.put(getUrn(StandardCoders.Enum.ITERABLE), IterableCoder.class)
.put(getUrn(StandardCoders.Enum.TIMER), Timer.Coder.class)
.put(getUrn(StandardCoders.Enum.GLOBAL_WINDOW), GlobalWindow.Coder.class)
.put(getUrn(StandardCoders.Enum.DOUBLE), DoubleCoder.class)
.put(
getUrn(StandardCoders.Enum.WINDOWED_VALUE),
WindowedValue.FullWindowedValueCoder.class)
.put(getUrn(StandardCoders.Enum.ROW), RowCoder.class)
.build();
@AutoValue
abstract static class CommonCoder {
abstract String getUrn();
abstract List<CommonCoder> getComponents();
@SuppressWarnings("mutable")
abstract byte[] getPayload();
abstract Boolean getNonDeterministic();
@JsonCreator
static CommonCoder create(
@JsonProperty("urn") String urn,
@JsonProperty("components") @Nullable List<CommonCoder> components,
@JsonProperty("payload") @Nullable String payload,
@JsonProperty("non_deterministic") @Nullable Boolean nonDeterministic) {
return new AutoValue_CommonCoderTest_CommonCoder(
checkNotNull(urn, "urn"),
firstNonNull(components, Collections.emptyList()),
firstNonNull(payload, "").getBytes(StandardCharsets.ISO_8859_1),
firstNonNull(nonDeterministic, Boolean.FALSE));
}
}
@AutoValue
abstract static class CommonCoderTestSpec {
abstract CommonCoder getCoder();
abstract @Nullable Boolean getNested();
abstract Map<String, Object> getExamples();
@JsonCreator
static CommonCoderTestSpec create(
@JsonProperty("coder") CommonCoder coder,
@JsonProperty("nested") @Nullable Boolean nested,
@JsonProperty("examples") Map<String, Object> examples) {
return new AutoValue_CommonCoderTest_CommonCoderTestSpec(coder, nested, examples);
}
}
@AutoValue
abstract static class OneCoderTestSpec {
abstract CommonCoder getCoder();
abstract boolean getNested();
@SuppressWarnings("mutable")
abstract byte[] getSerialized();
abstract Object getValue();
static OneCoderTestSpec create(
CommonCoder coder, boolean nested, byte[] serialized, Object value) {
return new AutoValue_CommonCoderTest_OneCoderTestSpec(coder, nested, serialized, value);
}
}
private static List<OneCoderTestSpec> loadStandardCodersSuite() throws IOException {
InputStream stream = CommonCoderTest.class.getResourceAsStream(STANDARD_CODERS_YAML_PATH);
if (stream == null) {
fail("Could not load standard coder specs as resource:" + STANDARD_CODERS_YAML_PATH);
}
// Would like to use the InputStream directly with Jackson, but Jackson does not seem to
// support streams of multiple entities. Instead, read the entire YAML as a String and split
// it up manually, passing each to Jackson.
String specString = CharStreams.toString(new InputStreamReader(stream, StandardCharsets.UTF_8));
Iterable<String> specs = Splitter.on("\n---\n").split(specString);
List<OneCoderTestSpec> ret = new ArrayList<>();
for (String spec : specs) {
CommonCoderTestSpec coderTestSpec = parseSpec(spec);
CommonCoder coder = coderTestSpec.getCoder();
for (Map.Entry<String, Object> oneTestSpec : coderTestSpec.getExamples().entrySet()) {
byte[] serialized = oneTestSpec.getKey().getBytes(StandardCharsets.ISO_8859_1);
Object value = oneTestSpec.getValue();
if (coderTestSpec.getNested() == null) {
// Missing nested means both
ret.add(OneCoderTestSpec.create(coder, true, serialized, value));
ret.add(OneCoderTestSpec.create(coder, false, serialized, value));
} else {
ret.add(OneCoderTestSpec.create(coder, coderTestSpec.getNested(), serialized, value));
}
}
}
return ImmutableList.copyOf(ret);
}
@Parameters(name = "{1}")
public static Iterable<Object[]> data() throws IOException {
ImmutableList.Builder<Object[]> ret = ImmutableList.builder();
for (OneCoderTestSpec test : loadStandardCodersSuite()) {
// Some tools cannot handle Unicode in test names, so omit the problematic value field.
String testname =
MoreObjects.toStringHelper(OneCoderTestSpec.class)
.add("coder", test.getCoder())
.add("nested", test.getNested())
.add("serialized", test.getSerialized())
.toString();
ret.add(new Object[] {test, testname});
}
return ret.build();
}
@Parameter(0)
public OneCoderTestSpec testSpec;
@Parameter(1)
public String ignoredTestName;
private static CommonCoderTestSpec parseSpec(String spec) throws IOException {
ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
return mapper.readValue(spec, CommonCoderTestSpec.class);
}
private static void assertCoderIsKnown(CommonCoder coder) {
assertThat("not a known coder", coders.keySet(), hasItem(coder.getUrn()));
for (CommonCoder component : coder.getComponents()) {
assertCoderIsKnown(component);
}
}
/** Converts from JSON-auto-deserialized types into the proper Java types for the known coders. */
private static Object convertValue(Object value, CommonCoder coderSpec, Coder coder) {
String s = coderSpec.getUrn();
if (s.equals(getUrn(StandardCoders.Enum.BYTES))) {
return ((String) value).getBytes(StandardCharsets.ISO_8859_1);
} else if (s.equals(getUrn(StandardCoders.Enum.BOOL))) {
return value;
} else if (s.equals(getUrn(StandardCoders.Enum.STRING_UTF8))) {
return value;
} else if (s.equals(getUrn(StandardCoders.Enum.KV))) {
Coder keyCoder = ((KvCoder) coder).getKeyCoder();
Coder valueCoder = ((KvCoder) coder).getValueCoder();
Map<String, Object> kvMap = (Map<String, Object>) value;
Object k = convertValue(kvMap.get("key"), coderSpec.getComponents().get(0), keyCoder);
Object v = convertValue(kvMap.get("value"), coderSpec.getComponents().get(1), valueCoder);
return KV.of(k, v);
} else if (s.equals(getUrn(StandardCoders.Enum.VARINT))) {
return ((Number) value).longValue();
} else if (s.equals(getUrn(StandardCoders.Enum.TIMER))) {
Map<String, Object> kvMap = (Map<String, Object>) value;
Coder<?> payloadCoder = (Coder) coder.getCoderArguments().get(0);
return Timer.of(
new Instant(((Number) kvMap.get("timestamp")).longValue()),
convertValue(kvMap.get("payload"), coderSpec.getComponents().get(0), payloadCoder));
} else if (s.equals(getUrn(StandardCoders.Enum.INTERVAL_WINDOW))) {
Map<String, Object> kvMap = (Map<String, Object>) value;
Instant end = new Instant(((Number) kvMap.get("end")).longValue());
Duration span = Duration.millis(((Number) kvMap.get("span")).longValue());
return new IntervalWindow(end.minus(span), span);
} else if (s.equals(getUrn(StandardCoders.Enum.ITERABLE))) {
Coder elementCoder = ((IterableCoder) coder).getElemCoder();
List<Object> elements = (List<Object>) value;
List<Object> convertedElements = new ArrayList<>();
for (Object element : elements) {
convertedElements.add(
convertValue(element, coderSpec.getComponents().get(0), elementCoder));
}
return convertedElements;
} else if (s.equals(getUrn(StandardCoders.Enum.GLOBAL_WINDOW))) {
return GlobalWindow.INSTANCE;
} else if (s.equals(getUrn(StandardCoders.Enum.WINDOWED_VALUE))) {
Map<String, Object> kvMap = (Map<String, Object>) value;
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) coder).getValueCoder();
Coder windowCoder = ((WindowedValue.FullWindowedValueCoder) coder).getWindowCoder();
Object windowValue =
convertValue(kvMap.get("value"), coderSpec.getComponents().get(0), valueCoder);
Instant timestamp = new Instant(((Number) kvMap.get("timestamp")).longValue());
List<BoundedWindow> windows = new ArrayList<>();
for (Object window : (List<Object>) kvMap.get("windows")) {
windows.add(
(BoundedWindow) convertValue(window, coderSpec.getComponents().get(1), windowCoder));
}
Map<String, Object> paneInfoMap = (Map<String, Object>) kvMap.get("pane");
PaneInfo paneInfo =
PaneInfo.createPane(
(boolean) paneInfoMap.get("is_first"),
(boolean) paneInfoMap.get("is_last"),
PaneInfo.Timing.valueOf((String) paneInfoMap.get("timing")),
(int) paneInfoMap.get("index"),
(int) paneInfoMap.get("on_time_index"));
return WindowedValue.of(windowValue, timestamp, windows, paneInfo);
} else if (s.equals(getUrn(StandardCoders.Enum.DOUBLE))) {
return Double.parseDouble((String) value);
} else if (s.equals(getUrn(StandardCoders.Enum.ROW))) {
Schema schema;
try {
schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(coderSpec.getPayload()));
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Failed to parse schema payload for row coder", e);
}
return parseField(value, Schema.FieldType.row(schema));
} else {
throw new IllegalStateException("Unknown coder URN: " + coderSpec.getUrn());
}
}
private static Object parseField(Object value, Schema.FieldType fieldType) {
switch (fieldType.getTypeName()) {
case BYTE:
return ((Number) value).byteValue();
case INT16:
return ((Number) value).shortValue();
case INT32:
return ((Number) value).intValue();
case INT64:
return ((Number) value).longValue();
case FLOAT:
return Float.parseFloat((String) value);
case DOUBLE:
return Double.parseDouble((String) value);
case STRING:
return (String) value;
case BOOLEAN:
return (Boolean) value;
case BYTES:
// extract String as byte[]
return ((String) value).getBytes(StandardCharsets.ISO_8859_1);
case ARRAY:
return ((List<Object>) value)
.stream()
.map((element) -> parseField(element, fieldType.getCollectionElementType()))
.collect(toImmutableList());
case MAP:
Map<Object, Object> kvMap = (Map<Object, Object>) value;
return kvMap.entrySet().stream()
.collect(
toImmutableMap(
(pair) -> parseField(pair.getKey(), fieldType.getMapKeyType()),
(pair) -> parseField(pair.getValue(), fieldType.getMapValueType())));
case ROW:
Map<String, Object> rowMap = (Map<String, Object>) value;
Schema schema = fieldType.getRowSchema();
Row.Builder row = Row.withSchema(schema);
for (Schema.Field field : schema.getFields()) {
Object element = rowMap.remove(field.getName());
if (element != null) {
element = parseField(element, field.getType());
}
row.addValue(element);
}
if (!rowMap.isEmpty()) {
throw new IllegalArgumentException(
"Value contains keys that are not in the schema: " + rowMap.keySet());
}
return row.build();
default: // DECIMAL, DATETIME, LOGICAL_TYPE
throw new IllegalArgumentException("Unsupported type name: " + fieldType.getTypeName());
}
}
private static Coder<?> instantiateCoder(CommonCoder coder) {
List<Coder<?>> components = new ArrayList<>();
for (CommonCoder innerCoder : coder.getComponents()) {
components.add(instantiateCoder(innerCoder));
}
Class<? extends Coder> coderType =
ModelCoderRegistrar.BEAM_MODEL_CODER_URNS.inverse().get(coder.getUrn());
checkNotNull(coderType, "Unknown coder URN: " + coder.getUrn());
CoderTranslator<?> translator = ModelCoderRegistrar.BEAM_MODEL_CODERS.get(coderType);
checkNotNull(
translator, "No translator found for common coder class: " + coderType.getSimpleName());
return translator.fromComponents(components, coder.getPayload());
}
@Test
public void executeSingleTest() throws IOException {
assertCoderIsKnown(testSpec.getCoder());
Coder coder = instantiateCoder(testSpec.getCoder());
Object testValue = convertValue(testSpec.getValue(), testSpec.getCoder(), coder);
Context context = testSpec.getNested() ? Context.NESTED : Context.OUTER;
byte[] encoded = CoderUtils.encodeToByteArray(coder, testValue, context);
Object decodedValue = CoderUtils.decodeFromByteArray(coder, testSpec.getSerialized(), context);
if (!testSpec.getCoder().getNonDeterministic()) {
assertThat(testSpec.toString(), encoded, equalTo(testSpec.getSerialized()));
}
verifyDecodedValue(testSpec.getCoder(), decodedValue, testValue);
}
private void verifyDecodedValue(CommonCoder coder, Object expectedValue, Object actualValue) {
String s = coder.getUrn();
if (s.equals(getUrn(StandardCoders.Enum.BYTES))) {
assertThat(expectedValue, equalTo(actualValue));
} else if (s.equals(getUrn(StandardCoders.Enum.BOOL))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.STRING_UTF8))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.KV))) {
assertThat(actualValue, instanceOf(KV.class));
verifyDecodedValue(
coder.getComponents().get(0), ((KV) expectedValue).getKey(), ((KV) actualValue).getKey());
verifyDecodedValue(
coder.getComponents().get(0),
((KV) expectedValue).getValue(),
((KV) actualValue).getValue());
} else if (s.equals(getUrn(StandardCoders.Enum.VARINT))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.INTERVAL_WINDOW))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.ITERABLE))) {
assertThat(actualValue, instanceOf(Iterable.class));
CommonCoder componentCoder = coder.getComponents().get(0);
Iterator<Object> expectedValueIterator = ((Iterable<Object>) expectedValue).iterator();
for (Object value : (Iterable<Object>) actualValue) {
verifyDecodedValue(componentCoder, expectedValueIterator.next(), value);
}
assertFalse(expectedValueIterator.hasNext());
} else if (s.equals(getUrn(StandardCoders.Enum.TIMER))) {
assertEquals(((Timer) expectedValue).getTimestamp(), ((Timer) actualValue).getTimestamp());
assertThat(((Timer) expectedValue).getPayload(), equalTo(((Timer) actualValue).getPayload()));
} else if (s.equals(getUrn(StandardCoders.Enum.GLOBAL_WINDOW))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.WINDOWED_VALUE))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.DOUBLE))) {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.ROW))) {
assertEquals(expectedValue, actualValue);
} else {
throw new IllegalStateException("Unknown coder URN: " + coder.getUrn());
}
}
/**
* Utility for adding new entries to the common coder spec -- prints the serialized bytes of the
* given value in the given context using JSON-escaped strings.
*/
private static <T> String jsonByteString(Coder<T> coder, T value, Context context)
throws CoderException {
byte[] bytes = CoderUtils.encodeToByteArray(coder, value, context);
ObjectMapper mapper = new ObjectMapper();
mapper.configure(JsonGenerator.Feature.ESCAPE_NON_ASCII, true);
try {
return mapper.writeValueAsString(new String(bytes, StandardCharsets.ISO_8859_1));
} catch (JsonProcessingException e) {
throw new CoderException(String.format("Unable to encode %s with coder %s", value, coder), e);
}
}
}