Merge pull request #10628: [BEAM-9144] Beam's own Avro TimeConversion class in beam-sdk-java-core
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 29de08a..8939bf2 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -373,6 +373,7 @@
def google_cloud_bigdataoss_version = "1.9.16"
def google_cloud_core_version = "1.61.0"
def google_cloud_spanner_version = "1.6.0"
+ def google_http_clients_version = "1.34.0"
def grpc_version = "1.25.0"
def guava_version = "25.1-jre"
def hadoop_version = "2.8.5"
@@ -457,10 +458,10 @@
google_cloud_dataflow_java_proto_library_all: "com.google.cloud.dataflow:google-cloud-dataflow-java-proto-library-all:0.5.160304",
google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:1.6.3",
google_cloud_spanner : "com.google.cloud:google-cloud-spanner:$google_cloud_spanner_version",
- google_http_client : "com.google.http-client:google-http-client:$google_clients_version",
- google_http_client_jackson : "com.google.http-client:google-http-client-jackson:$google_clients_version",
- google_http_client_jackson2 : "com.google.http-client:google-http-client-jackson2:$google_clients_version",
- google_http_client_protobuf : "com.google.http-client:google-http-client-protobuf:$google_clients_version",
+ google_http_client : "com.google.http-client:google-http-client:$google_http_clients_version",
+ google_http_client_jackson : "com.google.http-client:google-http-client-jackson:1.29.2",
+ google_http_client_jackson2 : "com.google.http-client:google-http-client-jackson2:$google_http_clients_version",
+ google_http_client_protobuf : "com.google.http-client:google-http-client-protobuf:$google_http_clients_version",
google_oauth_client : "com.google.oauth-client:google-oauth-client:$google_clients_version",
google_oauth_client_java6 : "com.google.oauth-client:google-oauth-client-java6:$google_clients_version",
grpc_all : "io.grpc:grpc-all:$grpc_version",
diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle
index f1a1394..30e91930 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -83,6 +83,7 @@
compile library.java.slf4j_api
compile library.java.vendored_grpc_1_26_0
testCompile library.java.hamcrest_core
+ testCompile library.java.guava_testlib
testCompile library.java.junit
testCompile project(path: ":sdks:java:io:google-cloud-platform", configuration: "testRuntime")
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObject.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObject.java
index e341004..8e8589d 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObject.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObject.java
@@ -22,6 +22,7 @@
import com.google.api.client.json.GenericJson;
import com.google.api.client.util.Key;
import java.util.Map;
+import java.util.Objects;
import javax.annotation.Nullable;
/**
@@ -182,4 +183,18 @@
public CloudObject clone() {
return (CloudObject) super.clone();
}
+
+ @Override
+ public boolean equals(Object otherObject) {
+ if (!(otherObject instanceof CloudObject)) {
+ return false;
+ }
+ CloudObject other = (CloudObject) otherObject;
+ return Objects.equals(className, other.className) && super.equals(otherObject);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(className, super.hashCode());
+ }
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/OutputReference.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/OutputReference.java
index f8b7784..7c1b9e4f 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/OutputReference.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/OutputReference.java
@@ -21,6 +21,7 @@
import com.google.api.client.json.GenericJson;
import com.google.api.client.util.Key;
+import java.util.Objects;
/**
* A representation used by {@link com.google.api.services.dataflow.model.Step}s to reference the
@@ -40,4 +41,21 @@
this.stepName = checkNotNull(stepName);
this.outputName = checkNotNull(outputName);
}
+
+ @Override
+ public boolean equals(Object otherObject) {
+ if (!(otherObject instanceof OutputReference)) {
+ return false;
+ }
+ OutputReference other = (OutputReference) otherObject;
+ return Objects.equals(type, other.type)
+ && Objects.equals(stepName, other.stepName)
+ && Objects.equals(outputName, other.outputName)
+ && super.equals(other);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(type, stepName, outputName, super.hashCode());
+ }
}
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectTest.java
new file mode 100644
index 0000000..118bab8
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectTest.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.util;
+
+import com.google.common.testing.EqualsTester;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class CloudObjectTest {
+
+ @Test
+ public void testEquality() {
+ new EqualsTester()
+ .addEqualityGroup(CloudObject.forFloat(1.0), CloudObject.forFloat(1.0))
+ .addEqualityGroup(CloudObject.forInteger(3L), CloudObject.forInteger(3L))
+ .addEqualityGroup(CloudObject.forFloat(3.0))
+ .addEqualityGroup(CloudObject.forString("foo"), CloudObject.forString("foo"))
+ .addEqualityGroup(CloudObject.forClassName("foo.Bar"), CloudObject.forClassName("foo.Bar"))
+ .addEqualityGroup(
+ CloudObject.fromSpec(ImmutableMap.of(PropertyNames.OBJECT_TYPE_NAME, "ValuesDoFn")),
+ CloudObject.fromSpec(ImmutableMap.of(PropertyNames.OBJECT_TYPE_NAME, "ValuesDoFn")))
+ .testEquals();
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/OutputReferenceTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/OutputReferenceTest.java
new file mode 100644
index 0000000..f0817b5
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/OutputReferenceTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.util;
+
+import com.google.common.testing.EqualsTester;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class OutputReferenceTest {
+
+ @Test
+ public void testEquality() {
+ new EqualsTester()
+ .addEqualityGroup(new OutputReference("sA", "oA"), new OutputReference("sA", "oA"))
+ .addEqualityGroup(new OutputReference("sB", "oB"), new OutputReference("sB", "oB"))
+ .testEquals();
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/GenericJsonMatcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/GenericJsonMatcherTest.java
index cac9fe3..146fb17 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/GenericJsonMatcherTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/GenericJsonMatcherTest.java
@@ -53,7 +53,9 @@
assertThat(actual, is(jsonOf(expected)));
} catch (AssertionError ex) {
assertEquals(
- "\nExpected: is {\"foo\":\"expected\"}\n but: was <{foo=actual}>", ex.getMessage());
+ "\nExpected: is {\"foo\":\"expected\"}\n"
+ + " but: was <GenericData{classInfo=[], {foo=actual}}>",
+ ex.getMessage());
// pass
return;
diff --git a/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java b/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
index 9612487..57617c0 100644
--- a/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
+++ b/sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
@@ -114,6 +114,7 @@
.setConnectionConfiguration(RedisConnectionConfiguration.create())
.setKeyPattern("*")
.setBatchSize(1000)
+ .setOutputParallelization(true)
.build();
}
@@ -125,6 +126,7 @@
return new AutoValue_RedisIO_ReadAll.Builder()
.setConnectionConfiguration(RedisConnectionConfiguration.create())
.setBatchSize(1000)
+ .setOutputParallelization(true)
.build();
}
@@ -150,7 +152,9 @@
abstract int batchSize();
- abstract Builder builder();
+ abstract boolean outputParallelization();
+
+ abstract Builder toBuilder();
@AutoValue.Builder
abstract static class Builder {
@@ -162,41 +166,53 @@
abstract Builder setBatchSize(int batchSize);
+ abstract Builder setOutputParallelization(boolean outputParallelization);
+
abstract Read build();
}
public Read withEndpoint(String host, int port) {
checkArgument(host != null, "host can not be null");
checkArgument(0 < port && port < 65536, "port must be a positive integer less than 65536");
- return builder()
+ return toBuilder()
.setConnectionConfiguration(connectionConfiguration().withHost(host).withPort(port))
.build();
}
public Read withAuth(String auth) {
checkArgument(auth != null, "auth can not be null");
- return builder().setConnectionConfiguration(connectionConfiguration().withAuth(auth)).build();
+ return toBuilder()
+ .setConnectionConfiguration(connectionConfiguration().withAuth(auth))
+ .build();
}
public Read withTimeout(int timeout) {
checkArgument(timeout >= 0, "timeout can not be negative");
- return builder()
+ return toBuilder()
.setConnectionConfiguration(connectionConfiguration().withTimeout(timeout))
.build();
}
public Read withKeyPattern(String keyPattern) {
checkArgument(keyPattern != null, "keyPattern can not be null");
- return builder().setKeyPattern(keyPattern).build();
+ return toBuilder().setKeyPattern(keyPattern).build();
}
public Read withConnectionConfiguration(RedisConnectionConfiguration connection) {
checkArgument(connection != null, "connection can not be null");
- return builder().setConnectionConfiguration(connection).build();
+ return toBuilder().setConnectionConfiguration(connection).build();
}
public Read withBatchSize(int batchSize) {
- return builder().setBatchSize(batchSize).build();
+ return toBuilder().setBatchSize(batchSize).build();
+ }
+
+ /**
+ * Whether to reshuffle the resulting PCollection so results are distributed to all workers. The
+ * default is to parallelize and should only be changed if this is known to be unnecessary.
+ */
+ public Read withOutputParallelization(boolean outputParallelization) {
+ return toBuilder().setOutputParallelization(outputParallelization).build();
}
@Override
@@ -214,7 +230,8 @@
.apply(
RedisIO.readAll()
.withConnectionConfiguration(connectionConfiguration())
- .withBatchSize(batchSize()));
+ .withBatchSize(batchSize())
+ .withOutputParallelization(outputParallelization()));
}
}
@@ -228,14 +245,18 @@
abstract int batchSize();
- abstract ReadAll.Builder builder();
+ abstract boolean outputParallelization();
+
+ abstract Builder toBuilder();
@AutoValue.Builder
abstract static class Builder {
@Nullable
- abstract ReadAll.Builder setConnectionConfiguration(RedisConnectionConfiguration connection);
+ abstract Builder setConnectionConfiguration(RedisConnectionConfiguration connection);
- abstract ReadAll.Builder setBatchSize(int batchSize);
+ abstract Builder setBatchSize(int batchSize);
+
+ abstract Builder setOutputParallelization(boolean outputParallelization);
abstract ReadAll build();
}
@@ -243,44 +264,57 @@
public ReadAll withEndpoint(String host, int port) {
checkArgument(host != null, "host can not be null");
checkArgument(port > 0, "port can not be negative or 0");
- return builder()
+ return toBuilder()
.setConnectionConfiguration(connectionConfiguration().withHost(host).withPort(port))
.build();
}
public ReadAll withAuth(String auth) {
checkArgument(auth != null, "auth can not be null");
- return builder().setConnectionConfiguration(connectionConfiguration().withAuth(auth)).build();
+ return toBuilder()
+ .setConnectionConfiguration(connectionConfiguration().withAuth(auth))
+ .build();
}
public ReadAll withTimeout(int timeout) {
checkArgument(timeout >= 0, "timeout can not be negative");
- return builder()
+ return toBuilder()
.setConnectionConfiguration(connectionConfiguration().withTimeout(timeout))
.build();
}
public ReadAll withConnectionConfiguration(RedisConnectionConfiguration connection) {
checkArgument(connection != null, "connection can not be null");
- return builder().setConnectionConfiguration(connection).build();
+ return toBuilder().setConnectionConfiguration(connection).build();
}
public ReadAll withBatchSize(int batchSize) {
- return builder().setBatchSize(batchSize).build();
+ return toBuilder().setBatchSize(batchSize).build();
+ }
+
+ /**
+ * Whether to reshuffle the resulting PCollection so results are distributed to all workers. The
+ * default is to parallelize and should only be changed if this is known to be unnecessary.
+ */
+ public ReadAll withOutputParallelization(boolean outputParallelization) {
+ return toBuilder().setOutputParallelization(outputParallelization).build();
}
@Override
public PCollection<KV<String, String>> expand(PCollection<String> input) {
checkArgument(connectionConfiguration() != null, "withConnectionConfiguration() is required");
-
- return input
- .apply(ParDo.of(new ReadFn(connectionConfiguration(), batchSize())))
- .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
- .apply(new Reparallelize());
+ PCollection<KV<String, String>> output =
+ input
+ .apply(ParDo.of(new ReadFn(connectionConfiguration(), batchSize())))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));
+ if (outputParallelization()) {
+ output = output.apply(new Reparallelize());
+ }
+ return output;
}
}
- abstract static class BaseReadFn<T> extends DoFn<String, T> {
+ private abstract static class BaseReadFn<T> extends DoFn<String, T> {
protected final RedisConnectionConfiguration connectionConfiguration;
transient Jedis jedis;
@@ -307,9 +341,9 @@
}
@ProcessElement
- public void processElement(ProcessContext processContext) throws Exception {
+ public void processElement(ProcessContext c) {
ScanParams scanParams = new ScanParams();
- scanParams.match(processContext.element());
+ scanParams.match(c.element());
String cursor = ScanParams.SCAN_POINTER_START;
boolean finished = false;
@@ -317,7 +351,7 @@
ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
List<String> keys = scanResult.getResult();
for (String k : keys) {
- processContext.output(k);
+ c.output(k);
}
cursor = scanResult.getCursor();
if (cursor.equals(ScanParams.SCAN_POINTER_START)) {
@@ -326,42 +360,52 @@
}
}
}
+
/** A {@link DoFn} requesting Redis server to get key/value pairs. */
private static class ReadFn extends BaseReadFn<KV<String, String>> {
@Nullable transient Multimap<BoundedWindow, String> bundles = null;
@Nullable AtomicInteger batchCount = null;
private final int batchSize;
- @StartBundle
- public void startBundle(StartBundleContext context) {
- bundles = ArrayListMultimap.create();
- batchCount = new AtomicInteger();
- }
-
ReadFn(RedisConnectionConfiguration connectionConfiguration, int batchSize) {
super(connectionConfiguration);
this.batchSize = batchSize;
}
- private int getBatchSize() {
- return batchSize;
+ @StartBundle
+ public void startBundle() {
+ bundles = ArrayListMultimap.create();
+ batchCount = new AtomicInteger();
}
@ProcessElement
- public void processElement(ProcessContext processContext, BoundedWindow window)
- throws Exception {
- String key = processContext.element();
+ public void processElement(ProcessContext c, BoundedWindow window) {
+ String key = c.element();
bundles.put(window, key);
if (batchCount.incrementAndGet() > getBatchSize()) {
Multimap<BoundedWindow, KV<String, String>> kvs = fetchAndFlush();
for (BoundedWindow w : kvs.keySet()) {
for (KV<String, String> kv : kvs.get(w)) {
- processContext.output(kv);
+ c.output(kv);
}
}
}
}
+ @FinishBundle
+ public void finishBundle(FinishBundleContext context) {
+ Multimap<BoundedWindow, KV<String, String>> kvs = fetchAndFlush();
+ for (BoundedWindow w : kvs.keySet()) {
+ for (KV<String, String> kv : kvs.get(w)) {
+ context.output(kv, w.maxTimestamp(), w);
+ }
+ }
+ }
+
+ private int getBatchSize() {
+ return batchSize;
+ }
+
private Multimap<BoundedWindow, KV<String, String>> fetchAndFlush() {
Multimap<BoundedWindow, KV<String, String>> kvs = ArrayListMultimap.create();
for (BoundedWindow w : bundles.keySet()) {
@@ -378,16 +422,6 @@
batchCount.set(0);
return kvs;
}
-
- @FinishBundle
- public void finishBundle(FinishBundleContext context) throws Exception {
- Multimap<BoundedWindow, KV<String, String>> kvs = fetchAndFlush();
- for (BoundedWindow w : kvs.keySet()) {
- for (KV<String, String> kv : kvs.get(w)) {
- context.output(kv, w.maxTimestamp(), w);
- }
- }
- }
}
private static class Reparallelize
@@ -395,8 +429,7 @@
@Override
public PCollection<KV<String, String>> expand(PCollection<KV<String, String>> input) {
- // reparallelize mimics the same behavior as in JdbcIO
- // breaking fusion
+ // reparallelize mimics the same behavior as in JdbcIO, used to break fusion
PCollectionView<Iterable<KV<String, String>>> empty =
input
.apply("Consume", Filter.by(SerializableFunctions.constant(false)))
@@ -407,8 +440,8 @@
ParDo.of(
new DoFn<KV<String, String>, KV<String, String>>() {
@ProcessElement
- public void processElement(ProcessContext context) {
- context.output(context.element());
+ public void processElement(ProcessContext c) {
+ c.output(c.element());
}
})
.withSideInputs(empty));
@@ -468,7 +501,7 @@
@Nullable
abstract Long expireTime();
- abstract Builder builder();
+ abstract Builder toBuilder();
@AutoValue.Builder
abstract static class Builder {
@@ -486,37 +519,39 @@
public Write withEndpoint(String host, int port) {
checkArgument(host != null, "host can not be null");
checkArgument(port > 0, "port can not be negative or 0");
- return builder()
+ return toBuilder()
.setConnectionConfiguration(connectionConfiguration().withHost(host).withPort(port))
.build();
}
public Write withAuth(String auth) {
checkArgument(auth != null, "auth can not be null");
- return builder().setConnectionConfiguration(connectionConfiguration().withAuth(auth)).build();
+ return toBuilder()
+ .setConnectionConfiguration(connectionConfiguration().withAuth(auth))
+ .build();
}
public Write withTimeout(int timeout) {
checkArgument(timeout >= 0, "timeout can not be negative");
- return builder()
+ return toBuilder()
.setConnectionConfiguration(connectionConfiguration().withTimeout(timeout))
.build();
}
public Write withConnectionConfiguration(RedisConnectionConfiguration connection) {
checkArgument(connection != null, "connection can not be null");
- return builder().setConnectionConfiguration(connection).build();
+ return toBuilder().setConnectionConfiguration(connection).build();
}
public Write withMethod(Method method) {
checkArgument(method != null, "method can not be null");
- return builder().setMethod(method).build();
+ return toBuilder().setMethod(method).build();
}
public Write withExpireTime(Long expireTimeMillis) {
checkArgument(expireTimeMillis != null, "expireTimeMillis can not be null");
checkArgument(expireTimeMillis > 0, "expireTimeMillis can not be negative or 0");
- return builder().setExpireTime(expireTimeMillis).build();
+ return toBuilder().setExpireTime(expireTimeMillis).build();
}
@Override
@@ -555,8 +590,8 @@
}
@ProcessElement
- public void processElement(ProcessContext processContext) {
- KV<String, String> record = processContext.element();
+ public void processElement(ProcessContext c) {
+ KV<String, String> record = c.element();
writeRecord(record);
diff --git a/sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java b/sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java
index bcb3fca..badf039 100644
--- a/sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java
+++ b/sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java
@@ -175,7 +175,7 @@
}
@Test
- public void testWriteUsingINCRBY() throws Exception {
+ public void testWriteUsingINCRBY() {
String key = "key_incr";
List<String> values = Arrays.asList("0", "1", "2", "-3", "2", "4", "0", "5");
List<KV<String, String>> data = buildConstantKeyList(key, values);
@@ -190,7 +190,7 @@
}
@Test
- public void testWriteUsingDECRBY() throws Exception {
+ public void testWriteUsingDECRBY() {
String key = "key_decr";
List<String> values = Arrays.asList("-10", "1", "2", "-3", "2", "4", "0", "5");
diff --git a/sdks/python/apache_beam/internal/gcp/auth.py b/sdks/python/apache_beam/internal/gcp/auth.py
index 5f248b3..3921f73 100644
--- a/sdks/python/apache_beam/internal/gcp/auth.py
+++ b/sdks/python/apache_beam/internal/gcp/auth.py
@@ -128,7 +128,8 @@
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/devstorage.full_control',
'https://www.googleapis.com/auth/userinfo.email',
- 'https://www.googleapis.com/auth/datastore'
+ 'https://www.googleapis.com/auth/datastore',
+ 'https://www.googleapis.com/auth/spanner'
]
try:
credentials = GoogleCredentials.get_application_default()
diff --git a/sdks/python/apache_beam/io/gcp/experimental/__init__.py b/sdks/python/apache_beam/io/gcp/experimental/__init__.py
new file mode 100644
index 0000000..f4f43cb
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/experimental/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import absolute_import
diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
new file mode 100644
index 0000000..21a2f8f
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
@@ -0,0 +1,583 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Google Cloud Spanner IO
+
+Experimental; no backwards-compatibility guarantees.
+
+This is an experimental module for reading and writing data from Google Cloud
+Spanner. Visit: https://cloud.google.com/spanner for more details.
+
+To read from Cloud Spanner apply ReadFromSpanner transformation. It will
+return a PCollection, where each element represents an individual row returned
+from the read operation. Both Query and Read APIs are supported.
+
+ReadFromSpanner relies on the ReadOperation objects which is exposed by the
+SpannerIO API. ReadOperation holds the immutable data which is responsible to
+execute batch and naive reads on Cloud Spanner. This is done for more
+convenient programming.
+
+ReadFromSpanner reads from Cloud Spanner by providing either an 'sql' param
+in the constructor or 'table' name with 'columns' as list. For example:::
+
+ records = (pipeline
+ | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
+ sql='Select * from users'))
+
+ records = (pipeline
+ | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
+ table='users', columns=['id', 'name', 'email']))
+
+You can also perform multiple reads by providing a list of ReadOperations
+to the ReadFromSpanner transform constructor. ReadOperation exposes two static
+methods. Use 'query' to perform sql based reads, 'table' to perform read from
+table name. For example:::
+
+ read_operations = [
+ ReadOperation.table(table='customers', columns=['name',
+ 'email']),
+ ReadOperation.table(table='vendors', columns=['name',
+ 'email']),
+ ]
+ all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
+ read_operations=read_operations)
+
+ ...OR...
+
+ read_operations = [
+ ReadOperation.query(sql='Select name, email from
+ customers'),
+ ReadOperation.query(
+ sql='Select * from users where id <= @user_id',
+ params={'user_id': 100},
+ params_type={'user_id': param_types.INT64}
+ ),
+ ]
+ # `params_types` are instance of `google.cloud.spanner.param_types`
+ all_users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
+ read_operations=read_operations)
+
+For more information, please review the docs on class ReadOperation.
+
+User can also able to provide the ReadOperation in form of PCollection via
+pipeline. For example:::
+
+ users = (pipeline
+ | beam.Create([ReadOperation...])
+ | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME))
+
+User may also create cloud spanner transaction from the transform called
+`create_transaction` which is available in the SpannerIO API.
+
+The transform is guaranteed to be executed on a consistent snapshot of data,
+utilizing the power of read only transactions. Staleness of data can be
+controlled by providing the `read_timestamp` or `exact_staleness` param values
+in the constructor.
+
+This transform requires root of the pipeline (PBegin) and returns PTransform
+which is passed later to the `ReadFromSpanner` constructor. `ReadFromSpanner`
+pass this transaction PTransform as a singleton side input to the
+`_NaiveSpannerReadDoFn` containing 'session_id' and 'transaction_id'.
+For example:::
+
+ transaction = (pipeline | create_transaction(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ DB_NAME))
+
+ users = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
+ sql='Select * from users', transaction=transaction)
+
+ tweets = pipeline | ReadFromSpanner(PROJECT_ID, INSTANCE_ID, DB_NAME,
+ sql='Select * from tweets', transaction=transaction)
+
+For further details of this transform, please review the docs on the
+:meth:`create_transaction` method available in the SpannerIO API.
+
+ReadFromSpanner takes this transform in the constructor and pass this to the
+read pipeline as the singleton side input.
+"""
+from __future__ import absolute_import
+
+import typing
+from collections import namedtuple
+
+from apache_beam import Create
+from apache_beam import DoFn
+from apache_beam import ParDo
+from apache_beam import Reshuffle
+from apache_beam.pvalue import AsSingleton
+from apache_beam.pvalue import PBegin
+from apache_beam.transforms import PTransform
+from apache_beam.transforms import ptransform_fn
+from apache_beam.transforms.display import DisplayDataItem
+from apache_beam.typehints import with_input_types
+from apache_beam.typehints import with_output_types
+from apache_beam.utils.annotations import experimental
+
+try:
+ from google.cloud.spanner import Client
+ from google.cloud.spanner import KeySet
+ from google.cloud.spanner_v1.database import BatchSnapshot
+except ImportError:
+ Client = None
+ KeySet = None
+ BatchSnapshot = None
+
+__all__ = ['create_transaction', 'ReadFromSpanner', 'ReadOperation']
+
+
+class _SPANNER_TRANSACTION(namedtuple("SPANNER_TRANSACTION", ["transaction"])):
+ """
+ Holds the spanner transaction details.
+ """
+
+ __slots__ = ()
+
+
+class ReadOperation(namedtuple("ReadOperation", ["is_sql", "is_table",
+ "read_operation", "kwargs"])):
+ """
+ Encapsulates a spanner read operation.
+ """
+
+ __slots__ = ()
+
+ @classmethod
+ def query(cls, sql, params=None, param_types=None):
+ """
+ A convenient method to construct ReadOperation from sql query.
+
+ Args:
+ sql: SQL query statement
+ params: (optional) values for parameter replacement. Keys must match the
+ names used in sql
+ param_types: (optional) maps explicit types for one or more param values;
+ required if parameters are passed.
+ """
+
+ if params:
+ assert param_types is not None
+
+ return cls(
+ is_sql=True,
+ is_table=False,
+ read_operation="process_query_batch",
+ kwargs={'sql': sql, 'params': params, 'param_types': param_types}
+ )
+
+ @classmethod
+ def table(cls, table, columns, index="", keyset=None):
+ """
+ A convenient method to construct ReadOperation from table.
+
+ Args:
+ table: name of the table from which to fetch data.
+ columns: names of columns to be retrieved.
+ index: (optional) name of index to use, rather than the table's primary
+ key.
+ keyset: (optional) `KeySet` keys / ranges identifying rows to be
+ retrieved.
+ """
+ keyset = keyset or KeySet(all_=True)
+ if not isinstance(keyset, KeySet):
+ raise ValueError("keyset must be an instance of class "
+ "google.cloud.spanner.KeySet")
+ return cls(
+ is_sql=False,
+ is_table=True,
+ read_operation="process_read_batch",
+ kwargs={'table': table, 'columns': columns, 'index': index,
+ 'keyset': keyset}
+ )
+
+
+class _BeamSpannerConfiguration(namedtuple(
+ "_BeamSpannerConfiguration", ["project", "instance", "database",
+ "credentials", "pool",
+ "snapshot_read_timestamp",
+ "snapshot_exact_staleness"])):
+ """
+ A namedtuple holds the immutable data of the connection string to the cloud
+ spanner.
+ """
+
+ @property
+ def snapshot_options(self):
+ snapshot_options = {}
+ if self.snapshot_exact_staleness:
+ snapshot_options['exact_staleness'] = self.snapshot_exact_staleness
+ if self.snapshot_read_timestamp:
+ snapshot_options['read_timestamp'] = self.snapshot_read_timestamp
+ return snapshot_options
+
+@with_input_types(ReadOperation, typing.Dict[typing.Any, typing.Any])
+@with_output_types(typing.List[typing.Any])
+class _NaiveSpannerReadDoFn(DoFn):
+
+ def __init__(self, spanner_configuration):
+ """
+ A naive version of Spanner read which uses the transaction API of the
+ cloud spanner.
+ https://googleapis.dev/python/spanner/latest/transaction-api.html
+ In Naive reads, this transform performs single reads, where as the
+ Batch reads use the spanner partitioning query to create batches.
+
+ Args:
+ spanner_configuration: (_BeamSpannerConfiguration) Connection details to
+ connect with cloud spanner.
+ """
+ self._spanner_configuration = spanner_configuration
+ self._snapshot = None
+ self._session = None
+
+ def _get_session(self):
+ if self._session is None:
+ session = self._session = self._database.session()
+ session.create()
+ return self._session
+
+ def _close_session(self):
+ if self._session is not None:
+ self._session.delete()
+
+ def setup(self):
+ # setting up client to connect with cloud spanner
+ spanner_client = Client(self._spanner_configuration.project)
+ instance = spanner_client.instance(self._spanner_configuration.instance)
+ self._database = instance.database(self._spanner_configuration.database,
+ pool=self._spanner_configuration.pool)
+
+ def process(self, element, spanner_transaction):
+ # `spanner_transaction` should be the instance of the _SPANNER_TRANSACTION
+ # object.
+ if not isinstance(spanner_transaction, _SPANNER_TRANSACTION):
+ raise ValueError("Invalid transaction object: %s. It should be instance "
+ "of SPANNER_TRANSACTION object created by "
+ "spannerio.create_transaction transform."
+ % type(spanner_transaction))
+
+ transaction_info = spanner_transaction.transaction
+
+ # We used batch snapshot to reuse the same transaction passed through the
+ # side input
+ self._snapshot = BatchSnapshot.from_dict(self._database, transaction_info)
+
+ # getting the transaction from the snapshot's session to run read operation.
+ # with self._snapshot.session().transaction() as transaction:
+ with self._get_session().transaction() as transaction:
+ if element.is_sql is True:
+ transaction_read = transaction.execute_sql
+ elif element.is_table is True:
+ transaction_read = transaction.read
+ else:
+ raise ValueError("ReadOperation is improperly configure: %s" % str(
+ element))
+
+ for row in transaction_read(**element.kwargs):
+ yield row
+
+
+@with_input_types(ReadOperation)
+@with_output_types(typing.Dict[typing.Any, typing.Any])
+class _CreateReadPartitions(DoFn):
+ """
+ A DoFn to create partitions. Uses the Partitioning API (PartitionRead /
+ PartitionQuery) request to start a partitioned query operation. Returns a
+ list of batch information needed to perform the actual queries.
+
+ If the element is the instance of :class:`ReadOperation` is to perform sql
+ query, `PartitionQuery` API is used the create partitions and returns mappings
+ of information used perform actual partitioned reads via
+ :meth:`process_query_batch`.
+
+ If the element is the instance of :class:`ReadOperation` is to perform read
+ from table, `PartitionRead` API is used the create partitions and returns
+ mappings of information used perform actual partitioned reads via
+ :meth:`process_read_batch`.
+ """
+
+ def __init__(self, spanner_configuration):
+ self._spanner_configuration = spanner_configuration
+
+ def setup(self):
+ spanner_client = Client(project=self._spanner_configuration.project,
+ credentials=self._spanner_configuration.credentials)
+ instance = spanner_client.instance(self._spanner_configuration.instance)
+ self._database = instance.database(self._spanner_configuration.database,
+ pool=self._spanner_configuration.pool)
+ self._snapshot = self._database.batch_snapshot(**self._spanner_configuration
+ .snapshot_options)
+ self._snapshot_dict = self._snapshot.to_dict()
+
+ def process(self, element):
+ if element.is_sql is True:
+ partitioning_action = self._snapshot.generate_query_batches
+ elif element.is_table is True:
+ partitioning_action = self._snapshot.generate_read_batches
+ else:
+ raise ValueError("ReadOperation is improperly configure: %s" % str(
+ element))
+
+ for p in partitioning_action(**element.kwargs):
+ yield {"is_sql": element.is_sql, "is_table": element.is_table,
+ "read_operation": element.read_operation, "partitions": p,
+ "transaction_info": self._snapshot_dict}
+
+
+@with_input_types(int)
+@with_output_types(typing.Dict[typing.Any, typing.Any])
+class _CreateTransactionFn(DoFn):
+ """
+ A DoFn to create the transaction of cloud spanner.
+ It connects to the database and and returns the transaction_id and session_id
+ by using the batch_snapshot.to_dict() method available in the google cloud
+ spanner sdk.
+
+ https://googleapis.dev/python/spanner/latest/database-api.html?highlight=
+ batch_snapshot#google.cloud.spanner_v1.database.BatchSnapshot.to_dict
+ """
+
+ def __init__(self, project_id, instance_id, database_id, credentials,
+ pool, read_timestamp,
+ exact_staleness):
+ self._project_id = project_id
+ self._instance_id = instance_id
+ self._database_id = database_id
+ self._credentials = credentials
+ self._pool = pool
+
+ self._snapshot_options = {}
+ if read_timestamp:
+ self._snapshot_options['read_timestamp'] = read_timestamp
+ if exact_staleness:
+ self._snapshot_options['exact_staleness'] = exact_staleness
+ self._snapshot = None
+
+ def setup(self):
+ self._spanner_client = Client(project=self._project_id,
+ credentials=self._credentials)
+ self._instance = self._spanner_client.instance(self._instance_id)
+ self._database = self._instance.database(self._database_id, pool=self._pool)
+
+ def process(self, element, *args, **kwargs):
+ self._snapshot = self._database.batch_snapshot(**self._snapshot_options)
+ return [_SPANNER_TRANSACTION(self._snapshot.to_dict())]
+
+
+@ptransform_fn
+def create_transaction(pbegin, project_id, instance_id, database_id,
+ credentials=None, pool=None, read_timestamp=None,
+ exact_staleness=None):
+ """
+ A PTransform method to create a batch transaction.
+
+ Args:
+ pbegin: Root of the pipeline
+ project_id: Cloud spanner project id. Be sure to use the Project ID,
+ not the Project Number.
+ instance_id: Cloud spanner instance id.
+ database_id: Cloud spanner database id.
+ credentials: (optional) The authorization credentials to attach to requests.
+ These credentials identify this application to the service.
+ If none are specified, the client will attempt to ascertain
+ the credentials from the environment.
+ pool: (optional) session pool to be used by database. If not passed,
+ Spanner Cloud SDK uses the BurstyPool by default.
+ `google.cloud.spanner.BurstyPool`. Ref:
+ https://googleapis.dev/python/spanner/latest/database-api.html?#google.
+ cloud.spanner_v1.database.Database
+ read_timestamp: (optional) An instance of the `datetime.datetime` object to
+ execute all reads at the given timestamp.
+ exact_staleness: (optional) An instance of the `datetime.timedelta`
+ object. These timestamp bounds execute reads at a user-specified
+ timestamp.
+ """
+
+ assert isinstance(pbegin, PBegin)
+
+ return (pbegin | Create([1]) | ParDo(_CreateTransactionFn(
+ project_id, instance_id, database_id, credentials,
+ pool, read_timestamp,
+ exact_staleness)))
+
+@with_input_types(typing.Dict[typing.Any, typing.Any])
+@with_output_types(typing.List[typing.Any])
+class _ReadFromPartitionFn(DoFn):
+ """
+ A DoFn to perform reads from the partition.
+ """
+
+ def __init__(self, spanner_configuration):
+ self._spanner_configuration = spanner_configuration
+
+ def setup(self):
+ spanner_client = Client(self._spanner_configuration.project)
+ instance = spanner_client.instance(self._spanner_configuration.instance)
+ self._database = instance.database(self._spanner_configuration.database,
+ pool=self._spanner_configuration.pool)
+ self._snapshot = self._database.batch_snapshot(**self._spanner_configuration
+ .snapshot_options)
+
+ def process(self, element):
+ self._snapshot = BatchSnapshot.from_dict(
+ self._database,
+ element['transaction_info']
+ )
+
+ if element['is_sql'] is True:
+ read_action = self._snapshot.process_query_batch
+ elif element['is_table'] is True:
+ read_action = self._snapshot.process_read_batch
+ else:
+ raise ValueError("ReadOperation is improperly configure: %s" % str(
+ element))
+
+ for row in read_action(element['partitions']):
+ yield row
+
+ def teardown(self):
+ if self._snapshot:
+ self._snapshot.close()
+
+
+@experimental(extra_message="No backwards-compatibility guarantees.")
+class ReadFromSpanner(PTransform):
+ """
+ A PTransform to perform reads from cloud spanner.
+ ReadFromSpanner uses BatchAPI to perform all read operations.
+ """
+
+ def __init__(self, project_id, instance_id, database_id, pool=None,
+ read_timestamp=None, exact_staleness=None, credentials=None,
+ sql=None, params=None, param_types=None, # with_query
+ table=None, columns=None, index="", keyset=None, # with_table
+ read_operations=None, # for read all
+ transaction=None
+ ):
+ """
+ A PTransform that uses Spanner Batch API to perform reads.
+
+ Args:
+ project_id: Cloud spanner project id. Be sure to use the Project ID,
+ not the Project Number.
+ instance_id: Cloud spanner instance id.
+ database_id: Cloud spanner database id.
+ pool: (optional) session pool to be used by database. If not passed,
+ Spanner Cloud SDK uses the BurstyPool by default.
+ `google.cloud.spanner.BurstyPool`. Ref:
+ https://googleapis.dev/python/spanner/latest/database-api.html?#google.
+ cloud.spanner_v1.database.Database
+ read_timestamp: (optional) An instance of the `datetime.datetime` object
+ to execute all reads at the given timestamp. By default, set to `None`.
+ exact_staleness: (optional) An instance of the `datetime.timedelta`
+ object. These timestamp bounds execute reads at a user-specified
+ timestamp. By default, set to `None`.
+ credentials: (optional) The authorization credentials to attach to
+ requests. These credentials identify this application to the service.
+ If none are specified, the client will attempt to ascertain
+ the credentials from the environment. By default, set to `None`.
+ sql: (optional) SQL query statement.
+ params: (optional) Values for parameter replacement. Keys must match the
+ names used in sql. By default, set to `None`.
+ param_types: (optional) maps explicit types for one or more param values;
+ required if params are passed. By default, set to `None`.
+ table: (optional) Name of the table from which to fetch data. By
+ default, set to `None`.
+ columns: (optional) List of names of columns to be retrieved; required if
+ the table is passed. By default, set to `None`.
+ index: (optional) name of index to use, rather than the table's primary
+ key. By default, set to `None`.
+ keyset: (optional) keys / ranges identifying rows to be retrieved. By
+ default, set to `None`.
+ read_operations: (optional) List of the objects of :class:`ReadOperation`
+ to perform read all. By default, set to `None`.
+ transaction: (optional) PTransform of the :meth:`create_transaction` to
+ perform naive read on cloud spanner. By default, set to `None`.
+ """
+ self._configuration = _BeamSpannerConfiguration(
+ project=project_id, instance=instance_id, database=database_id,
+ credentials=credentials, pool=pool,
+ snapshot_read_timestamp=read_timestamp,
+ snapshot_exact_staleness=exact_staleness
+ )
+
+ self._read_operations = read_operations
+ self._transaction = transaction
+
+ if self._read_operations is None:
+ if table is not None:
+ if columns is None:
+ raise ValueError("Columns are required with the table name.")
+ self._read_operations = [ReadOperation.table(
+ table=table, columns=columns, index=index, keyset=keyset)]
+ elif sql is not None:
+ self._read_operations = [ReadOperation.query(
+ sql=sql, params=params, param_types=param_types)]
+
+ def expand(self, pbegin):
+ if self._read_operations is not None and isinstance(pbegin,
+ PBegin):
+ pcoll = pbegin.pipeline | Create(self._read_operations)
+ elif not isinstance(pbegin, PBegin):
+ if self._read_operations is not None:
+ raise ValueError("Read operation in the constructor only works with "
+ "the root of the pipeline.")
+ pcoll = pbegin
+ else:
+ raise ValueError("Spanner required read operation, sql or table "
+ "with columns.")
+
+ if self._transaction is None:
+ # reading as batch read using the spanner partitioning query to create
+ # batches.
+ p = (pcoll
+ | 'Generate Partitions' >> ParDo(_CreateReadPartitions(
+ spanner_configuration=self._configuration))
+ | 'Reshuffle' >> Reshuffle()
+ | 'Read From Partitions' >> ParDo(_ReadFromPartitionFn(
+ spanner_configuration=self._configuration)))
+ else:
+ # reading as naive read, in which we don't make batches and execute the
+ # queries as a single read.
+ p = (pcoll
+ | 'Reshuffle' >> Reshuffle().with_input_types(ReadOperation)
+ | 'Perform Read' >> ParDo(
+ _NaiveSpannerReadDoFn(spanner_configuration=self._configuration),
+ AsSingleton(self._transaction)))
+ return p
+
+ def display_data(self):
+ res = dict()
+ sql = []
+ table = []
+ if self._read_operations is not None:
+ for ro in self._read_operations:
+ if ro.is_sql is True:
+ sql.append(ro.kwargs)
+ elif ro.is_table is True:
+ table.append(ro.kwargs)
+
+ if sql:
+ res['sql'] = DisplayDataItem(str(sql), label='Sql')
+ if table:
+ res['table'] = DisplayDataItem(str(table), label='Table')
+
+ if self._transaction:
+ res['transaction'] = DisplayDataItem(str(self._transaction),
+ label='transaction')
+
+ return res
diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
new file mode 100644
index 0000000..be838f4
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
@@ -0,0 +1,338 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import
+
+import datetime
+import logging
+import random
+import string
+import unittest
+
+import mock
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+# Protect against environments where spanner library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
+try:
+ from google.cloud import spanner
+ from apache_beam.io.gcp.experimental.spannerio import (create_transaction,
+ ReadOperation,
+ ReadFromSpanner) # pylint: disable=unused-import
+ # disable=unused-import
+except ImportError:
+ spanner = None
+# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
+
+
+MAX_DB_NAME_LENGTH = 30
+TEST_PROJECT_ID = 'apache-beam-testing'
+TEST_INSTANCE_ID = 'beam-test'
+TEST_DATABASE_PREFIX = 'spanner-testdb-'
+FAKE_TRANSACTION_INFO = {"session_id": "qwerty", "transaction_id": "qwerty"}
+FAKE_ROWS = [[1, 'Alice'], [2, 'Bob'], [3, 'Carl'], [4, 'Dan'], [5, 'Evan'],
+ [6, 'Floyd']]
+
+
+def _generate_database_name():
+ mask = string.ascii_lowercase + string.digits
+ length = MAX_DB_NAME_LENGTH - 1 - len(TEST_DATABASE_PREFIX)
+ return TEST_DATABASE_PREFIX + ''.join(random.choice(mask) for i in range(
+ length))
+
+
+def _generate_test_data():
+ mask = string.ascii_lowercase + string.digits
+ length = 100
+ return [('users', ['Key', 'Value'], [(x, ''.join(
+ random.choice(mask) for _ in range(length))) for x in range(1, 5)])]
+
+
+@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.')
+@mock.patch('apache_beam.io.gcp.experimental.spannerio.Client')
+@mock.patch('apache_beam.io.gcp.experimental.spannerio.BatchSnapshot')
+class SpannerReadTest(unittest.TestCase):
+
+ def test_read_with_query_batch(self, mock_batch_snapshot_class,
+ mock_client_class):
+ mock_snapshot = mock.MagicMock()
+
+ mock_snapshot.generate_query_batches.return_value = [
+ {'query': {'sql': 'SELECT * FROM users'},
+ 'partition': 'test_partition'} for _ in range(3)]
+ mock_snapshot.process_query_batch.side_effect = [
+ FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]]
+
+ ro = [ReadOperation.query("Select * from users")]
+ pipeline = TestPipeline()
+
+ read = (pipeline
+ | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID,
+ _generate_database_name(),
+ sql="SELECT * FROM users"))
+
+ readall = (pipeline
+ | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name(),
+ read_operations=ro))
+
+ readpipeline = (pipeline
+ | 'create reads' >> beam.Create(ro)
+ | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name()))
+
+ pipeline.run()
+ assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
+ assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
+ assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
+
+ def test_read_with_table_batch(self, mock_batch_snapshot_class,
+ mock_client_class):
+ mock_snapshot = mock.MagicMock()
+ mock_snapshot.generate_read_batches.return_value = [{
+ 'read': {'table': 'users', 'keyset': {'all': True},
+ 'columns': ['Key', 'Value'], 'index': ''},
+ 'partition': 'test_partition'} for _ in range(3)]
+ mock_snapshot.process_read_batch.side_effect = [
+ FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]]
+
+ ro = [ReadOperation.table("users", ["Key", "Value"])]
+ pipeline = TestPipeline()
+
+ read = (pipeline
+ | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID,
+ _generate_database_name(),
+ table="users",
+ columns=["Key", "Value"]))
+
+ readall = (pipeline
+ | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name(),
+ read_operations=ro))
+
+ readpipeline = (pipeline
+ | 'create reads' >> beam.Create(ro)
+ | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name()))
+
+ pipeline.run()
+ assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
+ assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
+ assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
+
+ with self.assertRaises(ValueError):
+ # Test the exception raised when user passes the read operations in the
+ # constructor and also in the pipeline.
+ _ = (pipeline | 'reads error' >> ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ table="users"
+ ))
+ pipeline.run()
+
+ def test_read_with_index(self, mock_batch_snapshot_class,
+ mock_client_class):
+ mock_snapshot = mock.MagicMock()
+ mock_snapshot.generate_read_batches.return_value = [{
+ 'read': {'table': 'users', 'keyset': {'all': True},
+ 'columns': ['Key', 'Value'], 'index': ''},
+ 'partition': 'test_partition'} for _ in range(3)]
+ mock_snapshot.process_read_batch.side_effect = [
+ FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]]
+ ro = [ReadOperation.table("users", ["Key", "Value"], index="Key")]
+ pipeline = TestPipeline()
+ read = (pipeline
+ | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID,
+ _generate_database_name(),
+ table="users",
+ columns=["Key", "Value"]))
+ readall = (pipeline
+ | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name(),
+ read_operations=ro))
+ readpipeline = (pipeline
+ | 'create reads' >> beam.Create(ro)
+ | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name()))
+ pipeline.run()
+ assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
+ assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
+ assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
+ with self.assertRaises(ValueError):
+ # Test the exception raised when user passes the read operations in the
+ # constructor and also in the pipeline.
+ _ = (pipeline | 'reads error' >> ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ table="users"
+ ))
+ pipeline.run()
+
+ def test_read_with_transaction(self, mock_batch_snapshot_class,
+ mock_client_class):
+ mock_client = mock.MagicMock()
+ mock_instance = mock.MagicMock()
+ mock_database = mock.MagicMock()
+ mock_snapshot = mock.MagicMock()
+
+ mock_client_class.return_value = mock_client
+ mock_client.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_database
+ mock_database.batch_snapshot.return_value = mock_snapshot
+ mock_batch_snapshot_class.return_value = mock_snapshot
+ mock_batch_snapshot_class.from_dict.return_value = mock_snapshot
+ mock_snapshot.to_dict.return_value = FAKE_TRANSACTION_INFO
+
+ mock_session = mock.MagicMock()
+ mock_transaction_ctx = mock.MagicMock()
+ mock_transaction = mock.MagicMock()
+
+ mock_snapshot._get_session.return_value = mock_session
+ mock_session.transaction.return_value = mock_transaction
+ mock_transaction.__enter__.return_value = mock_transaction_ctx
+ mock_transaction_ctx.execute_sql.return_value = FAKE_ROWS
+
+ ro = [ReadOperation.query("Select * from users")]
+ p = TestPipeline()
+
+ transaction = (p | create_transaction(
+ project_id=TEST_PROJECT_ID, instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ exact_staleness=datetime.timedelta(seconds=10)))
+
+ read_query = (p | 'with query' >> ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ transaction=transaction,
+ sql="Select * from users"
+ ))
+
+ read_table = (p | 'with table' >> ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ transaction=transaction,
+ table="users",
+ columns=["Key", "Value"]
+ ))
+
+ read_indexed_table = (p | 'with index' >> ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ transaction=transaction,
+ table="users",
+ index="Key",
+ columns=["Key", "Value"]
+ ))
+
+ read = (p | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name(),
+ transaction=transaction,
+ read_operations=ro))
+
+ read_pipeline = (p
+ | 'create read operations' >> beam.Create(ro)
+ | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name(),
+ transaction=transaction))
+
+ p.run()
+
+ assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery')
+ assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable')
+ assert_that(read_indexed_table, equal_to(FAKE_ROWS),
+ label='checkTableIndex')
+ assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll')
+ assert_that(read_pipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
+
+ with self.assertRaises(ValueError):
+ # Test the exception raised when user passes the read operations in the
+ # constructor and also in the pipeline.
+ _ = (p
+ | 'create read operations2' >> beam.Create(ro)
+ | 'reads with error' >> ReadFromSpanner(TEST_PROJECT_ID,
+ TEST_INSTANCE_ID,
+ _generate_database_name(),
+ transaction=transaction,
+ read_operations=ro))
+ p.run()
+
+ def test_invalid_transaction(self, mock_batch_snapshot_class,
+ mock_client_class):
+ with self.assertRaises(ValueError):
+ p = TestPipeline()
+ transaction = (p | beam.Create([{"invalid": "transaction"}]))
+ _ = (p | 'with query' >> ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ transaction=transaction,
+ sql="Select * from users"
+ ))
+ p.run()
+
+
+ def test_display_data(self, *args):
+ dd_sql = ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ sql="Select * from users"
+ ).display_data()
+
+ dd_table = ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ table="users",
+ columns=['id', 'name']
+ ).display_data()
+
+ dd_transaction = ReadFromSpanner(
+ project_id=TEST_PROJECT_ID,
+ instance_id=TEST_INSTANCE_ID,
+ database_id=_generate_database_name(),
+ table="users",
+ columns=['id', 'name'],
+ transaction={"transaction_id": "test123", "session_id": "test456"}
+ ).display_data()
+
+ self.assertTrue("sql" in dd_sql)
+ self.assertTrue("table" in dd_table)
+ self.assertTrue("table" in dd_transaction)
+ self.assertTrue("transaction" in dd_transaction)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
diff --git a/sdks/python/apache_beam/io/hdfs_integration_test/Dockerfile b/sdks/python/apache_beam/io/hdfs_integration_test/Dockerfile
index 788b8d2..1c78181 100644
--- a/sdks/python/apache_beam/io/hdfs_integration_test/Dockerfile
+++ b/sdks/python/apache_beam/io/hdfs_integration_test/Dockerfile
@@ -24,22 +24,13 @@
WORKDIR /app
ENV HDFSCLI_CONFIG /app/sdks/python/apache_beam/io/hdfs_integration_test/hdfscli.cfg
-RUN pip install --no-cache-dir holdup gsutil
-RUN gsutil cp gs://dataflow-samples/shakespeare/kinglear.txt .
-# Install Beam and dependencies.
-ADD sdks/python /app/sdks/python
-ADD model /app/model
-RUN cd sdks/python && \
- python setup.py sdist && \
- pip install --no-cache-dir $(ls dist/apache-beam-*.tar.gz | tail -n1)[gcp]
+# Add Beam SDK sources.
+COPY sdks/python /app/sdks/python
+COPY model /app/model
+
+# This step should look like setupVirtualenv minus virtualenv creation.
+RUN pip install --no-cache-dir tox==3.11.1 -r sdks/python/build-requirements.txt
# Run wordcount, and write results to HDFS.
-CMD holdup -t 45 http://namenode:50070 http://datanode:50075 && \
- echo "Waiting for safe mode to end." && \
- sleep 45 && \
- hdfscli -v -v -v upload -f kinglear.txt / && \
- python -m apache_beam.examples.wordcount \
- --input hdfs://kinglear* \
- --output hdfs://py-wordcount-integration \
- --hdfs_host namenode --hdfs_port 50070 --hdfs_user root
+CMD cd sdks/python && tox -e hdfs_integration_test
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index d584f6b..b6902ff 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -118,6 +118,10 @@
return outputs
+ # We must mark this method as not a test or else its name is a matcher for
+ # nosetest tests.
+ apply_TestStream.__test__ = False
+
def run_pipeline(self, pipeline, options):
from apache_beam.pipeline import PipelineVisitor
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 65bfdbd..a396d17 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -204,6 +204,8 @@
'google-cloud-bigtable>=0.31.1,<1.1.0',
# [BEAM-4543] googledatastore is not supported in Python 3.
'proto-google-cloud-datastore-v1>=0.90.0,<=0.90.4; python_version < "3.0"',
+ 'google-cloud-spanner>=1.7.1<1.8.0',
+ 'grpcio-gcp>=0.2.2,<1',
]
INTERACTIVE_BEAM = [
diff --git a/sdks/python/test-suites/direct/py2/build.gradle b/sdks/python/test-suites/direct/py2/build.gradle
index 436cd7e..acb7fee 100644
--- a/sdks/python/test-suites/direct/py2/build.gradle
+++ b/sdks/python/test-suites/direct/py2/build.gradle
@@ -77,11 +77,10 @@
}
task hdfsIntegrationTest {
- dependsOn 'installGcpTest'
doLast {
exec {
executable 'sh'
- args '-c', ". ${envdir}/bin/activate && ${rootDir}/sdks/python/apache_beam/io/hdfs_integration_test/hdfs_integration_test.sh python:2"
+ args '-c', "${rootDir}/sdks/python/apache_beam/io/hdfs_integration_test/hdfs_integration_test.sh python:2"
}
}
}
diff --git a/sdks/python/test-suites/direct/py37/build.gradle b/sdks/python/test-suites/direct/py37/build.gradle
index c5c30fb..1f54f90 100644
--- a/sdks/python/test-suites/direct/py37/build.gradle
+++ b/sdks/python/test-suites/direct/py37/build.gradle
@@ -56,11 +56,10 @@
}
task hdfsIntegrationTest {
- dependsOn 'installGcpTest'
doLast {
exec {
executable 'sh'
- args '-c', ". ${envdir}/bin/activate && ${pythonDir}/apache_beam/io/hdfs_integration_test/hdfs_integration_test.sh python:3.7"
+ args '-c', "${pythonDir}/apache_beam/io/hdfs_integration_test/hdfs_integration_test.sh python:3.7"
}
}
}
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index c78bddf..e6908d9 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -315,3 +315,30 @@
coverage report --skip-covered
# Generate report in xml format
coverage xml
+
+[testenv:hdfs_integration_test]
+# Used by hdfs_integration_test.sh. Do not run this directly, as it depends on
+# nodes defined in hdfs_integration_test/docker-compose.yml.
+deps =
+ -r build-requirements.txt
+ gsutil==4.47
+ holdup==1.8.0
+extras =
+ gcp
+whitelist_externals =
+ echo
+ sleep
+passenv = HDFSCLI_CONFIG
+commands =
+ holdup -t 45 http://namenode:50070 http://datanode:50075
+ echo "Waiting for safe mode to end."
+ sleep 45
+ gsutil cp gs://dataflow-samples/shakespeare/kinglear.txt .
+ hdfscli -v -v -v upload -f kinglear.txt /
+ python -m apache_beam.examples.wordcount \
+ --input hdfs://kinglear* \
+ --output hdfs://py-wordcount-integration \
+ --hdfs_host namenode --hdfs_port 50070 --hdfs_user root
+# Disable pip check. TODO: remove this once gsutil does not conflict with
+# apache_beam (oauth2client).
+commands_pre =