Merge pull request #14762 from ibzib/BEAM-12312
[BEAM-12312] Don't rely on remove() in LazyAggregateCombineFn#mergeAc…
diff --git a/.test-infra/jenkins/CommonJobProperties.groovy b/.test-infra/jenkins/CommonJobProperties.groovy
index 851fc0b..295cc89 100644
--- a/.test-infra/jenkins/CommonJobProperties.groovy
+++ b/.test-infra/jenkins/CommonJobProperties.groovy
@@ -183,6 +183,11 @@
context.switches("-Dorg.gradle.jvmargs=-Xms2g")
context.switches("-Dorg.gradle.jvmargs=-Xmx4g")
+ // Disable file system watching for CI builds
+ // Builds are performed on a clean clone and files aren't modified, so
+ // there's no value in watching for changes.
+ context.switches("-Dorg.gradle.vfs.watch=false")
+
// Include dependency licenses when build docker images on Jenkins, see https://s.apache.org/zt68q
context.switches("-Pdocker-pull-licenses")
}
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
index 0e3e628..1ac4562 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
@@ -27,7 +27,7 @@
description('Runs the ValidatesRunner suite on the Dataflow runner.')
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
previousNames(/beam_PostCommit_Java_ValidatesRunner_Dataflow_Gradle/)
// Publish all test results to Jenkins
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
index af9e25e..6ba9685 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
@@ -28,7 +28,7 @@
def JAVA_11_HOME = '/usr/lib/jvm/java-11-openjdk-amd64'
def JAVA_8_HOME = '/usr/lib/jvm/java-8-openjdk-amd64'
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
publishers {
archiveJunit('**/build/test-results/**/*.xml')
}
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
index 0de085d..3b3d8ed 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
@@ -27,7 +27,7 @@
description('Runs Java ValidatesRunner suite on the Dataflow runner V2 forcing streaming mode.')
- commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 330)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 450)
// Publish all test results to Jenkins
publishers {
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
index e847928..5631649 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
@@ -26,7 +26,7 @@
description('Runs Python ValidatesRunner suite on the Dataflow runner.')
// Set common parameters.
- commonJobProperties.setTopLevelMainJobProperties(delegate)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
publishers {
archiveJunit('**/nosetests*.xml')
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
index c219de6..e164514 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
@@ -26,7 +26,7 @@
description('Runs Python ValidatesRunner suite on the Dataflow runner v2.')
// Set common parameters.
- commonJobProperties.setTopLevelMainJobProperties(delegate)
+ commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
publishers {
archiveJunit('**/nosetests*.xml')
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
index 40e7383..5cba1d4 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
@@ -42,7 +42,7 @@
static def alpn_api_version = "1.1.2.v20150522"
static def npn_api_version = "1.1.1.v20141010"
static def jboss_marshalling_version = "1.4.11.Final"
- static def jboss_modules_version = "1.11.0.Final"
+ static def jboss_modules_version = "1.1.0.Beta1"
/** Returns the list of compile time dependencies. */
static List<String> dependencies() {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
index aae35ab..e32edf7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
@@ -26,7 +26,6 @@
import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.POutput;
-import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.PValues;
import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.sdk.values.TupleTag;
@@ -41,17 +40,16 @@
private ReplacementOutputs() {}
public static Map<PCollection<?>, ReplacementOutput> singleton(
- Map<TupleTag<?>, PCollection<?>> original, PValue replacement) {
+ Map<TupleTag<?>, PCollection<?>> original, POutput replacement) {
Entry<TupleTag<?>, PCollection<?>> originalElement =
Iterables.getOnlyElement(original.entrySet());
- TupleTag<?> replacementTag = Iterables.getOnlyElement(replacement.expand().entrySet()).getKey();
- PCollection<?> replacementCollection =
- (PCollection<?>) Iterables.getOnlyElement(replacement.expand().entrySet()).getValue();
+ Entry<TupleTag<?>, PCollection<?>> replacementElement =
+ Iterables.getOnlyElement(PValues.expandOutput(replacement).entrySet());
return Collections.singletonMap(
- replacementCollection,
+ replacementElement.getValue(),
ReplacementOutput.of(
TaggedPValue.of(originalElement.getKey(), originalElement.getValue()),
- TaggedPValue.of(replacementTag, replacementCollection)));
+ TaggedPValue.of(replacementElement.getKey(), replacementElement.getValue())));
}
public static Map<PCollection<?>, ReplacementOutput> tagged(
diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
index 9ea5099..1ba1ac4 100644
--- a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
+++ b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
@@ -24,53 +24,105 @@
"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/go/pkg/beam/core/util/protox"
+ "github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
pipepb "github.com/apache/beam/sdks/go/pkg/beam/model/pipeline_v1"
)
// ResolveArtifacts acquires all dependencies for a cross-language transform
func ResolveArtifacts(ctx context.Context, edges []*graph.MultiEdge, p *pipepb.Pipeline) {
- path, err := filepath.Abs("/tmp/artifacts")
+ _, err := ResolveArtifactsWithConfig(ctx, edges, ResolveConfig{})
if err != nil {
panic(err)
}
+}
+
+// ResolveConfig contains fields for configuring the behavior for resolving
+// artifacts.
+type ResolveConfig struct {
+ // SdkPath replaces the default filepath for dependencies, but only in the
+ // external environment proto to be used by the SDK Harness during pipeline
+ // execution. This is used to specify alternate staging directories, such
+ // as for staging artifacts remotely.
+ //
+ // Setting an SdkPath does not change staging behavior otherwise. All
+ // artifacts still get staged to the default local filepath, and it is the
+ // user's responsibility to stage those local artifacts to the SdkPath.
+ SdkPath string
+
+ // JoinFn is a function for combining SdkPath and individual artifact names.
+ // If not specified, it defaults to using filepath.Join.
+ JoinFn func(path, name string) string
+}
+
+func defaultJoinFn(path, name string) string {
+ return filepath.Join(path, "/", name)
+}
+
+// ResolveArtifactsWithConfig acquires all dependencies for cross-language
+// transforms, but with some additional configuration to behavior. By default,
+// this function performs the following steps for each cross-language transform
+// in the list of edges:
+// 1. Retrieves a list of dependencies needed from the expansion service.
+// 2. Retrieves each dependency as an artifact and stages it to a default
+// local filepath.
+// 3. Adds the dependencies to the transform's stored environment proto.
+// The changes that can be configured are documented in ResolveConfig.
+//
+// This returns a map of "local path" to "sdk path". By default these are
+// identical, unless ResolveConfig.SdkPath has been set.
+func ResolveArtifactsWithConfig(ctx context.Context, edges []*graph.MultiEdge, cfg ResolveConfig) (paths map[string]string, err error) {
+ tmpPath, err := filepath.Abs("/tmp/artifacts")
+ if err != nil {
+ return nil, errors.WithContext(err, "resolving remote artifacts")
+ }
+ if cfg.JoinFn == nil {
+ cfg.JoinFn = defaultJoinFn
+ }
+ paths = make(map[string]string)
for _, e := range edges {
if e.Op == graph.External {
components, err := graphx.ExpandedComponents(e.External.Expanded)
if err != nil {
- panic(err)
+ return nil, errors.WithContextf(err,
+ "resolving remote artifacts for edge %v", e.Name())
}
envs := components.Environments
for eid, env := range envs {
-
if strings.HasPrefix(eid, "go") {
continue
}
deps := env.GetDependencies()
- resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", path)
+ resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", tmpPath)
if err != nil {
- panic(err)
+ return nil, errors.WithContextf(err,
+ "resolving remote artifacts for env %v in edge %v", eid, e.Name())
}
var resolvedDeps []*pipepb.ArtifactInformation
for _, a := range resolvedArtifacts {
- name, sha256 := artifact.MustExtractFilePayload(a)
- fullPath := filepath.Join(path, "/", name)
+ name, _ := artifact.MustExtractFilePayload(a)
+ fullTmpPath := filepath.Join(tmpPath, "/", name)
+ fullSdkPath := fullTmpPath
+ if len(cfg.SdkPath) > 0 {
+ fullSdkPath = cfg.JoinFn(cfg.SdkPath, name)
+ }
resolvedDeps = append(resolvedDeps,
&pipepb.ArtifactInformation{
TypeUrn: "beam:artifact:type:file:v1",
TypePayload: protox.MustEncode(
&pipepb.ArtifactFilePayload{
- Path: fullPath,
- Sha256: sha256,
+ Path: fullSdkPath,
},
),
RoleUrn: a.RoleUrn,
RolePayload: a.RolePayload,
},
)
+ paths[fullTmpPath] = fullSdkPath
}
env.Dependencies = resolvedDeps
}
}
}
+ return paths, nil
}
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
index c46ff4c..cd7be52 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
@@ -178,11 +178,23 @@
}
// (1) Build and submit
+ // NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
+ id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
+
+ modelURL := gcsx.Join(*stagingLocation, id, "model")
+ workerURL := gcsx.Join(*stagingLocation, id, "worker")
+ jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
+ xlangURL := gcsx.Join(*stagingLocation, id, "xlang")
edges, _, err := p.Build()
if err != nil {
return nil, err
}
+ artifactURLs, err := dataflowlib.ResolveXLangArtifacts(ctx, edges, opts.Project, xlangURL)
+ if err != nil {
+ return nil, errors.WithContext(err, "resolving cross-language artifacts")
+ }
+ opts.ArtifactURLs = artifactURLs
environment, err := graphx.CreateEnvironment(ctx, jobopts.GetEnvironmentUrn(ctx), getContainerImage)
if err != nil {
return nil, errors.WithContext(err, "creating environment for model pipeline")
@@ -196,13 +208,6 @@
return nil, errors.WithContext(err, "applying container image overrides")
}
- // NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
- id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
-
- modelURL := gcsx.Join(*stagingLocation, id, "model")
- workerURL := gcsx.Join(*stagingLocation, id, "worker")
- jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
-
if *dryRun {
log.Info(ctx, "Dry-run: not submitting job!")
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
index 511b962..390082b 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
@@ -56,6 +56,7 @@
WorkerRegion string
WorkerZone string
ContainerImage string
+ ArtifactURLs []string // Additional packages for workers.
// Autoscaling settings
Algorithm string
@@ -128,6 +129,15 @@
experiments = append(experiments, "use_staged_dataflow_worker_jar")
}
+ for _, url := range opts.ArtifactURLs {
+ name := url[strings.LastIndexAny(url, "/")+1:]
+ pkg := &df.Package{
+ Name: name,
+ Location: url,
+ }
+ packages = append(packages, pkg)
+ }
+
ipConfiguration := "WORKER_IP_UNSPECIFIED"
if opts.NoUsePublicIPs {
ipConfiguration = "WORKER_IP_PRIVATE"
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
index 67a2bde..49ca5bf 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
@@ -22,6 +22,8 @@
"os"
"cloud.google.com/go/storage"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/graph"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/runtime/xlangx"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/go/pkg/beam/util/gcsx"
)
@@ -54,3 +56,28 @@
_, err = gcsx.Upload(ctx, client, project, bucket, obj, r)
return err
}
+
+// ResolveXLangArtifacts resolves cross-language artifacts with a given GCS
+// URL as a destination, and then stages all local artifacts to that URL. This
+// function returns a list of staged artifact URLs.
+func ResolveXLangArtifacts(ctx context.Context, edges []*graph.MultiEdge, project, url string) ([]string, error) {
+ cfg := xlangx.ResolveConfig{
+ SdkPath: url,
+ JoinFn: func(url, name string) string {
+ return gcsx.Join(url, "/", name)
+ },
+ }
+ paths, err := xlangx.ResolveArtifactsWithConfig(ctx, edges, cfg)
+ if err != nil {
+ return nil, err
+ }
+ var urls []string
+ for local, remote := range paths {
+ err := StageFile(ctx, project, remote, local)
+ if err != nil {
+ return nil, errors.WithContextf(err, "staging file to %v", remote)
+ }
+ urls = append(urls, remote)
+ }
+ return urls, nil
+}
diff --git a/sdks/go/pkg/beam/runners/universal/universal.go b/sdks/go/pkg/beam/runners/universal/universal.go
index f01a71f..af14b38 100644
--- a/sdks/go/pkg/beam/runners/universal/universal.go
+++ b/sdks/go/pkg/beam/runners/universal/universal.go
@@ -79,6 +79,9 @@
getEnvCfg = srv.EnvironmentConfig
}
+ // Fetch all dependencies for cross-language transforms
+ xlangx.ResolveArtifacts(ctx, edges, nil)
+
environment, err := graphx.CreateEnvironment(ctx, envUrn, getEnvCfg)
if err != nil {
return nil, errors.WithContextf(err, "generating model pipeline")
@@ -88,9 +91,6 @@
return nil, errors.WithContextf(err, "generating model pipeline")
}
- // Fetch all dependencies for cross-language transforms
- xlangx.ResolveArtifacts(ctx, edges, pipeline)
-
log.Info(ctx, proto.MarshalTextString(pipeline))
opt := &runnerlib.JobOptions{
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index c756e05..7db20ac 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -78,12 +78,6 @@
var dataflowFilters = []string{
// TODO(BEAM-11576): TestFlattenDup failing on this runner.
"TestFlattenDup",
- // TODO(BEAM-11418): These tests require implementing x-lang artifact
- // staging on Dataflow.
- "TestXLang_CoGroupBy",
- "TestXLang_Multi",
- "TestXLang_Partition",
- "TestXLang_Prefix",
}
// CheckFilters checks if an integration test is filtered to be skipped, either
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
index 92c3e05..b04266c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
@@ -411,12 +411,12 @@
checkState(
this.outputs == null, "Tried to specify more than one output for %s", getFullName());
checkNotNull(output, "Tried to set the output of %s to null", getFullName());
- this.outputs = PValues.fullyExpand(output.expand());
+ this.outputs = PValues.expandOutput(output);
// Validate that a primitive transform produces only primitive output, and a composite
// transform does not produce primitive output.
Set<Node> outputProducers = new HashSet<>();
- for (PCollection<?> outputValue : PValues.fullyExpand(output.expand()).values()) {
+ for (PCollection<?> outputValue : PValues.expandOutput(output).values()) {
outputProducers.add(getProducer(outputValue));
}
if (outputProducers.contains(this) && (!parts.isEmpty() || outputProducers.size() > 1)) {
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 34664ac..7b580ef 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -27,6 +27,7 @@
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
+import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.BiMap;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableBiMap;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -39,9 +40,6 @@
import org.joda.time.base.AbstractInstant;
/** Utility methods for Calcite related operations. */
-@SuppressWarnings({
- "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
public class CalciteUtils {
private static final long UNLIMITED_ARRAY_SIZE = -1L;
@@ -73,7 +71,9 @@
}
if (fieldType.getTypeName().isLogicalType()) {
- String logicalId = fieldType.getLogicalType().getIdentifier();
+ Schema.LogicalType logicalType = fieldType.getLogicalType();
+ Preconditions.checkArgumentNotNull(logicalType);
+ String logicalId = logicalType.getIdentifier();
return logicalId.equals(SqlTypes.DATE.getIdentifier())
|| logicalId.equals(SqlTypes.TIME.getIdentifier())
|| logicalId.equals(TimeWithLocalTzType.IDENTIFIER)
@@ -88,7 +88,9 @@
}
if (fieldType.getTypeName().isLogicalType()) {
- String logicalId = fieldType.getLogicalType().getIdentifier();
+ Schema.LogicalType logicalType = fieldType.getLogicalType();
+ Preconditions.checkArgumentNotNull(logicalType);
+ String logicalId = logicalType.getIdentifier();
return logicalId.equals(CharType.IDENTIFIER);
}
return false;
@@ -210,7 +212,12 @@
+ "so it cannot be converted to a %s",
sqlTypeName, Schema.FieldType.class.getSimpleName()));
default:
- return CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+ FieldType fieldType = CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+ if (fieldType == null) {
+ throw new IllegalArgumentException(
+ "Cannot find a matching Beam FieldType for Calcite type: " + sqlTypeName);
+ }
+ return fieldType;
}
}
@@ -234,7 +241,12 @@
return FieldType.row(toSchema(calciteType));
default:
- return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+ try {
+ return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(
+ "Cannot find a matching Beam FieldType for Calcite type: " + calciteType, e);
+ }
}
}
@@ -254,16 +266,22 @@
switch (fieldType.getTypeName()) {
case ARRAY:
case ITERABLE:
+ FieldType collectionElementType = fieldType.getCollectionElementType();
+ Preconditions.checkArgumentNotNull(collectionElementType);
return dataTypeFactory.createArrayType(
- toRelDataType(dataTypeFactory, fieldType.getCollectionElementType()),
- UNLIMITED_ARRAY_SIZE);
+ toRelDataType(dataTypeFactory, collectionElementType), UNLIMITED_ARRAY_SIZE);
case MAP:
- RelDataType componentKeyType = toRelDataType(dataTypeFactory, fieldType.getMapKeyType());
- RelDataType componentValueType =
- toRelDataType(dataTypeFactory, fieldType.getMapValueType());
+ FieldType mapKeyType = fieldType.getMapKeyType();
+ FieldType mapValueType = fieldType.getMapValueType();
+ Preconditions.checkArgumentNotNull(mapKeyType);
+ Preconditions.checkArgumentNotNull(mapValueType);
+ RelDataType componentKeyType = toRelDataType(dataTypeFactory, mapKeyType);
+ RelDataType componentValueType = toRelDataType(dataTypeFactory, mapValueType);
return dataTypeFactory.createMapType(componentKeyType, componentValueType);
case ROW:
- return toCalciteRowType(fieldType.getRowSchema(), dataTypeFactory);
+ Schema schema = fieldType.getRowSchema();
+ Preconditions.checkArgumentNotNull(schema);
+ return toCalciteRowType(schema, dataTypeFactory);
default:
return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
index 50b6ab2..e76ee7f 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
@@ -30,13 +30,17 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
/** Tests for conversion from Beam schema to Calcite data type. */
public class CalciteUtilsTest {
RelDataTypeFactory dataTypeFactory;
+ @Rule public ExpectedException thrown = ExpectedException.none();
+
@Before
public void setUp() {
dataTypeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
@@ -166,4 +170,12 @@
assertEquals(schema, out);
}
+
+ @Test
+ public void testFieldTypeNotFound() {
+ RelDataType relDataType = dataTypeFactory.createUnknownType();
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Cannot find a matching Beam FieldType for Calcite type: UNKNOWN");
+ CalciteUtils.toFieldType(relDataType);
+ }
}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
index 6413fd2..82c4b74 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
@@ -729,6 +729,17 @@
Map<String, String> argsMap =
((Map<String, Object>) MAPPER.convertValue(pipeline.getOptions(), Map.class).get("options"))
.entrySet().stream()
+ .filter(
+ (entry) -> {
+ if (entry.getValue() instanceof List) {
+ if (!((List) entry.getValue()).isEmpty()) {
+ throw new IllegalArgumentException("Cannot encode list arguments");
+ }
+ // We can encode empty lists, just omit them.
+ return false;
+ }
+ return true;
+ })
.collect(Collectors.toMap(Map.Entry::getKey, entry -> toArg(entry.getValue())));
InMemoryMetaStore inMemoryMetaStore = new InMemoryMetaStore();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 9af7f58..296767b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -48,8 +48,6 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
/**
* The {@code PCollectionConsumerRegistry} is used to maintain a collection of consuming
@@ -216,7 +214,7 @@
private final String pTransformId;
private final SimpleExecutionState state;
private final Counter unboundedElementCountCounter;
- private final SampleByteSizeDistribution<T> unboundSampledByteSizeDistribution;
+ private final SampleByteSizeDistribution<T> unboundedSampledByteSizeDistribution;
private final Coder<T> coder;
private final MetricsContainer metricsContainer;
@@ -239,7 +237,7 @@
MonitoringInfoMetricName sampledByteSizeMetricName =
MonitoringInfoMetricName.named(Urns.SAMPLED_BYTE_SIZE, labels);
- this.unboundSampledByteSizeDistribution =
+ this.unboundedSampledByteSizeDistribution =
new SampleByteSizeDistribution<>(
unboundMetricContainer.getDistribution(sampledByteSizeMetricName));
@@ -252,7 +250,7 @@
// Increment the counter for each window the element occurs in.
this.unboundedElementCountCounter.inc(input.getWindows().size());
// TODO(BEAM-11879): Consider updating size per window when we have window optimization.
- this.unboundSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
+ this.unboundedSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
// Wrap the consumer with extra logic to set the metric container with the appropriate
// PTransform context. This ensures that user metrics obtain the pTransform ID when they are
// created. Also use the ExecutionStateTracker and enter an appropriate state to track the
@@ -262,6 +260,7 @@
this.delegate.accept(input);
}
}
+ this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
}
}
@@ -321,6 +320,7 @@
consumerAndMetadata.getConsumer().accept(input);
}
}
+ this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
}
}
}
@@ -365,28 +365,33 @@
}
final Distribution distribution;
+ ByteSizeObserver byteCountObserver;
public SampleByteSizeDistribution(Distribution distribution) {
this.distribution = distribution;
+ this.byteCountObserver = null;
}
public void tryUpdate(T value, Coder<T> coder) throws Exception {
if (shouldSampleElement()) {
// First try using byte size observer
- ByteSizeObserver observer = new ByteSizeObserver();
- coder.registerByteSizeObserver(value, observer);
+ byteCountObserver = new ByteSizeObserver();
+ coder.registerByteSizeObserver(value, byteCountObserver);
- if (!observer.getIsLazy()) {
- observer.advance();
- this.distribution.update(observer.observedSize);
- } else {
- // TODO(BEAM-11841): Optimize calculation of element size for iterables.
- // Coder byte size observation is lazy (requires iteration for observation) so fall back
- // to counting output stream
- CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
- coder.encode(value, os);
- this.distribution.update(os.getCount());
+ if (!byteCountObserver.getIsLazy()) {
+ byteCountObserver.advance();
+ this.distribution.update(byteCountObserver.observedSize);
}
+ } else {
+ byteCountObserver = null;
+ }
+ }
+
+ public void finishLazyUpdate() {
+ // Advance lazy ElementByteSizeObservers, if any.
+ if (byteCountObserver != null && byteCountObserver.getIsLazy()) {
+ byteCountObserver.advance();
+ this.distribution.update(byteCountObserver.observedSize);
}
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index 90baa5e..708ca8b 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -24,6 +24,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@@ -32,7 +33,9 @@
import static org.powermock.api.mockito.PowerMockito.mockStatic;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import org.apache.beam.fn.harness.HandlesSplits;
import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
@@ -44,15 +47,19 @@
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
+import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
@@ -329,5 +336,108 @@
verify(consumerA1).trySplit(0.3);
}
+ @Test
+ public void testLazyByteSizeEstimation() throws Exception {
+ final String pCollectionA = "pCollectionA";
+ final String pTransformIdA = "pTransformIdA";
+
+ MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+ PCollectionConsumerRegistry consumers =
+ new PCollectionConsumerRegistry(
+ metricsContainerRegistry, mock(ExecutionStateTracker.class));
+ FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class);
+
+ consumers.register(
+ pCollectionA, pTransformIdA, consumerA1, IterableCoder.of(StringUtf8Coder.of()));
+
+ FnDataReceiver<WindowedValue<Iterable<String>>> wrapperConsumer =
+ (FnDataReceiver<WindowedValue<Iterable<String>>>)
+ (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+ String elementValue = "elem";
+ long elementByteSize = StringUtf8Coder.of().getEncodedElementByteSize(elementValue);
+ WindowedValue<Iterable<String>> element =
+ valueInGlobalWindow(
+ new TestElementByteSizeObservableIterable<>(
+ Arrays.asList(elementValue, elementValue), elementByteSize));
+ int numElements = 10;
+ // Mock doing work on the iterable items
+ doAnswer(
+ (Answer<Void>)
+ invocation -> {
+ Object[] args = invocation.getArguments();
+ WindowedValue<Iterable<String>> arg = (WindowedValue<Iterable<String>>) args[0];
+ Iterator it = arg.getValue().iterator();
+ while (it.hasNext()) {
+ it.next();
+ }
+ return null;
+ })
+ .when(consumerA1)
+ .accept(element);
+
+ for (int i = 0; i < numElements; i++) {
+ wrapperConsumer.accept(element);
+ }
+
+ // Check that the underlying consumers are each invoked per element.
+ verify(consumerA1, times(numElements)).accept(element);
+ assertThat(consumers.keySet(), contains(pCollectionA));
+
+ List<MonitoringInfo> expected = new ArrayList<>();
+
+ SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+ builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+ builder.setInt64SumValue(numElements);
+ expected.add(builder.build());
+
+ builder = new SimpleMonitoringInfoBuilder();
+ builder.setUrn(Urns.SAMPLED_BYTE_SIZE);
+ builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+ long expectedBytes =
+ (elementByteSize + 1) * 2
+ + 5; // Additional 5 bytes are due to size and hasNext = false (1 byte).
+ builder.setInt64DistributionValue(
+ DistributionData.create(
+ numElements * expectedBytes, numElements, expectedBytes, expectedBytes));
+ expected.add(builder.build());
+ // Clear the timestamp before comparison.
+ Iterable<MonitoringInfo> result =
+ Iterables.filter(
+ metricsContainerRegistry.getMonitoringInfos(),
+ monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+
+ assertThat(result, containsInAnyOrder(expected.toArray()));
+ }
+
+ private class TestElementByteSizeObservableIterable<T>
+ extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>> {
+ private List<T> elements;
+ private long elementByteSize;
+
+ public TestElementByteSizeObservableIterable(List<T> elements, long elementByteSize) {
+ this.elements = elements;
+ this.elementByteSize = elementByteSize;
+ }
+
+ @Override
+ protected ElementByteSizeObservableIterator createIterator() {
+ return new ElementByteSizeObservableIterator() {
+ private int index = 0;
+
+ @Override
+ public boolean hasNext() {
+ return index < elements.size();
+ }
+
+ @Override
+ public Object next() {
+ notifyValueReturned(elementByteSize);
+ return elements.get(index++);
+ }
+ };
+ }
+ }
+
private abstract static class SplittingReceiver<T> implements FnDataReceiver<T>, HandlesSplits {}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
index 4193ba6..fc1fe94 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
@@ -109,7 +109,10 @@
.getReferencedTables();
if (referencedTables != null && !referencedTables.isEmpty()) {
TableReference referencedTable = referencedTables.get(0);
- effectiveLocation = tableService.getTable(referencedTable).getLocation();
+ effectiveLocation =
+ tableService
+ .getDataset(referencedTable.getProjectId(), referencedTable.getDatasetId())
+ .getLocation();
}
}
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index 84743f8c..820418c 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -43,7 +43,7 @@
configurations.testRuntimeClasspath {
resolutionStrategy {
- def log4j_version = "2.4.1"
+ def log4j_version = "2.8.2"
// Beam's build system forces a uniform log4j version resolution for all modules, however for
// the HCatalog case the current version of log4j produces NoClassDefFoundError so we need to
// force an old version on the tests runtime classpath
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index c6d816a..e1f6635 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -123,6 +123,11 @@
try:
_load_main_session(semi_persistent_directory)
+ except CorruptMainSessionException:
+ exception_details = traceback.format_exc()
+ _LOGGER.error(
+ 'Could not load main session: %s', exception_details, exc_info=True)
+ raise
except Exception: # pylint: disable=broad-except
exception_details = traceback.format_exc()
_LOGGER.error(
@@ -245,12 +250,29 @@
return 0
+class CorruptMainSessionException(Exception):
+ """
+ Used to crash this worker if a main session file was provided but
+ is not valid.
+ """
+ pass
+
+
def _load_main_session(semi_persistent_directory):
"""Loads a pickled main session from the path specified."""
if semi_persistent_directory:
session_file = os.path.join(
semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
if os.path.isfile(session_file):
+ # If the expected session file is present but empty, it's likely that
+ # the user code run by this worker will likely crash at runtime.
+ # This can happen if the worker fails to download the main session.
+ # Raise a fatal error and crash this worker, forcing a restart.
+ if os.path.getsize(session_file) == 0:
+ raise CorruptMainSessionException(
+ 'Session file found, but empty: %s. Functions defined in __main__ '
+ '(interactive session) will almost certainly fail.' %
+ (session_file, ))
pickler.load_session(session_file)
else:
_LOGGER.warning(