Merge pull request #14769: [BEAM-12325] Remove use of beam_fn_api experiment from Dataflow tests
diff --git a/.test-infra/jenkins/CommonJobProperties.groovy b/.test-infra/jenkins/CommonJobProperties.groovy
index 851fc0b..295cc89 100644
--- a/.test-infra/jenkins/CommonJobProperties.groovy
+++ b/.test-infra/jenkins/CommonJobProperties.groovy
@@ -183,6 +183,11 @@
context.switches("-Dorg.gradle.jvmargs=-Xms2g")
context.switches("-Dorg.gradle.jvmargs=-Xmx4g")
+ // Disable file system watching for CI builds
+ // Builds are performed on a clean clone and files aren't modified, so
+ // there's no value in watching for changes.
+ context.switches("-Dorg.gradle.vfs.watch=false")
+
// Include dependency licenses when build docker images on Jenkins, see https://s.apache.org/zt68q
context.switches("-Pdocker-pull-licenses")
}
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
index 0e3e628..1ac4562 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
@@ -27,7 +27,7 @@
description('Runs the ValidatesRunner suite on the Dataflow runner.')
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
previousNames(/beam_PostCommit_Java_ValidatesRunner_Dataflow_Gradle/)
// Publish all test results to Jenkins
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
index af9e25e..6ba9685 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
@@ -28,7 +28,7 @@
def JAVA_11_HOME = '/usr/lib/jvm/java-11-openjdk-amd64'
def JAVA_8_HOME = '/usr/lib/jvm/java-8-openjdk-amd64'
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
publishers {
archiveJunit('**/build/test-results/**/*.xml')
}
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
index 0de085d..3b3d8ed 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
@@ -27,7 +27,7 @@
description('Runs Java ValidatesRunner suite on the Dataflow runner V2 forcing streaming mode.')
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 330)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 450)
// Publish all test results to Jenkins
publishers {
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
index e847928..5631649 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
@@ -26,7 +26,7 @@
description('Runs Python ValidatesRunner suite on the Dataflow runner.')
// Set common parameters.
- commonJobProperties.setTopLevelMainJobProperties(delegate)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
publishers {
archiveJunit('**/nosetests*.xml')
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
index c219de6..e164514 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
@@ -26,7 +26,7 @@
description('Runs Python ValidatesRunner suite on the Dataflow runner v2.')
// Set common parameters.
- commonJobProperties.setTopLevelMainJobProperties(delegate)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
publishers {
archiveJunit('**/nosetests*.xml')
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
index 40e7383..5cba1d4 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
@@ -42,7 +42,7 @@
static def alpn_api_version = "1.1.2.v20150522"
static def npn_api_version = "1.1.1.v20141010"
static def jboss_marshalling_version = "1.4.11.Final"
- static def jboss_modules_version = "1.11.0.Final"
+ static def jboss_modules_version = "1.1.0.Beta1"
/** Returns the list of compile time dependencies. */
static List<String> dependencies() {
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go b/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
index a011a78..0bd33bd 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
@@ -51,7 +51,6 @@
// Registry retains mappings from go types to Schemas and LogicalTypes.
type Registry struct {
- lastShortID int64
typeToSchema map[reflect.Type]*pipepb.Schema
idToType map[string]reflect.Type
syntheticToUser map[reflect.Type]reflect.Type
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go b/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
index abb6f10..087d8c1 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
@@ -26,7 +26,9 @@
package schema
import (
+ "bytes"
"fmt"
+ "hash/fnv"
"reflect"
"strings"
@@ -72,8 +74,19 @@
defaultRegistry.RegisterType(ut)
}
-func getUUID() string {
- return uuid.New().String()
+// getUUID generates a UUID using the string form of the type name.
+func getUUID(ut reflect.Type) string {
+ // String produces non-empty output for pointer and slice types.
+ typename := ut.String()
+ hasher := fnv.New128a()
+ if n, err := hasher.Write([]byte(typename)); err != nil || n != len(typename) {
+ panic(fmt.Sprintf("unable to generate schema uuid for %s, wrote out %d bytes, want %d: err %v", typename, n, len(typename), err))
+ }
+ id, err := uuid.NewRandomFromReader(bytes.NewBuffer(hasher.Sum(nil)))
+ if err != nil {
+ panic(fmt.Sprintf("unable to genereate schema uuid for type %s: %v", typename, err))
+ }
+ return id.String()
}
// Registered returns whether the given type has been registered with
@@ -350,7 +363,7 @@
if lID != "" {
schm.Options = append(schm.Options, logicalOption(lID))
}
- schm.Id = getUUID()
+ schm.Id = getUUID(ot)
r.typeToSchema[ot] = schm
r.idToType[schm.GetId()] = ot
return schm, nil
@@ -365,7 +378,7 @@
// Cache the pointer type here with it's own id.
pt := reflect.PtrTo(t)
schm = proto.Clone(schm).(*pipepb.Schema)
- schm.Id = getUUID()
+ schm.Id = getUUID(pt)
schm.Options = append(schm.Options, &pipepb.Option{
Name: optGoNillable,
})
@@ -454,7 +467,7 @@
schm := ftype.GetRowType().GetSchema()
schm = proto.Clone(schm).(*pipepb.Schema)
schm.Options = append(schm.Options, logicalOption(lID))
- schm.Id = getUUID()
+ schm.Id = getUUID(t)
r.typeToSchema[t] = schm
r.idToType[schm.GetId()] = t
return schm, nil
@@ -483,7 +496,7 @@
schm := &pipepb.Schema{
Fields: fields,
- Id: getUUID(),
+ Id: getUUID(t),
}
r.idToType[schm.GetId()] = t
r.typeToSchema[t] = schm
diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
index 9ea5099..1ba1ac4 100644
--- a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
+++ b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
@@ -24,53 +24,105 @@
"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/go/pkg/beam/core/util/protox"
+ "github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
pipepb "github.com/apache/beam/sdks/go/pkg/beam/model/pipeline_v1"
)
// ResolveArtifacts acquires all dependencies for a cross-language transform
func ResolveArtifacts(ctx context.Context, edges []*graph.MultiEdge, p *pipepb.Pipeline) {
- path, err := filepath.Abs("/tmp/artifacts")
+ _, err := ResolveArtifactsWithConfig(ctx, edges, ResolveConfig{})
if err != nil {
panic(err)
}
+}
+
+// ResolveConfig contains fields for configuring the behavior for resolving
+// artifacts.
+type ResolveConfig struct {
+ // SdkPath replaces the default filepath for dependencies, but only in the
+ // external environment proto to be used by the SDK Harness during pipeline
+ // execution. This is used to specify alternate staging directories, such
+ // as for staging artifacts remotely.
+ //
+ // Setting an SdkPath does not change staging behavior otherwise. All
+ // artifacts still get staged to the default local filepath, and it is the
+ // user's responsibility to stage those local artifacts to the SdkPath.
+ SdkPath string
+
+ // JoinFn is a function for combining SdkPath and individual artifact names.
+ // If not specified, it defaults to using filepath.Join.
+ JoinFn func(path, name string) string
+}
+
+func defaultJoinFn(path, name string) string {
+ return filepath.Join(path, "/", name)
+}
+
+// ResolveArtifactsWithConfig acquires all dependencies for cross-language
+// transforms, but with some additional configuration to behavior. By default,
+// this function performs the following steps for each cross-language transform
+// in the list of edges:
+// 1. Retrieves a list of dependencies needed from the expansion service.
+// 2. Retrieves each dependency as an artifact and stages it to a default
+// local filepath.
+// 3. Adds the dependencies to the transform's stored environment proto.
+// The changes that can be configured are documented in ResolveConfig.
+//
+// This returns a map of "local path" to "sdk path". By default these are
+// identical, unless ResolveConfig.SdkPath has been set.
+func ResolveArtifactsWithConfig(ctx context.Context, edges []*graph.MultiEdge, cfg ResolveConfig) (paths map[string]string, err error) {
+ tmpPath, err := filepath.Abs("/tmp/artifacts")
+ if err != nil {
+ return nil, errors.WithContext(err, "resolving remote artifacts")
+ }
+ if cfg.JoinFn == nil {
+ cfg.JoinFn = defaultJoinFn
+ }
+ paths = make(map[string]string)
for _, e := range edges {
if e.Op == graph.External {
components, err := graphx.ExpandedComponents(e.External.Expanded)
if err != nil {
- panic(err)
+ return nil, errors.WithContextf(err,
+ "resolving remote artifacts for edge %v", e.Name())
}
envs := components.Environments
for eid, env := range envs {
-
if strings.HasPrefix(eid, "go") {
continue
}
deps := env.GetDependencies()
- resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", path)
+ resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", tmpPath)
if err != nil {
- panic(err)
+ return nil, errors.WithContextf(err,
+ "resolving remote artifacts for env %v in edge %v", eid, e.Name())
}
var resolvedDeps []*pipepb.ArtifactInformation
for _, a := range resolvedArtifacts {
- name, sha256 := artifact.MustExtractFilePayload(a)
- fullPath := filepath.Join(path, "/", name)
+ name, _ := artifact.MustExtractFilePayload(a)
+ fullTmpPath := filepath.Join(tmpPath, "/", name)
+ fullSdkPath := fullTmpPath
+ if len(cfg.SdkPath) > 0 {
+ fullSdkPath = cfg.JoinFn(cfg.SdkPath, name)
+ }
resolvedDeps = append(resolvedDeps,
&pipepb.ArtifactInformation{
TypeUrn: "beam:artifact:type:file:v1",
TypePayload: protox.MustEncode(
&pipepb.ArtifactFilePayload{
- Path: fullPath,
- Sha256: sha256,
+ Path: fullSdkPath,
},
),
RoleUrn: a.RoleUrn,
RolePayload: a.RolePayload,
},
)
+ paths[fullTmpPath] = fullSdkPath
}
env.Dependencies = resolvedDeps
}
}
}
+ return paths, nil
}
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
index c46ff4c..cd7be52 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
@@ -178,11 +178,23 @@
}
// (1) Build and submit
+ // NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
+ id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
+
+ modelURL := gcsx.Join(*stagingLocation, id, "model")
+ workerURL := gcsx.Join(*stagingLocation, id, "worker")
+ jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
+ xlangURL := gcsx.Join(*stagingLocation, id, "xlang")
edges, _, err := p.Build()
if err != nil {
return nil, err
}
+ artifactURLs, err := dataflowlib.ResolveXLangArtifacts(ctx, edges, opts.Project, xlangURL)
+ if err != nil {
+ return nil, errors.WithContext(err, "resolving cross-language artifacts")
+ }
+ opts.ArtifactURLs = artifactURLs
environment, err := graphx.CreateEnvironment(ctx, jobopts.GetEnvironmentUrn(ctx), getContainerImage)
if err != nil {
return nil, errors.WithContext(err, "creating environment for model pipeline")
@@ -196,13 +208,6 @@
return nil, errors.WithContext(err, "applying container image overrides")
}
- // NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
- id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
-
- modelURL := gcsx.Join(*stagingLocation, id, "model")
- workerURL := gcsx.Join(*stagingLocation, id, "worker")
- jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
-
if *dryRun {
log.Info(ctx, "Dry-run: not submitting job!")
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
index 511b962..390082b 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
@@ -56,6 +56,7 @@
WorkerRegion string
WorkerZone string
ContainerImage string
+ ArtifactURLs []string // Additional packages for workers.
// Autoscaling settings
Algorithm string
@@ -128,6 +129,15 @@
experiments = append(experiments, "use_staged_dataflow_worker_jar")
}
+ for _, url := range opts.ArtifactURLs {
+ name := url[strings.LastIndexAny(url, "/")+1:]
+ pkg := &df.Package{
+ Name: name,
+ Location: url,
+ }
+ packages = append(packages, pkg)
+ }
+
ipConfiguration := "WORKER_IP_UNSPECIFIED"
if opts.NoUsePublicIPs {
ipConfiguration = "WORKER_IP_PRIVATE"
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
index 67a2bde..49ca5bf 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
@@ -22,6 +22,8 @@
"os"
"cloud.google.com/go/storage"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/graph"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/runtime/xlangx"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/go/pkg/beam/util/gcsx"
)
@@ -54,3 +56,28 @@
_, err = gcsx.Upload(ctx, client, project, bucket, obj, r)
return err
}
+
+// ResolveXLangArtifacts resolves cross-language artifacts with a given GCS
+// URL as a destination, and then stages all local artifacts to that URL. This
+// function returns a list of staged artifact URLs.
+func ResolveXLangArtifacts(ctx context.Context, edges []*graph.MultiEdge, project, url string) ([]string, error) {
+ cfg := xlangx.ResolveConfig{
+ SdkPath: url,
+ JoinFn: func(url, name string) string {
+ return gcsx.Join(url, "/", name)
+ },
+ }
+ paths, err := xlangx.ResolveArtifactsWithConfig(ctx, edges, cfg)
+ if err != nil {
+ return nil, err
+ }
+ var urls []string
+ for local, remote := range paths {
+ err := StageFile(ctx, project, remote, local)
+ if err != nil {
+ return nil, errors.WithContextf(err, "staging file to %v", remote)
+ }
+ urls = append(urls, remote)
+ }
+ return urls, nil
+}
diff --git a/sdks/go/pkg/beam/runners/universal/universal.go b/sdks/go/pkg/beam/runners/universal/universal.go
index f01a71f..af14b38 100644
--- a/sdks/go/pkg/beam/runners/universal/universal.go
+++ b/sdks/go/pkg/beam/runners/universal/universal.go
@@ -79,6 +79,9 @@
getEnvCfg = srv.EnvironmentConfig
}
+ // Fetch all dependencies for cross-language transforms
+ xlangx.ResolveArtifacts(ctx, edges, nil)
+
environment, err := graphx.CreateEnvironment(ctx, envUrn, getEnvCfg)
if err != nil {
return nil, errors.WithContextf(err, "generating model pipeline")
@@ -88,9 +91,6 @@
return nil, errors.WithContextf(err, "generating model pipeline")
}
- // Fetch all dependencies for cross-language transforms
- xlangx.ResolveArtifacts(ctx, edges, pipeline)
-
log.Info(ctx, proto.MarshalTextString(pipeline))
opt := &runnerlib.JobOptions{
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index c756e05..7db20ac 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -78,12 +78,6 @@
var dataflowFilters = []string{
// TODO(BEAM-11576): TestFlattenDup failing on this runner.
"TestFlattenDup",
- // TODO(BEAM-11418): These tests require implementing x-lang artifact
- // staging on Dataflow.
- "TestXLang_CoGroupBy",
- "TestXLang_Multi",
- "TestXLang_Partition",
- "TestXLang_Prefix",
}
// CheckFilters checks if an integration test is filtered to be skipped, either
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
index 3b782d9..c489d12 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -73,10 +73,9 @@
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
- Iterator<AccumT> it = accumulators.iterator();
- AccumT first = it.next();
- it.remove();
- return getAggregateFn().mergeAccumulators(first, accumulators);
+ AccumT first = accumulators.iterator().next();
+ Iterable<AccumT> rest = new SkipFirstElementIterable<>(accumulators);
+ return getAggregateFn().mergeAccumulators(first, rest);
}
@Override
@@ -99,4 +98,20 @@
public TypeVariable<?> getAccumTVariable() {
return AggregateFn.class.getTypeParameters()[1];
}
+
+ /** Wrapper {@link Iterable} which always skips its first element. */
+ private static class SkipFirstElementIterable<T> implements Iterable<T> {
+ private final Iterable<T> all;
+
+ SkipFirstElementIterable(Iterable<T> all) {
+ this.all = all;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ Iterator<T> it = all.iterator();
+ it.next();
+ return it;
+ }
+ }
}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 34664ac..7b580ef 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -27,6 +27,7 @@
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
+import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.BiMap;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableBiMap;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -39,9 +40,6 @@
import org.joda.time.base.AbstractInstant;
/** Utility methods for Calcite related operations. */
-@SuppressWarnings({
- "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
public class CalciteUtils {
private static final long UNLIMITED_ARRAY_SIZE = -1L;
@@ -73,7 +71,9 @@
}
if (fieldType.getTypeName().isLogicalType()) {
- String logicalId = fieldType.getLogicalType().getIdentifier();
+ Schema.LogicalType logicalType = fieldType.getLogicalType();
+ Preconditions.checkArgumentNotNull(logicalType);
+ String logicalId = logicalType.getIdentifier();
return logicalId.equals(SqlTypes.DATE.getIdentifier())
|| logicalId.equals(SqlTypes.TIME.getIdentifier())
|| logicalId.equals(TimeWithLocalTzType.IDENTIFIER)
@@ -88,7 +88,9 @@
}
if (fieldType.getTypeName().isLogicalType()) {
- String logicalId = fieldType.getLogicalType().getIdentifier();
+ Schema.LogicalType logicalType = fieldType.getLogicalType();
+ Preconditions.checkArgumentNotNull(logicalType);
+ String logicalId = logicalType.getIdentifier();
return logicalId.equals(CharType.IDENTIFIER);
}
return false;
@@ -210,7 +212,12 @@
+ "so it cannot be converted to a %s",
sqlTypeName, Schema.FieldType.class.getSimpleName()));
default:
- return CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+ FieldType fieldType = CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+ if (fieldType == null) {
+ throw new IllegalArgumentException(
+ "Cannot find a matching Beam FieldType for Calcite type: " + sqlTypeName);
+ }
+ return fieldType;
}
}
@@ -234,7 +241,12 @@
return FieldType.row(toSchema(calciteType));
default:
- return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+ try {
+ return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(
+ "Cannot find a matching Beam FieldType for Calcite type: " + calciteType, e);
+ }
}
}
@@ -254,16 +266,22 @@
switch (fieldType.getTypeName()) {
case ARRAY:
case ITERABLE:
+ FieldType collectionElementType = fieldType.getCollectionElementType();
+ Preconditions.checkArgumentNotNull(collectionElementType);
return dataTypeFactory.createArrayType(
- toRelDataType(dataTypeFactory, fieldType.getCollectionElementType()),
- UNLIMITED_ARRAY_SIZE);
+ toRelDataType(dataTypeFactory, collectionElementType), UNLIMITED_ARRAY_SIZE);
case MAP:
- RelDataType componentKeyType = toRelDataType(dataTypeFactory, fieldType.getMapKeyType());
- RelDataType componentValueType =
- toRelDataType(dataTypeFactory, fieldType.getMapValueType());
+ FieldType mapKeyType = fieldType.getMapKeyType();
+ FieldType mapValueType = fieldType.getMapValueType();
+ Preconditions.checkArgumentNotNull(mapKeyType);
+ Preconditions.checkArgumentNotNull(mapValueType);
+ RelDataType componentKeyType = toRelDataType(dataTypeFactory, mapKeyType);
+ RelDataType componentValueType = toRelDataType(dataTypeFactory, mapValueType);
return dataTypeFactory.createMapType(componentKeyType, componentValueType);
case ROW:
- return toCalciteRowType(fieldType.getRowSchema(), dataTypeFactory);
+ Schema schema = fieldType.getRowSchema();
+ Preconditions.checkArgumentNotNull(schema);
+ return toCalciteRowType(schema, dataTypeFactory);
default:
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
index 21ab8d0..cf3f40d 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
@@ -19,12 +19,14 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertEquals;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -41,6 +43,13 @@
assertThat(coder, instanceOf(VarLongCoder.class));
}
+ @Test
+ public void mergeAccumulators() {
+ LazyAggregateCombineFn<Long, Long, Long> combiner = new LazyAggregateCombineFn<>(new Sum());
+ long merged = combiner.mergeAccumulators(ImmutableList.of(1L, 1L));
+ assertEquals(2L, merged);
+ }
+
public static class Sum implements AggregateFn<Long, Long, Long> {
@Override
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
index 50b6ab2..e76ee7f 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
@@ -30,13 +30,17 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
/** Tests for conversion from Beam schema to Calcite data type. */
public class CalciteUtilsTest {
RelDataTypeFactory dataTypeFactory;
+ @Rule public ExpectedException thrown = ExpectedException.none();
+
@Before
public void setUp() {
dataTypeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
@@ -166,4 +170,12 @@
assertEquals(schema, out);
}
+
+ @Test
+ public void testFieldTypeNotFound() {
+ RelDataType relDataType = dataTypeFactory.createUnknownType();
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Cannot find a matching Beam FieldType for Calcite type: UNKNOWN");
+ CalciteUtils.toFieldType(relDataType);
+ }
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
index 6413fd2..82c4b74 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
@@ -729,6 +729,17 @@
Map<String, String> argsMap =
((Map<String, Object>) MAPPER.convertValue(pipeline.getOptions(), Map.class).get("options"))
.entrySet().stream()
+ .filter(
+ (entry) -> {
+ if (entry.getValue() instanceof List) {
+ if (!((List) entry.getValue()).isEmpty()) {
+ throw new IllegalArgumentException("Cannot encode list arguments");
+ }
+ // We can encode empty lists, just omit them.
+ return false;
+ }
+ return true;
+ })
.collect(Collectors.toMap(Map.Entry::getKey, entry -> toArg(entry.getValue())));
InMemoryMetaStore inMemoryMetaStore = new InMemoryMetaStore();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 231d2b8..0ca4581 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -272,12 +272,14 @@
finalizeBundleHandler,
metricsShortIds);
+ BeamFnStatusClient beamFnStatusClient = null;
if (statusApiServiceDescriptor != null) {
- new BeamFnStatusClient(
- statusApiServiceDescriptor,
- channelFactory::forDescriptor,
- processBundleHandler.getBundleProcessorCache(),
- options);
+ beamFnStatusClient =
+ new BeamFnStatusClient(
+ statusApiServiceDescriptor,
+ channelFactory::forDescriptor,
+ processBundleHandler.getBundleProcessorCache(),
+ options);
}
// TODO(BEAM-9729): Remove once runners no longer send this instruction.
@@ -337,6 +339,9 @@
executorService,
handlers);
control.waitForTermination();
+ if (beamFnStatusClient != null) {
+ beamFnStatusClient.close();
+ }
processBundleHandler.shutdown();
} finally {
System.out.println("Shutting SDK harness down.");
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 9af7f58..296767b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -48,8 +48,6 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
/**
* The {@code PCollectionConsumerRegistry} is used to maintain a collection of consuming
@@ -216,7 +214,7 @@
private final String pTransformId;
private final SimpleExecutionState state;
private final Counter unboundedElementCountCounter;
- private final SampleByteSizeDistribution<T> unboundSampledByteSizeDistribution;
+ private final SampleByteSizeDistribution<T> unboundedSampledByteSizeDistribution;
private final Coder<T> coder;
private final MetricsContainer metricsContainer;
@@ -239,7 +237,7 @@
MonitoringInfoMetricName sampledByteSizeMetricName =
MonitoringInfoMetricName.named(Urns.SAMPLED_BYTE_SIZE, labels);
- this.unboundSampledByteSizeDistribution =
+ this.unboundedSampledByteSizeDistribution =
new SampleByteSizeDistribution<>(
unboundMetricContainer.getDistribution(sampledByteSizeMetricName));
@@ -252,7 +250,7 @@
// Increment the counter for each window the element occurs in.
this.unboundedElementCountCounter.inc(input.getWindows().size());
// TODO(BEAM-11879): Consider updating size per window when we have window optimization.
- this.unboundSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
+ this.unboundedSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
// Wrap the consumer with extra logic to set the metric container with the appropriate
// PTransform context. This ensures that user metrics obtain the pTransform ID when they are
// created. Also use the ExecutionStateTracker and enter an appropriate state to track the
@@ -262,6 +260,7 @@
this.delegate.accept(input);
}
}
+ this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
}
}
@@ -321,6 +320,7 @@
consumerAndMetadata.getConsumer().accept(input);
}
}
+ this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
}
}
}
@@ -365,28 +365,33 @@
}
final Distribution distribution;
+ ByteSizeObserver byteCountObserver;
public SampleByteSizeDistribution(Distribution distribution) {
this.distribution = distribution;
+ this.byteCountObserver = null;
}
public void tryUpdate(T value, Coder<T> coder) throws Exception {
if (shouldSampleElement()) {
// First try using byte size observer
- ByteSizeObserver observer = new ByteSizeObserver();
- coder.registerByteSizeObserver(value, observer);
+ byteCountObserver = new ByteSizeObserver();
+ coder.registerByteSizeObserver(value, byteCountObserver);
- if (!observer.getIsLazy()) {
- observer.advance();
- this.distribution.update(observer.observedSize);
- } else {
- // TODO(BEAM-11841): Optimize calculation of element size for iterables.
- // Coder byte size observation is lazy (requires iteration for observation) so fall back
- // to counting output stream
- CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
- coder.encode(value, os);
- this.distribution.update(os.getCount());
+ if (!byteCountObserver.getIsLazy()) {
+ byteCountObserver.advance();
+ this.distribution.update(byteCountObserver.observedSize);
}
+ } else {
+ byteCountObserver = null;
+ }
+ }
+
+ public void finishLazyUpdate() {
+ // Advance lazy ElementByteSizeObservers, if any.
+ if (byteCountObserver != null && byteCountObserver.getIsLazy()) {
+ byteCountObserver.advance();
+ this.distribution.update(byteCountObserver.observedSize);
}
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
index 4c01c04..e059471 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
@@ -25,6 +25,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
@@ -42,9 +44,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class BeamFnStatusClient {
+public class BeamFnStatusClient implements AutoCloseable {
+ private static final Object COMPLETED = new Object();
private final StreamObserver<WorkerStatusResponse> outboundObserver;
private final BundleProcessorCache processBundleCache;
+ private final ManagedChannel channel;
+ private final CompletableFuture<Object> inboundObserverCompletion;
private static final Logger LOG = LoggerFactory.getLogger(BeamFnStatusClient.class);
private final MemoryMonitor memoryMonitor;
@@ -53,11 +58,12 @@
Function<ApiServiceDescriptor, ManagedChannel> channelFactory,
BundleProcessorCache processBundleCache,
PipelineOptions options) {
- BeamFnWorkerStatusGrpc.BeamFnWorkerStatusStub stub =
- BeamFnWorkerStatusGrpc.newStub(channelFactory.apply(apiServiceDescriptor));
- this.outboundObserver = stub.workerStatus(new InboundObserver());
+ this.channel = channelFactory.apply(apiServiceDescriptor);
+ this.outboundObserver =
+ BeamFnWorkerStatusGrpc.newStub(channel).workerStatus(new InboundObserver());
this.processBundleCache = processBundleCache;
this.memoryMonitor = MemoryMonitor.fromOptions(options);
+ this.inboundObserverCompletion = new CompletableFuture<>();
Thread thread = new Thread(memoryMonitor);
thread.setDaemon(true);
thread.setPriority(Thread.MIN_PRIORITY);
@@ -65,6 +71,22 @@
thread.start();
}
+ @Override
+ public void close() throws Exception {
+ try {
+ Object completion = inboundObserverCompletion.get(1, TimeUnit.MINUTES);
+ if (completion != COMPLETED) {
+ LOG.warn("InboundObserver for BeamFnStatusClient completed with exception.");
+ }
+ } finally {
+ // Shut the channel down
+ channel.shutdown();
+ if (!channel.awaitTermination(10, TimeUnit.SECONDS)) {
+ channel.shutdownNow();
+ }
+ }
+ }
+
/**
* Class representing the execution state of a thread.
*
@@ -222,9 +244,12 @@
@Override
public void onError(Throwable t) {
LOG.error("Error getting SDK harness status", t);
+ inboundObserverCompletion.completeExceptionally(t);
}
@Override
- public void onCompleted() {}
+ public void onCompleted() {
+ inboundObserverCompletion.complete(COMPLETED);
+ }
}
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index 90baa5e..708ca8b 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -24,6 +24,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@@ -32,7 +33,9 @@
import static org.powermock.api.mockito.PowerMockito.mockStatic;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import org.apache.beam.fn.harness.HandlesSplits;
import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
@@ -44,15 +47,19 @@
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
+import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
@@ -329,5 +336,108 @@
verify(consumerA1).trySplit(0.3);
}
+ @Test
+ public void testLazyByteSizeEstimation() throws Exception {
+ final String pCollectionA = "pCollectionA";
+ final String pTransformIdA = "pTransformIdA";
+
+ MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+ PCollectionConsumerRegistry consumers =
+ new PCollectionConsumerRegistry(
+ metricsContainerRegistry, mock(ExecutionStateTracker.class));
+ FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class);
+
+ consumers.register(
+ pCollectionA, pTransformIdA, consumerA1, IterableCoder.of(StringUtf8Coder.of()));
+
+ FnDataReceiver<WindowedValue<Iterable<String>>> wrapperConsumer =
+ (FnDataReceiver<WindowedValue<Iterable<String>>>)
+ (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+ String elementValue = "elem";
+ long elementByteSize = StringUtf8Coder.of().getEncodedElementByteSize(elementValue);
+ WindowedValue<Iterable<String>> element =
+ valueInGlobalWindow(
+ new TestElementByteSizeObservableIterable<>(
+ Arrays.asList(elementValue, elementValue), elementByteSize));
+ int numElements = 10;
+ // Mock doing work on the iterable items
+ doAnswer(
+ (Answer<Void>)
+ invocation -> {
+ Object[] args = invocation.getArguments();
+ WindowedValue<Iterable<String>> arg = (WindowedValue<Iterable<String>>) args[0];
+ Iterator it = arg.getValue().iterator();
+ while (it.hasNext()) {
+ it.next();
+ }
+ return null;
+ })
+ .when(consumerA1)
+ .accept(element);
+
+ for (int i = 0; i < numElements; i++) {
+ wrapperConsumer.accept(element);
+ }
+
+ // Check that the underlying consumers are each invoked per element.
+ verify(consumerA1, times(numElements)).accept(element);
+ assertThat(consumers.keySet(), contains(pCollectionA));
+
+ List<MonitoringInfo> expected = new ArrayList<>();
+
+ SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+ builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+ builder.setInt64SumValue(numElements);
+ expected.add(builder.build());
+
+ builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrn(Urns.SAMPLED_BYTE_SIZE);
+ builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+ long expectedBytes =
+ (elementByteSize + 1) * 2
+ + 5; // Additional 5 bytes are due to size and hasNext = false (1 byte).
+ builder.setInt64DistributionValue(
+ DistributionData.create(
+ numElements * expectedBytes, numElements, expectedBytes, expectedBytes));
+ expected.add(builder.build());
+ // Clear the timestamp before comparison.
+ Iterable<MonitoringInfo> result =
+ Iterables.filter(
+ metricsContainerRegistry.getMonitoringInfos(),
+ monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+
+ assertThat(result, containsInAnyOrder(expected.toArray()));
+ }
+
+ private class TestElementByteSizeObservableIterable<T>
+ extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>> {
+ private List<T> elements;
+ private long elementByteSize;
+
+ public TestElementByteSizeObservableIterable(List<T> elements, long elementByteSize) {
+ this.elements = elements;
+ this.elementByteSize = elementByteSize;
+ }
+
+ @Override
+ protected ElementByteSizeObservableIterator createIterator() {
+ return new ElementByteSizeObservableIterator() {
+ private int index = 0;
+
+ @Override
+ public boolean hasNext() {
+ return index < elements.size();
+ }
+
+ @Override
+ public Object next() {
+ notifyValueReturned(elementByteSize);
+ return elements.get(index++);
+ }
+ };
+ }
+ }
+
private abstract static class SplittingReceiver<T> implements FnDataReceiver<T>, HandlesSplits {}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
index 4193ba6..fc1fe94 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
@@ -109,7 +109,10 @@
.getReferencedTables();
if (referencedTables != null && !referencedTables.isEmpty()) {
TableReference referencedTable = referencedTables.get(0);
- effectiveLocation = tableService.getTable(referencedTable).getLocation();
+ effectiveLocation =
+ tableService
+ .getDataset(referencedTable.getProjectId(), referencedTable.getDatasetId())
+ .getLocation();
}
}
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index 84743f8c..820418c 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -43,7 +43,7 @@
configurations.testRuntimeClasspath {
resolutionStrategy {
- def log4j_version = "2.4.1"
+ def log4j_version = "2.8.2"
// Beam's build system forces a uniform log4j version resolution for all modules, however for
// the HCatalog case the current version of log4j produces NoClassDefFoundError so we need to
// force an old version on the tests runtime classpath
diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py
index cc30cec..ee3c2a3 100644
--- a/sdks/python/apache_beam/dataframe/schemas.py
+++ b/sdks/python/apache_beam/dataframe/schemas.py
@@ -281,7 +281,8 @@
ctor = element_type_from_dataframe(proxy, include_indexes=include_indexes)
return beam.ParDo(
- _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor))
+ _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor)
+ ).with_output_types(ctor)
elif isinstance(proxy, pd.Series):
# Raise a TypeError if proxy has an unknown type
output_type = _dtype_to_fieldtype(proxy.dtype)
diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py
index 8b1159c..af30e25 100644
--- a/sdks/python/apache_beam/dataframe/schemas_test.py
+++ b/sdks/python/apache_beam/dataframe/schemas_test.py
@@ -36,6 +36,8 @@
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
Simple = typing.NamedTuple(
'Simple', [('name', unicode), ('id', int), ('height', float)])
@@ -65,41 +67,59 @@
# dtype. For example:
# pd.Series([b'abc'], dtype=bytes).dtype != 'S'
# pd.Series([b'abc'], dtype=bytes).astype(bytes).dtype == 'S'
+# (test data, pandas_type, column_name, beam_type)
COLUMNS = [
- ([375, 24, 0, 10, 16], np.int32, 'i32'),
- ([375, 24, 0, 10, 16], np.int64, 'i64'),
- ([375, 24, None, 10, 16], pd.Int32Dtype(), 'i32_nullable'),
- ([375, 24, None, 10, 16], pd.Int64Dtype(), 'i64_nullable'),
- ([375., 24., None, 10., 16.], np.float64, 'f64'),
- ([375., 24., None, 10., 16.], np.float32, 'f32'),
- ([True, False, True, True, False], bool, 'bool'),
- (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any'),
- ([True, False, True, None, False], pd.BooleanDtype(), 'bool_nullable'),
+ ([375, 24, 0, 10, 16], np.int32, 'i32', np.int32),
+ ([375, 24, 0, 10, 16], np.int64, 'i64', np.int64),
+ ([375, 24, None, 10, 16],
+ pd.Int32Dtype(),
+ 'i32_nullable',
+ typing.Optional[np.int32]),
+ ([375, 24, None, 10, 16],
+ pd.Int64Dtype(),
+ 'i64_nullable',
+ typing.Optional[np.int64]),
+ ([375., 24., None, 10., 16.],
+ np.float64,
+ 'f64',
+ typing.Optional[np.float64]),
+ ([375., 24., None, 10., 16.],
+ np.float32,
+ 'f32',
+ typing.Optional[np.float32]),
+ ([True, False, True, True, False], bool, 'bool', bool),
+ (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any', typing.Any),
+ ([True, False, True, None, False],
+ pd.BooleanDtype(),
+ 'bool_nullable',
+ typing.Optional[bool]),
(['Falcon', 'Ostrich', None, 'Aardvark', 'Elephant'],
pd.StringDtype(),
- 'strdtype'),
-] # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str]]
+ 'strdtype',
+ typing.Optional[str]),
+] # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str, typing.Any]]
-NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name in COLUMNS])
-for arr, dtype, name in COLUMNS:
+NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name, _ in COLUMNS])
+for arr, dtype, name, _ in COLUMNS:
NICE_TYPES_DF[name] = pd.Series(arr, dtype=dtype, name=name).astype(dtype)
NICE_TYPES_PROXY = NICE_TYPES_DF[:0]
-SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr) for arr,
- dtype,
- name in COLUMNS]
+SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr, beam_type)
+ for (arr, dtype, name, beam_type) in COLUMNS]
_TEST_ARRAYS = [
- arr for arr, _, _ in COLUMNS
+ arr for (arr, _, _, _) in COLUMNS
] # type: typing.List[typing.List[typing.Any]]
DF_RESULT = list(zip(*_TEST_ARRAYS))
-INDEX_DF_TESTS = [
- (NICE_TYPES_DF.set_index([name for _, _, name in COLUMNS[:i]]), DF_RESULT)
- for i in range(1, len(COLUMNS) + 1)
-]
+BEAM_SCHEMA = typing.NamedTuple( # type: ignore
+ 'BEAM_SCHEMA', [(name, beam_type) for _, _, name, beam_type in COLUMNS])
+INDEX_DF_TESTS = [(
+ NICE_TYPES_DF.set_index([name for _, _, name, _ in COLUMNS[:i]]),
+ DF_RESULT,
+ BEAM_SCHEMA) for i in range(1, len(COLUMNS) + 1)]
-NOINDEX_DF_TESTS = [(NICE_TYPES_DF, DF_RESULT)]
+NOINDEX_DF_TESTS = [(NICE_TYPES_DF, DF_RESULT, BEAM_SCHEMA)]
PD_VERSION = tuple(int(n) for n in pd.__version__.split('.'))
@@ -203,8 +223,18 @@
proxy=schemas.generate_proxy(Animal)))
assert_that(res, equal_to([('Falcon', 375.), ('Parrot', 25.)]))
+ def assert_typehints_equal(self, left, right):
+ left = typehints.normalize(left)
+ right = typehints.normalize(right)
+
+ if match_is_named_tuple(left):
+ self.assertTrue(match_is_named_tuple(right))
+ self.assertEqual(left.__annotations__, right.__annotations__)
+ else:
+ self.assertEqual(left, right)
+
@parameterized.expand(SERIES_TESTS + NOINDEX_DF_TESTS)
- def test_unbatch_no_index(self, df_or_series, rows):
+ def test_unbatch_no_index(self, df_or_series, rows, beam_type):
proxy = df_or_series[:0]
with TestPipeline() as p:
@@ -212,10 +242,15 @@
p | beam.Create([df_or_series[::2], df_or_series[1::2]])
| schemas.UnbatchPandas(proxy))
+ # Verify that the unbatched PCollection has the expected typehint
+ # TODO(BEAM-8538): typehints should support NamedTuple so we can use
+ # typehints.is_consistent_with here instead
+ self.assert_typehints_equal(res.element_type, beam_type)
+
assert_that(res, equal_to(rows))
@parameterized.expand(SERIES_TESTS + INDEX_DF_TESTS)
- def test_unbatch_with_index(self, df_or_series, rows):
+ def test_unbatch_with_index(self, df_or_series, rows, _):
proxy = df_or_series[:0]
with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp.py
index 3c4406c..93510c8 100644
--- a/sdks/python/apache_beam/ml/gcp/cloud_dlp.py
+++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp.py
@@ -19,8 +19,6 @@
functionality.
"""
-from __future__ import absolute_import
-
import logging
from google.cloud import dlp_v2
diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py
index fbba610..4ada679 100644
--- a/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py
+++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp_it_test.py
@@ -17,8 +17,6 @@
"""Integration tests for Google Cloud Video Intelligence API transforms."""
-from __future__ import absolute_import
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py b/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py
index ef25555..111e5be 100644
--- a/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py
+++ b/sdks/python/apache_beam/ml/gcp/cloud_dlp_test.py
@@ -17,8 +17,6 @@
"""Unit tests for Google Cloud Video Intelligence API transforms."""
-from __future__ import absolute_import
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py
index 5263a60..7817eb9 100644
--- a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py
+++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py
@@ -15,8 +15,6 @@
# limitations under the License.
#
-from __future__ import absolute_import
-
from typing import Mapping
from typing import Optional
from typing import Sequence
diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py
index ef72359..e639517 100644
--- a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py
+++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test.py
@@ -18,8 +18,6 @@
"""Unit tests for Google Cloud Natural Language API transform."""
-from __future__ import absolute_import
-
import unittest
import mock
diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py
index 4cf58e4..932bc685 100644
--- a/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py
+++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml_test_it.py
@@ -16,8 +16,6 @@
#
# pytype: skip-file
-from __future__ import absolute_import
-
import unittest
from nose.plugins.attrib import attr
diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py
index 67ff496..bc0aa08 100644
--- a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py
+++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py
@@ -17,15 +17,10 @@
"""A connector for sending API requests to the GCP Video Intelligence API."""
-from __future__ import absolute_import
-
from typing import Optional
from typing import Tuple
from typing import Union
-from future.utils import binary_type
-from future.utils import text_type
-
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.transforms import DoFn
@@ -55,8 +50,8 @@
ref: https://cloud.google.com/video-intelligence/docs
Sends each element to the GCP Video Intelligence API. Element is a
- Union[text_type, binary_type] of either an URI (e.g. a GCS URI) or
- binary_type base64-encoded video data.
+ Union[str, bytes] of either an URI (e.g. a GCS URI) or
+ bytes base64-encoded video data.
Accepts an `AsDict` side input that maps each video to a video context.
"""
def __init__(
@@ -118,8 +113,7 @@
@typehints.with_input_types(
- Union[text_type, binary_type],
- Optional[videointelligence.types.VideoContext])
+ Union[str, bytes], Optional[videointelligence.types.VideoContext])
class _VideoAnnotateFn(DoFn):
"""A DoFn that sends each input element to the GCP Video Intelligence API
service and outputs an element with the return result of the API
@@ -138,7 +132,7 @@
self._client = get_videointelligence_client()
def _annotate_video(self, element, video_context):
- if isinstance(element, text_type): # Is element an URI to a GCS bucket
+ if isinstance(element, str): # Is element an URI to a GCS bucket
response = self._client.annotate_video(
input_uri=element,
features=self.features,
@@ -171,11 +165,11 @@
Sends each element to the GCP Video Intelligence API.
Element is a tuple of
- (Union[text_type, binary_type],
+ (Union[str, bytes],
Optional[videointelligence.types.VideoContext])
where the former is either an URI (e.g. a GCS URI) or
- binary_type base64-encoded video data
+ bytes base64-encoded video data
"""
def __init__(self, features, location_id=None, metadata=None, timeout=120):
"""
@@ -208,8 +202,7 @@
@typehints.with_input_types(
- Tuple[Union[text_type, binary_type],
- Optional[videointelligence.types.VideoContext]])
+ Tuple[Union[str, bytes], Optional[videointelligence.types.VideoContext]])
class _VideoAnnotateFnWithContext(_VideoAnnotateFn):
"""A DoFn that unpacks each input tuple to element, video_context variables
and sends these to the GCP Video Intelligence API service and outputs
diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py
index b0bb441..3215ceb 100644
--- a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py
+++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test.py
@@ -19,9 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py
index a2d8216..d934520 100644
--- a/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py
+++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml_test_it.py
@@ -19,9 +19,6 @@
"""An integration test that labels entities appearing in a video and checks
if some expected entities were properly recognized."""
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
import unittest
import hamcrest as hc
diff --git a/sdks/python/apache_beam/ml/gcp/visionml.py b/sdks/python/apache_beam/ml/gcp/visionml.py
index 13884a7..0fb45ce 100644
--- a/sdks/python/apache_beam/ml/gcp/visionml.py
+++ b/sdks/python/apache_beam/ml/gcp/visionml.py
@@ -20,16 +20,11 @@
A connector for sending API requests to the GCP Vision API.
"""
-from __future__ import absolute_import
-
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
-from future.utils import binary_type
-from future.utils import text_type
-
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.transforms import DoFn
@@ -65,8 +60,8 @@
Batches elements together using ``util.BatchElements`` PTransform and sends
each batch of elements to the GCP Vision API.
- Element is a Union[text_type, binary_type] of either an URI (e.g. a GCS URI)
- or binary_type base64-encoded image data.
+ Element is a Union[str, bytes] of either an URI (e.g. a GCS URI)
+ or bytes base64-encoded image data.
Accepts an `AsDict` side input that maps each image to an image context.
"""
@@ -158,7 +153,7 @@
metadata=self.metadata)))
@typehints.with_input_types(
- Union[text_type, binary_type], Optional[vision.types.ImageContext])
+ Union[str, bytes], Optional[vision.types.ImageContext])
@typehints.with_output_types(List[vision.types.AnnotateImageRequest])
def _create_image_annotation_pairs(self, element, context_side_input):
if context_side_input: # If we have a side input image context, use that
@@ -166,10 +161,10 @@
else:
image_context = None
- if isinstance(element, text_type):
+ if isinstance(element, str):
image = vision.types.Image(
source=vision.types.ImageSource(image_uri=element))
- else: # Typehint checks only allows text_type or binary_type
+ else: # Typehint checks only allows str or bytes
image = vision.types.Image(content=element)
request = vision.types.AnnotateImageRequest(
@@ -185,10 +180,10 @@
Element is a tuple of::
- (Union[text_type, binary_type],
+ (Union[str, bytes],
Optional[``vision.types.ImageContext``])
- where the former is either an URI (e.g. a GCS URI) or binary_type
+ where the former is either an URI (e.g. a GCS URI) or bytes
base64-encoded image data.
"""
def __init__(
@@ -249,14 +244,14 @@
metadata=self.metadata)))
@typehints.with_input_types(
- Tuple[Union[text_type, binary_type], Optional[vision.types.ImageContext]])
+ Tuple[Union[str, bytes], Optional[vision.types.ImageContext]])
@typehints.with_output_types(List[vision.types.AnnotateImageRequest])
def _create_image_annotation_pairs(self, element, **kwargs):
element, image_context = element # Unpack (image, image_context) tuple
- if isinstance(element, text_type):
+ if isinstance(element, str):
image = vision.types.Image(
source=vision.types.ImageSource(image_uri=element))
- else: # Typehint checks only allows text_type or binary_type
+ else: # Typehint checks only allows str or bytes
image = vision.types.Image(content=element)
request = vision.types.AnnotateImageRequest(
diff --git a/sdks/python/apache_beam/ml/gcp/visionml_test.py b/sdks/python/apache_beam/ml/gcp/visionml_test.py
index d4c6c20..f038442 100644
--- a/sdks/python/apache_beam/ml/gcp/visionml_test.py
+++ b/sdks/python/apache_beam/ml/gcp/visionml_test.py
@@ -20,9 +20,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/ml/gcp/visionml_test_it.py b/sdks/python/apache_beam/ml/gcp/visionml_test_it.py
index a81acd0..14af3cb8 100644
--- a/sdks/python/apache_beam/ml/gcp/visionml_test_it.py
+++ b/sdks/python/apache_beam/ml/gcp/visionml_test_it.py
@@ -16,8 +16,6 @@
#
# pytype: skip-file
-from __future__ import absolute_import
-
import unittest
from nose.plugins.attrib import attr
diff --git a/sdks/python/apache_beam/options/__init__.py b/sdks/python/apache_beam/options/__init__.py
index f4f43cb..cce3aca 100644
--- a/sdks/python/apache_beam/options/__init__.py
+++ b/sdks/python/apache_beam/options/__init__.py
@@ -14,4 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from __future__ import absolute_import
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index 449c7d3..073014c 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -19,13 +19,9 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import argparse
import json
import logging
-from builtins import list
-from builtins import object
from typing import Any
from typing import Callable
from typing import Dict
diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py
index 08c4145..8299a54 100644
--- a/sdks/python/apache_beam/options/pipeline_options_test.py
+++ b/sdks/python/apache_beam/options/pipeline_options_test.py
@@ -19,8 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import json
import logging
import unittest
diff --git a/sdks/python/apache_beam/options/pipeline_options_validator.py b/sdks/python/apache_beam/options/pipeline_options_validator.py
index 97df45f..2196ca6 100644
--- a/sdks/python/apache_beam/options/pipeline_options_validator.py
+++ b/sdks/python/apache_beam/options/pipeline_options_validator.py
@@ -21,13 +21,8 @@
"""
# pytype: skip-file
-from __future__ import absolute_import
-
import logging
import re
-from builtins import object
-
-from past.builtins import unicode
from apache_beam.internal import pickler
from apache_beam.options.pipeline_options import DebugOptions
@@ -220,8 +215,7 @@
'Transform name mapping option is only useful when '
'--update and --streaming is specified')
for _, (key, value) in enumerate(view.transform_name_mapping.items()):
- if not isinstance(key, (str, unicode)) \
- or not isinstance(value, (str, unicode)):
+ if not isinstance(key, str) or not isinstance(value, str):
errors.extend(
self._validate_error(
self.ERR_INVALID_TRANSFORM_NAME_MAPPING, key, value))
diff --git a/sdks/python/apache_beam/options/pipeline_options_validator_test.py b/sdks/python/apache_beam/options/pipeline_options_validator_test.py
index 0eea8d7..54b3a69 100644
--- a/sdks/python/apache_beam/options/pipeline_options_validator_test.py
+++ b/sdks/python/apache_beam/options/pipeline_options_validator_test.py
@@ -19,11 +19,8 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import logging
import unittest
-from builtins import object
from hamcrest import assert_that
from hamcrest import contains_string
diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py
index a300df4..5a5d363 100644
--- a/sdks/python/apache_beam/options/value_provider.py
+++ b/sdks/python/apache_beam/options/value_provider.py
@@ -24,9 +24,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
-from builtins import object
from functools import wraps
from typing import Set
diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py
index 189501b..21e05b3 100644
--- a/sdks/python/apache_beam/options/value_provider_test.py
+++ b/sdks/python/apache_beam/options/value_provider_test.py
@@ -19,8 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
import logging
import unittest
diff --git a/sdks/python/apache_beam/portability/__init__.py b/sdks/python/apache_beam/portability/__init__.py
index 9fbf215..0bce5d6 100644
--- a/sdks/python/apache_beam/portability/__init__.py
+++ b/sdks/python/apache_beam/portability/__init__.py
@@ -16,4 +16,3 @@
#
"""For internal use only; no backwards-compatibility guarantees."""
-from __future__ import absolute_import
diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py
index 18cc249..4b8838f 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -19,8 +19,6 @@
# pytype: skip-file
-from __future__ import absolute_import
-
from apache_beam.portability.api.beam_runner_api_pb2_urns import BeamConstants
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardArtifacts
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardCoders
diff --git a/sdks/python/apache_beam/portability/utils.py b/sdks/python/apache_beam/portability/utils.py
index 176c6db..4d23ff1 100644
--- a/sdks/python/apache_beam/portability/utils.py
+++ b/sdks/python/apache_beam/portability/utils.py
@@ -16,8 +16,6 @@
#
"""For internal use only; no backwards-compatibility guarantees."""
-from __future__ import absolute_import
-
from typing import TYPE_CHECKING
from typing import NamedTuple
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index c6d816a..e1f6635 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -123,6 +123,11 @@
try:
_load_main_session(semi_persistent_directory)
+ except CorruptMainSessionException:
+ exception_details = traceback.format_exc()
+ _LOGGER.error(
+ 'Could not load main session: %s', exception_details, exc_info=True)
+ raise
except Exception: # pylint: disable=broad-except
exception_details = traceback.format_exc()
_LOGGER.error(
@@ -245,12 +250,29 @@
return 0
+class CorruptMainSessionException(Exception):
+ """
+ Used to crash this worker if a main session file was provided but
+ is not valid.
+ """
+ pass
+
+
def _load_main_session(semi_persistent_directory):
"""Loads a pickled main session from the path specified."""
if semi_persistent_directory:
session_file = os.path.join(
semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
if os.path.isfile(session_file):
+ # If the expected session file is present but empty, it's likely that
+ # the user code run by this worker will likely crash at runtime.
+ # This can happen if the worker fails to download the main session.
+ # Raise a fatal error and crash this worker, forcing a restart.
+ if os.path.getsize(session_file) == 0:
+ raise CorruptMainSessionException(
+ 'Session file found, but empty: %s. Functions defined in __main__ '
+ '(interactive session) will almost certainly fail.' %
+ (session_file, ))
pickler.load_session(session_file)
else:
_LOGGER.warning(
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 2b738bf..d77727e 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -107,7 +107,7 @@
return getattr(user_type, '__origin__', None) is expected_origin
-def _match_is_named_tuple(user_type):
+def match_is_named_tuple(user_type):
return (
_safe_issubclass(user_type, typing.Tuple) and
hasattr(user_type, '_field_types'))
@@ -234,7 +234,7 @@
# We just convert it to Any for now.
# This MUST appear before the entry for the normal Tuple.
_TypeMapEntry(
- match=_match_is_named_tuple, arity=0, beam_type=typehints.Any),
+ match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Tuple),
arity=-1,
diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py
index 6a299bd..5daf68a 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -70,10 +70,10 @@
from apache_beam.typehints import row_type
from apache_beam.typehints.native_type_compatibility import _get_args
from apache_beam.typehints.native_type_compatibility import _match_is_exactly_mapping
-from apache_beam.typehints.native_type_compatibility import _match_is_named_tuple
from apache_beam.typehints.native_type_compatibility import _match_is_optional
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import extract_optional_type
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.utils import proto_utils
from apache_beam.utils.timestamp import Timestamp
@@ -148,7 +148,7 @@
def typing_to_runner_api(type_):
- if _match_is_named_tuple(type_):
+ if match_is_named_tuple(type_):
schema = None
if hasattr(type_, _BEAM_SCHEMA_ID):
schema = SCHEMA_REGISTRY.get_schema_by_id(getattr(type_, _BEAM_SCHEMA_ID))
@@ -287,7 +287,7 @@
# TODO(BEAM-10722): Make sure beam.Row generated schemas are registered and
# de-duped
return named_fields_to_schema(element_type._fields)
- elif _match_is_named_tuple(element_type):
+ elif match_is_named_tuple(element_type):
return named_tuple_to_schema(element_type)
else:
raise TypeError(