Merge pull request #14762 from ibzib/BEAM-12312

[BEAM-12312] Don't rely on remove() in LazyAggregateCombineFn#mergeAc…
diff --git a/.test-infra/jenkins/CommonJobProperties.groovy b/.test-infra/jenkins/CommonJobProperties.groovy
index 851fc0b..295cc89 100644
--- a/.test-infra/jenkins/CommonJobProperties.groovy
+++ b/.test-infra/jenkins/CommonJobProperties.groovy
@@ -183,6 +183,11 @@
     context.switches("-Dorg.gradle.jvmargs=-Xms2g")
     context.switches("-Dorg.gradle.jvmargs=-Xmx4g")
 
+    // Disable file system watching for CI builds
+    // Builds are performed on a clean clone and files aren't modified, so
+    // there's no value in watching for changes.
+    context.switches("-Dorg.gradle.vfs.watch=false")
+
     // Include dependency licenses when build docker images on Jenkins, see https://s.apache.org/zt68q
     context.switches("-Pdocker-pull-licenses")
   }
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
index 0e3e628..1ac4562 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow.groovy
@@ -27,7 +27,7 @@
 
       description('Runs the ValidatesRunner suite on the Dataflow runner.')
 
-      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
       previousNames(/beam_PostCommit_Java_ValidatesRunner_Dataflow_Gradle/)
 
       // Publish all test results to Jenkins
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
index af9e25e..6ba9685 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_Java11.groovy
@@ -28,7 +28,7 @@
       def JAVA_11_HOME = '/usr/lib/jvm/java-11-openjdk-amd64'
       def JAVA_8_HOME = '/usr/lib/jvm/java-8-openjdk-amd64'
 
-      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 270)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 420)
       publishers {
         archiveJunit('**/build/test-results/**/*.xml')
       }
diff --git a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
index 0de085d..3b3d8ed 100644
--- a/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.groovy
@@ -27,7 +27,7 @@
 
       description('Runs Java ValidatesRunner suite on the Dataflow runner V2 forcing streaming mode.')
 
