Merge pull request #14746 from ibzib/BEAM-12302
[BEAM-12302] [zetasql] Support/test TIMESTAMP as UDF argument/return …
diff --git a/.test-infra/jenkins/CommonJobProperties.groovy b/.test-infra/jenkins/CommonJobProperties.groovy
index 851fc0b..295cc89 100644
--- a/.test-infra/jenkins/CommonJobProperties.groovy
+++ b/.test-infra/jenkins/CommonJobProperties.groovy
@@ -183,6 +183,11 @@
context.switches("-Dorg.gradle.jvmargs=-Xms2g")
context.switches("-Dorg.gradle.jvmargs=-Xmx4g")
+ // Disable file system watching for CI builds
+ // Builds are performed on a clean clone and files aren't modified, so
+ // there's no value in watching for changes.
+ context.switches("-Dorg.gradle.vfs.watch=false")
+
// Include dependency licenses when build docker images on Jenkins, see https://s.apache.org/zt68q
context.switches("-Pdocker-pull-licenses")
}
diff --git a/.test-infra/jenkins/job_PerformanceTests_KafkaIO_IT.groovy b/.test-infra/jenkins/job_PerformanceTests_KafkaIO_IT.groovy
index 59cb4e9..87767f8 100644
--- a/.test-infra/jenkins/job_PerformanceTests_KafkaIO_IT.groovy
+++ b/.test-infra/jenkins/job_PerformanceTests_KafkaIO_IT.groovy
@@ -75,7 +75,7 @@
bigQueryTable : 'kafkaioit_results_sdf_wrapper',
influxMeasurement : 'kafkaioit_results_sdf_wrapper',
// TODO(BEAM-11779) remove shuffle_mode=appliance with runner v2 once issue is resolved.
- experiments : 'beam_fn_api,use_runner_v2,shuffle_mode=appliance,use_unified_worker',
+ experiments : 'use_runner_v2,shuffle_mode=appliance,use_unified_worker',
]
Map dataflowRunnerV2SdfPipelineOptions = pipelineOptions + [
@@ -90,7 +90,7 @@
bigQueryTable : 'kafkaioit_results_runner_v2',
influxMeasurement : 'kafkaioit_results_runner_v2',
// TODO(BEAM-11779) remove shuffle_mode=appliance with runner v2 once issue is resolved.
- experiments : 'beam_fn_api,use_runner_v2,shuffle_mode=appliance,use_unified_worker',
+ experiments : 'use_runner_v2,shuffle_mode=appliance,use_unified_worker',
]
steps {
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
index 0e3e628..1ac4562 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
@@ -27,7 +27,7 @@
description('Runs the ValidatesRunner suite on the Dataflow runner.')
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
previousNames(/beam_PostCommit_Java_ValidatesRunner_Dataflow_Gradle/)
// Publish all test results to Jenkins
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
index af9e25e..6ba9685 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
@@ -28,7 +28,7 @@
def JAVA_11_HOME = '/usr/lib/jvm/java-11-openjdk-amd64'
def JAVA_8_HOME = '/usr/lib/jvm/java-8-openjdk-amd64'
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
publishers {
archiveJunit('**/build/test-results/**/*.xml')
}
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
index 0de085d..3b3d8ed 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
@@ -27,7 +27,7 @@
description('Runs Java ValidatesRunner suite on the Dataflow runner V2 forcing streaming mode.')
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 330)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 450)
// Publish all test results to Jenkins
publishers {
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesContainer_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesContainer_Dataflow.groovy
index e722bdc..bc43ecb 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesContainer_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesContainer_Dataflow.groovy
@@ -31,7 +31,7 @@
commonJobProperties.setTopLevelMainJobProperties(delegate)
publishers {
- archiveJunit('**/nosetests*.xml')
+ archiveJunit('**/pytest*.xml')
}
// Execute shell command to test Python SDK.
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
index e847928..5631649 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
@@ -26,7 +26,7 @@
description('Runs Python ValidatesRunner suite on the Dataflow runner.')
// Set common parameters.
- commonJobProperties.setTopLevelMainJobProperties(delegate)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
publishers {
archiveJunit('**/nosetests*.xml')
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
index c219de6..e164514 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
@@ -26,7 +26,7 @@
description('Runs Python ValidatesRunner suite on the Dataflow runner v2.')
// Set common parameters.
- commonJobProperties.setTopLevelMainJobProperties(delegate)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
publishers {
archiveJunit('**/nosetests*.xml')
diff --git a/CHANGES.md b/CHANGES.md
index 0fae45a..af72444 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -118,24 +118,20 @@
* Fixed X (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
-# [2.29.0] - Release branch cut, any updates should be cherry-picked
+# [2.29.0] - 2021-04-29
## Highlights
-* New highly anticipated feature X added to Python SDK ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
-* New highly anticipated feature Y added to Java SDK ([BEAM-Y](https://issues.apache.org/jira/browse/BEAM-Y)).
* Spark Classic and Portable runners officially support Spark 3 ([BEAM-7093](https://issues.apache.org/jira/browse/BEAM-7093)).
* Official Java 11 support for most runners (Dataflow, Flink, Spark) ([BEAM-2530](https://issues.apache.org/jira/browse/BEAM-2530)).
* DataFrame API now supports GroupBy.apply ([BEAM-11628](https://issues.apache.org/jira/browse/BEAM-11628)).
## I/Os
-* Support for X source added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
* Added support for S3 filesystem on AWS SDK V2 (Java) ([BEAM-7637](https://issues.apache.org/jira/browse/BEAM-7637))
## New Features / Improvements
-* X feature added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
* DataFrame API now supports pandas 1.2.x ([BEAM-11531](https://issues.apache.org/jira/browse/BEAM-11531)).
* Multiple DataFrame API bugfixes ([BEAM-12071](https://issues.apache/jira/browse/BEAM-12071), [BEAM-11929](https://issues.apache/jira/browse/BEAM-11929))
@@ -145,17 +141,11 @@
To restore the old behavior, one can register `FakeDeterministicFastPrimitivesCoder` with
`beam.coders.registry.register_fallback_coder(beam.coders.coders.FakeDeterministicFastPrimitivesCoder())`
or use the `allow_non_deterministic_key_coders` pipeline option.
-* X behavior was changed ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
## Deprecations
-* X behavior is deprecated and will be removed in X versions ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
* Support for Flink 1.8 and 1.9 will be removed in the next release (2.30.0) ([BEAM-11948](https://issues.apache.org/jira/browse/BEAM-11948)).
-## Known Issues
-
-* Fixed X (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
-
# [2.28.0] - 2021-02-22
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 3bde968..d2e2898 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -417,7 +417,6 @@
// a dependency version which should match across multiple
// Maven artifacts.
def activemq_version = "5.14.5"
- def antlr_version = "4.9.2"
def autovalue_version = "1.8.1"
def aws_java_sdk_version = "1.11.974"
def aws_java_sdk2_version = "2.15.31"
@@ -442,7 +441,6 @@
def jaxb_api_version = "2.3.3"
def jsr305_version = "3.0.2"
def kafka_version = "2.4.1"
- def log4j_version = "2.14.1"
def nemo_version = "0.1"
def netty_version = "4.1.52.Final"
def postgres_version = "42.2.16"
@@ -468,8 +466,8 @@
activemq_junit : "org.apache.activemq.tooling:activemq-junit:$activemq_version",
activemq_kahadb_store : "org.apache.activemq:activemq-kahadb-store:$activemq_version",
activemq_mqtt : "org.apache.activemq:activemq-mqtt:$activemq_version",
- antlr : "org.antlr:antlr4:$antlr_version",
- antlr_runtime : "org.antlr:antlr4-runtime:$antlr_version",
+ antlr : "org.antlr:antlr4:4.7",
+ antlr_runtime : "org.antlr:antlr4-runtime:4.7",
args4j : "args4j:args4j:2.33",
auto_value_annotations : "com.google.auto.value:auto-value-annotations:$autovalue_version",
avro : "org.apache.avro:avro:1.8.2",
@@ -597,8 +595,6 @@
junit : "junit:junit:4.13.1",
kafka : "org.apache.kafka:kafka_2.11:$kafka_version",
kafka_clients : "org.apache.kafka:kafka-clients:$kafka_version",
- log4j_api : "org.apache.logging.log4j:log4j-api:$log4j_version",
- log4j_core : "org.apache.logging.log4j:log4j-core:$log4j_version",
mockito_core : "org.mockito:mockito-core:3.7.7",
mongo_java_driver : "org.mongodb:mongo-java-driver:3.12.7",
nemo_compiler_frontend_beam : "org.apache.nemo:nemo-compiler-frontend-beam:$nemo_version",
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
index 40e7383..5cba1d4 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
@@ -42,7 +42,7 @@
static def alpn_api_version = "1.1.2.v20150522"
static def npn_api_version = "1.1.1.v20141010"
static def jboss_marshalling_version = "1.4.11.Final"
- static def jboss_modules_version = "1.11.0.Final"
+ static def jboss_modules_version = "1.1.0.Beta1"
/** Returns the list of compile time dependencies. */
static List<String> dependencies() {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
index aae35ab..e32edf7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
@@ -26,7 +26,6 @@
import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.POutput;
-import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.PValues;
import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.sdk.values.TupleTag;
@@ -41,17 +40,16 @@
private ReplacementOutputs() {}
public static Map<PCollection<?>, ReplacementOutput> singleton(
- Map<TupleTag<?>, PCollection<?>> original, PValue replacement) {
+ Map<TupleTag<?>, PCollection<?>> original, POutput replacement) {
Entry<TupleTag<?>, PCollection<?>> originalElement =
Iterables.getOnlyElement(original.entrySet());
- TupleTag<?> replacementTag = Iterables.getOnlyElement(replacement.expand().entrySet()).getKey();
- PCollection<?> replacementCollection =
- (PCollection<?>) Iterables.getOnlyElement(replacement.expand().entrySet()).getValue();
+ Entry<TupleTag<?>, PCollection<?>> replacementElement =
+ Iterables.getOnlyElement(PValues.expandOutput(replacement).entrySet());
return Collections.singletonMap(
- replacementCollection,
+ replacementElement.getValue(),
ReplacementOutput.of(
TaggedPValue.of(originalElement.getKey(), originalElement.getValue()),
- TaggedPValue.of(replacementTag, replacementCollection)));
+ TaggedPValue.of(replacementElement.getKey(), replacementElement.getValue())));
}
public static Map<PCollection<?>, ReplacementOutput> tagged(
diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle
index 43f507d..e8cb290 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -151,7 +151,7 @@
"--tempRoot=${dataflowValidatesTempRoot}",
"--sdkContainerImage=${dockerImageContainer}:${dockerTag}",
// TODO(BEAM-11779) remove shuffle_mode=appliance with runner v2 once issue is resolved.
- "--experiments=beam_fn_api,use_unified_worker,use_runner_v2,shuffle_mode=appliance",
+ "--experiments=use_unified_worker,use_runner_v2,shuffle_mode=appliance",
]
def commonLegacyExcludeCategories = [
diff --git a/runners/google-cloud-dataflow-java/examples/build.gradle b/runners/google-cloud-dataflow-java/examples/build.gradle
index b128bc1..b52469b 100644
--- a/runners/google-cloud-dataflow-java/examples/build.gradle
+++ b/runners/google-cloud-dataflow-java/examples/build.gradle
@@ -41,7 +41,7 @@
def gcsTempRoot = project.findProperty('gcsTempRoot') ?: 'gs://temp-storage-for-end-to-end-tests/'
def dockerImageName = project(':runners:google-cloud-dataflow-java').ext.dockerImageName
// If -PuseExecutableStage is set, the use_executable_stage_bundle_execution wil be enabled.
-def fnapiExperiments = project.hasProperty('useExecutableStage') ? 'beam_fn_api,beam_fn_api_use_deprecated_read,use_executable_stage_bundle_execution' : "beam_fn_api,beam_fn_api_use_deprecated_read"
+def fnapiExperiments = project.hasProperty('useExecutableStage') ? 'beam_fn_api_use_deprecated_read,use_executable_stage_bundle_execution' : "beam_fn_api,beam_fn_api_use_deprecated_read"
def commonConfig = { dataflowWorkerJar, workerHarnessContainerImage = '', additionalOptions = [] ->
// return the preevaluated configuration closure
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
index 12c9852..f6e683e 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
@@ -76,7 +76,7 @@
@Test
public void testFnApiSingleOutputOverrideNonCrashing() throws Exception {
- DataflowPipelineOptions options = buildPipelineOptions("--experiments=beam_fn_api");
+ DataflowPipelineOptions options = buildPipelineOptions();
options.setRunner(DataflowRunner.class);
Pipeline pipeline = Pipeline.create(options);
@@ -113,7 +113,7 @@
+ "exposes a way to know when the replacement is not required by checking that the "
+ "preceding ParDos to a GBK are key preserving.")
public void testFnApiMultiOutputOverrideNonCrashing() throws Exception {
- DataflowPipelineOptions options = buildPipelineOptions("--experiments=beam_fn_api");
+ DataflowPipelineOptions options = buildPipelineOptions();
options.setRunner(DataflowRunner.class);
Pipeline pipeline = Pipeline.create(options);
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/logging/BeamFnLoggingServiceTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/logging/BeamFnLoggingServiceTest.java
index 001a2a0..06e09fc 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/logging/BeamFnLoggingServiceTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/logging/BeamFnLoggingServiceTest.java
@@ -115,7 +115,7 @@
}
}
- @Test
+ @Test(timeout = 5000)
public void testMultipleClientsFailingIsHandledGracefullyByServer() throws Exception {
Collection<Callable<Void>> tasks = new ArrayList<>();
ConcurrentLinkedQueue<BeamFnApi.LogEntry> logs = new ConcurrentLinkedQueue<>();
diff --git a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
index 67a0e80..0b3880a 100644
--- a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
+++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
@@ -173,12 +173,8 @@
private TranslationContext translatePipeline(Pipeline pipeline) {
PipelineTranslator.detectTranslationMode(pipeline, options);
- // Default to using the primitive versions of Read.Bounded and Read.Unbounded if we are
- // executing an unbounded pipeline or the user specifically requested it.
- if (options.isStreaming()
- || ExperimentalOptions.hasExperiment(
- pipeline.getOptions(), "beam_fn_api_use_deprecated_read")
- || ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) {
+ if (!ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_sdf_read")) {
+ // Default to using the primitive versions of Read.Bounded and Read.Unbounded.
pipeline.replaceAll(ImmutableList.of(KafkaIO.Read.KAFKA_READ_OVERRIDE));
SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline);
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
index 5369409..131dcec 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
@@ -161,12 +161,8 @@
// visit the pipeline to determine the translation mode
detectTranslationMode(pipeline);
- // Default to using the primitive versions of Read.Bounded and Read.Unbounded if we are
- // executing an unbounded pipeline or the user specifically requested it.
- if (pipelineOptions.isStreaming()
- || ExperimentalOptions.hasExperiment(
- pipeline.getOptions(), "beam_fn_api_use_deprecated_read")
- || ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) {
+ if (!ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_sdf_read")) {
+ // Default to using the primitive versions of Read.Bounded and Read.Unbounded.
pipeline.replaceAll(ImmutableList.of(KafkaIO.Read.KAFKA_READ_OVERRIDE));
SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline);
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java
index 37d9d54..2e52a87 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java
@@ -81,15 +81,13 @@
public SparkPipelineResult run(Pipeline pipeline) {
boolean isStreaming =
options.isStreaming() || options.as(TestSparkPipelineOptions.class).isForceStreaming();
- // Default to using the primitive versions of Read.Bounded and Read.Unbounded if we are
- // executing an unbounded pipeline or the user specifically requested it.
- if (isStreaming
- || ExperimentalOptions.hasExperiment(
- pipeline.getOptions(), "beam_fn_api_use_deprecated_read")
- || ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) {
+
+ if (!ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_sdf_read")) {
+ // Default to using the primitive versions of Read.Bounded and Read.Unbounded.
pipeline.replaceAll(ImmutableList.of(KafkaIO.Read.KAFKA_READ_OVERRIDE));
SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline);
}
+
JavaSparkContext jsc = new JavaSparkContext("local[1]", "Debug_Pipeline");
JavaStreamingContext jssc =
new JavaStreamingContext(jsc, new org.apache.spark.streaming.Duration(1000));
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
index 6b2782d..c9bb83d 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java
@@ -81,7 +81,7 @@
.apply(TextIO.write().to("!!PLACEHOLDER-OUTPUT-DIR!!").withNumShards(3).withSuffix(".txt"));
final String expectedPipeline =
- "_.<org.apache.beam.sdk.io.Read$Bounded>\n"
+ "sparkContext.<readFrom(org.apache.beam.sdk.transforms.Create$Values$CreateSource)>()\n"
+ "_.mapPartitions("
+ "new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n"
+ "_.mapPartitions(new org.apache.beam.sdk.transforms.Contextful())\n"
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go b/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
index a011a78..0bd33bd 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
@@ -51,7 +51,6 @@
// Registry retains mappings from go types to Schemas and LogicalTypes.
type Registry struct {
- lastShortID int64
typeToSchema map[reflect.Type]*pipepb.Schema
idToType map[string]reflect.Type
syntheticToUser map[reflect.Type]reflect.Type
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go b/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
index abb6f10..087d8c1 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
@@ -26,7 +26,9 @@
package schema
import (
+ "bytes"
"fmt"
+ "hash/fnv"
"reflect"
"strings"
@@ -72,8 +74,19 @@
defaultRegistry.RegisterType(ut)
}
-func getUUID() string {
- return uuid.New().String()
+// getUUID generates a UUID using the string form of the type name.
+func getUUID(ut reflect.Type) string {
+ // String produces non-empty output for pointer and slice types.
+ typename := ut.String()
+ hasher := fnv.New128a()
+ if n, err := hasher.Write([]byte(typename)); err != nil || n != len(typename) {
+ panic(fmt.Sprintf("unable to generate schema uuid for %s, wrote out %d bytes, want %d: err %v", typename, n, len(typename), err))
+ }
+ id, err := uuid.NewRandomFromReader(bytes.NewBuffer(hasher.Sum(nil)))
+ if err != nil {
+ panic(fmt.Sprintf("unable to genereate schema uuid for type %s: %v", typename, err))
+ }
+ return id.String()
}
// Registered returns whether the given type has been registered with
@@ -350,7 +363,7 @@
if lID != "" {
schm.Options = append(schm.Options, logicalOption(lID))
}
- schm.Id = getUUID()
+ schm.Id = getUUID(ot)
r.typeToSchema[ot] = schm
r.idToType[schm.GetId()] = ot
return schm, nil
@@ -365,7 +378,7 @@
// Cache the pointer type here with it's own id.
pt := reflect.PtrTo(t)
schm = proto.Clone(schm).(*pipepb.Schema)
- schm.Id = getUUID()
+ schm.Id = getUUID(pt)
schm.Options = append(schm.Options, &pipepb.Option{
Name: optGoNillable,
})
@@ -454,7 +467,7 @@
schm := ftype.GetRowType().GetSchema()
schm = proto.Clone(schm).(*pipepb.Schema)
schm.Options = append(schm.Options, logicalOption(lID))
- schm.Id = getUUID()
+ schm.Id = getUUID(t)
r.typeToSchema[t] = schm
r.idToType[schm.GetId()] = t
return schm, nil
@@ -483,7 +496,7 @@
schm := &pipepb.Schema{
Fields: fields,
- Id: getUUID(),
+ Id: getUUID(t),
}
r.idToType[schm.GetId()] = t
r.typeToSchema[t] = schm
diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
index 9ea5099..1ba1ac4 100644
--- a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
+++ b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
@@ -24,53 +24,105 @@
"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/go/pkg/beam/core/util/protox"
+ "github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
pipepb "github.com/apache/beam/sdks/go/pkg/beam/model/pipeline_v1"
)
// ResolveArtifacts acquires all dependencies for a cross-language transform
func ResolveArtifacts(ctx context.Context, edges []*graph.MultiEdge, p *pipepb.Pipeline) {
- path, err := filepath.Abs("/tmp/artifacts")
+ _, err := ResolveArtifactsWithConfig(ctx, edges, ResolveConfig{})
if err != nil {
panic(err)
}
+}
+
+// ResolveConfig contains fields for configuring the behavior for resolving
+// artifacts.
+type ResolveConfig struct {
+ // SdkPath replaces the default filepath for dependencies, but only in the
+ // external environment proto to be used by the SDK Harness during pipeline
+ // execution. This is used to specify alternate staging directories, such
+ // as for staging artifacts remotely.
+ //
+ // Setting an SdkPath does not change staging behavior otherwise. All
+ // artifacts still get staged to the default local filepath, and it is the
+ // user's responsibility to stage those local artifacts to the SdkPath.
+ SdkPath string
+
+ // JoinFn is a function for combining SdkPath and individual artifact names.
+ // If not specified, it defaults to using filepath.Join.
+ JoinFn func(path, name string) string
+}
+
+func defaultJoinFn(path, name string) string {
+ return filepath.Join(path, "/", name)
+}
+
+// ResolveArtifactsWithConfig acquires all dependencies for cross-language
+// transforms, but with some additional configuration to behavior. By default,
+// this function performs the following steps for each cross-language transform
+// in the list of edges:
+// 1. Retrieves a list of dependencies needed from the expansion service.
+// 2. Retrieves each dependency as an artifact and stages it to a default
+// local filepath.
+// 3. Adds the dependencies to the transform's stored environment proto.
+// The changes that can be configured are documented in ResolveConfig.
+//
+// This returns a map of "local path" to "sdk path". By default these are
+// identical, unless ResolveConfig.SdkPath has been set.
+func ResolveArtifactsWithConfig(ctx context.Context, edges []*graph.MultiEdge, cfg ResolveConfig) (paths map[string]string, err error) {
+ tmpPath, err := filepath.Abs("/tmp/artifacts")
+ if err != nil {
+ return nil, errors.WithContext(err, "resolving remote artifacts")
+ }
+ if cfg.JoinFn == nil {
+ cfg.JoinFn = defaultJoinFn
+ }
+ paths = make(map[string]string)
for _, e := range edges {
if e.Op == graph.External {
components, err := graphx.ExpandedComponents(e.External.Expanded)
if err != nil {
- panic(err)
+ return nil, errors.WithContextf(err,
+ "resolving remote artifacts for edge %v", e.Name())
}
envs := components.Environments
for eid, env := range envs {
-
if strings.HasPrefix(eid, "go") {
continue
}
deps := env.GetDependencies()
- resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", path)
+ resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", tmpPath)
if err != nil {
- panic(err)
+ return nil, errors.WithContextf(err,
+ "resolving remote artifacts for env %v in edge %v", eid, e.Name())
}
var resolvedDeps []*pipepb.ArtifactInformation
for _, a := range resolvedArtifacts {
- name, sha256 := artifact.MustExtractFilePayload(a)
- fullPath := filepath.Join(path, "/", name)
+ name, _ := artifact.MustExtractFilePayload(a)
+ fullTmpPath := filepath.Join(tmpPath, "/", name)
+ fullSdkPath := fullTmpPath
+ if len(cfg.SdkPath) > 0 {
+ fullSdkPath = cfg.JoinFn(cfg.SdkPath, name)
+ }
resolvedDeps = append(resolvedDeps,
&pipepb.ArtifactInformation{
TypeUrn: "beam:artifact:type:file:v1",
TypePayload: protox.MustEncode(
&pipepb.ArtifactFilePayload{
- Path: fullPath,
- Sha256: sha256,
+ Path: fullSdkPath,
},
),
RoleUrn: a.RoleUrn,
RolePayload: a.RolePayload,
},
)
+ paths[fullTmpPath] = fullSdkPath
}
env.Dependencies = resolvedDeps
}
}
}
+ return paths, nil
}
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
index c46ff4c..cd7be52 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
@@ -178,11 +178,23 @@
}
// (1) Build and submit
+ // NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
+ id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
+
+ modelURL := gcsx.Join(*stagingLocation, id, "model")
+ workerURL := gcsx.Join(*stagingLocation, id, "worker")
+ jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
+ xlangURL := gcsx.Join(*stagingLocation, id, "xlang")
edges, _, err := p.Build()
if err != nil {
return nil, err
}
+ artifactURLs, err := dataflowlib.ResolveXLangArtifacts(ctx, edges, opts.Project, xlangURL)
+ if err != nil {
+ return nil, errors.WithContext(err, "resolving cross-language artifacts")
+ }
+ opts.ArtifactURLs = artifactURLs
environment, err := graphx.CreateEnvironment(ctx, jobopts.GetEnvironmentUrn(ctx), getContainerImage)
if err != nil {
return nil, errors.WithContext(err, "creating environment for model pipeline")
@@ -196,13 +208,6 @@
return nil, errors.WithContext(err, "applying container image overrides")
}
- // NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
- id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
-
- modelURL := gcsx.Join(*stagingLocation, id, "model")
- workerURL := gcsx.Join(*stagingLocation, id, "worker")
- jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
-
if *dryRun {
log.Info(ctx, "Dry-run: not submitting job!")
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
index 511b962..390082b 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
@@ -56,6 +56,7 @@
WorkerRegion string
WorkerZone string
ContainerImage string
+ ArtifactURLs []string // Additional packages for workers.
// Autoscaling settings
Algorithm string
@@ -128,6 +129,15 @@
experiments = append(experiments, "use_staged_dataflow_worker_jar")
}
+ for _, url := range opts.ArtifactURLs {
+ name := url[strings.LastIndexAny(url, "/")+1:]
+ pkg := &df.Package{
+ Name: name,
+ Location: url,
+ }
+ packages = append(packages, pkg)
+ }
+
ipConfiguration := "WORKER_IP_UNSPECIFIED"
if opts.NoUsePublicIPs {
ipConfiguration = "WORKER_IP_PRIVATE"
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
index 67a2bde..49ca5bf 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
@@ -22,6 +22,8 @@
"os"
"cloud.google.com/go/storage"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/graph"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/runtime/xlangx"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/go/pkg/beam/util/gcsx"
)
@@ -54,3 +56,28 @@
_, err = gcsx.Upload(ctx, client, project, bucket, obj, r)
return err
}
+
+// ResolveXLangArtifacts resolves cross-language artifacts with a given GCS
+// URL as a destination, and then stages all local artifacts to that URL. This
+// function returns a list of staged artifact URLs.
+func ResolveXLangArtifacts(ctx context.Context, edges []*graph.MultiEdge, project, url string) ([]string, error) {
+ cfg := xlangx.ResolveConfig{
+ SdkPath: url,
+ JoinFn: func(url, name string) string {
+ return gcsx.Join(url, "/", name)
+ },
+ }
+ paths, err := xlangx.ResolveArtifactsWithConfig(ctx, edges, cfg)
+ if err != nil {
+ return nil, err
+ }
+ var urls []string
+ for local, remote := range paths {
+ err := StageFile(ctx, project, remote, local)
+ if err != nil {
+ return nil, errors.WithContextf(err, "staging file to %v", remote)
+ }
+ urls = append(urls, remote)
+ }
+ return urls, nil
+}
diff --git a/sdks/go/pkg/beam/runners/universal/universal.go b/sdks/go/pkg/beam/runners/universal/universal.go
index f01a71f..af14b38 100644
--- a/sdks/go/pkg/beam/runners/universal/universal.go
+++ b/sdks/go/pkg/beam/runners/universal/universal.go
@@ -79,6 +79,9 @@
getEnvCfg = srv.EnvironmentConfig
}
+ // Fetch all dependencies for cross-language transforms
+ xlangx.ResolveArtifacts(ctx, edges, nil)
+
environment, err := graphx.CreateEnvironment(ctx, envUrn, getEnvCfg)
if err != nil {
return nil, errors.WithContextf(err, "generating model pipeline")
@@ -88,9 +91,6 @@
return nil, errors.WithContextf(err, "generating model pipeline")
}
- // Fetch all dependencies for cross-language transforms
- xlangx.ResolveArtifacts(ctx, edges, pipeline)
-
log.Info(ctx, proto.MarshalTextString(pipeline))
opt := &runnerlib.JobOptions{
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index c756e05..7db20ac 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -78,12 +78,6 @@
var dataflowFilters = []string{
// TODO(BEAM-11576): TestFlattenDup failing on this runner.
"TestFlattenDup",
- // TODO(BEAM-11418): These tests require implementing x-lang artifact
- // staging on Dataflow.
- "TestXLang_CoGroupBy",
- "TestXLang_Multi",
- "TestXLang_Partition",
- "TestXLang_Prefix",
}
// CheckFilters checks if an integration test is filtered to be skipped, either
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
index 92c3e05..b04266c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
@@ -411,12 +411,12 @@
checkState(
this.outputs == null, "Tried to specify more than one output for %s", getFullName());
checkNotNull(output, "Tried to set the output of %s to null", getFullName());
- this.outputs = PValues.fullyExpand(output.expand());
+ this.outputs = PValues.expandOutput(output);
// Validate that a primitive transform produces only primitive output, and a composite
// transform does not produce primitive output.
Set<Node> outputProducers = new HashSet<>();
- for (PCollection<?> outputValue : PValues.fullyExpand(output.expand()).values()) {
+ for (PCollection<?> outputValue : PValues.expandOutput(output).values()) {
outputProducers.add(getProducer(outputValue));
}
if (outputProducers.contains(this) && (!parts.isEmpty() || outputProducers.size() > 1)) {
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
index 3b782d9..c489d12 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -73,10 +73,9 @@
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
- Iterator<AccumT> it = accumulators.iterator();
- AccumT first = it.next();
- it.remove();
- return getAggregateFn().mergeAccumulators(first, accumulators);
+ AccumT first = accumulators.iterator().next();
+ Iterable<AccumT> rest = new SkipFirstElementIterable<>(accumulators);
+ return getAggregateFn().mergeAccumulators(first, rest);
}
@Override
@@ -99,4 +98,20 @@
public TypeVariable<?> getAccumTVariable() {
return AggregateFn.class.getTypeParameters()[1];
}
+
+ /** Wrapper {@link Iterable} which always skips its first element. */
+ private static class SkipFirstElementIterable<T> implements Iterable<T> {
+ private final Iterable<T> all;
+
+ SkipFirstElementIterable(Iterable<T> all) {
+ this.all = all;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ Iterator<T> it = all.iterator();
+ it.next();
+ return it;
+ }
+ }
}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 34664ac..7b580ef 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -27,6 +27,7 @@
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
+import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.BiMap;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableBiMap;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -39,9 +40,6 @@
import org.joda.time.base.AbstractInstant;
/** Utility methods for Calcite related operations. */
-@SuppressWarnings({
- "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
public class CalciteUtils {
private static final long UNLIMITED_ARRAY_SIZE = -1L;
@@ -73,7 +71,9 @@
}
if (fieldType.getTypeName().isLogicalType()) {
- String logicalId = fieldType.getLogicalType().getIdentifier();
+ Schema.LogicalType logicalType = fieldType.getLogicalType();
+ Preconditions.checkArgumentNotNull(logicalType);
+ String logicalId = logicalType.getIdentifier();
return logicalId.equals(SqlTypes.DATE.getIdentifier())
|| logicalId.equals(SqlTypes.TIME.getIdentifier())
|| logicalId.equals(TimeWithLocalTzType.IDENTIFIER)
@@ -88,7 +88,9 @@
}
if (fieldType.getTypeName().isLogicalType()) {
- String logicalId = fieldType.getLogicalType().getIdentifier();
+ Schema.LogicalType logicalType = fieldType.getLogicalType();
+ Preconditions.checkArgumentNotNull(logicalType);
+ String logicalId = logicalType.getIdentifier();
return logicalId.equals(CharType.IDENTIFIER);
}
return false;
@@ -210,7 +212,12 @@
+ "so it cannot be converted to a %s",
sqlTypeName, Schema.FieldType.class.getSimpleName()));
default:
- return CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+ FieldType fieldType = CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+ if (fieldType == null) {
+ throw new IllegalArgumentException(
+ "Cannot find a matching Beam FieldType for Calcite type: " + sqlTypeName);
+ }
+ return fieldType;
}
}
@@ -234,7 +241,12 @@
return FieldType.row(toSchema(calciteType));
default:
- return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+ try {
+ return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(
+ "Cannot find a matching Beam FieldType for Calcite type: " + calciteType, e);
+ }
}
}
@@ -254,16 +266,22 @@
switch (fieldType.getTypeName()) {
case ARRAY:
case ITERABLE:
+ FieldType collectionElementType = fieldType.getCollectionElementType();
+ Preconditions.checkArgumentNotNull(collectionElementType);
return dataTypeFactory.createArrayType(
- toRelDataType(dataTypeFactory, fieldType.getCollectionElementType()),
- UNLIMITED_ARRAY_SIZE);
+ toRelDataType(dataTypeFactory, collectionElementType), UNLIMITED_ARRAY_SIZE);
case MAP:
- RelDataType componentKeyType = toRelDataType(dataTypeFactory, fieldType.getMapKeyType());
- RelDataType componentValueType =
- toRelDataType(dataTypeFactory, fieldType.getMapValueType());
+ FieldType mapKeyType = fieldType.getMapKeyType();
+ FieldType mapValueType = fieldType.getMapValueType();
+ Preconditions.checkArgumentNotNull(mapKeyType);
+ Preconditions.checkArgumentNotNull(mapValueType);
+ RelDataType componentKeyType = toRelDataType(dataTypeFactory, mapKeyType);
+ RelDataType componentValueType = toRelDataType(dataTypeFactory, mapValueType);
return dataTypeFactory.createMapType(componentKeyType, componentValueType);
case ROW:
- return toCalciteRowType(fieldType.getRowSchema(), dataTypeFactory);
+ Schema schema = fieldType.getRowSchema();
+ Preconditions.checkArgumentNotNull(schema);
+ return toCalciteRowType(schema, dataTypeFactory);
default:
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
index 21ab8d0..cf3f40d 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
@@ -19,12 +19,14 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertEquals;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -41,6 +43,13 @@
assertThat(coder, instanceOf(VarLongCoder.class));
}
+ @Test
+ public void mergeAccumulators() {
+ LazyAggregateCombineFn<Long, Long, Long> combiner = new LazyAggregateCombineFn<>(new Sum());
+ long merged = combiner.mergeAccumulators(ImmutableList.of(1L, 1L));
+ assertEquals(2L, merged);
+ }
+
public static class Sum implements AggregateFn<Long, Long, Long> {
@Override
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
index 50b6ab2..e76ee7f 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
@@ -30,13 +30,17 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
/** Tests for conversion from Beam schema to Calcite data type. */
public class CalciteUtilsTest {
RelDataTypeFactory dataTypeFactory;
+ @Rule public ExpectedException thrown = ExpectedException.none();
+
@Before
public void setUp() {
dataTypeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
@@ -166,4 +170,12 @@
assertEquals(schema, out);
}
+
+ @Test
+ public void testFieldTypeNotFound() {
+ RelDataType relDataType = dataTypeFactory.createUnknownType();
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Cannot find a matching Beam FieldType for Calcite type: UNKNOWN");
+ CalciteUtils.toFieldType(relDataType);
+ }
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
index 6413fd2..82c4b74 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
@@ -729,6 +729,17 @@
Map<String, String> argsMap =
((Map<String, Object>) MAPPER.convertValue(pipeline.getOptions(), Map.class).get("options"))
.entrySet().stream()
+ .filter(
+ (entry) -> {
+ if (entry.getValue() instanceof List) {
+ if (!((List) entry.getValue()).isEmpty()) {
+ throw new IllegalArgumentException("Cannot encode list arguments");
+ }
+ // We can encode empty lists, just omit them.
+ return false;
+ }
+ return true;
+ })
.collect(Collectors.toMap(Map.Entry::getKey, entry -> toArg(entry.getValue())));
InMemoryMetaStore inMemoryMetaStore = new InMemoryMetaStore();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 231d2b8..0ca4581 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -272,12 +272,14 @@
finalizeBundleHandler,
metricsShortIds);
+ BeamFnStatusClient beamFnStatusClient = null;
if (statusApiServiceDescriptor != null) {
- new BeamFnStatusClient(
- statusApiServiceDescriptor,
- channelFactory::forDescriptor,
- processBundleHandler.getBundleProcessorCache(),
- options);
+ beamFnStatusClient =
+ new BeamFnStatusClient(
+ statusApiServiceDescriptor,
+ channelFactory::forDescriptor,
+ processBundleHandler.getBundleProcessorCache(),
+ options);
}
// TODO(BEAM-9729): Remove once runners no longer send this instruction.
@@ -337,6 +339,9 @@
executorService,
handlers);
control.waitForTermination();
+ if (beamFnStatusClient != null) {
+ beamFnStatusClient.close();
+ }
processBundleHandler.shutdown();
} finally {
System.out.println("Shutting SDK harness down.");
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 9af7f58..296767b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -48,8 +48,6 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
/**
* The {@code PCollectionConsumerRegistry} is used to maintain a collection of consuming
@@ -216,7 +214,7 @@
private final String pTransformId;
private final SimpleExecutionState state;
private final Counter unboundedElementCountCounter;
- private final SampleByteSizeDistribution<T> unboundSampledByteSizeDistribution;
+ private final SampleByteSizeDistribution<T> unboundedSampledByteSizeDistribution;
private final Coder<T> coder;
private final MetricsContainer metricsContainer;
@@ -239,7 +237,7 @@
MonitoringInfoMetricName sampledByteSizeMetricName =
MonitoringInfoMetricName.named(Urns.SAMPLED_BYTE_SIZE, labels);
- this.unboundSampledByteSizeDistribution =
+ this.unboundedSampledByteSizeDistribution =
new SampleByteSizeDistribution<>(
unboundMetricContainer.getDistribution(sampledByteSizeMetricName));
@@ -252,7 +250,7 @@
// Increment the counter for each window the element occurs in.
this.unboundedElementCountCounter.inc(input.getWindows().size());
// TODO(BEAM-11879): Consider updating size per window when we have window optimization.
- this.unboundSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
+ this.unboundedSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
// Wrap the consumer with extra logic to set the metric container with the appropriate
// PTransform context. This ensures that user metrics obtain the pTransform ID when they are
// created. Also use the ExecutionStateTracker and enter an appropriate state to track the
@@ -262,6 +260,7 @@
this.delegate.accept(input);
}
}
+ this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
}
}
@@ -321,6 +320,7 @@
consumerAndMetadata.getConsumer().accept(input);
}
}
+ this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
}
}
}
@@ -365,28 +365,33 @@
}
final Distribution distribution;
+ ByteSizeObserver byteCountObserver;
public SampleByteSizeDistribution(Distribution distribution) {
this.distribution = distribution;
+ this.byteCountObserver = null;
}
public void tryUpdate(T value, Coder<T> coder) throws Exception {
if (shouldSampleElement()) {
// First try using byte size observer
- ByteSizeObserver observer = new ByteSizeObserver();
- coder.registerByteSizeObserver(value, observer);
+ byteCountObserver = new ByteSizeObserver();
+ coder.registerByteSizeObserver(value, byteCountObserver);
- if (!observer.getIsLazy()) {
- observer.advance();
- this.distribution.update(observer.observedSize);
- } else {
- // TODO(BEAM-11841): Optimize calculation of element size for iterables.
- // Coder byte size observation is lazy (requires iteration for observation) so fall back
- // to counting output stream
- CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
- coder.encode(value, os);
- this.distribution.update(os.getCount());
+ if (!byteCountObserver.getIsLazy()) {
+ byteCountObserver.advance();
+ this.distribution.update(byteCountObserver.observedSize);
}
+ } else {
+ byteCountObserver = null;
+ }
+ }
+
+ public void finishLazyUpdate() {
+ // Advance lazy ElementByteSizeObservers, if any.
+ if (byteCountObserver != null && byteCountObserver.getIsLazy()) {
+ byteCountObserver.advance();
+ this.distribution.update(byteCountObserver.observedSize);
}
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
index 4c01c04..e059471 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
@@ -25,6 +25,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
@@ -42,9 +44,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class BeamFnStatusClient {
+public class BeamFnStatusClient implements AutoCloseable {
+ private static final Object COMPLETED = new Object();
private final StreamObserver<WorkerStatusResponse> outboundObserver;
private final BundleProcessorCache processBundleCache;
+ private final ManagedChannel channel;
+ private final CompletableFuture<Object> inboundObserverCompletion;
private static final Logger LOG = LoggerFactory.getLogger(BeamFnStatusClient.class);
private final MemoryMonitor memoryMonitor;
@@ -53,11 +58,12 @@
Function<ApiServiceDescriptor, ManagedChannel> channelFactory,
BundleProcessorCache processBundleCache,
PipelineOptions options) {
- BeamFnWorkerStatusGrpc.BeamFnWorkerStatusStub stub =
- BeamFnWorkerStatusGrpc.newStub(channelFactory.apply(apiServiceDescriptor));
- this.outboundObserver = stub.workerStatus(new InboundObserver());
+ this.channel = channelFactory.apply(apiServiceDescriptor);
+ this.outboundObserver =
+ BeamFnWorkerStatusGrpc.newStub(channel).workerStatus(new InboundObserver());
this.processBundleCache = processBundleCache;
this.memoryMonitor = MemoryMonitor.fromOptions(options);
+ this.inboundObserverCompletion = new CompletableFuture<>();
Thread thread = new Thread(memoryMonitor);
thread.setDaemon(true);
thread.setPriority(Thread.MIN_PRIORITY);
@@ -65,6 +71,22 @@
thread.start();
}
+ @Override
+ public void close() throws Exception {
+ try {
+ Object completion = inboundObserverCompletion.get(1, TimeUnit.MINUTES);
+ if (completion != COMPLETED) {
+ LOG.warn("InboundObserver for BeamFnStatusClient completed with exception.");
+ }
+ } finally {
+ // Shut the channel down
+ channel.shutdown();
+ if (!channel.awaitTermination(10, TimeUnit.SECONDS)) {
+ channel.shutdownNow();
+ }
+ }
+ }
+
/**
* Class representing the execution state of a thread.
*
@@ -222,9 +244,12 @@
@Override
public void onError(Throwable t) {
LOG.error("Error getting SDK harness status", t);
+ inboundObserverCompletion.completeExceptionally(t);
}
@Override
- public void onCompleted() {}
+ public void onCompleted() {
+ inboundObserverCompletion.complete(COMPLETED);
+ }
}
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index 90baa5e..708ca8b 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -24,6 +24,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@@ -32,7 +33,9 @@
import static org.powermock.api.mockito.PowerMockito.mockStatic;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import org.apache.beam.fn.harness.HandlesSplits;
import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
@@ -44,15 +47,19 @@
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
+import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
@@ -329,5 +336,108 @@
verify(consumerA1).trySplit(0.3);
}
+ @Test
+ public void testLazyByteSizeEstimation() throws Exception {
+ final String pCollectionA = "pCollectionA";
+ final String pTransformIdA = "pTransformIdA";
+
+ MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+ PCollectionConsumerRegistry consumers =
+ new PCollectionConsumerRegistry(
+ metricsContainerRegistry, mock(ExecutionStateTracker.class));
+ FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class);
+
+ consumers.register(
+ pCollectionA, pTransformIdA, consumerA1, IterableCoder.of(StringUtf8Coder.of()));
+
+ FnDataReceiver<WindowedValue<Iterable<String>>> wrapperConsumer =
+ (FnDataReceiver<WindowedValue<Iterable<String>>>)
+ (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+ String elementValue = "elem";
+ long elementByteSize = StringUtf8Coder.of().getEncodedElementByteSize(elementValue);
+ WindowedValue<Iterable<String>> element =
+ valueInGlobalWindow(
+ new TestElementByteSizeObservableIterable<>(
+ Arrays.asList(elementValue, elementValue), elementByteSize));
+ int numElements = 10;
+ // Mock doing work on the iterable items
+ doAnswer(
+ (Answer<Void>)
+ invocation -> {
+ Object[] args = invocation.getArguments();
+ WindowedValue<Iterable<String>> arg = (WindowedValue<Iterable<String>>) args[0];
+ Iterator it = arg.getValue().iterator();
+ while (it.hasNext()) {
+ it.next();
+ }
+ return null;
+ })
+ .when(consumerA1)
+ .accept(element);
+
+ for (int i = 0; i < numElements; i++) {
+ wrapperConsumer.accept(element);
+ }
+
+ // Check that the underlying consumers are each invoked per element.
+ verify(consumerA1, times(numElements)).accept(element);
+ assertThat(consumers.keySet(), contains(pCollectionA));
+
+ List<MonitoringInfo> expected = new ArrayList<>();
+
+ SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+ builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+ builder.setInt64SumValue(numElements);
+ expected.add(builder.build());
+
+ builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrn(Urns.SAMPLED_BYTE_SIZE);
+ builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+ long expectedBytes =
+ (elementByteSize + 1) * 2
+ + 5; // Additional 5 bytes are due to size and hasNext = false (1 byte).
+ builder.setInt64DistributionValue(
+ DistributionData.create(
+ numElements * expectedBytes, numElements, expectedBytes, expectedBytes));
+ expected.add(builder.build());
+ // Clear the timestamp before comparison.
+ Iterable<MonitoringInfo> result =
+ Iterables.filter(
+ metricsContainerRegistry.getMonitoringInfos(),
+ monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+
+ assertThat(result, containsInAnyOrder(expected.toArray()));
+ }
+
+ private class TestElementByteSizeObservableIterable<T>
+ extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>> {
+ private List<T> elements;
+ private long elementByteSize;
+
+ public TestElementByteSizeObservableIterable(List<T> elements, long elementByteSize) {
+ this.elements = elements;
+ this.elementByteSize = elementByteSize;
+ }
+
+ @Override
+ protected ElementByteSizeObservableIterator createIterator() {
+ return new ElementByteSizeObservableIterator() {
+ private int index = 0;
+
+ @Override
+ public boolean hasNext() {
+ return index < elements.size();
+ }
+
+ @Override
+ public Object next() {
+ notifyValueReturned(elementByteSize);
+ return elements.get(index++);
+ }
+ };
+ }
+ }
+
private abstract static class SplittingReceiver<T> implements FnDataReceiver<T>, HandlesSplits {}
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
index 179c341..dd281bb 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
@@ -27,6 +27,7 @@
description = "Apache Beam :: SDKs :: Java :: IO :: Elasticsearch-Tests :: 2.x"
ext.summary = "Tests of ElasticsearchIO on Elasticsearch 2.x"
+def log4j_version = "2.14.1"
def elastic_search_version = "2.4.1"
dependencies {
@@ -40,8 +41,8 @@
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:7.9.2"
testCompile "org.elasticsearch:elasticsearch:$elastic_search_version"
- testRuntimeOnly library.java.log4j_api
- testRuntimeOnly library.java.log4j_core
+ testRuntimeOnly "org.apache.logging.log4j:log4j-api:$log4j_version"
+ testRuntimeOnly "org.apache.logging.log4j:log4j-core:$log4j_version"
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
index e5090ef..718b903 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
@@ -32,14 +32,15 @@
systemProperty "tests.security.manager", "false"
}
+def log4j_version = "2.14.1"
def elastic_search_version = "5.6.3"
configurations.all {
resolutionStrategy {
// Make sure the log4j versions for api and core match instead of taking the default
// Gradle rule of using the latest.
- force library.java.log4j_api
- force library.java.log4j_core
+ force "org.apache.logging.log4j:log4j-core:$log4j_version"
+ force "org.apache.logging.log4j:log4j-api:$log4j_version"
}
}
@@ -58,8 +59,8 @@
testCompile library.java.hamcrest_library
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
- testRuntimeOnly library.java.log4j_api
- testRuntimeOnly library.java.log4j_core
+ testRuntimeOnly "org.apache.logging.log4j:log4j-api:$log4j_version"
+ testRuntimeOnly "org.apache.logging.log4j:log4j-core:$log4j_version"
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
index 71cbe00..6eee3a0 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
@@ -32,14 +32,15 @@
systemProperty "tests.security.manager", "false"
}
+def log4j_version = "2.14.1"
def elastic_search_version = "6.4.0"
configurations.all {
resolutionStrategy {
// Make sure the log4j versions for api and core match instead of taking the default
// Gradle rule of using the latest.
- force library.java.log4j_api
- force library.java.log4j_core
+ force "org.apache.logging.log4j:log4j-core:$log4j_version"
+ force "org.apache.logging.log4j:log4j-api:$log4j_version"
}
}
@@ -58,8 +59,8 @@
testCompile library.java.hamcrest_library
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
- testRuntimeOnly library.java.log4j_api
- testRuntimeOnly library.java.log4j_core
+ testRuntimeOnly "org.apache.logging.log4j:log4j-api:$log4j_version"
+ testRuntimeOnly "org.apache.logging.log4j:log4j-core:$log4j_version"
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/build.gradle
index 127408a..ead50d7 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-7/build.gradle
@@ -32,14 +32,15 @@
systemProperty "tests.security.manager", "false"
}
+def log4j_version = "2.14.1"
def elastic_search_version = "7.9.2"
configurations.all {
resolutionStrategy {
// Make sure the log4j versions for api and core match instead of taking the default
// Gradle rule of using the latest.
- force library.java.log4j_api
- force library.java.log4j_core
+ force "org.apache.logging.log4j:log4j-core:$log4j_version"
+ force "org.apache.logging.log4j:log4j-api:$log4j_version"
}
}
@@ -58,8 +59,8 @@
testCompile library.java.hamcrest_library
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
- testRuntimeOnly library.java.log4j_api
- testRuntimeOnly library.java.log4j_core
+ testRuntimeOnly "org.apache.logging.log4j:log4j-api:$log4j_version"
+ testRuntimeOnly "org.apache.logging.log4j:log4j-core:$log4j_version"
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
index 6a86075..396d3a1 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
@@ -25,14 +25,15 @@
description = "Apache Beam :: SDKs :: Java :: IO :: Elasticsearch-Tests :: Common"
ext.summary = "Common test classes for ElasticsearchIO"
+def log4j_version = "2.14.1"
def elastic_search_version = "7.9.2"
configurations.all {
resolutionStrategy {
// Make sure the log4j versions for api and core match instead of taking the default
// Gradle rule of using the latest.
- force library.java.log4j_api
- force library.java.log4j_core
+ force "org.apache.logging.log4j:log4j-core:$log4j_version"
+ force "org.apache.logging.log4j:log4j-api:$log4j_version"
}
}
@@ -46,8 +47,8 @@
testCompile library.java.hamcrest_library
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
- testRuntimeOnly library.java.log4j_api
- testRuntimeOnly library.java.log4j_core
+ testRuntimeOnly "org.apache.logging.log4j:log4j-api:$log4j_version"
+ testRuntimeOnly "org.apache.logging.log4j:log4j-core:$log4j_version"
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
index 4193ba6..fc1fe94 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
@@ -109,7 +109,10 @@
.getReferencedTables();
if (referencedTables != null && !referencedTables.isEmpty()) {
TableReference referencedTable = referencedTables.get(0);
- effectiveLocation = tableService.getTable(referencedTable).getLocation();
+ effectiveLocation =
+ tableService
+ .getDataset(referencedTable.getProjectId(), referencedTable.getDatasetId())
+ .getLocation();
}
}
diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle
index 16787e5..d5075fe 100644
--- a/sdks/java/io/hadoop-format/build.gradle
+++ b/sdks/java/io/hadoop-format/build.gradle
@@ -96,7 +96,6 @@
testCompile library.java.hamcrest_library
testCompile library.java.testcontainers_postgresql
testCompile library.java.netty_all
- testRuntimeOnly library.java.log4j_core
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index 84743f8c..820418c 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -43,7 +43,7 @@
configurations.testRuntimeClasspath {
resolutionStrategy {
- def log4j_version = "2.4.1"
+ def log4j_version = "2.8.2"
// Beam's build system forces a uniform log4j version resolution for all modules, however for
// the HCatalog case the current version of log4j produces NoClassDefFoundError so we need to
// force an old version on the tests runtime classpath
diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py
index 25355fe..42df75b 100644
--- a/sdks/python/apache_beam/dataframe/frames.py
+++ b/sdks/python/apache_beam/dataframe/frames.py
@@ -37,6 +37,7 @@
import itertools
import math
import re
+import warnings
from typing import List
from typing import Optional
@@ -948,41 +949,87 @@
to_string = frame_base.wont_implement_method(
pd.Series, 'to_string', reason="non-deferred-result")
- def aggregate(self, func, axis=0, *args, **kwargs):
+ @frame_base.args_to_kwargs(pd.Series)
+ @frame_base.populate_defaults(pd.Series)
+ def aggregate(self, func, axis, *args, **kwargs):
+ if kwargs.get('skipna', False):
+ # Eagerly generate a proxy to make sure skipna is a valid argument
+ # for this aggregation method
+ _ = self._expr.proxy().aggregate(func, axis, *args, **kwargs)
+ kwargs.pop('skipna')
+ return self.dropna().aggregate(func, axis, *args, **kwargs)
+
if isinstance(func, list) and len(func) > 1:
- # Aggregate each column separately, then stick them all together.
+ # level arg is ignored for multiple aggregations
+ _ = kwargs.pop('level', None)
+
+ # Aggregate with each method separately, then stick them all together.
rows = [self.agg([f], *args, **kwargs) for f in func]
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'join_aggregate',
lambda *rows: pd.concat(rows), [row._expr for row in rows]))
else:
- # We're only handling a single column.
+ # We're only handling a single column. It could be 'func' or ['func'],
+ # which produce different results. 'func' produces a scalar, ['func']
+ # produces a single element Series.
base_func = func[0] if isinstance(func, list) else func
- if _is_associative(base_func) and not args and not kwargs:
+
+ if (_is_numeric(base_func) and
+ not pd.core.dtypes.common.is_numeric_dtype(self.dtype)):
+ warnings.warn(
+ f"Performing a numeric aggregation, {base_func!r}, on "
+ f"Series {self._expr.proxy().name!r} with non-numeric type "
+ f"{self.dtype!r}. This can result in runtime errors or surprising "
+ "results.")
+
+ if 'level' in kwargs:
+ # Defer to groupby.agg for level= mode
+ return self.groupby(
+ level=kwargs.pop('level'), axis=axis).agg(func, *args, **kwargs)
+
+ singleton_reason = None
+ if 'min_count' in kwargs:
+ # Eagerly generate a proxy to make sure min_count is a valid argument
+ # for this aggregation method
+ _ = self._expr.proxy().agg(func, axis, *args, **kwargs)
+
+ singleton_reason = (
+ "Aggregation with min_count= requires collecting all data on a "
+ "single node.")
+
+ agg_kwargs = kwargs.copy()
+ if ((_is_associative(base_func) or _is_liftable_with_sum(base_func)) and
+ singleton_reason is None):
intermediate = expressions.ComputedExpression(
'pre_aggregate',
- lambda s: s.agg([base_func], *args, **kwargs), [self._expr],
+ # Coerce to a Series, if the result is scalar we still want a Series
+ # so we can combine and do the final aggregation next.
+ lambda s: pd.Series(s.agg(func, *args, **kwargs)),
+ [self._expr],
requires_partition_by=partitionings.Arbitrary(),
preserves_partition_by=partitionings.Singleton())
allow_nonparallel_final = True
+ if _is_associative(base_func):
+ agg_func = func
+ else:
+ agg_func = ['sum'] if isinstance(func, list) else 'sum'
else:
intermediate = self._expr
allow_nonparallel_final = None # i.e. don't change the value
+ agg_func = func
+ singleton_reason = (
+ f"Aggregation function {func!r} cannot currently be "
+ "parallelized, it requires collecting all data for "
+ "this Series on a single node.")
with expressions.allow_non_parallel_operations(allow_nonparallel_final):
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'aggregate',
- lambda s: s.agg(func, *args, **kwargs),
- [intermediate],
- preserves_partition_by=partitionings.Arbitrary(),
- # TODO(BEAM-11839): This reason should be more specific. It's
- # actually incorrect for the args/kwargs case above.
+ lambda s: s.agg(agg_func, *args, **agg_kwargs), [intermediate],
+ preserves_partition_by=partitionings.Singleton(),
requires_partition_by=partitionings.Singleton(
- reason=(
- f"Aggregation function {func!r} cannot currently be "
- "parallelized, it requires collecting all data for "
- "this Series on a single node."))))
+ reason=singleton_reason)))
agg = aggregate
@@ -994,6 +1041,7 @@
all = frame_base._agg_method('all')
any = frame_base._agg_method('any')
+ # TODO(BEAM-12074): Document that Series.count(level=) will drop NaN's
count = frame_base._agg_method('count')
min = frame_base._agg_method('min')
max = frame_base._agg_method('max')
@@ -1428,7 +1476,86 @@
preserves_partition_by=preserves,
requires_partition_by=partitionings.Arbitrary()))
+ @frame_base.args_to_kwargs(pd.DataFrame)
+ @frame_base.populate_defaults(pd.DataFrame)
+ def insert(self, value, **kwargs):
+ if isinstance(value, list):
+ raise frame_base.WontImplementMethod(
+ "insert(value=list) is not supported because it joins the input "
+ "list to the deferred DataFrame based on the order of the data.",
+ reason="order-sensitive")
+
+ if isinstance(value, pd.core.generic.NDFrame):
+ value = frame_base.DeferredFrame.wrap(
+ expressions.ConstantExpression(value))
+
+ if isinstance(value, frame_base.DeferredFrame):
+ def func_zip(df, value):
+ df = df.copy()
+ df.insert(value=value, **kwargs)
+ return df
+
+ inserted = frame_base.DeferredFrame.wrap(
+ expressions.ComputedExpression(
+ 'insert',
+ func_zip,
+ [self._expr, value._expr],
+ requires_partition_by=partitionings.Index(),
+ preserves_partition_by=partitionings.Arbitrary()))
+ else:
+ def func_elementwise(df):
+ df = df.copy()
+ df.insert(value=value, **kwargs)
+ return df
+ inserted = frame_base.DeferredFrame.wrap(
+ expressions.ComputedExpression(
+ 'insert',
+ func_elementwise,
+ [self._expr],
+ requires_partition_by=partitionings.Arbitrary(),
+ preserves_partition_by=partitionings.Arbitrary()))
+
+ self._expr = inserted._expr
+
+ @frame_base.args_to_kwargs(pd.DataFrame)
+ @frame_base.populate_defaults(pd.DataFrame)
def aggregate(self, func, axis=0, *args, **kwargs):
+ if 'numeric_only' in kwargs and kwargs['numeric_only']:
+ # Eagerly generate a proxy to make sure numeric_only is a valid argument
+ # for this aggregation method
+ _ = self._expr.proxy().agg(func, axis, *args, **kwargs)
+
+ projected = self[[name for name, dtype in self.dtypes.items()
+ if pd.core.dtypes.common.is_numeric_dtype(dtype)]]
+ kwargs.pop('numeric_only')
+ return projected.agg(func, axis, *args, **kwargs)
+
+ if 'bool_only' in kwargs and kwargs['bool_only']:
+ # Eagerly generate a proxy to make sure bool_only is a valid argument
+ # for this aggregation method
+ _ = self._expr.proxy().agg(func, axis, *args, **kwargs)
+
+ projected = self[[name for name, dtype in self.dtypes.items()
+ if pd.core.dtypes.common.is_bool_dtype(dtype)]]
+ kwargs.pop('bool_only')
+ return projected.agg(func, axis, *args, **kwargs)
+
+ nonnumeric_columns = [name for (name, dtype) in self.dtypes.items()
+ if not pd.core.dtypes.common.is_numeric_dtype(dtype)]
+ if _is_numeric(func) and len(nonnumeric_columns):
+ if 'numeric_only' in kwargs and kwargs['numeric_only'] is False:
+ # User has opted in to execution with non-numeric columns, they
+ # will accept runtime errors
+ pass
+ else:
+ raise frame_base.WontImplementError(
+ f"Numeric aggregation ({func!r}) on a DataFrame containing "
+ f"non-numeric columns ({*nonnumeric_columns,!r} is not supported, "
+ "unless `numeric_only=` is specified.\n"
+ "Use `numeric_only=True` to only aggregate over numeric columns.\n"
+ "Use `numeric_only=False` to aggregate over all columns. Note this "
+ "is not recommended, as it could result in execution time errors.")
+
if axis is None:
# Aggregate across all elements by first aggregating across columns,
# then across rows.
@@ -1442,14 +1569,13 @@
lambda df: df.agg(func, axis=1, *args, **kwargs),
[self._expr],
requires_partition_by=partitionings.Arbitrary()))
- elif len(self._expr.proxy().columns) == 0 or args or kwargs:
- # For these corner cases, just colocate everything.
+ elif len(self._expr.proxy().columns) == 0:
+ # For this corner case, just colocate everything.
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'aggregate',
lambda df: df.agg(func, *args, **kwargs),
[self._expr],
- # TODO(BEAM-11839): Provide a reason for this Singleton
requires_partition_by=partitionings.Singleton()))
else:
# In the general case, compute the aggregation of each column separately,
@@ -1460,15 +1586,19 @@
else:
col_names = list(func.keys())
aggregated_cols = []
+ has_lists = any(isinstance(f, list) for f in func.values())
for col in col_names:
funcs = func[col]
- if not isinstance(funcs, list):
+ if has_lists and not isinstance(funcs, list):
+ # If any of the columns do multiple aggregations, they all must use
+ # "list" style output
funcs = [funcs]
aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
# The final shape is different depending on whether any of the columns
# were aggregated by a list of aggregators.
with expressions.allow_non_parallel_operations():
- if any(isinstance(funcs, list) for funcs in func.values()):
+ if (any(isinstance(funcs, list) for funcs in func.values()) or
+ 'level' in kwargs):
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'join_aggregate',
@@ -1481,7 +1611,7 @@
expressions.ComputedExpression(
'join_aggregate',
lambda *cols: pd.Series(
- {col: value[0] for col, value in zip(col_names, cols)}),
+ {col: value for col, value in zip(col_names, cols)}),
[col._expr for col in aggregated_cols],
requires_partition_by=partitionings.Singleton(),
proxy=self._expr.proxy().agg(func, *args, **kwargs)))
@@ -2321,18 +2451,23 @@
self._grouping_indexes,
projection=name)
- def agg(self, fn):
- if not callable(fn):
- # TODO: Add support for strings in (UN)LIFTABLE_AGGREGATIONS. Test by
- # running doctests for pandas.core.groupby.generic
- raise NotImplementedError('GroupBy.agg currently only supports callable '
- 'arguments')
- return DeferredDataFrame(
- expressions.ComputedExpression(
- 'agg',
- lambda gb: gb.agg(fn), [self._expr],
- requires_partition_by=partitionings.Index(),
- preserves_partition_by=partitionings.Singleton()))
+ def agg(self, fn, *args, **kwargs):
+ if _is_associative(fn):
+ return _liftable_agg(fn)(self, *args, **kwargs)
+ elif _is_liftable_with_sum(fn):
+ return _liftable_agg(fn, postagg_meth='sum')(self, *args, **kwargs)
+ elif _is_unliftable(fn):
+ return _unliftable_agg(fn)(self, *args, **kwargs)
+ elif callable(fn):
+ return DeferredDataFrame(
+ expressions.ComputedExpression(
+ 'agg',
+ lambda gb: gb.agg(fn, *args, **kwargs), [self._expr],
+ requires_partition_by=partitionings.Index(),
+ preserves_partition_by=partitionings.Singleton()))
+ else:
+ raise NotImplementedError(f"GroupBy.agg(func={fn!r})")
+
def apply(self, fn, *args, **kwargs):
if self._grouping_columns and not self._projection:
@@ -2440,16 +2575,19 @@
def _liftable_agg(meth, postagg_meth=None):
- name, agg_func = frame_base.name_and_func(meth)
+ agg_name, _ = frame_base.name_and_func(meth)
if postagg_meth is None:
- post_agg_name, post_agg_func = name, agg_func
+ post_agg_name = agg_name
else:
- post_agg_name, post_agg_func = frame_base.name_and_func(postagg_meth)
+ post_agg_name, _ = frame_base.name_and_func(postagg_meth)
def wrapper(self, *args, **kwargs):
assert isinstance(self, DeferredGroupBy)
+ if 'min_count' in kwargs:
+ return _unliftable_agg(meth)(self, *args, **kwargs)
+
to_group = self._ungrouped.proxy().index
is_categorical_grouping = any(to_group.get_level_values(i).is_categorical()
for i in self._grouping_indexes)
@@ -2461,20 +2599,24 @@
project = _maybe_project_func(self._projection)
pre_agg = expressions.ComputedExpression(
- 'pre_combine_' + name,
- lambda df: agg_func(project(
- df.groupby(level=list(range(df.index.nlevels)),
- **preagg_groupby_kwargs),
- ), **kwargs),
+ 'pre_combine_' + agg_name,
+ lambda df: getattr(
+ project(
+ df.groupby(level=list(range(df.index.nlevels)),
+ **preagg_groupby_kwargs)
+ ),
+ agg_name)(**kwargs),
[self._ungrouped],
requires_partition_by=partitionings.Arbitrary(),
preserves_partition_by=partitionings.Arbitrary())
+
post_agg = expressions.ComputedExpression(
'post_combine_' + post_agg_name,
- lambda df: post_agg_func(
- df.groupby(level=list(range(df.index.nlevels)), **groupby_kwargs),
- **kwargs),
+ lambda df: getattr(
+ df.groupby(level=list(range(df.index.nlevels)),
+ **groupby_kwargs),
+ post_agg_name)(**kwargs),
[pre_agg],
requires_partition_by=(partitionings.Singleton(reason=(
"Aggregations grouped by a categorical column are not currently "
@@ -2489,7 +2631,7 @@
def _unliftable_agg(meth):
- name, agg_func = frame_base.name_and_func(meth)
+ agg_name, _ = frame_base.name_and_func(meth)
def wrapper(self, *args, **kwargs):
assert isinstance(self, DeferredGroupBy)
@@ -2501,11 +2643,11 @@
groupby_kwargs = self._kwargs
project = _maybe_project_func(self._projection)
post_agg = expressions.ComputedExpression(
- name,
- lambda df: agg_func(project(
+ agg_name,
+ lambda df: getattr(project(
df.groupby(level=list(range(df.index.nlevels)),
**groupby_kwargs),
- ), **kwargs),
+ ), agg_name)(**kwargs),
[self._ungrouped],
requires_partition_by=(partitionings.Singleton(reason=(
"Aggregations grouped by a categorical column are not currently "
@@ -2529,13 +2671,27 @@
for meth in UNLIFTABLE_AGGREGATIONS:
setattr(DeferredGroupBy, meth, _unliftable_agg(meth))
-
-def _is_associative(agg_func):
- return agg_func in LIFTABLE_AGGREGATIONS or (
- getattr(agg_func, '__name__', None) in LIFTABLE_AGGREGATIONS
+def _check_str_or_np_builtin(agg_func, func_list):
+ return agg_func in func_list or (
+ getattr(agg_func, '__name__', None) in func_list
and agg_func.__module__ in ('numpy', 'builtins'))
+def _is_associative(agg_func):
+ return _check_str_or_np_builtin(agg_func, LIFTABLE_AGGREGATIONS)
+
+def _is_liftable_with_sum(agg_func):
+ return _check_str_or_np_builtin(agg_func, LIFTABLE_WITH_SUM_AGGREGATIONS)
+
+def _is_unliftable(agg_func):
+ return _check_str_or_np_builtin(agg_func, UNLIFTABLE_AGGREGATIONS)
+
+NUMERIC_AGGREGATIONS = ['max', 'min', 'prod', 'sum', 'mean', 'median', 'std',
+ 'var']
+
+def _is_numeric(agg_func):
+ return _check_str_or_np_builtin(agg_func, NUMERIC_AGGREGATIONS)
+
@populate_not_implemented(DataFrameGroupBy)
class _DeferredGroupByCols(frame_base.DeferredFrame):
diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py
index 1cf1dfb..fcf5963 100644
--- a/sdks/python/apache_beam/dataframe/frames_test.py
+++ b/sdks/python/apache_beam/dataframe/frames_test.py
@@ -98,6 +98,18 @@
f'Expected {expected_error!r} to be raised, but got {actual!r}'
) from actual
+ def _run_inplace_test(self, func, arg, **kwargs):
+ """Verify an inplace operation performed by func.
+
+ Checks that func performs the same inplace operation on arg, in pandas and
+ in Beam."""
+ def wrapper(df):
+ df = df.copy()
+ func(df)
+ return df
+
+ self._run_test(wrapper, arg, **kwargs)
+
def _run_test(self, func, *args, distributed=True, nonparallel=False):
"""Verify that func(*args) produces the same result in pandas and in Beam.
@@ -160,7 +172,10 @@
else:
# Expectation is not a pandas object
if isinstance(expected, float):
- cmp = lambda x: np.isclose(expected, x)
+ if np.isnan(expected):
+ cmp = np.isnan
+ else:
+ cmp = lambda x: np.isclose(expected, x)
else:
cmp = expected.__eq__
self.assertTrue(
@@ -182,13 +197,12 @@
def test_set_column(self):
def new_column(df):
df['NewCol'] = df['Speed']
- return df
df = pd.DataFrame({
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
'Speed': [380., 370., 24., 26.]
})
- self._run_test(new_column, df)
+ self._run_inplace_test(new_column, df)
def test_str_split(self):
s = pd.Series([
@@ -209,13 +223,12 @@
def test_set_column_from_index(self):
def new_column(df):
df['NewCol'] = df.index
- return df
df = pd.DataFrame({
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
'Speed': [380., 370., 24., 26.]
})
- self._run_test(new_column, df)
+ self._run_inplace_test(new_column, df)
def test_tz_localize_ambiguous_series(self):
# This replicates a tz_localize doctest:
@@ -703,11 +716,7 @@
self._run_test(lambda df: df.eval('foo = a + b - c'), df)
self._run_test(lambda df: df.query('a > b + c'), df)
- def eval_inplace(df):
- df.eval('foo = a + b - c', inplace=True)
- return df.foo
-
- self._run_test(eval_inplace, df)
+ self._run_inplace_test(lambda df: df.eval('foo = a + b - c'), df)
# Verify that attempting to access locals raises a useful error
deferred_df = frame_base.DeferredFrame.wrap(
@@ -723,9 +732,8 @@
def change_index_names(df):
df.index.names = ['A', None]
- return df
- self._run_test(change_index_names, df)
+ self._run_inplace_test(change_index_names, df)
@parameterized.expand((x, ) for x in [
0,
@@ -845,6 +853,212 @@
df2,
construction_time=False)
+ def test_series_agg_level(self):
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.count(level=0),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.max(level=0), GROUPBY_DF)
+
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.median(level=0),
+ GROUPBY_DF)
+
+ self._run_test(
+ lambda df: df.set_index(['foo', 'group']).bar.count(level=1),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.max(level=1), GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.max(level='foo'),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.median(level=1),
+ GROUPBY_DF)
+
+ def test_dataframe_agg_level(self):
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).count(level=0), GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).max(
+ level=0, numeric_only=False),
+ GROUPBY_DF)
+ # pandas implementation doesn't respect numeric_only argument here
+ # (https://github.com/pandas-dev/pandas/issues/40788), it
+ # always acts as if numeric_only=True. Our implmentation respects it so we
+ # need to make it explicit.
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).sum(
+ level=0, numeric_only=True),
+ GROUPBY_DF)
+
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo'])[['bar']].count(level=1),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).count(level=1), GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).max(
+ level=1, numeric_only=False),
+ GROUPBY_DF)
+ # sum with str columns is order-sensitive
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).sum(
+ level=1, numeric_only=True),
+ GROUPBY_DF)
+
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).median(
+ level=0, numeric_only=True),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.drop('str', axis=1).set_index(['foo', 'group']).median(
+ level=1, numeric_only=True),
+ GROUPBY_DF)
+
+ def test_series_agg_multifunc_level(self):
+ # level= is ignored for multiple agg fns
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).bar.agg(['min', 'max'],
+ level=0),
+ GROUPBY_DF)
+
+ def test_dataframe_agg_multifunc_level(self):
+ # level= is ignored for multiple agg fns
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).agg(['min', 'max'], level=0),
+ GROUPBY_DF)
+
+ @parameterized.expand([(True, ), (False, )])
+ @unittest.skipIf(
+ PD_VERSION < (1, 2),
+ "pandas 1.1.0 produces different dtypes for these examples")
+ def test_dataframe_agg_numeric_only(self, numeric_only):
+ # Note other aggregation functions can fail on this input with
+ # numeric_only={False,None}. These are the only ones that actually work for
+ # the string inputs.
+ self._run_test(lambda df: df.max(numeric_only=numeric_only), GROUPBY_DF)
+ self._run_test(lambda df: df.min(numeric_only=numeric_only), GROUPBY_DF)
+
+ @unittest.skip(
+ "pandas implementation doesn't respect numeric_only= with "
+ "level= (https://github.com/pandas-dev/pandas/issues/40788)")
+ def test_dataframe_agg_level_numeric_only(self):
+ self._run_test(
+ lambda df: df.set_index('foo').sum(level=0, numeric_only=True),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index('foo').max(level=0, numeric_only=True),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index('foo').mean(level=0, numeric_only=True),
+ GROUPBY_DF)
+ self._run_test(
+ lambda df: df.set_index('foo').median(level=0, numeric_only=True),
+ GROUPBY_DF)
+
+ def test_dataframe_agg_bool_only(self):
+ df = pd.DataFrame({
+ 'all': [True for i in range(10)],
+ 'any': [i % 3 == 0 for i in range(10)],
+ 'int': range(10)
+ })
+
+ self._run_test(lambda df: df.all(), df)
+ self._run_test(lambda df: df.any(), df)
+ self._run_test(lambda df: df.all(bool_only=True), df)
+ self._run_test(lambda df: df.any(bool_only=True), df)
+
+ @unittest.skip(
+ "pandas doesn't implement bool_only= with level= "
+ "(https://github.com/pandas-dev/pandas/blob/"
+ "v1.2.3/pandas/core/generic.py#L10573)")
+ def test_dataframe_agg_level_bool_only(self):
+ df = pd.DataFrame({
+ 'all': [True for i in range(10)],
+ 'any': [i % 3 == 0 for i in range(10)],
+ 'int': range(10)
+ })
+
+ self._run_test(lambda df: df.set_index('int', drop=False).all(level=0), df)
+ self._run_test(lambda df: df.set_index('int', drop=False).any(level=0), df)
+ self._run_test(
+ lambda df: df.set_index('int', drop=False).all(level=0, bool_only=True),
+ df)
+ self._run_test(
+ lambda df: df.set_index('int', drop=False).any(level=0, bool_only=True),
+ df)
+
+ def test_series_agg_np_size(self):
+ self._run_test(
+ lambda df: df.set_index(['group', 'foo']).agg(np.size), GROUPBY_DF)
+
+ def test_df_agg_invalid_kwarg_raises(self):
+ self._run_error_test(lambda df: df.agg('mean', bool_only=True), GROUPBY_DF)
+ self._run_error_test(
+ lambda df: df.agg('any', numeric_only=True), GROUPBY_DF)
+ self._run_error_test(
+ lambda df: df.agg('median', min_count=3, numeric_only=True), GROUPBY_DF)
+
+ def test_series_agg_method_invalid_kwarg_raises(self):
+ self._run_error_test(lambda df: df.foo.median(min_count=3), GROUPBY_DF)
+ self._run_error_test(
+ lambda df: df.foo.agg('median', min_count=3), GROUPBY_DF)
+
+ @unittest.skipIf(
+ PD_VERSION < (1, 3),
+ (
+ "DataFrame.agg raises a different exception from the "
+ "aggregation methods. Fixed in "
+ "https://github.com/pandas-dev/pandas/pull/40543."))
+ def test_df_agg_method_invalid_kwarg_raises(self):
+ self._run_error_test(lambda df: df.mean(bool_only=True), GROUPBY_DF)
+ self._run_error_test(lambda df: df.any(numeric_only=True), GROUPBY_DF)
+ self._run_error_test(
+ lambda df: df.median(min_count=3, numeric_only=True), GROUPBY_DF)
+
+ def test_agg_min_count(self):
+ df = pd.DataFrame({
+ 'good': [1, 2, 3, np.nan],
+ 'bad': [np.nan, np.nan, np.nan, 4],
+ },
+ index=['a', 'b', 'a', 'b'])
+
+ self._run_test(lambda df: df.sum(level=0, min_count=2), df)
+
+ self._run_test(lambda df: df.sum(min_count=3), df, nonparallel=True)
+ self._run_test(lambda df: df.sum(min_count=1), df, nonparallel=True)
+ self._run_test(lambda df: df.good.sum(min_count=2), df, nonparallel=True)
+ self._run_test(lambda df: df.bad.sum(min_count=2), df, nonparallel=True)
+
+ def test_groupby_sum_min_count(self):
+ df = pd.DataFrame({
+ 'good': [1, 2, 3, np.nan],
+ 'bad': [np.nan, np.nan, np.nan, 4],
+ 'group': ['a', 'b', 'a', 'b']
+ })
+
+ self._run_test(lambda df: df.groupby('group').sum(min_count=2), df)
+
+ def test_dataframe_sum_nonnumeric_raises(self):
+ # Attempting a numeric aggregation with the str column present should
+ # raise, and suggest the numeric_only argument
+ with self.assertRaisesRegex(frame_base.WontImplementError, 'numeric_only'):
+ self._run_test(lambda df: df.sum(), GROUPBY_DF)
+
+ # numeric_only=True should work
+ self._run_test(lambda df: df.sum(numeric_only=True), GROUPBY_DF)
+ # projecting only numeric columns should too
+ self._run_test(lambda df: df[['foo', 'bar']].sum(), GROUPBY_DF)
+
+ def test_insert(self):
+ df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
+
+ self._run_inplace_test(lambda df: df.insert(1, 'C', df.A * 2), df)
+ self._run_inplace_test(
+ lambda df: df.insert(0, 'foo', pd.Series([8], index=[1])), df)
+ self._run_inplace_test(lambda df: df.insert(2, 'bar', value='q'), df)
+
class AllowNonParallelTest(unittest.TestCase):
def _use_non_parallel_operation(self):
diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py
index cc30cec..ee3c2a3 100644
--- a/sdks/python/apache_beam/dataframe/schemas.py
+++ b/sdks/python/apache_beam/dataframe/schemas.py
@@ -281,7 +281,8 @@
ctor = element_type_from_dataframe(proxy, include_indexes=include_indexes)
return beam.ParDo(
- _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor))
+ _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor)
+ ).with_output_types(ctor)
elif isinstance(proxy, pd.Series):
# Raise a TypeError if proxy has an unknown type
output_type = _dtype_to_fieldtype(proxy.dtype)
diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py
index 8b1159c..af30e25 100644
--- a/sdks/python/apache_beam/dataframe/schemas_test.py
+++ b/sdks/python/apache_beam/dataframe/schemas_test.py
@@ -36,6 +36,8 @@
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
Simple = typing.NamedTuple(
'Simple', [('name', unicode), ('id', int), ('height', float)])
@@ -65,41 +67,59 @@
# dtype. For example:
# pd.Series([b'abc'], dtype=bytes).dtype != 'S'
# pd.Series([b'abc'], dtype=bytes).astype(bytes).dtype == 'S'
+# (test data, pandas_type, column_name, beam_type)
COLUMNS = [
- ([375, 24, 0, 10, 16], np.int32, 'i32'),
- ([375, 24, 0, 10, 16], np.int64, 'i64'),
- ([375, 24, None, 10, 16], pd.Int32Dtype(), 'i32_nullable'),
- ([375, 24, None, 10, 16], pd.Int64Dtype(), 'i64_nullable'),
- ([375., 24., None, 10., 16.], np.float64, 'f64'),
- ([375., 24., None, 10., 16.], np.float32, 'f32'),
- ([True, False, True, True, False], bool, 'bool'),
- (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any'),
- ([True, False, True, None, False], pd.BooleanDtype(), 'bool_nullable'),
+ ([375, 24, 0, 10, 16], np.int32, 'i32', np.int32),
+ ([375, 24, 0, 10, 16], np.int64, 'i64', np.int64),
+ ([375, 24, None, 10, 16],
+ pd.Int32Dtype(),
+ 'i32_nullable',
+ typing.Optional[np.int32]),
+ ([375, 24, None, 10, 16],
+ pd.Int64Dtype(),
+ 'i64_nullable',
+ typing.Optional[np.int64]),
+ ([375., 24., None, 10., 16.],
+ np.float64,
+ 'f64',
+ typing.Optional[np.float64]),
+ ([375., 24., None, 10., 16.],
+ np.float32,
+ 'f32',
+ typing.Optional[np.float32]),
+ ([True, False, True, True, False], bool, 'bool', bool),
+ (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any', typing.Any),
+ ([True, False, True, None, False],
+ pd.BooleanDtype(),
+ 'bool_nullable',
+ typing.Optional[bool]),
(['Falcon', 'Ostrich', None, 'Aardvark', 'Elephant'],
pd.StringDtype(),
- 'strdtype'),
-] # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str]]
+ 'strdtype',
+ typing.Optional[str]),
+] # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str, typing.Any]]
-NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name in COLUMNS])
-for arr, dtype, name in COLUMNS:
+NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name, _ in COLUMNS])
+for arr, dtype, name, _ in COLUMNS:
NICE_TYPES_DF[name] = pd.Series(arr, dtype=dtype, name=name).astype(dtype)
NICE_TYPES_PROXY = NICE_TYPES_DF[:0]
-SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr) for arr,
- dtype,
- name in COLUMNS]
+SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr, beam_type)
+ for (arr, dtype, name, beam_type) in COLUMNS]
_TEST_ARRAYS = [
- arr for arr, _, _ in COLUMNS
+ arr for (arr, _, _, _) in COLUMNS
] # type: typing.List[typing.List[typing.Any]]
DF_RESULT = list(zip(*_TEST_ARRAYS))
-INDEX_DF_TESTS = [
- (NICE_TYPES_DF.set_index([name for _, _, name in COLUMNS[:i]]), DF_RESULT)
- for i in range(1, len(COLUMNS) + 1)
-]
+BEAM_SCHEMA = typing.NamedTuple( # type: ignore
+ 'BEAM_SCHEMA', [(name, beam_type) for _, _, name, beam_type in COLUMNS])
+INDEX_DF_TESTS = [(
+ NICE_TYPES_DF.set_index([name for _, _, name, _ in COLUMNS[:i]]),
+ DF_RESULT,
+ BEAM_SCHEMA) for i in range(1, len(COLUMNS) + 1)]
-NOINDEX_DF_TESTS = [(NICE_TYPES_DF, DF_RESULT)]
+NOINDEX_DF_TESTS = [(NICE_TYPES_DF, DF_RESULT, BEAM_SCHEMA)]
PD_VERSION = tuple(int(n) for n in pd.__version__.split('.'))
@@ -203,8 +223,18 @@
proxy=schemas.generate_proxy(Animal)))
assert_that(res, equal_to([('Falcon', 375.), ('Parrot', 25.)]))
+ def assert_typehints_equal(self, left, right):
+ left = typehints.normalize(left)
+ right = typehints.normalize(right)
+
+ if match_is_named_tuple(left):
+ self.assertTrue(match_is_named_tuple(right))
+ self.assertEqual(left.__annotations__, right.__annotations__)
+ else:
+ self.assertEqual(left, right)
+
@parameterized.expand(SERIES_TESTS + NOINDEX_DF_TESTS)
- def test_unbatch_no_index(self, df_or_series, rows):
+ def test_unbatch_no_index(self, df_or_series, rows, beam_type):
proxy = df_or_series[:0]
with TestPipeline() as p:
@@ -212,10 +242,15 @@
p | beam.Create([df_or_series[::2], df_or_series[1::2]])
| schemas.UnbatchPandas(proxy))
+ # Verify that the unbatched PCollection has the expected typehint
+ # TODO(BEAM-8538): typehints should support NamedTuple so we can use
+ # typehints.is_consistent_with here instead
+ self.assert_typehints_equal(res.element_type, beam_type)
+
assert_that(res, equal_to(rows))
@parameterized.expand(SERIES_TESTS + INDEX_DF_TESTS)
- def test_unbatch_with_index(self, df_or_series, rows):
+ def test_unbatch_with_index(self, df_or_series, rows, _):
proxy = df_or_series[:0]
with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/examples/wordcount_it_test.py b/sdks/python/apache_beam/examples/wordcount_it_test.py
index 24c455b..242dcff 100644
--- a/sdks/python/apache_beam/examples/wordcount_it_test.py
+++ b/sdks/python/apache_beam/examples/wordcount_it_test.py
@@ -24,6 +24,7 @@
import time
import unittest
+import pytest
from hamcrest.core.core.allof import all_of
from nose.plugins.attrib import attr
@@ -50,18 +51,19 @@
def test_wordcount_it(self):
self._run_wordcount_it(wordcount.run)
- @attr('IT', 'ValidatesContainer')
+ @attr('IT')
+ @pytest.mark.it_validatescontainer
def test_wordcount_fnapi_it(self):
self._run_wordcount_it(wordcount.run, experiment='beam_fn_api')
- @attr('ValidatesContainer')
+ @pytest.mark.it_validatescontainer
def test_wordcount_it_with_prebuilt_sdk_container_local_docker(self):
self._run_wordcount_it(
wordcount.run,
experiment='beam_fn_api',
prebuild_sdk_container_engine='local_docker')
- @attr('ValidatesContainer')
+ @pytest.mark.it_validatescontainer
def test_wordcount_it_with_prebuilt_sdk_container_cloud_build(self):
self._run_wordcount_it(
wordcount.run,
diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp.py
index 3c4406c..93510c8 100644
--- a/sdks/python/apache_beam/ml/gcp/cloud_dlp.py
+++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp.py
@@ -19,8 +19,6 @@
functionality.
"""
-from __future__ import absolute_import
-
import logging
from google.cloud import dlp_v2
diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py
index fbba610..4ada679 100644
--- a/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py
+++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py
@@ -17,8 +17,6 @@
"""Integration tests for Google Cloud Video Intelligence API transforms."""
-from __future__ import absolute_import
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py
index ef25555..111e5be 100644
--- a/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py
+++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py
@@ -17,8 +17,6 @@
"""Unit tests for Google Cloud Video Intelligence API transforms."""
-from __future__ import absolute_import
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py
index 5263a60..7817eb9 100644
--- a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py
+++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py
@@ -15,8 +15,6 @@
# limitations under the License.
#
-from __future__ import absolute_import
-
from typing import Mapping
from typing import Optional
from typing import Sequence
diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py
index ef72359..e639517 100644
--- a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py
+++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py
@@ -18,8 +18,6 @@
"""Unit tests for Google Cloud Natural Language API transform."""
-from __future__ import absolute_import
-
import unittest
import mock
diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py
index 4cf58e4..932bc685 100644
--- a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py
+++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py
@@ -16,8 +16,6 @@
#
# pytype: skip-file
-from __future__ import absolute_import
-
import unittest
from nose.plugins.attrib import attr
diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py
index 67ff496..bc0aa08 100644
--- a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py
+++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py
@@ -17,15 +17,10 @@
"""A connector for sending API requests to the GCP Video Intelligence API."""
-from __future__ import absolute_import
-
from typing import Optional
from typing import Tuple
from typing import Union
-from future.utils import binary_type
-from future.utils import text_type
-
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.transforms import DoFn
@@ -55,8 +50,8 @@
ref: https://cloud.google.com/video-intelligence/docs
Sends each element to the GCP Video Intelligence API. Element is a
- Union[text_type, binary_type] of either an URI (e.g. a GCS URI) or
- binary_type base64-encoded video data.
+ Union[str, bytes] of either an URI (e.g. a GCS URI) or
+ bytes base64-encoded video data.
Accepts an `AsDict` side input that maps each video to a video context.
"""
def __init__(
@@ -118,8 +113,7 @@
@typehints.with_input_types(
- Union[text_type, binary_type],
- Optional[videointelligence.types.VideoContext])
+ Union[str, bytes], Optional[videointelligence.types.VideoContext])
class _VideoAnnotateFn(DoFn):
"""A DoFn that sends each input element to the GCP Video Intelligence API
service and outputs an element with the return result of the API
@@ -138,7 +132,7 @@
self._client = get_videointelligence_client()
def _annotate_video(self, element, video_context):
- if isinstance(element, text_type): # Is element an URI to a GCS bucket
+ if isinstance(element, str): # Is element an URI to a GCS bucket
response = self._client.annotate_video(
input_uri=element,
features=self.features,
@@ -171,11 +165,11 @@
Sends each element to the GCP Video Intelligence API.
Element is a tuple of
- (Union[text_type, binary_type],
+ (Union[str, bytes],
Optional[videointelligence.types.VideoContext])
where the former is either an URI (e.g. a GCS URI) or
- binary_type base64-encoded video data
+ bytes base64-encoded video data
"""
def __init__(self, features, location_id=None, metadata=None, timeout=120):
"""
@@ -208,8 +202,7 @@
@typehints.with_input_types(
- Tuple[Union[text_type, binary_type],
- Optional[videointelligence.types.VideoContext]])
+ Tuple[Union[str, bytes], Optional[videointelligence.types.VideoContext]])
class _VideoAnnotateFnWithContext(_VideoAnnotateFn):
"""A DoFn that unpacks each input tuple to element, video_context variables
and sends these to the GCP Video Intelligence API service and outputs
diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py
index b0bb441..3215ceb 100644
--- a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py
+++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py
@@ -19,9 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py
index a2d8216..d934520 100644
--- a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py
+++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py
@@ -19,9 +19,6 @@
"""An integration test that labels entities appearing in a video and checks
if some expected entities were properly recognized."""
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
import unittest
import hamcrest as hc
diff --git a/sdks/python/apache_beam/ml/gcp/visionml.py b/sdks/python/apache_beam/ml/gcp/visionml.py
index 13884a7..0fb45ce 100644
--- a/sdks/python/apache_beam/ml/gcp/visionml.py
+++ b/sdks/python/apache_beam/ml/gcp/visionml.py
@@ -20,16 +20,11 @@
A connector for sending API requests to the GCP Vision API.
"""
-from __future__ import absolute_import
-
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
-from future.utils import binary_type
-from future.utils import text_type
-
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.transforms import DoFn
@@ -65,8 +60,8 @@
Batches elements together using ``util.BatchElements`` PTransform and sends
each batch of elements to the GCP Vision API.
- Element is a Union[text_type, binary_type] of either an URI (e.g. a GCS URI)
- or binary_type base64-encoded image data.
+ Element is a Union[str, bytes] of either an URI (e.g. a GCS URI)
+ or bytes base64-encoded image data.
Accepts an `AsDict` side input that maps each image to an image context.
"""
@@ -158,7 +153,7 @@
metadata=self.metadata)))
@typehints.with_input_types(
- Union[text_type, binary_type], Optional[vision.types.ImageContext])
+ Union[str, bytes], Optional[vision.types.ImageContext])
@typehints.with_output_types(List[vision.types.AnnotateImageRequest])
def _create_image_annotation_pairs(self, element, context_side_input):
if context_side_input: # If we have a side input image context, use that
@@ -166,10 +161,10 @@
else:
image_context = None
- if isinstance(element, text_type):
+ if isinstance(element, str):
image = vision.types.Image(
source=vision.types.ImageSource(image_uri=element))
- else: # Typehint checks only allows text_type or binary_type
+ else: # Typehint checks only allows str or bytes
image = vision.types.Image(content=element)
request = vision.types.AnnotateImageRequest(
@@ -185,10 +180,10 @@
Element is a tuple of::
- (Union[text_type, binary_type],
+ (Union[str, bytes],
Optional[``vision.types.ImageContext``])
- where the former is either an URI (e.g. a GCS URI) or binary_type
+ where the former is either an URI (e.g. a GCS URI) or bytes
base64-encoded image data.
"""
def __init__(
@@ -249,14 +244,14 @@
metadata=self.metadata)))
@typehints.with_input_types(
- Tuple[Union[text_type, binary_type], Optional[vision.types.ImageContext]])
+ Tuple[Union[str, bytes], Optional[vision.types.ImageContext]])
@typehints.with_output_types(List[vision.types.AnnotateImageRequest])
def _create_image_annotation_pairs(self, element, **kwargs):
element, image_context = element # Unpack (image, image_context) tuple
- if isinstance(element, text_type):
+ if isinstance(element, str):
image = vision.types.Image(
source=vision.types.ImageSource(image_uri=element))
- else: # Typehint checks only allows text_type or binary_type
+ else: # Typehint checks only allows str or bytes
image = vision.types.Image(content=element)
request = vision.types.AnnotateImageRequest(
diff --git a/sdks/python/apache_beam/ml/gcp/visionml_test.py b/sdks/python/apache_beam/ml/gcp/visionml_test.py
index d4c6c20..f038442 100644
--- a/sdks/python/apache_beam/ml/gcp/visionml_test.py
+++ b/sdks/python/apache_beam/ml/gcp/visionml_test.py
@@ -20,9 +20,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/visionml_test_it.py b/sdks/python/apache_beam/ml/gcp/visionml_test_it.py
index a81acd0..14af3cb8 100644
--- a/sdks/python/apache_beam/ml/gcp/visionml_test_it.py
+++ b/sdks/python/apache_beam/ml/gcp/visionml_test_it.py
@@ -16,8 +16,6 @@
#
# pytype: skip-file
-from __future__ import absolute_import
-
import unittest
from nose.plugins.attrib import attr
diff --git a/sdks/python/apache_beam/options/__init__.py b/sdks/python/apache_beam/options/__init__.py
index f4f43cb..cce3aca 100644
--- a/sdks/python/apache_beam/options/__init__.py
+++ b/sdks/python/apache_beam/options/__init__.py
@@ -14,4 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from __future__ import absolute_import
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index 449c7d3..073014c 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -19,13 +19,9 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import argparse
import json
import logging
-from builtins import list
-from builtins import object
from typing import Any
from typing import Callable
from typing import Dict
diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py
index 08c4145..8299a54 100644
--- a/sdks/python/apache_beam/options/pipeline_options_test.py
+++ b/sdks/python/apache_beam/options/pipeline_options_test.py
@@ -19,8 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import json
import logging
import unittest
diff --git a/sdks/python/apache_beam/options/pipeline_options_validator.py b/sdks/python/apache_beam/options/pipeline_options_validator.py
index 97df45f..2196ca6 100644
--- a/sdks/python/apache_beam/options/pipeline_options_validator.py
+++ b/sdks/python/apache_beam/options/pipeline_options_validator.py
@@ -21,13 +21,8 @@
"""
# pytype: skip-file
-from __future__ import absolute_import
-
import logging
import re
-from builtins import object
-
-from past.builtins import unicode
from apache_beam.internal import pickler
from apache_beam.options.pipeline_options import DebugOptions
@@ -220,8 +215,7 @@
'Transform name mapping option is only useful when '
'--update and --streaming is specified')
for _, (key, value) in enumerate(view.transform_name_mapping.items()):
- if not isinstance(key, (str, unicode)) \
- or not isinstance(value, (str, unicode)):
+ if not isinstance(key, str) or not isinstance(value, str):
errors.extend(
self._validate_error(
self.ERR_INVALID_TRANSFORM_NAME_MAPPING, key, value))
diff --git a/sdks/python/apache_beam/options/pipeline_options_validator_test.py b/sdks/python/apache_beam/options/pipeline_options_validator_test.py
index 0eea8d7..54b3a69 100644
--- a/sdks/python/apache_beam/options/pipeline_options_validator_test.py
+++ b/sdks/python/apache_beam/options/pipeline_options_validator_test.py
@@ -19,11 +19,8 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import logging
import unittest
-from builtins import object
from hamcrest import assert_that
from hamcrest import contains_string
diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py
index a300df4..5a5d363 100644
--- a/sdks/python/apache_beam/options/value_provider.py
+++ b/sdks/python/apache_beam/options/value_provider.py
@@ -24,9 +24,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
-from builtins import object
from functools import wraps
from typing import Set
diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py
index 189501b..21e05b3 100644
--- a/sdks/python/apache_beam/options/value_provider_test.py
+++ b/sdks/python/apache_beam/options/value_provider_test.py
@@ -19,8 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index e296a30..35d38a7 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -299,6 +299,9 @@
original_transform_node.full_label,
original_transform_node.inputs)
+ replacement_transform_node.resource_hints = (
+ original_transform_node.resource_hints)
+
# Transform execution could depend on order in which nodes are
# considered. Hence we insert the replacement transform node to same
# index as the original transform node. Note that this operation
diff --git a/sdks/python/apache_beam/portability/__init__.py b/sdks/python/apache_beam/portability/__init__.py
index 9fbf215..0bce5d6 100644
--- a/sdks/python/apache_beam/portability/__init__.py
+++ b/sdks/python/apache_beam/portability/__init__.py
@@ -16,4 +16,3 @@
#
"""For internal use only; no backwards-compatibility guarantees."""
-from __future__ import absolute_import
diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py
index 18cc249..4b8838f 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -19,8 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
from apache_beam.portability.api.beam_runner_api_pb2_urns import BeamConstants
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardArtifacts
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardCoders
diff --git a/sdks/python/apache_beam/portability/utils.py b/sdks/python/apache_beam/portability/utils.py
index 176c6db..4d23ff1 100644
--- a/sdks/python/apache_beam/portability/utils.py
+++ b/sdks/python/apache_beam/portability/utils.py
@@ -16,8 +16,6 @@
#
"""For internal use only; no backwards-compatibility guarantees."""
-from __future__ import absolute_import
-
from typing import TYPE_CHECKING
from typing import NamedTuple
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py
index b51489a..0ee0e0f 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_exercise_metrics_pipeline_test.py
@@ -22,6 +22,7 @@
import argparse
import unittest
+import pytest
from nose.plugins.attrib import attr
import apache_beam as beam
@@ -50,7 +51,8 @@
dataflow_exercise_metrics_pipeline.metric_matchers())
self.assertFalse(errors, str(errors))
- @attr('IT', 'ValidatesContainer')
+ @attr('IT')
+ @pytest.mark.it_validatescontainer
def test_metrics_fnapi_it(self):
result = self.run_pipeline(experiment='beam_fn_api')
errors = metric_result_matchers.verify_all(
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index c6d816a..e1f6635 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -123,6 +123,11 @@
try:
_load_main_session(semi_persistent_directory)
+ except CorruptMainSessionException:
+ exception_details = traceback.format_exc()
+ _LOGGER.error(
+ 'Could not load main session: %s', exception_details, exc_info=True)
+ raise
except Exception: # pylint: disable=broad-except
exception_details = traceback.format_exc()
_LOGGER.error(
@@ -245,12 +250,29 @@
return 0
+class CorruptMainSessionException(Exception):
+ """
+ Used to crash this worker if a main session file was provided but
+ is not valid.
+ """
+ pass
+
+
def _load_main_session(semi_persistent_directory):
"""Loads a pickled main session from the path specified."""
if semi_persistent_directory:
session_file = os.path.join(
semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
if os.path.isfile(session_file):
+ # If the expected session file is present but empty, it's likely that
+ # the user code run by this worker will likely crash at runtime.
+ # This can happen if the worker fails to download the main session.
+ # Raise a fatal error and crash this worker, forcing a restart.
+ if os.path.getsize(session_file) == 0:
+ raise CorruptMainSessionException(
+ 'Session file found, but empty: %s. Functions defined in __main__ '
+ '(interactive session) will almost certainly fail.' %
+ (session_file, ))
pickler.load_session(session_file)
else:
_LOGGER.warning(
diff --git a/sdks/python/apache_beam/testing/test_pipeline.py b/sdks/python/apache_beam/testing/test_pipeline.py
index a8ccf6a..2ba273e 100644
--- a/sdks/python/apache_beam/testing/test_pipeline.py
+++ b/sdks/python/apache_beam/testing/test_pipeline.py
@@ -59,6 +59,10 @@
pcoll = ...
assert_that(pcoll, equal_to(...))
"""
+ # Command line options read in by pytest.
+ # If this is not None, will use as default value for --test-pipeline-options.
+ pytest_test_pipeline_options = None
+
def __init__(
self,
runner=None,
@@ -142,8 +146,9 @@
default=False,
help='whether not to use test-runner-api')
known, unused_argv = parser.parse_known_args(argv)
-
- if self.is_integration_test and not known.test_pipeline_options:
+ test_pipeline_options = known.test_pipeline_options or \
+ TestPipeline.pytest_test_pipeline_options
+ if self.is_integration_test and not test_pipeline_options:
# Skip integration test when argument '--test-pipeline-options' is not
# specified since nose calls integration tests when runs unit test by
# 'setup.py test'.
@@ -152,8 +157,8 @@
'is not specified')
self.not_use_test_runner_api = known.not_use_test_runner_api
- return shlex.split(known.test_pipeline_options) \
- if known.test_pipeline_options else []
+ return shlex.split(test_pipeline_options) \
+ if test_pipeline_options else []
def get_full_options_as_args(self, **extra_opts):
"""Get full pipeline options as an argument list.
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 2b738bf..d77727e 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -107,7 +107,7 @@
return getattr(user_type, '__origin__', None) is expected_origin
-def _match_is_named_tuple(user_type):
+def match_is_named_tuple(user_type):
return (
_safe_issubclass(user_type, typing.Tuple) and
hasattr(user_type, '_field_types'))
@@ -234,7 +234,7 @@
# We just convert it to Any for now.
# This MUST appear before the entry for the normal Tuple.
_TypeMapEntry(
- match=_match_is_named_tuple, arity=0, beam_type=typehints.Any),
+ match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Tuple),
arity=-1,
diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py
index 6a299bd..5daf68a 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -70,10 +70,10 @@
from apache_beam.typehints import row_type
from apache_beam.typehints.native_type_compatibility import _get_args
from apache_beam.typehints.native_type_compatibility import _match_is_exactly_mapping
-from apache_beam.typehints.native_type_compatibility import _match_is_named_tuple
from apache_beam.typehints.native_type_compatibility import _match_is_optional
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import extract_optional_type
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.utils import proto_utils
from apache_beam.utils.timestamp import Timestamp
@@ -148,7 +148,7 @@
def typing_to_runner_api(type_):
- if _match_is_named_tuple(type_):
+ if match_is_named_tuple(type_):
schema = None
if hasattr(type_, _BEAM_SCHEMA_ID):
schema = SCHEMA_REGISTRY.get_schema_by_id(getattr(type_, _BEAM_SCHEMA_ID))
@@ -287,7 +287,7 @@
# TODO(BEAM-10722): Make sure beam.Row generated schemas are registered and
# de-duped
return named_fields_to_schema(element_type._fields)
- elif _match_is_named_tuple(element_type):
+ elif match_is_named_tuple(element_type):
return named_tuple_to_schema(element_type)
else:
raise TypeError(
diff --git a/sdks/python/conftest.py b/sdks/python/conftest.py
index b0a35cf..3e2d5ca 100644
--- a/sdks/python/conftest.py
+++ b/sdks/python/conftest.py
@@ -21,6 +21,7 @@
import sys
from apache_beam.options import pipeline_options
+from apache_beam.testing.test_pipeline import TestPipeline
MAX_SUPPORTED_PYTHON_VERSION = (3, 8)
@@ -40,5 +41,11 @@
def pytest_configure(config):
+ """Saves options added in pytest_addoption for later use.
+ This is necessary since pytest-xdist workers do not have the same sys.argv as
+ the main pytest invocation. xdist does seem to pickle TestPipeline
+ """
+ TestPipeline.pytest_test_pipeline_options = config.getoption(
+ 'test_pipeline_options', default='')
# Enable optional type checks on all tests.
pipeline_options.enable_all_additional_type_checks()
diff --git a/sdks/python/container/run_validatescontainer.sh b/sdks/python/container/run_validatescontainer.sh
index ad5ecb7..4eab5d9 100755
--- a/sdks/python/container/run_validatescontainer.sh
+++ b/sdks/python/container/run_validatescontainer.sh
@@ -71,7 +71,7 @@
echo "Must set Python version with one of 'python36', 'python37' and 'python38' from commandline."
exit 1
fi
-XUNIT_FILE="nosetests-$IMAGE_NAME.xml"
+XUNIT_FILE="pytest-$IMAGE_NAME.xml"
# Verify in the root of the repository
test -d sdks/python/container
@@ -118,14 +118,14 @@
# Run ValidatesRunner tests on Google Cloud Dataflow service
echo ">>> RUNNING DATAFLOW RUNNER VALIDATESCONTAINER TEST"
-python setup.py nosetests \
- --attr ValidatesContainer \
- --nologcapture \
- --processes=1 \
- --process-timeout=900 \
- --with-xunitmp \
- --xunitmp-file=$XUNIT_FILE \
- --ignore-files '.*py3\d?\.py$' \
+pytest -o junit_suite_name=$IMAGE_NAME \
+ -m="it_validatescontainer" \
+ --show-capture=no \
+ --numprocesses=1 \
+ --timeout=900 \
+ --junitxml=$XUNIT_FILE \
+ --ignore-glob '.*py3\d?\.py$' \
+ --log-cli-level=INFO \
--test-pipeline-options=" \
--runner=TestDataflowRunner \
--project=$PROJECT \
diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini
index 00d8032..4837bca 100644
--- a/sdks/python/pytest.ini
+++ b/sdks/python/pytest.ini
@@ -24,9 +24,10 @@
python_functions =
# Discover tests using filenames.
# See conftest.py for extra collection rules.
-python_files = test_*.py *_test.py *_test_py3*.py
+python_files = test_*.py *_test.py *_test_py3*.py *_test_it.py
markers =
+ it_validatescontainer: collect for ValidatesContainer integration test runs
# Tests using this marker conflict with the xdist plugin in some way, such
# as enabling save_main_session.
no_xdist: run without pytest-xdist plugin
diff --git a/sdks/python/scripts/run_integration_test.sh b/sdks/python/scripts/run_integration_test.sh
index 519ee3d..c5baa89 100755
--- a/sdks/python/scripts/run_integration_test.sh
+++ b/sdks/python/scripts/run_integration_test.sh
@@ -78,6 +78,7 @@
WORKER_JAR=""
KMS_KEY_NAME="projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test"
SUITE=""
+COLLECT_MARKERS=
# Default test (nose) options.
# Run WordCountIT.test_wordcount_it by default if no test options are
@@ -163,6 +164,16 @@
shift # past argument
shift # past value
;;
+ --pytest)
+ PYTEST="$2"
+ shift # past argument
+ shift # past value
+ ;;
+ --collect)
+ COLLECT_MARKERS="-m=$2"
+ shift # past argument
+ shift # past value
+ ;;
*) # unknown option
echo "Unknown option: $1"
exit 1
@@ -270,11 +281,23 @@
# Run tests and validate that jobs finish successfully.
echo ">>> RUNNING integration tests with pipeline options: $PIPELINE_OPTS"
-echo ">>> test options: $TEST_OPTS"
-# TODO(BEAM-3713): Pass $SUITE once migrated to pytest. xunitmp doesn't support
-# suite names.
-python setup.py nosetests \
- --test-pipeline-options="$PIPELINE_OPTS" \
- --with-xunitmp --xunitmp-file=$XUNIT_FILE \
- --ignore-files '.*py3\d?\.py$' \
- $TEST_OPTS
+if [[ "$PYTEST" = true ]]; then
+ echo ">>> pytest options: $TEST_OPTS"
+ echo ">>> collect markers: $COLLECT_MARKERS"
+ ARGS="-o junit_suite_name=$SUITE --junitxml=pytest_$SUITE.xml $TEST_OPTS"
+ # Handle markers as an independient argument from $TEST_OPTS to prevent errors in space separeted flags
+ if [ -z "$COLLECT_MARKERS" ]; then
+ pytest $ARGS --test-pipeline-options="$PIPELINE_OPTS"
+ else
+ pytest $ARGS --test-pipeline-options="$PIPELINE_OPTS" "$COLLECT_MARKERS"
+ fi
+else
+ echo ">>> test options: $TEST_OPTS"
+ # TODO(BEAM-3713): Pass $SUITE once migrated to pytest. xunitmp doesn't
+ # support suite names.
+ python setup.py nosetests \
+ --test-pipeline-options="$PIPELINE_OPTS" \
+ --with-xunitmp --xunitmp-file=$XUNIT_FILE \
+ --ignore-files '.*py3\d?\.py$' \
+ $TEST_OPTS
+fi
\ No newline at end of file