[BEAM-7372] cleanup py2 codepath from apache_beam/ml, options and portability (#14753)

[BEAM-7372] cleanup py2 codepath from apache_beam/ml, options and portability
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
index e847928..5631649 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow.groovy
@@ -26,7 +26,7 @@
       description('Runs Python ValidatesRunner suite on the Dataflow runner.')
 
       // Set common parameters.
-      commonJobProperties.setTopLevelMainJobProperties(delegate)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
 
       publishers {
         archiveJunit('**/nosetests*.xml')
diff --git a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
index c219de6..e164514 100644
--- a/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
+++ b/.test-infra/jenkins/job_PostCommit_Python_ValidatesRunner_Dataflow_V2.groovy
@@ -26,7 +26,7 @@
       description('Runs Python ValidatesRunner suite on the Dataflow runner v2.')
 
       // Set common parameters.
-      commonJobProperties.setTopLevelMainJobProperties(delegate)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 200)
 
       publishers {
         archiveJunit('**/nosetests*.xml')
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
index 40e7383..5cba1d4 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_26_0.groovy
@@ -42,7 +42,7 @@
   static def alpn_api_version = "1.1.2.v20150522"
   static def npn_api_version = "1.1.1.v20141010"
   static def jboss_marshalling_version = "1.4.11.Final"
-  static def jboss_modules_version = "1.11.0.Final"
+  static def jboss_modules_version = "1.1.0.Beta1"
 
   /** Returns the list of compile time dependencies. */
   static List<String> dependencies() {
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go b/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
index a011a78..0bd33bd 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/schema/logicaltypes.go
@@ -51,7 +51,6 @@
 
 // Registry retains mappings from go types to Schemas and LogicalTypes.
 type Registry struct {
-	lastShortID     int64
 	typeToSchema    map[reflect.Type]*pipepb.Schema
 	idToType        map[string]reflect.Type
 	syntheticToUser map[reflect.Type]reflect.Type
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go b/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
index abb6f10..087d8c1 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
@@ -26,7 +26,9 @@
 package schema
 
 import (
+	"bytes"
 	"fmt"
+	"hash/fnv"
 	"reflect"
 	"strings"
 
@@ -72,8 +74,19 @@
 	defaultRegistry.RegisterType(ut)
 }
 
-func getUUID() string {
-	return uuid.New().String()
+// getUUID generates a UUID using the string form of the type name.
+func getUUID(ut reflect.Type) string {
+	// String produces non-empty output for pointer and slice types.
+	typename := ut.String()
+	hasher := fnv.New128a()
+	if n, err := hasher.Write([]byte(typename)); err != nil || n != len(typename) {
+		panic(fmt.Sprintf("unable to generate schema uuid for %s, wrote out %d bytes, want %d: err %v", typename, n, len(typename), err))
+	}
+	id, err := uuid.NewRandomFromReader(bytes.NewBuffer(hasher.Sum(nil)))
+	if err != nil {
+		panic(fmt.Sprintf("unable to genereate schema uuid for type %s: %v", typename, err))
+	}
+	return id.String()
 }
 
 // Registered returns whether the given type has been registered with
@@ -350,7 +363,7 @@
 		if lID != "" {
 			schm.Options = append(schm.Options, logicalOption(lID))
 		}
-		schm.Id = getUUID()
+		schm.Id = getUUID(ot)
 		r.typeToSchema[ot] = schm
 		r.idToType[schm.GetId()] = ot
 		return schm, nil
@@ -365,7 +378,7 @@
 	// Cache the pointer type here with it's own id.
 	pt := reflect.PtrTo(t)
 	schm = proto.Clone(schm).(*pipepb.Schema)
-	schm.Id = getUUID()
+	schm.Id = getUUID(pt)
 	schm.Options = append(schm.Options, &pipepb.Option{
 		Name: optGoNillable,
 	})
@@ -454,7 +467,7 @@
 		schm := ftype.GetRowType().GetSchema()
 		schm = proto.Clone(schm).(*pipepb.Schema)
 		schm.Options = append(schm.Options, logicalOption(lID))
-		schm.Id = getUUID()
+		schm.Id = getUUID(t)
 		r.typeToSchema[t] = schm
 		r.idToType[schm.GetId()] = t
 		return schm, nil
@@ -483,7 +496,7 @@
 
 	schm := &pipepb.Schema{
 		Fields: fields,
-		Id:     getUUID(),
+		Id:     getUUID(t),
 	}
 	r.idToType[schm.GetId()] = t
 	r.typeToSchema[t] = schm
diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
index 9ea5099..1ba1ac4 100644
--- a/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
+++ b/sdks/go/pkg/beam/core/runtime/xlangx/resolve.go
@@ -24,53 +24,105 @@
 	"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/graphx"
 	"github.com/apache/beam/sdks/go/pkg/beam/core/util/protox"
+	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
 	pipepb "github.com/apache/beam/sdks/go/pkg/beam/model/pipeline_v1"
 )
 
 // ResolveArtifacts acquires all dependencies for a cross-language transform
 func ResolveArtifacts(ctx context.Context, edges []*graph.MultiEdge, p *pipepb.Pipeline) {
-	path, err := filepath.Abs("/tmp/artifacts")
+	_, err := ResolveArtifactsWithConfig(ctx, edges, ResolveConfig{})
 	if err != nil {
 		panic(err)
 	}
+}
+
+// ResolveConfig contains fields for configuring the behavior for resolving
+// artifacts.
+type ResolveConfig struct {
+	// SdkPath replaces the default filepath for dependencies, but only in the
+	// external environment proto to be used by the SDK Harness during pipeline
+	// execution. This is used to specify alternate staging directories, such
+	// as for staging artifacts remotely.
+	//
+	// Setting an SdkPath does not change staging behavior otherwise. All
+	// artifacts still get staged to the default local filepath, and it is the
+	// user's responsibility to stage those local artifacts to the SdkPath.
+	SdkPath string
+
+	// JoinFn is a function for combining SdkPath and individual artifact names.
+	// If not specified, it defaults to using filepath.Join.
+	JoinFn func(path, name string) string
+}
+
+func defaultJoinFn(path, name string) string {
+	return filepath.Join(path, "/", name)
+}
+
+// ResolveArtifactsWithConfig acquires all dependencies for cross-language
+// transforms, but with some additional configuration to behavior. By default,
+// this function performs the following steps for each cross-language transform
+// in the list of edges:
+//   1. Retrieves a list of dependencies needed from the expansion service.
+//   2. Retrieves each dependency as an artifact and stages it to a default
+//      local filepath.
+//   3. Adds the dependencies to the transform's stored environment proto.
+// The changes that can be configured are documented in ResolveConfig.
+//
+// This returns a map of "local path" to "sdk path". By default these are
+// identical, unless ResolveConfig.SdkPath has been set.
+func ResolveArtifactsWithConfig(ctx context.Context, edges []*graph.MultiEdge, cfg ResolveConfig) (paths map[string]string, err error) {
+	tmpPath, err := filepath.Abs("/tmp/artifacts")
+	if err != nil {
+		return nil, errors.WithContext(err, "resolving remote artifacts")
+	}
+	if cfg.JoinFn == nil {
+		cfg.JoinFn = defaultJoinFn
+	}
+	paths = make(map[string]string)
 	for _, e := range edges {
 		if e.Op == graph.External {
 			components, err := graphx.ExpandedComponents(e.External.Expanded)
 			if err != nil {
-				panic(err)
+				return nil, errors.WithContextf(err,
+					"resolving remote artifacts for edge %v", e.Name())
 			}
 			envs := components.Environments
 			for eid, env := range envs {
-
 				if strings.HasPrefix(eid, "go") {
 					continue
 				}
 				deps := env.GetDependencies()
-				resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", path)
+				resolvedArtifacts, err := artifact.Materialize(ctx, e.External.ExpansionAddr, deps, "", tmpPath)
 				if err != nil {
-					panic(err)
+					return nil, errors.WithContextf(err,
+						"resolving remote artifacts for env %v in edge %v", eid, e.Name())
 				}
 
 				var resolvedDeps []*pipepb.ArtifactInformation
 				for _, a := range resolvedArtifacts {
-					name, sha256 := artifact.MustExtractFilePayload(a)
-					fullPath := filepath.Join(path, "/", name)
+					name, _ := artifact.MustExtractFilePayload(a)
+					fullTmpPath := filepath.Join(tmpPath, "/", name)
+					fullSdkPath := fullTmpPath
+					if len(cfg.SdkPath) > 0 {
+						fullSdkPath = cfg.JoinFn(cfg.SdkPath, name)
+					}
 					resolvedDeps = append(resolvedDeps,
 						&pipepb.ArtifactInformation{
 							TypeUrn: "beam:artifact:type:file:v1",
 							TypePayload: protox.MustEncode(
 								&pipepb.ArtifactFilePayload{
-									Path:   fullPath,
-									Sha256: sha256,
+									Path: fullSdkPath,
 								},
 							),
 							RoleUrn:     a.RoleUrn,
 							RolePayload: a.RolePayload,
 						},
 					)
+					paths[fullTmpPath] = fullSdkPath
 				}
 				env.Dependencies = resolvedDeps
 			}
 		}
 	}
+	return paths, nil
 }
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
index c46ff4c..cd7be52 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go
@@ -178,11 +178,23 @@
 	}
 
 	// (1) Build and submit
+	// NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
+	id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
+
+	modelURL := gcsx.Join(*stagingLocation, id, "model")
+	workerURL := gcsx.Join(*stagingLocation, id, "worker")
+	jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
+	xlangURL := gcsx.Join(*stagingLocation, id, "xlang")
 
 	edges, _, err := p.Build()
 	if err != nil {
 		return nil, err
 	}
+	artifactURLs, err := dataflowlib.ResolveXLangArtifacts(ctx, edges, opts.Project, xlangURL)
+	if err != nil {
+		return nil, errors.WithContext(err, "resolving cross-language artifacts")
+	}
+	opts.ArtifactURLs = artifactURLs
 	environment, err := graphx.CreateEnvironment(ctx, jobopts.GetEnvironmentUrn(ctx), getContainerImage)
 	if err != nil {
 		return nil, errors.WithContext(err, "creating environment for model pipeline")
@@ -196,13 +208,6 @@
 		return nil, errors.WithContext(err, "applying container image overrides")
 	}
 
-	// NOTE(herohde) 10/8/2018: the last segment of the names must be "worker" and "dataflow-worker.jar".
-	id := fmt.Sprintf("go-%v-%v", atomic.AddInt32(&unique, 1), time.Now().UnixNano())
-
-	modelURL := gcsx.Join(*stagingLocation, id, "model")
-	workerURL := gcsx.Join(*stagingLocation, id, "worker")
-	jarURL := gcsx.Join(*stagingLocation, id, "dataflow-worker.jar")
-
 	if *dryRun {
 		log.Info(ctx, "Dry-run: not submitting job!")
 
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
index 511b962..390082b 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
@@ -56,6 +56,7 @@
 	WorkerRegion        string
 	WorkerZone          string
 	ContainerImage      string
+	ArtifactURLs        []string // Additional packages for workers.
 
 	// Autoscaling settings
 	Algorithm     string
@@ -128,6 +129,15 @@
 		experiments = append(experiments, "use_staged_dataflow_worker_jar")
 	}
 
+	for _, url := range opts.ArtifactURLs {
+		name := url[strings.LastIndexAny(url, "/")+1:]
+		pkg := &df.Package{
+			Name:     name,
+			Location: url,
+		}
+		packages = append(packages, pkg)
+	}
+
 	ipConfiguration := "WORKER_IP_UNSPECIFIED"
 	if opts.NoUsePublicIPs {
 		ipConfiguration = "WORKER_IP_PRIVATE"
diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
index 67a2bde..49ca5bf 100644
--- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
+++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/stage.go
@@ -22,6 +22,8 @@
 	"os"
 
 	"cloud.google.com/go/storage"
+	"github.com/apache/beam/sdks/go/pkg/beam/core/graph"
+	"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/xlangx"
 	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
 	"github.com/apache/beam/sdks/go/pkg/beam/util/gcsx"
 )
@@ -54,3 +56,28 @@
 	_, err = gcsx.Upload(ctx, client, project, bucket, obj, r)
 	return err
 }
+
+// ResolveXLangArtifacts resolves cross-language artifacts with a given GCS
+// URL as a destination, and then stages all local artifacts to that URL. This
+// function returns a list of staged artifact URLs.
+func ResolveXLangArtifacts(ctx context.Context, edges []*graph.MultiEdge, project, url string) ([]string, error) {
+	cfg := xlangx.ResolveConfig{
+		SdkPath: url,
+		JoinFn: func(url, name string) string {
+			return gcsx.Join(url, "/", name)
+		},
+	}
+	paths, err := xlangx.ResolveArtifactsWithConfig(ctx, edges, cfg)
+	if err != nil {
+		return nil, err
+	}
+	var urls []string
+	for local, remote := range paths {
+		err := StageFile(ctx, project, remote, local)
+		if err != nil {
+			return nil, errors.WithContextf(err, "staging file to %v", remote)
+		}
+		urls = append(urls, remote)
+	}
+	return urls, nil
+}
diff --git a/sdks/go/pkg/beam/runners/universal/universal.go b/sdks/go/pkg/beam/runners/universal/universal.go
index f01a71f..af14b38 100644
--- a/sdks/go/pkg/beam/runners/universal/universal.go
+++ b/sdks/go/pkg/beam/runners/universal/universal.go
@@ -79,6 +79,9 @@
 		getEnvCfg = srv.EnvironmentConfig
 	}
 
+	// Fetch all dependencies for cross-language transforms
+	xlangx.ResolveArtifacts(ctx, edges, nil)
+
 	environment, err := graphx.CreateEnvironment(ctx, envUrn, getEnvCfg)
 	if err != nil {
 		return nil, errors.WithContextf(err, "generating model pipeline")
@@ -88,9 +91,6 @@
 		return nil, errors.WithContextf(err, "generating model pipeline")
 	}
 
-	// Fetch all dependencies for cross-language transforms
-	xlangx.ResolveArtifacts(ctx, edges, pipeline)
-
 	log.Info(ctx, proto.MarshalTextString(pipeline))
 
 	opt := &runnerlib.JobOptions{
diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go
index c756e05..7db20ac 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -78,12 +78,6 @@
 var dataflowFilters = []string{
 	// TODO(BEAM-11576): TestFlattenDup failing on this runner.
 	"TestFlattenDup",
-	// TODO(BEAM-11418): These tests require implementing x-lang artifact
-	// staging on Dataflow.
-	"TestXLang_CoGroupBy",
-	"TestXLang_Multi",
-	"TestXLang_Partition",
-	"TestXLang_Prefix",
 }
 
 // CheckFilters checks if an integration test is filtered to be skipped, either
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
index 3b782d9..c489d12 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFn.java
@@ -73,10 +73,9 @@
 
   @Override
   public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
-    Iterator<AccumT> it = accumulators.iterator();
-    AccumT first = it.next();
-    it.remove();
-    return getAggregateFn().mergeAccumulators(first, accumulators);
+    AccumT first = accumulators.iterator().next();
+    Iterable<AccumT> rest = new SkipFirstElementIterable<>(accumulators);
+    return getAggregateFn().mergeAccumulators(first, rest);
   }
 
   @Override
@@ -99,4 +98,20 @@
   public TypeVariable<?> getAccumTVariable() {
     return AggregateFn.class.getTypeParameters()[1];
   }
+
+  /** Wrapper {@link Iterable} which always skips its first element. */
+  private static class SkipFirstElementIterable<T> implements Iterable<T> {
+    private final Iterable<T> all;
+
+    SkipFirstElementIterable(Iterable<T> all) {
+      this.all = all;
+    }
+
+    @Override
+    public Iterator<T> iterator() {
+      Iterator<T> it = all.iterator();
+      it.next();
+      return it;
+    }
+  }
 }
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 34664ac..7b580ef 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -27,6 +27,7 @@
 import org.apache.beam.sdk.schemas.Schema.TypeName;
 import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType;
 import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