-      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 330)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 450)
 
       // Publish all test results to Jenkins
       publishers {
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
index e847928..5631649 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
@@ -26,7 +26,7 @@
       description('Runs Python ValidatesRunner suite on the Dataflow runner.')
 
       // Set common parameters.
-      commonJobProperties.setTopLevelMainJobProperties(delegate)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
 
       publishers {
         archiveJunit('**/nosetests*.xml')
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
index c219de6..e164514 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
@@ -26,7 +26,7 @@
       description('Runs Python ValidatesRunner suite on the Dataflow runner v2.')
 
       // Set common parameters.
-      commonJobProperties.setTopLevelMainJobProperties(delegate)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
 
       publishers {
         archiveJunit('**/nosetests*.xml')
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
index 40e7383..5cba1d4 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
@@ -42,7 +42,7 @@
   static def alpn_api_version = "1.1.2.v20150522"
   static def npn_api_version = "1.1.1.v20141010"
   static def jboss_marshalling_version = "1.4.11.Final"
-  static def jboss_modules_version = "1.11.0.Final"
+  static def jboss_modules_version = "1.1.0.Beta1"
 
   /** Returns the list of compile time dependencies. */
   static List<String> dependencies() {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
index aae35ab..e32edf7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ReplacementOutputs.java
@@ -26,7 +26,6 @@
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.POutput;
-import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.PValues;
 import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TupleTag;
@@ -41,17 +40,16 @@
   private ReplacementOutputs() {}
 
   public static Map<PCollection<?>, ReplacementOutput> singleton(
-      Map<TupleTag<?>, PCollection<?>> original, PValue replacement) {
+      Map<TupleTag<?>, PCollection<?>> original, POutput replacement) {
     Entry<TupleTag<?>, PCollection<?>> originalElement =
         Iterables.getOnlyElement(original.entrySet());
-    TupleTag<?> replacementTag = Iterables.getOnlyElement(replacement.expand().entrySet()).getKey();
-    PCollection<?> replacementCollection =
-        (PCollection<?>) Iterables.getOnlyElement(replacement.expand().entrySet()).getValue();
+    Entry<TupleTag<?>, PCollection<?>> replacementElement =
+        Iterables.getOnlyElement(PValues.expandOutput(replacement).entrySet());
     return Collections.singletonMap(
-        replacementCollection,
+        replacementElement.getValue(),
         ReplacementOutput.of(
             TaggedPValue.of(originalElement.getKey(), originalElement.getValue()),
-            TaggedPValue.of(replacementTag, replacementCollection)));
+            TaggedPValue.of(replacementElement.getKey(), replacementElement.getValue())));
   }
 
   public static Map<PCollection<?>, ReplacementOutput> tagged(
diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
index 9ea5099..1ba1ac4 100644
--- a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
+++ b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
@@ -24,53 +24,105 @@
 	"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/util/protox"
+	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
 	pipepb "github.com/apache/beam/sdks/go/pkg/beam/model/pipeline_v1"
 )
 
 // ResolveArtifacts acquires all dependencies for a cross-language transform
 func ResolveArtifacts(ctx context.Context, edges []*graph.MultiEdge, p *pipepb.Pipeline) {
-	path, err := filepath.Abs("/tmp/artifacts")
+	_, err := ResolveArtifactsWithConfig(ctx, edges, ResolveConfig{})
 	if err != nil {
 		panic(err)
 	}
+}
+
+// ResolveConfig contains fields for configuring the behavior for resolving
+// artifacts.
+type ResolveConfig struct {
+	// SdkPath replaces the default filepath for dependencies, but only in the
+	// external environment proto to be used by the SDK Harness during pipeline
+	// execution. This is used to specify alternate staging directories, such
+	// as for staging artifacts remotely.
+	//
+	// Setting an SdkPath does not change staging behavior otherwise. All
+	// artifacts still get staged to the default local filepath, and it is the
+	// user's responsibility to stage those local artifacts to the SdkPath.
+	SdkPath string
+
+	// JoinFn is a function for combining SdkPath and individual artifact names.
+	// If not specified, it defaults to using filepath.Join.
+	JoinFn func(path, name string) string
+}
+
+func defaultJoinFn(path, name string) string {
+	return filepath.Join(path, "/", name)
+}
+
+// ResolveArtifactsWithConfig acquires all dependencies for cross-language
+// transforms, but with some additional configuration to behavior. By default,
+// this function performs the following steps for each cross-language transform
+// in the list of edges:
+//   1. Retrieves a list of dependencies needed from the expansion service.
+//   2. Retrieves each dependency as an artifact and stages it to a default
+//      local filepath.
+//   3. Adds the dependencies to the transform's stored environment proto.
+// The changes that can be configured are documented in ResolveConfig.
+//
+// This returns a map of "local path" to "sdk path". By default these are
+// identical, unless ResolveConfig.SdkPath has been set.
+func ResolveArtifactsWithConfig(ctx context.Context, edges []*graph.MultiEdge, cfg ResolveConfig) (paths map[string]string, err error) {
+	tmpPath, err := filepath.Abs("/tmp/artifacts")
+	if err != nil {
+		return nil, errors.WithContext(err, "resolving remote artifacts")
+	}
+	if cfg.JoinFn == nil {
+		cfg.JoinFn = defaultJoinFn
+	}
+	paths = make(map[string]string)
 	for _, e := range edges {
 		if e.Op == graph.External {
 			components, err := graphx.ExpandedComponents(e.External.Expanded)
 			if err != nil {
-				panic(err)
+				return nil, errors.WithContextf(err,
+					"resolving remote artifacts for edge %v", e.Name())
 			}
 			envs := components.Environments
 			for eid, env := range envs {
-
 				if strings.HasPrefix(eid, "go") {
 					continue
 				}
 				deps := env.GetDependencies()
-				resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", path)
+				resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", tmpPath)
 				if err != nil {
-					panic(err)
+					return nil, errors.WithContextf(err,
+						"resolving remote artifacts for env %v in edge %v", eid, e.Name())
 				}
 
 				var resolvedDeps []*pipepb.ArtifactInformation
 				for _, a := range resolvedArtifacts {
-					name, sha256 := artifact.MustExtractFilePayload(a)
-					fullPath := filepath.Join(path, "/", name)
+					name, _ := artifact.MustExtractFilePayload(a)
+					fullTmpPath := filepath.Join(tmpPath, "/", name)
+					fullSdkPath := fullTmpPath
+					if len(cfg.SdkPath) > 0 {
+						fullSdkPath = cfg.JoinFn(cfg.SdkPath, name)
+					}
 					resolvedDeps = append(resolvedDeps,
 						&pipepb.ArtifactInformation{
 							TypeUrn: "beam:artifact:type:file:v1",
 							TypePayload: protox.MustEncode(
 								&pipepb.ArtifactFilePayload{
-									Path:   fullPath,
-									Sha256: sha256,
+									Path: fullSdkPath,
 								},
 							),
 							RoleUrn:     a.RoleUrn,
 							RolePayload: a.RolePayload,
 						},
 					)
+					paths[fullTmpPath] = fullSdkPath
 				}
 				env.Dependencies = resolvedDeps
 			}
 		}
 	}
+	return paths, nil
 }
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
index c46ff4c..cd7be52 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
@@ -178,11 +178,23 @@
 	}
 
 	// (1) Build and submit
+	// NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
+	id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
+
+	modelURL := gcsx.Join(*stagingLocation, id, "model")
+	workerURL := gcsx.Join(*stagingLocation, id, "worker")
+	jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
+	xlangURL := gcsx.Join(*stagingLocation, id, "xlang")
 
 	edges, _, err := p.Build()
 	if err != nil {
 		return nil, err
 	}
+	artifactURLs, err := dataflowlib.ResolveXLangArtifacts(ctx, edges, opts.Project, xlangURL)
+	if err != nil {
+		return nil, errors.WithContext(err, "resolving cross-language artifacts")
+	}
+	opts.ArtifactURLs = artifactURLs
 	environment, err := graphx.CreateEnvironment(ctx, jobopts.GetEnvironmentUrn(ctx), getContainerImage)
 	if err != nil {
 		return nil, errors.WithContext(err, "creating environment for model pipeline")
@@ -196,13 +208,6 @@
 		return nil, errors.WithContext(err, "applying container image overrides")
 	}
 
-	// NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
-	id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
-
-	modelURL := gcsx.Join(*stagingLocation, id, "model")
-	workerURL := gcsx.Join(*stagingLocation, id, "worker")
-	jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
-
 	if *dryRun {
 		log.Info(ctx, "Dry-run: not submitting job!")
 
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
index 511b962..390082b 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
@@ -56,6 +56,7 @@
 	WorkerRegion        string
 	WorkerZone          string
 	ContainerImage      string
+	ArtifactURLs        []string // Additional packages for workers.
 
 	// Autoscaling settings
 	Algorithm     string
@@ -128,6 +129,15 @@
 		experiments = append(experiments, "use_staged_dataflow_worker_jar")
 	}
 
+	for _, url := range opts.ArtifactURLs {
+		name := url[strings.LastIndexAny(url, "/")+1:]
+		pkg := &df.Package{
+			Name:     name,
+			Location: url,
+		}
+		packages = append(packages, pkg)
+	}
+
 	ipConfiguration := "WORKER_IP_UNSPECIFIED"
 	if opts.NoUsePublicIPs {
 		ipConfiguration = "WORKER_IP_PRIVATE"
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
index 67a2bde..49ca5bf 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
@@ -22,6 +22,8 @@
 	"os"
 
 	"cloud.google.com/go/storage"
+	"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
+	"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/xlangx"
 	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
 	"github.com/apache/beam/sdks/go/pkg/beam/util/gcsx"
 )
@@ -54,3 +56,28 @@
 	_, err = gcsx.Upload(ctx, client, project, bucket, obj, r)
 	return err
 }
