Merge pull request #14704: [BEAM-12253] Change Read.UnboundedSourceAsSDFRestrictionTracker.getSplitBacklog to use the reader cache
diff --git a/.test-infra/jenkins/CommonJobProperties.groovy b/.test-infra/jenkins/CommonJobProperties.groovy
index 40f143a..851fc0b 100644
--- a/.test-infra/jenkins/CommonJobProperties.groovy
+++ b/.test-infra/jenkins/CommonJobProperties.groovy
@@ -49,7 +49,7 @@
// Discard old builds. Build records are only kept up to this number of days.
context.logRotator {
- daysToKeep(14)
+ daysToKeep(30)
}
// Source code management.
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/MorePipelineTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/MorePipelineTest.java
index 267b4d5..369aad9 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/MorePipelineTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/MorePipelineTest.java
@@ -25,18 +25,14 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
-import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.GenerateSequence;
-import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverride;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
@@ -103,22 +99,13 @@
@Override
public PCollectionView<List<T>> expand(PCollection<T> input) {
+ PCollection<KV<Void, T>> materializationInput =
+ input.apply(new View.VoidKeyToMultimapMaterialization<>());
Coder<T> inputCoder = input.getCoder();
- PCollection<KV<Long, PCollectionViews.ValueOrMetadata<T, OffsetRange>>> materializationInput =
- input
- .apply("IndexElements", ParDo.of(new View.ToListViewDoFn<>()))
- .setCoder(
- KvCoder.of(
- BigEndianLongCoder.of(),
- PCollectionViews.ValueOrMetadataCoder.create(
- inputCoder, OffsetRange.Coder.of())));
PCollectionView<List<T>> view =
- PCollectionViews.listView(
+ PCollectionViews.listViewUsingVoidKey(
materializationInput,
- (TupleTag<
- Materializations.MultimapView<
- Long, PCollectionViews.ValueOrMetadata<T, OffsetRange>>>)
- originalView.getTagInternal(),
+ (TupleTag<Materializations.MultimapView<Void, T>>) originalView.getTagInternal(),
(PCollectionViews.TypeDescriptorSupplier<T>) inputCoder::getEncodedTypeDescriptor,
materializationInput.getWindowingStrategy());
materializationInput.apply(View.CreatePCollectionView.of(view));
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IsmSideInputReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IsmSideInputReader.java
index 35d865f..37fa6fc 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IsmSideInputReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IsmSideInputReader.java
@@ -69,17 +69,11 @@
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PCollectionViews.HasDefaultValue;
import org.apache.beam.sdk.values.PCollectionViews.IterableViewFn;
-import org.apache.beam.sdk.values.PCollectionViews.IterableViewFn2;
import org.apache.beam.sdk.values.PCollectionViews.ListViewFn;
-import org.apache.beam.sdk.values.PCollectionViews.ListViewFn2;
import org.apache.beam.sdk.values.PCollectionViews.MapViewFn;
-import org.apache.beam.sdk.values.PCollectionViews.MapViewFn2;
import org.apache.beam.sdk.values.PCollectionViews.MultimapViewFn;
-import org.apache.beam.sdk.values.PCollectionViews.MultimapViewFn2;
import org.apache.beam.sdk.values.PCollectionViews.SingletonViewFn;
-import org.apache.beam.sdk.values.PCollectionViews.SingletonViewFn2;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
@@ -110,13 +104,7 @@
private static final Object NULL_PLACE_HOLDER = new Object();
private static final ImmutableList<Class<? extends ViewFn>> KNOWN_SINGLETON_VIEW_TYPES =
- ImmutableList.of(
- SingletonViewFn.class,
- SingletonViewFn2.class,
- MapViewFn.class,
- MapViewFn2.class,
- MultimapViewFn.class,
- MultimapViewFn2.class);
+ ImmutableList.of(SingletonViewFn.class, MapViewFn.class, MultimapViewFn.class);
/**
* Limit the number of concurrent initializations.
@@ -314,7 +302,7 @@
// We handle the singleton case separately since a null value may be returned.
// We use a null place holder to represent this, and when we detect it, we translate
// back to null for the user.
- if (viewFn instanceof SingletonViewFn || viewFn instanceof SingletonViewFn2) {
+ if (viewFn instanceof SingletonViewFn) {
ViewT rval =
executionContext
.<PCollectionViewWindow<ViewT>, ViewT>getLogicalReferenceCache()
@@ -323,7 +311,7 @@
() -> {
@SuppressWarnings("unchecked")
ViewT viewT =
- getSingletonForWindow(tag, (HasDefaultValue<ViewT>) viewFn, window);
+ getSingletonForWindow(tag, (SingletonViewFn<ViewT>) viewFn, window);
@SuppressWarnings("unchecked")
ViewT nullPlaceHolder = (ViewT) NULL_PLACE_HOLDER;
return viewT == null ? nullPlaceHolder : viewT;
@@ -331,10 +319,7 @@
return rval == NULL_PLACE_HOLDER ? null : rval;
} else if (singletonMaterializedTags.contains(tag)) {
checkArgument(
- viewFn instanceof MapViewFn
- || viewFn instanceof MapViewFn2
- || viewFn instanceof MultimapViewFn
- || viewFn instanceof MultimapViewFn2,
+ viewFn instanceof MapViewFn || viewFn instanceof MultimapViewFn,
"Unknown view type stored as singleton. Expected one of %s, got %s",
KNOWN_SINGLETON_VIEW_TYPES,
viewFn.getClass().getName());
@@ -351,19 +336,15 @@
.get(
PCollectionViewWindow.of(view, window),
() -> {
- if (viewFn instanceof IterableViewFn
- || viewFn instanceof IterableViewFn2
- || viewFn instanceof ListViewFn
- || viewFn instanceof ListViewFn2) {
+ if (viewFn instanceof IterableViewFn || viewFn instanceof ListViewFn) {
@SuppressWarnings("unchecked")
ViewT viewT = (ViewT) getListForWindow(tag, window);
return viewT;
- } else if (viewFn instanceof MapViewFn || viewFn instanceof MapViewFn2) {
+ } else if (viewFn instanceof MapViewFn) {
@SuppressWarnings("unchecked")
ViewT viewT = (ViewT) getMapForWindow(tag, window);
return viewT;
- } else if (viewFn instanceof MultimapViewFn
- || viewFn instanceof MultimapViewFn2) {
+ } else if (viewFn instanceof MultimapViewFn) {
@SuppressWarnings("unchecked")
ViewT viewT = (ViewT) getMultimapForWindow(tag, window);
return viewT;
@@ -394,7 +375,7 @@
* </ul>
*/
private <T, W extends BoundedWindow> T getSingletonForWindow(
- TupleTag<?> viewTag, HasDefaultValue<T> viewFn, W window) throws IOException {
+ TupleTag<?> viewTag, SingletonViewFn<T> viewFn, W window) throws IOException {
@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"unchecked"
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
index a3b5665..63b4191 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.transforms;
+import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.io.IOException;
@@ -46,6 +47,7 @@
import org.apache.beam.sdk.transforms.CombineWithContext.Context;
import org.apache.beam.sdk.transforms.CombineWithContext.RequiresContextInternal;
import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
+import org.apache.beam.sdk.transforms.View.VoidKeyToMultimapMaterialization;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
@@ -1306,21 +1308,43 @@
@Override
public PCollectionView<OutputT> expand(PCollection<InputT> input) {
+ // TODO(BEAM-10097): Make this the default expansion for all portable runners.
+ if (hasExperiment(input.getPipeline().getOptions(), "beam_fn_api")
+ && (hasExperiment(input.getPipeline().getOptions(), "use_runner_v2")
+ || hasExperiment(input.getPipeline().getOptions(), "use_unified_worker"))) {
+ PCollection<OutputT> combined =
+ input.apply(
+ "CombineValues",
+ Combine.<InputT, OutputT>globally(fn).withoutDefaults().withFanout(fanout));
+ Coder<OutputT> outputCoder = combined.getCoder();
+ PCollectionView<OutputT> view =
+ PCollectionViews.singletonView(
+ combined,
+ (TypeDescriptorSupplier<OutputT>)
+ () -> outputCoder != null ? outputCoder.getEncodedTypeDescriptor() : null,
+ input.getWindowingStrategy(),
+ insertDefault,
+ insertDefault ? fn.defaultValue() : null,
+ combined.getCoder());
+ combined.apply("CreatePCollectionView", CreatePCollectionView.of(view));
+ return view;
+ }
+
PCollection<OutputT> combined =
- input.apply(
- "CombineValues",
- Combine.<InputT, OutputT>globally(fn).withoutDefaults().withFanout(fanout));
+ input.apply(Combine.<InputT, OutputT>globally(fn).withoutDefaults().withFanout(fanout));
+ PCollection<KV<Void, OutputT>> materializationInput =
+ combined.apply(new VoidKeyToMultimapMaterialization<>());
Coder<OutputT> outputCoder = combined.getCoder();
PCollectionView<OutputT> view =
- PCollectionViews.singletonView(
- combined,
+ PCollectionViews.singletonViewUsingVoidKey(
+ materializationInput,
(TypeDescriptorSupplier<OutputT>)
() -> outputCoder != null ? outputCoder.getEncodedTypeDescriptor() : null,
input.getWindowingStrategy(),
insertDefault,
insertDefault ? fn.defaultValue() : null,
combined.getCoder());
- combined.apply("CreatePCollectionView", CreatePCollectionView.of(view));
+ materializationInput.apply(CreatePCollectionView.of(view));
return view;
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java
index e81f0b8..904575c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.transforms;
+import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment;
+
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -27,6 +29,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -257,16 +260,33 @@
* Long#MIN_VALUE} key is used to store all known {@link OffsetRange ranges} allowing us to
* compute such an ordering.
*/
+
+ // TODO(BEAM-10097): Make this the default expansion for all portable runners.
+ if (hasExperiment(input.getPipeline().getOptions(), "beam_fn_api")
+ && (hasExperiment(input.getPipeline().getOptions(), "use_runner_v2")
+ || hasExperiment(input.getPipeline().getOptions(), "use_unified_worker"))) {
+ Coder<T> inputCoder = input.getCoder();
+ PCollection<KV<Long, ValueOrMetadata<T, OffsetRange>>> materializationInput =
+ input
+ .apply("IndexElements", ParDo.of(new ToListViewDoFn<>()))
+ .setCoder(
+ KvCoder.of(
+ BigEndianLongCoder.of(),
+ ValueOrMetadataCoder.create(inputCoder, OffsetRange.Coder.of())));
+ PCollectionView<List<T>> view =
+ PCollectionViews.listView(
+ materializationInput,
+ (TypeDescriptorSupplier<T>) inputCoder::getEncodedTypeDescriptor,
+ input.getWindowingStrategy());
+ materializationInput.apply(CreatePCollectionView.of(view));
+ return view;
+ }
+
+ PCollection<KV<Void, T>> materializationInput =
+ input.apply(new VoidKeyToMultimapMaterialization<>());
Coder<T> inputCoder = input.getCoder();
- PCollection<KV<Long, ValueOrMetadata<T, OffsetRange>>> materializationInput =
- input
- .apply("IndexElements", ParDo.of(new ToListViewDoFn<>()))
- .setCoder(
- KvCoder.of(
- BigEndianLongCoder.of(),
- ValueOrMetadataCoder.create(inputCoder, OffsetRange.Coder.of())));
PCollectionView<List<T>> view =
- PCollectionViews.listView(
+ PCollectionViews.listViewUsingVoidKey(
materializationInput,
(TypeDescriptorSupplier<T>) inputCoder::getEncodedTypeDescriptor,
materializationInput.getWindowingStrategy());
@@ -280,8 +300,8 @@
* range for each window seen. We use random offset ranges to minimize the chance that two ranges
* overlap increasing the odds that each "key" represents a single index.
*/
- @Internal
- public static class ToListViewDoFn<T> extends DoFn<T, KV<Long, ValueOrMetadata<T, OffsetRange>>> {
+ private static class ToListViewDoFn<T>
+ extends DoFn<T, KV<Long, ValueOrMetadata<T, OffsetRange>>> {
private Map<BoundedWindow, OffsetRange> windowsToOffsets = new HashMap<>();
private OffsetRange generateRange(BoundedWindow window) {
@@ -330,19 +350,29 @@
throw new IllegalStateException("Unable to create a side-input view from input", e);
}
+ // TODO(BEAM-10097): Make this the default expansion for all portable runners.
+ if (hasExperiment(input.getPipeline().getOptions(), "beam_fn_api")
+ && (hasExperiment(input.getPipeline().getOptions(), "use_runner_v2")
+ || hasExperiment(input.getPipeline().getOptions(), "use_unified_worker"))) {
+ Coder<T> inputCoder = input.getCoder();
+ PCollectionView<Iterable<T>> view =
+ PCollectionViews.iterableView(
+ input,
+ (TypeDescriptorSupplier<T>) inputCoder::getEncodedTypeDescriptor,
+ input.getWindowingStrategy());
+ input.apply(CreatePCollectionView.of(view));
+ return view;
+ }
+
+ PCollection<KV<Void, T>> materializationInput =
+ input.apply(new VoidKeyToMultimapMaterialization<>());
Coder<T> inputCoder = input.getCoder();
- // HACK to work around https://issues.apache.org/jira/browse/BEAM-12228:
- // There are bugs in "composite" vs "primitive" transform distinction
- // in TransformHierachy. This noop transform works around them and should be zero
- // cost.
- PCollection<T> materializationInput =
- input.apply(MapElements.via(new SimpleFunction<T, T>(x -> x) {}));
PCollectionView<Iterable<T>> view =
- PCollectionViews.iterableView(
+ PCollectionViews.iterableViewUsingVoidKey(
materializationInput,
(TypeDescriptorSupplier<T>) inputCoder::getEncodedTypeDescriptor,
materializationInput.getWindowingStrategy());
- input.apply(CreatePCollectionView.of(view));
+ materializationInput.apply(CreatePCollectionView.of(view));
return view;
}
}
@@ -478,22 +508,35 @@
throw new IllegalStateException("Unable to create a side-input view from input", e);
}
+ // TODO(BEAM-10097): Make this the default expansion for all portable runners.
+ if (hasExperiment(input.getPipeline().getOptions(), "beam_fn_api")
+ && (hasExperiment(input.getPipeline().getOptions(), "use_runner_v2")
+ || hasExperiment(input.getPipeline().getOptions(), "use_unified_worker"))) {
+ KvCoder<K, V> kvCoder = (KvCoder<K, V>) input.getCoder();
+ Coder<K> keyCoder = kvCoder.getKeyCoder();
+ Coder<V> valueCoder = kvCoder.getValueCoder();
+ PCollectionView<Map<K, Iterable<V>>> view =
+ PCollectionViews.multimapView(
+ input,
+ (TypeDescriptorSupplier<K>) keyCoder::getEncodedTypeDescriptor,
+ (TypeDescriptorSupplier<V>) valueCoder::getEncodedTypeDescriptor,
+ input.getWindowingStrategy());
+ input.apply(CreatePCollectionView.of(view));
+ return view;
+ }
+
KvCoder<K, V> kvCoder = (KvCoder<K, V>) input.getCoder();
Coder<K> keyCoder = kvCoder.getKeyCoder();
Coder<V> valueCoder = kvCoder.getValueCoder();
- // HACK to work around https://issues.apache.org/jira/browse/BEAM-12228:
- // There are bugs in "composite" vs "primitive" transform distinction
- // in TransformHierachy. This noop transform works around them and should be zero
- // cost.
- PCollection<KV<K, V>> materializationInput =
- input.apply(MapElements.via(new SimpleFunction<KV<K, V>, KV<K, V>>(x -> x) {}));
+ PCollection<KV<Void, KV<K, V>>> materializationInput =
+ input.apply(new VoidKeyToMultimapMaterialization<>());
PCollectionView<Map<K, Iterable<V>>> view =
- PCollectionViews.multimapView(
+ PCollectionViews.multimapViewUsingVoidKey(
materializationInput,
(TypeDescriptorSupplier<K>) keyCoder::getEncodedTypeDescriptor,
(TypeDescriptorSupplier<V>) valueCoder::getEncodedTypeDescriptor,
materializationInput.getWindowingStrategy());
- input.apply(CreatePCollectionView.of(view));
+ materializationInput.apply(CreatePCollectionView.of(view));
return view;
}
}
@@ -524,19 +567,37 @@
throw new IllegalStateException("Unable to create a side-input view from input", e);
}
+ // TODO(BEAM-10097): Make this the default expansion for all portable runners.
+ if (hasExperiment(input.getPipeline().getOptions(), "beam_fn_api")
+ && (hasExperiment(input.getPipeline().getOptions(), "use_runner_v2")
+ || hasExperiment(input.getPipeline().getOptions(), "use_unified_worker"))) {
+ KvCoder<K, V> kvCoder = (KvCoder<K, V>) input.getCoder();
+ Coder<K> keyCoder = kvCoder.getKeyCoder();
+ Coder<V> valueCoder = kvCoder.getValueCoder();
+
+ PCollectionView<Map<K, V>> view =
+ PCollectionViews.mapView(
+ input,
+ (TypeDescriptorSupplier<K>) keyCoder::getEncodedTypeDescriptor,
+ (TypeDescriptorSupplier<V>) valueCoder::getEncodedTypeDescriptor,
+ input.getWindowingStrategy());
+ input.apply(CreatePCollectionView.of(view));
+ return view;
+ }
+
KvCoder<K, V> kvCoder = (KvCoder<K, V>) input.getCoder();
Coder<K> keyCoder = kvCoder.getKeyCoder();
Coder<V> valueCoder = kvCoder.getValueCoder();
- PCollection<KV<K, V>> materializationInput =
- input.apply(MapElements.via(new SimpleFunction<KV<K, V>, KV<K, V>>(x -> x) {}));
+ PCollection<KV<Void, KV<K, V>>> materializationInput =
+ input.apply(new VoidKeyToMultimapMaterialization<>());
PCollectionView<Map<K, V>> view =
- PCollectionViews.mapView(
+ PCollectionViews.mapViewUsingVoidKey(
materializationInput,
(TypeDescriptorSupplier<K>) keyCoder::getEncodedTypeDescriptor,
(TypeDescriptorSupplier<V>) valueCoder::getEncodedTypeDescriptor,
materializationInput.getWindowingStrategy());
- input.apply(CreatePCollectionView.of(view));
+ materializationInput.apply(CreatePCollectionView.of(view));
return view;
}
}
@@ -545,11 +606,34 @@
// Internal details below
/**
+ * A {@link PTransform} which converts all values into {@link KV}s with {@link Void} keys.
+ *
+ * <p>TODO(BEAM-10097): Replace this materialization with specializations that optimize the
+ * various SDK requested views.
+ */
+ @Internal
+ public static class VoidKeyToMultimapMaterialization<T>
+ extends PTransform<PCollection<T>, PCollection<KV<Void, T>>> {
+
+ private static class VoidKeyToMultimapMaterializationDoFn<T> extends DoFn<T, KV<Void, T>> {
+ @ProcessElement
+ public void processElement(@Element T element, OutputReceiver<KV<Void, T>> r) {
+ r.output(KV.of((Void) null, element));
+ }
+ }
+
+ @Override
+ public PCollection<KV<Void, T>> expand(PCollection<T> input) {
+ PCollection output = input.apply(ParDo.of(new VoidKeyToMultimapMaterializationDoFn<>()));
+ output.setCoder(KvCoder.of(VoidCoder.of(), input.getCoder()));
+ return output;
+ }
+ }
+
+ /**
* <b><i>For internal use only; no backwards-compatibility guarantees.</i></b>
*
- * <p>Placeholder transform for runners to have a hook to materialize a {@link PCollection} as a
- * side input. The metadata included in the {@link PCollectionView} is how the {@link PCollection}
- * will be read as a side input.
+ * <p>Creates a primitive {@link PCollectionView}.
*
* @param <ElemT> The type of the elements of the input PCollection
* @param <ViewT> The type associated with the {@link PCollectionView} used as a side input
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java
index 360c1af..df88e21 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java
@@ -120,7 +120,6 @@
*/
@Deprecated
public static <T, W extends BoundedWindow> PCollectionView<T> singletonViewUsingVoidKey(
- TupleTag<MultimapView<Void, T>> tag,
PCollection<KV<Void, T>> pCollection,
TypeDescriptorSupplier<T> typeDescriptorSupplier,
WindowingStrategy<?, W> windowingStrategy,
@@ -129,7 +128,6 @@
Coder<T> defaultValueCoder) {
return new SimplePCollectionView<>(
pCollection,
- tag,
new SingletonViewFn<>(hasDefault, defaultValue, defaultValueCoder, typeDescriptorSupplier),
windowingStrategy.getWindowFn().getDefaultWindowMappingFn(),
windowingStrategy);
@@ -158,13 +156,11 @@
*/
@Deprecated
public static <T, W extends BoundedWindow> PCollectionView<Iterable<T>> iterableViewUsingVoidKey(
- TupleTag<MultimapView<Void, T>> tag,
PCollection<KV<Void, T>> pCollection,
TypeDescriptorSupplier<T> typeDescriptorSupplier,
WindowingStrategy<?, W> windowingStrategy) {
return new SimplePCollectionView<>(
pCollection,
- tag,
new IterableViewFn<>(typeDescriptorSupplier),
windowingStrategy.getWindowFn().getDefaultWindowMappingFn(),
windowingStrategy);
@@ -188,35 +184,16 @@
/**
* Returns a {@code PCollectionView<List<T>>} capable of processing elements windowed using the
* provided {@link WindowingStrategy}.
- */
- public static <T, W extends BoundedWindow> PCollectionView<List<T>> listView(
- PCollection<KV<Long, ValueOrMetadata<T, OffsetRange>>> pCollection,
- TupleTag<Materializations.MultimapView<Long, ValueOrMetadata<T, OffsetRange>>> tag,
- TypeDescriptorSupplier<T> typeDescriptorSupplier,
- WindowingStrategy<?, W> windowingStrategy) {
- return new SimplePCollectionView<>(
- pCollection,
- tag,
- new ListViewFn2<>(typeDescriptorSupplier),
- windowingStrategy.getWindowFn().getDefaultWindowMappingFn(),
- windowingStrategy);
- }
-
- /**
- * Returns a {@code PCollectionView<List<T>>} capable of processing elements windowed using the
- * provided {@link WindowingStrategy}.
*
* @deprecated See {@link #listView}.
*/
@Deprecated
public static <T, W extends BoundedWindow> PCollectionView<List<T>> listViewUsingVoidKey(
- TupleTag<MultimapView<Void, T>> tag,
PCollection<KV<Void, T>> pCollection,
TypeDescriptorSupplier<T> typeDescriptorSupplier,
WindowingStrategy<?, W> windowingStrategy) {
return new SimplePCollectionView<>(
pCollection,
- tag,
new ListViewFn<>(typeDescriptorSupplier),
windowingStrategy.getWindowFn().getDefaultWindowMappingFn(),
windowingStrategy);
@@ -266,14 +243,12 @@
*/
@Deprecated
public static <K, V, W extends BoundedWindow> PCollectionView<Map<K, V>> mapViewUsingVoidKey(
- TupleTag<MultimapView<Void, KV<K, V>>> tag,
PCollection<KV<Void, KV<K, V>>> pCollection,
TypeDescriptorSupplier<K> keyTypeDescriptorSupplier,
TypeDescriptorSupplier<V> valueTypeDescriptorSupplier,
WindowingStrategy<?, W> windowingStrategy) {
return new SimplePCollectionView<>(
pCollection,
- tag,
new MapViewFn<>(keyTypeDescriptorSupplier, valueTypeDescriptorSupplier),
windowingStrategy.getWindowFn().getDefaultWindowMappingFn(),
windowingStrategy);
@@ -304,14 +279,12 @@
@Deprecated
public static <K, V, W extends BoundedWindow>
PCollectionView<Map<K, Iterable<V>>> multimapViewUsingVoidKey(
- TupleTag<MultimapView<Void, KV<K, V>>> tag,
PCollection<KV<Void, KV<K, V>>> pCollection,
TypeDescriptorSupplier<K> keyTypeDescriptorSupplier,
TypeDescriptorSupplier<V> valueTypeDescriptorSupplier,
WindowingStrategy<?, W> windowingStrategy) {
return new SimplePCollectionView<>(
pCollection,
- tag,
new MultimapViewFn<>(keyTypeDescriptorSupplier, valueTypeDescriptorSupplier),
windowingStrategy.getWindowFn().getDefaultWindowMappingFn(),
windowingStrategy);
@@ -339,9 +312,7 @@
* <p>{@link SingletonViewFn} is meant to be removed in the future and replaced with this class.
*/
@Experimental(Kind.CORE_RUNNERS_ONLY)
- @Internal
- public static class SingletonViewFn2<T> extends ViewFn<IterableView<T>, T>
- implements HasDefaultValue<T> {
+ private static class SingletonViewFn2<T> extends ViewFn<IterableView<T>, T> {
private byte @Nullable [] encodedDefaultValue;
private transient @Nullable T defaultValue;
private @Nullable Coder<T> valueCoder;
@@ -379,7 +350,6 @@
*
* @throws NoSuchElementException if no default was specified.
*/
- @Override
public T getDefaultValue() {
if (!hasDefault) {
throw new NoSuchElementException("Empty PCollection accessed as a singleton view.");
@@ -423,11 +393,6 @@
}
}
- @Internal
- public interface HasDefaultValue<T> {
- T getDefaultValue();
- }
-
/**
* Implementation which is able to adapt a multimap materialization to a {@code T}.
*
@@ -437,8 +402,7 @@
*/
@Deprecated
@Experimental(Kind.CORE_RUNNERS_ONLY)
- public static class SingletonViewFn<T> extends ViewFn<MultimapView<Void, T>, T>
- implements HasDefaultValue<T> {
+ public static class SingletonViewFn<T> extends ViewFn<MultimapView<Void, T>, T> {
private byte @Nullable [] encodedDefaultValue;
private transient @Nullable T defaultValue;
private @Nullable Coder<T> valueCoder;
@@ -476,7 +440,6 @@
*
* @throws NoSuchElementException if no default was specified.
*/
- @Override
public T getDefaultValue() {
if (!hasDefault) {
throw new NoSuchElementException("Empty PCollection accessed as a singleton view.");
@@ -530,8 +493,7 @@
* <p>{@link IterableViewFn} is meant to be removed in the future and replaced with this class.
*/
@Experimental(Kind.CORE_RUNNERS_ONLY)
- @Internal
- public static class IterableViewFn2<T> extends ViewFn<IterableView<T>, Iterable<T>> {
+ private static class IterableViewFn2<T> extends ViewFn<IterableView<T>, Iterable<T>> {
private TypeDescriptorSupplier<T> typeDescriptorSupplier;
public IterableViewFn2(TypeDescriptorSupplier<T> typeDescriptorSupplier) {
@@ -597,7 +559,7 @@
*/
@Experimental(Kind.CORE_RUNNERS_ONLY)
@VisibleForTesting
- public static class ListViewFn2<T>
+ static class ListViewFn2<T>
extends ViewFn<MultimapView<Long, ValueOrMetadata<T, OffsetRange>>, List<T>> {
private TypeDescriptorSupplier<T> typeDescriptorSupplier;
@@ -1041,8 +1003,7 @@
* <p>{@link MultimapViewFn} is meant to be removed in the future and replaced with this class.
*/
@Experimental(Kind.CORE_RUNNERS_ONLY)
- @Internal
- public static class MultimapViewFn2<K, V>
+ private static class MultimapViewFn2<K, V>
extends ViewFn<MultimapView<K, V>, Map<K, Iterable<V>>> {
private TypeDescriptorSupplier<K> keyTypeDescriptorSupplier;
private TypeDescriptorSupplier<V> valueTypeDescriptorSupplier;
@@ -1130,8 +1091,7 @@
*
* <p>{@link MapViewFn} is meant to be removed in the future and replaced with this class.
*/
- @Internal
- public static class MapViewFn2<K, V> extends ViewFn<MultimapView<K, V>, Map<K, V>> {
+ private static class MapViewFn2<K, V> extends ViewFn<MultimapView<K, V>, Map<K, V>> {
private TypeDescriptorSupplier<K> keyTypeDescriptorSupplier;
private TypeDescriptorSupplier<V> valueTypeDescriptorSupplier;
@@ -1319,13 +1279,7 @@
@Override
public String toString() {
- return MoreObjects.toStringHelper(this)
- .add("tag", tag)
- .add("viewFn", viewFn)
- .add("coder", coder)
- .add("windowMappingFn", windowMappingFn)
- .add("pCollection", pCollection)
- .toString();
+ return MoreObjects.toStringHelper(this).add("tag", tag).toString();
}
@Override
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java
index 61b4bf8..b340d91 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PCollectionViewTesting.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.testing;
+import static org.apache.beam.sdk.options.ExperimentalOptions.hasExperiment;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -40,42 +42,82 @@
// materializations will differ but test code should not worry about what these look like if
// they are relying on the ViewFn to "undo" the conversion.
- if (View.AsSingleton.class.equals(viewTransformClass.getClass())) {
- for (Object value : values) {
- rval.add(value);
- }
- } else if (View.AsIterable.class.equals(viewTransformClass.getClass())) {
- for (Object value : values) {
- rval.add(value);
- }
- } else if (View.AsList.class.equals(viewTransformClass.getClass())) {
- if (values.length > 0) {
- rval.add(
- KV.of(
- Long.MIN_VALUE, ValueOrMetadata.createMetadata(new OffsetRange(0, values.length))));
- for (int i = 0; i < values.length; ++i) {
- rval.add(KV.of((long) i, ValueOrMetadata.create(values[i])));
+ // TODO(BEAM-10097): Make this the default case once all portable runners can support
+ // the iterable access pattern.
+ if (hasExperiment(options, "beam_fn_api")
+ && (hasExperiment(options, "use_runner_v2")
+ || hasExperiment(options, "use_unified_worker"))) {
+ if (View.AsSingleton.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(value);
}
- }
- } else if (View.AsMap.class.equals(viewTransformClass.getClass())) {
- for (Object value : values) {
- rval.add(value);
- }
- } else if (View.AsMultimap.class.equals(viewTransformClass.getClass())) {
- for (Object value : values) {
- rval.add(value);
+ } else if (View.AsIterable.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(value);
+ }
+ } else if (View.AsList.class.equals(viewTransformClass.getClass())) {
+ if (values.length > 0) {
+ rval.add(
+ KV.of(
+ Long.MIN_VALUE,
+ ValueOrMetadata.createMetadata(new OffsetRange(0, values.length))));
+ for (int i = 0; i < values.length; ++i) {
+ rval.add(KV.of((long) i, ValueOrMetadata.create(values[i])));
+ }
+ }
+ } else if (View.AsMap.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(value);
+ }
+ } else if (View.AsMultimap.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(value);
+ }
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "Unknown type of view %s. Supported views are %s.",
+ viewTransformClass.getClass(),
+ ImmutableSet.of(
+ View.AsSingleton.class,
+ View.AsIterable.class,
+ View.AsList.class,
+ View.AsMap.class,
+ View.AsMultimap.class)));
}
} else {
- throw new IllegalArgumentException(
- String.format(
- "Unknown type of view %s. Supported views are %s.",
- viewTransformClass.getClass(),
- ImmutableSet.of(
- View.AsSingleton.class,
- View.AsIterable.class,
- View.AsList.class,
- View.AsMap.class,
- View.AsMultimap.class)));
+ if (View.AsSingleton.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(KV.of(null, value));
+ }
+ } else if (View.AsIterable.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(KV.of(null, value));
+ }
+ } else if (View.AsList.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(KV.of(null, value));
+ }
+ } else if (View.AsMap.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(KV.of(null, value));
+ }
+ } else if (View.AsMultimap.class.equals(viewTransformClass.getClass())) {
+ for (Object value : values) {
+ rval.add(KV.of(null, value));
+ }
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "Unknown type of view %s. Supported views are %s.",
+ viewTransformClass.getClass(),
+ ImmutableSet.of(
+ View.AsSingleton.class,
+ View.AsIterable.class,
+ View.AsList.class,
+ View.AsMap.class,
+ View.AsMultimap.class)));
+ }
}
return Collections.unmodifiableList(rval);
}
diff --git a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java
index ec328d8..0a48131 100644
--- a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java
+++ b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/sources/UnboundedEventSource.java
@@ -98,7 +98,7 @@
* Current backlog, as estimated number of event bytes we are behind, or null if unknown.
* Reported to callers.
*/
- private @Nullable Long backlogBytes;
+ private long backlogBytes;
/** Wallclock time (ms since epoch) we last reported the backlog, or -1 if never reported. */
private long lastReportedBacklogWallclock;
@@ -127,6 +127,7 @@
lastReportedBacklogWallclock = -1;
pendingEventWallclockTime = -1;
timestampAtLastReportedBacklogMs = -1;
+ updateBacklog(System.currentTimeMillis(), 0);
}
public EventReader(GeneratorConfig config) {
@@ -146,9 +147,7 @@
while (pendingEvent == null) {
if (!generator.hasNext() && heldBackEvents.isEmpty()) {
// No more events, EVER.
- if (isRateLimited) {
- updateBacklog(System.currentTimeMillis(), 0);
- }
+ updateBacklog(System.currentTimeMillis(), 0);
if (watermark < BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) {
watermark = BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis();
LOG.trace("stopped unbounded generator {}", generator);
@@ -177,9 +176,7 @@
}
} else {
// Waiting for held-back event to fire.
- if (isRateLimited) {
- updateBacklog(now, 0);
- }
+ updateBacklog(now, 0);
return false;
}
@@ -199,6 +196,8 @@
return false;
}
updateBacklog(now, now - pendingEventWallclockTime);
+ } else {
+ updateBacklog(now, 0);
}
// This event is ready to fire.
@@ -210,20 +209,26 @@
private void updateBacklog(long now, long newBacklogDurationMs) {
backlogDurationMs = newBacklogDurationMs;
long interEventDelayUs = generator.currentInterEventDelayUs();
- if (interEventDelayUs != 0) {
+ if (isRateLimited && interEventDelayUs > 0) {
long backlogEvents = (backlogDurationMs * 1000 + interEventDelayUs - 1) / interEventDelayUs;
backlogBytes = generator.getCurrentConfig().estimatedBytesForEvents(backlogEvents);
+ } else {
+ double fractionRemaining = 1.0 - generator.getFractionConsumed();
+ backlogBytes =
+ Math.max(
+ 0L,
+ (long) (generator.getCurrentConfig().getEstimatedSizeBytes() * fractionRemaining));
}
if (lastReportedBacklogWallclock < 0
|| now - lastReportedBacklogWallclock > BACKLOG_PERIOD.getMillis()) {
- double timeDialation = Double.NaN;
+ double timeDilation = Double.NaN;
if (pendingEvent != null
&& lastReportedBacklogWallclock >= 0
&& timestampAtLastReportedBacklogMs >= 0) {
long wallclockProgressionMs = now - lastReportedBacklogWallclock;
long eventTimeProgressionMs =
pendingEvent.getTimestamp().getMillis() - timestampAtLastReportedBacklogMs;
- timeDialation = (double) eventTimeProgressionMs / (double) wallclockProgressionMs;
+ timeDilation = (double) eventTimeProgressionMs / (double) wallclockProgressionMs;
}
LOG.debug(
"unbounded generator backlog now {}ms ({} bytes) at {}us interEventDelay "
@@ -231,7 +236,7 @@
backlogDurationMs,
backlogBytes,
interEventDelayUs,
- timeDialation);
+ timeDilation);
lastReportedBacklogWallclock = now;
if (pendingEvent != null) {
timestampAtLastReportedBacklogMs = pendingEvent.getTimestamp().getMillis();
@@ -277,7 +282,7 @@
@Override
public long getSplitBacklogBytes() {
- return backlogBytes == null ? BACKLOG_UNKNOWN : backlogBytes;
+ return backlogBytes;
}
@Override
diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py
index 80ee149..54a15c5 100644
--- a/sdks/python/apache_beam/dataframe/expressions.py
+++ b/sdks/python/apache_beam/dataframe/expressions.py
@@ -348,10 +348,16 @@
preserves_partition_by: The level of partitioning preserved.
"""
if (not _get_allow_non_parallel() and
- requires_partition_by == partitionings.Singleton()):
+ isinstance(requires_partition_by, partitionings.Singleton)):
+ reason = requires_partition_by.reason or (
+ f"Encountered non-parallelizable form of {name!r}.")
+
raise NonParallelOperation(
- "Using non-parallel form of %s "
- "outside of allow_non_parallel_operations block." % name)
+ f"{reason}\n"
+ "Consider using an allow_non_parallel_operations block if you're "
+ "sure you want to do this. See "
+ "https://s.apache.org/dataframe-non-parallel-operations for more "
+ "information.")
args = tuple(args)
if proxy is None:
proxy = func(*(arg.proxy() for arg in args))
@@ -406,4 +412,6 @@
class NonParallelOperation(Exception):
- pass
+ def __init__(self, msg):
+ super(NonParallelOperation, self).__init__(self, msg)
+ self.msg = msg
diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py
index 1f892f4..25355fe 100644
--- a/sdks/python/apache_beam/dataframe/frames.py
+++ b/sdks/python/apache_beam/dataframe/frames.py
@@ -109,7 +109,11 @@
if index is not None and errors == 'raise':
# In order to raise an error about missing index values, we'll
# need to collect the entire dataframe.
- requires = partitionings.Singleton()
+ requires = partitionings.Singleton(
+ reason=(
+ "drop(errors='raise', axis='index') is not currently "
+ "parallelizable. This requires collecting all data on a single "
+ f"node in order to detect if one of {index!r} is missing."))
else:
requires = partitionings.Arbitrary()
@@ -142,24 +146,26 @@
def fillna(self, value, method, axis, limit, **kwargs):
# Default value is None, but is overriden with index.
axis = axis or 'index'
- if method is not None and axis in (0, 'index'):
- raise frame_base.WontImplementError(
- f"fillna(method={method!r}) is not supported because it is "
- "order-sensitive. Only fillna(method=None) is supported.",
- reason="order-sensitive")
+
+ if axis in (0, 'index'):
+ if method is not None:
+ raise frame_base.WontImplementError(
+ f"fillna(method={method!r}, axis={axis!r}) is not supported "
+ "because it is order-sensitive. Only fillna(method=None) is "
+ f"supported with axis={axis!r}.",
+ reason="order-sensitive")
+ if limit is not None:
+ raise frame_base.WontImplementError(
+ f"fillna(limit={method!r}, axis={axis!r}) is not supported because "
+ "it is order-sensitive. Only fillna(limit=None) is supported with "
+ f"axis={axis!r}.",
+ reason="order-sensitive")
+
if isinstance(value, frame_base.DeferredBase):
value_expr = value._expr
else:
value_expr = expressions.ConstantExpression(value)
- if limit is not None and method is None:
- # If method is not None (and axis is 'columns'), we can do limit in
- # a distributed way. Otherwise the limit is global, so it requires
- # Singleton partitioning.
- requires = partitionings.Singleton()
- else:
- requires = partitionings.Arbitrary()
-
return frame_base.DeferredFrame.wrap(
# yapf: disable
expressions.ComputedExpression(
@@ -169,7 +175,7 @@
value, method=method, axis=axis, limit=limit, **kwargs),
[self._expr, value_expr],
preserves_partition_by=partitionings.Arbitrary(),
- requires_partition_by=requires))
+ requires_partition_by=partitionings.Arbitrary()))
@frame_base.args_to_kwargs(pd.DataFrame)
@frame_base.populate_defaults(pd.DataFrame)
@@ -523,7 +529,11 @@
if errors == "ignore":
# We need all data in order to ignore errors and propagate the original
# data.
- requires = partitionings.Singleton()
+ requires = partitionings.Singleton(
+ reason=(
+ f"where(errors={errors!r}) is currently not parallelizable, "
+ "because all data must be collected on one node to determine if "
+ "the original data should be propagated instead."))
actual_args['errors'] = errors
@@ -668,10 +678,8 @@
reason="order-sensitive")
if verify_integrity:
- # verifying output has a unique index requires global index.
- # TODO(BEAM-11839): Attach an explanation to the Singleton partitioning
- # requirement, and include it in raised errors.
- requires = partitionings.Singleton()
+ # We can verify the index is non-unique within index partitioned data.
+ requires = partitionings.Index()
else:
requires = partitionings.Arbitrary()
@@ -750,7 +758,12 @@
right = other._expr
right_is_series = False
else:
- raise frame_base.WontImplementError('non-deferred result')
+ raise frame_base.WontImplementError(
+ "other must be a DeferredDataFrame or DeferredSeries instance. "
+ "Passing a concrete list or numpy array is not supported. Those "
+ "types have no index and must be joined based on the order of the "
+ "data.",
+ reason="order-sensitive")
dots = expressions.ComputedExpression(
'dot',
@@ -838,6 +851,10 @@
return x._corr_aligned(y, min_periods)
else:
+ reason = (
+ f"Encountered corr(method={method!r}) which cannot be "
+ "parallelized. Only corr(method='pearson') is currently "
+ "parallelizable.")
# The rank-based correlations are not obviously parallelizable, though
# perhaps an approximation could be done with a knowledge of quantiles
# and custom partitioning.
@@ -847,9 +864,7 @@
lambda df,
other: df.corr(other, method=method, min_periods=min_periods),
[self._expr, other._expr],
- # TODO(BEAM-11839): Attach an explanation to the Singleton
- # partitioning requirement, and include it in raised errors.
- requires_partition_by=partitionings.Singleton()))
+ requires_partition_by=partitionings.Singleton(reason=reason)))
def _corr_aligned(self, other, min_periods):
std_x = self.std()
@@ -958,9 +973,16 @@
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'aggregate',
- lambda s: s.agg(func, *args, **kwargs), [intermediate],
+ lambda s: s.agg(func, *args, **kwargs),
+ [intermediate],
preserves_partition_by=partitionings.Arbitrary(),
- requires_partition_by=partitionings.Singleton()))
+ # TODO(BEAM-11839): This reason should be more specific. It's
+ # actually incorrect for the args/kwargs case above.
+ requires_partition_by=partitionings.Singleton(
+ reason=(
+ f"Aggregation function {func!r} cannot currently be "
+ "parallelized, it requires collecting all data for "
+ "this Series on a single node."))))
agg = aggregate
@@ -1119,7 +1141,10 @@
if limit is None:
requires_partition_by = partitionings.Arbitrary()
else:
- requires_partition_by = partitionings.Singleton()
+ requires_partition_by = partitionings.Singleton(
+ reason=(
+ f"replace(limit={limit!r}) cannot currently be parallelized, it "
+ "requires collecting all data on a single node."))
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'replace',
@@ -1154,7 +1179,8 @@
'unique',
lambda df: pd.Series(df.unique()), [self._expr],
preserves_partition_by=partitionings.Singleton(),
- requires_partition_by=partitionings.Singleton()))
+ requires_partition_by=partitionings.Singleton(
+ reason="unique() cannot currently be parallelized.")))
def update(self, other):
self._expr = expressions.ComputedExpression(
@@ -1242,7 +1268,8 @@
elif _is_integer_slice(key):
# This depends on the contents of the index.
raise frame_base.WontImplementError(
- 'Use iloc or loc with integer slices.')
+ "Integer slices are not supported as they are ambiguous. Please "
+ "use iloc or loc with integer slices.")
else:
return self.loc[key]
@@ -1278,7 +1305,10 @@
@frame_base.populate_defaults(pd.DataFrame)
def align(self, other, join, axis, copy, level, method, **kwargs):
if not copy:
- raise frame_base.WontImplementError('align(copy=False)')
+ raise frame_base.WontImplementError(
+ "align(copy=False) is not supported because it might be an inplace "
+ "operation depending on the data. Please prefer the default "
+ "align(copy=True).")
if method is not None:
raise frame_base.WontImplementError(
f"align(method={method!r}) is not supported because it is "
@@ -1289,7 +1319,9 @@
if level is not None:
# Could probably get by partitioning on the used levels.
- requires_partition_by = partitionings.Singleton()
+ requires_partition_by = partitionings.Singleton(reason=(
+ f"align(level={level}) is not currently parallelizable. Only "
+ "align(level=None) can be parallelized."))
elif axis in ('columns', 1):
requires_partition_by = partitionings.Arbitrary()
else:
@@ -1314,16 +1346,21 @@
"append(ignore_index=True) is order sensitive because it requires "
"generating a new index based on the order of the data.",
reason="order-sensitive")
+
if verify_integrity:
- raise frame_base.WontImplementError(
- "append(verify_integrity=True) produces an execution time error")
+ # We can verify the index is non-unique within index partitioned data.
+ requires = partitionings.Index()
+ else:
+ requires = partitionings.Arbitrary()
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'append',
- lambda s, other: s.append(other, sort=sort, **kwargs),
+ lambda s, other: s.append(other, sort=sort,
+ verify_integrity=verify_integrity,
+ **kwargs),
[self._expr, other._expr],
- requires_partition_by=partitionings.Arbitrary(),
+ requires_partition_by=requires,
preserves_partition_by=partitionings.Arbitrary()
)
)
@@ -1391,8 +1428,6 @@
preserves_partition_by=preserves,
requires_partition_by=partitionings.Arbitrary()))
-
-
def aggregate(self, func, axis=0, *args, **kwargs):
if axis is None:
# Aggregate across all elements by first aggregating across columns,
@@ -1414,6 +1449,7 @@
'aggregate',
lambda df: df.agg(func, *args, **kwargs),
[self._expr],
+ # TODO(BEAM-11839): Provide a reason for this Singleton
requires_partition_by=partitionings.Singleton()))
else:
# In the general case, compute the aggregation of each column separately,
@@ -1499,12 +1535,15 @@
proxy=proxy))
else:
+ reason = (f"Encountered corr(method={method!r}) which cannot be "
+ "parallelized. Only corr(method='pearson') is currently "
+ "parallelizable.")
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'corr',
lambda df: df.corr(method=method, min_periods=min_periods),
[self._expr],
- requires_partition_by=partitionings.Singleton()))
+ requires_partition_by=partitionings.Singleton(reason=reason)))
@frame_base.args_to_kwargs(pd.DataFrame)
@frame_base.populate_defaults(pd.DataFrame)
@@ -1653,8 +1692,12 @@
'mode',
lambda df: df.mode(*args, **kwargs),
[self._expr],
- #TODO(robertwb): Approximate?
- requires_partition_by=partitionings.Singleton(),
+ #TODO(BEAM-12181): Can we add an approximate implementation?
+ requires_partition_by=partitionings.Singleton(reason=(
+ "mode(axis='index') cannot currently be parallelized. See "
+ "BEAM-12181 tracking the possble addition of an approximate, "
+ "parallelizable implementation of mode."
+ )),
preserves_partition_by=partitionings.Singleton()))
@frame_base.args_to_kwargs(pd.DataFrame)
@@ -1662,8 +1705,12 @@
@frame_base.maybe_inplace
def dropna(self, axis, **kwargs):
# TODO(robertwb): This is a common pattern. Generalize?
- if axis == 1 or axis == 'columns':
- requires_partition_by = partitionings.Singleton()
+ if axis in (1, 'columns'):
+ requires_partition_by = partitionings.Singleton(reason=(
+ "dropna(axis=1) cannot currently be parallelized. It requires "
+ "checking all values in each column for NaN values, to determine "
+ "if that column should be dropped."
+ ))
else:
requires_partition_by = partitionings.Arbitrary()
return frame_base.DeferredFrame.wrap(
@@ -1913,8 +1960,11 @@
requires_partition_by = partitionings.Arbitrary()
preserves_partition_by = partitionings.Index()
else:
- # TODO: This could be implemented in a distributed fashion
- requires_partition_by = partitionings.Singleton()
+ # TODO(BEAM-9547): This could be implemented in a distributed fashion,
+ # perhaps by deferring to a distributed drop_duplicates
+ requires_partition_by = partitionings.Singleton(reason=(
+ "nunique(axis='index') is not currently parallelizable."
+ ))
preserves_partition_by = partitionings.Singleton()
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
@@ -1941,22 +1991,31 @@
@frame_base.args_to_kwargs(pd.DataFrame)
@frame_base.populate_defaults(pd.DataFrame)
def quantile(self, q, axis, **kwargs):
- if axis in (1, 'columns') and isinstance(q, list):
- raise frame_base.WontImplementError(
- "quantile(axis=columns) with multiple q values is not supported "
- "because it transposes the input DataFrame. Note computing "
- "an individual quantile across columns (e.g. "
- f"df.quantile(q={q[0]!r}, axis={axis!r}) is supported.",
- reason="non-deferred-columns")
+ if axis in (1, 'columns'):
+ if isinstance(q, list):
+ raise frame_base.WontImplementError(
+ "quantile(axis=columns) with multiple q values is not supported "
+ "because it transposes the input DataFrame. Note computing "
+ "an individual quantile across columns (e.g. "
+ f"df.quantile(q={q[0]!r}, axis={axis!r}) is supported.",
+ reason="non-deferred-columns")
+ else:
+ requires = partitionings.Arbitrary()
+ else: # axis='index'
+ # TODO(BEAM-12167): Provide an option for approximate distributed
+ # quantiles
+ requires = partitionings.Singleton(reason=(
+ "Computing quantiles across index cannot currently be parallelized. "
+ "See BEAM-12167 tracking the possible addition of an approximate, "
+ "parallelizable implementation of quantile."
+ ))
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'quantile',
lambda df: df.quantile(q=q, axis=axis, **kwargs),
[self._expr],
- # TODO(BEAM-12167): Provide an option for approximate distributed
- # quantiles
- requires_partition_by=partitionings.Singleton(),
+ requires_partition_by=requires,
preserves_partition_by=partitionings.Singleton()))
@frame_base.args_to_kwargs(pd.DataFrame)
@@ -1978,8 +2037,15 @@
preserves_partition_by = partitionings.Index()
if kwargs.get('errors', None) == 'raise' and rename_index:
- # Renaming index with checking requires global index.
- requires_partition_by = partitionings.Singleton()
+ # TODO: We could do this in parallel by creating a ConstantExpression
+ # with a series created from the mapper dict. Then Index() partitioning
+ # would co-locate the necessary index values and we could raise
+ # individually within each partition. Execution time errors are
+ # discouraged anyway so probably not worth the effort.
+ requires_partition_by = partitionings.Singleton(reason=(
+ "rename(errors='raise', axis='index') requires collecting all "
+ "data on a single node in order to detect missing index values."
+ ))
else:
requires_partition_by = partitionings.Arbitrary()
@@ -2014,7 +2080,9 @@
if limit is None:
requires_partition_by = partitionings.Arbitrary()
else:
- requires_partition_by = partitionings.Singleton()
+ requires_partition_by = partitionings.Singleton(reason=(
+ f"replace(limit={limit!r}) cannot currently be parallelized, it "
+ "requires collecting all data on a single node."))
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'replace',
@@ -2032,8 +2100,11 @@
if level is not None and not isinstance(level, (tuple, list)):
level = [level]
if level is None or len(level) == self._expr.proxy().index.nlevels:
- # TODO: Could do distributed re-index with offsets.
- requires_partition_by = partitionings.Singleton()
+ # TODO(BEAM-12182): Could do distributed re-index with offsets.
+ requires_partition_by = partitionings.Singleton(reason=(
+ "reset_index(level={level!r}) drops the entire index and creates a "
+ "new one, so it cannot currently be parallelized (BEAM-12182)."
+ ))
else:
requires_partition_by = partitionings.Arbitrary()
return frame_base.DeferredFrame.wrap(
@@ -2070,20 +2141,37 @@
@frame_base.args_to_kwargs(pd.DataFrame)
@frame_base.populate_defaults(pd.DataFrame)
- def shift(self, axis, **kwargs):
- if 'freq' in kwargs:
- raise frame_base.WontImplementError('data-dependent')
- if axis == 1 or axis == 'columns':
- requires_partition_by = partitionings.Arbitrary()
+ def shift(self, axis, freq, **kwargs):
+ if axis in (1, 'columns'):
+ preserves = partitionings.Arbitrary()
+ proxy = None
else:
- requires_partition_by = partitionings.Singleton()
+ if freq is None or 'fill_value' in kwargs:
+ fill_value = kwargs.get('fill_value', 'NOT SET')
+ raise frame_base.WontImplementError(
+ f"shift(axis={axis!r}) is only supported with freq defined, and "
+ f"fill_value undefined (got freq={freq!r},"
+ f"fill_value={fill_value!r}). Other configurations are sensitive "
+ "to the order of the data because they require populating shifted "
+ "rows with `fill_value`.",
+ reason="order-sensitive")
+ # proxy generation fails in pandas <1.2
+ # Seems due to https://github.com/pandas-dev/pandas/issues/14811,
+ # bug with shift on empty indexes.
+ # Fortunately the proxy should be identical to the input.
+ proxy = self._expr.proxy().copy()
+
+ # index is modified, so no partitioning is preserved.
+ preserves = partitionings.Singleton()
+
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'shift',
- lambda df: df.shift(axis=axis, **kwargs),
+ lambda df: df.shift(axis=axis, freq=freq, **kwargs),
[self._expr],
- preserves_partition_by=partitionings.Singleton(),
- requires_partition_by=requires_partition_by))
+ proxy=proxy,
+ preserves_partition_by=preserves,
+ requires_partition_by=partitionings.Arbitrary()))
shape = property(frame_base.wont_implement_method(
pd.DataFrame, 'shape', reason="non-deferred-result"))
@@ -2388,7 +2476,10 @@
df.groupby(level=list(range(df.index.nlevels)), **groupby_kwargs),
**kwargs),
[pre_agg],
- requires_partition_by=(partitionings.Singleton()
+ requires_partition_by=(partitionings.Singleton(reason=(
+ "Aggregations grouped by a categorical column are not currently "
+ "parallelizable (BEAM-11190)."
+ ))
if is_categorical_grouping
else partitionings.Index()),
preserves_partition_by=partitionings.Arbitrary())
@@ -2416,7 +2507,10 @@
**groupby_kwargs),
), **kwargs),
[self._ungrouped],
- requires_partition_by=(partitionings.Singleton()
+ requires_partition_by=(partitionings.Singleton(reason=(
+ "Aggregations grouped by a categorical column are not currently "
+ "parallelizable (BEAM-11190)."
+ ))
if is_categorical_grouping
else partitionings.Index()),
preserves_partition_by=partitionings.Arbitrary())
@@ -2633,7 +2727,10 @@
def cat(self, others, join, **kwargs):
if others is None:
# Concatenate series into a single String
- requires = partitionings.Singleton()
+ requires = partitionings.Singleton(reason=(
+ "cat(others=None) concatenates all data in a Series into a single "
+ "string, so it requires collecting all data on a single node."
+ ))
func = lambda df: df.str.cat(join=join, **kwargs)
args = [self._expr]
diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py
index b692f08..1cf1dfb 100644
--- a/sdks/python/apache_beam/dataframe/frames_test.py
+++ b/sdks/python/apache_beam/dataframe/frames_test.py
@@ -45,12 +45,15 @@
class DeferredFrameTest(unittest.TestCase):
- def _run_error_test(self, func, *args):
+ def _run_error_test(
+ self, func, *args, construction_time=True, distributed=True):
"""Verify that func(*args) raises the same exception in pandas and in Beam.
- Note that for Beam this only checks for exceptions that are raised during
- expression generation (i.e. construction time). Execution time exceptions
- are not helpful."""
+ Note that by default this only checks for exceptions that the Beam DataFrame
+ API raises during expression generation (i.e. construction time).
+ Exceptions raised while the pipeline is executing are less helpful, but
+ are sometimes unavoidable (e.g. data validation exceptions), to check for
+ these exceptions use construction_time=False."""
deferred_args = _get_deferred_args(*args)
# Get expected error
@@ -64,14 +67,29 @@
f"returned:\n{expected}")
# Get actual error
- try:
- _ = func(*deferred_args)._expr
- except Exception as e:
- actual = e
- else:
- raise AssertionError(
- "Expected an error:\n{expected_error}\nbut Beam successfully "
- "generated an expression.")
+ if construction_time:
+ try:
+ _ = func(*deferred_args)._expr
+ except Exception as e:
+ actual = e
+ else:
+ raise AssertionError(
+ f"Expected an error:\n{expected_error}\nbut Beam successfully "
+ f"generated an expression.")
+ else: # not construction_time
+ # Check for an error raised during pipeline execution
+ expr = func(*deferred_args)._expr
+ session_type = (
+ expressions.PartitioningSession
+ if distributed else expressions.Session)
+ try:
+ result = session_type({}).evaluate(expr)
+ except Exception as e:
+ actual = e
+ else:
+ raise AssertionError(
+ f"Expected an error:\n{expected_error}\nbut Beam successfully "
+ f"Computed the result:\n{result}.")
# Verify
if (not isinstance(actual, type(expected_error)) or
@@ -99,8 +117,15 @@
deferred_args = _get_deferred_args(*args)
if nonparallel:
# First run outside a nonparallel block to confirm this raises as expected
- with self.assertRaises(expressions.NonParallelOperation):
- _ = func(*deferred_args)
+ with self.assertRaises(expressions.NonParallelOperation) as raised:
+ func(*deferred_args)
+
+ if raised.exception.msg.startswith(
+ "Encountered non-parallelizable form of"):
+ raise AssertionError(
+ "Default NonParallelOperation raised, please specify a reason in "
+ "the Singleton() partitioning requirement for this operation."
+ ) from raised.exception
# Re-run in an allow non parallel block to get an expression to verify
with beam.dataframe.allow_non_parallel_operations():
@@ -722,13 +747,14 @@
lambda x: (x.foo + x.bar).median()),
df)
- def test_quantile_axis_columns(self):
+ def test_quantile(self):
df = pd.DataFrame(
np.array([[1, 1], [2, 10], [3, 100], [4, 100]]), columns=['a', 'b'])
- with beam.dataframe.allow_non_parallel_operations():
- self._run_test(lambda df: df.quantile(0.1, axis='columns'), df)
+ self._run_test(lambda df: df.quantile(0.1), df, nonparallel=True)
+ self._run_test(lambda df: df.quantile([0.1, 0.9]), df, nonparallel=True)
+ self._run_test(lambda df: df.quantile(0.1, axis='columns'), df)
with self.assertRaisesRegex(frame_base.WontImplementError,
r"df\.quantile\(q=0\.1, axis='columns'\)"):
self._run_test(lambda df: df.quantile([0.1, 0.5], axis='columns'), df)
@@ -742,6 +768,7 @@
lambda df: df.groupby('foo', dropna=False).bar.count(), GROUPBY_DF)
def test_dataframe_melt(self):
+
df = pd.DataFrame({
'A': {
0: 'a', 1: 'b', 2: 'c'
@@ -784,6 +811,40 @@
id_vars=[('A', 'D')], value_vars=[('B', 'E')], ignore_index=False),
df)
+ def test_fillna_columns(self):
+ df = pd.DataFrame(
+ [[np.nan, 2, np.nan, 0], [3, 4, np.nan, 1], [np.nan, np.nan, np.nan, 5],
+ [np.nan, 3, np.nan, 4], [3, np.nan, np.nan, 4]],
+ columns=list('ABCD'))
+
+ self._run_test(lambda df: df.fillna(method='ffill', axis='columns'), df)
+ self._run_test(
+ lambda df: df.fillna(method='ffill', axis='columns', limit=1), df)
+ self._run_test(
+ lambda df: df.fillna(method='bfill', axis='columns', limit=1), df)
+
+ # Intended behavior is unclear here. See
+ # https://github.com/pandas-dev/pandas/issues/40989
+ # self._run_test(lambda df: df.fillna(axis='columns', value=100,
+ # limit=2), df)
+
+ def test_append_verify_integrity(self):
+ df1 = pd.DataFrame({'A': range(10), 'B': range(10)}, index=range(10))
+ df2 = pd.DataFrame({'A': range(10), 'B': range(10)}, index=range(9, 19))
+
+ self._run_error_test(
+ lambda s1,
+ s2: s1.append(s2, verify_integrity=True),
+ df1['A'],
+ df2['A'],
+ construction_time=False)
+ self._run_error_test(
+ lambda df1,
+ df2: df1.append(df2, verify_integrity=True),
+ df1,
+ df2,
+ construction_time=False)
+
class AllowNonParallelTest(unittest.TestCase):
def _use_non_parallel_operation(self):
diff --git a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py
index f720112..fcf18fa 100644
--- a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py
+++ b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py
@@ -40,7 +40,10 @@
'pandas.core.generic.NDFrame.first': ['*'],
'pandas.core.generic.NDFrame.head': ['*'],
'pandas.core.generic.NDFrame.last': ['*'],
- 'pandas.core.generic.NDFrame.shift': ['*'],
+ 'pandas.core.generic.NDFrame.shift': [
+ 'df.shift(periods=3)',
+ 'df.shift(periods=3, fill_value=0)',
+ ],
'pandas.core.generic.NDFrame.tail': ['*'],
'pandas.core.generic.NDFrame.take': ['*'],
'pandas.core.generic.NDFrame.values': ['*'],
@@ -189,8 +192,8 @@
'pandas.core.frame.DataFrame.transpose': ['*'],
'pandas.core.frame.DataFrame.shape': ['*'],
'pandas.core.frame.DataFrame.shift': [
- 'df.shift(periods=3, freq="D")',
- 'df.shift(periods=3, freq="infer")'
+ 'df.shift(periods=3)',
+ 'df.shift(periods=3, fill_value=0)',
],
'pandas.core.frame.DataFrame.unstack': ['*'],
'pandas.core.frame.DataFrame.memory_usage': ['*'],
@@ -395,7 +398,10 @@
],
'pandas.core.series.Series.pop': ['*'],
'pandas.core.series.Series.searchsorted': ['*'],
- 'pandas.core.series.Series.shift': ['*'],
+ 'pandas.core.series.Series.shift': [
+ 'df.shift(periods=3)',
+ 'df.shift(periods=3, fill_value=0)',
+ ],
'pandas.core.series.Series.take': ['*'],
'pandas.core.series.Series.to_dict': ['*'],
'pandas.core.series.Series.unique': ['*'],
diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py
index ef58023..afb71ba 100644
--- a/sdks/python/apache_beam/dataframe/partitionings.py
+++ b/sdks/python/apache_beam/dataframe/partitionings.py
@@ -151,6 +151,13 @@
class Singleton(Partitioning):
"""A partitioning of all the data into a single partition.
"""
+ def __init__(self, reason=None):
+ self._reason = reason
+
+ @property
+ def reason(self):
+ return self._reason
+
def __eq__(self, other):
return type(self) == type(other)
diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py
deleted file mode 100644
index 6e50cd2..0000000
--- a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py
+++ /dev/null
@@ -1,585 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""A connector for sending API requests to the GCP Recommendations AI
-API (https://cloud.google.com/recommendations).
-"""
-
-from __future__ import absolute_import
-
-from typing import Sequence
-from typing import Tuple
-
-from google.api_core.retry import Retry
-
-from apache_beam import pvalue
-from apache_beam.metrics import Metrics
-from apache_beam.options.pipeline_options import GoogleCloudOptions
-from apache_beam.transforms import DoFn
-from apache_beam.transforms import ParDo
-from apache_beam.transforms import PTransform
-from apache_beam.transforms.util import GroupIntoBatches
-from cachetools.func import ttl_cache
-
-# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
-try:
- from google.cloud import recommendationengine
-except ImportError:
- raise ImportError(
- 'Google Cloud Recommendation AI not supported for this execution '
- 'environment (could not import google.cloud.recommendationengine).')
-# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
-
-__all__ = [
- 'CreateCatalogItem',
- 'WriteUserEvent',
- 'ImportCatalogItems',
- 'ImportUserEvents',
- 'PredictUserEvent'
-]
-
-FAILED_CATALOG_ITEMS = "failed_catalog_items"
-
-
-@ttl_cache(maxsize=128, ttl=3600)
-def get_recommendation_prediction_client():
- """Returns a Recommendation AI - Prediction Service client."""
- _client = recommendationengine.PredictionServiceClient()
- return _client
-
-
-@ttl_cache(maxsize=128, ttl=3600)
-def get_recommendation_catalog_client():
- """Returns a Recommendation AI - Catalog Service client."""
- _client = recommendationengine.CatalogServiceClient()
- return _client
-
-
-@ttl_cache(maxsize=128, ttl=3600)
-def get_recommendation_user_event_client():
- """Returns a Recommendation AI - UserEvent Service client."""
- _client = recommendationengine.UserEventServiceClient()
- return _client
-
-
-class CreateCatalogItem(PTransform):
- """Creates catalogitem information.
- The ``PTranform`` returns a PCollectionTuple with a PCollections of
- successfully and failed created CatalogItems.
-
- Example usage::
-
- pipeline | CreateCatalogItem(
- project='example-gcp-project',
- catalog_name='my-catalog')
- """
- def __init__(
- self,
- project: str = None,
- retry: Retry = None,
- timeout: float = 120,
- metadata: Sequence[Tuple[str, str]] = None,
- catalog_name: str = "default_catalog"):
- """Initializes a :class:`CreateCatalogItem` transform.
-
- Args:
- project (str): Optional. GCP project name in which the catalog
- data will be imported.
- retry: Optional. Designation of what
- errors, if any, should be retried.
- timeout (float): Optional. The amount of time, in seconds, to wait
- for the request to complete.
- metadata: Optional. Strings which
- should be sent along with the request as metadata.
- catalog_name (str): Optional. Name of the catalog.
- Default: 'default_catalog'
- """
- self.project = project
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.catalog_name = catalog_name
-
- def expand(self, pcoll):
- if self.project is None:
- self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
- if self.project is None:
- raise ValueError(
- """GCP project name needs to be specified in "project" pipeline
- option""")
- return pcoll | ParDo(
- _CreateCatalogItemFn(
- self.project,
- self.retry,
- self.timeout,
- self.metadata,
- self.catalog_name))
-
-
-class _CreateCatalogItemFn(DoFn):
- def __init__(
- self,
- project: str = None,
- retry: Retry = None,
- timeout: float = 120,
- metadata: Sequence[Tuple[str, str]] = None,
- catalog_name: str = None):
- self._client = None
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}"
- self.counter = Metrics.counter(self.__class__, "api_calls")
-
- def setup(self):
- if self._client is None:
- self._client = get_recommendation_catalog_client()
-
- def process(self, element):
- catalog_item = recommendationengine.CatalogItem(element)
- request = recommendationengine.CreateCatalogItemRequest(
- parent=self.parent, catalog_item=catalog_item)
-
- try:
- created_catalog_item = self._client.create_catalog_item(
- request=request,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata)
-
- self.counter.inc()
- yield recommendationengine.CatalogItem.to_dict(created_catalog_item)
- except Exception:
- yield pvalue.TaggedOutput(
- FAILED_CATALOG_ITEMS,
- recommendationengine.CatalogItem.to_dict(catalog_item))
-
-
-class ImportCatalogItems(PTransform):
- """Imports catalogitems in bulk.
- The `PTransform` returns a PCollectionTuple with PCollections of
- successfully and failed imported CatalogItems.
-
- Example usage::
-
- pipeline
- | ImportCatalogItems(
- project='example-gcp-project',
- catalog_name='my-catalog')
- """
- def __init__(
- self,
- max_batch_size: int = 5000,
- project: str = None,
- retry: Retry = None,
- timeout: float = 120,
- metadata: Sequence[Tuple[str, str]] = None,
- catalog_name: str = "default_catalog"):
- """Initializes a :class:`ImportCatalogItems` transform
-
- Args:
- batch_size (int): Required. Maximum number of catalogitems per
- request.
- project (str): Optional. GCP project name in which the catalog
- data will be imported.
- retry: Optional. Designation of what
- errors, if any, should be retried.
- timeout (float): Optional. The amount of time, in seconds, to wait
- for the request to complete.
- metadata: Optional. Strings which
- should be sent along with the request as metadata.
- catalog_name (str): Optional. Name of the catalog.
- Default: 'default_catalog'
- """
- self.max_batch_size = max_batch_size
- self.project = project
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.catalog_name = catalog_name
-
- def expand(self, pcoll):
- if self.project is None:
- self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
- if self.project is None:
- raise ValueError(
- 'GCP project name needs to be specified in "project" pipeline option')
- return (
- pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo(
- _ImportCatalogItemsFn(
- self.project,
- self.retry,
- self.timeout,
- self.metadata,
- self.catalog_name)))
-
-
-class _ImportCatalogItemsFn(DoFn):
- def __init__(
- self,
- project=None,
- retry=None,
- timeout=120,
- metadata=None,
- catalog_name=None):
- self._client = None
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}"
- self.counter = Metrics.counter(self.__class__, "api_calls")
-
- def setup(self):
- if self._client is None:
- self.client = get_recommendation_catalog_client()
-
- def process(self, element):
- catalog_items = [recommendationengine.CatalogItem(e) for e in element[1]]
- catalog_inline_source = recommendationengine.CatalogInlineSource(
- {"catalog_items": catalog_items})
- input_config = recommendationengine.InputConfig(
- catalog_inline_source=catalog_inline_source)
-
- request = recommendationengine.ImportCatalogItemsRequest(
- parent=self.parent, input_config=input_config)
-
- try:
- operation = self._client.import_catalog_items(
- request=request,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata)
- self.counter.inc(len(catalog_items))
- yield operation.result()
- except Exception:
- yield pvalue.TaggedOutput(FAILED_CATALOG_ITEMS, catalog_items)
-
-
-class WriteUserEvent(PTransform):
- """Write user event information.
- The `PTransform` returns a PCollectionTuple with PCollections of
- successfully and failed written UserEvents.
-
- Example usage::
-
- pipeline
- | WriteUserEvent(
- project='example-gcp-project',
- catalog_name='my-catalog',
- event_store='my_event_store')
- """
- def __init__(
- self,
- project: str = None,
- retry: Retry = None,
- timeout: float = 120,
- metadata: Sequence[Tuple[str, str]] = None,
- catalog_name: str = "default_catalog",
- event_store: str = "default_event_store"):
- """Initializes a :class:`WriteUserEvent` transform.
-
- Args:
- project (str): Optional. GCP project name in which the catalog
- data will be imported.
- retry: Optional. Designation of what
- errors, if any, should be retried.
- timeout (float): Optional. The amount of time, in seconds, to wait
- for the request to complete.
- metadata: Optional. Strings which
- should be sent along with the request as metadata.
- catalog_name (str): Optional. Name of the catalog.
- Default: 'default_catalog'
- event_store (str): Optional. Name of the event store.
- Default: 'default_event_store'
- """
- self.project = project
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.catalog_name = catalog_name
- self.event_store = event_store
-
- def expand(self, pcoll):
- if self.project is None:
- self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
- if self.project is None:
- raise ValueError(
- 'GCP project name needs to be specified in "project" pipeline option')
- return pcoll | ParDo(
- _WriteUserEventFn(
- self.project,
- self.retry,
- self.timeout,
- self.metadata,
- self.catalog_name,
- self.event_store))
-
-
-class _WriteUserEventFn(DoFn):
- FAILED_USER_EVENTS = "failed_user_events"
-
- def __init__(
- self,
- project=None,
- retry=None,
- timeout=120,
- metadata=None,
- catalog_name=None,
- event_store=None):
- self._client = None
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.parent = f"projects/{project}/locations/global/catalogs/"\
- f"{catalog_name}/eventStores/{event_store}"
- self.counter = Metrics.counter(self.__class__, "api_calls")
-
- def setup(self):
- if self._client is None:
- self._client = get_recommendation_user_event_client()
-
- def process(self, element):
- user_event = recommendationengine.UserEvent(element)
- request = recommendationengine.WriteUserEventRequest(
- parent=self.parent, user_event=user_event)
-
- try:
- created_user_event = self._client.write_user_event(request)
- self.counter.inc()
- yield recommendationengine.UserEvent.to_dict(created_user_event)
- except Exception:
- yield pvalue.TaggedOutput(
- self.FAILED_USER_EVENTS,
- recommendationengine.UserEvent.to_dict(user_event))
-
-
-class ImportUserEvents(PTransform):
- """Imports userevents in bulk.
- The `PTransform` returns a PCollectionTuple with PCollections of
- successfully and failed imported UserEvents.
-
- Example usage::
-
- pipeline
- | ImportUserEvents(
- project='example-gcp-project',
- catalog_name='my-catalog',
- event_store='my_event_store')
- """
- def __init__(
- self,
- max_batch_size: int = 5000,
- project: str = None,
- retry: Retry = None,
- timeout: float = 120,
- metadata: Sequence[Tuple[str, str]] = None,
- catalog_name: str = "default_catalog",
- event_store: str = "default_event_store"):
- """Initializes a :class:`WriteUserEvent` transform.
-
- Args:
- batch_size (int): Required. Maximum number of catalogitems
- per request.
- project (str): Optional. GCP project name in which the catalog
- data will be imported.
- retry: Optional. Designation of what
- errors, if any, should be retried.
- timeout (float): Optional. The amount of time, in seconds, to wait
- for the request to complete.
- metadata: Optional. Strings which
- should be sent along with the request as metadata.
- catalog_name (str): Optional. Name of the catalog.
- Default: 'default_catalog'
- event_store (str): Optional. Name of the event store.
- Default: 'default_event_store'
- """
- self.max_batch_size = max_batch_size
- self.project = project
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.catalog_name = catalog_name
- self.event_store = event_store
-
- def expand(self, pcoll):
- if self.project is None:
- self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
- if self.project is None:
- raise ValueError(
- 'GCP project name needs to be specified in "project" pipeline option')
- return (
- pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo(
- _ImportUserEventsFn(
- self.project,
- self.retry,
- self.timeout,
- self.metadata,
- self.catalog_name,
- self.event_store)))
-
-
-class _ImportUserEventsFn(DoFn):
- FAILED_USER_EVENTS = "failed_user_events"
-
- def __init__(
- self,
- project=None,
- retry=None,
- timeout=120,
- metadata=None,
- catalog_name=None,
- event_store=None):
- self._client = None
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.parent = f"projects/{project}/locations/global/catalogs/"\
- f"{catalog_name}/eventStores/{event_store}"
- self.counter = Metrics.counter(self.__class__, "api_calls")
-
- def setup(self):
- if self._client is None:
- self.client = get_recommendation_user_event_client()
-
- def process(self, element):
-
- user_events = [recommendationengine.UserEvent(e) for e in element[1]]
- user_event_inline_source = recommendationengine.UserEventInlineSource(
- {"user_events": user_events})
- input_config = recommendationengine.InputConfig(
- user_event_inline_source=user_event_inline_source)
-
- request = recommendationengine.ImportUserEventsRequest(
- parent=self.parent, input_config=input_config)
-
- try:
- operation = self._client.write_user_event(request)
- self.counter.inc(len(user_events))
- yield recommendationengine.PredictResponse.to_dict(operation.result())
- except Exception:
- yield pvalue.TaggedOutput(self.FAILED_USER_EVENTS, user_events)
-
-
-class PredictUserEvent(PTransform):
- """Make a recommendation prediction.
- The `PTransform` returns a PCollection
-
- Example usage::
-
- pipeline
- | PredictUserEvent(
- project='example-gcp-project',
- catalog_name='my-catalog',
- event_store='my_event_store',
- placement_id='recently_viewed_default')
- """
- def __init__(
- self,
- project: str = None,
- retry: Retry = None,
- timeout: float = 120,
- metadata: Sequence[Tuple[str, str]] = None,
- catalog_name: str = "default_catalog",
- event_store: str = "default_event_store",
- placement_id: str = None):
- """Initializes a :class:`PredictUserEvent` transform.
-
- Args:
- project (str): Optional. GCP project name in which the catalog
- data will be imported.
- retry: Optional. Designation of what
- errors, if any, should be retried.
- timeout (float): Optional. The amount of time, in seconds, to wait
- for the request to complete.
- metadata: Optional. Strings which
- should be sent along with the request as metadata.
- catalog_name (str): Optional. Name of the catalog.
- Default: 'default_catalog'
- event_store (str): Optional. Name of the event store.
- Default: 'default_event_store'
- placement_id (str): Required. ID of the recommendation engine
- placement. This id is used to identify the set of models that
- will be used to make the prediction.
- """
- self.project = project
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.placement_id = placement_id
- self.catalog_name = catalog_name
- self.event_store = event_store
- if placement_id is None:
- raise ValueError('placement_id must be specified')
- else:
- self.placement_id = placement_id
-
- def expand(self, pcoll):
- if self.project is None:
- self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
- if self.project is None:
- raise ValueError(
- 'GCP project name needs to be specified in "project" pipeline option')
- return pcoll | ParDo(
- _PredictUserEventFn(
- self.project,
- self.retry,
- self.timeout,
- self.metadata,
- self.catalog_name,
- self.event_store,
- self.placement_id))
-
-
-class _PredictUserEventFn(DoFn):
- FAILED_PREDICTIONS = "failed_predictions"
-
- def __init__(
- self,
- project=None,
- retry=None,
- timeout=120,
- metadata=None,
- catalog_name=None,
- event_store=None,
- placement_id=None):
- self._client = None
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.name = f"projects/{project}/locations/global/catalogs/"\
- f"{catalog_name}/eventStores/{event_store}/placements/"\
- f"{placement_id}"
- self.counter = Metrics.counter(self.__class__, "api_calls")
-
- def setup(self):
- if self._client is None:
- self._client = get_recommendation_prediction_client()
-
- def process(self, element):
- user_event = recommendationengine.UserEvent(element)
- request = recommendationengine.PredictRequest(
- name=self.name, user_event=user_event)
-
- try:
- prediction = self._client.predict(request)
- self.counter.inc()
- yield [
- recommendationengine.PredictResponse.to_dict(p)
- for p in prediction.pages
- ]
- except Exception:
- yield pvalue.TaggedOutput(self.FAILED_PREDICTIONS, user_event)
diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py
deleted file mode 100644
index 2f688d9..0000000
--- a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py
+++ /dev/null
@@ -1,207 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""Unit tests for Recommendations AI transforms."""
-
-from __future__ import absolute_import
-
-import unittest
-
-import mock
-
-import apache_beam as beam
-from apache_beam.metrics import MetricsFilter
-
-# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
-try:
- from google.cloud import recommendationengine
- from apache_beam.ml.gcp import recommendations_ai
-except ImportError:
- recommendationengine = None
-# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
-
-
-@unittest.skipIf(
- recommendationengine is None,
- "Recommendations AI dependencies not installed.")
-class RecommendationsAICatalogItemTest(unittest.TestCase):
- def setUp(self):
- self._mock_client = mock.Mock()
- self._mock_client.create_catalog_item.return_value = (
- recommendationengine.CatalogItem())
- self.m2 = mock.Mock()
- self.m2.result.return_value = None
- self._mock_client.import_catalog_items.return_value = self.m2
-
- self._catalog_item = {
- "id": "12345",
- "title": "Sample laptop",
- "description": "Indisputably the most fantastic laptop ever created.",
- "language_code": "en",
- "category_hierarchies": [{
- "categories": ["Electronic", "Computers"]
- }]
- }
-
- def test_CreateCatalogItem(self):
- expected_counter = 1
- with mock.patch.object(recommendations_ai,
- 'get_recommendation_catalog_client',
- return_value=self._mock_client):
- p = beam.Pipeline()
-
- _ = (
- p | "Create data" >> beam.Create([self._catalog_item])
- | "Create CatalogItem" >>
- recommendations_ai.CreateCatalogItem(project="test"))
-
- result = p.run()
- result.wait_until_finish()
-
- read_filter = MetricsFilter().with_name('api_calls')
- query_result = result.metrics().query(read_filter)
- if query_result['counters']:
- read_counter = query_result['counters'][0]
- self.assertTrue(read_counter.result == expected_counter)
-
- def test_ImportCatalogItems(self):
- expected_counter = 1
- with mock.patch.object(recommendations_ai,
- 'get_recommendation_catalog_client',
- return_value=self._mock_client):
- p = beam.Pipeline()
-
- _ = (
- p | "Create data" >> beam.Create([
- (self._catalog_item["id"], self._catalog_item),
- (self._catalog_item["id"], self._catalog_item)
- ]) | "Create CatalogItems" >>
- recommendations_ai.ImportCatalogItems(project="test"))
-
- result = p.run()
- result.wait_until_finish()
-
- read_filter = MetricsFilter().with_name('api_calls')
- query_result = result.metrics().query(read_filter)
- if query_result['counters']:
- read_counter = query_result['counters'][0]
- self.assertTrue(read_counter.result == expected_counter)
-
-
-@unittest.skipIf(
- recommendationengine is None,
- "Recommendations AI dependencies not installed.")
-class RecommendationsAIUserEventTest(unittest.TestCase):
- def setUp(self):
- self._mock_client = mock.Mock()
- self._mock_client.write_user_event.return_value = (
- recommendationengine.UserEvent())
- self.m2 = mock.Mock()
- self.m2.result.return_value = None
- self._mock_client.import_user_events.return_value = self.m2
-
- self._user_event = {
- "event_type": "page-visit", "user_info": {
- "visitor_id": "1"
- }
- }
-
- def test_CreateUserEvent(self):
- expected_counter = 1
- with mock.patch.object(recommendations_ai,
- 'get_recommendation_user_event_client',
- return_value=self._mock_client):
- p = beam.Pipeline()
-
- _ = (
- p | "Create data" >> beam.Create([self._user_event])
- | "Create UserEvent" >>
- recommendations_ai.WriteUserEvent(project="test"))
-
- result = p.run()
- result.wait_until_finish()
-
- read_filter = MetricsFilter().with_name('api_calls')
- query_result = result.metrics().query(read_filter)
- if query_result['counters']:
- read_counter = query_result['counters'][0]
- self.assertTrue(read_counter.result == expected_counter)
-
- def test_ImportUserEvents(self):
- expected_counter = 1
- with mock.patch.object(recommendations_ai,
- 'get_recommendation_user_event_client',
- return_value=self._mock_client):
- p = beam.Pipeline()
-
- _ = (
- p | "Create data" >> beam.Create([
- (self._user_event["user_info"]["visitor_id"], self._user_event),
- (self._user_event["user_info"]["visitor_id"], self._user_event)
- ]) | "Create UserEvents" >>
- recommendations_ai.ImportUserEvents(project="test"))
-
- result = p.run()
- result.wait_until_finish()
-
- read_filter = MetricsFilter().with_name('api_calls')
- query_result = result.metrics().query(read_filter)
- if query_result['counters']:
- read_counter = query_result['counters'][0]
- self.assertTrue(read_counter.result == expected_counter)
-
-
-@unittest.skipIf(
- recommendationengine is None,
- "Recommendations AI dependencies not installed.")
-class RecommendationsAIPredictTest(unittest.TestCase):
- def setUp(self):
- self._mock_client = mock.Mock()
- self._mock_client.predict.return_value = [
- recommendationengine.PredictResponse()
- ]
-
- self._user_event = {
- "event_type": "page-visit", "user_info": {
- "visitor_id": "1"
- }
- }
-
- def test_Predict(self):
- expected_counter = 1
- with mock.patch.object(recommendations_ai,
- 'get_recommendation_prediction_client',
- return_value=self._mock_client):
- p = beam.Pipeline()
-
- _ = (
- p | "Create data" >> beam.Create([self._user_event])
- | "Prediction UserEvents" >> recommendations_ai.PredictUserEvent(
- project="test", placement_id="recently_viewed_default"))
-
- result = p.run()
- result.wait_until_finish()
-
- read_filter = MetricsFilter().with_name('api_calls')
- query_result = result.metrics().query(read_filter)
- if query_result['counters']:
- read_counter = query_result['counters'][0]
- self.assertTrue(read_counter.result == expected_counter)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py
deleted file mode 100644
index 19e6b9e..0000000
--- a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py
+++ /dev/null
@@ -1,107 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""Integration tests for Recommendations AI transforms."""
-
-from __future__ import absolute_import
-
-import random
-import unittest
-
-from nose.plugins.attrib import attr
-
-import apache_beam as beam
-from apache_beam.testing.test_pipeline import TestPipeline
-from apache_beam.testing.util import assert_that
-from apache_beam.testing.util import equal_to
-from apache_beam.testing.util import is_not_empty
-
-# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
-try:
- from google.cloud import recommendationengine
- from apache_beam.ml.gcp import recommendations_ai
-except ImportError:
- recommendationengine = None
-# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
-
-
-def extract_id(response):
- yield response["id"]
-
-
-def extract_event_type(response):
- yield response["event_type"]
-
-
-def extract_prediction(response):
- yield response[0]["results"]
-
-
-@attr('IT')
-@unittest.skipIf(
- recommendationengine is None,
- "Recommendations AI dependencies not installed.")
-class RecommendationAIIT(unittest.TestCase):
- def test_create_catalog_item(self):
-
- CATALOG_ITEM = {
- "id": str(int(random.randrange(100000))),
- "title": "Sample laptop",
- "description": "Indisputably the most fantastic laptop ever created.",
- "language_code": "en",
- "category_hierarchies": [{
- "categories": ["Electronic", "Computers"]
- }]
- }
-
- with TestPipeline(is_integration_test=True) as p:
- output = (
- p | 'Create data' >> beam.Create([CATALOG_ITEM])
- | 'Create CatalogItem' >>
- recommendations_ai.CreateCatalogItem(project=p.get_option('project'))
- | beam.ParDo(extract_id) | beam.combiners.ToList())
-
- assert_that(output, equal_to([[CATALOG_ITEM["id"]]]))
-
- def test_create_user_event(self):
- USER_EVENT = {"event_type": "page-visit", "user_info": {"visitor_id": "1"}}
-
- with TestPipeline(is_integration_test=True) as p:
- output = (
- p | 'Create data' >> beam.Create([USER_EVENT]) | 'Create UserEvent' >>
- recommendations_ai.WriteUserEvent(project=p.get_option('project'))
- | beam.ParDo(extract_event_type) | beam.combiners.ToList())
-
- assert_that(output, equal_to([[USER_EVENT["event_type"]]]))
-
- def test_create_predict(self):
- USER_EVENT = {"event_type": "page-visit", "user_info": {"visitor_id": "1"}}
-
- with TestPipeline(is_integration_test=True) as p:
- output = (
- p | 'Create data' >> beam.Create([USER_EVENT])
- | 'Predict UserEvent' >> recommendations_ai.PredictUserEvent(
- project=p.get_option('project'),
- placement_id="recently_viewed_default")
- | beam.ParDo(extract_prediction))
-
- assert_that(output, is_not_empty())
-
-
-if __name__ == '__main__':
- print(recommendationengine.CatalogItem.__module__)
- unittest.main()
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 7117a55..775f321 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -200,7 +200,6 @@
'google-cloud-language>=1.3.0,<2',
'google-cloud-videointelligence>=1.8.0,<2',
'google-cloud-vision>=0.38.0,<2',
- 'google-cloud-recommendations-ai>=0.1.0,<=0.2.0'
]
INTERACTIVE_BEAM = [