+import org.apache.beam.sdk.util.Preconditions;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.BiMap;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableBiMap;
 import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
@@ -39,9 +40,6 @@
 import org.joda.time.base.AbstractInstant;
 
 /** Utility methods for Calcite related operations. */
-@SuppressWarnings({
-  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
 public class CalciteUtils {
   private static final long UNLIMITED_ARRAY_SIZE = -1L;
 
@@ -73,7 +71,9 @@
     }
 
     if (fieldType.getTypeName().isLogicalType()) {
-      String logicalId = fieldType.getLogicalType().getIdentifier();
+      Schema.LogicalType logicalType = fieldType.getLogicalType();
+      Preconditions.checkArgumentNotNull(logicalType);
+      String logicalId = logicalType.getIdentifier();
       return logicalId.equals(SqlTypes.DATE.getIdentifier())
           || logicalId.equals(SqlTypes.TIME.getIdentifier())
           || logicalId.equals(TimeWithLocalTzType.IDENTIFIER)
@@ -88,7 +88,9 @@
     }
 
     if (fieldType.getTypeName().isLogicalType()) {
-      String logicalId = fieldType.getLogicalType().getIdentifier();
+      Schema.LogicalType logicalType = fieldType.getLogicalType();
+      Preconditions.checkArgumentNotNull(logicalType);
+      String logicalId = logicalType.getIdentifier();
       return logicalId.equals(CharType.IDENTIFIER);
     }
     return false;
@@ -210,7 +212,12 @@
                     + "so it cannot be converted to a %s",
                 sqlTypeName, Schema.FieldType.class.getSimpleName()));
       default:
-        return CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+        FieldType fieldType = CALCITE_TO_BEAM_TYPE_MAPPING.get(sqlTypeName);
+        if (fieldType == null) {
+          throw new IllegalArgumentException(
+              "Cannot find a matching Beam FieldType for Calcite type: " + sqlTypeName);
+        }
+        return fieldType;
     }
   }
 
@@ -234,7 +241,12 @@
         return FieldType.row(toSchema(calciteType));
 
       default:
-        return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+        try {
+          return toFieldType(calciteType.getSqlTypeName()).withNullable(calciteType.isNullable());
+        } catch (IllegalArgumentException e) {
+          throw new IllegalArgumentException(
+              "Cannot find a matching Beam FieldType for Calcite type: " + calciteType, e);
+        }
     }
   }
 
@@ -254,16 +266,22 @@
     switch (fieldType.getTypeName()) {
       case ARRAY:
       case ITERABLE:
+        FieldType collectionElementType = fieldType.getCollectionElementType();
+        Preconditions.checkArgumentNotNull(collectionElementType);
         return dataTypeFactory.createArrayType(
-            toRelDataType(dataTypeFactory, fieldType.getCollectionElementType()),
-            UNLIMITED_ARRAY_SIZE);
+            toRelDataType(dataTypeFactory, collectionElementType), UNLIMITED_ARRAY_SIZE);
       case MAP:
-        RelDataType componentKeyType = toRelDataType(dataTypeFactory, fieldType.getMapKeyType());
-        RelDataType componentValueType =
-            toRelDataType(dataTypeFactory, fieldType.getMapValueType());
+        FieldType mapKeyType = fieldType.getMapKeyType();
+        FieldType mapValueType = fieldType.getMapValueType();
+        Preconditions.checkArgumentNotNull(mapKeyType);
+        Preconditions.checkArgumentNotNull(mapValueType);
+        RelDataType componentKeyType = toRelDataType(dataTypeFactory, mapKeyType);
+        RelDataType componentValueType = toRelDataType(dataTypeFactory, mapValueType);
         return dataTypeFactory.createMapType(componentKeyType, componentValueType);
       case ROW:
-        return toCalciteRowType(fieldType.getRowSchema(), dataTypeFactory);
+        Schema schema = fieldType.getRowSchema();
+        Preconditions.checkArgumentNotNull(schema);
+        return toCalciteRowType(schema, dataTypeFactory);
       default:
         return dataTypeFactory.createSqlType(toSqlTypeName(fieldType));
     }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
index 21ab8d0..cf3f40d 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/LazyAggregateCombineFnTest.java
@@ -19,12 +19,14 @@
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertEquals;
 
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -41,6 +43,13 @@
     assertThat(coder, instanceOf(VarLongCoder.class));
   }
 
+  @Test
+  public void mergeAccumulators() {
+    LazyAggregateCombineFn<Long, Long, Long> combiner = new LazyAggregateCombineFn<>(new Sum());
+    long merged = combiner.mergeAccumulators(ImmutableList.of(1L, 1L));
+    assertEquals(2L, merged);
+  }
+
   public static class Sum implements AggregateFn<Long, Long, Long> {
 
     @Override
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
index 50b6ab2..e76ee7f 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java
@@ -30,13 +30,17 @@
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 /** Tests for conversion from Beam schema to Calcite data type. */
 public class CalciteUtilsTest {
 
   RelDataTypeFactory dataTypeFactory;
 
+  @Rule public ExpectedException thrown = ExpectedException.none();
+
   @Before
   public void setUp() {
     dataTypeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
@@ -166,4 +170,12 @@
 
     assertEquals(schema, out);
   }
+
+  @Test
+  public void testFieldTypeNotFound() {
+    RelDataType relDataType = dataTypeFactory.createUnknownType();
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("Cannot find a matching Beam FieldType for Calcite type: UNKNOWN");
+    CalciteUtils.toFieldType(relDataType);
+  }
 }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
index 6413fd2..82c4b74 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubTableProviderIT.java
@@ -729,6 +729,17 @@
     Map<String, String> argsMap =
         ((Map<String, Object>) MAPPER.convertValue(pipeline.getOptions(), Map.class).get("options"))
             .entrySet().stream()
+                .filter(
+                    (entry) -> {
+                      if (entry.getValue() instanceof List) {
+                        if (!((List) entry.getValue()).isEmpty()) {
+                          throw new IllegalArgumentException("Cannot encode list arguments");
+                        }
+                        // We can encode empty lists, just omit them.
+                        return false;
+                      }
+                      return true;
+                    })
                 .collect(Collectors.toMap(Map.Entry::getKey, entry -> toArg(entry.getValue())));
 
     InMemoryMetaStore inMemoryMetaStore = new InMemoryMetaStore();
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 231d2b8..0ca4581 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -272,12 +272,14 @@
               finalizeBundleHandler,
               metricsShortIds);
 