+
+// ResolveXLangArtifacts resolves cross-language artifacts with a given GCS
+// URL as a destination, and then stages all local artifacts to that URL. This
+// function returns a list of staged artifact URLs.
+func ResolveXLangArtifacts(ctx context.Context, edges []*graph.MultiEdge, project, url string) ([]string, error) {
+	cfg := xlangx.ResolveConfig{
+		SdkPath: url,
+		JoinFn: func(url, name string) string {
+			return gcsx.Join(url, "/", name)
+		},
+	}
+	paths, err := xlangx.ResolveArtifactsWithConfig(ctx, edges, cfg)
+	if err != nil {
+		return nil, err
+	}
+	var urls []string
+	for local, remote := range paths {
+		err := StageFile(ctx, project, remote, local)
+		if err != nil {
+			return nil, errors.WithContextf(err, "staging file to %v", remote)
+		}
+		urls = append(urls, remote)
+	}
+	return urls, nil
+}
diff --git a/sdks/go/pkg/beam/runners/universal/universal.go b/sdks/go/pkg/beam/runners/universal/universal.go
index f01a71f..af14b38 100644
--- a/sdks/go/pkg/beam/runners/universal/universal.go
+++ b/sdks/go/pkg/beam/runners/universal/universal.go
@@ -79,6 +79,9 @@
 		getEnvCfg = srv.EnvironmentConfig
 	}
 
