Merge pull request #10294: [BEAM-8895] Add BigQuery table name sanitization to BigQueryIOIT
diff --git a/model/job-management/src/main/proto/beam_artifact_api.proto b/model/job-management/src/main/proto/beam_artifact_api.proto
index 34eb389..2cfede9 100644
--- a/model/job-management/src/main/proto/beam_artifact_api.proto
+++ b/model/job-management/src/main/proto/beam_artifact_api.proto
@@ -29,6 +29,8 @@
option java_package = "org.apache.beam.model.jobmanagement.v1";
option java_outer_classname = "ArtifactApi";
+import "beam_runner_api.proto";
+
// A service to stage artifacts for use in a Job.
service ArtifactStagingService {
// Stage an artifact to be available during job execution. The first request must contain the
@@ -142,6 +144,10 @@
// The result of committing a manifest.
message CommitManifestResponse {
+ enum Constants {
+ // Token indicating that no artifacts were staged and therefore no retrieval attempt is necessary.
+ NO_ARTIFACTS_STAGED_TOKEN = 0 [(org.apache.beam.model.pipeline.v1.beam_constant) = "__no_artifacts_staged__"];
+ }
// (Required) An opaque token representing the entirety of the staged artifacts.
// This can be used to retrieve the manifest and artifacts from an associated
// ArtifactRetrievalService.
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
index d975b4f..8a41d7e 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
@@ -147,7 +147,7 @@
ProcessContinuation continuation =
invoker.invokeProcessElement(new NestedProcessContext<>(fn, c, element, w, tracker));
if (continuation.shouldResume()) {
- restriction = tracker.currentRestriction();
+ restriction = tracker.trySplit(0).getResidual();
Uninterruptibles.sleepUninterruptibly(
continuation.resumeDelay().getMillis(), TimeUnit.MILLISECONDS);
} else {
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java
index 86e79a5..0b0fadf 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java
@@ -23,6 +23,7 @@
import java.nio.channels.WritableByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactMetadata;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.CommitManifestRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.CommitManifestResponse;
@@ -32,6 +33,7 @@
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.PutArtifactRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.PutArtifactResponse;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.fnexecution.FnService;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.util.JsonFormat;
@@ -50,7 +52,11 @@
public abstract class AbstractArtifactStagingService extends ArtifactStagingServiceImplBase
implements FnService {
- public static final String NO_ARTIFACTS_STAGED_TOKEN = "__no_artifacts_staged__";
+ public static final String NO_ARTIFACTS_STAGED_TOKEN =
+ ArtifactApi.CommitManifestResponse.Constants.NO_ARTIFACTS_STAGED_TOKEN
+ .getValueDescriptor()
+ .getOptions()
+ .getExtension(RunnerApi.beamConstant);
private static final Logger LOG = LoggerFactory.getLogger(AbstractArtifactStagingService.class);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java
index dc27606..a6ecc45 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java
@@ -17,17 +17,11 @@
*/
package org.apache.beam.sdk.schemas;
-import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
-
import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
-import java.lang.reflect.ParameterizedType;
-import java.lang.reflect.Type;
import java.util.Arrays;
-import java.util.Collection;
-import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.sdk.schemas.utils.ReflectUtils;
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -129,9 +123,13 @@
}
public static FieldValueTypeInformation forSetter(Method method) {
+ return forSetter(method, "set");
+ }
+
+ public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) {
String name;
- if (method.getName().startsWith("set")) {
- name = ReflectUtils.stripPrefix(method.getName(), "set");
+ if (method.getName().startsWith(setterPrefix)) {
+ name = ReflectUtils.stripPrefix(method.getName(), setterPrefix);
} else {
throw new RuntimeException("Setter has wrong prefix " + method.getName());
}
@@ -162,25 +160,9 @@
}
@Nullable
- private static FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) {
+ static FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) {
// TODO: Figure out nullable elements.
- TypeDescriptor componentType = null;
- if (valueType.isArray()) {
- Type component = valueType.getComponentType().getType();
- if (!component.equals(byte.class)) {
- componentType = TypeDescriptor.of(component);
- }
- } else if (valueType.isSubtypeOf(TypeDescriptor.of(Iterable.class))) {
- TypeDescriptor<Iterable<?>> collection = valueType.getSupertype(Iterable.class);
- if (collection.getType() instanceof ParameterizedType) {
- ParameterizedType ptype = (ParameterizedType) collection.getType();
- java.lang.reflect.Type[] params = ptype.getActualTypeArguments();
- checkArgument(params.length == 1);
- componentType = TypeDescriptor.of(params[0]);
- } else {
- throw new RuntimeException("Collection parameter is not parameterized!");
- }
- }
+ TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType);
if (componentType == null) {
return null;
}
@@ -223,17 +205,7 @@
@SuppressWarnings("unchecked")
@Nullable
private static FieldValueTypeInformation getMapType(TypeDescriptor valueType, int index) {
- TypeDescriptor mapType = null;
- if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) {
- TypeDescriptor<Collection<?>> map = valueType.getSupertype(Map.class);
- if (map.getType() instanceof ParameterizedType) {
- ParameterizedType ptype = (ParameterizedType) map.getType();
- java.lang.reflect.Type[] params = ptype.getActualTypeArguments();
- mapType = TypeDescriptor.of(params[index]);
- } else {
- throw new RuntimeException("Map type is not parameterized! " + map);
- }
- }
+ TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index);
if (mapType == null) {
return null;
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
index b1b8ee8..61c0d05 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
@@ -21,7 +21,7 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.lang.reflect.Type;
-import java.util.Iterator;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -31,6 +31,9 @@
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.RowWithGetters;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
@@ -107,7 +110,8 @@
return (ValueT) fromRow((Row) value, (Class) fieldType, typeFactory);
} else if (TypeName.ARRAY.equals(type.getTypeName())) {
return (ValueT)
- fromListValue(type.getCollectionElementType(), (List) value, elementType, typeFactory);
+ fromCollectionValue(
+ type.getCollectionElementType(), (Collection) value, elementType, typeFactory);
} else if (TypeName.ITERABLE.equals(type.getTypeName())) {
return (ValueT)
fromIterableValue(
@@ -127,25 +131,35 @@
}
}
+ private static <SourceT, DestT> Collection<DestT> transformCollection(
+ Collection<SourceT> collection, Function<SourceT, DestT> function) {
+ if (collection instanceof List) {
+ // For performance reasons if the input is a list, make sure that we produce a list. Otherwise
+ // Row unwrapping
+ // is forced to physically copy the collection into a new List object.
+ return Lists.transform((List) collection, function);
+ } else {
+ return Collections2.transform(collection, function);
+ }
+ }
+
@SuppressWarnings("unchecked")
- private <ElementT> List fromListValue(
+ private <ElementT> Collection fromCollectionValue(
FieldType elementType,
- List<ElementT> rowList,
+ Collection<ElementT> rowCollection,
FieldValueTypeInformation elementTypeInformation,
Factory<List<FieldValueTypeInformation>> typeFactory) {
- List list = Lists.newArrayList();
- for (ElementT element : rowList) {
- list.add(
- fromValue(
- elementType,
- element,
- elementTypeInformation.getType().getType(),
- elementTypeInformation.getElementType(),
- elementTypeInformation.getMapKeyType(),
- elementTypeInformation.getMapValueType(),
- typeFactory));
- }
- return list;
+ return transformCollection(
+ rowCollection,
+ element ->
+ fromValue(
+ elementType,
+ element,
+ elementTypeInformation.getType().getType(),
+ elementTypeInformation.getElementType(),
+ elementTypeInformation.getMapKeyType(),
+ elementTypeInformation.getMapValueType(),
+ typeFactory));
}
@SuppressWarnings("unchecked")
@@ -154,32 +168,17 @@
Iterable<ElementT> rowIterable,
FieldValueTypeInformation elementTypeInformation,
Factory<List<FieldValueTypeInformation>> typeFactory) {
- return new Iterable<ElementT>() {
- @Override
- public Iterator<ElementT> iterator() {
- return new Iterator<ElementT>() {
- Iterator<ElementT> innerIter = rowIterable.iterator();
-
- @Override
- public boolean hasNext() {
- return innerIter.hasNext();
- }
-
- @Override
- public ElementT next() {
- ElementT element = innerIter.next();
- return fromValue(
+ return Iterables.transform(
+ rowIterable,
+ element ->
+ fromValue(
elementType,
element,
elementTypeInformation.getType().getType(),
elementTypeInformation.getElementType(),
elementTypeInformation.getMapKeyType(),
elementTypeInformation.getMapValueType(),
- typeFactory);
- }
- };
- }
- };
+ typeFactory));
}
@SuppressWarnings("unchecked")
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
index b12ad5e..c998037 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
@@ -75,6 +75,11 @@
public int hashCode() {
return Arrays.hashCode(array);
}
+
+ @Override
+ public String toString() {
+ return Arrays.toString(array);
+ }
}
// A mapping between field names an indices.
private final BiMap<String, Integer> fieldIndices = HashBiMap.create();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java
index 9604b95..791dafb 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java
@@ -31,6 +31,8 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
import org.apache.beam.sdk.schemas.FieldValueGetter;
import org.apache.beam.sdk.schemas.FieldValueSetter;
import org.apache.beam.sdk.schemas.FieldValueTypeInformation;
@@ -44,6 +46,7 @@
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.description.type.TypeDescription;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.description.type.TypeDescription.ForLoadedType;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.DynamicType;
+import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.scaffold.InstrumentedType;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.Implementation;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.ByteCodeAppender;
@@ -64,8 +67,12 @@
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.member.MethodVariableAccess;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.matcher.ElementMatchers;
import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.utility.RandomString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Primitives;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ClassUtils;
import org.joda.time.DateTimeZone;
@@ -82,7 +89,7 @@
private static final ForLoadedType CHAR_SEQUENCE_TYPE = new ForLoadedType(CharSequence.class);
private static final ForLoadedType INSTANT_TYPE = new ForLoadedType(Instant.class);
private static final ForLoadedType DATE_TIME_ZONE_TYPE = new ForLoadedType(DateTimeZone.class);
- private static final ForLoadedType LIST_TYPE = new ForLoadedType(List.class);
+ private static final ForLoadedType COLLECTION_TYPE = new ForLoadedType(Collection.class);
private static final ForLoadedType READABLE_INSTANT_TYPE =
new ForLoadedType(ReadableInstant.class);
private static final ForLoadedType READABLE_PARTIAL_TYPE =
@@ -90,6 +97,8 @@
private static final ForLoadedType OBJECT_TYPE = new ForLoadedType(Object.class);
private static final ForLoadedType INTEGER_TYPE = new ForLoadedType(Integer.class);
private static final ForLoadedType ENUM_TYPE = new ForLoadedType(Enum.class);
+ private static final ForLoadedType BYTE_BUDDY_UTILS_TYPE =
+ new ForLoadedType(ByteBuddyUtils.class);
/**
* A naming strategy for ByteBuddy classes.
@@ -98,7 +107,7 @@
* This way, if the class fields or methods are package private, our generated class can still
* access them.
*/
- static class InjectPackageStrategy extends NamingStrategy.AbstractBase {
+ public static class InjectPackageStrategy extends NamingStrategy.AbstractBase {
/** A resolver for the base name for naming the unnamed type. */
private static final BaseNameResolver baseNameResolver =
BaseNameResolver.ForUnnamedType.INSTANCE;
@@ -123,6 +132,30 @@
}
};
+ // Create a new FieldValueGetter subclass.
+ @SuppressWarnings("unchecked")
+ static DynamicType.Builder<FieldValueGetter> subclassGetterInterface(
+ ByteBuddy byteBuddy, Type objectType, Type fieldType) {
+ TypeDescription.Generic getterGenericType =
+ TypeDescription.Generic.Builder.parameterizedType(
+ FieldValueGetter.class, objectType, fieldType)
+ .build();
+ return (DynamicType.Builder<FieldValueGetter>)
+ byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType);
+ }
+
+ // Create a new FieldValueSetter subclass.
+ @SuppressWarnings("unchecked")
+ static DynamicType.Builder<FieldValueSetter> subclassSetterInterface(
+ ByteBuddy byteBuddy, Type objectType, Type fieldType) {
+ TypeDescription.Generic setterGenericType =
+ TypeDescription.Generic.Builder.parameterizedType(
+ FieldValueSetter.class, objectType, fieldType)
+ .build();
+ return (DynamicType.Builder<FieldValueSetter>)
+ byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType);
+ }
+
public interface TypeConversionsFactory {
TypeConversion<Type> createTypeConversion(boolean returnRawTypes);
@@ -148,30 +181,6 @@
}
}
- // Create a new FieldValueGetter subclass.
- @SuppressWarnings("unchecked")
- static DynamicType.Builder<FieldValueGetter> subclassGetterInterface(
- ByteBuddy byteBuddy, Type objectType, Type fieldType) {
- TypeDescription.Generic getterGenericType =
- TypeDescription.Generic.Builder.parameterizedType(
- FieldValueGetter.class, objectType, fieldType)
- .build();
- return (DynamicType.Builder<FieldValueGetter>)
- byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType);
- }
-
- // Create a new FieldValueSetter subclass.
- @SuppressWarnings("unchecked")
- static DynamicType.Builder<FieldValueSetter> subclassSetterInterface(
- ByteBuddy byteBuddy, Type objectType, Type fieldType) {
- TypeDescription.Generic setterGenericType =
- TypeDescription.Generic.Builder.parameterizedType(
- FieldValueSetter.class, objectType, fieldType)
- .build();
- return (DynamicType.Builder<FieldValueSetter>)
- byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType);
- }
-
// Base class used below to convert types.
@SuppressWarnings("unchecked")
public abstract static class TypeConversion<T> {
@@ -195,7 +204,9 @@
} else if (typeDescriptor.getRawType().isEnum()) {
return convertEnum(typeDescriptor);
} else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Iterable.class))) {
- if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Collection.class))) {
+ if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(List.class))) {
+ return convertList(typeDescriptor);
+ } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Collection.class))) {
return convertCollection(typeDescriptor);
} else {
return convertIterable(typeDescriptor);
@@ -211,6 +222,8 @@
protected abstract T convertCollection(TypeDescriptor<?> type);
+ protected abstract T convertList(TypeDescriptor<?> type);
+
protected abstract T convertMap(TypeDescriptor<?> type);
protected abstract T convertDateTime(TypeDescriptor<?> type);
@@ -253,18 +266,26 @@
@Override
protected Type convertArray(TypeDescriptor<?> type) {
- TypeDescriptor ret = createListType(type);
+ TypeDescriptor ret = createCollectionType(type.getComponentType());
return returnRawTypes ? ret.getRawType() : ret.getType();
}
@Override
protected Type convertCollection(TypeDescriptor<?> type) {
- return Collection.class;
+ TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type));
+ return returnRawTypes ? ret.getRawType() : ret.getType();
+ }
+
+ @Override
+ protected Type convertList(TypeDescriptor<?> type) {
+ TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type));
+ return returnRawTypes ? ret.getRawType() : ret.getType();
}
@Override
protected Type convertIterable(TypeDescriptor<?> type) {
- return Iterable.class;
+ TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type));
+ return returnRawTypes ? ret.getRawType() : ret.getType();
}
@Override
@@ -305,11 +326,190 @@
}
@SuppressWarnings("unchecked")
- private <ElementT> TypeDescriptor<List<ElementT>> createListType(TypeDescriptor<?> type) {
- TypeDescriptor componentType =
- TypeDescriptor.of(ClassUtils.primitiveToWrapper(type.getComponentType().getRawType()));
- return new TypeDescriptor<List<ElementT>>() {}.where(
- new TypeParameter<ElementT>() {}, componentType);
+ private <ElementT> TypeDescriptor<Collection<ElementT>> createCollectionType(
+ TypeDescriptor<?> componentType) {
+ TypeDescriptor wrappedComponentType =
+ TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType()));
+ return new TypeDescriptor<Collection<ElementT>>() {}.where(
+ new TypeParameter<ElementT>() {}, wrappedComponentType);
+ }
+
+ @SuppressWarnings("unchecked")
+ private <ElementT> TypeDescriptor<Iterable<ElementT>> createIterableType(
+ TypeDescriptor<?> componentType) {
+ TypeDescriptor wrappedComponentType =
+ TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType()));
+ return new TypeDescriptor<Iterable<ElementT>>() {}.where(
+ new TypeParameter<ElementT>() {}, wrappedComponentType);
+ }
+ }
+
+ private static final ByteBuddy BYTE_BUDDY = new ByteBuddy();
+
+ // When processing a container (e.g. List<T>) we need to recursively process the element type.
+ // This function
+ // generates a subclass of Function that can be used to recursively transform each element of the
+ // container.
+ static Class createCollectionTransformFunction(
+ Type fromType, Type toType, Function<StackManipulation, StackManipulation> convertElement) {
+ // Generate a TypeDescription for the class we want to generate.
+ TypeDescription.Generic functionGenericType =
+ TypeDescription.Generic.Builder.parameterizedType(
+ Function.class, Primitives.wrap((Class) fromType), Primitives.wrap((Class) toType))
+ .build();
+
+ DynamicType.Builder<Function> builder =
+ (DynamicType.Builder<Function>)
+ BYTE_BUDDY
+ .subclass(functionGenericType)
+ .method(ElementMatchers.named("apply"))
+ .intercept(
+ new Implementation() {
+ @Override
+ public ByteCodeAppender appender(Target target) {
+ return (methodVisitor, implementationContext, instrumentedMethod) -> {
+ // this + method parameters.
+ int numLocals = 1 + instrumentedMethod.getParameters().size();
+
+ StackManipulation readValue = MethodVariableAccess.REFERENCE.loadFrom(1);
+ StackManipulation stackManipulation =
+ new StackManipulation.Compound(
+ convertElement.apply(readValue), MethodReturn.REFERENCE);
+
+ StackManipulation.Size size =
+ stackManipulation.apply(methodVisitor, implementationContext);
+ return new Size(size.getMaximalSize(), numLocals);
+ };
+ }
+
+ @Override
+ public InstrumentedType prepare(InstrumentedType instrumentedType) {
+ return instrumentedType;
+ }
+ });
+
+ return builder
+ .make()
+ .load(ByteBuddyUtils.class.getClassLoader(), ClassLoadingStrategy.Default.INJECTION)
+ .getLoaded();
+ }
+
+ // A function to transform a container, special casing List and Collection types. This is used in
+ // byte-buddy
+ // generated code.
+ public static <FromT, DestT> Iterable<DestT> transformContainer(
+ Iterable<FromT> iterable, Function<FromT, DestT> function) {
+ if (iterable instanceof List) {
+ return Lists.transform((List<FromT>) iterable, function);
+ } else if (iterable instanceof Collection) {
+ return Collections2.transform((Collection<FromT>) iterable, function);
+ } else {
+ return Iterables.transform(iterable, function);
+ }
+ }
+
+ static StackManipulation createTransformingContainer(
+ ForLoadedType functionType, StackManipulation readValue) {
+ StackManipulation stackManipulation =
+ new Compound(
+ readValue,
+ TypeCreation.of(functionType),
+ Duplication.SINGLE,
+ MethodInvocation.invoke(
+ functionType
+ .getDeclaredMethods()
+ .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
+ .getOnly()),
+ MethodInvocation.invoke(
+ BYTE_BUDDY_UTILS_TYPE
+ .getDeclaredMethods()
+ .filter(ElementMatchers.named("transformContainer"))
+ .getOnly()));
+ return stackManipulation;
+ }
+
+ public static <K1, V1, K2, V2> TransformingMap<K1, V1, K2, V2> getTransformingMap(
+ Map<K1, V1> sourceMap, Function<K1, K2> keyFunction, Function<V1, V2> valueFunction) {
+ return new TransformingMap<>(sourceMap, keyFunction, valueFunction);
+ }
+
+ public static class TransformingMap<K1, V1, K2, V2> implements Map<K2, V2> {
+ private final Map<K2, V2> delegateMap;
+
+ public TransformingMap(
+ Map<K1, V1> sourceMap, Function<K1, K2> keyFunction, Function<V1, V2> valueFunction) {
+ if (sourceMap instanceof SortedMap) {
+ delegateMap =
+ (Map<K2, V2>)
+ Maps.newTreeMap(); // We don't support copying the comparator. Makes no sense if key
+ // is changing.
+ } else {
+ delegateMap = Maps.newHashMap();
+ }
+ for (Map.Entry<K1, V1> entry : sourceMap.entrySet()) {
+ delegateMap.put(keyFunction.apply(entry.getKey()), valueFunction.apply(entry.getValue()));
+ }
+ }
+
+ @Override
+ public int size() {
+ return delegateMap.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return delegateMap.isEmpty();
+ }
+
+ @Override
+ public boolean containsKey(Object key) {
+ return delegateMap.containsKey(key);
+ }
+
+ @Override
+ public boolean containsValue(Object value) {
+ return delegateMap.containsValue(value);
+ }
+
+ @Override
+ public V2 get(Object key) {
+ return delegateMap.get(key);
+ }
+
+ @Override
+ public V2 put(K2 key, V2 value) {
+ return delegateMap.put(key, value);
+ }
+
+ @Override
+ public V2 remove(Object key) {
+ return delegateMap.remove(key);
+ }
+
+ @Override
+ public void putAll(Map<? extends K2, ? extends V2> m) {
+ delegateMap.putAll(m);
+ }
+
+ @Override
+ public void clear() {
+ delegateMap.clear();
+ ;
+ }
+
+ @Override
+ public Set<K2> keySet() {
+ return delegateMap.keySet();
+ }
+
+ @Override
+ public Collection<V2> values() {
+ return delegateMap.values();
+ }
+
+ @Override
+ public Set<Entry<K2, V2>> entrySet() {
+ return delegateMap.entrySet();
}
}
@@ -338,46 +538,153 @@
// return isComponentTypePrimitive ? Arrays.asList(ArrayUtils.toObject(value))
// : Arrays.asList(value);
- ForLoadedType loadedType = new ForLoadedType(type.getRawType());
- StackManipulation stackManipulation = readValue;
+ TypeDescriptor<?> componentType = type.getComponentType();
+ ForLoadedType loadedArrayType = new ForLoadedType(type.getRawType());
+ StackManipulation readArrayValue = readValue;
// Row always expects to get an Iterable back for array types. Wrap this array into a
// List using Arrays.asList before returning.
- if (loadedType.getComponentType().isPrimitive()) {
+ if (loadedArrayType.getComponentType().isPrimitive()) {
// Arrays.asList doesn't take primitive arrays, so convert first using ArrayUtils.toObject.
- stackManipulation =
+ readArrayValue =
new Compound(
- stackManipulation,
+ readArrayValue,
MethodInvocation.invoke(
ARRAY_UTILS_TYPE
.getDeclaredMethods()
.filter(
ElementMatchers.isStatic()
.and(ElementMatchers.named("toObject"))
- .and(ElementMatchers.takesArguments(loadedType)))
+ .and(ElementMatchers.takesArguments(loadedArrayType)))
.getOnly()));
+
+ componentType = TypeDescriptor.of(Primitives.wrap(componentType.getRawType()));
}
- return new Compound(
- stackManipulation,
- MethodInvocation.invoke(
- ARRAYS_TYPE
- .getDeclaredMethods()
- .filter(ElementMatchers.isStatic().and(ElementMatchers.named("asList")))
- .getOnly()));
+ // Now convert to a List object.
+ StackManipulation readListValue =
+ new Compound(
+ readArrayValue,
+ MethodInvocation.invoke(
+ ARRAYS_TYPE
+ .getDeclaredMethods()
+ .filter(ElementMatchers.isStatic().and(ElementMatchers.named("asList")))
+ .getOnly()));
+
+ // Generate a SerializableFunction to convert the element-type objects.
+ final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType);
+ if (!finalComponentType.hasUnresolvedParameters()) {
+ Type convertedComponentType =
+ getFactory().createTypeConversion(true).convert(componentType);
+ ForLoadedType functionType =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ componentType.getRawType(),
+ convertedComponentType,
+ (s) -> getFactory().createGetterConversions(s).convert(finalComponentType)));
+ return createTransformingContainer(functionType, readListValue);
+ } else {
+ return readListValue;
+ }
}
@Override
protected StackManipulation convertIterable(TypeDescriptor<?> type) {
- return readValue;
+ TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type);
+ Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType);
+ final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType);
+ if (!finalComponentType.hasUnresolvedParameters()) {
+ ForLoadedType functionType =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ componentType.getRawType(),
+ convertedComponentType,
+ (s) -> getFactory().createGetterConversions(s).convert(finalComponentType)));
+ return createTransformingContainer(functionType, readValue);
+ } else {
+ return readValue;
+ }
}
@Override
protected StackManipulation convertCollection(TypeDescriptor<?> type) {
- return readValue;
+ TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type);
+ Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType);
+ final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType);
+ if (!finalComponentType.hasUnresolvedParameters()) {
+ ForLoadedType functionType =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ componentType.getRawType(),
+ convertedComponentType,
+ (s) -> getFactory().createGetterConversions(s).convert(finalComponentType)));
+ return createTransformingContainer(functionType, readValue);
+ } else {
+ return readValue;
+ }
+ }
+
+ @Override
+ protected StackManipulation convertList(TypeDescriptor<?> type) {
+ TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type);
+ Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType);
+ final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType);
+ if (!finalComponentType.hasUnresolvedParameters()) {
+ ForLoadedType functionType =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ componentType.getRawType(),
+ convertedComponentType,
+ (s) -> getFactory().createGetterConversions(s).convert(finalComponentType)));
+ return createTransformingContainer(functionType, readValue);
+ } else {
+ return readValue;
+ }
}
@Override
protected StackManipulation convertMap(TypeDescriptor<?> type) {
- return readValue;
+ final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0);
+ final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1);
+
+ Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType);
+ Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType);
+
+ if (!keyType.hasUnresolvedParameters() && !valueType.hasUnresolvedParameters()) {
+ ForLoadedType keyFunctionType =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ keyType.getRawType(),
+ convertedKeyType,
+ (s) -> getFactory().createGetterConversions(s).convert(keyType)));
+ ForLoadedType valueFunctionType =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ valueType.getRawType(),
+ convertedValueType,
+ (s) -> getFactory().createGetterConversions(s).convert(valueType)));
+ return new Compound(
+ readValue,
+ TypeCreation.of(keyFunctionType),
+ Duplication.SINGLE,
+ MethodInvocation.invoke(
+ keyFunctionType
+ .getDeclaredMethods()
+ .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
+ .getOnly()),
+ TypeCreation.of(valueFunctionType),
+ Duplication.SINGLE,
+ MethodInvocation.invoke(
+ valueFunctionType
+ .getDeclaredMethods()
+ .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
+ .getOnly()),
+ MethodInvocation.invoke(
+ BYTE_BUDDY_UTILS_TYPE
+ .getDeclaredMethods()
+ .filter(ElementMatchers.named("getTransformingMap"))
+ .getOnly()));
+ } else {
+ return readValue;
+ }
}
@Override
@@ -529,7 +836,7 @@
* there. This class generates code to convert between these types.
*/
public static class ConvertValueForSetter extends TypeConversion<StackManipulation> {
- StackManipulation readValue;
+ protected StackManipulation readValue;
protected ConvertValueForSetter(StackManipulation readValue) {
this.readValue = readValue;
@@ -553,18 +860,31 @@
.build()
.asErasure();
+ Type rowElementType =
+ getFactory().createTypeConversion(false).convert(type.getComponentType());
+ final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(type.getComponentType());
+ if (!arrayElementType.hasUnresolvedParameters()) {
+ ForLoadedType conversionFunction =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ TypeDescriptor.of(rowElementType).getRawType(),
+ Primitives.wrap(arrayElementType.getRawType()),
+ (s) -> getFactory().createSetterConversions(s).convert(arrayElementType)));
+ readValue = createTransformingContainer(conversionFunction, readValue);
+ }
+
// Extract an array from the collection.
StackManipulation stackManipulation =
new Compound(
readValue,
- TypeCasting.to(LIST_TYPE),
+ TypeCasting.to(COLLECTION_TYPE),
// Call Collection.toArray(T[[]) to extract the array. Push new T[0] on the stack
// before
// calling toArray.
ArrayFactory.forType(loadedType.getComponentType().asBoxed().asGenericType())
.withValues(Collections.emptyList()),
MethodInvocation.invoke(
- LIST_TYPE
+ COLLECTION_TYPE
.getDeclaredMethods()
.filter(
ElementMatchers.named("toArray").and(ElementMatchers.takesArguments(1)))
@@ -591,16 +911,128 @@
@Override
protected StackManipulation convertIterable(TypeDescriptor<?> type) {
- return readValue;
+ Type rowElementType =
+ getFactory()
+ .createTypeConversion(false)
+ .convert(ReflectUtils.getIterableComponentType(type));
+ final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type);
+ if (!iterableElementType.hasUnresolvedParameters()) {
+ ForLoadedType conversionFunction =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ TypeDescriptor.of(rowElementType).getRawType(),
+ iterableElementType.getRawType(),
+ (s) -> getFactory().createSetterConversions(s).convert(iterableElementType)));
+ StackManipulation transformedContainer =
+ createTransformingContainer(conversionFunction, readValue);
+ return transformedContainer;
+ } else {
+ return readValue;
+ }
}
@Override
protected StackManipulation convertCollection(TypeDescriptor<?> type) {
- return readValue;
+ Type rowElementType =
+ getFactory()
+ .createTypeConversion(false)
+ .convert(ReflectUtils.getIterableComponentType(type));
+ final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type);
+
+ if (!collectionElementType.hasUnresolvedParameters()) {
+ ForLoadedType conversionFunction =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ TypeDescriptor.of(rowElementType).getRawType(),
+ collectionElementType.getRawType(),
+ (s) -> getFactory().createSetterConversions(s).convert(collectionElementType)));
+ StackManipulation transformedContainer =
+ createTransformingContainer(conversionFunction, readValue);
+ return transformedContainer;
+ } else {
+ return readValue;
+ }
+ }
+
+ @Override
+ protected StackManipulation convertList(TypeDescriptor<?> type) {
+ Type rowElementType =
+ getFactory()
+ .createTypeConversion(false)
+ .convert(ReflectUtils.getIterableComponentType(type));
+ final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type);
+
+ if (!collectionElementType.hasUnresolvedParameters()) {
+ ForLoadedType conversionFunction =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ TypeDescriptor.of(rowElementType).getRawType(),
+ collectionElementType.getRawType(),
+ (s) -> getFactory().createSetterConversions(s).convert(collectionElementType)));
+ readValue = createTransformingContainer(conversionFunction, readValue);
+ }
+ // TODO: Don't copy if already a list!
+ StackManipulation transformedList =
+ new Compound(
+ readValue,
+ MethodInvocation.invoke(
+ new ForLoadedType(Lists.class)
+ .getDeclaredMethods()
+ .filter(
+ ElementMatchers.named("newArrayList")
+ .and(ElementMatchers.takesArguments(Iterable.class)))
+ .getOnly()));
+ return transformedList;
}
@Override
protected StackManipulation convertMap(TypeDescriptor<?> type) {
+ Type rowKeyType =
+ getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0));
+ final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0);
+ Type rowValueType =
+ getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1));
+ final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1);
+
+ if (!keyElementType.hasUnresolvedParameters()
+ && !valueElementType.hasUnresolvedParameters()) {
+ ForLoadedType keyConversionFunction =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ TypeDescriptor.of(rowKeyType).getRawType(),
+ keyElementType.getRawType(),
+ (s) -> getFactory().createSetterConversions(s).convert(keyElementType)));
+ ForLoadedType valueConversionFunction =
+ new ForLoadedType(
+ createCollectionTransformFunction(
+ TypeDescriptor.of(rowValueType).getRawType(),
+ valueElementType.getRawType(),
+ (s) -> getFactory().createSetterConversions(s).convert(valueElementType)));
+ readValue =
+ new Compound(
+ readValue,
+ TypeCreation.of(keyConversionFunction),
+ Duplication.SINGLE,
+ MethodInvocation.invoke(
+ keyConversionFunction
+ .getDeclaredMethods()
+ .filter(
+ ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
+ .getOnly()),
+ TypeCreation.of(valueConversionFunction),
+ Duplication.SINGLE,
+ MethodInvocation.invoke(
+ valueConversionFunction
+ .getDeclaredMethods()
+ .filter(
+ ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
+ .getOnly()),
+ MethodInvocation.invoke(
+ BYTE_BUDDY_UTILS_TYPE
+ .getDeclaredMethods()
+ .filter(ElementMatchers.named("getTransformingMap"))
+ .getOnly()));
+ }
return readValue;
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java
index b9f1ae5..d56f0bd 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java
@@ -23,8 +23,11 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
import java.security.InvalidParameterException;
import java.util.Arrays;
+import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -33,16 +36,19 @@
import javax.annotation.Nullable;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
+import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Primitives;
/** A set of reflection helper methods. */
public class ReflectUtils {
- static class ClassWithSchema {
+ /** Represents a class and a schema. */
+ public static class ClassWithSchema {
private final Class clazz;
private final Schema schema;
- ClassWithSchema(Class clazz, Schema schema) {
+ public ClassWithSchema(Class clazz, Schema schema) {
this.clazz = clazz;
this.schema = schema;
}
@@ -78,6 +84,9 @@
clazz,
c -> {
return Arrays.stream(c.getDeclaredMethods())
+ .filter(
+ m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must
+ // ignore.
.filter(m -> !Modifier.isPrivate(m.getModifiers()))
.filter(m -> !Modifier.isProtected(m.getModifiers()))
.filter(m -> !Modifier.isStatic(m.getModifiers()))
@@ -183,4 +192,49 @@
public static String stripSetterPrefix(String method) {
return stripPrefix(method, "set");
}
+
+ /** For an array T[] or a subclass of Iterable<T>, return a TypeDescriptor describing T. */
+ @Nullable
+ public static TypeDescriptor getIterableComponentType(TypeDescriptor valueType) {
+ TypeDescriptor componentType = null;
+ if (valueType.isArray()) {
+ Type component = valueType.getComponentType().getType();
+ if (!component.equals(byte.class)) {
+ // Byte arrays are special cased since we have a schema type corresponding to them.
+ componentType = TypeDescriptor.of(component);
+ }
+ } else if (valueType.isSubtypeOf(TypeDescriptor.of(Iterable.class))) {
+ TypeDescriptor<Iterable<?>> collection = valueType.getSupertype(Iterable.class);
+ if (collection.getType() instanceof ParameterizedType) {
+ ParameterizedType ptype = (ParameterizedType) collection.getType();
+ java.lang.reflect.Type[] params = ptype.getActualTypeArguments();
+ checkArgument(params.length == 1);
+ componentType = TypeDescriptor.of(params[0]);
+ } else {
+ throw new RuntimeException("Collection parameter is not parameterized!");
+ }
+ }
+ return componentType;
+ }
+
+ public static TypeDescriptor getMapType(TypeDescriptor valueType, int index) {
+ TypeDescriptor mapType = null;
+ if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) {
+ TypeDescriptor<Collection<?>> map = valueType.getSupertype(Map.class);
+ if (map.getType() instanceof ParameterizedType) {
+ ParameterizedType ptype = (ParameterizedType) map.getType();
+ java.lang.reflect.Type[] params = ptype.getActualTypeArguments();
+ mapType = TypeDescriptor.of(params[index]);
+ } else {
+ throw new RuntimeException("Map type is not parameterized! " + map);
+ }
+ }
+ return mapType;
+ }
+
+ public static TypeDescriptor boxIfPrimitive(TypeDescriptor typeDescriptor) {
+ return typeDescriptor.getRawType().isPrimitive()
+ ? TypeDescriptor.of(Primitives.wrap(typeDescriptor.getRawType()))
+ : typeDescriptor;
+ }
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
index d437b06..be28467 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
@@ -25,6 +25,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
@@ -187,7 +188,7 @@
* match.
*/
@Nullable
- public <T> List<T> getArray(String fieldName) {
+ public <T> Collection<T> getArray(String fieldName) {
return getArray(getSchema().indexOf(fieldName));
}
@@ -332,7 +333,7 @@
* match.
*/
@Nullable
- public <T> List<T> getArray(int idx) {
+ public <T> Collection<T> getArray(int idx) {
return getValue(idx);
}
@@ -421,8 +422,8 @@
} else if (fieldType.getTypeName() == Schema.TypeName.BYTES) {
return Arrays.equals((byte[]) a, (byte[]) b);
} else if (fieldType.getTypeName() == TypeName.ARRAY) {
- return deepEqualsForList(
- (List<Object>) a, (List<Object>) b, fieldType.getCollectionElementType());
+ return deepEqualsForCollection(
+ (Collection<Object>) a, (Collection<Object>) b, fieldType.getCollectionElementType());
} else if (fieldType.getTypeName() == TypeName.ITERABLE) {
return deepEqualsForIterable(
(Iterable<Object>) a, (Iterable<Object>) b, fieldType.getCollectionElementType());
@@ -493,7 +494,8 @@
return h;
}
- static boolean deepEqualsForList(List<Object> a, List<Object> b, Schema.FieldType elementType) {
+ static boolean deepEqualsForCollection(
+ Collection<Object> a, Collection<Object> b, Schema.FieldType elementType) {
if (a == b) {
return true;
}
@@ -584,7 +586,7 @@
return addValues(Arrays.asList(values));
}
- public <T> Builder addArray(List<T> values) {
+ public <T> Builder addArray(Collection<T> values) {
this.values.add(values);
return this;
}
@@ -662,16 +664,16 @@
private List<Object> verifyArray(
Object value, FieldType collectionElementType, String fieldName) {
boolean collectionElementTypeNullable = collectionElementType.getNullable();
- if (!(value instanceof List)) {
+ if (!(value instanceof Collection)) {
throw new IllegalArgumentException(
String.format(
- "For field name %s and array type expected List class. Instead "
+ "For field name %s and array type expected Collection class. Instead "
+ "class type was %s.",
fieldName, value.getClass()));
}
- List<Object> valueList = (List<Object>) value;
- List<Object> verifiedList = Lists.newArrayListWithCapacity(valueList.size());
- for (Object listValue : valueList) {
+ Collection<Object> valueCollection = (Collection<Object>) value;
+ List<Object> verifiedList = Lists.newArrayListWithCapacity(valueCollection.size());
+ for (Object listValue : valueCollection) {
if (listValue == null) {
if (!collectionElementTypeNullable) {
throw new IllegalArgumentException(
@@ -696,8 +698,8 @@
+ "class type was %s.",
fieldName, value.getClass()));
}
- Iterable<Object> valueList = (Iterable<Object>) value;
- for (Object listValue : valueList) {
+ Iterable<Object> valueIterable = (Iterable<Object>) value;
+ for (Object listValue : valueIterable) {
if (listValue == null) {
if (!collectionElementTypeNullable) {
throw new IllegalArgumentException(
@@ -708,7 +710,7 @@
verify(listValue, collectionElementType, fieldName);
}
}
- return valueList;
+ return valueIterable;
}
private Map<Object, Object> verifyMap(
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java
index 0d78731..ebf59b9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.values;
-import java.util.Iterator;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -29,6 +29,8 @@
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
@@ -44,7 +46,7 @@
private final Object getterTarget;
private final List<FieldValueGetter> getters;
- private final Map<Integer, List> cachedLists = Maps.newHashMap();
+ private final Map<Integer, Collection> cachedCollections = Maps.newHashMap();
private final Map<Integer, Iterable> cachedIterables = Maps.newHashMap();
private final Map<Integer, Map> cachedMaps = Maps.newHashMap();
@@ -69,36 +71,22 @@
return fieldValue != null ? getValue(type, fieldValue, fieldIdx) : null;
}
- private List getListValue(FieldType elementType, Object fieldValue) {
- Iterable iterable = (Iterable) fieldValue;
- List<Object> list = Lists.newArrayList();
- for (Object o : iterable) {
- list.add(getValue(elementType, o, null));
+ private Collection getCollectionValue(FieldType elementType, Object fieldValue) {
+ Collection collection = (Collection) fieldValue;
+ if (collection instanceof List) {
+ // For performance reasons if the input is a list, make sure that we produce a list. Otherwise
+ // Row forwarding
+ // is forced to physically copy the collection into a new List object.
+ return Lists.transform((List) collection, v -> getValue(elementType, v, null));
+ } else {
+ return Collections2.transform(collection, v -> getValue(elementType, v, null));
}
- return list;
}
private Iterable getIterableValue(FieldType elementType, Object fieldValue) {
Iterable iterable = (Iterable) fieldValue;
// Wrap the iterable to avoid having to materialize the entire collection.
- return new Iterable() {
- @Override
- public Iterator iterator() {
- return new Iterator() {
- Iterator iterator = iterable.iterator();
-
- @Override
- public boolean hasNext() {
- return iterator.hasNext();
- }
-
- @Override
- public Object next() {
- return getValue(elementType, iterator.next(), null);
- }
- };
- }
- };
+ return Iterables.transform(iterable, v -> getValue(elementType, v, null));
}
private Map<?, ?> getMapValue(FieldType keyType, FieldType valueType, Map<?, ?> fieldValue) {
@@ -117,9 +105,9 @@
} else if (type.getTypeName().equals(TypeName.ARRAY)) {
return cacheKey != null
? (T)
- cachedLists.computeIfAbsent(
- cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue))
- : (T) getListValue(type.getCollectionElementType(), fieldValue);
+ cachedCollections.computeIfAbsent(
+ cacheKey, i -> getCollectionValue(type.getCollectionElementType(), fieldValue))
+ : (T) getCollectionValue(type.getCollectionElementType(), fieldValue);
} else if (type.getTypeName().equals(TypeName.ITERABLE)) {
return cacheKey != null
? (T)
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java
index 6ceac8b..feb51db 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.schemas;
+import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA;
import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ITERABLE_BEAM_SCHEMA;
import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.NESTED_ARRAYS_BEAM_SCHEMA;
import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.NESTED_ARRAY_BEAN_SCHEMA;
@@ -30,11 +31,13 @@
import static org.junit.Assert.assertTrue;
import java.math.BigDecimal;
+import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.schemas.utils.SchemaTestUtils;
+import org.apache.beam.sdk.schemas.utils.TestJavaBeans.ArrayOfByteArray;
import org.apache.beam.sdk.schemas.utils.TestJavaBeans.IterableBean;
import org.apache.beam.sdk.schemas.utils.TestJavaBeans.MismatchingNullableBean;
import org.apache.beam.sdk.schemas.utils.TestJavaBeans.NestedArrayBean;
@@ -272,7 +275,7 @@
NestedArrayBean bean = new NestedArrayBean(simple1, simple2, simple3);
Row row = registry.getToRowFunction(NestedArrayBean.class).apply(bean);
- List<Row> rows = row.getArray("beans");
+ List<Row> rows = (List) row.getArray("beans");
assertSame(simple1, registry.getFromRowFunction(SimpleBean.class).apply(rows.get(0)));
assertSame(simple2, registry.getFromRowFunction(SimpleBean.class).apply(rows.get(1)));
assertSame(simple3, registry.getFromRowFunction(SimpleBean.class).apply(rows.get(2)));
@@ -422,4 +425,38 @@
list.add("three");
assertEquals(list, Lists.newArrayList(converted.getStrings()));
}
+
+ @Test
+ public void testToRowArrayOfBytes() throws NoSuchSchemaException {
+ SchemaRegistry registry = SchemaRegistry.createDefault();
+ Schema schema = registry.getSchema(ArrayOfByteArray.class);
+ SchemaTestUtils.assertSchemaEquivalent(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA, schema);
+
+ ArrayOfByteArray arrayOfByteArray =
+ new ArrayOfByteArray(
+ ImmutableList.of(ByteBuffer.wrap(BYTE_ARRAY), ByteBuffer.wrap(BYTE_ARRAY)));
+ Row expectedRow =
+ Row.withSchema(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA)
+ .addArray(ImmutableList.of(BYTE_ARRAY, BYTE_ARRAY))
+ .build();
+ Row converted = registry.getToRowFunction(ArrayOfByteArray.class).apply(arrayOfByteArray);
+ assertEquals(expectedRow, converted);
+ }
+
+ @Test
+ public void testFromRowArrayOfBytes() throws NoSuchSchemaException {
+ SchemaRegistry registry = SchemaRegistry.createDefault();
+ Schema schema = registry.getSchema(ArrayOfByteArray.class);
+ SchemaTestUtils.assertSchemaEquivalent(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA, schema);
+
+ ArrayOfByteArray expectedArrayOfByteArray =
+ new ArrayOfByteArray(
+ ImmutableList.of(ByteBuffer.wrap(BYTE_ARRAY), ByteBuffer.wrap(BYTE_ARRAY)));
+ Row row =
+ Row.withSchema(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA)
+ .addArray(ImmutableList.of(BYTE_ARRAY, BYTE_ARRAY))
+ .build();
+ ArrayOfByteArray converted = registry.getFromRowFunction(ArrayOfByteArray.class).apply(row);
+ assertEquals(expectedArrayOfByteArray, converted);
+ }
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
index 992c6bb..4134a57 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java
@@ -305,7 +305,7 @@
NestedArrayPOJO pojo = new NestedArrayPOJO(simple1, simple2, simple3);
Row row = registry.getToRowFunction(NestedArrayPOJO.class).apply(pojo);
- List<Row> rows = row.getArray("pojos");
+ List<Row> rows = (List) row.getArray("pojos");
assertSame(simple1, registry.getFromRowFunction(SimplePOJO.class).apply(rows.get(0)));
assertSame(simple2, registry.getFromRowFunction(SimplePOJO.class).apply(rows.get(1)));
assertSame(simple3, registry.getFromRowFunction(SimplePOJO.class).apply(rows.get(2)));
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
index 73aa085..81517af 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java
@@ -23,6 +23,7 @@
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
import static org.junit.Assert.assertThat;
+import java.util.Collection;
import java.util.List;
import org.apache.beam.sdk.TestUtils.KvMatcher;
import org.apache.beam.sdk.schemas.Schema;
@@ -691,7 +692,7 @@
Schema valueSchema = value.getSchema();
for (int i = 0; i < valueSchema.getFieldCount(); ++i) {
assertEquals(TypeName.ARRAY, valueSchema.getField(i).getType().getTypeName());
- fieldMatchers.add(new ArrayFieldMatchesAnyOrder(i, value.getArray(i)));
+ fieldMatchers.add(new ArrayFieldMatchesAnyOrder(i, (List) value.getArray(i)));
}
matchers.add(
KvMatcher.isKv(equalTo(row.getKey()), allOf(fieldMatchers.toArray(new Matcher[0]))));
@@ -715,7 +716,7 @@
return false;
}
Row row = (Row) item;
- List<Row> actual = row.getArray(fieldIndex);
+ Collection<Row> actual = row.getArray(fieldIndex);
return containsInAnyOrder(expected).matches(actual);
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
index 15d1379..6deab6d 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
@@ -424,12 +424,17 @@
return false;
}
PartialRowMultipleArray that = (PartialRowMultipleArray) o;
- return Objects.equals(field1, that.field1);
+ return Objects.equals(field1, that.field1) && Objects.equals(field3, that.field3);
}
@Override
public int hashCode() {
- return Objects.hash(field1);
+ return Objects.hash(field1, field3);
+ }
+
+ @Override
+ public String toString() {
+ return "PartialRowMultipleArray{" + "field1=" + field1 + ", field3=" + field3 + '}';
}
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java
index f137477..32cf264 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java
@@ -972,4 +972,45 @@
/** The schema for {@link NestedArrayBean}. * */
public static final Schema ITERABLE_BEAM_SCHEMA =
Schema.builder().addIterableField("strings", FieldType.STRING).build();
+
+ /** A bean containing an Array of ByteArray. * */
+ @DefaultSchema(JavaBeanSchema.class)
+ public static class ArrayOfByteArray {
+ private List<ByteBuffer> byteBuffers;
+
+ public ArrayOfByteArray(List<ByteBuffer> byteBuffers) {
+ this.byteBuffers = byteBuffers;
+ }
+
+ public ArrayOfByteArray() {}
+
+ public List<ByteBuffer> getByteBuffers() {
+ return byteBuffers;
+ }
+
+ public void setByteBuffers(List<ByteBuffer> byteBuffers) {
+ this.byteBuffers = byteBuffers;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ArrayOfByteArray that = (ArrayOfByteArray) o;
+ return Objects.equals(byteBuffers, that.byteBuffers);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(byteBuffers);
+ }
+ }
+
+ /** The schema for {@link NestedArrayBean}. * */
+ public static final Schema ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA =
+ Schema.builder().addArrayField("byteBuffers", FieldType.BYTES).build();
}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
index 1263b3d..9f27a4a 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
@@ -17,7 +17,7 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;
-import java.util.List;
+import java.util.Collection;
import javax.annotation.Nullable;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
@@ -129,7 +129,7 @@
@ProcessElement
public void process(@Element Row row, OutputReceiver<Row> out) {
- @Nullable List<Object> rawValues = row.getArray(unnestIndex);
+ @Nullable Collection<Object> rawValues = row.getArray(unnestIndex);
if (rawValues == null) {
return;
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
index 7b8ce03..9b06a12 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
@@ -20,15 +20,22 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import java.io.Serializable;
+import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
+import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTableFilter;
+import org.apache.beam.sdk.extensions.sql.meta.DefaultTableFilter;
+import org.apache.beam.sdk.extensions.sql.meta.ProjectSupport;
import org.apache.beam.sdk.extensions.sql.meta.SchemaBaseBeamTable;
import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.io.mongodb.FindQuery;
import org.apache.beam.sdk.io.mongodb.MongoDbIO;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.JsonToRow;
import org.apache.beam.sdk.transforms.MapElements;
@@ -85,6 +92,29 @@
}
@Override
+ public PCollection<Row> buildIOReader(
+ PBegin begin, BeamSqlTableFilter filters, List<String> fieldNames) {
+ MongoDbIO.Read readInstance =
+ MongoDbIO.read().withUri(dbUri).withDatabase(dbName).withCollection(dbCollection);
+
+ final FieldAccessDescriptor resolved =
+ FieldAccessDescriptor.withFieldNames(fieldNames)
+ .withOrderByFieldInsertionOrder()
+ .resolve(getSchema());
+ final Schema newSchema = SelectHelpers.getOutputSchema(getSchema(), resolved);
+
+ if (!(filters instanceof DefaultTableFilter)) {
+ throw new AssertionError("Predicate push-down is unsupported, yet received a predicate.");
+ }
+
+ if (!fieldNames.isEmpty()) {
+ readInstance = readInstance.withQueryFn(FindQuery.create().withProjection(fieldNames));
+ }
+
+ return readInstance.expand(begin).apply(DocumentToRow.withSchema(newSchema));
+ }
+
+ @Override
public POutput buildIOWriter(PCollection<Row> input) {
return input
.apply(new RowToDocument())
@@ -92,6 +122,11 @@
}
@Override
+ public ProjectSupport supportsProjects() {
+ return ProjectSupport.WITH_FIELD_REORDERING;
+ }
+
+ @Override
public IsBounded isBounded() {
return IsBounded.BOUNDED;
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
index aa11690..0d4296a 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
@@ -25,9 +25,16 @@
import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
import static org.apache.beam.sdk.schemas.Schema.FieldType.INT64;
import static org.apache.beam.sdk.schemas.Schema.FieldType.STRING;
-import static org.junit.Assert.assertEquals;
+import static org.apache.beam.sdk.testing.SerializableMatchers.containsInAnyOrder;
+import static org.apache.beam.sdk.testing.SerializableMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsInstanceOf.instanceOf;
+import com.mongodb.BasicDBObject;
import com.mongodb.MongoClient;
+import com.mongodb.client.MongoCollection;
+import com.mongodb.client.MongoDatabase;
+import com.mongodb.client.model.Filters;
import de.flapdoodle.embed.mongo.MongodExecutable;
import de.flapdoodle.embed.mongo.MongodProcess;
import de.flapdoodle.embed.mongo.MongodStarter;
@@ -40,6 +47,7 @@
import de.flapdoodle.embed.process.runtime.Network;
import java.util.Arrays;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamPushDownIOSourceRel;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
import org.apache.beam.sdk.io.common.NetworkTestHelper;
@@ -49,7 +57,10 @@
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
+import org.bson.Document;
+import org.junit.After;
import org.junit.AfterClass;
+import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Rule;
@@ -69,7 +80,6 @@
private static final Logger LOG = LoggerFactory.getLogger(MongoDbReadWriteIT.class);
private static final Schema SOURCE_SCHEMA =
Schema.builder()
- .addNullableField("_id", STRING)
.addNullableField("c_bigint", INT64)
.addNullableField("c_tinyint", BYTE)
.addNullableField("c_smallint", INT16)
@@ -83,7 +93,6 @@
private static final String hostname = "localhost";
private static final String database = "beam";
private static final String collection = "collection";
- private static int port;
@ClassRule public static final TemporaryFolder MONGODB_LOCATION = new TemporaryFolder();
@@ -92,12 +101,15 @@
private static MongodProcess mongodProcess;
private static MongoClient client;
+ private static BeamSqlEnv sqlEnv;
+ private static String mongoSqlUrl;
+
@Rule public final TestPipeline writePipeline = TestPipeline.create();
@Rule public final TestPipeline readPipeline = TestPipeline.create();
@BeforeClass
public static void setUp() throws Exception {
- port = NetworkTestHelper.getAvailableLocalPort();
+ int port = NetworkTestHelper.getAvailableLocalPort();
LOG.info("Starting MongoDB embedded instance on {}", port);
IMongodConfig mongodConfig =
new MongodConfigBuilder()
@@ -117,6 +129,8 @@
mongodExecutable = mongodStarter.prepare(mongodConfig);
mongodProcess = mongodExecutable.start();
client = new MongoClient(hostname, port);
+
+ mongoSqlUrl = String.format("mongodb://%s:%d/%s/%s", hostname, port, database, collection);
}
@AfterClass
@@ -127,15 +141,23 @@
mongodExecutable.stop();
}
+ @Before
+ public void init() {
+ sqlEnv = BeamSqlEnv.inMemory(new MongoDbTableProvider());
+ MongoDatabase db = client.getDatabase(database);
+ Document r = db.runCommand(new BasicDBObject().append("profile", 2));
+ }
+
+ @After
+ public void cleanUp() {
+ client.getDatabase(database).drop();
+ }
+
@Test
public void testWriteAndRead() {
- final String mongoSqlUrl =
- String.format("mongodb://%s:%d/%s/%s", hostname, port, database, collection);
-
Row testRow =
row(
SOURCE_SCHEMA,
- "object_id",
9223372036854775807L,
(byte) 127,
(short) 32767,
@@ -148,7 +170,6 @@
String createTableStatement =
"CREATE EXTERNAL TABLE TEST( \n"
- + " _id VARCHAR, \n "
+ " c_bigint BIGINT, \n "
+ " c_tinyint TINYINT, \n"
+ " c_smallint SMALLINT, \n"
@@ -163,12 +184,10 @@
+ "LOCATION '"
+ mongoSqlUrl
+ "'";
- BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new MongoDbTableProvider());
sqlEnv.executeDdl(createTableStatement);
String insertStatement =
"INSERT INTO TEST VALUES ("
- + "'object_id', "
+ "9223372036854775807, "
+ "127, "
+ "32767, "
@@ -187,13 +206,95 @@
PCollection<Row> output =
BeamSqlRelUtils.toPCollection(readPipeline, sqlEnv.parseQuery("select * from TEST"));
- assertEquals(output.getSchema(), SOURCE_SCHEMA);
+ assertThat(output.getSchema(), equalTo(SOURCE_SCHEMA));
PAssert.that(output).containsInAnyOrder(testRow);
readPipeline.run().waitUntilFinish();
}
+ @Test
+ public void testProjectPushDown() {
+ final Schema expectedSchema =
+ Schema.builder()
+ .addNullableField("c_varchar", STRING)
+ .addNullableField("c_boolean", BOOLEAN)
+ .addNullableField("c_integer", INT32)
+ .build();
+ Row testRow = row(expectedSchema, "varchar", true, 2147483647);
+
+ String createTableStatement =
+ "CREATE EXTERNAL TABLE TEST( \n"
+ + " c_bigint BIGINT, \n "
+ + " c_tinyint TINYINT, \n"
+ + " c_smallint SMALLINT, \n"
+ + " c_integer INTEGER, \n"
+ + " c_float FLOAT, \n"
+ + " c_double DOUBLE, \n"
+ + " c_boolean BOOLEAN, \n"
+ + " c_varchar VARCHAR, \n "
+ + " c_arr ARRAY<VARCHAR> \n"
+ + ") \n"
+ + "TYPE 'mongodb' \n"
+ + "LOCATION '"
+ + mongoSqlUrl
+ + "'";
+ sqlEnv.executeDdl(createTableStatement);
+
+ String insertStatement =
+ "INSERT INTO TEST VALUES ("
+ + "9223372036854775807, "
+ + "127, "
+ + "32767, "
+ + "2147483647, "
+ + "1.0, "
+ + "1.0, "
+ + "TRUE, "
+ + "'varchar', "
+ + "ARRAY['123', '456']"
+ + ")";
+
+ BeamRelNode insertRelNode = sqlEnv.parseQuery(insertStatement);
+ BeamSqlRelUtils.toPCollection(writePipeline, insertRelNode);
+ writePipeline.run().waitUntilFinish();
+
+ BeamRelNode node = sqlEnv.parseQuery("select c_varchar, c_boolean, c_integer from TEST");
+ // Calc should be dropped, since MongoDb supports project push-down and field reordering.
+ assertThat(node, instanceOf(BeamPushDownIOSourceRel.class));
+ // Only selected fields are projected.
+ assertThat(
+ node.getRowType().getFieldNames(),
+ containsInAnyOrder("c_varchar", "c_boolean", "c_integer"));
+ PCollection<Row> output = BeamSqlRelUtils.toPCollection(readPipeline, node);
+
+ assertThat(output.getSchema(), equalTo(expectedSchema));
+ PAssert.that(output).containsInAnyOrder(testRow);
+
+ readPipeline.run().waitUntilFinish();
+
+ MongoDatabase db = client.getDatabase(database);
+ MongoCollection coll = db.getCollection("system.profile");
+ // Find the last executed query.
+ Object query =
+ coll.find()
+ .filter(Filters.eq("op", "query"))
+ .sort(new BasicDBObject().append("ts", -1))
+ .iterator()
+ .next();
+
+ // Retrieve a projection parameters.
+ assertThat(query, instanceOf(Document.class));
+ Object command = ((Document) query).get("command");
+ assertThat(command, instanceOf(Document.class));
+ Object projection = ((Document) command).get("projection");
+ assertThat(projection, instanceOf(Document.class));
+
+ // Validate projected fields.
+ assertThat(
+ ((Document) projection).keySet(),
+ containsInAnyOrder("c_varchar", "c_boolean", "c_integer"));
+ }
+
private Row row(Schema schema, Object... values) {
return Row.withSchema(schema).addValues(values).build();
}
diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java
index 13073a1..59f6057 100644
--- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java
+++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java
@@ -73,6 +73,7 @@
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
+import org.apache.http.conn.ssl.TrustStrategy;
import org.apache.http.entity.BufferedHttpEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.impl.client.BasicCredentialsProvider;
@@ -253,6 +254,8 @@
@Nullable
public abstract Integer getConnectTimeout();
+ public abstract boolean isTrustSelfSignedCerts();
+
abstract Builder builder();
@AutoValue.Builder
@@ -275,6 +278,8 @@
abstract Builder setConnectTimeout(Integer connectTimeout);
+ abstract Builder setTrustSelfSignedCerts(boolean trustSelfSignedCerts);
+
abstract ConnectionConfiguration build();
}
@@ -295,6 +300,7 @@
.setAddresses(Arrays.asList(addresses))
.setIndex(index)
.setType(type)
+ .setTrustSelfSignedCerts(false)
.build();
}
@@ -352,6 +358,18 @@
}
/**
+ * If Elasticsearch uses SSL/TLS then configure whether to trust self signed certs or not. The
+ * default is false.
+ *
+ * @param trustSelfSignedCerts Whether to trust self signed certs
+ * @return a {@link ConnectionConfiguration} describes a connection configuration to
+ * Elasticsearch.
+ */
+ public ConnectionConfiguration withTrustSelfSignedCerts(boolean trustSelfSignedCerts) {
+ return builder().setTrustSelfSignedCerts(trustSelfSignedCerts).build();
+ }
+
+ /**
* If set, overwrites the default max retry timeout (30000ms) in the Elastic {@link RestClient}
* and the default socket timeout (30000ms) in the {@link RequestConfig} of the Elastic {@link
* RestClient}.
@@ -386,6 +404,7 @@
builder.addIfNotNull(DisplayData.item("keystore.path", getKeystorePath()));
builder.addIfNotNull(DisplayData.item("socketAndRetryTimeout", getSocketAndRetryTimeout()));
builder.addIfNotNull(DisplayData.item("connectTimeout", getConnectTimeout()));
+ builder.addIfNotNull(DisplayData.item("trustSelfSignedCerts", isTrustSelfSignedCerts()));
}
@VisibleForTesting
@@ -413,10 +432,10 @@
String keystorePassword = getKeystorePassword();
keyStore.load(is, (keystorePassword == null) ? null : keystorePassword.toCharArray());
}
+ final TrustStrategy trustStrategy =
+ isTrustSelfSignedCerts() ? new TrustSelfSignedStrategy() : null;
final SSLContext sslContext =
- SSLContexts.custom()
- .loadTrustMaterial(keyStore, new TrustSelfSignedStrategy())
- .build();
+ SSLContexts.custom().loadTrustMaterial(keyStore, trustStrategy).build();
final SSLIOSessionStrategy sessionStrategy = new SSLIOSessionStrategy(sslContext);
restClientBuilder.setHttpClientConfigCallback(
httpClientBuilder ->
diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py
index e149451..fb7ee42 100644
--- a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py
+++ b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input.py
@@ -49,7 +49,7 @@
selected = None
len_corpus = len(corpus)
while not selected:
- c = list(corpus[randrange(0, len_corpus - 1)].values())[0]
+ c = list(corpus[randrange(0, len_corpus)].values())[0]
if c != ignore:
selected = c
@@ -59,7 +59,7 @@
selected = None
len_words = len(words)
while not selected:
- c = list(words[randrange(0, len_words - 1)].values())[0]
+ c = list(words[randrange(0, len_words)].values())[0]
if c != ignore:
selected = c
diff --git a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py
index 11fb95b..031eeb3 100644
--- a/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py
+++ b/sdks/python/apache_beam/examples/cookbook/bigquery_side_input_test.py
@@ -36,9 +36,9 @@
group_ids_pcoll = p | 'CreateGroupIds' >> beam.Create(['A', 'B', 'C'])
corpus_pcoll = p | 'CreateCorpus' >> beam.Create(
- [{'f': 'corpus1'}, {'f': 'corpus2'}, {'f': 'corpus3'}])
+ [{'f': 'corpus1'}, {'f': 'corpus2'}])
words_pcoll = p | 'CreateWords' >> beam.Create(
- [{'f': 'word1'}, {'f': 'word2'}, {'f': 'word3'}])
+ [{'f': 'word1'}, {'f': 'word2'}])
ignore_corpus_pcoll = p | 'CreateIgnoreCorpus' >> beam.Create(['corpus1'])
ignore_word_pcoll = p | 'CreateIgnoreWord' >> beam.Create(['word1'])
diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py
index 4b52266..f0f53e2 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets_test.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py
@@ -528,11 +528,12 @@
def test_model_pcollection(self):
temp_path = self.create_temp_file()
snippets.model_pcollection(['--output=%s' % temp_path])
- self.assertEqual(self.get_output(temp_path, sorted_output=False), [
+ self.assertEqual(self.get_output(temp_path), [
+ 'Or to take arms against a sea of troubles, ',
+ 'The slings and arrows of outrageous fortune, ',
'To be, or not to be: that is the question: ',
'Whether \'tis nobler in the mind to suffer ',
- 'The slings and arrows of outrageous fortune, ',
- 'Or to take arms against a sea of troubles, '])
+ ])
def test_construct_pipeline(self):
temp_path = self.create_temp_file(
diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample.py
new file mode 100644
index 0000000..d5abc37
--- /dev/null
+++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample.py
@@ -0,0 +1,69 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+
+def sample_fixed_size_globally(test=None):
+ # [START sample_fixed_size_globally]
+ import apache_beam as beam
+
+ with beam.Pipeline() as pipeline:
+ sample = (
+ pipeline
+ | 'Create produce' >> beam.Create([
+ 'đ Strawberry',
+ 'đ„ Carrot',
+ 'đ Eggplant',
+ 'đ
Tomato',
+ 'đ„ Potato',
+ ])
+ | 'Sample N elements' >> beam.combiners.Sample.FixedSizeGlobally(3)
+ | beam.Map(print)
+ )
+ # [END sample_fixed_size_globally]
+ if test:
+ test(sample)
+
+
+def sample_fixed_size_per_key(test=None):
+ # [START sample_fixed_size_per_key]
+ import apache_beam as beam
+
+ with beam.Pipeline() as pipeline:
+ samples_per_key = (
+ pipeline
+ | 'Create produce' >> beam.Create([
+ ('spring', 'đ'),
+ ('spring', 'đ„'),
+ ('spring', 'đ'),
+ ('spring', 'đ
'),
+ ('summer', 'đ„'),
+ ('summer', 'đ
'),
+ ('summer', 'đœ'),
+ ('fall', 'đ„'),
+ ('fall', 'đ
'),
+ ('winter', 'đ'),
+ ])
+ | 'Samples per key' >> beam.combiners.Sample.FixedSizePerKey(3)
+ | beam.Map(print)
+ )
+ # [END sample_fixed_size_per_key]
+ if test:
+ test(samples_per_key)
diff --git a/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py
new file mode 100644
index 0000000..22cd656
--- /dev/null
+++ b/sdks/python/apache_beam/examples/snippets/transforms/aggregation/sample_test.py
@@ -0,0 +1,63 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+import unittest
+
+import mock
+
+from apache_beam.examples.snippets.util import assert_matches_stdout
+from apache_beam.testing.test_pipeline import TestPipeline
+
+from . import sample
+
+
+def check_sample(actual):
+ expected = '''[START sample]
+['đ„ Carrot', 'đ Eggplant', 'đ
Tomato']
+[END sample]'''.splitlines()[1:-1]
+ # The sampled elements are non-deterministic, so check the sample size.
+ assert_matches_stdout(actual, expected, lambda elements: len(elements))
+
+
+def check_samples_per_key(actual):
+ expected = '''[START samples_per_key]
+('spring', ['đ', 'đ„', 'đ'])
+('summer', ['đ„', 'đ
', 'đœ'])
+('fall', ['đ„', 'đ
'])
+('winter', ['đ'])
+[END samples_per_key]'''.splitlines()[1:-1]
+ # The sampled elements are non-deterministic, so check the sample size.
+ assert_matches_stdout(actual, expected, lambda pair: (pair[0], len(pair[1])))
+
+
+@mock.patch('apache_beam.Pipeline', TestPipeline)
+@mock.patch(
+ 'apache_beam.examples.snippets.transforms.aggregation.sample.print', str)
+class SampleTest(unittest.TestCase):
+ def test_sample_fixed_size_globally(self):
+ sample.sample_fixed_size_globally(check_sample)
+
+ def test_sample_fixed_size_per_key(self):
+ sample.sample_fixed_size_per_key(check_samples_per_key)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
index bbf8d3a..92e37c6 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
@@ -122,7 +122,7 @@
with TestPipeline() as p:
output_pcs = (
p
- | beam.Create(_DESTINATION_ELEMENT_PAIRS)
+ | beam.Create(_DESTINATION_ELEMENT_PAIRS, reshuffle=False)
| beam.ParDo(fn, self.tmpdir)
.with_outputs(fn.WRITTEN_FILE_TAG, fn.UNWRITTEN_RECORD_TAG))
@@ -325,7 +325,7 @@
('destination0', ['file2', 'file3'])]
single_partition_result = [('destination1', ['file0', 'file1'])]
with TestPipeline() as p:
- destination_file_pairs = p | beam.Create(self._ELEMENTS)
+ destination_file_pairs = p | beam.Create(self._ELEMENTS, reshuffle=False)
partitioned_files = (
destination_file_pairs
| beam.ParDo(bqfl.PartitionFiles(1000, 2))
@@ -347,7 +347,7 @@
('destination0', ['file3'])]
single_partition_result = [('destination1', ['file0', 'file1'])]
with TestPipeline() as p:
- destination_file_pairs = p | beam.Create(self._ELEMENTS)
+ destination_file_pairs = p | beam.Create(self._ELEMENTS, reshuffle=False)
partitioned_files = (
destination_file_pairs
| beam.ParDo(bqfl.PartitionFiles(150, 10))
@@ -533,7 +533,7 @@
with TestPipeline('DirectRunner') as p:
outputs = (p
- | beam.Create(_ELEMENTS)
+ | beam.Create(_ELEMENTS, reshuffle=False)
| bqfl.BigQueryBatchFileLoads(
destination,
custom_gcs_temp_location=self._new_tempdir(),
@@ -660,7 +660,7 @@
experiments='use_beam_bq_sink')
with beam.Pipeline(argv=args) as p:
- input = p | beam.Create(_ELEMENTS)
+ input = p | beam.Create(_ELEMENTS, reshuffle=False)
schema_map_pcv = beam.pvalue.AsDict(
p | "MakeSchemas" >> beam.Create(schema_kv_pairs))
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index 505a683..ac62774 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -43,6 +43,7 @@
from apache_beam.io.gcp.bigquery import _StreamToBigQuery
from apache_beam.io.gcp.bigquery_file_loads_test import _ELEMENTS
from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR
+from apache_beam.io.gcp.bigquery_tools import RetryStrategy
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.io.gcp.pubsub import ReadFromPubSub
from apache_beam.io.gcp.tests import utils
@@ -647,7 +648,6 @@
method='FILE_LOADS'))
@attr('IT')
- @unittest.skip('BEAM-8842: Disabled due to reliance on old retry behavior.')
def test_multiple_destinations_transform(self):
streaming = self.test_pipeline.options.view_as(StandardOptions).streaming
if streaming and isinstance(self.test_pipeline.runner, TestDataflowRunner):
@@ -735,6 +735,7 @@
table_side_inputs=(table_record_pcv,),
schema=lambda dest, table_map: table_map.get(dest, None),
schema_side_inputs=(schema_table_pcv,),
+ insert_retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR,
method='STREAMING_INSERTS'))
assert_that(r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS],
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index e2bd696..0a1b211 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -851,7 +851,6 @@
@staticmethod
def get_desired_chunk_size(total_size):
- total_size
if total_size:
# 1MB = 1 shard, 1GB = 32 shards, 1TB = 1000 shards, 1PB = 32k shards
chunk_size = max(1 << 20, 1000 * int(math.sqrt(total_size)))
@@ -860,31 +859,11 @@
return chunk_size
def expand(self, pbegin):
- from apache_beam.options.pipeline_options import DebugOptions
- from apache_beam.transforms import util
-
- assert isinstance(pbegin, pvalue.PBegin)
- self.pipeline = pbegin.pipeline
-
- debug_options = self.pipeline._options.view_as(DebugOptions)
- if debug_options.experiments and 'beam_fn_api' in debug_options.experiments:
- source = self.source
-
- def split_source(unused_impulse):
- return source.split(
- self.get_desired_chunk_size(self.source.estimate_size()))
-
- return (
- pbegin
- | core.Impulse()
- | 'Split' >> core.FlatMap(split_source)
- | util.Reshuffle()
- | 'ReadSplits' >> core.FlatMap(lambda split: split.source.read(
- split.source.get_range_tracker(
- split.start_position, split.stop_position))))
+ if isinstance(self.source, BoundedSource):
+ return pbegin | _SDFBoundedSourceWrapper(self.source)
else:
# Treat Read itself as a primitive.
- return pvalue.PCollection(self.pipeline,
+ return pvalue.PCollection(pbegin.pipeline,
is_bounded=self.source.is_bounded())
def get_windowing(self, unused_inputs):
@@ -1534,7 +1513,11 @@
def _create_sdf_bounded_source_dofn(self):
source = self.source
- chunk_size = Read.get_desired_chunk_size(source.estimate_size())
+ try:
+ estimated_size = source.estimate_size()
+ except NotImplementedError:
+ estimated_size = None
+ chunk_size = Read.get_desired_chunk_size(estimated_size)
class SDFBoundedSourceDoFn(core.DoFn):
def __init__(self, read_source):
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 9772591..a574d20 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -198,8 +198,6 @@
experiments = (p._options.view_as(DebugOptions).experiments or [])
# Setup experiment option to enable using SDFBoundedSourceWrapper
- if 'use_sdf_bounded_source' not in experiments:
- experiments.append('use_sdf_bounded_source')
if 'beam_fn_api' not in experiments:
# Required so mocking below doesn't mock Create used in assert_that.
experiments.append('beam_fn_api')
diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py
index 34f8eba..719bf55 100644
--- a/sdks/python/apache_beam/io/parquetio_test.py
+++ b/sdks/python/apache_beam/io/parquetio_test.py
@@ -312,7 +312,7 @@
path = dst.name
with TestPipeline() as p:
_ = p \
- | Create(self.RECORDS) \
+ | Create(self.RECORDS, reshuffle=False) \
| WriteToParquet(
path, self.SCHEMA, num_shards=1, shard_name_template='')
with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py
index ecfa6fb..ad336c5 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -1113,7 +1113,7 @@
with open(file_name, 'rb') as f:
read_result.extend(f.read().splitlines())
- self.assertEqual(read_result, self.lines)
+ self.assertEqual(sorted(read_result), sorted(self.lines))
def test_write_dataflow_auto_compression(self):
pipeline = TestPipeline()
@@ -1126,7 +1126,7 @@
with gzip.GzipFile(file_name, 'rb') as f:
read_result.extend(f.read().splitlines())
- self.assertEqual(read_result, self.lines)
+ self.assertEqual(sorted(read_result), sorted(self.lines))
def test_write_dataflow_auto_compression_unsharded(self):
pipeline = TestPipeline()
@@ -1142,7 +1142,7 @@
with gzip.GzipFile(file_name, 'rb') as f:
read_result.extend(f.read().splitlines())
- self.assertEqual(read_result, self.lines)
+ self.assertEqual(sorted(read_result), sorted(self.lines))
def test_write_dataflow_header(self):
pipeline = TestPipeline()
@@ -1159,7 +1159,8 @@
with gzip.GzipFile(file_name, 'rb') as f:
read_result.extend(f.read().splitlines())
# header_text is automatically encoded in WriteToText
- self.assertEqual(read_result, [header_text.encode('utf-8')] + self.lines)
+ self.assertEqual(read_result[0], header_text.encode('utf-8'))
+ self.assertEqual(sorted(read_result[1:]), sorted(self.lines))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index 1f7ba2a..dfb154a 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -228,7 +228,7 @@
file_name, options=tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)):
actual.append(r)
- self.assertEqual(actual, input_data)
+ self.assertEqual(sorted(actual), sorted(input_data))
def test_write_record_auto(self):
with TempDir() as temp_dir:
@@ -244,7 +244,7 @@
file_name, options=tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)):
actual.append(r)
- self.assertEqual(actual, input_data)
+ self.assertEqual(sorted(actual), sorted(input_data))
class TestReadFromTFRecord(unittest.TestCase):
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index c1c25d1..b1b2e68 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -253,7 +253,7 @@
def test_visit_entire_graph(self):
pipeline = Pipeline()
- pcoll1 = pipeline | 'pcoll' >> Create([1, 2, 3])
+ pcoll1 = pipeline | 'pcoll' >> beam.Impulse()
pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
pcoll4 = pcoll2 | 'do3' >> FlatMap(lambda x: [x + 1])
@@ -266,9 +266,9 @@
set(visitor.visited))
self.assertEqual(set(visitor.enter_composite),
set(visitor.leave_composite))
- self.assertEqual(3, len(visitor.enter_composite))
- self.assertEqual(visitor.enter_composite[2].transform, transform)
- self.assertEqual(visitor.leave_composite[1].transform, transform)
+ self.assertEqual(2, len(visitor.enter_composite))
+ self.assertEqual(visitor.enter_composite[1].transform, transform)
+ self.assertEqual(visitor.leave_composite[0].transform, transform)
def test_apply_custom_transform(self):
pipeline = TestPipeline()
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 8632cfd..c045231 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -50,13 +50,14 @@
class NameContext(object):
"""Holds the name information for a step."""
- def __init__(self, step_name):
+ def __init__(self, step_name, transform_id=None):
"""Creates a new step NameContext.
Args:
step_name: The name of the step.
"""
self.step_name = step_name
+ self.transform_id = transform_id
def __eq__(self, other):
return self.step_name == other.step_name
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index 718ab61..1fae45d 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -101,12 +101,19 @@
# TODO: Remove the apache_beam.pipeline dependency in CreatePTransformOverride
from apache_beam.runners.dataflow.ptransform_overrides import CreatePTransformOverride
from apache_beam.runners.dataflow.ptransform_overrides import ReadPTransformOverride
+ from apache_beam.runners.dataflow.ptransform_overrides import JrhReadPTransformOverride
_PTRANSFORM_OVERRIDES = [
- CreatePTransformOverride(),
]
- _SDF_PTRANSFORM_OVERRIDES = [
+ _JRH_PTRANSFORM_OVERRIDES = [
+ JrhReadPTransformOverride(),
+ ]
+
+ # These overrides should be applied after the proto representation of the
+ # graph is created.
+ _NON_PORTABLE_PTRANSFORM_OVERRIDES = [
+ CreatePTransformOverride(),
ReadPTransformOverride(),
]
@@ -395,8 +402,10 @@
# done before Runner API serialization, since the new proto needs to contain
# any added PTransforms.
pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)
- if apiclient._use_sdf_bounded_source(options):
- pipeline.replace_all(DataflowRunner._SDF_PTRANSFORM_OVERRIDES)
+
+ if (apiclient._use_fnapi(options)
+ and not apiclient._use_unified_worker(options)):
+ pipeline.replace_all(DataflowRunner._JRH_PTRANSFORM_OVERRIDES)
use_fnapi = apiclient._use_fnapi(options)
from apache_beam.transforms import environments
@@ -424,6 +433,11 @@
self.proto_pipeline, self.proto_context = pipeline.to_runner_api(
return_context=True, default_environment=default_environment)
+ else:
+ # Performing configured PTransform overrides which should not be reflected
+ # in the proto representation of the graph.
+ pipeline.replace_all(DataflowRunner._NON_PORTABLE_PTRANSFORM_OVERRIDES)
+
# Add setup_options for all the BeamPlugin imports
setup_options = options.view_as(SetupOptions)
plugins = BeamPlugin.get_all_plugin_paths()
@@ -504,10 +518,10 @@
result.metric_results = self._metrics
return result
- def _get_typehint_based_encoding(self, typehint, window_coder, use_fnapi):
+ def _get_typehint_based_encoding(self, typehint, window_coder):
"""Returns an encoding based on a typehint object."""
return self._get_cloud_encoding(
- self._get_coder(typehint, window_coder=window_coder), use_fnapi)
+ self._get_coder(typehint, window_coder=window_coder))
@staticmethod
def _get_coder(typehint, window_coder):
@@ -518,13 +532,12 @@
window_coder=window_coder)
return coders.registry.get_coder(typehint)
- def _get_cloud_encoding(self, coder, use_fnapi):
+ def _get_cloud_encoding(self, coder, unused=None):
"""Returns an encoding based on a coder object."""
if not isinstance(coder, coders.Coder):
raise TypeError('Coder object must inherit from coders.Coder: %s.' %
str(coder))
- return coder.as_cloud_object(self.proto_context
- .coders if use_fnapi else None)
+ return coder.as_cloud_object(self.proto_context.coders)
def _get_side_input_encoding(self, input_encoding):
"""Returns an encoding for the output of a view transform.
@@ -567,11 +580,7 @@
output_tag].windowing.windowfn.get_window_coder())
else:
window_coder = None
- from apache_beam.runners.dataflow.internal import apiclient
- use_fnapi = apiclient._use_fnapi(
- list(transform_node.outputs.values())[0].pipeline._options)
- return self._get_typehint_based_encoding(element_type, window_coder,
- use_fnapi)
+ return self._get_typehint_based_encoding(element_type, window_coder)
def _add_step(self, step_kind, step_label, transform_node, side_tags=()):
"""Creates a Step object and adds it to the cache."""
@@ -879,6 +888,8 @@
serialized_data = pickler.dumps(
self._pardo_fn_data(transform_node, lookup_label))
step.add_property(PropertyNames.SERIALIZED_FN, serialized_data)
+ # TODO(BEAM-8882): Enable once dataflow service doesn't reject this.
+ # step.add_property(PropertyNames.PIPELINE_PROTO_TRANSFORM_ID, transform_id)
step.add_property(
PropertyNames.PARALLEL_INPUT,
{'@type': 'OutputReference',
@@ -935,10 +946,9 @@
# Add the restriction encoding if we are a splittable DoFn
# and are using the Fn API on the unified worker.
restriction_coder = transform.get_restriction_coder()
- if (use_fnapi and use_unified_worker and restriction_coder):
+ if restriction_coder:
step.add_property(PropertyNames.RESTRICTION_ENCODING,
- self._get_cloud_encoding(
- restriction_coder, use_fnapi))
+ self._get_cloud_encoding(restriction_coder))
@staticmethod
def _pardo_fn_data(transform_node, get_label):
@@ -958,6 +968,7 @@
input_step = self._cache.get_pvalue(transform_node.inputs[0])
step = self._add_step(
TransformNames.COMBINE, transform_node.full_label, transform_node)
+ transform_id = self.proto_context.transforms.get_id(transform_node.parent)
# The data transmitted in SERIALIZED_FN is different depending on whether
# this is a fnapi pipeline or not.
@@ -967,8 +978,7 @@
# Fnapi pipelines send the transform ID of the CombineValues transform's
# parent composite because Dataflow expects the ID of a CombinePerKey
# transform.
- serialized_data = self.proto_context.transforms.get_id(
- transform_node.parent)
+ serialized_data = transform_id
else:
# Combiner functions do not take deferred side-inputs (i.e. PValues) and
# therefore the code to handle extra args/kwargs is simpler than for the
@@ -977,6 +987,8 @@
serialized_data = pickler.dumps((transform.fn, transform.args,
transform.kwargs, ()))
step.add_property(PropertyNames.SERIALIZED_FN, serialized_data)
+ # TODO(BEAM-8882): Enable once dataflow service doesn't reject this.
+ # step.add_property(PropertyNames.PIPELINE_PROTO_TRANSFORM_ID, transform_id)
step.add_property(
PropertyNames.PARALLEL_INPUT,
{'@type': 'OutputReference',
@@ -985,7 +997,7 @@
# Note that the accumulator must not have a WindowedValue encoding, while
# the output of this step does in fact have a WindowedValue encoding.
accumulator_encoding = self._get_cloud_encoding(
- transform_node.transform.fn.get_accumulator_coder(), use_fnapi)
+ transform_node.transform.fn.get_accumulator_coder())
output_encoding = self._get_encoded_output_coder(transform_node)
step.encoding = output_encoding
@@ -1005,16 +1017,7 @@
# Consider native Read to be a primitive for dataflow.
return beam.pvalue.PCollection.from_(pbegin)
else:
- debug_options = options.view_as(DebugOptions)
- if (
- debug_options.experiments and
- 'beam_fn_api' in debug_options.experiments
- ):
- # Expand according to FnAPI primitives.
- return self.apply_PTransform(transform, pbegin, options)
- else:
- # Custom Read is also a primitive for non-FnAPI on dataflow.
- return beam.pvalue.PCollection.from_(pbegin)
+ return self.apply_PTransform(transform, pbegin, options)
def run_Read(self, transform_node, options):
transform = transform_node.transform
@@ -1125,8 +1128,7 @@
coders.coders.GlobalWindowCoder())
from apache_beam.runners.dataflow.internal import apiclient
- use_fnapi = apiclient._use_fnapi(options)
- step.encoding = self._get_cloud_encoding(coder, use_fnapi)
+ step.encoding = self._get_cloud_encoding(coder)
step.add_property(
PropertyNames.OUTPUT_INFO,
[{PropertyNames.USER_NAME: (
@@ -1212,8 +1214,7 @@
coder = coders.WindowedValueCoder(transform.sink.coder,
coders.coders.GlobalWindowCoder())
from apache_beam.runners.dataflow.internal import apiclient
- use_fnapi = apiclient._use_fnapi(options)
- step.encoding = self._get_cloud_encoding(coder, use_fnapi)
+ step.encoding = self._get_cloud_encoding(coder)
step.add_property(PropertyNames.ENCODING, step.encoding)
step.add_property(
PropertyNames.PARALLEL_INPUT,
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index c47ab88..58c722c 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -238,13 +238,14 @@
p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
p.run()
job_dict = json.loads(str(remote_runner.job))
- self.assertEqual(len(job_dict[u'steps']), 2)
+ self.assertEqual(len(job_dict[u'steps']), 3)
self.assertEqual(job_dict[u'steps'][0][u'kind'], u'ParallelRead')
self.assertEqual(
job_dict[u'steps'][0][u'properties'][u'pubsub_subscription'],
'_starting_signal/')
self.assertEqual(job_dict[u'steps'][1][u'kind'], u'ParallelDo')
+ self.assertEqual(job_dict[u'steps'][2][u'kind'], u'ParallelDo')
def test_biqquery_read_streaming_fail(self):
remote_runner = DataflowRunner()
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
index 1ccbd13..da37813 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
@@ -894,13 +894,6 @@
'use_unified_worker' in debug_options.experiments)
-def _use_sdf_bounded_source(pipeline_options):
- debug_options = pipeline_options.view_as(DebugOptions)
- return _use_fnapi(pipeline_options) and (
- debug_options.experiments and
- 'use_sdf_bounded_source' in debug_options.experiments)
-
-
def _get_container_image_tag():
base_version = pkg_resources.parse_version(
beam_version.__version__).base_version
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py
index fdce49b..111259d 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/names.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py
@@ -106,6 +106,7 @@
OUTPUT_INFO = 'output_info'
OUTPUT_NAME = 'output_name'
PARALLEL_INPUT = 'parallel_input'
+ PIPELINE_PROTO_TRANSFORM_ID = 'pipeline_proto_transform_id'
PUBSUB_ID_LABEL = 'pubsub_id_label'
PUBSUB_SERIALIZED_ATTRIBUTES_FN = 'pubsub_serialized_attributes_fn'
PUBSUB_SUBSCRIPTION = 'pubsub_subscription'
diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
index 828455b..a0a6541 100644
--- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
@@ -188,7 +188,7 @@
p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned
p.run()
- self.assertEqual(['a', 'b', 'c'], sink.written_values)
+ self.assertEqual(['a', 'b', 'c'], sorted(sink.written_values))
class Test_NativeWrite(unittest.TestCase):
diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py b/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py
deleted file mode 100644
index 481209e..0000000
--- a/sdks/python/apache_beam/runners/dataflow/native_io/streaming_create.py
+++ /dev/null
@@ -1,76 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""Create transform for streaming."""
-
-from __future__ import absolute_import
-
-from builtins import map
-
-from apache_beam import DoFn
-from apache_beam import ParDo
-from apache_beam import PTransform
-from apache_beam import Windowing
-from apache_beam import pvalue
-from apache_beam.transforms.window import GlobalWindows
-
-
-class StreamingCreate(PTransform):
- """A specialized implementation for ``Create`` transform in streaming mode.
-
- Note: There is no unbounded source API in python to wrap the Create source,
- so we map this to composite of Impulse primitive and an SDF.
- """
-
- def __init__(self, values, coder):
- self.coder = coder
- self.encoded_values = list(map(coder.encode, values))
-
- class DecodeAndEmitDoFn(DoFn):
- """A DoFn which stores encoded versions of elements.
-
- It also stores a Coder to decode and emit those elements.
- TODO: BEAM-2422 - Make this a SplittableDoFn.
- """
-
- def __init__(self, encoded_values, coder):
- self.encoded_values = encoded_values
- self.coder = coder
-
- def process(self, unused_element):
- for encoded_value in self.encoded_values:
- yield self.coder.decode(encoded_value)
-
- class Impulse(PTransform):
- """The Dataflow specific override for the impulse primitive."""
-
- def expand(self, pbegin):
- assert isinstance(pbegin, pvalue.PBegin), (
- 'Input to Impulse transform must be a PBegin but found %s' % pbegin)
- return pvalue.PCollection(pbegin.pipeline, is_bounded=False)
-
- def get_windowing(self, inputs):
- return Windowing(GlobalWindows())
-
- def infer_output_type(self, unused_input_type):
- return bytes
-
- def expand(self, pbegin):
- return (pbegin
- | 'Impulse' >> self.Impulse()
- | 'Decode Values' >> ParDo(
- self.DecodeAndEmitDoFn(self.encoded_values, self.coder)))
diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
index 6e84c15..e3e76a5 100644
--- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
+++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py
@@ -19,7 +19,6 @@
from __future__ import absolute_import
-from apache_beam.coders import typecoders
from apache_beam.pipeline import PTransformOverride
@@ -30,24 +29,24 @@
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam import Create
- from apache_beam.options.pipeline_options import StandardOptions
+ from apache_beam.runners.dataflow.internal import apiclient
if isinstance(applied_ptransform.transform, Create):
- standard_options = (applied_ptransform
- .outputs[None]
- .pipeline._options
- .view_as(StandardOptions))
- return standard_options.streaming
+ return not apiclient._use_fnapi(
+ applied_ptransform.outputs[None].pipeline._options)
else:
return False
def get_replacement_transform(self, ptransform):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
- from apache_beam.runners.dataflow.native_io.streaming_create import \
- StreamingCreate
- coder = typecoders.registry.get_coder(ptransform.get_output_type())
- return StreamingCreate(ptransform.values, coder)
+ from apache_beam import PTransform
+ # Return a wrapper rather than ptransform.as_read() directly to
+ # ensure backwards compatibility of the pipeline structure.
+ class LegacyCreate(PTransform):
+ def expand(self, pbegin):
+ return pbegin | ptransform.as_read()
+ return LegacyCreate().with_output_types(ptransform.get_output_type())
class ReadPTransformOverride(PTransformOverride):
@@ -57,11 +56,51 @@
from apache_beam.io import Read
from apache_beam.io.iobase import BoundedSource
# Only overrides Read(BoundedSource) transform
- if isinstance(applied_ptransform.transform, Read):
+ if (isinstance(applied_ptransform.transform, Read)
+ and not getattr(applied_ptransform.transform, 'override', False)):
if isinstance(applied_ptransform.transform.source, BoundedSource):
return True
return False
def get_replacement_transform(self, ptransform):
- from apache_beam.io.iobase import _SDFBoundedSourceWrapper
- return _SDFBoundedSourceWrapper(ptransform.source)
+ from apache_beam import pvalue
+ from apache_beam.io import iobase
+ class Read(iobase.Read):
+ override = True
+ def expand(self, pbegin):
+ return pvalue.PCollection(
+ self.pipeline, is_bounded=self.source.is_bounded())
+ return Read(ptransform.source).with_output_types(
+ ptransform.get_type_hints().simple_output_type('Read'))
+
+
+class JrhReadPTransformOverride(PTransformOverride):
+ """A ``PTransformOverride`` for ``Read(BoundedSource)``"""
+
+ def matches(self, applied_ptransform):
+ from apache_beam.io import Read
+ from apache_beam.io.iobase import BoundedSource
+ return (isinstance(applied_ptransform.transform, Read)
+ and isinstance(applied_ptransform.transform.source, BoundedSource))
+
+ def get_replacement_transform(self, ptransform):
+ from apache_beam.io import Read
+ from apache_beam.transforms import core
+ from apache_beam.transforms import util
+ # Make this a local to narrow what's captured in the closure.
+ source = ptransform.source
+
+ class JrhRead(core.PTransform):
+ def expand(self, pbegin):
+ return (
+ pbegin
+ | core.Impulse()
+ | 'Split' >> core.FlatMap(lambda _: source.split(
+ Read.get_desired_chunk_size(source.estimate_size())))
+ | util.Reshuffle()
+ | 'ReadSplits' >> core.FlatMap(lambda split: split.source.read(
+ split.source.get_range_tracker(
+ split.start_position, split.stop_position))))
+
+ return JrhRead().with_output_types(
+ ptransform.get_type_hints().simple_output_type('Read'))
diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py
index 97d4375..6d21e55 100644
--- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py
+++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor_test.py
@@ -21,9 +21,8 @@
import logging
import unittest
+import apache_beam as beam
from apache_beam import pvalue
-from apache_beam.io import Read
-from apache_beam.io import iobase
from apache_beam.pipeline import Pipeline
from apache_beam.pvalue import AsList
from apache_beam.runners.direct import DirectRunner
@@ -51,10 +50,7 @@
pass
def test_root_transforms(self):
- class DummySource(iobase.BoundedSource):
- pass
-
- root_read = Read(DummySource())
+ root_read = beam.Impulse()
root_flatten = Flatten(pipeline=self.pipeline)
pbegin = pvalue.PBegin(self.pipeline)
@@ -88,10 +84,7 @@
def process(self, element, negatives):
yield element
- class DummySource(iobase.BoundedSource):
- pass
-
- root_read = Read(DummySource())
+ root_read = beam.Impulse()
result = (self.pipeline
| 'read' >> root_read
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py
index 22a930c..95df81d 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py
@@ -79,7 +79,7 @@
return [element]
p = Pipeline(DirectRunner())
- pcoll = (p | beam.Create([1, 2, 3, 4, 5])
+ pcoll = (p | beam.Create([1, 2, 3, 4, 5], reshuffle=False)
| 'Do' >> beam.ParDo(MyDoFn()))
assert_that(pcoll, equal_to([1, 2, 3, 4, 5]))
result = p.run()
@@ -132,6 +132,10 @@
| beam.Create([[]]).with_output_types(beam.typehints.List[int])
| beam.combiners.Count.Globally())
+ def test_impulse(self):
+ with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p:
+ assert_that(p | beam.Impulse(), equal_to([b'']))
+
class DirectRunnerRetryTests(unittest.TestCase):
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index fd04d4c..d9d68cc 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -147,8 +147,16 @@
super(SDFDirectRunnerTest, self).setUp()
# Importing following for DirectRunner SDF implemenation for testing.
from apache_beam.runners.direct import transform_evaluator
- self._default_max_num_outputs = (
+ self._old_default_max_num_outputs = (
transform_evaluator._ProcessElementsEvaluator.DEFAULT_MAX_NUM_OUTPUTS)
+ self._default_max_num_outputs = (
+ transform_evaluator._ProcessElementsEvaluator.DEFAULT_MAX_NUM_OUTPUTS
+ ) = 100
+
+ def tearDown(self):
+ from apache_beam.runners.direct import transform_evaluator
+ transform_evaluator._ProcessElementsEvaluator.DEFAULT_MAX_NUM_OUTPUTS = (
+ self._old_default_max_num_outputs)
def run_sdf_read_pipeline(
self, num_files, num_records_per_file, resume_count=None):
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index 4617711..c893d8b 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -83,6 +83,7 @@
io.Read: _BoundedReadEvaluator,
_DirectReadFromPubSub: _PubSubReadEvaluator,
core.Flatten: _FlattenEvaluator,
+ core.Impulse: _ImpulseEvaluator,
core.ParDo: _ParDoEvaluator,
core._GroupByKeyOnly: _GroupByKeyOnlyEvaluator,
_StreamingGroupByKeyOnly: _StreamingGroupByKeyOnlyEvaluator,
@@ -517,6 +518,17 @@
return TransformResult(self, bundles, [], None, None)
+class _ImpulseEvaluator(_TransformEvaluator):
+ """TransformEvaluator for Impulse transform."""
+
+ def finish_bundle(self):
+ assert len(self._outputs) == 1
+ output_pcollection = list(self._outputs)[0]
+ bundle = self._evaluation_context.create_bundle(output_pcollection)
+ bundle.output(GlobalWindows.windowed_value(b''))
+ return TransformResult(self, [bundle], [], None, None)
+
+
class _TaggedReceivers(dict):
"""Received ParDo output and redirect to the associated output bundle."""
@@ -811,15 +823,17 @@
k = self.key_coder.decode(encoded_k)
state = self._step_context.get_keyed_state(encoded_k)
+ watermarks = self._evaluation_context._watermark_manager.get_watermarks(
+ self._applied_ptransform)
for timer_firing in timer_firings:
for wvalue in self.driver.process_timer(
timer_firing.window, timer_firing.name, timer_firing.time_domain,
- timer_firing.timestamp, state):
+ timer_firing.timestamp, state, watermarks.input_watermark):
self.gabw_items.append(wvalue.with_value((k, wvalue.value)))
- watermark = self._evaluation_context._watermark_manager.get_watermarks(
- self._applied_ptransform).output_watermark
if vs:
- for wvalue in self.driver.process_elements(state, vs, watermark):
+ for wvalue in self.driver.process_elements(state, vs,
+ watermarks.output_watermark,
+ watermarks.input_watermark):
self.gabw_items.append(wvalue.with_value((k, wvalue.value)))
self.keyed_holds[encoded_k] = state.get_earliest_hold()
@@ -903,7 +917,7 @@
# Maximum number of elements that will be produced by a Splittable DoFn before
# a checkpoint is requested by the runner.
- DEFAULT_MAX_NUM_OUTPUTS = 100
+ DEFAULT_MAX_NUM_OUTPUTS = None
# Maximum duration a Splittable DoFn will process an element before a
# checkpoint is requested by the runner.
DEFAULT_MAX_DURATION = 1
diff --git a/sdks/python/apache_beam/runners/direct/util.py b/sdks/python/apache_beam/runners/direct/util.py
index 407ea39..57650ac 100644
--- a/sdks/python/apache_beam/runners/direct/util.py
+++ b/sdks/python/apache_beam/runners/direct/util.py
@@ -64,9 +64,8 @@
self.timestamp = timestamp
def __repr__(self):
- return 'TimerFiring(%r, %r, %s, %s)' % (self.encoded_key,
- self.name, self.time_domain,
- self.timestamp)
+ return 'TimerFiring({!r}, {!r}, {}, {})'.format(
+ self.encoded_key, self.name, self.time_domain, self.timestamp)
class KeyedWorkItem(object):
@@ -75,3 +74,7 @@
self.encoded_key = encoded_key
self.timer_firings = timer_firings or []
self.elements = elements or []
+
+ def __repr__(self):
+ return 'KeyedWorkItem({!r}, {}, {})'.format(
+ self.encoded_key, self.timer_firings, self.elements)
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_analyzer_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_analyzer_test.py
index e860226..b0433ff 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_analyzer_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_analyzer_test.py
@@ -226,7 +226,7 @@
pipeline_proto = to_stable_runner_api(p)
pipeline_info = pipeline_analyzer.PipelineInfo(pipeline_proto.components)
- pcoll_id = 'ref_PCollection_PCollection_3' # Output PCollection of Square
+ pcoll_id = 'ref_PCollection_PCollection_12' # Output PCollection of Square
cache_label1 = pipeline_info.cache_label(pcoll_id)
analyzer = pipeline_analyzer.PipelineAnalyzer(self.cache_manager,
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
index af9b5c5..cd06b42 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
@@ -26,6 +26,7 @@
import apache_beam as beam
from apache_beam.io.gcp.pubsub import ReadFromPubSub
from apache_beam.pipeline import PipelineVisitor
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive import interactive_environment as ie
@@ -73,13 +74,20 @@
pipeline.to_runner_api(use_fake_coders=True),
pipeline.runner,
options)
+
+ self._background_caching_pipeline = beam.pipeline.Pipeline.from_runner_api(
+ pipeline.to_runner_api(use_fake_coders=True),
+ pipeline.runner,
+ options)
+
# Snapshot of original pipeline information.
(self._original_pipeline_proto,
self._original_context) = self._pipeline_snap.to_runner_api(
return_context=True, use_fake_coders=True)
# All compute-once-against-original-pipeline fields.
- self._has_unbounded_source = has_unbounded_source(self._pipeline_snap)
+ self._unbounded_sources = unbounded_sources(
+ self._background_caching_pipeline)
# TODO(BEAM-7760): once cache scope changed, this is not needed to manage
# relationships across pipelines, runners, and jobs.
self._pcolls_to_pcoll_id = pcolls_to_pcoll_id(self._pipeline_snap,
@@ -103,11 +111,131 @@
"""Always returns a new instance of portable instrumented proto."""
return self._pipeline.to_runner_api(use_fake_coders=True)
+ def _required_components(self, pipeline_proto, required_transforms_ids):
+ """Returns the components and subcomponents of the given transforms.
+
+ This method returns all the components (transforms, PCollections, coders,
+ and windowing stratgies) related to the given transforms and to all of their
+ subtransforms. This method accomplishes this recursively.
+ """
+ transforms = pipeline_proto.components.transforms
+ pcollections = pipeline_proto.components.pcollections
+ coders = pipeline_proto.components.coders
+ windowing_strategies = pipeline_proto.components.windowing_strategies
+
+ # Cache the transforms that will be copied into the new pipeline proto.
+ required_transforms = {k: transforms[k] for k in required_transforms_ids}
+
+ # Cache all the output PCollections of the transforms.
+ pcollection_ids = [pc for t in required_transforms.values()
+ for pc in t.outputs.values()]
+ required_pcollections = {pc_id: pcollections[pc_id]
+ for pc_id in pcollection_ids}
+
+ # Cache all the PCollection coders.
+ coder_ids = [pc.coder_id for pc in required_pcollections.values()]
+ required_coders = {c_id: coders[c_id] for c_id in coder_ids}
+
+ # Cache all the windowing strategy ids.
+ windowing_strategies_ids = [pc.windowing_strategy_id
+ for pc in required_pcollections.values()]
+ required_windowing_strategies = {ws_id: windowing_strategies[ws_id]
+ for ws_id in windowing_strategies_ids}
+
+ subtransforms = {}
+ subpcollections = {}
+ subcoders = {}
+ subwindowing_strategies = {}
+
+ # Recursively go through all the subtransforms and add their components.
+ for transform_id, transform in required_transforms.items():
+ if transform_id in pipeline_proto.root_transform_ids:
+ continue
+ (t, pc, c, ws) = self._required_components(pipeline_proto,
+ transform.subtransforms)
+ subtransforms.update(t)
+ subpcollections.update(pc)
+ subcoders.update(c)
+ subwindowing_strategies.update(ws)
+
+ # Now we got all the components and their subcomponents, so return the
+ # complete collection.
+ required_transforms.update(subtransforms)
+ required_pcollections.update(subpcollections)
+ required_coders.update(subcoders)
+ required_windowing_strategies.update(subwindowing_strategies)
+
+ return (required_transforms, required_pcollections, required_coders,
+ required_windowing_strategies)
+
+ def background_caching_pipeline_proto(self):
+ """Returns the background caching pipeline.
+
+ This method creates a background caching pipeline by: adding writes to cache
+ from each unbounded source (done in the instrument method), and cutting out
+ all components (transform, PCollections, coders, windowing strategies) that
+ are not the unbounded sources or writes to cache (or subtransforms thereof).
+ """
+ # Create the pipeline_proto to read all the components from. It will later
+ # create a new pipeline proto from the cut out components.
+ pipeline_proto = self._background_caching_pipeline.to_runner_api(
+ return_context=False, use_fake_coders=True)
+
+ # Get all the sources we want to cache.
+ sources = unbounded_sources(self._background_caching_pipeline)
+
+ # Get all the root transforms. The caching transforms will be subtransforms
+ # of one of these roots.
+ roots = [root for root in pipeline_proto.root_transform_ids]
+
+ # Get the transform IDs of the caching transforms. These caching operations
+ # are added the the _background_caching_pipeline in the instrument() method.
+ # It's added there so that multiple calls to this method won't add multiple
+ # caching operations (idempotent).
+ transforms = pipeline_proto.components.transforms
+ caching_transform_ids = [t_id for root in roots
+ for t_id in transforms[root].subtransforms
+ if WRITE_CACHE in t_id]
+
+ # Get the IDs of the unbounded sources.
+ required_transform_labels = [src.full_label for src in sources]
+ unbounded_source_ids = [k for k, v in transforms.items()
+ if v.unique_name in required_transform_labels]
+
+ # The required transforms are the tranforms that we want to cut out of
+ # the pipeline_proto and insert into a new pipeline to return.
+ required_transform_ids = (roots + caching_transform_ids +
+ unbounded_source_ids)
+ (t, p, c, w) = self._required_components(pipeline_proto,
+ required_transform_ids)
+
+ def set_proto_map(proto_map, new_value):
+ proto_map.clear()
+ for key, value in new_value.items():
+ proto_map[key].CopyFrom(value)
+
+ # Copy the transforms into the new pipeline.
+ pipeline_to_execute = beam_runner_api_pb2.Pipeline()
+ pipeline_to_execute.root_transform_ids[:] = roots
+ set_proto_map(pipeline_to_execute.components.transforms, t)
+ set_proto_map(pipeline_to_execute.components.pcollections, p)
+ set_proto_map(pipeline_to_execute.components.coders, c)
+ set_proto_map(pipeline_to_execute.components.windowing_strategies, w)
+
+ # Cut out all subtransforms in the root that aren't the required transforms.
+ for root_id in roots:
+ root = pipeline_to_execute.components.transforms[root_id]
+ root.subtransforms[:] = [
+ transform_id for transform_id in root.subtransforms
+ if transform_id in pipeline_to_execute.components.transforms]
+
+ return pipeline_to_execute
+
@property
- def has_unbounded_source(self):
+ def has_unbounded_sources(self):
"""Returns whether the pipeline has any `REPLACEABLE_UNBOUNDED_SOURCES`.
"""
- return self._has_unbounded_source
+ return len(self._unbounded_sources) > 0
@property
def cacheables(self):
@@ -165,12 +293,19 @@
self._pipeline.visit(v)
# Create ReadCache transforms.
for cacheable_input in cacheable_inputs:
- self._read_cache(cacheable_input)
+ self._read_cache(self._pipeline, cacheable_input)
# Replace/wire inputs w/ cached PCollections from ReadCache transforms.
- self._replace_with_cached_inputs()
+ self._replace_with_cached_inputs(self._pipeline)
# Write cache for all cacheables.
for _, cacheable in self.cacheables.items():
- self._write_cache(cacheable['pcoll'])
+ self._write_cache(self._pipeline, cacheable['pcoll'])
+
+ # Instrument the background caching pipeline if we can.
+ if self.has_unbounded_sources:
+ for source in self._unbounded_sources:
+ self._write_cache(self._background_caching_pipeline,
+ source.outputs[None])
+
# TODO(BEAM-7760): prune sub graphs that doesn't need to be executed.
def preprocess(self):
@@ -207,7 +342,7 @@
v = PreprocessVisitor(self)
self._pipeline.visit(v)
- def _write_cache(self, pcoll):
+ def _write_cache(self, pipeline, pcoll):
"""Caches a cacheable PCollection.
For the given PCollection, by appending sub transform part that materialize
@@ -218,29 +353,29 @@
the pipeline being instrumented and the keyed cache is absent.
Modifies:
- self._pipeline
+ pipeline
"""
# Makes sure the pcoll belongs to the pipeline being instrumented.
- if pcoll.pipeline is not self._pipeline:
+ if pcoll.pipeline is not pipeline:
return
# The keyed cache is always valid within this instrumentation.
key = self.cache_key(pcoll)
# Only need to write when the cache with expected key doesn't exist.
if not self._cache_manager.exists('full', key):
- _ = pcoll | '{}{}'.format(WRITE_CACHE, key) >> cache.WriteCache(
- self._cache_manager, key)
+ label = '{}{}'.format(WRITE_CACHE, key)
+ _ = pcoll | label >> cache.WriteCache(self._cache_manager, key)
- def _read_cache(self, pcoll):
+ def _read_cache(self, pipeline, pcoll):
"""Reads a cached pvalue.
A noop will cause the pipeline to execute the transform as
it is and cache nothing from this transform for next run.
Modifies:
- self._pipeline
+ pipeline
"""
# Makes sure the pcoll belongs to the pipeline being instrumented.
- if pcoll.pipeline is not self._pipeline:
+ if pcoll.pipeline is not pipeline:
return
# The keyed cache is always valid within this instrumentation.
key = self.cache_key(pcoll)
@@ -250,13 +385,13 @@
# Mutates the pipeline with cache read transform attached
# to root of the pipeline.
pcoll_from_cache = (
- self._pipeline
+ pipeline
| '{}{}'.format(READ_CACHE, key) >> cache.ReadCache(
self._cache_manager, key))
self._cached_pcoll_read[key] = pcoll_from_cache
# else: NOOP when cache doesn't exist, just compute the original graph.
- def _replace_with_cached_inputs(self):
+ def _replace_with_cached_inputs(self, pipeline):
"""Replace PCollection inputs in the pipeline with cache if possible.
For any input PCollection, find out whether there is valid cache. If so,
@@ -287,7 +422,7 @@
transform_node.inputs = tuple(input_list)
v = ReadCacheWireVisitor(self)
- self._pipeline.visit(v)
+ pipeline.visit(v)
def _cacheable_inputs(self, transform):
inputs = set()
@@ -374,9 +509,10 @@
cacheable['version'] = str(id(val))
cacheable['pcoll'] = val
cacheable['producer_version'] = str(id(val.producer))
- cacheables[cacheable_key(val, pcolls_to_pcoll_id)] = cacheable
pcoll_version_map[cacheable['pcoll_id']] = cacheable['version']
+ cacheables[cacheable_key(val, pcolls_to_pcoll_id)] = cacheable
cacheable_var_by_pcoll_id[cacheable['pcoll_id']] = key
+
return pcoll_version_map, cacheables, cacheable_var_by_pcoll_id
@@ -390,8 +526,13 @@
return '_'.join((pcoll_version, pcoll_id))
-def has_unbounded_source(pipeline):
+def has_unbounded_sources(pipeline):
"""Checks if a given pipeline has replaceable unbounded sources."""
+ return len(unbounded_sources(pipeline)) > 0
+
+
+def unbounded_sources(pipeline):
+ """Returns a pipeline's replaceable unbounded sources."""
class CheckUnboundednessVisitor(PipelineVisitor):
"""Visitor checks if there are any unbounded read sources in the Pipeline.
@@ -401,18 +542,18 @@
"""
def __init__(self):
- self.has_unbounded_source = False
+ self.unbounded_sources = []
def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)
def visit_transform(self, transform_node):
- self.has_unbounded_source |= isinstance(transform_node.transform,
- REPLACEABLE_UNBOUNDED_SOURCES)
+ if isinstance(transform_node.transform, REPLACEABLE_UNBOUNDED_SOURCES):
+ self.unbounded_sources.append(transform_node)
v = CheckUnboundednessVisitor()
pipeline.visit(v)
- return v.has_unbounded_source
+ return v.unbounded_sources
def pcolls_to_pcoll_id(pipeline, original_context):
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
index 3d9a611..c45b8e3 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
@@ -44,10 +44,8 @@
def setUp(self):
ie.new_env(cache_manager=cache.FileBasedCacheManager())
- def assertPipelineEqual(self, actual_pipeline, expected_pipeline):
- actual_pipeline_proto = actual_pipeline.to_runner_api(use_fake_coders=True)
- expected_pipeline_proto = expected_pipeline.to_runner_api(
- use_fake_coders=True)
+ def assertPipelineProtoEqual(self, actual_pipeline_proto,
+ expected_pipeline_proto):
components1 = actual_pipeline_proto.components
components2 = expected_pipeline_proto.components
self.assertEqual(len(components1.transforms), len(components2.transforms))
@@ -64,6 +62,13 @@
expected_pipeline_proto,
expected_pipeline_proto.root_transform_ids[0])
+ def assertPipelineEqual(self, actual_pipeline, expected_pipeline):
+ actual_pipeline_proto = actual_pipeline.to_runner_api(use_fake_coders=True)
+ expected_pipeline_proto = expected_pipeline.to_runner_api(
+ use_fake_coders=True)
+ self.assertPipelineProtoEqual(actual_pipeline_proto,
+ expected_pipeline_proto)
+
def assertTransformEqual(self, actual_pipeline_proto, actual_transform_id,
expected_pipeline_proto, expected_transform_id):
transform_proto1 = actual_pipeline_proto.components.transforms[
@@ -83,7 +88,7 @@
def test_pcolls_to_pcoll_id(self):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
# pylint: disable=range-builtin-not-iterating
- init_pcoll = p | 'Init Create' >> beam.Create(range(10))
+ init_pcoll = p | 'Init Create' >> beam.Impulse()
_, ctx = p.to_runner_api(use_fake_coders=True, return_context=True)
self.assertEqual(instr.pcolls_to_pcoll_id(p, ctx), {
str(init_pcoll): 'ref_PCollection_PCollection_1'})
@@ -95,7 +100,7 @@
_, ctx = p.to_runner_api(use_fake_coders=True, return_context=True)
self.assertEqual(
instr.cacheable_key(init_pcoll, instr.pcolls_to_pcoll_id(p, ctx)),
- str(id(init_pcoll)) + '_ref_PCollection_PCollection_1')
+ str(id(init_pcoll)) + '_ref_PCollection_PCollection_10')
def test_cacheable_key_with_version_map(self):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
@@ -118,8 +123,8 @@
# init_pcoll_2 is supplied as long as the version map is given.
self.assertEqual(
instr.cacheable_key(init_pcoll_2, instr.pcolls_to_pcoll_id(p2, ctx), {
- 'ref_PCollection_PCollection_1': str(id(init_pcoll))}),
- str(id(init_pcoll)) + '_ref_PCollection_PCollection_1')
+ 'ref_PCollection_PCollection_10': str(id(init_pcoll))}),
+ str(id(init_pcoll)) + '_ref_PCollection_PCollection_10')
def test_cache_key(self):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
@@ -132,13 +137,13 @@
pin = instr.pin(p)
self.assertEqual(pin.cache_key(init_pcoll), 'init_pcoll_' + str(
- id(init_pcoll)) + '_ref_PCollection_PCollection_1_' + str(id(
+ id(init_pcoll)) + '_ref_PCollection_PCollection_10_' + str(id(
init_pcoll.producer)))
self.assertEqual(pin.cache_key(squares), 'squares_' + str(
- id(squares)) + '_ref_PCollection_PCollection_2_' + str(id(
+ id(squares)) + '_ref_PCollection_PCollection_11_' + str(id(
squares.producer)))
self.assertEqual(pin.cache_key(cubes), 'cubes_' + str(
- id(cubes)) + '_ref_PCollection_PCollection_3_' + str(id(
+ id(cubes)) + '_ref_PCollection_PCollection_12_' + str(id(
cubes.producer)))
def test_cacheables(self):
@@ -154,21 +159,21 @@
pin._cacheable_key(init_pcoll): {
'var': 'init_pcoll',
'version': str(id(init_pcoll)),
- 'pcoll_id': 'ref_PCollection_PCollection_1',
+ 'pcoll_id': 'ref_PCollection_PCollection_10',
'producer_version': str(id(init_pcoll.producer)),
'pcoll': init_pcoll
},
pin._cacheable_key(squares): {
'var': 'squares',
'version': str(id(squares)),
- 'pcoll_id': 'ref_PCollection_PCollection_2',
+ 'pcoll_id': 'ref_PCollection_PCollection_11',
'producer_version': str(id(squares.producer)),
'pcoll': squares
},
pin._cacheable_key(cubes): {
'var': 'cubes',
'version': str(id(cubes)),
- 'pcoll_id': 'ref_PCollection_PCollection_3',
+ 'pcoll_id': 'ref_PCollection_PCollection_12',
'producer_version': str(id(cubes.producer)),
'pcoll': cubes
}
@@ -178,14 +183,48 @@
p = beam.Pipeline(interactive_runner.InteractiveRunner())
_ = p | 'ReadUnboundedSource' >> beam.io.ReadFromPubSub(
subscription='projects/fake-project/subscriptions/fake_sub')
- self.assertTrue(instr.has_unbounded_source(p))
+ self.assertTrue(instr.has_unbounded_sources(p))
def test_not_has_unbounded_source(self):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(b'test')
_ = p | 'ReadBoundedSource' >> beam.io.ReadFromText(f.name)
- self.assertFalse(instr.has_unbounded_source(p))
+ self.assertFalse(instr.has_unbounded_sources(p))
+
+ def test_background_caching_pipeline_proto(self):
+ p = beam.Pipeline(interactive_runner.InteractiveRunner())
+
+ # Test that the two ReadFromPubSub are correctly cut out.
+ a = p | 'ReadUnboundedSourceA' >> beam.io.ReadFromPubSub(
+ subscription='projects/fake-project/subscriptions/fake_sub')
+ b = p | 'ReadUnboundedSourceB' >> beam.io.ReadFromPubSub(
+ subscription='projects/fake-project/subscriptions/fake_sub')
+
+ # Add some extra PTransform afterwards to make sure that only the unbounded
+ # sources remain.
+ c = (a, b) | beam.CoGroupByKey()
+ _ = c | beam.Map(lambda x: x)
+
+ ib.watch(locals())
+ instrumenter = instr.pin(p)
+ actual_pipeline = instrumenter.background_caching_pipeline_proto()
+
+ # Now recreate the expected pipeline, which should only have the unbounded
+ # sources.
+ p = beam.Pipeline(interactive_runner.InteractiveRunner())
+ a = p | 'ReadUnboundedSourceA' >> beam.io.ReadFromPubSub(
+ subscription='projects/fake-project/subscriptions/fake_sub')
+ _ = a | 'a' >> cache.WriteCache(ie.current_env().cache_manager(), '')
+
+ b = p | 'ReadUnboundedSourceB' >> beam.io.ReadFromPubSub(
+ subscription='projects/fake-project/subscriptions/fake_sub')
+ _ = b | 'b' >> cache.WriteCache(ie.current_env().cache_manager(), '')
+
+ expected_pipeline = p.to_runner_api(return_context=False,
+ use_fake_coders=True)
+
+ self.assertPipelineProtoEqual(actual_pipeline, expected_pipeline)
def _example_pipeline(self, watch=True):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
@@ -244,11 +283,11 @@
# Mock as if cacheable PCollections are cached.
init_pcoll_cache_key = 'init_pcoll_' + str(
- id(init_pcoll)) + '_ref_PCollection_PCollection_1_' + str(id(
+ id(init_pcoll)) + '_ref_PCollection_PCollection_10_' + str(id(
init_pcoll.producer))
self._mock_write_cache(init_pcoll, init_pcoll_cache_key)
second_pcoll_cache_key = 'second_pcoll_' + str(
- id(second_pcoll)) + '_ref_PCollection_PCollection_2_' + str(id(
+ id(second_pcoll)) + '_ref_PCollection_PCollection_11_' + str(id(
second_pcoll.producer))
self._mock_write_cache(second_pcoll, second_pcoll_cache_key)
ie.current_env().cache_manager().exists = MagicMock(return_value=True)
diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py b/sdks/python/apache_beam/runners/portability/artifact_service.py
index 100eca5..1ba9602 100644
--- a/sdks/python/apache_beam/runners/portability/artifact_service.py
+++ b/sdks/python/apache_beam/runners/portability/artifact_service.py
@@ -146,12 +146,12 @@
Writing to zip files requires Python 3.6+.
"""
- def __init__(self, path, chunk_size=None):
+ def __init__(self, path, internal_root, chunk_size=None):
if sys.version_info < (3, 6):
raise RuntimeError(
'Writing to zip files requires Python 3.6+, '
'but current version is %s' % sys.version)
- super(ZipFileArtifactService, self).__init__('', chunk_size)
+ super(ZipFileArtifactService, self).__init__(internal_root, chunk_size)
self._zipfile = zipfile.ZipFile(path, 'a')
self._lock = threading.Lock()
@@ -172,6 +172,10 @@
pass
def _open(self, path, mode):
+ if path.startswith('/'):
+ raise ValueError(
+ 'ZIP file entry %s invalid: '
+ 'path must not contain a leading slash.' % path)
return self._zipfile.open(path, mode, force_zip64=True)
def PutArtifact(self, request_iterator, context=None):
diff --git a/sdks/python/apache_beam/runners/portability/artifact_service_test.py b/sdks/python/apache_beam/runners/portability/artifact_service_test.py
index f5da724..6efb60d 100644
--- a/sdks/python/apache_beam/runners/portability/artifact_service_test.py
+++ b/sdks/python/apache_beam/runners/portability/artifact_service_test.py
@@ -219,7 +219,7 @@
class ZipFileArtifactServiceTest(AbstractArtifactServiceTest):
def create_service(self, staging_dir):
return artifact_service.ZipFileArtifactService(
- os.path.join(staging_dir, 'test.zip'), chunk_size=10)
+ os.path.join(staging_dir, 'test.zip'), 'root', chunk_size=10)
class BeamFilesystemArtifactServiceTest(AbstractArtifactServiceTest):
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
index c6348cf..b318971 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
@@ -105,6 +105,7 @@
[PIPELINE_FOLDER, PIPELINE_NAME, 'pipeline-options.json'])
ARTIFACT_MANIFEST_PATH = '/'.join(
[PIPELINE_FOLDER, PIPELINE_NAME, 'artifact-manifest.json'])
+ ARTIFACT_FOLDER = '/'.join([PIPELINE_FOLDER, PIPELINE_NAME, 'artifacts'])
def __init__(
self, master_url, executable_jar, job_id, job_name, pipeline, options,
@@ -134,7 +135,7 @@
def _start_artifact_service(self, jar, requested_port):
self._artifact_staging_service = artifact_service.ZipFileArtifactService(
- jar)
+ jar, self.ARTIFACT_FOLDER)
self._artifact_staging_server = grpc.server(futures.ThreadPoolExecutor())
port = self._artifact_staging_server.add_insecure_port(
'[::]:%s' % requested_port)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index ea9d02b..71343e5 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -385,9 +385,6 @@
self._profiler_factory = profiler.Profile.factory_from_options(
options.view_as(pipeline_options.ProfilingOptions))
- if 'use_sdf_bounded_source' in experiments:
- pipeline.replace_all(DataflowRunner._SDF_PTRANSFORM_OVERRIDES)
-
self._latest_run_result = self.run_via_runner_api(pipeline.to_runner_api(
default_environment=self._default_environment))
return self._latest_run_result
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 23480ce..ef09b1f 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -271,8 +271,10 @@
('B', 'b', 3)]
with self.create_pipeline() as p:
- assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()),
- equal_to(expected))
+ # TODO(BEAM-8893): Allow the reshuffle.
+ assert_that(
+ p | beam.Create(inputs, reshuffle=False) | beam.ParDo(AddIndex()),
+ equal_to(expected))
@unittest.skip('TestStream not yet supported')
def test_teststream_pardo_timers(self):
@@ -417,7 +419,8 @@
with self.create_pipeline() as p:
actual = (
p
- | beam.Create(elements)
+ # TODO(BEAM-8893): Allow the reshuffle.
+ | beam.Create(elements, reshuffle=False)
# Send even and odd elements to different windows.
| beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2))
| beam.WindowInto(window.FixedWindows(1) if windowed
@@ -777,8 +780,10 @@
self, monitoring_infos, urn, labels, value=None, ge_value=None):
# TODO(ajamato): Consider adding a matcher framework
found = 0
+ matches = []
for mi in monitoring_infos:
if has_urn_and_labels(mi, urn, labels):
+ matches.append(mi.metric.counter_data.int64_value)
if ge_value is not None:
if mi.metric.counter_data.int64_value >= ge_value:
found = found + 1
@@ -790,8 +795,8 @@
ge_value_str = {'ge_value' : ge_value} if ge_value else ''
value_str = {'value' : value} if value else ''
self.assertEqual(
- 1, found, "Found (%s) Expected only 1 monitoring_info for %s." %
- (found, (urn, labels, value_str, ge_value_str),))
+ 1, found, "Found (%s, %s) Expected only 1 monitoring_info for %s." %
+ (found, matches, (urn, labels, value_str, ge_value_str),))
def assert_has_distribution(
self, monitoring_infos, urn, labels,
@@ -833,10 +838,7 @@
(found, (urn, labels, str(description)),))
def create_pipeline(self):
- p = beam.Pipeline(runner=fn_api_runner.FnApiRunner())
- # TODO(BEAM-8448): Fix these tests.
- p.options.view_as(DebugOptions).experiments.remove('beam_fn_api')
- return p
+ return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
def test_element_count_metrics(self):
class GenerateTwoOutputs(beam.DoFn):
@@ -854,7 +856,8 @@
# Produce enough elements to make sure byte sampling occurs.
num_source_elems = 100
- pcoll = p | beam.Create(['a%d' % i for i in range(num_source_elems)])
+ pcoll = p | beam.Create(
+ ['a%d' % i for i in range(num_source_elems)], reshuffle=False)
# pylint: disable=expression-not-assigned
pardo = ('StepThatDoesTwoOutputs' >> beam.ParDo(
@@ -883,13 +886,14 @@
and
monitoring_infos.PCOLLECTION_LABEL not in x.labels])
try:
- labels = {monitoring_infos.PCOLLECTION_LABEL : 'Impulse'}
+ labels = {
+ monitoring_infos.PCOLLECTION_LABEL : 'ref_PCollection_PCollection_1'}
self.assert_has_counter(
counters, monitoring_infos.ELEMENT_COUNT_URN, labels, 1)
- # Create/Read, "out" output.
+ # Create output.
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_1'}
+ 'ref_PCollection_PCollection_3'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems)
@@ -902,7 +906,7 @@
# GenerateTwoOutputs, main output.
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_2'}
+ 'ref_PCollection_PCollection_4'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems)
@@ -915,7 +919,7 @@
# GenerateTwoOutputs, "SecondOutput" output.
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_3'}
+ 'ref_PCollection_PCollection_5'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, 2 * num_source_elems)
@@ -928,7 +932,7 @@
# GenerateTwoOutputs, "ThirdOutput" output.
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_4'}
+ 'ref_PCollection_PCollection_6'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems)
@@ -943,7 +947,7 @@
# outputs.
# Flatten/Read, main output.
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_5'}
+ 'ref_PCollection_PCollection_7'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, 4 * num_source_elems)
@@ -956,7 +960,7 @@
# PassThrough, main output
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_6'}
+ 'ref_PCollection_PCollection_8'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, 4 * num_source_elems)
@@ -969,7 +973,7 @@
# PassThrough2, main output
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_7'}
+ 'ref_PCollection_PCollection_9'}
self.assert_has_counter(
counters,
monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems)
@@ -1014,7 +1018,8 @@
namespace = split[0]
name = ':'.join(split[1:])
assert_counter_exists(
- all_metrics_via_montoring_infos, namespace, name, step='Create/Read')
+ all_metrics_via_montoring_infos, namespace, name,
+ step='Create/Impulse')
assert_counter_exists(
all_metrics_via_montoring_infos, namespace, name, step='MyStep')
@@ -1027,7 +1032,8 @@
p = self.create_pipeline()
_ = (p
- | beam.Create([0, 0, 0, 5e-3 * DEFAULT_SAMPLING_PERIOD_MS])
+ | beam.Create(
+ [0, 0, 0, 5e-3 * DEFAULT_SAMPLING_PERIOD_MS], reshuffle=False)
| beam.Map(time.sleep)
| beam.Map(lambda x: ('key', x))
| beam.GroupByKey()
@@ -1051,13 +1057,13 @@
# Test the DEPRECATED legacy metrics
pregbk_metrics, postgbk_metrics = list(
res._metrics_by_stage.values())
- if 'Create/Read' not in pregbk_metrics.ptransforms:
+ if 'Create/Map(decode)' not in pregbk_metrics.ptransforms:
# The metrics above are actually unordered. Swap.
pregbk_metrics, postgbk_metrics = postgbk_metrics, pregbk_metrics
self.assertEqual(
4,
- pregbk_metrics.ptransforms['Create/Read']
- .processed_elements.measured.output_element_counts['out'])
+ pregbk_metrics.ptransforms['Create/Map(decode)']
+ .processed_elements.measured.output_element_counts['None'])
self.assertEqual(
4,
pregbk_metrics.ptransforms['Map(sleep)']
@@ -1089,20 +1095,20 @@
self.assertEqual(2, len(res._monitoring_infos_by_stage))
pregbk_mis, postgbk_mis = list(res._monitoring_infos_by_stage.values())
- if not has_mi_for_ptransform(pregbk_mis, 'Create/Read'):
+ if not has_mi_for_ptransform(pregbk_mis, 'Create/Map(decode)'):
# The monitoring infos above are actually unordered. Swap.
pregbk_mis, postgbk_mis = postgbk_mis, pregbk_mis
# pregbk monitoring infos
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_1'}
+ 'ref_PCollection_PCollection_3'}
self.assert_has_counter(
pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4)
self.assert_has_distribution(
pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels)
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_2'}
+ 'ref_PCollection_PCollection_4'}
self.assert_has_counter(
pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4)
self.assert_has_distribution(
@@ -1115,14 +1121,14 @@
# postgbk monitoring infos
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_6'}
+ 'ref_PCollection_PCollection_8'}
self.assert_has_counter(
postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=1)
self.assert_has_distribution(
postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels)
labels = {monitoring_infos.PCOLLECTION_LABEL :
- 'ref_PCollection_PCollection_7'}
+ 'ref_PCollection_PCollection_9'}
self.assert_has_counter(
postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=5)
self.assert_has_distribution(
@@ -1406,7 +1412,7 @@
with self.create_pipeline() as p:
grouped = (
p
- | beam.Create(elements)
+ | beam.Create(elements, reshuffle=False)
| 'SDF' >> beam.ParDo(EnumerateSdf()))
flat = grouped | beam.FlatMap(lambda x: x)
assert_that(flat, equal_to(expected))
diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py
index b97f683..2bbafbb 100644
--- a/sdks/python/apache_beam/runners/portability/local_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/local_job_service.py
@@ -98,16 +98,36 @@
provision_info,
self._artifact_staging_endpoint)
+ def get_bind_address(self):
+ """Return the address used to open the port on the gRPC server.
+
+ This is often, but not always the same as the service address. For
+ example, to make the service accessible to external machines, override this
+ to return '[::]' and override `get_service_address()` to return a publicly
+ accessible host name.
+ """
+ return self.get_service_address()
+
+ def get_service_address(self):
+ """Return the host name at which this server will be accessible.
+
+ In particular, this is provided to the client upon connection as the
+ artifact staging endpoint.
+ """
+ return 'localhost'
+
def start_grpc_server(self, port=0):
self._server = grpc.server(UnboundedThreadPoolExecutor())
- port = self._server.add_insecure_port('localhost:%d' % port)
+ port = self._server.add_insecure_port(
+ '%s:%d' % (self.get_bind_address(), port))
beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self._server)
beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
self._artifact_service, self._server)
+ hostname = self.get_service_address()
self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor(
- url='localhost:%d' % port)
+ url='%s:%d' % (hostname, port))
self._server.start()
- _LOGGER.info('Grpc server started on port %s', port)
+ _LOGGER.info('Grpc server started at %s on port %d' % (hostname, port))
return port
def stop(self, timeout=1):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index c62f194..fd2528d 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -978,7 +978,7 @@
transform_id)
output_coder = factory.get_only_output_coder(transform_proto)
return DataInputOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
transform_proto.unique_name,
consumers,
factory.counter_factory,
@@ -1000,7 +1000,7 @@
transform_id)
output_coder = factory.get_only_input_coder(transform_proto)
return DataOutputOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
transform_proto.unique_name,
consumers,
factory.counter_factory,
@@ -1019,7 +1019,7 @@
[factory.get_only_output_coder(transform_proto)])
return factory.augment_oldstyle_op(
operations.ReadOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
spec,
factory.counter_factory,
factory.state_sampler),
@@ -1036,7 +1036,7 @@
[WindowedValueCoder(source.default_output_coder())])
return factory.augment_oldstyle_op(
operations.ReadOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
spec,
factory.counter_factory,
factory.state_sampler),
@@ -1048,7 +1048,7 @@
python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload)
def create(factory, transform_id, transform_proto, parameter, consumers):
return operations.ImpulseReadOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
factory.counter_factory,
factory.state_sampler,
consumers,
@@ -1227,7 +1227,7 @@
result = factory.augment_oldstyle_op(
operation_cls(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
spec,
factory.counter_factory,
factory.state_sampler,
@@ -1276,7 +1276,7 @@
def create(factory, transform_id, transform_proto, unused_parameter, consumers):
return factory.augment_oldstyle_op(
operations.FlattenOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerFlatten(
None, [factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
@@ -1294,7 +1294,7 @@
[], {}))
return factory.augment_oldstyle_op(
operations.PGBKCVOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerPartialGroupByKey(
serialized_combine_fn,
None,
@@ -1310,7 +1310,7 @@
beam_runner_api_pb2.CombinePayload)
def create(factory, transform_id, transform_proto, payload, consumers):
return _create_combine_phase_operation(
- factory, transform_proto, payload, consumers, 'merge')
+ factory, transform_id, transform_proto, payload, consumers, 'merge')
@BeamTransformFactory.register_urn(
@@ -1318,7 +1318,7 @@
beam_runner_api_pb2.CombinePayload)
def create(factory, transform_id, transform_proto, payload, consumers):
return _create_combine_phase_operation(
- factory, transform_proto, payload, consumers, 'extract')
+ factory, transform_id, transform_proto, payload, consumers, 'extract')
@BeamTransformFactory.register_urn(
@@ -1326,17 +1326,17 @@
beam_runner_api_pb2.CombinePayload)
def create(factory, transform_id, transform_proto, payload, consumers):
return _create_combine_phase_operation(
- factory, transform_proto, payload, consumers, 'all')
+ factory, transform_id, transform_proto, payload, consumers, 'all')
def _create_combine_phase_operation(
- factory, transform_proto, payload, consumers, phase):
+ factory, transform_id, transform_proto, payload, consumers, phase):
serialized_combine_fn = pickler.dumps(
(beam.CombineFn.from_runner_api(payload.combine_fn, factory.context),
[], {}))
return factory.augment_oldstyle_op(
operations.CombineOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerCombineFn(
serialized_combine_fn,
phase,
@@ -1352,7 +1352,7 @@
def create(factory, transform_id, transform_proto, unused_parameter, consumers):
return factory.augment_oldstyle_op(
operations.FlattenOperation(
- transform_proto.unique_name,
+ common.NameContext(transform_proto.unique_name, transform_id),
operation_specs.WorkerFlatten(
None,
[factory.get_only_output_coder(transform_proto)]),
diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py
index 08dac3a..12f162b 100644
--- a/sdks/python/apache_beam/runners/worker/log_handler.py
+++ b/sdks/python/apache_beam/runners/worker/log_handler.py
@@ -25,12 +25,13 @@
import sys
import threading
import time
-from builtins import range
+import traceback
import grpc
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
+from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
@@ -54,7 +55,8 @@
logging.ERROR: beam_fn_api_pb2.LogEntry.Severity.ERROR,
logging.WARNING: beam_fn_api_pb2.LogEntry.Severity.WARN,
logging.INFO: beam_fn_api_pb2.LogEntry.Severity.INFO,
- logging.DEBUG: beam_fn_api_pb2.LogEntry.Severity.DEBUG
+ logging.DEBUG: beam_fn_api_pb2.LogEntry.Severity.DEBUG,
+ -float('inf'): beam_fn_api_pb2.LogEntry.Severity.DEBUG,
}
def __init__(self, log_service_descriptor):
@@ -81,16 +83,37 @@
self._log_channel)
return self._logging_stub.Logging(self._write_log_entries())
+ def map_log_level(self, level):
+ try:
+ return self.LOG_LEVEL_MAP[level]
+ except KeyError:
+ return max(
+ beam_level for python_level, beam_level in self.LOG_LEVEL_MAP.items()
+ if python_level <= level)
+
def emit(self, record):
log_entry = beam_fn_api_pb2.LogEntry()
- log_entry.severity = self.LOG_LEVEL_MAP[record.levelno]
+ log_entry.severity = self.map_log_level(record.levelno)
log_entry.message = self.format(record)
log_entry.thread = record.threadName
- log_entry.log_location = record.module + '.' + record.funcName
+ log_entry.log_location = '%s:%s' % (
+ record.pathname or record.module, record.lineno or record.funcName)
(fraction, seconds) = math.modf(record.created)
nanoseconds = 1e9 * fraction
log_entry.timestamp.seconds = int(seconds)
log_entry.timestamp.nanos = int(nanoseconds)
+ if record.exc_info:
+ log_entry.trace = ''.join(traceback.format_exception(*record.exc_info))
+ instruction_id = statesampler.get_current_instruction_id()
+ if instruction_id:
+ log_entry.instruction_id = instruction_id
+ tracker = statesampler.get_current_tracker()
+ if tracker:
+ current_state = tracker.current_state()
+ if (current_state
+ and current_state.name_context
+ and current_state.name_context.transform_id):
+ log_entry.transform_id = current_state.name_context.transform_id
try:
self._log_entry_queue.put(log_entry, block=False)
diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py
index a651409..c79ccf9 100644
--- a/sdks/python/apache_beam/runners/worker/log_handler_test.py
+++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py
@@ -18,6 +18,7 @@
from __future__ import absolute_import
import logging
+import re
import unittest
from builtins import range
@@ -26,7 +27,9 @@
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import endpoints_pb2
+from apache_beam.runners.common import NameContext
from apache_beam.runners.worker import log_handler
+from apache_beam.runners.worker import statesampler
from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
_LOGGER = logging.getLogger(__name__)
@@ -83,14 +86,58 @@
log_entry.severity)
self.assertEqual('%s: %s' % (msg, num_received_log_entries),
log_entry.message)
- self.assertEqual(u'log_handler_test._verify_fn_log_handler',
- log_entry.log_location)
+ self.assertTrue(
+ re.match(r'.*/log_handler_test.py:\d+', log_entry.log_location),
+ log_entry.log_location)
self.assertGreater(log_entry.timestamp.seconds, 0)
self.assertGreaterEqual(log_entry.timestamp.nanos, 0)
num_received_log_entries += 1
self.assertEqual(num_received_log_entries, num_log_entries)
+ def assertContains(self, haystack, needle):
+ self.assertTrue(
+ needle in haystack, 'Expected %r to contain %r.' % (haystack, needle))
+
+ def test_exc_info(self):
+ try:
+ raise ValueError('some message')
+ except ValueError:
+ _LOGGER.error('some error', exc_info=True)
+
+ self.fn_log_handler.close()
+
+ log_entry = self.test_logging_service.log_records_received[0].log_entries[0]
+ self.assertContains(log_entry.message, 'some error')
+ self.assertContains(log_entry.trace, 'some message')
+ self.assertContains(log_entry.trace, 'log_handler_test.py')
+
+ def test_context(self):
+ try:
+ with statesampler.instruction_id('A'):
+ tracker = statesampler.for_test()
+ with tracker.scoped_state(NameContext('name', 'tid'), 'stage'):
+ _LOGGER.info('message a')
+ with statesampler.instruction_id('B'):
+ _LOGGER.info('message b')
+ _LOGGER.info('message c')
+
+ self.fn_log_handler.close()
+ a, b, c = sum(
+ [list(logs.log_entries)
+ for logs in self.test_logging_service.log_records_received], [])
+
+ self.assertEqual(a.instruction_id, 'A')
+ self.assertEqual(b.instruction_id, 'B')
+ self.assertEqual(c.instruction_id, '')
+
+ self.assertEqual(a.transform_id, 'tid')
+ self.assertEqual(b.transform_id, '')
+ self.assertEqual(c.transform_id, '')
+
+ finally:
+ statesampler.set_current_tracker(None)
+
# Test cases.
data = {
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index e0534ff..00a0ac2 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -39,6 +39,7 @@
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
@@ -132,17 +133,18 @@
_LOGGER.info('Done consuming work.')
def _execute(self, task, request):
- try:
- response = task()
- except Exception: # pylint: disable=broad-except
- traceback_string = traceback.format_exc()
- print(traceback_string, file=sys.stderr)
- _LOGGER.error(
- 'Error processing instruction %s. Original traceback is\n%s\n',
- request.instruction_id, traceback_string)
- response = beam_fn_api_pb2.InstructionResponse(
- instruction_id=request.instruction_id, error=traceback_string)
- self._responses.put(response)
+ with statesampler.instruction_id(request.instruction_id):
+ try:
+ response = task()
+ except Exception: # pylint: disable=broad-except
+ traceback_string = traceback.format_exc()
+ print(traceback_string, file=sys.stderr)
+ _LOGGER.error(
+ 'Error processing instruction %s. Original traceback is\n%s\n',
+ request.instruction_id, traceback_string)
+ response = beam_fn_api_pb2.InstructionResponse(
+ instruction_id=request.instruction_id, error=traceback_string)
+ self._responses.put(response)
def _request_register(self, request):
# registration request is handled synchronously
diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py
index 707ee1f..e57815e 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler.py
@@ -19,6 +19,7 @@
from __future__ import absolute_import
+import contextlib
import threading
from collections import namedtuple
@@ -49,6 +50,25 @@
return None
+_INSTRUCTION_IDS = threading.local()
+
+
+def get_current_instruction_id():
+ try:
+ return _INSTRUCTION_IDS.instruction_id
+ except AttributeError:
+ return None
+
+
+@contextlib.contextmanager
+def instruction_id(id):
+ try:
+ _INSTRUCTION_IDS.instruction_id = id
+ yield
+ finally:
+ _INSTRUCTION_IDS.instruction_id = None
+
+
def for_test():
set_current_tracker(StateSampler('test', CounterFactory()))
return get_current_tracker()
diff --git a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
index fdda05c..cac0c74 100644
--- a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
+++ b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
@@ -55,24 +55,6 @@
- {window: [20, 29], values: [25], timestamp: 25, late: false}
---
-name: timestamp_combiner_earliest_separate_bundles
-window_fn: FixedWindows(10)
-trigger_fn: Default
-timestamp_combiner: OUTPUT_AT_EARLIEST
-transcript:
- - input: [1]
- - input: [2]
- - input: [3]
- - input: [10]
- - input: [11]
- - input: [25]
- - watermark: 100
- - expect:
- - {window: [0, 9], values: [1, 2, 3], timestamp: 1, final: false}
- - {window: [10, 19], values: [10, 11], timestamp: 10}
- - {window: [20, 29], values: [25], timestamp: 25, late: false}
-
----
name: timestamp_combiner_latest
window_fn: FixedWindows(10)
trigger_fn: Default
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py
index c8bc9ff..15d1770 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -169,7 +169,6 @@
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
- use_global_window=False,
label='assert per window')
p.run()
@@ -177,8 +176,9 @@
def test_gbk_execution_after_watermark_trigger(self):
test_stream = (TestStream()
.advance_watermark_to(10)
- .add_elements(['a'])
+ .add_elements([TimestampedValue('a', 11)])
.advance_watermark_to(20)
+ .add_elements([TimestampedValue('b', 21)])
.advance_watermark_to_infinity())
options = PipelineOptions()
@@ -199,15 +199,18 @@
# assert per window
expected_window_to_elements = {
- window.IntervalWindow(15, 30): [
+ window.IntervalWindow(0, 15): [
('k', ['a']),
- ('k', []),
+ ('k', [])
+ ],
+ window.IntervalWindow(15, 30): [
+ ('k', ['b']),
+ ('k', [])
],
}
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
- use_global_window=False,
label='assert per window')
p.run()
@@ -247,7 +250,6 @@
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
- use_global_window=False,
label='assert per window')
p.run()
@@ -272,12 +274,12 @@
elm=beam.DoFn.ElementParam,
ts=beam.DoFn.TimestampParam,
side=beam.DoFn.SideInputParam):
- yield (elm, ts, side)
+ yield (elm, ts, sorted(side))
records = (main_stream # pylint: disable=unused-variable
| beam.ParDo(RecordFn(), beam.pvalue.AsList(side)))
- assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 4])]))
+ assert_that(records, equal_to([('e', Timestamp(10), [1, 2, 4])]))
p.run()
@@ -349,7 +351,6 @@
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
- use_global_window=False,
label='assert per window')
p.run()
@@ -403,7 +404,6 @@
assert_that(
records,
equal_to_per_window(expected_window_to_elements),
- use_global_window=False,
label='assert per window')
p.run()
diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py
index b52e61b..5b6bc85 100644
--- a/sdks/python/apache_beam/testing/util.py
+++ b/sdks/python/apache_beam/testing/util.py
@@ -39,6 +39,7 @@
__all__ = [
'assert_that',
'equal_to',
+ 'equal_to_per_window',
'is_empty',
'is_not_empty',
'matches_all',
@@ -85,30 +86,67 @@
return InAnyOrder(iterable)
+class _EqualToPerWindowMatcher(object):
+ def __init__(self, expected_window_to_elements):
+ self._expected_window_to_elements = expected_window_to_elements
+
+ def __call__(self, value):
+ # Short-hand.
+ _expected = self._expected_window_to_elements
+
+ # Match the given windowed value to an expected window. Fails if the window
+ # doesn't exist or the element wasn't found in the window.
+ def match(windowed_value):
+ actual = windowed_value.value
+ window_key = windowed_value.windows[0]
+ try:
+ expected = _expected[window_key]
+ except KeyError:
+ raise BeamAssertException(
+ 'Failed assert: window {} not found in any expected ' \
+ 'windows {}'.format(window_key, list(_expected.keys())))
+
+ # Remove any matched elements from the window. This is used later on to
+ # assert that all elements in the window were matched with actual
+ # elements.
+ try:
+ _expected[window_key].remove(actual)
+ except ValueError:
+ raise BeamAssertException(
+ 'Failed assert: element {} not found in window ' \
+ '{}:{}'.format(actual, window_key, _expected[window_key]))
+
+ # Run the matcher for each window and value pair. Fails if the
+ # windowed_value is not a TestWindowedValue.
+ for windowed_value in value:
+ if not isinstance(windowed_value, TestWindowedValue):
+ raise BeamAssertException(
+ 'Failed assert: Received element {} is not of type ' \
+ 'TestWindowedValue. Did you forget to set reify_windows=True ' \
+ 'on the assertion?'.format(windowed_value))
+ match(windowed_value)
+
+ # Finally, some elements may not have been matched. Assert that we removed
+ # all the elements that we received from the expected list. If the list is
+ # non-empty, then there are unmatched elements.
+ for win in _expected:
+ if _expected[win]:
+ raise BeamAssertException(
+ 'Failed assert: unmatched elements {} in window {}'.format(
+ _expected[win], win))
def equal_to_per_window(expected_window_to_elements):
- """Matcher used by assert_that to check on values for specific windows.
+ """Matcher used by assert_that to check to assert expected windows.
+
+ The 'assert_that' statement must have reify_windows=True. This assertion works
+ when elements are emitted and are finally checked at the end of the window.
Arguments:
expected_window_to_elements: A dictionary where the keys are the windows
to check and the values are the elements associated with each window.
"""
- def matcher(elements):
- actual_elements_in_window, window = elements
- if window in expected_window_to_elements:
- expected_elements_in_window = list(
- expected_window_to_elements[window])
- sorted_expected = sorted(expected_elements_in_window)
- sorted_actual = sorted(actual_elements_in_window)
- if sorted_expected != sorted_actual:
- # Results for the same window don't necessarily come all
- # at once. Hence the same actual window may contain only
- # subsets of the expected elements for the window.
- # For example, in the presence of early triggers.
- if all(elem in sorted_expected for elem in sorted_actual) is False:
- raise BeamAssertException(
- 'Failed assert: %r not in %r' % (sorted_actual, sorted_expected))
- return matcher
+
+ return _EqualToPerWindowMatcher(expected_window_to_elements)
# Note that equal_to checks if expected and actual are permutations of each
@@ -214,6 +252,10 @@
pvalue.PCollection), ('%s is not a supported type for Beam assert'
% type(actual))
+ if isinstance(matcher, _EqualToPerWindowMatcher):
+ reify_windows = True
+ use_global_window = True
+
class ReifyTimestampWindow(DoFn):
def process(self, element, timestamp=DoFn.TimestampParam,
window=DoFn.WindowParam):
@@ -239,6 +281,8 @@
keyed_actual = pcoll | "ToVoidKey" >> Map(lambda v: (None, v))
+ # This is a CoGroupByKey so that the matcher always runs, even if the
+ # PCollection is empty.
plain_actual = ((keyed_singleton, keyed_actual)
| "Group" >> CoGroupByKey()
| "Unkey" >> Map(lambda k_values: k_values[1][1]))
diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py
index 1fd1da6..72c9205 100644
--- a/sdks/python/apache_beam/testing/util_test.py
+++ b/sdks/python/apache_beam/testing/util_test.py
@@ -21,14 +21,20 @@
import unittest
+import apache_beam as beam
from apache_beam import Create
+from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import BeamAssertException
from apache_beam.testing.util import TestWindowedValue
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.testing.util import equal_to_per_window
from apache_beam.testing.util import is_empty
from apache_beam.testing.util import is_not_empty
+from apache_beam.transforms import trigger
+from apache_beam.transforms import window
+from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.timestamp import MIN_TIMESTAMP
@@ -110,6 +116,96 @@
with TestPipeline() as p:
assert_that(p | Create([]), is_not_empty())
+ def test_equal_to_per_window_passes(self):
+ start = int(MIN_TIMESTAMP.micros // 1e6) - 5
+ end = start + 20
+ expected = {
+ window.IntervalWindow(start, end): [('k', [1])],
+ }
+ with TestPipeline(options=StandardOptions(streaming=True)) as p:
+ assert_that((p
+ | Create([1])
+ | beam.WindowInto(
+ FixedWindows(20),
+ trigger=trigger.AfterWatermark(),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('k', x))
+ | beam.GroupByKey()),
+ equal_to_per_window(expected),
+ reify_windows=True)
+
+ def test_equal_to_per_window_fail_unmatched_window(self):
+ with self.assertRaises(BeamAssertException):
+ expected = {
+ window.IntervalWindow(50, 100): [('k', [1])],
+ }
+ with TestPipeline(options=StandardOptions(streaming=True)) as p:
+ assert_that((p
+ | Create([1])
+ | beam.WindowInto(
+ FixedWindows(20),
+ trigger=trigger.AfterWatermark(),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('k', x))
+ | beam.GroupByKey()),
+ equal_to_per_window(expected),
+ reify_windows=True)
+
+ def test_equal_to_per_window_fail_unmatched_element(self):
+ with self.assertRaises(BeamAssertException):
+ start = int(MIN_TIMESTAMP.micros // 1e6) - 5
+ end = start + 20
+ expected = {
+ window.IntervalWindow(start, end): [('k', [1]), ('k', [2])],
+ }
+ with TestPipeline(options=StandardOptions(streaming=True)) as p:
+ assert_that((p
+ | Create([1])
+ | beam.WindowInto(
+ FixedWindows(20),
+ trigger=trigger.AfterWatermark(),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('k', x))
+ | beam.GroupByKey()),
+ equal_to_per_window(expected),
+ reify_windows=True)
+
+ def test_equal_to_per_window_succeeds_no_reify_windows(self):
+ start = int(MIN_TIMESTAMP.micros // 1e6) - 5
+ end = start + 20
+ expected = {
+ window.IntervalWindow(start, end): [('k', [1])],
+ }
+ with TestPipeline(options=StandardOptions(streaming=True)) as p:
+ assert_that((p
+ | Create([1])
+ | beam.WindowInto(
+ FixedWindows(20),
+ trigger=trigger.AfterWatermark(),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('k', x))
+ | beam.GroupByKey()),
+ equal_to_per_window(expected))
+
+ def test_equal_to_per_window_fail_unexpected_element(self):
+ with self.assertRaises(BeamAssertException):
+ start = int(MIN_TIMESTAMP.micros // 1e6) - 5
+ end = start + 20
+ expected = {
+ window.IntervalWindow(start, end): [('k', [1])],
+ }
+ with TestPipeline(options=StandardOptions(streaming=True)) as p:
+ assert_that((p
+ | Create([1, 2])
+ | beam.WindowInto(
+ FixedWindows(20),
+ trigger=trigger.AfterWatermark(),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('k', x))
+ | beam.GroupByKey()),
+ equal_to_per_window(expected),
+ reify_windows=True)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 8794d2a..3169d53 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -38,7 +38,6 @@
from apache_beam.coders import typecoders
from apache_beam.internal import pickler
from apache_beam.internal import util
-from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
@@ -2511,38 +2510,32 @@
def expand(self, pbegin):
assert isinstance(pbegin, pvalue.PBegin)
- # Must guard against this as some legacy runners don't implement impulse.
- debug_options = pbegin.pipeline._options.view_as(DebugOptions)
- fn_api = (debug_options.experiments
- and 'beam_fn_api' in debug_options.experiments)
- if fn_api:
- coder = typecoders.registry.get_coder(self.get_output_type())
- serialized_values = [coder.encode(v) for v in self.values]
- reshuffle = self.reshuffle
- # Avoid the "redistributing" reshuffle for 0 and 1 element Creates.
- # These special cases are often used in building up more complex
- # transforms (e.g. Write).
+ coder = typecoders.registry.get_coder(self.get_output_type())
+ serialized_values = [coder.encode(v) for v in self.values]
+ reshuffle = self.reshuffle
+ # Avoid the "redistributing" reshuffle for 0 and 1 element Creates.
+ # These special cases are often used in building up more complex
+ # transforms (e.g. Write).
- class MaybeReshuffle(PTransform):
- def expand(self, pcoll):
- if len(serialized_values) > 1 and reshuffle:
- from apache_beam.transforms.util import Reshuffle
- return pcoll | Reshuffle()
- else:
- return pcoll
- return (
- pbegin
- | Impulse()
- | FlatMap(lambda _: serialized_values)
- | MaybeReshuffle()
- | Map(coder.decode).with_output_types(self.get_output_type()))
- else:
- self.pipeline = pbegin.pipeline
- from apache_beam.io import iobase
- coder = typecoders.registry.get_coder(self.get_output_type())
- source = self._create_source_from_iterable(self.values, coder)
- return (pbegin.pipeline
- | iobase.Read(source).with_output_types(self.get_output_type()))
+ class MaybeReshuffle(PTransform):
+ def expand(self, pcoll):
+ if len(serialized_values) > 1 and reshuffle:
+ from apache_beam.transforms.util import Reshuffle
+ return pcoll | Reshuffle()
+ else:
+ return pcoll
+ return (
+ pbegin
+ | Impulse()
+ | FlatMap(lambda _: serialized_values).with_output_types(bytes)
+ | MaybeReshuffle().with_output_types(bytes)
+ | Map(coder.decode).with_output_types(self.get_output_type()))
+
+ def as_read(self):
+ from apache_beam.io import iobase
+ coder = typecoders.registry.get_coder(self.get_output_type())
+ source = self._create_source_from_iterable(self.values, coder)
+ return iobase.Read(source).with_output_types(self.get_output_type())
def get_windowing(self, unused_inputs):
return Windowing(GlobalWindows())
@@ -2558,6 +2551,7 @@
return _CreateSource(serialized_values, coder)
+@typehints.with_output_types(bytes)
class Impulse(PTransform):
"""Impulse primitive."""
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 380708d..0c9459f 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -506,9 +506,10 @@
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.transforms.core import Create
# pylint: enable=wrong-import-order, wrong-import-position
- replacements = {id(v): p | 'CreatePInput%s' % ix >> Create(v)
- for ix, v in enumerate(pvalues)
- if not isinstance(v, pvalue.PValue) and v is not None}
+ replacements = {
+ id(v): p | 'CreatePInput%s' % ix >> Create(v, reshuffle=False)
+ for ix, v in enumerate(pvalues)
+ if not isinstance(v, pvalue.PValue) and v is not None}
pvalueish = _SetInputPValues().visit(pvalueish, replacements)
self.pipeline = p
result = p.apply(self, pvalueish, label)
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index ad87082..ffb245c 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -84,29 +84,29 @@
str(PTransform()))
pa = TestPipeline()
- res = pa | 'ALabel' >> beam.Create([1, 2])
- self.assertEqual('AppliedPTransform(ALabel/Read, Read)',
+ res = pa | 'ALabel' >> beam.Impulse()
+ self.assertEqual('AppliedPTransform(ALabel, Impulse)',
str(res.producer))
pc = TestPipeline()
- res = pc | beam.Create([1, 2])
+ res = pc | beam.Impulse()
inputs_tr = res.producer.transform
inputs_tr.inputs = ('ci',)
self.assertEqual(
- """<Read(PTransform) label=[Read] inputs=('ci',)>""",
+ "<Impulse(PTransform) label=[Impulse] inputs=('ci',)>",
str(inputs_tr))
pd = TestPipeline()
- res = pd | beam.Create([1, 2])
+ res = pd | beam.Impulse()
side_tr = res.producer.transform
side_tr.side_inputs = (4,)
self.assertEqual(
- '<Read(PTransform) label=[Read] side_inputs=(4,)>',
+ '<Impulse(PTransform) label=[Impulse] side_inputs=(4,)>',
str(side_tr))
inputs_tr.side_inputs = ('cs',)
self.assertEqual(
- """<Read(PTransform) label=[Read] """
+ """<Impulse(PTransform) label=[Impulse] """
"""inputs=('ci',) side_inputs=('cs',)>""",
str(inputs_tr))
@@ -495,7 +495,7 @@
pipeline = TestPipeline()
pcoll = pipeline | 'start' >> beam.Create(
[(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)])
- result = pcoll | 'Group' >> beam.GroupByKey()
+ result = pcoll | 'Group' >> beam.GroupByKey() | _SortLists
assert_that(result, equal_to([(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]))
pipeline.run()
@@ -562,7 +562,7 @@
created = pipeline | 'A' >> beam.Create(contents)
partitioned = created | 'B' >> beam.Partition(lambda x, n: len(x) % n, 3)
flattened = partitioned | 'C' >> beam.Flatten()
- grouped = flattened | 'D' >> beam.GroupByKey()
+ grouped = flattened | 'D' >> beam.GroupByKey() | _SortLists
assert_that(grouped, equal_to([('aa', [1, 2]), ('bb', [2])]))
pipeline.run()
@@ -654,7 +654,7 @@
[('a', 1), ('a', 2), ('b', 3), ('c', 4)])
pcoll_2 = pipeline | 'Start 2' >> beam.Create(
[('a', 5), ('a', 6), ('c', 7), ('c', 8)])
- result = (pcoll_1, pcoll_2) | beam.CoGroupByKey()
+ result = (pcoll_1, pcoll_2) | beam.CoGroupByKey() | _SortLists
assert_that(result, equal_to([('a', ([1, 2], [5, 6])),
('b', ([3], [])),
('c', ([4], [7, 8]))]))
@@ -667,6 +667,7 @@
pcoll_2 = pipeline | 'Start 2' >> beam.Create(
[('a', 5), ('a', 6), ('c', 7), ('c', 8)])
result = [pc for pc in (pcoll_1, pcoll_2)] | beam.CoGroupByKey()
+ result |= _SortLists
assert_that(result, equal_to([('a', ([1, 2], [5, 6])),
('b', ([3], [])),
('c', ([4], [7, 8]))]))
@@ -679,6 +680,7 @@
pcoll_2 = pipeline | 'Start 2' >> beam.Create(
[('a', 5), ('a', 6), ('c', 7), ('c', 8)])
result = {'X': pcoll_1, 'Y': pcoll_2} | beam.CoGroupByKey()
+ result |= _SortLists
assert_that(result, equal_to([('a', {'X': [1, 2], 'Y': [5, 6]}),
('b', {'X': [3], 'Y': []}),
('c', {'X': [4], 'Y': [7, 8]})]))
@@ -760,8 +762,9 @@
([1, 2, 3], [100]) | beam.Flatten())
join_input = ([('k', 'a')],
[('k', 'b'), ('k', 'c')])
- self.assertCountEqual([('k', (['a'], ['b', 'c']))],
- join_input | beam.CoGroupByKey())
+ self.assertCountEqual(
+ [('k', (['a'], ['b', 'c']))],
+ join_input | beam.CoGroupByKey() | _SortLists)
def test_multi_input_ptransform(self):
class DisjointUnion(PTransform):
@@ -908,7 +911,9 @@
def check_label(self, ptransform, expected_label):
pipeline = TestPipeline()
pipeline | 'Start' >> beam.Create([('a', 1)]) | ptransform
- actual_label = sorted(pipeline.applied_labels - {'Start', 'Start/Read'})[0]
+ actual_label = sorted(
+ label for label in pipeline.applied_labels
+ if not label.startswith('Start'))[0]
self.assertEqual(expected_label, re.sub(r'\d{3,}', '#', actual_label))
def test_default_labels(self):
@@ -1365,11 +1370,12 @@
| 'T' >> beam.Create(['t', 'e', 's', 't', 'i', 'n', 'g'])
.with_output_types(str)
| 'GenKeys' >> beam.Map(group_with_upper_ord)
- | 'O' >> beam.GroupByKey())
+ | 'O' >> beam.GroupByKey()
+ | _SortLists)
assert_that(result, equal_to([(1, ['g']),
- (3, ['s', 'i', 'n']),
- (4, ['t', 'e', 't'])]))
+ (3, ['i', 'n', 's']),
+ (4, ['e', 't', 't'])]))
self.p.run()
def test_pipeline_checking_satisfied_but_run_time_types_violate(self):
@@ -1415,7 +1421,8 @@
result = (self.p
| 'Nums' >> beam.Create(range(5)).with_output_types(int)
| 'IsEven' >> beam.Map(is_even_as_key)
- | 'Parity' >> beam.GroupByKey())
+ | 'Parity' >> beam.GroupByKey()
+ | _SortLists)
assert_that(result, equal_to([(False, [1, 3]), (True, [0, 2, 4])]))
self.p.run()
@@ -1429,7 +1436,7 @@
# passed instead.
with self.assertRaises(typehints.TypeCheckError) as e:
(self.p
- | beam.Create([1, 2, 3])
+ | beam.Create([1, 1, 1])
| ('ToInt' >> beam.FlatMap(lambda x: [int(x)])
.with_input_types(str).with_output_types(int)))
self.p.run()
@@ -1474,7 +1481,7 @@
int)).get_type_hints())
with self.assertRaises(typehints.TypeCheckError) as e:
(self.p
- | beam.Create([1, 2, 3])
+ | beam.Create([1, 1, 1])
| ('ToInt' >> beam.FlatMap(lambda x: [float(x)])
.with_input_types(int).with_output_types(int))
)
@@ -1681,7 +1688,7 @@
with self.assertRaises(typehints.TypeCheckError) as e:
(self.p
- | beam.Create(range(3)).with_output_types(int)
+ | beam.Create([0]).with_output_types(int)
| ('SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s)))
.with_input_types(str).with_output_types(str)))
self.p.run()
@@ -2238,5 +2245,19 @@
p.run()
+def _sort_lists(result):
+ if isinstance(result, list):
+ return sorted(result)
+ elif isinstance(result, tuple):
+ return tuple(_sort_lists(e) for e in result)
+ elif isinstance(result, dict):
+ return {k: _sort_lists(v) for k, v in result.items()}
+ else:
+ return result
+
+
+_SortLists = beam.Map(_sort_lists)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py
index 6f59f21..2a76c2f 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -215,6 +215,16 @@
pass
@abstractmethod
+ def has_ontime_pane(self):
+ """Whether this trigger creates an empty pane even if there are no elements.
+
+ Returns:
+ True if this trigger guarantees that there will always be an ON_TIME pane
+ even if there are no elements in that pane.
+ """
+ pass
+
+ @abstractmethod
def on_fire(self, watermark, window, context):
"""Called when a trigger actually fires.
@@ -280,6 +290,10 @@
context.clear_timer('', TimeDomain.WATERMARK)
def should_fire(self, time_domain, watermark, window, context):
+ if watermark >= window.end:
+ # Explicitly clear the timer so that late elements are not emitted again
+ # when the timer is fired.
+ context.clear_timer('', TimeDomain.WATERMARK)
return watermark >= window.end
def on_fire(self, watermark, window, context):
@@ -302,6 +316,9 @@
return beam_runner_api_pb2.Trigger(
default=beam_runner_api_pb2.Trigger.Default())
+ def has_ontime_pane(self):
+ return True
+
class AfterProcessingTime(TriggerFn):
"""Fire exactly once after a specified delay from processing time.
@@ -351,6 +368,9 @@
after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime(
timestamp_transforms=[delay_proto]))
+ def has_ontime_pane(self):
+ return False
+
class AfterWatermark(TriggerFn):
"""Fire exactly once when the watermark passes the end of the window.
@@ -406,6 +426,9 @@
return self.late.should_fire(time_domain, watermark,
window, NestedContext(context, 'late'))
elif watermark >= window.end:
+ # Explicitly clear the timer so that late elements are not emitted again
+ # when the timer is fired.
+ context.clear_timer('', TimeDomain.WATERMARK)
return True
elif self.early:
return self.early.should_fire(time_domain, watermark,
@@ -461,6 +484,9 @@
early_firings=early_proto,
late_firings=late_proto))
+ def has_ontime_pane(self):
+ return True
+
class AfterCount(TriggerFn):
"""Fire when there are at least count elements in this window pane.
@@ -509,6 +535,8 @@
element_count=beam_runner_api_pb2.Trigger.ElementCount(
element_count=self.count))
+ def has_ontime_pane(self):
+ return False
class Repeatedly(TriggerFn):
"""Repeatedly invoke the given trigger, never finishing."""
@@ -552,6 +580,9 @@
repeat=beam_runner_api_pb2.Trigger.Repeat(
subtrigger=self.underlying.to_runner_api(context)))
+ def has_ontime_pane(self):
+ return self.underlying.has_ontime_pane()
+
class _ParallelTriggerFn(with_metaclass(ABCMeta, TriggerFn)):
@@ -630,6 +661,8 @@
else:
raise NotImplementedError(self)
+ def has_ontime_pane(self):
+ return any(t.has_ontime_pane() for t in self.triggers)
class AfterAny(_ParallelTriggerFn):
"""Fires when any subtrigger fires.
@@ -717,6 +750,8 @@
subtrigger.to_runner_api(context)
for subtrigger in self.triggers]))
+ def has_ontime_pane(self):
+ return any(t.has_ontime_pane() for t in self.triggers)
class OrFinally(AfterAny):
@@ -968,18 +1003,21 @@
"""Breaks a series of bundle and timer firings into window (pane)s."""
@abstractmethod
- def process_elements(self, state, windowed_values, output_watermark):
+ def process_elements(self, state, windowed_values, output_watermark,
+ input_watermark=MIN_TIMESTAMP):
pass
@abstractmethod
- def process_timer(self, window_id, name, time_domain, timestamp, state):
+ def process_timer(self, window_id, name, time_domain, timestamp, state,
+ input_watermark=None):
pass
- def process_entire_key(
- self, key, windowed_values, output_watermark=MIN_TIMESTAMP):
+ def process_entire_key(self, key, windowed_values,
+ unused_output_watermark=None,
+ unused_input_watermark=None):
state = InMemoryUnmergedState()
for wvalue in self.process_elements(
- state, windowed_values, output_watermark):
+ state, windowed_values, MIN_TIMESTAMP, MIN_TIMESTAMP):
yield wvalue.with_value((key, wvalue.value))
while state.timers:
fired = state.get_and_clear_timers()
@@ -1039,14 +1077,17 @@
index=0,
nonspeculative_index=0)
- def process_elements(self, state, windowed_values, unused_output_watermark):
+ def process_elements(self, state, windowed_values,
+ unused_output_watermark,
+ unused_input_watermark=MIN_TIMESTAMP):
yield WindowedValue(
_UnwindowedValues(windowed_values),
MIN_TIMESTAMP,
self.GLOBAL_WINDOW_TUPLE,
self.ONLY_FIRING)
- def process_timer(self, window_id, name, time_domain, timestamp, state):
+ def process_timer(self, window_id, name, time_domain, timestamp, state,
+ input_watermark=None):
raise TypeError('Triggers never set or called for batch default windowing.')
@@ -1057,15 +1098,19 @@
self.phased_combine_fn = phased_combine_fn
self.underlying = underlying
- def process_elements(self, state, windowed_values, output_watermark):
+ def process_elements(self, state, windowed_values, output_watermark,
+ input_watermark=MIN_TIMESTAMP):
uncombined = self.underlying.process_elements(state, windowed_values,
- output_watermark)
+ output_watermark,
+ input_watermark)
for output in uncombined:
yield output.with_value(self.phased_combine_fn.apply(output.value))
- def process_timer(self, window_id, name, time_domain, timestamp, state):
+ def process_timer(self, window_id, name, time_domain, timestamp, state,
+ input_watermark=None):
uncombined = self.underlying.process_timer(window_id, name, time_domain,
- timestamp, state)
+ timestamp, state,
+ input_watermark)
for output in uncombined:
yield output.with_value(self.phased_combine_fn.apply(output.value))
@@ -1094,7 +1139,8 @@
self.accumulation_mode = windowing.accumulation_mode
self.is_merging = True
- def process_elements(self, state, windowed_values, output_watermark):
+ def process_elements(self, state, windowed_values, output_watermark,
+ input_watermark=MIN_TIMESTAMP):
if self.is_merging:
state = MergeableStateAdapter(state)
@@ -1155,14 +1201,17 @@
self.trigger_fn.on_element(value, window, context)
# Maybe fire this window.
- watermark = MIN_TIMESTAMP
- if self.trigger_fn.should_fire(TimeDomain.WATERMARK, watermark,
+ if self.trigger_fn.should_fire(TimeDomain.WATERMARK, input_watermark,
window, context):
- finished = self.trigger_fn.on_fire(watermark, window, context)
- yield self._output(window, finished, state, output_watermark, False)
+ finished = self.trigger_fn.on_fire(input_watermark, window, context)
+ yield self._output(window, finished, state, input_watermark,
+ output_watermark, False)
def process_timer(self, window_id, unused_name, time_domain, timestamp,
- state):
+ state, input_watermark=None):
+ if input_watermark is None:
+ input_watermark = timestamp
+
if self.is_merging:
state = MergeableStateAdapter(state)
window = state.get_window(window_id)
@@ -1175,16 +1224,17 @@
if self.trigger_fn.should_fire(time_domain, timestamp,
window, context):
finished = self.trigger_fn.on_fire(timestamp, window, context)
- yield self._output(window, finished, state, timestamp,
- time_domain == TimeDomain.WATERMARK)
+ yield self._output(window, finished, state, input_watermark,
+ timestamp, time_domain == TimeDomain.WATERMARK)
else:
raise Exception('Unexpected time domain: %s' % time_domain)
- def _output(self, window, finished, state, watermark, maybe_ontime):
+ def _output(self, window, finished, state, input_watermark, output_watermark,
+ maybe_ontime):
"""Output window and clean up if appropriate."""
index = state.get_state(window, self.INDEX)
state.add_state(window, self.INDEX, 1)
- if watermark <= window.max_timestamp():
+ if output_watermark <= window.max_timestamp():
nonspeculative_index = -1
timing = windowed_value.PaneInfoTiming.EARLY
if state.get_state(window, self.NONSPECULATIVE_INDEX):
@@ -1220,6 +1270,10 @@
if timestamp is None:
# If no watermark hold was set, output at end of window.
timestamp = window.max_timestamp()
+ elif input_watermark < window.end and self.trigger_fn.has_ontime_pane():
+ # Hold the watermark in case there is an empty pane that needs to be fired
+ # at the end of the window.
+ pass
else:
state.clear_state(window, self.WATERMARK_HOLD)
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index d1e5433..58b29e0 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -23,6 +23,7 @@
import json
import os.path
import pickle
+import random
import unittest
from builtins import range
from builtins import zip
@@ -122,7 +123,8 @@
state = InMemoryUnmergedState()
for bundle in bundles:
- for wvalue in driver.process_elements(state, bundle, MIN_TIMESTAMP):
+ for wvalue in driver.process_elements(state, bundle, MIN_TIMESTAMP,
+ MIN_TIMESTAMP):
window, = wvalue.windows
self.assertEqual(window.max_timestamp(), wvalue.timestamp)
actual_panes[window].append(set(wvalue.value))
@@ -131,13 +133,14 @@
for timer_window, (name, time_domain, timestamp) in (
state.get_and_clear_timers()):
for wvalue in driver.process_timer(
- timer_window, name, time_domain, timestamp, state):
+ timer_window, name, time_domain, timestamp, state, MIN_TIMESTAMP):
window, = wvalue.windows
self.assertEqual(window.max_timestamp(), wvalue.timestamp)
actual_panes[window].append(set(wvalue.value))
for bundle in late_bundles:
- for wvalue in driver.process_elements(state, bundle, MAX_TIMESTAMP):
+ for wvalue in driver.process_elements(state, bundle, MAX_TIMESTAMP,
+ MAX_TIMESTAMP):
window, = wvalue.windows
self.assertEqual(window.max_timestamp(), wvalue.timestamp)
actual_panes[window].append(set(wvalue.value))
@@ -146,7 +149,7 @@
for timer_window, (name, time_domain, timestamp) in (
state.get_and_clear_timers()):
for wvalue in driver.process_timer(
- timer_window, name, time_domain, timestamp, state):
+ timer_window, name, time_domain, timestamp, state, MAX_TIMESTAMP):
window, = wvalue.windows
self.assertEqual(window.max_timestamp(), wvalue.timestamp)
actual_panes[window].append(set(wvalue.value))
@@ -395,7 +398,7 @@
for k in range(10))
with self.assertRaises(TypeError):
pickle.dumps(unpicklable)
- for unwindowed in driver.process_elements(None, unpicklable, None):
+ for unwindowed in driver.process_elements(None, unpicklable, None, None):
self.assertEqual(pickle.loads(pickle.dumps(unwindowed)).value,
list(range(10)))
@@ -644,7 +647,9 @@
vs, windows=[window], timestamp=t, pane_info=p)))
-def _windowed_value_info_check(actual, expected):
+def _windowed_value_info_check(actual, expected, key=None):
+
+ key_string = ' for %s' % key if key else ''
def format(panes):
return '\n[%s]\n' % '\n '.join(str(pane) for pane in sorted(
@@ -652,12 +657,12 @@
if len(actual) > len(expected):
raise AssertionError(
- 'Unexpected output: expected %s but got %s' % (
- format(expected), format(actual)))
+ 'Unexpected output%s: expected %s but got %s' % (
+ key_string, format(expected), format(actual)))
elif len(expected) > len(actual):
raise AssertionError(
- 'Unmatched output: expected %s but got %s' % (
- format(expected), format(actual)))
+ 'Unmatched output%s: expected %s but got %s' % (
+ key_string, format(expected), format(actual)))
else:
def diff(actual, expected):
@@ -670,8 +675,8 @@
diffs = [diff(output, pane) for pane in expected]
if all(diffs):
raise AssertionError(
- 'Unmatched output: %s not found in %s (diffs in %s)' % (
- output, format(expected), diffs))
+ 'Unmatched output%s: %s not found in %s (diffs in %s)' % (
+ key_string, output, format(expected), diffs))
class _ConcatCombineFn(beam.CombineFn):
@@ -757,6 +762,19 @@
if runner_name in spec.get('broken_on', ()):
self.skipTest('Known to be broken on %s' % runner_name)
+ is_order_agnostic = (
+ isinstance(trigger_fn, DefaultTrigger)
+ and accumulation_mode == AccumulationMode.ACCUMULATING)
+
+ if is_order_agnostic:
+ reshuffle_seed = random.randrange(1 << 20)
+ keys = [
+ u'original', u'reversed', u'reshuffled(%s)' % reshuffle_seed,
+ u'one-element-bundles', u'one-element-bundles-reversed',
+ u'two-element-bundles']
+ else:
+ keys = [u'key1', u'key2']
+
# Elements are encoded as a json strings to allow other languages to
# decode elements while executing the test stream.
# TODO(BEAM-8600): Eliminate these gymnastics.
@@ -767,7 +785,28 @@
else:
test_stream.add_elements([json.dumps(('expect', []))])
if action == 'input':
- test_stream.add_elements([json.dumps(('input', e)) for e in params])
+ def keyed(key, values):
+ return [json.dumps(('input', (key, v))) for v in values]
+ if is_order_agnostic:
+ # Must match keys above.
+ test_stream.add_elements(keyed('original', params))
+ test_stream.add_elements(keyed('reversed', reversed(params)))
+ r = random.Random(reshuffle_seed)
+ reshuffled = list(params)
+ r.shuffle(reshuffled)
+ test_stream.add_elements(keyed(
+ 'reshuffled(%s)' % reshuffle_seed, reshuffled))
+ for v in params:
+ test_stream.add_elements(keyed('one-element-bundles', [v]))
+ for v in reversed(params):
+ test_stream.add_elements(
+ keyed('one-element-bundles-reversed', [v]))
+ for ix in range(0, len(params), 2):
+ test_stream.add_elements(
+ keyed('two-element-bundles', params[ix:ix+2]))
+ else:
+ for key in keys:
+ test_stream.add_elements(keyed(key, params))
elif action == 'watermark':
test_stream.advance_watermark_to(params)
elif action == 'clock':
@@ -806,7 +845,7 @@
beam.transforms.userstate.BagStateSpec(
'expected',
beam.coders.FastPrimitivesCoder()))):
- _, (action, data) = element
+ key, (action, data) = element
if self.allow_out_of_order:
if action == 'expect' and not list(seen.read()):
@@ -831,7 +870,7 @@
elif action == 'expect':
actual = list(seen.read())
seen.clear()
- _windowed_value_info_check(actual, data)
+ _windowed_value_info_check(actual, data, key)
else:
raise ValueError('Unexpected action: %s' % action)
@@ -842,11 +881,9 @@
# a branch of expected results.
inputs, expected = (
inputs_and_expected
- | beam.FlatMapTuple(
- lambda tag, value: [
- beam.pvalue.TaggedOutput(tag, ('key1', value)),
- beam.pvalue.TaggedOutput(tag, ('key2', value)),
- ]).with_outputs('input', 'expect'))
+ | beam.MapTuple(
+ lambda tag, value: beam.pvalue.TaggedOutput(tag, value),
+ ).with_outputs('input', 'expect'))
# Process the inputs with the given windowing to produce actual outputs.
outputs = (
@@ -865,7 +902,8 @@
| 'Global' >> beam.WindowInto(beam.transforms.window.GlobalWindows()))
# Feed both the expected and actual outputs to Check() for comparison.
tagged_expected = (
- expected | beam.MapTuple(lambda key, value: (key, ('expect', value))))
+ expected | beam.FlatMap(
+ lambda value: [(key, ('expect', value)) for key in keys]))
tagged_outputs = (
outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
# pylint: disable=expression-not-assigned
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index 21ef0ec..601a1d4 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -521,7 +521,7 @@
('key', 2),
('key', 3),
('key', 4),
- ('key', 3)])
+ ('key', 3)], reshuffle=False)
actual_values = (values
| beam.ParDo(SetStatefulDoFn()))
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 7a87e60..c7f273a 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -25,6 +25,7 @@
import contextlib
import random
import re
+import sys
import time
import typing
import warnings
@@ -34,6 +35,7 @@
from builtins import zip
from future.utils import itervalues
+from past.builtins import long
from apache_beam import coders
from apache_beam import typehints
@@ -648,7 +650,7 @@
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]
- ungrouped = pcoll | Map(reify_timestamps)
+ ungrouped = pcoll | Map(reify_timestamps).with_output_types(typing.Any)
# TODO(BEAM-8104) Using global window as one of the standard window.
# This is to mitigate the Dataflow Java Runner Harness limitation to
@@ -660,7 +662,7 @@
timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
result = (ungrouped
| GroupByKey()
- | FlatMap(restore_timestamps))
+ | FlatMap(restore_timestamps).with_output_types(typing.Any))
result._windowing = windowing_saved
return result
@@ -680,10 +682,16 @@
"""
def expand(self, pcoll):
+ if sys.version_info >= (3,):
+ KeyedT = typing.Tuple[int, T]
+ else:
+ KeyedT = typing.Tuple[long, T] # pylint: disable=long-builtin
return (pcoll
| 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t))
+ .with_input_types(T).with_output_types(KeyedT)
| ReshufflePerKey()
- | 'RemoveRandomKeys' >> Map(lambda t: t[1]))
+ | 'RemoveRandomKeys' >> Map(lambda t: t[1])
+ .with_input_types(KeyedT).with_output_types(T))
def to_runner_api_parameter(self, unused_context):
return common_urns.composites.RESHUFFLE.urn, None
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 74829e5..58cf243 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -100,7 +100,7 @@
with TestPipeline() as p:
res = (
p
- | beam.Create(range(47))
+ | beam.Create(range(47), reshuffle=False)
| beam.Map(lambda t: window.TimestampedValue(t, t))
| beam.WindowInto(window.FixedWindows(30))
| util.BatchElements(
@@ -351,7 +351,8 @@
after_gbk = (pipeline
| beam.Create(data)
- | beam.GroupByKey())
+ | beam.GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
assert_that(after_gbk, equal_to(expected_result), label='after_gbk')
after_reshuffle = after_gbk | beam.Reshuffle()
assert_that(after_reshuffle, equal_to(expected_result),
@@ -435,7 +436,8 @@
before_reshuffle = (pipeline
| beam.Create(data)
| beam.WindowInto(GlobalWindows())
- | beam.GroupByKey())
+ | beam.GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
assert_that(before_reshuffle, equal_to(expected_data),
label='before_reshuffle')
after_reshuffle = before_reshuffle | beam.Reshuffle()
@@ -452,7 +454,8 @@
| beam.Create(data)
| beam.WindowInto(SlidingWindows(
size=window_size, period=1))
- | beam.GroupByKey())
+ | beam.GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
assert_that(before_reshuffle, equal_to(expected_data),
label='before_reshuffle')
after_reshuffle = before_reshuffle | beam.Reshuffle()
@@ -471,7 +474,8 @@
before_reshuffle = (pipeline
| beam.Create(data)
| beam.WindowInto(GlobalWindows())
- | beam.GroupByKey())
+ | beam.GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
assert_that(before_reshuffle, equal_to(expected_data),
label='before_reshuffle')
after_reshuffle = before_reshuffle | beam.Reshuffle()
diff --git a/sdks/python/apache_beam/transforms/window_test.py b/sdks/python/apache_beam/transforms/window_test.py
index a405948..30430cc 100644
--- a/sdks/python/apache_beam/transforms/window_test.py
+++ b/sdks/python/apache_beam/transforms/window_test.py
@@ -196,6 +196,7 @@
result = (pcoll
| 'w' >> WindowInto(SlidingWindows(period=2, size=4))
| GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs)))
| reify_windows)
expected = [('key @ [-2.0, 2.0)', [1]),
('key @ [0.0, 4.0)', [1, 2, 3]),
@@ -222,7 +223,8 @@
| Map(lambda x_t: TimestampedValue(x_t[0], x_t[1]))
| 'w' >> WindowInto(FixedWindows(5))
| Map(lambda v: ('key', v))
- | GroupByKey())
+ | GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
assert_that(result, equal_to([('key', [0, 1, 2, 3, 4]),
('key', [5, 6, 7, 8, 9])]))
@@ -237,7 +239,8 @@
| 'rewindow' >> WindowInto(FixedWindows(5))
| 'rewindow2' >> WindowInto(FixedWindows(5))
| Map(lambda v: ('key', v))
- | GroupByKey())
+ | GroupByKey()
+ | beam.MapTuple(lambda k, vs: (k, sorted(vs))))
assert_that(result, equal_to([('key', sorted([0, 1, 2, 3, 4] * 3)),
('key', sorted([5, 6, 7, 8, 9] * 3))]))
diff --git a/website/src/documentation/dsls/sql/extensions/create-external-table.md b/website/src/documentation/dsls/sql/extensions/create-external-table.md
index 81d7dae..e331eff 100644
--- a/website/src/documentation/dsls/sql/extensions/create-external-table.md
+++ b/website/src/documentation/dsls/sql/extensions/create-external-table.md
@@ -308,6 +308,43 @@
Only simple types are supported.
+## MongoDB
+
+### Syntax
+
+```
+CREATE EXTERNAL TABLE [ IF NOT EXISTS ] tableName (tableElement [, tableElement ]*)
+TYPE mongodb
+LOCATION 'mongodb://[HOST]:[PORT]/[DATABASE]/[COLLECTION]'
+```
+* `LOCATION`: Location of the collection.
+ * `HOST`: Location of the MongoDB server. Can be localhost or an ip address.
+ When authentication is required username and password can be specified
+ as follows: `username:password@localhost`.
+ * `PORT`: Port on which MongoDB server is listening.
+ * `DATABASE`: Database to connect to.
+ * `COLLECTION`: Collection within the database.
+
+### Read Mode
+
+Read Mode supports reading from a collection.
+
+### Write Mode
+
+Write Mode supports writing to a collection.
+
+### Schema
+
+Only simple types are supported. MongoDB documents are mapped to Beam SQL types via [`JsonToRow`](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/transforms/JsonToRow.html) transform.
+
+### Example
+
+```
+CREATE EXTERNAL TABLE users (id INTEGER, username VARCHAR)
+TYPE mongodb
+LOCATION 'mongodb://localhost:27017/apache/users'
+```
+
## Text
TextIO is experimental in Beam SQL. Read Mode and Write Mode do not currently