Merge pull request #15477 from echauchot/BEAM-5172-ignore-es-flaky-tests
[BEAM-5172] Temporary ignore testSplit and testSizes tests waiting for a fix because they are flaky.
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 57bae57..4cdf4ae 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -444,20 +444,20 @@
def checkerframework_version = "3.10.0"
def classgraph_version = "4.8.104"
def errorprone_version = "2.3.4"
- def google_clients_version = "1.31.0"
+ def google_clients_version = "1.32.1"
def google_cloud_bigdataoss_version = "2.2.2"
def google_cloud_pubsublite_version = "0.13.2"
def google_code_gson_version = "2.8.6"
def google_oauth_clients_version = "1.31.0"
// Try to keep grpc_version consistent with gRPC version in google_cloud_platform_libraries_bom
- def grpc_version = "1.37.0"
- def guava_version = "30.1-jre"
+ def grpc_version = "1.40.1"
+ def guava_version = "30.1.1-jre"
def hadoop_version = "2.10.1"
def hamcrest_version = "2.1"
def influxdb_version = "2.19"
- def httpclient_version = "4.5.10"
- def httpcore_version = "4.4.12"
- def jackson_version = "2.12.1"
+ def httpclient_version = "4.5.13"
+ def httpcore_version = "4.4.14"
+ def jackson_version = "2.12.4"
def jaxb_api_version = "2.3.3"
def jsr305_version = "3.0.2"
def kafka_version = "2.4.1"
@@ -466,7 +466,7 @@
def postgres_version = "42.2.16"
def powermock_version = "2.0.9"
// Try to keep protobuf_version consistent with the protobuf version in google_cloud_platform_libraries_bom
- def protobuf_version = "3.15.8"
+ def protobuf_version = "3.17.3"
def quickcheck_version = "0.8"
def slf4j_version = "1.7.30"
def spark_version = "2.4.8"
@@ -519,7 +519,7 @@
cassandra_driver_core : "com.datastax.cassandra:cassandra-driver-core:$cassandra_driver_version",
cassandra_driver_mapping : "com.datastax.cassandra:cassandra-driver-mapping:$cassandra_driver_version",
classgraph : "io.github.classgraph:classgraph:$classgraph_version",
- commons_codec : "commons-codec:commons-codec:1.14",
+ commons_codec : "commons-codec:commons-codec:1.15",
commons_compress : "org.apache.commons:commons-compress:1.21",
commons_csv : "org.apache.commons:commons-csv:1.8",
commons_io : "commons-io:commons-io:2.6",
@@ -530,23 +530,23 @@
gax : "com.google.api:gax", // google_cloud_platform_libraries_bom sets version
gax_grpc : "com.google.api:gax-grpc", // google_cloud_platform_libraries_bom sets version
gax_httpjson : "com.google.api:gax-httpjson", // google_cloud_platform_libraries_bom sets version
- google_api_client : "com.google.api-client:google-api-client:1.31.1", // 1.31.1 is required to run 1.31.0 of google_clients_version below.
+ google_api_client : "com.google.api-client:google-api-client:$google_clients_version", // for the libraries using $google_clients_version below.
google_api_client_jackson2 : "com.google.api-client:google-api-client-jackson2:$google_clients_version",
google_api_client_java6 : "com.google.api-client:google-api-client-java6:$google_clients_version",
google_api_common : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version
- google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20210410-$google_clients_version",
- google_api_services_clouddebugger : "com.google.apis:google-api-services-clouddebugger:v2-rev20210326-$google_clients_version",
- google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20210331-$google_clients_version",
- google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20210408-$google_clients_version",
- google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20210603-$google_clients_version",
- google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20210322-$google_clients_version",
+ google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20210813-$google_clients_version",
+ google_api_services_clouddebugger : "com.google.apis:google-api-services-clouddebugger:v2-rev20210813-$google_clients_version",
+ google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20210815-$google_clients_version",
+ google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20210818-$google_clients_version",
+ google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20210806-$google_clients_version",
+ google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20210809-$google_clients_version",
google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20210127-$google_clients_version",
google_auth_library_credentials : "com.google.auth:google-auth-library-credentials", // google_cloud_platform_libraries_bom sets version
google_auth_library_oauth2_http : "com.google.auth:google-auth-library-oauth2-http", // google_cloud_platform_libraries_bom sets version
google_cloud_bigquery : "com.google.cloud:google-cloud-bigquery", // google_cloud_platform_libraries_bom sets version
- google_cloud_bigquery_storage : "com.google.cloud:google-cloud-bigquerystorage:1.21.1",
- google_cloud_bigtable_client_core : "com.google.cloud.bigtable:bigtable-client-core:1.19.1",
- google_cloud_bigtable_emulator : "com.google.cloud:google-cloud-bigtable-emulator:0.125.2",
+ google_cloud_bigquery_storage : "com.google.cloud:google-cloud-bigquerystorage", // google_cloud_platform_libraries_bom sets version
+ google_cloud_bigtable_client_core : "com.google.cloud.bigtable:bigtable-client-core:1.23.1",
+ google_cloud_bigtable_emulator : "com.google.cloud:google-cloud-bigtable-emulator:0.137.1",
google_cloud_core : "com.google.cloud:google-cloud-core", // google_cloud_platform_libraries_bom sets version
google_cloud_core_grpc : "com.google.cloud:google-cloud-core-grpc", // google_cloud_platform_libraries_bom sets version
google_cloud_datacatalog_v1beta1 : "com.google.cloud:google-cloud-datacatalog", // google_cloud_platform_libraries_bom sets version
@@ -558,7 +558,7 @@
// The GCP Libraries BOM dashboard shows the versions set by the BOM:
// https://storage.googleapis.com/cloud-opensource-java-dashboard/com.google.cloud/libraries-bom/20.0.0/artifact_details.html
// Update libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml
- google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:20.0.0",
+ google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:22.0.0",
google_cloud_spanner : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version
google_code_gson : "com.google.code.gson:gson:$google_code_gson_version",
// google-http-client's version is explicitly declared for sdks/java/maven-archetypes/examples
diff --git a/model/pipeline/src/main/proto/metrics.proto b/model/pipeline/src/main/proto/metrics.proto
index 10166e1..8f819b6 100644
--- a/model/pipeline/src/main/proto/metrics.proto
+++ b/model/pipeline/src/main/proto/metrics.proto
@@ -417,6 +417,9 @@
GCS_PROJECT_ID = 17 [(label_props) = { name: "GCS_PROJECT_ID"}];
DATASTORE_PROJECT = 18 [(label_props) = { name: "DATASTORE_PROJECT" }];
DATASTORE_NAMESPACE = 19 [(label_props) = { name: "DATASTORE_NAMESPACE" }];
+ BIGTABLE_PROJECT_ID = 20 [(label_props) = { name: "BIGTABLE_PROJECT_ID"}];
+ INSTANCE_ID = 21 [(label_props) = { name: "INSTANCE_ID"}];
+ TABLE_ID = 22 [(label_props) = { name: "TABLE_ID"}];
}
// A set of key and value labels which define the scope of the metric. For
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/Concatenate.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/Concatenate.java
new file mode 100644
index 0000000..c61152e
--- /dev/null
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/Concatenate.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.values.PCollection;
+
+/**
+ * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
+ *
+ * <p>For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this is
+ * expected to crash!
+ *
+ * <p>This is copied from the dataflow runner code.
+ *
+ * @param <T> the type of elements to concatenate.
+ */
+public class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
+ @Override
+ public List<T> createAccumulator() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public List<T> addInput(List<T> accumulator, T input) {
+ accumulator.add(input);
+ return accumulator;
+ }
+
+ @Override
+ public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
+ List<T> result = createAccumulator();
+ for (List<T> accumulator : accumulators) {
+ result.addAll(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public List<T> extractOutput(List<T> accumulator) {
+ return accumulator;
+ }
+
+ @Override
+ public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
+ return ListCoder.of(inputCoder);
+ }
+
+ @Override
+ public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
+ return ListCoder.of(inputCoder);
+ }
+}
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
index 800413c..3133cc1 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
@@ -33,6 +33,16 @@
projectId, datasetId, tableId);
}
+ public static String bigtableTableID(String project, String instance, String table) {
+ return String.format("projects/%s/instances/%s/tables/%s", project, instance, table);
+ }
+
+ public static String bigtableResource(String projectId, String instanceId, String tableId) {
+ return String.format(
+ "//bigtable.googleapis.com/projects/%s/instances/%s/tables/%s",
+ projectId, instanceId, tableId);
+ }
+
public static String datastoreResource(String projectId, String namespace) {
return String.format(
"//bigtable.googleapis.com/projects/%s/namespaces/%s", projectId, namespace);
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
index 88e55db..57ee58a 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
@@ -80,6 +80,9 @@
public static final String BIGQUERY_QUERY_NAME = "BIGQUERY_QUERY_NAME";
public static final String DATASTORE_PROJECT = "DATASTORE_PROJECT";
public static final String DATASTORE_NAMESPACE = "DATASTORE_NAMESPACE";
+ public static final String BIGTABLE_PROJECT_ID = "BIGTABLE_PROJECT_ID";
+ public static final String INSTANCE_ID = "INSTANCE_ID";
+ public static final String TABLE_ID = "TABLE_ID";
static {
// Note: One benefit of defining these strings above, instead of pulling them in from
@@ -109,6 +112,10 @@
checkArgument(DATASTORE_PROJECT.equals(extractLabel(MonitoringInfoLabels.DATASTORE_PROJECT)));
checkArgument(
DATASTORE_NAMESPACE.equals(extractLabel(MonitoringInfoLabels.DATASTORE_NAMESPACE)));
+ checkArgument(
+ BIGTABLE_PROJECT_ID.equals(extractLabel(MonitoringInfoLabels.BIGTABLE_PROJECT_ID)));
+ checkArgument(INSTANCE_ID.equals(extractLabel(MonitoringInfoLabels.INSTANCE_ID)));
+ checkArgument(TABLE_ID.equals(extractLabel(MonitoringInfoLabels.TABLE_ID)));
}
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
index 8f34413..7920fa4 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/CreateStreamingFlinkView.java
@@ -18,14 +18,11 @@
package org.apache.beam.runners.flink;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import org.apache.beam.runners.core.Concatenate;
import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
import org.apache.beam.runners.core.construction.ReplacementOutputs;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.transforms.Combine;
@@ -59,51 +56,6 @@
}
/**
- * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
- *
- * <p>For internal use by {@link CreateStreamingFlinkView}. This combiner requires that the input
- * {@link PCollection} fits in memory. For a large {@link PCollection} this is expected to crash!
- *
- * @param <T> the type of elements to concatenate.
- */
- private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
- @Override
- public List<T> createAccumulator() {
- return new ArrayList<>();
- }
-
- @Override
- public List<T> addInput(List<T> accumulator, T input) {
- accumulator.add(input);
- return accumulator;
- }
-
- @Override
- public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
- List<T> result = createAccumulator();
- for (List<T> accumulator : accumulators) {
- result.addAll(accumulator);
- }
- return result;
- }
-
- @Override
- public List<T> extractOutput(List<T> accumulator) {
- return accumulator;
- }
-
- @Override
- public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
-
- @Override
- public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
- }
-
- /**
* Creates a primitive {@link PCollectionView}.
*
* <p>For internal use only by runner implementors.
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index d798846..d3d69d0 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -26,7 +26,6 @@
import com.google.auto.service.AutoService;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@@ -39,6 +38,7 @@
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
+import org.apache.beam.runners.core.Concatenate;
import org.apache.beam.runners.core.construction.NativeTransforms;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
@@ -62,10 +62,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.join.UnionCoder;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -73,7 +70,6 @@
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
@@ -559,53 +555,6 @@
Iterables.getOnlyElement(transform.getTransform().getOutputsMap().values()), dataSource);
}
- /**
- * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
- *
- * <p>For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this is
- * expected to crash!
- *
- * <p>This is copied from the dataflow runner code.
- *
- * @param <T> the type of elements to concatenate.
- */
- private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
- @Override
- public List<T> createAccumulator() {
- return new ArrayList<>();
- }
-
- @Override
- public List<T> addInput(List<T> accumulator, T input) {
- accumulator.add(input);
- return accumulator;
- }
-
- @Override
- public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
- List<T> result = createAccumulator();
- for (List<T> accumulator : accumulators) {
- result.addAll(accumulator);
- }
- return result;
- }
-
- @Override
- public List<T> extractOutput(List<T> accumulator) {
- return accumulator;
- }
-
- @Override
- public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
-
- @Override
- public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
- }
-
private static void urnNotFound(
PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
throw new IllegalArgumentException(
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
index c2826e0..af5259e 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
@@ -21,12 +21,12 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import org.apache.beam.runners.core.Concatenate;
import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
@@ -48,10 +48,8 @@
import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.runners.AppliedPTransform;
@@ -59,7 +57,6 @@
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
@@ -425,53 +422,6 @@
}
}
- /**
- * Combiner that combines {@code T}s into a single {@code List<T>} containing all inputs.
- *
- * <p>For internal use to translate {@link GroupByKey}. For a large {@link PCollection} this is
- * expected to crash!
- *
- * <p>This is copied from the dataflow runner code.
- *
- * @param <T> the type of elements to concatenate.
- */
- private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
- @Override
- public List<T> createAccumulator() {
- return new ArrayList<>();
- }
-
- @Override
- public List<T> addInput(List<T> accumulator, T input) {
- accumulator.add(input);
- return accumulator;
- }
-
- @Override
- public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
- List<T> result = createAccumulator();
- for (List<T> accumulator : accumulators) {
- result.addAll(accumulator);
- }
- return result;
- }
-
- @Override
- public List<T> extractOutput(List<T> accumulator) {
- return accumulator;
- }
-
- @Override
- public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
-
- @Override
- public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
- }
-
private static class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
implements FlinkBatchPipelineTranslator.BatchTransformTranslator<
PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java
index 8644e73..1f8fbcc 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTransformOverride.java
@@ -17,12 +17,8 @@
*/
package org.apache.beam.runners.samza.translation;
-import java.util.ArrayList;
-import java.util.List;
+import org.apache.beam.runners.core.Concatenate;
import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.PTransform;
@@ -67,41 +63,4 @@
return input;
}
}
-
- private static class Concatenate<T> extends Combine.CombineFn<T, List<T>, List<T>> {
- @Override
- public List<T> createAccumulator() {
- return new ArrayList<>();
- }
-
- @Override
- public List<T> addInput(List<T> accumulator, T input) {
- accumulator.add(input);
- return accumulator;
- }
-
- @Override
- public List<T> mergeAccumulators(Iterable<List<T>> accumulators) {
- List<T> result = createAccumulator();
- for (List<T> accumulator : accumulators) {
- result.addAll(accumulator);
- }
- return result;
- }
-
- @Override
- public List<T> extractOutput(List<T> accumulator) {
- return accumulator;
- }
-
- @Override
- public Coder<List<T>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
-
- @Override
- public Coder<List<T>> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
- return ListCoder.of(inputCoder);
- }
- }
}
diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle
index 5b49e14..5a2e0f9 100644
--- a/runners/spark/spark_runner.gradle
+++ b/runners/spark/spark_runner.gradle
@@ -321,7 +321,7 @@
forkEvery 1
maxParallelForks 4
// Increase memory heap in order to avoid OOM errors
- jvmArgs '-Xmx4g'
+ jvmArgs '-Xmx7g'
useJUnit {
includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner'
// Unbounded
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
index 6391ba4..4fe26d7 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
@@ -17,26 +17,27 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-import java.io.Serializable;
-import org.apache.beam.runners.core.InMemoryStateInternals;
-import org.apache.beam.runners.core.StateInternals;
-import org.apache.beam.runners.core.StateInternalsFactory;
-import org.apache.beam.runners.core.SystemReduceFn;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.runners.core.Concatenate;
import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.KeyValueGroupedDataset;
+import scala.Tuple2;
class GroupByKeyTranslatorBatch<K, V>
implements TransformTranslator<
@@ -48,43 +49,62 @@
AbstractTranslationContext context) {
@SuppressWarnings("unchecked")
- final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) context.getInput();
- Dataset<WindowedValue<KV<K, V>>> input = context.getDataset(inputPCollection);
- WindowingStrategy<?, ?> windowingStrategy = inputPCollection.getWindowingStrategy();
- KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder();
- Coder<V> valueCoder = kvCoder.getValueCoder();
+ final PCollection<KV<K, V>> input = (PCollection<KV<K, V>>) context.getInput();
+ @SuppressWarnings("unchecked")
+ final PCollection<KV<K, List<V>>> output = (PCollection<KV<K, List<V>>>) context.getOutput();
+ final Combine.CombineFn<V, List<V>, List<V>> combineFn = new Concatenate<>();
- // group by key only
- Coder<K> keyCoder = kvCoder.getKeyCoder();
- KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly =
- input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
+ WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
- // group also by windows
- WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder =
- WindowedValue.FullWindowedValueCoder.of(
- KvCoder.of(keyCoder, IterableCoder.of(valueCoder)),
- windowingStrategy.getWindowFn().windowCoder());
- Dataset<WindowedValue<KV<K, Iterable<V>>>> output =
- groupByKeyOnly.flatMapGroups(
- new GroupAlsoByWindowViaOutputBufferFn<>(
- windowingStrategy,
- new InMemoryStateInternalsFactory<>(),
- SystemReduceFn.buffering(valueCoder),
- context.getSerializableOptions()),
- EncoderHelpers.fromBeamCoder(outputCoder));
+ Dataset<WindowedValue<KV<K, V>>> inputDataset = context.getDataset(input);
- context.putDataset(context.getOutput(), output);
- }
+ KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
+ Coder<K> keyCoder = inputCoder.getKeyCoder();
+ KvCoder<K, List<V>> outputKVCoder = (KvCoder<K, List<V>>) output.getCoder();
+ Coder<List<V>> outputCoder = outputKVCoder.getValueCoder();
- /**
- * In-memory state internals factory.
- *
- * @param <K> State key type.
- */
- static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable {
- @Override
- public StateInternals stateInternalsForKey(K key) {
- return InMemoryStateInternals.forKey(key);
+ KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupedDataset =
+ inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
+
+ Coder<List<V>> accumulatorCoder = null;
+ try {
+ accumulatorCoder =
+ combineFn.getAccumulatorCoder(
+ input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
+ } catch (CannotProvideCoderException e) {
+ throw new RuntimeException(e);
}
+
+ Dataset<Tuple2<K, Iterable<WindowedValue<List<V>>>>> combinedDataset =
+ groupedDataset.agg(
+ new AggregatorCombiner<K, V, List<V>, List<V>, BoundedWindow>(
+ combineFn, windowingStrategy, accumulatorCoder, outputCoder)
+ .toColumn());
+
+ // expand the list into separate elements and put the key back into the elements
+ WindowedValue.WindowedValueCoder<KV<K, List<V>>> wvCoder =
+ WindowedValue.FullWindowedValueCoder.of(
+ outputKVCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
+ Dataset<WindowedValue<KV<K, List<V>>>> outputDataset =
+ combinedDataset.flatMap(
+ (FlatMapFunction<
+ Tuple2<K, Iterable<WindowedValue<List<V>>>>, WindowedValue<KV<K, List<V>>>>)
+ tuple2 -> {
+ K key = tuple2._1();
+ Iterable<WindowedValue<List<V>>> windowedValues = tuple2._2();
+ List<WindowedValue<KV<K, List<V>>>> result = new ArrayList<>();
+ for (WindowedValue<List<V>> windowedValue : windowedValues) {
+ KV<K, List<V>> kv = KV.of(key, windowedValue.getValue());
+ result.add(
+ WindowedValue.of(
+ kv,
+ windowedValue.getTimestamp(),
+ windowedValue.getWindows(),
+ windowedValue.getPane()));
+ }
+ return result.iterator();
+ },
+ EncoderHelpers.fromBeamCoder(wvCoder));
+ context.putDataset(output, outputDataset);
}
}
diff --git a/sdks/go/examples/xlang/transforms.go b/sdks/go/examples/xlang/transforms.go
index acad214..3d410e8 100644
--- a/sdks/go/examples/xlang/transforms.go
+++ b/sdks/go/examples/xlang/transforms.go
@@ -20,7 +20,6 @@
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
- "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)
@@ -51,7 +50,7 @@
pl := beam.CrossLanguagePayload(prefixPayload{Data: prefix})
outT := beam.UnnamedOutput(typex.New(reflectx.String))
outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:prefix", pl, addr, beam.UnnamedInput(col), outT)
- return outs[graph.UnnamedOutputTag]
+ return outs[beam.UnnamedOutputTag()]
}
func CoGroupByKey(s beam.Scope, addr string, col1, col2 beam.PCollection) beam.PCollection {
@@ -59,21 +58,21 @@
namedInputs := map[string]beam.PCollection{"col1": col1, "col2": col2}
outT := beam.UnnamedOutput(typex.NewCoGBK(typex.New(reflectx.Int64), typex.New(reflectx.String)))
outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:cgbk", nil, addr, namedInputs, outT)
- return outs[graph.UnnamedOutputTag]
+ return outs[beam.UnnamedOutputTag()]
}
func CombinePerKey(s beam.Scope, addr string, col beam.PCollection) beam.PCollection {
s = s.Scope("XLangTest.CombinePerKey")
outT := beam.UnnamedOutput(typex.NewKV(typex.New(reflectx.String), typex.New(reflectx.Int64)))
outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:compk", nil, addr, beam.UnnamedInput(col), outT)
- return outs[graph.UnnamedOutputTag]
+ return outs[beam.UnnamedOutputTag()]
}
func CombineGlobally(s beam.Scope, addr string, col beam.PCollection) beam.PCollection {
s = s.Scope("XLangTest.CombineGlobally")
outT := beam.UnnamedOutput(typex.New(reflectx.Int64))
outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:comgl", nil, addr, beam.UnnamedInput(col), outT)
- return outs[graph.UnnamedOutputTag]
+ return outs[beam.UnnamedOutputTag()]
}
func Flatten(s beam.Scope, addr string, col1, col2 beam.PCollection) beam.PCollection {
@@ -81,14 +80,14 @@
namedInputs := map[string]beam.PCollection{"col1": col1, "col2": col2}
outT := beam.UnnamedOutput(typex.New(reflectx.Int64))
outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:flatten", nil, addr, namedInputs, outT)
- return outs[graph.UnnamedOutputTag]
+ return outs[beam.UnnamedOutputTag()]
}
func GroupByKey(s beam.Scope, addr string, col beam.PCollection) beam.PCollection {
s = s.Scope("XLangTest.GroupByKey")
outT := beam.UnnamedOutput(typex.NewCoGBK(typex.New(reflectx.String), typex.New(reflectx.Int64)))
outs := beam.CrossLanguage(s, "beam:transforms:xlang:test:gbk", nil, addr, beam.UnnamedInput(col), outT)
- return outs[graph.UnnamedOutputTag]
+ return outs[beam.UnnamedOutputTag()]
}
func Multi(s beam.Scope, addr string, main1, main2, side beam.PCollection) (mainOut, sideOut beam.PCollection) {
@@ -112,5 +111,5 @@
s = s.Scope("XLang.Count")
outT := beam.UnnamedOutput(typex.NewKV(typex.New(reflectx.String), typex.New(reflectx.Int64)))
c := beam.CrossLanguage(s, "beam:transforms:xlang:count", nil, addr, beam.UnnamedInput(col), outT)
- return c[graph.UnnamedOutputTag]
+ return c[beam.UnnamedOutputTag()]
}
diff --git a/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go b/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go
index f2317f0..8781cff 100644
--- a/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go
+++ b/sdks/go/pkg/beam/io/xlang/kafkaio/kafka.go
@@ -48,7 +48,6 @@
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
- "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)
@@ -131,7 +130,7 @@
pl := beam.CrossLanguagePayload(rpl)
outT := beam.UnnamedOutput(typex.NewKV(typex.New(rcfg.key), typex.New(rcfg.val)))
out := beam.CrossLanguage(s, readURN, pl, addr, nil, outT)
- return out[graph.UnnamedOutputTag]
+ return out[beam.UnnamedOutputTag()]
}
type readOption func(*readConfig)
diff --git a/sdks/go/pkg/beam/xlang.go b/sdks/go/pkg/beam/xlang.go
index 4ec0162..b9c6ec7 100644
--- a/sdks/go/pkg/beam/xlang.go
+++ b/sdks/go/pkg/beam/xlang.go
@@ -22,32 +22,44 @@
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
)
-// xlang exposes an API to execute cross-language transforms within the Go SDK.
-// It is experimental and likely to change. It exposes convenient wrappers
-// around the core functions to pass in any combination of named/unnamed
-// inputs/outputs.
-
// UnnamedInput is a helper function for passing single unnamed inputs to
-// `beam.CrossLanguage`.
+// beam.CrossLanguage.
//
// Example:
-// beam.CrossLanguage(s, urn, payload, addr, UnnamedInput(input), outputs);
+// beam.CrossLanguage(s, urn, payload, addr, UnnamedInput(input), outputs)
func UnnamedInput(col PCollection) map[string]PCollection {
return map[string]PCollection{graph.UnnamedInputTag: col}
}
// UnnamedOutput is a helper function for passing single unnamed output types to
-// `beam.CrossLanguage`.
+// beam.CrossLanguage. The associated output can be accessed with beam.UnnamedOutputTag.
//
// Example:
-// beam.CrossLanguage(s, urn, payload, addr, inputs, UnnamedOutput(output));
+// resultMap := beam.CrossLanguage(s, urn, payload, addr, inputs, UnnamedOutput(output));
+// result := resultMap[beam.UnnamedOutputTag()]
func UnnamedOutput(t FullType) map[string]FullType {
return map[string]FullType{graph.UnnamedOutputTag: t}
}
-// CrossLanguagePayload encodes a native Go struct into a payload for
-// cross-language transforms. To find the expected structure of a payload,
-// consult the documentation in the SDK performing the expansion.
+// UnnamedOutputTag provides the output tag used for an output passed to beam.UnnamedOutput.
+// Needed to retrieve the unnamed output PCollection from the result of beam.CrossLanguage.
+func UnnamedOutputTag() string {
+ return graph.UnnamedOutputTag
+}
+
+// CrossLanguagePayload encodes a native Go struct into a payload for cross-language transforms.
+// payloads are []byte encoded ExternalConfigurationPayload protobufs. In order to fill the
+// contents of the protobuf, the provided struct will be used to converted to a row encoded
+// representation with an accompanying schema, so the input struct must be compatible with schemas.
+//
+// See https://beam.apache.org/documentation/programming-guide/#schemas for basic information on
+// schemas, and pkg/beam/core/runtime/graphx/schema for details on schemas in the Go SDK.
+//
+// Example:
+// type stringPayload struct {
+// Data string
+// }
+// encodedPl := beam.CrossLanguagePayload(stringPayload{Data: "foo"})
func CrossLanguagePayload(pl interface{}) []byte {
bytes, err := xlangx.EncodeStructPayload(pl)
if err != nil {
@@ -56,8 +68,73 @@
return bytes
}
-// CrossLanguage executes a cross-language transform that uses named inputs and
-// returns named outputs.
+// CrossLanguage is a low-level transform for executing cross-language transforms written in other
+// SDKs. Because this is low-level, it is recommended to use one of the higher-level IO-specific
+// wrappers where available. These can be found in the pkg/beam/io/xlang subdirectory.
+// CrossLanguage is useful for executing cross-language transforms which do not have any existing
+// IO wrappers.
+//
+// Usage requires an address for an expansion service accessible during pipeline construction, a
+// URN identifying the desired transform, an optional payload with configuration information, and
+// input and output names. It outputs a map of named output PCollections.
+//
+// For more information on expansion services and other aspects of cross-language transforms in
+// general, refer to the Beam programming guide: https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines
+//
+// Payload
+//
+// Payloads are configuration data that some cross-language transforms require for expansion.
+// Consult the documentation of the transform in the source SDK to find out what payload data it
+// requires. If no payload is required, pass in nil.
+//
+// CrossLanguage accepts payloads as a []byte containing an encoded ExternalConfigurationPayload
+// protobuf. The helper function beam.CrossLanguagePayload is the recommended way to easily encode
+// a standard Go struct for use as a payload.
+//
+// Inputs and Outputs
+//
+// Like most transforms, any input PCollections must be provided. Unlike most transforms, output
+// types must be provided because Go cannot infer output types from external transforms.
+//
+// Inputs and outputs to a cross-language transform may be either named or unnamed. Named
+// inputs/outputs are used when there are more than one input/output, and are provided as maps with
+// names as keys. Unnamed inputs/outputs are used when there is only one, and a map can be quickly
+// constructed with the UnnamedInput and UnnamedOutput methods.
+//
+// An example of defining named inputs and outputs:
+//
+// namedInputs := map[string]beam.PCollection{"pcol1": pcol1, "pcol2": pcol2}
+// namedOutputTypes := map[string]typex.FullType{
+// "main": typex.New(reflectx.String),
+// "side": typex.New(reflectx.Int64),
+// }
+//
+// CrossLanguage outputs a map of PCollections with associated names. These names will match those
+// from provided named outputs. If the beam.UnnamedOutput method was used, the PCollection can be
+// retrieved with beam.UnnamedOutputTag().
+//
+// An example of retrieving named outputs from a call to CrossLanguage:
+//
+// outputs := beam.CrossLanguage(...)
+// mainPcol := outputs["main"]
+// sidePcol := outputs["side"]
+//
+// Example
+//
+// This example shows using CrossLanguage to execute the Prefix cross-language transform using an
+// expansion service running on localhost:8099. Prefix requires a payload containing a prefix to
+// prepend to every input string.
+//
+// type prefixPayload struct {
+// Data string
+// }
+// encodedPl := beam.CrossLanguagePayload(prefixPayload{Data: "foo"})
+// urn := "beam:transforms:xlang:test:prefix"
+// expansionAddr := "localhost:8099"
+// outputType := beam.UnnamedOutput(typex.New(reflectx.String))
+// input := beam.UnnamedInput(inputPcol)
+// outs := beam.CrossLanguage(s, urn, encodedPl, expansionAddr, input, outputType)
+// outPcol := outputs[beam.UnnamedOutputTag()]
func CrossLanguage(
s Scope,
urn string,
@@ -86,7 +163,9 @@
return mapNodeToPCollection(namedOutputs)
}
-// TryCrossLanguage coordinates the core functions required to execute the cross-language transform
+// TryCrossLanguage coordinates the core functions required to execute the cross-language transform.
+// This is mainly intended for internal use. For the general-use entry point, see
+// beam.CrossLanguage.
func TryCrossLanguage(s Scope, ext *graph.ExternalTransform, ins []*graph.Inbound, outs []*graph.Outbound) (map[string]*graph.Node, error) {
// Adding an edge in the graph corresponding to the ExternalTransform
edge, isBoundedUpdater := graph.NewCrossLanguage(s.real, s.scope, ext, ins, outs)
diff --git a/sdks/java/container/license_scripts/dep_urls_java.yaml b/sdks/java/container/license_scripts/dep_urls_java.yaml
index 36439e8..14dbb3e 100644
--- a/sdks/java/container/license_scripts/dep_urls_java.yaml
+++ b/sdks/java/container/license_scripts/dep_urls_java.yaml
@@ -41,7 +41,7 @@
'1.1.6':
type: "3-Clause BSD"
libraries-bom:
- '20.0.0':
+ '22.0.0':
license: "https://raw.githubusercontent.com/GoogleCloudPlatform/cloud-opensource-java/master/LICENSE"
type: "Apache License 2.0"
paranamer:
@@ -53,6 +53,6 @@
'1.5':
license: "https://git.tukaani.org/?p=xz-java.git;a=blob;f=COPYING;h=c1d404dc7a6f06a0437bf1055fedaa4a4c89d728;hb=9f1f97a26f090ffec6568c004a38c6534aa82b94"
jackson-bom:
- '2.12.1':
+ '2.12.4':
license: "https://raw.githubusercontent.com/FasterXML/jackson-bom/master/LICENSE"
type: "Apache License 2.0"
diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProviderTest.java
similarity index 97%
rename from sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
rename to sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProviderTest.java
index 5244108..e4e6de7 100644
--- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
+++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProviderTest.java
@@ -68,9 +68,9 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/** Tests for {@link JavaCLassLookupTransformProvider}. */
+/** Tests for {@link JavaClassLookupTransformProvider}. */
@RunWith(JUnit4.class)
-public class JavaCLassLookupTransformProviderTest {
+public class JavaClassLookupTransformProviderTest {
private static final String TEST_URN = "test:beam:transforms:count";
@@ -466,7 +466,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructor");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -485,7 +485,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethod");
payloadBuilder.setConstructorMethod("from");
Row constructorRow =
@@ -505,7 +505,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -547,7 +547,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor");
Row constructorRow =
Row.withSchema(
@@ -571,7 +571,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -606,7 +606,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithWrapperTypes");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -636,7 +636,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypes");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -681,7 +681,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithArray");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -718,7 +718,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithList");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -758,7 +758,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeArray");
Schema complexTypeSchema =
Schema.builder()
@@ -824,7 +824,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeList");
Schema complexTypeSchema =
Schema.builder()
@@ -889,7 +889,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
payloadBuilder.setConstructorMethod("from");
Row constructorRow =
@@ -934,7 +934,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
payloadBuilder.setConstructorMethod("from");
Row constructorRow =
@@ -977,7 +977,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations");
payloadBuilder.setConstructorMethod("create_transform");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
@@ -1019,7 +1019,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$UnavailableClass");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$UnavailableClass");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
.withFieldValue("strField1", "test_str_1")
@@ -1042,7 +1042,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructor");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("incorrectField", FieldType.STRING)))
.withFieldValue("incorrectField", "test_str_1")
@@ -1065,7 +1065,7 @@
ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
ExternalTransforms.JavaClassLookupPayload.newBuilder();
payloadBuilder.setClassName(
- "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
+ "org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
Row constructorRow =
Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
.withFieldValue("strField1", "test_str_1")
diff --git a/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
index ad11523..dd76f47 100644
--- a/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
+++ b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
@@ -19,19 +19,19 @@
version: v1
allowedClasses:
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructor
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethod
allowedConstructorMethods:
- from
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods
allowedBuilderMethods:
- withStrField2
- withIntField1
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod
allowedBuilderMethods:
- withFields
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods
allowedConstructorMethods:
- from
allowedBuilderMethods:
@@ -39,28 +39,28 @@
- withIntField1
- strField2
- intField1
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations
allowedConstructorMethods:
- create_transform
allowedBuilderMethods:
- abc
- xyz
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithWrapperTypes
allowedBuilderMethods:
- withDoubleWrapperField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypes
allowedBuilderMethods:
- withComplexTypeField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithArray
allowedBuilderMethods:
- withStrArrayField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithList
allowedBuilderMethods:
- withStrListField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeArray
allowedBuilderMethods:
- withComplexTypeArrayField
-- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList
+- className: org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProviderTest$DummyTransformWithComplexTypeList
allowedBuilderMethods:
- withComplexTypeListField
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
index e94dc2c..bd919dd 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
@@ -21,13 +21,9 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import com.datastax.driver.core.Cluster;
-import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.ConsistencyLevel;
import com.datastax.driver.core.PlainTextAuthProvider;
import com.datastax.driver.core.QueryOptions;
-import com.datastax.driver.core.ResultSet;
-import com.datastax.driver.core.ResultSetFuture;
-import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.SocketOptions;
import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
@@ -36,32 +32,32 @@
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.Iterator;
import java.util.List;
-import java.util.NoSuchElementException;
import java.util.Optional;
+import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -89,6 +85,11 @@
*
* }</pre>
*
+ * <p>Alternatively, one may use {@code CassandraIO.<Person>readAll()
+ * .withCoder(SerializableCoder.of(Person.class))} to query a subset of the Cassandra database by
+ * creating a PCollection of {@code CassandraIO.Read<Person>} each with their own query or
+ * RingRange.
+ *
* <h3>Writing to Apache Cassandra</h3>
*
* <p>{@code CassandraIO} provides a sink to write a collection of entities to Apache Cassandra.
@@ -137,6 +138,11 @@
return new AutoValue_CassandraIO_Read.Builder<T>().build();
}
+ /** Provide a {@link ReadAll} {@link PTransform} to read data from a Cassandra database. */
+ public static <T> ReadAll<T> readAll() {
+ return new AutoValue_CassandraIO_ReadAll.Builder<T>().build();
+ }
+
/** Provide a {@link Write} {@link PTransform} to write data to a Cassandra database. */
public static <T> Write<T> write() {
return Write.<T>builder(MutationType.WRITE).build();
@@ -186,6 +192,9 @@
abstract @Nullable SerializableFunction<Session, Mapper> mapperFactoryFn();
+ @Nullable
+ abstract ValueProvider<Set<RingRange>> ringRanges();
+
abstract Builder<T> builder();
/** Specify the hosts of the Apache Cassandra instances. */
@@ -371,6 +380,14 @@
return builder().setMapperFactoryFn(mapperFactory).build();
}
+ public Read<T> withRingRanges(Set<RingRange> ringRange) {
+ return withRingRanges(ValueProvider.StaticValueProvider.of(ringRange));
+ }
+
+ public Read<T> withRingRanges(ValueProvider<Set<RingRange>> ringRange) {
+ return builder().setRingRanges(ringRange).build();
+ }
+
@Override
public PCollection<T> expand(PBegin input) {
checkArgument((hosts() != null && port() != null), "WithHosts() and withPort() are required");
@@ -379,7 +396,69 @@
checkArgument(entity() != null, "withEntity() is required");
checkArgument(coder() != null, "withCoder() is required");
- return input.apply(org.apache.beam.sdk.io.Read.from(new CassandraSource<>(this, null)));
+ PCollection<Read<T>> splits =
+ input
+ .apply(Create.of(this))
+ .apply("Create Splits", ParDo.of(new SplitFn<T>()))
+ .setCoder(SerializableCoder.of(new TypeDescriptor<Read<T>>() {}));
+
+ return splits.apply("ReadAll", CassandraIO.<T>readAll().withCoder(coder()));
+ }
+
+ private static class SplitFn<T> extends DoFn<Read<T>, Read<T>> {
+ @ProcessElement
+ public void process(
+ @Element CassandraIO.Read<T> read, OutputReceiver<Read<T>> outputReceiver) {
+ Set<RingRange> ringRanges = getRingRanges(read);
+ for (RingRange rr : ringRanges) {
+ Set<RingRange> subset = ImmutableSet.<RingRange>of(rr);
+ outputReceiver.output(read.withRingRanges(ImmutableSet.of(rr)));
+ }
+ }
+
+ private static <T> Set<RingRange> getRingRanges(Read<T> read) {
+ try (Cluster cluster =
+ getCluster(
+ read.hosts(),
+ read.port(),
+ read.username(),
+ read.password(),
+ read.localDc(),
+ read.consistencyLevel(),
+ read.connectTimeout(),
+ read.readTimeout())) {
+ if (isMurmur3Partitioner(cluster)) {
+ LOG.info("Murmur3Partitioner detected, splitting");
+ Integer splitCount;
+ if (read.minNumberOfSplits() != null && read.minNumberOfSplits().get() != null) {
+ splitCount = read.minNumberOfSplits().get();
+ } else {
+ splitCount = cluster.getMetadata().getAllHosts().size();
+ }
+ List<BigInteger> tokens =
+ cluster.getMetadata().getTokenRanges().stream()
+ .map(tokenRange -> new BigInteger(tokenRange.getEnd().getValue().toString()))
+ .collect(Collectors.toList());
+ SplitGenerator splitGenerator =
+ new SplitGenerator(cluster.getMetadata().getPartitioner());
+
+ return splitGenerator.generateSplits(splitCount, tokens).stream()
+ .flatMap(List::stream)
+ .collect(Collectors.toSet());
+
+ } else {
+ LOG.warn(
+ "Only Murmur3Partitioner is supported for splitting, using an unique source for "
+ + "the read");
+ String partitioner = cluster.getMetadata().getPartitioner();
+ RingRange totalRingRange =
+ RingRange.of(
+ SplitGenerator.getRangeMin(partitioner),
+ SplitGenerator.getRangeMax(partitioner));
+ return Collections.singleton(totalRingRange);
+ }
+ }
+ }
}
@AutoValue.Builder
@@ -418,6 +497,8 @@
abstract Optional<SerializableFunction<Session, Mapper>> mapperFactoryFn();
+ abstract Builder<T> setRingRanges(ValueProvider<Set<RingRange>> ringRange);
+
abstract Read<T> autoBuild();
public Read<T> build() {
@@ -429,390 +510,6 @@
}
}
- @VisibleForTesting
- static class CassandraSource<T> extends BoundedSource<T> {
- final Read<T> spec;
- final List<String> splitQueries;
- // split source ached size - can't be calculated when already split
- Long estimatedSize;
- private static final String MURMUR3PARTITIONER = "org.apache.cassandra.dht.Murmur3Partitioner";
-
- CassandraSource(Read<T> spec, List<String> splitQueries) {
- this(spec, splitQueries, null);
- }
-
- private CassandraSource(Read<T> spec, List<String> splitQueries, Long estimatedSize) {
- this.estimatedSize = estimatedSize;
- this.spec = spec;
- this.splitQueries = splitQueries;
- }
-
- @Override
- public Coder<T> getOutputCoder() {
- return spec.coder();
- }
-
- @Override
- public BoundedReader<T> createReader(PipelineOptions pipelineOptions) {
- return new CassandraReader(this);
- }
-
- @Override
- public List<BoundedSource<T>> split(
- long desiredBundleSizeBytes, PipelineOptions pipelineOptions) {
- try (Cluster cluster =
- getCluster(
- spec.hosts(),
- spec.port(),
- spec.username(),
- spec.password(),
- spec.localDc(),
- spec.consistencyLevel(),
- spec.connectTimeout(),
- spec.readTimeout())) {
- if (isMurmur3Partitioner(cluster)) {
- LOG.info("Murmur3Partitioner detected, splitting");
- return splitWithTokenRanges(
- spec, desiredBundleSizeBytes, getEstimatedSizeBytes(pipelineOptions), cluster);
- } else {
- LOG.warn(
- "Only Murmur3Partitioner is supported for splitting, using a unique source for "
- + "the read");
- return Collections.singletonList(
- new CassandraIO.CassandraSource<>(spec, Collections.singletonList(buildQuery(spec))));
- }
- }
- }
-
- private static String buildQuery(Read spec) {
- return (spec.query() == null)
- ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
- : spec.query().get().toString();
- }
-
- /**
- * Compute the number of splits based on the estimated size and the desired bundle size, and
- * create several sources.
- */
- private List<BoundedSource<T>> splitWithTokenRanges(
- CassandraIO.Read<T> spec,
- long desiredBundleSizeBytes,
- long estimatedSizeBytes,
- Cluster cluster) {
- long numSplits =
- getNumSplits(desiredBundleSizeBytes, estimatedSizeBytes, spec.minNumberOfSplits());
- LOG.info("Number of desired splits is {}", numSplits);
-
- SplitGenerator splitGenerator = new SplitGenerator(cluster.getMetadata().getPartitioner());
- List<BigInteger> tokens =
- cluster.getMetadata().getTokenRanges().stream()
- .map(tokenRange -> new BigInteger(tokenRange.getEnd().getValue().toString()))
- .collect(Collectors.toList());
- List<List<RingRange>> splits = splitGenerator.generateSplits(numSplits, tokens);
- LOG.info("{} splits were actually generated", splits.size());
-
- final String partitionKey =
- cluster.getMetadata().getKeyspace(spec.keyspace().get()).getTable(spec.table().get())
- .getPartitionKey().stream()
- .map(ColumnMetadata::getName)
- .collect(Collectors.joining(","));
-
- List<TokenRange> tokenRanges =
- getTokenRanges(cluster, spec.keyspace().get(), spec.table().get());
- final long estimatedSize = getEstimatedSizeBytesFromTokenRanges(tokenRanges) / splits.size();
-
- List<BoundedSource<T>> sources = new ArrayList<>();
- for (List<RingRange> split : splits) {
- List<String> queries = new ArrayList<>();
- for (RingRange range : split) {
- if (range.isWrapping()) {
- // A wrapping range is one that overlaps from the end of the partitioner range and its
- // start (ie : when the start token of the split is greater than the end token)
- // We need to generate two queries here : one that goes from the start token to the end
- // of
- // the partitioner range, and the other from the start of the partitioner range to the
- // end token of the split.
- queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), null));
- // Generation of the second query of the wrapping range
- queries.add(generateRangeQuery(spec, partitionKey, null, range.getEnd()));
- } else {
- queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), range.getEnd()));
- }
- }
- sources.add(new CassandraIO.CassandraSource<>(spec, queries, estimatedSize));
- }
- return sources;
- }
-
- private static String generateRangeQuery(
- Read spec, String partitionKey, BigInteger rangeStart, BigInteger rangeEnd) {
- final String rangeFilter =
- Joiner.on(" AND ")
- .skipNulls()
- .join(
- rangeStart == null
- ? null
- : String.format("(token(%s) >= %d)", partitionKey, rangeStart),
- rangeEnd == null
- ? null
- : String.format("(token(%s) < %d)", partitionKey, rangeEnd));
- final String query =
- (spec.query() == null)
- ? buildQuery(spec) + " WHERE " + rangeFilter
- : buildQuery(spec) + " AND " + rangeFilter;
- LOG.debug("CassandraIO generated query : {}", query);
- return query;
- }
-
- private static long getNumSplits(
- long desiredBundleSizeBytes,
- long estimatedSizeBytes,
- @Nullable ValueProvider<Integer> minNumberOfSplits) {
- long numSplits =
- desiredBundleSizeBytes > 0 ? (estimatedSizeBytes / desiredBundleSizeBytes) : 1;
- if (numSplits <= 0) {
- LOG.warn("Number of splits is less than 0 ({}), fallback to 1", numSplits);
- numSplits = 1;
- }
- return minNumberOfSplits != null ? Math.max(numSplits, minNumberOfSplits.get()) : numSplits;
- }
-
- /**
- * Returns cached estimate for split or if missing calculate size for whole table. Highly
- * innacurate if query is specified.
- *
- * @param pipelineOptions
- * @return
- */
- @Override
- public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) {
- if (estimatedSize != null) {
- return estimatedSize;
- } else {
- try (Cluster cluster =
- getCluster(
- spec.hosts(),
- spec.port(),
- spec.username(),
- spec.password(),
- spec.localDc(),
- spec.consistencyLevel(),
- spec.connectTimeout(),
- spec.readTimeout())) {
- if (isMurmur3Partitioner(cluster)) {
- try {
- List<TokenRange> tokenRanges =
- getTokenRanges(cluster, spec.keyspace().get(), spec.table().get());
- this.estimatedSize = getEstimatedSizeBytesFromTokenRanges(tokenRanges);
- return this.estimatedSize;
- } catch (Exception e) {
- LOG.warn("Can't estimate the size", e);
- return 0L;
- }
- } else {
- LOG.warn("Only Murmur3 partitioner is supported, can't estimate the size");
- return 0L;
- }
- }
- }
- }
-
- @VisibleForTesting
- static long getEstimatedSizeBytesFromTokenRanges(List<TokenRange> tokenRanges) {
- long size = 0L;
- for (TokenRange tokenRange : tokenRanges) {
- size = size + tokenRange.meanPartitionSize * tokenRange.partitionCount;
- }
- return Math.round(size / getRingFraction(tokenRanges));
- }
-
- @Override
- public void populateDisplayData(DisplayData.Builder builder) {
- super.populateDisplayData(builder);
- if (spec.hosts() != null) {
- builder.add(DisplayData.item("hosts", spec.hosts().toString()));
- }
- if (spec.port() != null) {
- builder.add(DisplayData.item("port", spec.port()));
- }
- builder.addIfNotNull(DisplayData.item("keyspace", spec.keyspace()));
- builder.addIfNotNull(DisplayData.item("table", spec.table()));
- builder.addIfNotNull(DisplayData.item("username", spec.username()));
- builder.addIfNotNull(DisplayData.item("localDc", spec.localDc()));
- builder.addIfNotNull(DisplayData.item("consistencyLevel", spec.consistencyLevel()));
- }
- // ------------- CASSANDRA SOURCE UTIL METHODS ---------------//
-
- /**
- * Gets the list of token ranges that a table occupies on a give Cassandra node.
- *
- * <p>NB: This method is compatible with Cassandra 2.1.5 and greater.
- */
- private static List<TokenRange> getTokenRanges(Cluster cluster, String keyspace, String table) {
- try (Session session = cluster.newSession()) {
- ResultSet resultSet =
- session.execute(
- "SELECT range_start, range_end, partitions_count, mean_partition_size FROM "
- + "system.size_estimates WHERE keyspace_name = ? AND table_name = ?",
- keyspace,
- table);
-
- ArrayList<TokenRange> tokenRanges = new ArrayList<>();
- for (Row row : resultSet) {
- TokenRange tokenRange =
- new TokenRange(
- row.getLong("partitions_count"),
- row.getLong("mean_partition_size"),
- new BigInteger(row.getString("range_start")),
- new BigInteger(row.getString("range_end")));
- tokenRanges.add(tokenRange);
- }
- // The table may not contain the estimates yet
- // or have partitions_count and mean_partition_size fields = 0
- // if the data was just inserted and the amount of data in the table was small.
- // This is very common situation during tests,
- // when we insert a few rows and immediately query them.
- // However, for tiny data sets the lack of size estimates is not a problem at all,
- // because we don't want to split tiny data anyways.
- // Therefore, we're not issuing a warning if the result set was empty
- // or mean_partition_size and partitions_count = 0.
- return tokenRanges;
- }
- }
-
- /** Compute the percentage of token addressed compared with the whole tokens in the cluster. */
- @VisibleForTesting
- static double getRingFraction(List<TokenRange> tokenRanges) {
- double ringFraction = 0;
- for (TokenRange tokenRange : tokenRanges) {
- ringFraction =
- ringFraction
- + (distance(tokenRange.rangeStart, tokenRange.rangeEnd).doubleValue()
- / SplitGenerator.getRangeSize(MURMUR3PARTITIONER).doubleValue());
- }
- return ringFraction;
- }
-
- /**
- * Check if the current partitioner is the Murmur3 (default in Cassandra version newer than 2).
- */
- @VisibleForTesting
- static boolean isMurmur3Partitioner(Cluster cluster) {
- return MURMUR3PARTITIONER.equals(cluster.getMetadata().getPartitioner());
- }
-
- /** Measure distance between two tokens. */
- @VisibleForTesting
- static BigInteger distance(BigInteger left, BigInteger right) {
- return (right.compareTo(left) > 0)
- ? right.subtract(left)
- : right.subtract(left).add(SplitGenerator.getRangeSize(MURMUR3PARTITIONER));
- }
-
- /**
- * Represent a token range in Cassandra instance, wrapping the partition count, size and token
- * range.
- */
- @VisibleForTesting
- static class TokenRange {
- private final long partitionCount;
- private final long meanPartitionSize;
- private final BigInteger rangeStart;
- private final BigInteger rangeEnd;
-
- TokenRange(
- long partitionCount, long meanPartitionSize, BigInteger rangeStart, BigInteger rangeEnd) {
- this.partitionCount = partitionCount;
- this.meanPartitionSize = meanPartitionSize;
- this.rangeStart = rangeStart;
- this.rangeEnd = rangeEnd;
- }
- }
-
- private class CassandraReader extends BoundedSource.BoundedReader<T> {
- private final CassandraIO.CassandraSource<T> source;
- private Cluster cluster;
- private Session session;
- private Iterator<T> iterator;
- private T current;
-
- CassandraReader(CassandraSource<T> source) {
- this.source = source;
- }
-
- @Override
- public boolean start() {
- LOG.debug("Starting Cassandra reader");
- cluster =
- getCluster(
- source.spec.hosts(),
- source.spec.port(),
- source.spec.username(),
- source.spec.password(),
- source.spec.localDc(),
- source.spec.consistencyLevel(),
- source.spec.connectTimeout(),
- source.spec.readTimeout());
- session = cluster.connect(source.spec.keyspace().get());
- LOG.debug("Queries: " + source.splitQueries);
- List<ResultSetFuture> futures = new ArrayList<>();
- for (String query : source.splitQueries) {
- futures.add(session.executeAsync(query));
- }
-
- final Mapper<T> mapper = getMapper(session, source.spec.entity());
-
- for (ResultSetFuture result : futures) {
- if (iterator == null) {
- iterator = mapper.map(result.getUninterruptibly());
- } else {
- iterator = Iterators.concat(iterator, mapper.map(result.getUninterruptibly()));
- }
- }
-
- return advance();
- }
-
- @Override
- public boolean advance() {
- if (iterator.hasNext()) {
- current = iterator.next();
- return true;
- }
- current = null;
- return false;
- }
-
- @Override
- public void close() {
- LOG.debug("Closing Cassandra reader");
- if (session != null) {
- session.close();
- }
- if (cluster != null) {
- cluster.close();
- }
- }
-
- @Override
- public T getCurrent() throws NoSuchElementException {
- if (current == null) {
- throw new NoSuchElementException();
- }
- return current;
- }
-
- @Override
- public CassandraIO.CassandraSource<T> getCurrentSource() {
- return source;
- }
-
- private Mapper<T> getMapper(Session session, Class<T> enitity) {
- return source.spec.mapperFactoryFn().apply(session);
- }
- }
- }
-
/** Specify the mutation type: either write or delete. */
public enum MutationType {
WRITE,
@@ -1179,7 +876,7 @@
}
/** Get a Cassandra cluster using hosts and port. */
- private static Cluster getCluster(
+ static Cluster getCluster(
ValueProvider<List<String>> hosts,
ValueProvider<Integer> port,
ValueProvider<String> username,
@@ -1301,4 +998,53 @@
}
}
}
+
+ /**
+ * A {@link PTransform} to read data from Apache Cassandra. See {@link CassandraIO} for more
+ * information on usage and configuration.
+ */
+ @AutoValue
+ public abstract static class ReadAll<T> extends PTransform<PCollection<Read<T>>, PCollection<T>> {
+ @AutoValue.Builder
+ abstract static class Builder<T> {
+
+ abstract Builder<T> setCoder(Coder<T> coder);
+
+ abstract ReadAll<T> autoBuild();
+
+ public ReadAll<T> build() {
+ return autoBuild();
+ }
+ }
+
+ @Nullable
+ abstract Coder<T> coder();
+
+ abstract Builder<T> builder();
+
+ /** Specify the {@link Coder} used to serialize the entity in the {@link PCollection}. */
+ public ReadAll<T> withCoder(Coder<T> coder) {
+ checkArgument(coder != null, "coder can not be null");
+ return builder().setCoder(coder).build();
+ }
+
+ @Override
+ public PCollection<T> expand(PCollection<Read<T>> input) {
+ checkArgument(coder() != null, "withCoder() is required");
+ return input
+ .apply("Reshuffle", Reshuffle.viaRandomKey())
+ .apply("Read", ParDo.of(new ReadFn<>()))
+ .setCoder(this.coder());
+ }
+ }
+
+ /**
+ * Check if the current partitioner is the Murmur3 (default in Cassandra version newer than 2).
+ */
+ @VisibleForTesting
+ private static boolean isMurmur3Partitioner(Cluster cluster) {
+ return MURMUR3PARTITIONER.equals(cluster.getMetadata().getPartitioner());
+ }
+
+ private static final String MURMUR3PARTITIONER = "org.apache.cassandra.dht.Murmur3Partitioner";
}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java
new file mode 100644
index 0000000..5091ac4
--- /dev/null
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.Session;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
+import org.apache.beam.sdk.options.ValueProvider;
+
+@SuppressWarnings({
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public class ConnectionManager {
+
+ private static final ConcurrentHashMap<String, Cluster> clusterMap =
+ new ConcurrentHashMap<String, Cluster>();
+ private static final ConcurrentHashMap<String, Session> sessionMap =
+ new ConcurrentHashMap<String, Session>();
+
+ static {
+ Runtime.getRuntime()
+ .addShutdownHook(
+ new Thread(
+ () -> {
+ for (Session session : sessionMap.values()) {
+ if (!session.isClosed()) {
+ session.close();
+ }
+ }
+ }));
+ }
+
+ private static String readToClusterHash(Read<?> read) {
+ return Objects.requireNonNull(read.hosts()).get().stream().reduce(",", (a, b) -> a + b)
+ + Objects.requireNonNull(read.port()).get()
+ + safeVPGet(read.localDc())
+ + safeVPGet(read.consistencyLevel());
+ }
+
+ private static String readToSessionHash(Read<?> read) {
+ return readToClusterHash(read) + read.keyspace().get();
+ }
+
+ static Session getSession(Read<?> read) {
+ Cluster cluster =
+ clusterMap.computeIfAbsent(
+ readToClusterHash(read),
+ k ->
+ CassandraIO.getCluster(
+ Objects.requireNonNull(read.hosts()),
+ Objects.requireNonNull(read.port()),
+ read.username(),
+ read.password(),
+ read.localDc(),
+ read.consistencyLevel(),
+ read.connectTimeout(),
+ read.readTimeout()));
+ return sessionMap.computeIfAbsent(
+ readToSessionHash(read),
+ k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));
+ }
+
+ private static String safeVPGet(ValueProvider<String> s) {
+ return s != null ? s.get() : "";
+ }
+}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
index 8f6d578..92ec2c5 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
@@ -34,7 +34,7 @@
})
class DefaultObjectMapper<T> implements Mapper<T>, Serializable {
- private transient com.datastax.driver.mapping.Mapper<T> mapper;
+ private final transient com.datastax.driver.mapping.Mapper<T> mapper;
DefaultObjectMapper(com.datastax.driver.mapping.Mapper mapper) {
this.mapper = mapper;
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
index 7976665..ef75ff3 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
@@ -34,7 +34,7 @@
class DefaultObjectMapperFactory<T> implements SerializableFunction<Session, Mapper> {
private transient MappingManager mappingManager;
- Class<T> entity;
+ final Class<T> entity;
DefaultObjectMapperFactory(Class<T> entity) {
this.entity = entity;
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java
new file mode 100644
index 0000000..193cdf0
--- /dev/null
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.ColumnMetadata;
+import com.datastax.driver.core.PreparedStatement;
+import com.datastax.driver.core.ResultSet;
+import com.datastax.driver.core.Session;
+import com.datastax.driver.core.Token;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@SuppressWarnings({
+ "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+class ReadFn<T> extends DoFn<Read<T>, T> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ReadFn.class);
+
+ @ProcessElement
+ public void processElement(@Element Read<T> read, OutputReceiver<T> receiver) {
+ try {
+ Session session = ConnectionManager.getSession(read);
+ Mapper<T> mapper = read.mapperFactoryFn().apply(session);
+ String partitionKey =
+ session.getCluster().getMetadata().getKeyspace(read.keyspace().get())
+ .getTable(read.table().get()).getPartitionKey().stream()
+ .map(ColumnMetadata::getName)
+ .collect(Collectors.joining(","));
+
+ String query = generateRangeQuery(read, partitionKey, read.ringRanges() != null);
+ PreparedStatement preparedStatement = session.prepare(query);
+ Set<RingRange> ringRanges =
+ read.ringRanges() == null ? Collections.emptySet() : read.ringRanges().get();
+
+ for (RingRange rr : ringRanges) {
+ Token startToken = session.getCluster().getMetadata().newToken(rr.getStart().toString());
+ Token endToken = session.getCluster().getMetadata().newToken(rr.getEnd().toString());
+ ResultSet rs =
+ session.execute(preparedStatement.bind().setToken(0, startToken).setToken(1, endToken));
+ Iterator<T> iter = mapper.map(rs);
+ while (iter.hasNext()) {
+ T n = iter.next();
+ receiver.output(n);
+ }
+ }
+
+ if (read.ringRanges() == null) {
+ ResultSet rs = session.execute(preparedStatement.bind());
+ Iterator<T> iter = mapper.map(rs);
+ while (iter.hasNext()) {
+ receiver.output(iter.next());
+ }
+ }
+ } catch (Exception ex) {
+ LOG.error("error", ex);
+ }
+ }
+
+ private Session getSession(Read<T> read) {
+ Cluster cluster =
+ CassandraIO.getCluster(
+ read.hosts(),
+ read.port(),
+ read.username(),
+ read.password(),
+ read.localDc(),
+ read.consistencyLevel(),
+ read.connectTimeout(),
+ read.readTimeout());
+
+ return cluster.connect(read.keyspace().get());
+ }
+
+ private static String generateRangeQuery(
+ Read<?> spec, String partitionKey, Boolean hasRingRange) {
+ final String rangeFilter =
+ (hasRingRange)
+ ? Joiner.on(" AND ")
+ .skipNulls()
+ .join(
+ String.format("(token(%s) >= ?)", partitionKey),
+ String.format("(token(%s) < ?)", partitionKey))
+ : "";
+ final String combinedQuery = buildInitialQuery(spec, hasRingRange) + rangeFilter;
+ LOG.debug("CassandraIO generated query : {}", combinedQuery);
+ return combinedQuery;
+ }
+
+ private static String buildInitialQuery(Read<?> spec, Boolean hasRingRange) {
+ return (spec.query() == null)
+ ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
+ + " WHERE "
+ : spec.query().get() + (hasRingRange ? " AND " : "");
+ }
+}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
index b5f94d7..c83e47f 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
@@ -17,23 +17,31 @@
*/
package org.apache.beam.sdk.io.cassandra;
+import java.io.Serializable;
import java.math.BigInteger;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Experimental.Kind;
/** Models a Cassandra token range. */
-final class RingRange {
+@Experimental(Kind.SOURCE_SINK)
+@SuppressWarnings({
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public final class RingRange implements Serializable {
private final BigInteger start;
private final BigInteger end;
- RingRange(BigInteger start, BigInteger end) {
+ private RingRange(BigInteger start, BigInteger end) {
this.start = start;
this.end = end;
}
- BigInteger getStart() {
+ public BigInteger getStart() {
return start;
}
- BigInteger getEnd() {
+ public BigInteger getEnd() {
return end;
}
@@ -55,4 +63,34 @@
public String toString() {
return String.format("(%s,%s]", start.toString(), end.toString());
}
+
+ public static RingRange of(BigInteger start, BigInteger end) {
+ return new RingRange(start, end);
+ }
+
+ @Override
+ public boolean equals(@Nullable Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ RingRange ringRange = (RingRange) o;
+
+ if (getStart() != null
+ ? !getStart().equals(ringRange.getStart())
+ : ringRange.getStart() != null) {
+ return false;
+ }
+ return getEnd() != null ? getEnd().equals(ringRange.getEnd()) : ringRange.getEnd() == null;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = getStart() != null ? getStart().hashCode() : 0;
+ result = 31 * result + (getEnd() != null ? getEnd().hashCode() : 0);
+ return result;
+ }
}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
index de49421..bc1205a 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
@@ -39,22 +39,22 @@
this.partitioner = partitioner;
}
- private static BigInteger getRangeMin(String partitioner) {
+ static BigInteger getRangeMin(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
return BigInteger.ZERO;
} else if (partitioner.endsWith("Murmur3Partitioner")) {
- return new BigInteger("2").pow(63).negate();
+ return BigInteger.valueOf(2).pow(63).negate();
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
}
}
- private static BigInteger getRangeMax(String partitioner) {
+ static BigInteger getRangeMax(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
- return new BigInteger("2").pow(127).subtract(BigInteger.ONE);
+ return BigInteger.valueOf(2).pow(127).subtract(BigInteger.ONE);
} else if (partitioner.endsWith("Murmur3Partitioner")) {
- return new BigInteger("2").pow(63).subtract(BigInteger.ONE);
+ return BigInteger.valueOf(2).pow(63).subtract(BigInteger.ONE);
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
@@ -84,7 +84,7 @@
BigInteger start = ringTokens.get(i);
BigInteger stop = ringTokens.get((i + 1) % tokenRangeCount);
- if (!inRange(start) || !inRange(stop)) {
+ if (!isInRange(start) || !isInRange(stop)) {
throw new RuntimeException(
String.format("Tokens (%s,%s) not in range of %s", start, stop, partitioner));
}
@@ -127,7 +127,7 @@
// Append the splits between the endpoints
for (int j = 0; j < splitCount; j++) {
- splits.add(new RingRange(endpointTokens.get(j), endpointTokens.get(j + 1)));
+ splits.add(RingRange.of(endpointTokens.get(j), endpointTokens.get(j + 1)));
LOG.debug("Split #{}: [{},{})", j + 1, endpointTokens.get(j), endpointTokens.get(j + 1));
}
}
@@ -144,7 +144,7 @@
return coalesceSplits(getTargetSplitSize(totalSplitCount), splits);
}
- private boolean inRange(BigInteger token) {
+ private boolean isInRange(BigInteger token) {
return !(token.compareTo(rangeMin) < 0 || token.compareTo(rangeMax) > 0);
}
diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
index a52808b..131ce83 100644
--- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
+++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
@@ -18,23 +18,18 @@
package org.apache.beam.sdk.io.cassandra;
import static junit.framework.TestCase.assertTrue;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.distance;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.getEstimatedSizeBytesFromTokenRanges;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.getRingFraction;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.isMurmur3Partitioner;
-import static org.apache.beam.sdk.testing.SourceTestUtils.readFromSource;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.Metadata;
+import com.datastax.driver.core.ProtocolVersion;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
+import com.datastax.driver.core.TypeCodec;
import com.datastax.driver.core.exceptions.NoHostAvailableException;
-import com.datastax.driver.core.querybuilder.QueryBuilder;
+import com.datastax.driver.mapping.annotations.ClusteringColumn;
import com.datastax.driver.mapping.annotations.Column;
import com.datastax.driver.mapping.annotations.Computed;
import com.datastax.driver.mapping.annotations.PartitionKey;
@@ -44,10 +39,14 @@
import java.io.IOException;
import java.io.Serializable;
import java.math.BigInteger;
+import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
@@ -61,13 +60,8 @@
import javax.management.remote.JMXConnectorFactory;
import javax.management.remote.JMXServiceURL;
import org.apache.beam.sdk.coders.SerializableCoder;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.TokenRange;
import org.apache.beam.sdk.io.common.NetworkTestHelper;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.testing.SourceTestUtils;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
@@ -82,7 +76,6 @@
import org.apache.cassandra.service.StorageServiceMBean;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.AfterClass;
-import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Rule;
@@ -99,7 +92,7 @@
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
})
public class CassandraIOTest implements Serializable {
- private static final long NUM_ROWS = 20L;
+ private static final long NUM_ROWS = 22L;
private static final String CASSANDRA_KEYSPACE = "beam_ks";
private static final String CASSANDRA_HOST = "127.0.0.1";
private static final String CASSANDRA_TABLE = "scientist";
@@ -190,39 +183,44 @@
LOG.info("Create Cassandra tables");
session.execute(
String.format(
- "CREATE TABLE IF NOT EXISTS %s.%s(person_id int, person_name text, PRIMARY KEY"
- + "(person_id));",
+ "CREATE TABLE IF NOT EXISTS %s.%s(person_department text, person_id int, person_name text, PRIMARY KEY"
+ + "((person_department), person_id));",
CASSANDRA_KEYSPACE, CASSANDRA_TABLE));
session.execute(
String.format(
- "CREATE TABLE IF NOT EXISTS %s.%s(person_id int, person_name text, PRIMARY KEY"
- + "(person_id));",
+ "CREATE TABLE IF NOT EXISTS %s.%s(person_department text, person_id int, person_name text, PRIMARY KEY"
+ + "((person_department), person_id));",
CASSANDRA_KEYSPACE, CASSANDRA_TABLE_WRITE));
LOG.info("Insert records");
- String[] scientists = {
- "Einstein",
- "Darwin",
- "Copernicus",
- "Pasteur",
- "Curie",
- "Faraday",
- "Newton",
- "Bohr",
- "Galilei",
- "Maxwell"
+ String[][] scientists = {
+ new String[] {"phys", "Einstein"},
+ new String[] {"bio", "Darwin"},
+ new String[] {"phys", "Copernicus"},
+ new String[] {"bio", "Pasteur"},
+ new String[] {"bio", "Curie"},
+ new String[] {"phys", "Faraday"},
+ new String[] {"math", "Newton"},
+ new String[] {"phys", "Bohr"},
+ new String[] {"phys", "Galileo"},
+ new String[] {"math", "Maxwell"},
+ new String[] {"logic", "Russel"},
};
for (int i = 0; i < NUM_ROWS; i++) {
int index = i % scientists.length;
- session.execute(
+ String insertStr =
String.format(
- "INSERT INTO %s.%s(person_id, person_name) values("
+ "INSERT INTO %s.%s(person_department, person_id, person_name) values("
+ + "'"
+ + scientists[index][0]
+ + "', "
+ i
+ ", '"
- + scientists[index]
+ + scientists[index][1]
+ "');",
CASSANDRA_KEYSPACE,
- CASSANDRA_TABLE));
+ CASSANDRA_TABLE);
+ session.execute(insertStr);
}
flushMemTablesAndRefreshSizeEstimates();
}
@@ -277,25 +275,6 @@
}
@Test
- public void testEstimatedSizeBytes() throws Exception {
- PipelineOptions pipelineOptions = PipelineOptionsFactory.create();
- CassandraIO.Read<Scientist> read =
- CassandraIO.<Scientist>read()
- .withHosts(Collections.singletonList(CASSANDRA_HOST))
- .withPort(cassandraPort)
- .withKeyspace(CASSANDRA_KEYSPACE)
- .withTable(CASSANDRA_TABLE);
- CassandraIO.CassandraSource<Scientist> source = new CassandraIO.CassandraSource<>(read, null);
- long estimatedSizeBytes = source.getEstimatedSizeBytes(pipelineOptions);
- // the size is non determanistic in Cassandra backend: checks that estimatedSizeBytes >= 12960L
- // -20% && estimatedSizeBytes <= 12960L +20%
- assertThat(
- "wrong estimated size in " + CASSANDRA_KEYSPACE + "/" + CASSANDRA_TABLE,
- estimatedSizeBytes,
- greaterThan(0L));
- }
-
- @Test
public void testRead() throws Exception {
PCollection<Scientist> output =
pipeline.apply(
@@ -304,6 +283,7 @@
.withPort(cassandraPort)
.withKeyspace(CASSANDRA_KEYSPACE)
.withTable(CASSANDRA_TABLE)
+ .withMinNumberOfSplits(50)
.withCoder(SerializableCoder.of(Scientist.class))
.withEntity(Scientist.class));
@@ -321,9 +301,113 @@
PAssert.that(mapped.apply("Count occurrences per scientist", Count.perKey()))
.satisfies(
input -> {
+ int count = 0;
for (KV<String, Long> element : input) {
+ count++;
assertEquals(element.getKey(), NUM_ROWS / 10, element.getValue().longValue());
}
+ assertEquals(11, count);
+ return null;
+ });
+
+ pipeline.run();
+ }
+
+ private CassandraIO.Read<Scientist> getReadWithRingRange(RingRange... rr) {
+ return CassandraIO.<Scientist>read()
+ .withHosts(Collections.singletonList(CASSANDRA_HOST))
+ .withPort(cassandraPort)
+ .withRingRanges(new HashSet<>(Arrays.asList(rr)))
+ .withKeyspace(CASSANDRA_KEYSPACE)
+ .withTable(CASSANDRA_TABLE)
+ .withCoder(SerializableCoder.of(Scientist.class))
+ .withEntity(Scientist.class);
+ }
+
+ private CassandraIO.Read<Scientist> getReadWithQuery(String query) {
+ return CassandraIO.<Scientist>read()
+ .withHosts(Collections.singletonList(CASSANDRA_HOST))
+ .withPort(cassandraPort)
+ .withQuery(query)
+ .withKeyspace(CASSANDRA_KEYSPACE)
+ .withTable(CASSANDRA_TABLE)
+ .withCoder(SerializableCoder.of(Scientist.class))
+ .withEntity(Scientist.class);
+ }
+
+ @Test
+ public void testReadAllQuery() {
+ String physQuery =
+ String.format(
+ "SELECT * From %s.%s WHERE person_department='phys' AND person_id=0;",
+ CASSANDRA_KEYSPACE, CASSANDRA_TABLE);
+
+ String mathQuery =
+ String.format(
+ "SELECT * From %s.%s WHERE person_department='math' AND person_id=6;",
+ CASSANDRA_KEYSPACE, CASSANDRA_TABLE);
+
+ PCollection<Scientist> output =
+ pipeline
+ .apply(Create.of(getReadWithQuery(physQuery), getReadWithQuery(mathQuery)))
+ .apply(
+ CassandraIO.<Scientist>readAll().withCoder(SerializableCoder.of(Scientist.class)));
+
+ PCollection<String> mapped =
+ output.apply(
+ MapElements.via(
+ new SimpleFunction<Scientist, String>() {
+ @Override
+ public String apply(Scientist scientist) {
+ return scientist.name;
+ }
+ }));
+ PAssert.that(mapped).containsInAnyOrder("Einstein", "Newton");
+ PAssert.thatSingleton(output.apply("count", Count.globally())).isEqualTo(2L);
+ pipeline.run();
+ }
+
+ @Test
+ public void testReadAllRingRange() {
+ RingRange physRR =
+ fromEncodedKey(
+ cluster.getMetadata(), TypeCodec.varchar().serialize("phys", ProtocolVersion.V3));
+
+ RingRange mathRR =
+ fromEncodedKey(
+ cluster.getMetadata(), TypeCodec.varchar().serialize("math", ProtocolVersion.V3));
+
+ RingRange logicRR =
+ fromEncodedKey(
+ cluster.getMetadata(), TypeCodec.varchar().serialize("logic", ProtocolVersion.V3));
+
+ PCollection<Scientist> output =
+ pipeline
+ .apply(Create.of(getReadWithRingRange(physRR), getReadWithRingRange(mathRR, logicRR)))
+ .apply(
+ CassandraIO.<Scientist>readAll().withCoder(SerializableCoder.of(Scientist.class)));
+
+ PCollection<KV<String, Integer>> mapped =
+ output.apply(
+ MapElements.via(
+ new SimpleFunction<Scientist, KV<String, Integer>>() {
+ @Override
+ public KV<String, Integer> apply(Scientist scientist) {
+ return KV.of(scientist.department, scientist.id);
+ }
+ }));
+
+ PAssert.that(mapped.apply("Count occurrences per department", Count.perKey()))
+ .satisfies(
+ input -> {
+ HashMap<String, Long> map = new HashMap<>();
+ for (KV<String, Long> element : input) {
+ map.put(element.getKey(), element.getValue());
+ }
+ assertEquals(3, map.size()); // do we have all three departments
+ assertEquals(map.get("phys"), 10L, 0L);
+ assertEquals(map.get("math"), 4L, 0L);
+ assertEquals(map.get("logic"), 2L, 0L);
return null;
});
@@ -339,8 +423,9 @@
.withPort(cassandraPort)
.withKeyspace(CASSANDRA_KEYSPACE)
.withTable(CASSANDRA_TABLE)
+ .withMinNumberOfSplits(20)
.withQuery(
- "select person_id, writetime(person_name) from beam_ks.scientist where person_id=10")
+ "select person_id, writetime(person_name) from beam_ks.scientist where person_id=10 AND person_department='logic'")
.withCoder(SerializableCoder.of(Scientist.class))
.withEntity(Scientist.class));
@@ -365,6 +450,7 @@
ScientistWrite scientist = new ScientistWrite();
scientist.id = i;
scientist.name = "Name " + i;
+ scientist.department = "bio";
data.add(scientist);
}
@@ -485,52 +571,6 @@
assertEquals(1, counter.intValue());
}
- @Test
- public void testSplit() throws Exception {
- PipelineOptions options = PipelineOptionsFactory.create();
- CassandraIO.Read<Scientist> read =
- CassandraIO.<Scientist>read()
- .withHosts(Collections.singletonList(CASSANDRA_HOST))
- .withPort(cassandraPort)
- .withKeyspace(CASSANDRA_KEYSPACE)
- .withTable(CASSANDRA_TABLE)
- .withEntity(Scientist.class)
- .withCoder(SerializableCoder.of(Scientist.class));
-
- // initialSource will be read without splitting (which does not happen in production)
- // so we need to provide splitQueries to avoid NPE in source.reader.start()
- String splitQuery = QueryBuilder.select().from(CASSANDRA_KEYSPACE, CASSANDRA_TABLE).toString();
- CassandraIO.CassandraSource<Scientist> initialSource =
- new CassandraIO.CassandraSource<>(read, Collections.singletonList(splitQuery));
- int desiredBundleSizeBytes = 2048;
- long estimatedSize = initialSource.getEstimatedSizeBytes(options);
- List<BoundedSource<Scientist>> splits = initialSource.split(desiredBundleSizeBytes, options);
- SourceTestUtils.assertSourcesEqualReferenceSource(initialSource, splits, options);
- float expectedNumSplitsloat =
- (float) initialSource.getEstimatedSizeBytes(options) / desiredBundleSizeBytes;
- long sum = 0;
-
- for (BoundedSource<Scientist> subSource : splits) {
- sum += subSource.getEstimatedSizeBytes(options);
- }
-
- // due to division and cast estimateSize != sum but will be close. Exact equals checked below
- assertEquals((long) (estimatedSize / splits.size()) * splits.size(), sum);
-
- int expectedNumSplits = (int) Math.ceil(expectedNumSplitsloat);
- assertEquals("Wrong number of splits", expectedNumSplits, splits.size());
- int emptySplits = 0;
- for (BoundedSource<Scientist> subSource : splits) {
- if (readFromSource(subSource, options).isEmpty()) {
- emptySplits += 1;
- }
- }
- assertThat(
- "There are too many empty splits, parallelism is sub-optimal",
- emptySplits,
- lessThan((int) (ACCEPTABLE_EMPTY_SPLITS_PERCENTAGE * splits.size())));
- }
-
private List<Row> getRows(String table) {
ResultSet result =
session.execute(
@@ -545,6 +585,7 @@
Scientist einstein = new Scientist();
einstein.id = 0;
+ einstein.department = "phys";
einstein.name = "Einstein";
pipeline
.apply(Create.of(einstein))
@@ -561,7 +602,8 @@
// re-insert suppressed doc to make the test autonomous
session.execute(
String.format(
- "INSERT INTO %s.%s(person_id, person_name) values("
+ "INSERT INTO %s.%s(person_department, person_id, person_name) values("
+ + "'phys', "
+ einstein.id
+ ", '"
+ einstein.name
@@ -570,58 +612,6 @@
CASSANDRA_TABLE));
}
- @Test
- public void testValidPartitioner() {
- Assert.assertTrue(isMurmur3Partitioner(cluster));
- }
-
- @Test
- public void testDistance() {
- BigInteger distance = distance(new BigInteger("10"), new BigInteger("100"));
- assertEquals(BigInteger.valueOf(90), distance);
-
- distance = distance(new BigInteger("100"), new BigInteger("10"));
- assertEquals(new BigInteger("18446744073709551526"), distance);
- }
-
- @Test
- public void testRingFraction() {
- // simulate a first range taking "half" of the available tokens
- List<TokenRange> tokenRanges = new ArrayList<>();
- tokenRanges.add(new TokenRange(1, 1, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
- assertEquals(0.5, getRingFraction(tokenRanges), 0);
-
- // add a second range to cover all tokens available
- tokenRanges.add(new TokenRange(1, 1, new BigInteger("0"), BigInteger.valueOf(Long.MAX_VALUE)));
- assertEquals(1.0, getRingFraction(tokenRanges), 0);
- }
-
- @Test
- public void testEstimatedSizeBytesFromTokenRanges() {
- List<TokenRange> tokenRanges = new ArrayList<>();
- // one partition containing all tokens, the size is actually the size of the partition
- tokenRanges.add(
- new TokenRange(
- 1, 1000, BigInteger.valueOf(Long.MIN_VALUE), BigInteger.valueOf(Long.MAX_VALUE)));
- assertEquals(1000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-
- // one partition with half of the tokens, we estimate the size to the double of this partition
- tokenRanges = new ArrayList<>();
- tokenRanges.add(
- new TokenRange(1, 1000, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
- assertEquals(2000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-
- // we have three partitions covering all tokens, the size is the sum of partition size *
- // partition count
- tokenRanges = new ArrayList<>();
- tokenRanges.add(
- new TokenRange(1, 1000, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("-3")));
- tokenRanges.add(new TokenRange(1, 1000, new BigInteger("-2"), new BigInteger("10000")));
- tokenRanges.add(
- new TokenRange(2, 3000, new BigInteger("10001"), BigInteger.valueOf(Long.MAX_VALUE)));
- assertEquals(8000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
- }
-
/** Simple Cassandra entity used in read tests. */
@Table(name = CASSANDRA_TABLE, keyspace = CASSANDRA_KEYSPACE)
static class Scientist implements Serializable {
@@ -632,10 +622,14 @@
@Computed("writetime(person_name)")
Long nameTs;
- @PartitionKey()
+ @ClusteringColumn()
@Column(name = "person_id")
int id;
+ @PartitionKey
+ @Column(name = "person_department")
+ String department;
+
@Override
public String toString() {
return id + ":" + name;
@@ -650,7 +644,9 @@
return false;
}
Scientist scientist = (Scientist) o;
- return id == scientist.id && Objects.equal(name, scientist.name);
+ return id == scientist.id
+ && Objects.equal(name, scientist.name)
+ && Objects.equal(department, scientist.department);
}
@Override
@@ -659,6 +655,11 @@
}
}
+ private static RingRange fromEncodedKey(Metadata metadata, ByteBuffer... bb) {
+ BigInteger bi = BigInteger.valueOf((long) metadata.newToken(bb).getValue());
+ return RingRange.of(bi, bi.add(BigInteger.valueOf(1L)));
+ }
+
private static final String CASSANDRA_TABLE_WRITE = "scientist_write";
/** Simple Cassandra entity used in write tests. */
@Table(name = CASSANDRA_TABLE_WRITE, keyspace = CASSANDRA_KEYSPACE)
diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle
index 3ec4d39..215a66b 100644
--- a/sdks/java/io/google-cloud-platform/build.gradle
+++ b/sdks/java/io/google-cloud-platform/build.gradle
@@ -131,7 +131,7 @@
compile "org.threeten:threetenbp:1.4.4"
- testCompile "org.apache.arrow:arrow-memory-netty:4.0.0"
+ testCompile library.java.arrow_memory_netty
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
testCompile project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntime")
testCompile project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntime")
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
index 484fe3b..dfe797e 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
@@ -86,7 +86,6 @@
private final SerializableFunction<ElementT, TableRow> toTableRow;
private final SerializableFunction<ElementT, TableRow> toFailsafeTableRow;
private final Set<String> allowedMetricUrns;
- private @Nullable DatasetService datasetService;
/** Tracks bytes written, exposed as "ByteCount" Counter. */
private Counter byteCounter = SinkMetrics.bytesWritten();
@@ -222,6 +221,15 @@
/** The list of unique ids for each BigQuery table row. */
private transient Map<String, List<String>> uniqueIdsForTableRows;
+ private transient @Nullable DatasetService datasetService;
+
+ private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
+ if (datasetService == null) {
+ datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
+ }
+ return datasetService;
+ }
+
/** Prepares a target BigQuery table. */
@StartBundle
public void startBundle() {
@@ -257,10 +265,10 @@
tableRows.entrySet()) {
TableReference tableReference = BigQueryHelpers.parseTableSpec(entry.getKey());
flushRows(
+ getDatasetService(options),
tableReference,
entry.getValue(),
uniqueIdsForTableRows.get(entry.getKey()),
- options,
failedInserts,
successfulInserts);
}
@@ -272,6 +280,18 @@
}
reportStreamingApiLogging(options);
}
+
+ @Teardown
+ public void onTeardown() {
+ try {
+ if (datasetService != null) {
+ datasetService.close();
+ datasetService = null;
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
}
// The max duration input records are allowed to be buffered in the state, if using ViaStateful.
@@ -325,13 +345,22 @@
// shuffling.
private class InsertBatchedElements
extends DoFn<KV<ShardedKey<String>, Iterable<TableRowInfo<ElementT>>>, Void> {
+ private transient @Nullable DatasetService datasetService;
+
+ private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
+ if (datasetService == null) {
+ datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
+ }
+ return datasetService;
+ }
+
@ProcessElement
public void processElement(
@Element KV<ShardedKey<String>, Iterable<TableRowInfo<ElementT>>> input,
BoundedWindow window,
ProcessContext context,
MultiOutputReceiver out)
- throws InterruptedException {
+ throws InterruptedException, IOException {
List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows = new ArrayList<>();
List<String> uniqueIds = new ArrayList<>();
for (TableRowInfo<ElementT> row : input.getValue()) {
@@ -347,7 +376,13 @@
TableReference tableReference = BigQueryHelpers.parseTableSpec(input.getKey().getKey());
List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
- flushRows(tableReference, tableRows, uniqueIds, options, failedInserts, successfulInserts);
+ flushRows(
+ getDatasetService(options),
+ tableReference,
+ tableRows,
+ uniqueIds,
+ failedInserts,
+ successfulInserts);
for (ValueInSingleWindow<ErrorT> row : failedInserts) {
out.get(failedOutputTag).output(row.getValue());
@@ -357,44 +392,43 @@
}
reportStreamingApiLogging(options);
}
- }
- @Teardown
- public void onTeardown() {
- try {
- if (datasetService != null) {
- datasetService.close();
- datasetService = null;
+ @Teardown
+ public void onTeardown() {
+ try {
+ if (datasetService != null) {
+ datasetService.close();
+ datasetService = null;
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
}
- } catch (Exception e) {
- throw new RuntimeException(e);
}
}
/** Writes the accumulated rows into BigQuery with streaming API. */
private void flushRows(
+ DatasetService datasetService,
TableReference tableReference,
List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows,
List<String> uniqueIds,
- BigQueryOptions options,
List<ValueInSingleWindow<ErrorT>> failedInserts,
List<ValueInSingleWindow<TableRow>> successfulInserts)
throws InterruptedException {
if (!tableRows.isEmpty()) {
try {
long totalBytes =
- getDatasetService(options)
- .insertAll(
- tableReference,
- tableRows,
- uniqueIds,
- retryPolicy,
- failedInserts,
- errorContainer,
- skipInvalidRows,
- ignoreUnknownValues,
- ignoreInsertIds,
- successfulInserts);
+ datasetService.insertAll(
+ tableReference,
+ tableRows,
+ uniqueIds,
+ retryPolicy,
+ failedInserts,
+ errorContainer,
+ skipInvalidRows,
+ ignoreUnknownValues,
+ ignoreInsertIds,
+ successfulInserts);
byteCounter.inc(totalBytes);
} catch (IOException e) {
throw new RuntimeException(e);
@@ -402,13 +436,6 @@
}
}
- private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
- if (datasetService == null) {
- datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
- }
- return datasetService;
- }
-
private void reportStreamingApiLogging(BigQueryOptions options) {
MetricsContainer processWideContainer = MetricsEnvironment.getProcessWideContainer();
if (processWideContainer instanceof MetricsLogger) {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index 1d3d894..8b9b705 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -965,49 +965,53 @@
// earlier stages of the pipeline or if a query depends on earlier stages of a pipeline.
// For these cases the withoutValidation method can be used to disable the check.
if (getValidate()) {
- if (table != null) {
- checkArgument(table.isAccessible(), "Cannot call validate if table is dynamically set.");
- }
- if (table != null && table.get().getProjectId() != null) {
- // Check for source table presence for early failure notification.
- DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions);
- BigQueryHelpers.verifyDatasetPresence(datasetService, table.get());
- BigQueryHelpers.verifyTablePresence(datasetService, table.get());
- } else if (getQuery() != null) {
- checkArgument(
- getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
- JobService jobService = getBigQueryServices().getJobService(bqOptions);
- try {
- jobService.dryRunQuery(
- bqOptions.getBigQueryProject() == null
- ? bqOptions.getProject()
- : bqOptions.getBigQueryProject(),
- new JobConfigurationQuery()
- .setQuery(getQuery().get())
- .setFlattenResults(getFlattenResults())
- .setUseLegacySql(getUseLegacySql()),
- getQueryLocation());
- } catch (Exception e) {
- throw new IllegalArgumentException(
- String.format(QUERY_VALIDATION_FAILURE_ERROR, getQuery().get()), e);
+ try (DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions)) {
+ if (table != null) {
+ checkArgument(
+ table.isAccessible(), "Cannot call validate if table is dynamically set.");
}
+ if (table != null && table.get().getProjectId() != null) {
+ // Check for source table presence for early failure notification.
+ BigQueryHelpers.verifyDatasetPresence(datasetService, table.get());
+ BigQueryHelpers.verifyTablePresence(datasetService, table.get());
+ } else if (getQuery() != null) {
+ checkArgument(
+ getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
+ JobService jobService = getBigQueryServices().getJobService(bqOptions);
+ try {
+ jobService.dryRunQuery(
+ bqOptions.getBigQueryProject() == null
+ ? bqOptions.getProject()
+ : bqOptions.getBigQueryProject(),
+ new JobConfigurationQuery()
+ .setQuery(getQuery().get())
+ .setFlattenResults(getFlattenResults())
+ .setUseLegacySql(getUseLegacySql()),
+ getQueryLocation());
+ } catch (Exception e) {
+ throw new IllegalArgumentException(
+ String.format(QUERY_VALIDATION_FAILURE_ERROR, getQuery().get()), e);
+ }
- DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions);
- // If the user provided a temp dataset, check if the dataset exists before launching the
- // query
- if (getQueryTempDataset() != null) {
- // The temp table is only used for dataset and project id validation, not for table name
- // validation
- TableReference tempTable =
- new TableReference()
- .setProjectId(
- bqOptions.getBigQueryProject() == null
- ? bqOptions.getProject()
- : bqOptions.getBigQueryProject())
- .setDatasetId(getQueryTempDataset())
- .setTableId("dummy table");
- BigQueryHelpers.verifyDatasetPresence(datasetService, tempTable);
+ // If the user provided a temp dataset, check if the dataset exists before launching the
+ // query
+ if (getQueryTempDataset() != null) {
+ // The temp table is only used for dataset and project id validation, not for table
+ // name
+ // validation
+ TableReference tempTable =
+ new TableReference()
+ .setProjectId(
+ bqOptions.getBigQueryProject() == null
+ ? bqOptions.getProject()
+ : bqOptions.getBigQueryProject())
+ .setDatasetId(getQueryTempDataset())
+ .setTableId("dummy table");
+ BigQueryHelpers.verifyDatasetPresence(datasetService, tempTable);
+ }
}
+ } catch (Exception e) {
+ throw new RuntimeException(e);
}
}
}
@@ -1401,15 +1405,17 @@
options.getJobName(), jobUuid, JobType.QUERY),
queryTempDataset);
- DatasetService datasetService = getBigQueryServices().getDatasetService(options);
- LOG.info("Deleting temporary table with query results {}", tempTable);
- datasetService.deleteTable(tempTable);
- // Delete dataset only if it was created by Beam
- boolean datasetCreatedByBeam = !queryTempDataset.isPresent();
- if (datasetCreatedByBeam) {
- LOG.info(
- "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
- datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+ try (DatasetService datasetService =
+ getBigQueryServices().getDatasetService(options)) {
+ LOG.info("Deleting temporary table with query results {}", tempTable);
+ datasetService.deleteTable(tempTable);
+ // Delete dataset only if it was created by Beam
+ boolean datasetCreatedByBeam = !queryTempDataset.isPresent();
+ if (datasetCreatedByBeam) {
+ LOG.info(
+ "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
+ datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+ }
}
}
};
@@ -2484,17 +2490,20 @@
// The user specified a table.
if (getJsonTableRef() != null && getJsonTableRef().isAccessible() && getValidate()) {
TableReference table = getTableWithDefaultProject(options).get();
- DatasetService datasetService = getBigQueryServices().getDatasetService(options);
- // Check for destination table presence and emptiness for early failure notification.
- // Note that a presence check can fail when the table or dataset is created by an earlier
- // stage of the pipeline. For these cases the #withoutValidation method can be used to
- // disable the check.
- BigQueryHelpers.verifyDatasetPresence(datasetService, table);
- if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {
- BigQueryHelpers.verifyTablePresence(datasetService, table);
- }
- if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) {
- BigQueryHelpers.verifyTableNotExistOrEmpty(datasetService, table);
+ try (DatasetService datasetService = getBigQueryServices().getDatasetService(options)) {
+ // Check for destination table presence and emptiness for early failure notification.
+ // Note that a presence check can fail when the table or dataset is created by an earlier
+ // stage of the pipeline. For these cases the #withoutValidation method can be used to
+ // disable the check.
+ BigQueryHelpers.verifyDatasetPresence(datasetService, table);
+ if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {
+ BigQueryHelpers.verifyTablePresence(datasetService, table);
+ }
+ if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) {
+ BigQueryHelpers.verifyTableNotExistOrEmpty(datasetService, table);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
}
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java
index aeee3d2..eb9f8c4 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java
@@ -38,10 +38,14 @@
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
+import java.util.HashMap;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.ServiceCallMetric;
import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource;
import org.apache.beam.sdk.io.range.ByteKeyRange;
import org.apache.beam.sdk.values.KV;
@@ -130,12 +134,40 @@
String tableNameSr =
session.getOptions().getInstanceName().toTableNameStr(source.getTableId().get());
+ HashMap<String, String> baseLabels = new HashMap<>();
+ baseLabels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+ baseLabels.put(MonitoringInfoConstants.Labels.SERVICE, "BigTable");
+ baseLabels.put(MonitoringInfoConstants.Labels.METHOD, "google.bigtable.v2.ReadRows");
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.RESOURCE,
+ GcpResourceIdentifiers.bigtableResource(
+ session.getOptions().getProjectId(),
+ session.getOptions().getInstanceId(),
+ source.getTableId().get()));
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.BIGTABLE_PROJECT_ID, session.getOptions().getProjectId());
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.INSTANCE_ID, session.getOptions().getInstanceId());
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.TABLE_ID,
+ GcpResourceIdentifiers.bigtableTableID(
+ session.getOptions().getProjectId(),
+ session.getOptions().getInstanceId(),
+ source.getTableId().get()));
+ ServiceCallMetric serviceCallMetric =
+ new ServiceCallMetric(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, baseLabels);
ReadRowsRequest.Builder requestB =
ReadRowsRequest.newBuilder().setRows(rowSet).setTableName(tableNameSr);
if (source.getRowFilter() != null) {
requestB.setFilter(source.getRowFilter());
}
- results = session.getDataClient().readRows(requestB.build());
+ try {
+ results = session.getDataClient().readRows(requestB.build());
+ serviceCallMetric.call("ok");
+ } catch (StatusRuntimeException e) {
+ serviceCallMetric.call(e.getStatus().getCode().value());
+ throw e;
+ }
return advance();
}
@@ -182,10 +214,12 @@
static class BigtableWriterImpl implements Writer {
private BigtableSession session;
private BulkMutation bulkMutation;
+ private BigtableTableName tableName;
BigtableWriterImpl(BigtableSession session, BigtableTableName tableName) {
this.session = session;
bulkMutation = session.createBulkMutation(tableName);
+ this.tableName = tableName;
}
@Override
@@ -231,6 +265,28 @@
.addAllMutations(record.getValue())
.build();
+ HashMap<String, String> baseLabels = new HashMap<>();
+ baseLabels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+ baseLabels.put(MonitoringInfoConstants.Labels.SERVICE, "BigTable");
+ baseLabels.put(MonitoringInfoConstants.Labels.METHOD, "google.bigtable.v2.MutateRows");
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.RESOURCE,
+ GcpResourceIdentifiers.bigtableResource(
+ session.getOptions().getProjectId(),
+ session.getOptions().getInstanceId(),
+ tableName.getTableId()));
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.BIGTABLE_PROJECT_ID, session.getOptions().getProjectId());
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.INSTANCE_ID, session.getOptions().getInstanceId());
+ baseLabels.put(
+ MonitoringInfoConstants.Labels.TABLE_ID,
+ GcpResourceIdentifiers.bigtableTableID(
+ session.getOptions().getProjectId(),
+ session.getOptions().getInstanceId(),
+ tableName.getTableId()));
+ ServiceCallMetric serviceCallMetric =
+ new ServiceCallMetric(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, baseLabels);
CompletableFuture<MutateRowResponse> result = new CompletableFuture<>();
Futures.addCallback(
new VendoredListenableFutureAdapter<>(bulkMutation.add(request)),
@@ -238,10 +294,17 @@
@Override
public void onSuccess(MutateRowResponse mutateRowResponse) {
result.complete(mutateRowResponse);
+ serviceCallMetric.call("ok");
}
@Override
public void onFailure(Throwable throwable) {
+ if (throwable instanceof StatusRuntimeException) {
+ serviceCallMetric.call(
+ ((StatusRuntimeException) throwable).getStatus().getCode().value());
+ } else {
+ serviceCallMetric.call("unknown");
+ }
result.completeExceptionally(throwable);
}
},
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java
index 69be079..b983cc3 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImplTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.gcp.bigtable;
+import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.times;
@@ -42,9 +43,15 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import java.util.HashMap;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO.BigtableSource;
import org.apache.beam.sdk.io.range.ByteKey;
import org.apache.beam.sdk.io.range.ByteKeyRange;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -61,8 +68,12 @@
@RunWith(JUnit4.class)
public class BigtableServiceImplTest {
+ private static final String PROJECT_ID = "project";
+ private static final String INSTANCE_ID = "instance";
+ private static final String TABLE_ID = "table";
+
private static final BigtableTableName TABLE_NAME =
- new BigtableInstanceName("project", "instance").toTableName("table");
+ new BigtableInstanceName(PROJECT_ID, INSTANCE_ID).toTableName(TABLE_ID);
@Mock private BigtableSession mockSession;
@@ -76,10 +87,13 @@
public void setup() {
MockitoAnnotations.initMocks(this);
BigtableOptions options =
- new BigtableOptions.Builder().setProjectId("project").setInstanceId("instance").build();
+ new BigtableOptions.Builder().setProjectId(PROJECT_ID).setInstanceId(INSTANCE_ID).build();
when(mockSession.getOptions()).thenReturn(options);
when(mockSession.createBulkMutation(eq(TABLE_NAME))).thenReturn(mockBulkMutation);
when(mockSession.getDataClient()).thenReturn(mockBigtableDataClient);
+ // Setup the ProcessWideContainer for testing metrics are set.
+ MetricsContainerImpl container = new MetricsContainerImpl(null);
+ MetricsEnvironment.setProcessWideContainer(container);
}
/**
@@ -94,7 +108,7 @@
ByteKey start = ByteKey.copyFrom("a".getBytes(StandardCharsets.UTF_8));
ByteKey end = ByteKey.copyFrom("b".getBytes(StandardCharsets.UTF_8));
when(mockBigtableSource.getRanges()).thenReturn(Arrays.asList(ByteKeyRange.of(start, end)));
- when(mockBigtableSource.getTableId()).thenReturn(StaticValueProvider.of("table_name"));
+ when(mockBigtableSource.getTableId()).thenReturn(StaticValueProvider.of(TABLE_ID));
@SuppressWarnings("unchecked")
ResultScanner<Row> mockResultScanner = Mockito.mock(ResultScanner.class);
Row expectedRow = Row.newBuilder().setKey(ByteString.copyFromUtf8("a")).build();
@@ -109,6 +123,7 @@
underTest.close();
verify(mockResultScanner, times(1)).close();
+ verifyMetricWasSet("google.bigtable.v2.ReadRows", "ok", 1);
}
/**
@@ -140,4 +155,27 @@
underTest.close();
verify(mockBulkMutation, times(1)).flush();
}
+
+ private void verifyMetricWasSet(String method, String status, long count) {
+ // Verify the metric as reported.
+ HashMap<String, String> labels = new HashMap<>();
+ labels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+ labels.put(MonitoringInfoConstants.Labels.SERVICE, "BigTable");
+ labels.put(MonitoringInfoConstants.Labels.METHOD, method);
+ labels.put(
+ MonitoringInfoConstants.Labels.RESOURCE,
+ GcpResourceIdentifiers.bigtableResource(PROJECT_ID, INSTANCE_ID, TABLE_ID));
+ labels.put(MonitoringInfoConstants.Labels.BIGTABLE_PROJECT_ID, PROJECT_ID);
+ labels.put(MonitoringInfoConstants.Labels.INSTANCE_ID, INSTANCE_ID);
+ labels.put(
+ MonitoringInfoConstants.Labels.TABLE_ID,
+ GcpResourceIdentifiers.bigtableTableID(PROJECT_ID, INSTANCE_ID, TABLE_ID));
+ labels.put(MonitoringInfoConstants.Labels.STATUS, status);
+
+ MonitoringInfoMetricName name =
+ MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, labels);
+ MetricsContainerImpl container =
+ (MetricsContainerImpl) MetricsEnvironment.getProcessWideContainer();
+ assertEquals(count, (long) container.getCounter(name).getCumulative());
+ }
}
diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index 2cab8b4..c0e36f7 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -1900,6 +1900,17 @@
.withMaxRetries(retryConfiguration.getMaxAttempts());
}
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ spec.populateDisplayData(builder);
+ builder.add(
+ DisplayData.item(
+ "query", preparedStatement == null ? "null" : preparedStatement.toString()));
+ builder.add(
+ DisplayData.item("dataSource", dataSource == null ? "null" : dataSource.toString()));
+ builder.add(DisplayData.item("spec", spec == null ? "null" : spec.toString()));
+ }
+
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
T record = context.element();
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 886e56e..9ed0b25 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -208,8 +208,6 @@
def load_pcoder(self, *labels):
saved_pcoder = self._saved_pcoders.get(self._path(*labels), None)
- # TODO(BEAM-12506): Get rid of the SafeFastPrimitivesCoder for
- # WindowedValueHolder.
if saved_pcoder is None or isinstance(saved_pcoder,
coders.FastPrimitivesCoder):
return self._default_pcoder
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 054c9a6..fc8a8aa 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -390,8 +390,6 @@
def load_pcoder(self, *labels):
saved_pcoder = self._saved_pcoders.get(
os.path.join(self._cache_dir, *labels), None)
- # TODO(BEAM-12506): Get rid of the SafeFastPrimitivesCoder for
- # WindowedValueHolder.
if saved_pcoder is None or isinstance(saved_pcoder,
coders.FastPrimitivesCoder):
return self._default_pcoder
diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
index cee3d34..1dc42e0 100644
--- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
@@ -227,7 +227,11 @@
endpoint=test_stream_service.endpoint)
sql_source = {}
for tag, output in output_pcolls.items():
- sql_source[tag_to_name[tag]] = output
+ name = tag_to_name[tag]
+ # Must mark the element_type to avoid introducing pickled Python coder
+ # to the Java expansion service.
+ output.element_type = name_to_pcoll[name].element_type
+ sql_source[name] = output
return sql_source