+	// Fetch all dependencies for cross-language transforms
+	xlangx.ResolveArtifacts(ctx, edges, nil)
+
 	environment, err := graphx.CreateEnvironment(ctx, envUrn, getEnvCfg)
 	if err != nil {
 		return nil, errors.WithContextf(err, "generating model pipeline")
@@ -88,9 +91,6 @@
 		return nil, errors.WithContextf(err, "generating model pipeline")
 	}
 
-	// Fetch all dependencies for cross-language transforms
-	xlangx.ResolveArtifacts(ctx, edges, pipeline)
-
 	log.Info(ctx, proto.MarshalTextString(pipeline))
 
 	opt := &runnerlib.JobOptions{
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index c756e05..7db20ac 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -78,12 +78,6 @@
 var dataflowFilters = []string{
 	// TODO(BEAM-11576): TestFlattenDup failing on this runner.
 	"TestFlattenDup",
-	// TODO(BEAM-11418): These tests require implementing x-lang artifact
-	// staging on Dataflow.
-	"TestXLang_CoGroupBy",
-	"TestXLang_Multi",
-	"TestXLang_Partition",
-	"TestXLang_Prefix",
 }
 
 // CheckFilters checks if an integration test is filtered to be skipped, either
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
index 92c3e05..b04266c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java
@@ -411,12 +411,12 @@
       checkState(
           this.outputs == null, "Tried to specify more than one output for %s", getFullName());
       checkNotNull(output, "Tried to set the output of %s to null", getFullName());
-      this.outputs = PValues.fullyExpand(output.expand());
+      this.outputs = PValues.expandOutput(output);
 
       // Validate that a primitive transform produces only primitive output, and a composite
       // transform does not produce primitive output.
       Set<Node> outputProducers = new HashSet<>();
-      for (PCollection<?> outputValue : PValues.fullyExpand(output.expand()).values()) {
+      for (PCollection<?> outputValue : PValues.expandOutput(output).values()) {
         outputProducers.add(getProducer(outputValue));
       }
       if (outputProducers.contains(this) && (!parts.isEmpty() || outputProducers.size() > 1)) {
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 34664ac..7b580ef 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -27,6 +27,7 @@
 import org.apache.beam.sdk.schemas.Schema.TypeName;
 import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
 import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
+import org.apache.beam.sdk.util.Preconditions;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.BiMap;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableBiMap;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -39,9 +40,6 @@
 import org.joda.time.base.AbstractInstant;
 
 /** Utility methods for Calcite related operations. */
-@SuppressWarnings({
-  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
 public class CalciteUtils {
   private static final long UNLIMITED_ARRAY_SIZE = -1L;
 
@@ -73,7 +71,9 @@
     }
 
     if (fieldType.getTypeName().isLogicalType()) {
-      String logicalId = fieldType.getLogicalType().getIdentifier();
+      Schema.LogicalType logicalType = fieldType.getLogicalType();
+      Preconditions.checkArgumentNotNull(logicalType);
+      String logicalId = logicalType.getIdentifier();
       return logicalId.equals(SqlTypes.DATE.getIdentifier())
           || logicalId.equals(SqlTypes.TIME.getIdentifier())
           || logicalId.equals(TimeWithLocalTzType.IDENTIFIER)
@@ -88,7 +88,9 @@
     }
 
     if (fieldType.getTypeName().isLogicalType()) {
-      String logicalId = fieldType.getLogicalType().getIdentifier();
+      Schema.LogicalType logicalType = fieldType.getLogicalType();
+      Preconditions.checkArgumentNotNull(logicalType);
+      String logicalId = logicalType.getIdentifier();
       return logicalId.equals(CharType.IDENTIFIER);
     }
     return false;
@@ -210,7 +212,12 @@
                     + "so it cannot be converted to a %s",
                 sqlTypeName, Schema.FieldType.class.getSimpleName()));
       default:
-        return CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+        FieldType fieldType = CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+        if (fieldType == null) {
+          throw new IllegalArgumentException(
+              "Cannot find a matching Beam FieldType for Calcite type: " + sqlTypeName);
+        }
+        return fieldType;
     }
   }
 
@@ -234,7 +241,12 @@
         return FieldType.row(toSchema(calciteType));
 
       default:
-        return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+        try {
+          return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+        } catch (IllegalArgumentException e) {
+          throw new IllegalArgumentException(
+              "Cannot find a matching Beam FieldType for Calcite type: " + calciteType, e);
+        }
     }
   }
 