+      BeamFnStatusClient beamFnStatusClient = null;
       if (statusApiServiceDescriptor != null) {
-        new BeamFnStatusClient(
-            statusApiServiceDescriptor,
-            channelFactory::forDescriptor,
-            processBundleHandler.getBundleProcessorCache(),
-            options);
+        beamFnStatusClient =
+            new BeamFnStatusClient(
+                statusApiServiceDescriptor,
+                channelFactory::forDescriptor,
+                processBundleHandler.getBundleProcessorCache(),
+                options);
       }
 
       // TODO(BEAM-9729): Remove once runners no longer send this instruction.
@@ -337,6 +339,9 @@
               executorService,
               handlers);
       control.waitForTermination();
+      if (beamFnStatusClient != null) {
+        beamFnStatusClient.close();
+      }
       processBundleHandler.shutdown();
     } finally {
       System.out.println("Shutting SDK harness down.");
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 9af7f58..296767b 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -48,8 +48,6 @@
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 
 /**
  * The {@code PCollectionConsumerRegistry} is used to maintain a collection of consuming
@@ -216,7 +214,7 @@
     private final String pTransformId;
     private final SimpleExecutionState state;
     private final Counter unboundedElementCountCounter;
-    private final SampleByteSizeDistribution<T> unboundSampledByteSizeDistribution;
+    private final SampleByteSizeDistribution<T> unboundedSampledByteSizeDistribution;
     private final Coder<T> coder;
     private final MetricsContainer metricsContainer;
 
@@ -239,7 +237,7 @@
 
       MonitoringInfoMetricName sampledByteSizeMetricName =
           MonitoringInfoMetricName.named(Urns.SAMPLED_BYTE_SIZE, labels);
-      this.unboundSampledByteSizeDistribution =
+      this.unboundedSampledByteSizeDistribution =
           new SampleByteSizeDistribution<>(
               unboundMetricContainer.getDistribution(sampledByteSizeMetricName));
 
@@ -252,7 +250,7 @@
       // Increment the counter for each window the element occurs in.
       this.unboundedElementCountCounter.inc(input.getWindows().size());
       // TODO(BEAM-11879): Consider updating size per window when we have window optimization.
-      this.unboundSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
+      this.unboundedSampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
       // Wrap the consumer with extra logic to set the metric container with the appropriate
       // PTransform context. This ensures that user metrics obtain the pTransform ID when they are
       // created. Also use the ExecutionStateTracker and enter an appropriate state to track the
@@ -262,6 +260,7 @@
           this.delegate.accept(input);
         }
       }
+      this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
     }
   }
 
@@ -321,6 +320,7 @@
             consumerAndMetadata.getConsumer().accept(input);
           }
         }
+        this.unboundedSampledByteSizeDistribution.finishLazyUpdate();
       }
     }
   }
@@ -365,28 +365,33 @@
     }
 
     final Distribution distribution;
+    ByteSizeObserver byteCountObserver;
 
     public SampleByteSizeDistribution(Distribution distribution) {
       this.distribution = distribution;
+      this.byteCountObserver = null;
     }
 
     public void tryUpdate(T value, Coder<T> coder) throws Exception {
       if (shouldSampleElement()) {
         // First try using byte size observer
-        ByteSizeObserver observer = new ByteSizeObserver();
-        coder.registerByteSizeObserver(value, observer);
+        byteCountObserver = new ByteSizeObserver();
+        coder.registerByteSizeObserver(value, byteCountObserver);
 
-        if (!observer.getIsLazy()) {
-          observer.advance();
-          this.distribution.update(observer.observedSize);
-        } else {
-          // TODO(BEAM-11841): Optimize calculation of element size for iterables.
-          // Coder byte size observation is lazy (requires iteration for observation) so fall back
-          // to counting output stream
-          CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-          coder.encode(value, os);
-          this.distribution.update(os.getCount());
+        if (!byteCountObserver.getIsLazy()) {
+          byteCountObserver.advance();
+          this.distribution.update(byteCountObserver.observedSize);
         }
+      } else {
+        byteCountObserver = null;
+      }
+    }
+
+    public void finishLazyUpdate() {
+      // Advance lazy ElementByteSizeObservers, if any.
+      if (byteCountObserver != null && byteCountObserver.getIsLazy()) {
+        byteCountObserver.advance();
+        this.distribution.update(byteCountObserver.observedSize);
       }
     }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
index 4c01c04..e059471 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
@@ -25,6 +25,8 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.StringJoiner;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
@@ -42,9 +44,12 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public class BeamFnStatusClient {
+public class BeamFnStatusClient implements AutoCloseable {
+  private static final Object COMPLETED = new Object();
   private final StreamObserver<WorkerStatusResponse> outboundObserver;
   private final BundleProcessorCache processBundleCache;
+  private final ManagedChannel channel;
+  private final CompletableFuture<Object> inboundObserverCompletion;
   private static final Logger LOG = LoggerFactory.getLogger(BeamFnStatusClient.class);
   private final MemoryMonitor memoryMonitor;
 
@@ -53,11 +58,12 @@
       Function<ApiServiceDescriptor, ManagedChannel> channelFactory,
       BundleProcessorCache processBundleCache,
       PipelineOptions options) {
-    BeamFnWorkerStatusGrpc.BeamFnWorkerStatusStub stub =
-        BeamFnWorkerStatusGrpc.newStub(channelFactory.apply(apiServiceDescriptor));
-    this.outboundObserver = stub.workerStatus(new InboundObserver());
+    this.channel = channelFactory.apply(apiServiceDescriptor);
+    this.outboundObserver =
+        BeamFnWorkerStatusGrpc.newStub(channel).workerStatus(new InboundObserver());
     this.processBundleCache = processBundleCache;
     this.memoryMonitor = MemoryMonitor.fromOptions(options);
+    this.inboundObserverCompletion = new CompletableFuture<>();
     Thread thread = new Thread(memoryMonitor);
     thread.setDaemon(true);
     thread.setPriority(Thread.MIN_PRIORITY);
@@ -65,6 +71,22 @@
     thread.start();
   }
 
+  @Override
+  public void close() throws Exception {
+    try {
+      Object completion = inboundObserverCompletion.get(1, TimeUnit.MINUTES);
+      if (completion != COMPLETED) {
+        LOG.warn("InboundObserver for BeamFnStatusClient completed with exception.");
+      }
+    } finally {
+      // Shut the channel down
+      channel.shutdown();
+      if (!channel.awaitTermination(10, TimeUnit.SECONDS)) {
+        channel.shutdownNow();
+      }
+    }
+  }
+
   /**
    * Class representing the execution state of a thread.
    *
@@ -222,9 +244,12 @@
     @Override
     public void onError(Throwable t) {
       LOG.error("Error getting SDK harness status", t);
+      inboundObserverCompletion.completeExceptionally(t);
     }
 
     @Override
-    public void onCompleted() {}
+    public void onCompleted() {
+      inboundObserverCompletion.complete(COMPLETED);
+    }
   }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index 90baa5e..708ca8b 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -24,6 +24,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -32,7 +33,9 @@
 import static org.powermock.api.mockito.PowerMockito.mockStatic;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import org.apache.beam.fn.harness.HandlesSplits;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
@@ -44,15 +47,19 @@
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
 import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
+import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
@@ -329,5 +336,108 @@
     verify(consumerA1).trySplit(0.3);
   }
 
+  @Test
+  public void testLazyByteSizeEstimation() throws Exception {
+    final String pCollectionA = "pCollectionA";
+    final String pTransformIdA = "pTransformIdA";
+
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class);
+
+    consumers.register(
+        pCollectionA, pTransformIdA, consumerA1, IterableCoder.of(StringUtf8Coder.of()));
+
+    FnDataReceiver<WindowedValue<Iterable<String>>> wrapperConsumer =
+        (FnDataReceiver<WindowedValue<Iterable<String>>>)
+            (FnDataReceiver) consumers.getMultiplexingConsumer(pCollectionA);
+    String elementValue = "elem";
+    long elementByteSize = StringUtf8Coder.of().getEncodedElementByteSize(elementValue);
+    WindowedValue<Iterable<String>> element =
+        valueInGlobalWindow(
+            new TestElementByteSizeObservableIterable<>(
+                Arrays.asList(elementValue, elementValue), elementByteSize));
+    int numElements = 10;
+    // Mock doing work on the iterable items
+    doAnswer(
+            (Answer<Void>)
+                invocation -> {
+                  Object[] args = invocation.getArguments();
+                  WindowedValue<Iterable<String>> arg = (WindowedValue<Iterable<String>>) args[0];
+                  Iterator it = arg.getValue().iterator();
+                  while (it.hasNext()) {
+                    it.next();
+                  }
+                  return null;
+                })
+        .when(consumerA1)
+        .accept(element);
+
+    for (int i = 0; i < numElements; i++) {
+      wrapperConsumer.accept(element);
+    }
+
+    // Check that the underlying consumers are each invoked per element.
+    verify(consumerA1, times(numElements)).accept(element);
+    assertThat(consumers.keySet(), contains(pCollectionA));
+
+    List<MonitoringInfo> expected = new ArrayList<>();
+
+    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+    builder.setInt64SumValue(numElements);
+    expected.add(builder.build());
+
+    builder = new SimpleMonitoringInfoBuilder();
+    builder.setUrn(Urns.SAMPLED_BYTE_SIZE);
+    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, pCollectionA);
+    long expectedBytes =
+        (elementByteSize + 1) * 2
+            + 5; // Additional 5 bytes are due to size and hasNext = false (1 byte).
+    builder.setInt64DistributionValue(
+        DistributionData.create(
+            numElements * expectedBytes, numElements, expectedBytes, expectedBytes));
+    expected.add(builder.build());
+    // Clear the timestamp before comparison.
+    Iterable<MonitoringInfo> result =
+        Iterables.filter(
+            metricsContainerRegistry.getMonitoringInfos(),
+            monitoringInfo -> monitoringInfo.containsLabels(Labels.PCOLLECTION));
+
+    assertThat(result, containsInAnyOrder(expected.toArray()));
+  }
+
+  private class TestElementByteSizeObservableIterable<T>
+      extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>> {
+    private List<T> elements;
+    private long elementByteSize;
+
+    public TestElementByteSizeObservableIterable(List<T> elements, long elementByteSize) {
+      this.elements = elements;
+      this.elementByteSize = elementByteSize;
+    }
+
+    @Override
+    protected ElementByteSizeObservableIterator createIterator() {
+      return new ElementByteSizeObservableIterator() {
+        private int index = 0;
+
+        @Override
+        public boolean hasNext() {
+          return index < elements.size();
+        }
+
+        @Override
+        public Object next() {
+          notifyValueReturned(elementByteSize);
+          return elements.get(index++);
+        }
+      };
+    }
+  }
+
   private abstract static class SplittingReceiver<T> implements FnDataReceiver<T>, HandlesSplits {}
 }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
index 4193ba6..fc1fe94 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQueryHelper.java
@@ -109,7 +109,10 @@
               .getReferencedTables();
       if (referencedTables != null && !referencedTables.isEmpty()) {
         TableReference referencedTable = referencedTables.get(0);
-        effectiveLocation = tableService.getTable(referencedTable).getLocation();
+        effectiveLocation =
+            tableService
+                .getDataset(referencedTable.getProjectId(), referencedTable.getDatasetId())
+                .getLocation();
       }
     }
 
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index 84743f8c..820418c 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -43,7 +43,7 @@
 
 configurations.testRuntimeClasspath {
   resolutionStrategy {
-    def log4j_version = "2.4.1"
+    def log4j_version = "2.8.2"
     // Beam's build system forces a uniform log4j version resolution for all modules, however for
     // the HCatalog case the current version of log4j produces NoClassDefFoundError so we need to
     // force an old version on the tests runtime classpath
diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py
index cc30cec..ee3c2a3 100644
--- a/sdks/python/apache_beam/dataframe/schemas.py
+++ b/sdks/python/apache_beam/dataframe/schemas.py
@@ -281,7 +281,8 @@
     ctor = element_type_from_dataframe(proxy, include_indexes=include_indexes)
 
     return beam.ParDo(
-        _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor))
+        _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor)
+    ).with_output_types(ctor)
   elif isinstance(proxy, pd.Series):
     # Raise a TypeError if proxy has an unknown type
     output_type = _dtype_to_fieldtype(proxy.dtype)
diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py
index 8b1159c..af30e25 100644
--- a/sdks/python/apache_beam/dataframe/schemas_test.py
+++ b/sdks/python/apache_beam/dataframe/schemas_test.py
@@ -36,6 +36,8 @@
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
 
 Simple = typing.NamedTuple(
     'Simple', [('name', unicode), ('id', int), ('height', float)])
@@ -65,41 +67,59 @@
 # dtype. For example:
 #   pd.Series([b'abc'], dtype=bytes).dtype != 'S'
 #   pd.Series([b'abc'], dtype=bytes).astype(bytes).dtype == 'S'
+# (test data, pandas_type, column_name, beam_type)
 COLUMNS = [
-    ([375, 24, 0, 10, 16], np.int32, 'i32'),
-    ([375, 24, 0, 10, 16], np.int64, 'i64'),
-    ([375, 24, None, 10, 16], pd.Int32Dtype(), 'i32_nullable'),
-    ([375, 24, None, 10, 16], pd.Int64Dtype(), 'i64_nullable'),
-    ([375., 24., None, 10., 16.], np.float64, 'f64'),
-    ([375., 24., None, 10., 16.], np.float32, 'f32'),
-    ([True, False, True, True, False], bool, 'bool'),
-    (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any'),
-    ([True, False, True, None, False], pd.BooleanDtype(), 'bool_nullable'),
+    ([375, 24, 0, 10, 16], np.int32, 'i32', np.int32),
+    ([375, 24, 0, 10, 16], np.int64, 'i64', np.int64),
+    ([375, 24, None, 10, 16],
+     pd.Int32Dtype(),
+     'i32_nullable',
+     typing.Optional[np.int32]),
+    ([375, 24, None, 10, 16],
+     pd.Int64Dtype(),
+     'i64_nullable',
+     typing.Optional[np.int64]),
+    ([375., 24., None, 10., 16.],
+     np.float64,
+     'f64',
+     typing.Optional[np.float64]),
+    ([375., 24., None, 10., 16.],
+     np.float32,
+     'f32',
+     typing.Optional[np.float32]),
+    ([True, False, True, True, False], bool, 'bool', bool),
+    (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any', typing.Any),
+    ([True, False, True, None, False],
+     pd.BooleanDtype(),
+     'bool_nullable',
+     typing.Optional[bool]),
     (['Falcon', 'Ostrich', None, 'Aardvark', 'Elephant'],
      pd.StringDtype(),
-     'strdtype'),
-]  # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str]]
+     'strdtype',
+     typing.Optional[str]),
+]  # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str, typing.Any]]
 
-NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name in COLUMNS])
-for arr, dtype, name in COLUMNS:
+NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name, _ in COLUMNS])
+for arr, dtype, name, _ in COLUMNS:
   NICE_TYPES_DF[name] = pd.Series(arr, dtype=dtype, name=name).astype(dtype)
 
 NICE_TYPES_PROXY = NICE_TYPES_DF[:0]
 
-SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr) for arr,
-                dtype,
-                name in COLUMNS]
+SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr, beam_type)
+                for (arr, dtype, name, beam_type) in COLUMNS]
 
 _TEST_ARRAYS = [
-    arr for arr, _, _ in COLUMNS
+    arr for (arr, _, _, _) in COLUMNS
 ]  # type: typing.List[typing.List[typing.Any]]
 DF_RESULT = list(zip(*_TEST_ARRAYS))
-INDEX_DF_TESTS = [
-    (NICE_TYPES_DF.set_index([name for _, _, name in COLUMNS[:i]]), DF_RESULT)
-    for i in range(1, len(COLUMNS) + 1)
-]
+BEAM_SCHEMA = typing.NamedTuple(  # type: ignore
+    'BEAM_SCHEMA', [(name, beam_type) for _, _, name, beam_type in COLUMNS])
+INDEX_DF_TESTS = [(
+    NICE_TYPES_DF.set_index([name for _, _, name, _ in COLUMNS[:i]]),
+    DF_RESULT,
+    BEAM_SCHEMA) for i in range(1, len(COLUMNS) + 1)]
 
-NOINDEX_DF_TESTS = [(NICE_TYPES_DF, DF_RESULT)]
+NOINDEX_DF_TESTS = [(NICE_TYPES_DF, DF_RESULT, BEAM_SCHEMA)]
 
 PD_VERSION = tuple(int(n) for n in pd.__version__.split('.'))
 
@@ -203,8 +223,18 @@
                 proxy=schemas.generate_proxy(Animal)))
         assert_that(res, equal_to([('Falcon', 375.), ('Parrot', 25.)]))
 
+  def assert_typehints_equal(self, left, right):
+    left = typehints.normalize(left)
+    right = typehints.normalize(right)
+
+    if match_is_named_tuple(left):
+      self.assertTrue(match_is_named_tuple(right))
+      self.assertEqual(left.__annotations__, right.__annotations__)
+    else:
+      self.assertEqual(left, right)
+
   @parameterized.expand(SERIES_TESTS + NOINDEX_DF_TESTS)
-  def test_unbatch_no_index(self, df_or_series, rows):
+  def test_unbatch_no_index(self, df_or_series, rows, beam_type):
     proxy = df_or_series[:0]
 
     with TestPipeline() as p:
@@ -212,10 +242,15 @@
           p | beam.Create([df_or_series[::2], df_or_series[1::2]])
           | schemas.UnbatchPandas(proxy))
 
+      # Verify that the unbatched PCollection has the expected typehint
+      # TODO(BEAM-8538): typehints should support NamedTuple so we can use
+      # typehints.is_consistent_with here instead
+      self.assert_typehints_equal(res.element_type, beam_type)
+
       assert_that(res, equal_to(rows))
 
   @parameterized.expand(SERIES_TESTS + INDEX_DF_TESTS)
-  def test_unbatch_with_index(self, df_or_series, rows):
+  def test_unbatch_with_index(self, df_or_series, rows, _):
     proxy = df_or_series[:0]
 
     with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index c6d816a..e1f6635 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -123,6 +123,11 @@
 
   try:
     _load_main_session(semi_persistent_directory)
+  except CorruptMainSessionException:
+    exception_details = traceback.format_exc()
+    _LOGGER.error(
+        'Could not load main session: %s', exception_details, exc_info=True)
+    raise
   except Exception:  # pylint: disable=broad-except
     exception_details = traceback.format_exc()
     _LOGGER.error(
@@ -245,12 +250,29 @@
   return 0
 
 
+class CorruptMainSessionException(Exception):
+  """
+  Used to crash this worker if a main session file was provided but
+  is not valid.
+  """
+  pass
+
+
 def _load_main_session(semi_persistent_directory):
   """Loads a pickled main session from the path specified."""
   if semi_persistent_directory:
     session_file = os.path.join(
         semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
     if os.path.isfile(session_file):
+      # If the expected session file is present but empty, it's likely that
+      # the user code run by this worker will likely crash at runtime.
+      # This can happen if the worker fails to download the main session.
+      # Raise a fatal error and crash this worker, forcing a restart.
+      if os.path.getsize(session_file) == 0:
+        raise CorruptMainSessionException(
+            'Session file found, but empty: %s. Functions defined in __main__ '
+            '(interactive session) will almost certainly fail.' %
+            (session_file, ))
       pickler.load_session(session_file)
     else:
       _LOGGER.warning(
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 2b738bf..d77727e 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -107,7 +107,7 @@
   return getattr(user_type, '__origin__', None) is expected_origin
 
 
-def _match_is_named_tuple(user_type):
+def match_is_named_tuple(user_type):
   return (
       _safe_issubclass(user_type, typing.Tuple) and
       hasattr(user_type, '_field_types'))
@@ -234,7 +234,7 @@
       # We just convert it to Any for now.
       # This MUST appear before the entry for the normal Tuple.
       _TypeMapEntry(
-          match=_match_is_named_tuple, arity=0, beam_type=typehints.Any),
+          match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
       _TypeMapEntry(
           match=_match_issubclass(typing.Tuple),
           arity=-1,
diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py
index 6a299bd..5daf68a 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -70,10 +70,10 @@
 from apache_beam.typehints import row_type
 from apache_beam.typehints.native_type_compatibility import _get_args
 from apache_beam.typehints.native_type_compatibility import _match_is_exactly_mapping
-from apache_beam.typehints.native_type_compatibility import _match_is_named_tuple
 from apache_beam.typehints.native_type_compatibility import _match_is_optional
 from apache_beam.typehints.native_type_compatibility import _safe_issubclass
 from apache_beam.typehints.native_type_compatibility import extract_optional_type
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
 from apache_beam.utils import proto_utils
 from apache_beam.utils.timestamp import Timestamp
 
@@ -148,7 +148,7 @@
 
 
 def typing_to_runner_api(type_):
-  if _match_is_named_tuple(type_):
+  if match_is_named_tuple(type_):
     schema = None
     if hasattr(type_, _BEAM_SCHEMA_ID):
       schema = SCHEMA_REGISTRY.get_schema_by_id(getattr(type_, _BEAM_SCHEMA_ID))
@@ -287,7 +287,7 @@
     # TODO(BEAM-10722): Make sure beam.Row generated schemas are registered and
     # de-duped
     return named_fields_to_schema(element_type._fields)
-  elif _match_is_named_tuple(element_type):
+  elif match_is_named_tuple(element_type):
     return named_tuple_to_schema(element_type)
   else:
     raise TypeError(