@@ -254,16 +266,22 @@
     switch (fieldType.getTypeName()) {
       case ARRAY:
       case ITERABLE:
+        FieldType collectionElementType = fieldType.getCollectionElementType();
+        Preconditions.checkArgumentNotNull(collectionElementType);
         return dataTypeFactory.createArrayType(
-            toRelDataType(dataTypeFactory, fieldType.getCollectionElementType()),
-            UNLIMITED_ARRAY_SIZE);
+            toRelDataType(dataTypeFactory, collectionElementType), UNLIMITED_ARRAY_SIZE);
       case MAP:
-        RelDataType componentKeyType = toRelDataType(dataTypeFactory, fieldType.getMapKeyType());
-        RelDataType componentValueType =
-            toRelDataType(dataTypeFactory, fieldType.getMapValueType());
+        FieldType mapKeyType = fieldType.getMapKeyType();
+        FieldType mapValueType = fieldType.getMapValueType();
+        Preconditions.checkArgumentNotNull(mapKeyType);
+        Preconditions.checkArgumentNotNull(mapValueType);
+        RelDataType componentKeyType = toRelDataType(dataTypeFactory, mapKeyType);
+        RelDataType componentValueType = toRelDataType(dataTypeFactory, mapValueType);
         return dataTypeFactory.createMapType(componentKeyType, componentValueType);
       case ROW:
-        return toCalciteRowType(fieldType.getRowSchema(), dataTypeFactory);
+        Schema schema = fieldType.getRowSchema();
+        Preconditions.checkArgumentNotNull(schema);
+        return toCalciteRowType(schema, dataTypeFactory);
       default:
         return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
     }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
index 50b6ab2..e76ee7f 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
@@ -30,13 +30,17 @@
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 /** Tests for conversion from Beam schema to Calcite data type. */
 public class CalciteUtilsTest {
 
   RelDataTypeFactory dataTypeFactory;
 
+  @Rule public ExpectedException thrown = ExpectedException.none();
+
   @Before
   public void setUp() {
     dataTypeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
@@ -166,4 +170,12 @@
 
     assertEquals(schema, out);
   }
+
+  @Test
+  public void testFieldTypeNotFound() {
+    RelDataType relDataType = dataTypeFactory.createUnknownType();
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("Cannot find a matching Beam FieldType for Calcite type: UNKNOWN");
+    CalciteUtils.toFieldType(relDataType);
+  }
 }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
index 6413fd2..82c4b74 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
@@ -729,6 +729,17 @@
     Map<String, String> argsMap =
         ((Map<String, Object>) MAPPER.convertValue(pipeline.getOptions(), Map.class).get("options"))
             .entrySet().stream()
+                .filter(
+                    (entry) -> {
+                      if (entry.getValue() instanceof List) {
+                        if (!((List) entry.getValue()).isEmpty()) {
+                          throw new IllegalArgumentException("Cannot encode list arguments");
+                        }
+                        // We can encode empty lists, just omit them.
+                        return false;
+                      }
+                      return true;
+                    })
                 .collect(Collectors.toMap(Map.Entry::getKey, entry -> toArg(entry.getValue())));
 
     InMemoryMetaStore inMemoryMetaStore = new InMemoryMetaStore();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 9af7f58..296767b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -48,8 +48,6 @@
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 
 /**
  * The {@code PCollectionConsumerRegistry} is used to maintain a collection of consuming
@@ -216,7 +214,7 @@
     private final String pTransformId;
     private final SimpleExecutionState state;
     private final Counter unboundedElementCountCounter;
-    private final SampleByteSizeDistribution<T> unboundSampledByteSizeDistribution;
+    private final SampleByteSizeDistribution<T> unboundedSampledByteSizeDistribution;
     private final Coder<T> coder;
     private final MetricsContainer metricsContainer;
 
@@ -239,7 +237,7 @@
 
       MonitoringInfoMetricName sampledByteSizeMetricName =
           MonitoringInfoMetricName.named(Urns.SAMPLED_BYTE_SIZE, labels);
-      this.unboundSampledByteSizeDistribution =
+      this.unboundedSampledByteSizeDistribution =
           new SampleByteSizeDistribution<>(
               unboundMetricContainer.getDistribution(sampledByteSizeMetricName));
 
@@ -252,7 +250,7 @@
       // Increment the counter for each window the element occurs in.
       this.unboundedElementCountCounter.inc(input.getWindows().size());
       // TODO(BEAM-11879): Consider updating size per window when we have window optimization.
-      this.unboundSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
+      this.unboundedSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
       // Wrap the consumer with extra logic to set the metric container with the appropriate
       // PTransform context. This ensures that user metrics obtain the pTransform ID when they are
       // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
@@ -262,6 +260,7 @@
           this.delegate.accept(input);
         }
       }
+      this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
     }
   }
 
@@ -321,6 +320,7 @@
             consumerAndMetadata.getConsumer().accept(input);
           }
         }
+        this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
       }
     }
   }
@@ -365,28 +365,33 @@
     }
 
     final Distribution distribution;
+    ByteSizeObserver byteCountObserver;
 
     public SampleByteSizeDistribution(Distribution distribution) {
       this.distribution = distribution;
+      this.byteCountObserver = null;
     }
 
     public void tryUpdate(T value, Coder<T> coder) throws Exception {
       if (shouldSampleElement()) {
         // First try using byte size observer
-        ByteSizeObserver observer = new ByteSizeObserver();
-        coder.registerByteSizeObserver(value, observer);
+        byteCountObserver = new ByteSizeObserver();
+        coder.registerByteSizeObserver(value, byteCountObserver);
 
-        if (!observer.getIsLazy()) {
-          observer.advance();
-          this.distribution.update(observer.observedSize);
-        } else {
-          // TODO(BEAM-11841): Optimize calculation of element size for iterables.
-          // Coder byte size observation is lazy (requires iteration for observation) so fall back
-          // to counting output stream
-          CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-          coder.encode(value, os);
-          this.distribution.update(os.getCount());
+        if (!byteCountObserver.getIsLazy()) {
+          byteCountObserver.advance();
+          this.distribution.update(byteCountObserver.observedSize);
         }
+      } else {
+        byteCountObserver = null;
+      }
+    }
+
+    public void finishLazyUpdate() {
+      // Advance lazy ElementByteSizeObservers, if any.
+      if (byteCountObserver != null && byteCountObserver.getIsLazy()) {
+        byteCountObserver.advance();
+        this.distribution.update(byteCountObserver.observedSize);
       }
     }
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index 90baa5e..708ca8b 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -24,6 +24,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -32,7 +33,9 @@
 import static org.powermock.api.mockito.PowerMockito.mockStatic;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import org.apache.beam.fn.harness.HandlesSplits;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
@@ -44,15 +47,19 @@
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
 import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
+import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
@@ -329,5 +336,108 @@
     verify(consumerA1).trySplit(0.3);
   }
 
+  @Test
+  public void testLazyByteSizeEstimation() throws Exception {
+    final String pCollectionA = "pCollectionA";
+    final String pTransformIdA = "pTransformIdA";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class);
+
+    consumers.register(
+        pCollectionA, pTransformIdA, consumerA1, IterableCoder.of(StringUtf8Coder.of()));
+
+    FnDataReceiver<WindowedValue<Iterable<String>>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<Iterable<String>>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+    String elementValue = "elem";
+    long elementByteSize = StringUtf8Coder.of().getEncodedElementByteSize(elementValue);
+    WindowedValue<Iterable<String>> element =
+        valueInGlobalWindow(
+            new TestElementByteSizeObservableIterable<>(
+                Arrays.asList(elementValue, elementValue), elementByteSize));
+    int numElements = 10;
+    // Mock doing work on the iterable items
+    doAnswer(
+            (Answer<Void>)
+                invocation -> {
+                  Object[] args = invocation.getArguments();
+                  WindowedValue<Iterable<String>> arg = (WindowedValue<Iterable<String>>) args[0];
+                  Iterator it = arg.getValue().iterator();
+                  while (it.hasNext()) {
+                    it.next();
+                  }
+                  return null;
+                })
+        .when(consumerA1)
+        .accept(element);
+
+    for (int i = 0; i < numElements; i++) {
+      wrapperConsumer.accept(element);
+    }
+
+    // Check that the underlying consumers are each invoked per element.
+    verify(consumerA1, times(numElements)).accept(element);
+    assertThat(consumers.keySet(), contains(pCollectionA));
+
+    List<MonitoringInfo> expected = new ArrayList<>();
+
+    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+    builder.setInt64SumValue(numElements);
+    expected.add(builder.build());
+
+    builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(Urns.SAMPLED_BYTE_SIZE);
+    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+    long expectedBytes =
+        (elementByteSize + 1) * 2
+            + 5; // Additional 5 bytes are due to size and hasNext = false (1 byte).
+    builder.setInt64DistributionValue(
+        DistributionData.create(
+            numElements * expectedBytes, numElements, expectedBytes, expectedBytes));
+    expected.add(builder.build());
+    // Clear the timestamp before comparison.
+    Iterable<MonitoringInfo> result =
+        Iterables.filter(
+            metricsContainerRegistry.getMonitoringInfos(),
+            monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+
+    assertThat(result, containsInAnyOrder(expected.toArray()));
+  }
+
+  private class TestElementByteSizeObservableIterable<T>
+      extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>> {
+    private List<T> elements;
+    private long elementByteSize;
+
+    public TestElementByteSizeObservableIterable(List<T> elements, long elementByteSize) {
+      this.elements = elements;
+      this.elementByteSize = elementByteSize;
+    }
+
+    @Override
+    protected ElementByteSizeObservableIterator createIterator() {
+      return new ElementByteSizeObservableIterator() {
+        private int index = 0;
+
+        @Override
+        public boolean hasNext() {
+          return index < elements.size();
+        }
+
+        @Override
+        public Object next() {
+          notifyValueReturned(elementByteSize);
+          return elements.get(index++);
+        }
+      };
+    }
+  }
+
   private abstract static class SplittingReceiver<T> implements FnDataReceiver<T>, HandlesSplits {}
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
index 4193ba6..fc1fe94 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
@@ -109,7 +109,10 @@
               .getReferencedTables();
       if (referencedTables != null && !referencedTables.isEmpty()) {
         TableReference referencedTable = referencedTables.get(0);
-        effectiveLocation = tableService.getTable(referencedTable).getLocation();
+        effectiveLocation =
+            tableService
+                .getDataset(referencedTable.getProjectId(), referencedTable.getDatasetId())
+                .getLocation();
       }
     }
 
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index 84743f8c..820418c 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -43,7 +43,7 @@
 
 configurations.testRuntimeClasspath {
   resolutionStrategy {
-    def log4j_version = "2.4.1"
+    def log4j_version = "2.8.2"
     // Beam's build system forces a uniform log4j version resolution for all modules, however for
     // the HCatalog case the current version of log4j produces NoClassDefFoundError so we need to
     // force an old version on the tests runtime classpath
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index c6d816a..e1f6635 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -123,6 +123,11 @@
 
   try:
     _load_main_session(semi_persistent_directory)
+  except CorruptMainSessionException:
+    exception_details = traceback.format_exc()
+    _LOGGER.error(
+        'Could not load main session: %s', exception_details, exc_info=True)
+    raise
   except Exception:  # pylint: disable=broad-except
     exception_details = traceback.format_exc()
     _LOGGER.error(
@@ -245,12 +250,29 @@
   return 0
 
 
+class CorruptMainSessionException(Exception):
+  """
+  Used to crash this worker if a main session file was provided but
+  is not valid.
+  """
+  pass
+
+
 def _load_main_session(semi_persistent_directory):
   """Loads a pickled main session from the path specified."""
   if semi_persistent_directory:
     session_file = os.path.join(
         semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
     if os.path.isfile(session_file):
+      # If the expected session file is present but empty, it's likely that
+      # the user code run by this worker will likely crash at runtime.
+      # This can happen if the worker fails to download the main session.
+      # Raise a fatal error and crash this worker, forcing a restart.
+      if os.path.getsize(session_file) == 0:
+        raise CorruptMainSessionException(
+            'Session file found, but empty: %s. Functions defined in __main__ '
+            '(interactive session) will almost certainly fail.' %
+            (session_file, ))
       pickler.load_session(session_file)
     else:
       _LOGGER.warning(