Register MapCoder, some comments/cleanup. (#15471)
* Register Mapcoder, some comments/cleanup.
* Fix incorrect type declarations in spannerio.
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index 370110d..05f7a9d 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -531,6 +531,13 @@
return coder_impl.MapCoderImpl(
self._key_coder.get_impl(), self._value_coder.get_impl())
+ @classmethod
+ def from_type_hint(cls, typehint, registry):
+ # type: (typehints.DictConstraint, CoderRegistry) -> MapCoder
+ return cls(
+ registry.get_coder(typehint.key_type),
+ registry.get_coder(typehint.value_type))
+
def to_type_hint(self):
return typehints.Dict[self._key_coder.to_type_hint(),
self._value_coder.to_type_hint()]
diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py
index 8b13b46..4d67f8b 100644
--- a/sdks/python/apache_beam/coders/row_coder.py
+++ b/sdks/python/apache_beam/coders/row_coder.py
@@ -87,10 +87,10 @@
def from_runner_api_parameter(schema, components, unused_context):
return RowCoder(schema)
- @staticmethod
- def from_type_hint(type_hint, registry):
+ @classmethod
+ def from_type_hint(cls, type_hint, registry):
schema = schema_from_element_type(type_hint)
- return RowCoder(schema)
+ return cls(schema)
@staticmethod
def from_payload(payload):
diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py
index 4fe5f3b..03b6cee 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -87,12 +87,14 @@
def register_standard_coders(self, fallback_coder):
"""Register coders for all basic and composite types."""
+ # Coders without subclasses.
self._register_coder_internal(int, coders.VarIntCoder)
self._register_coder_internal(float, coders.FloatCoder)
self._register_coder_internal(bytes, coders.BytesCoder)
self._register_coder_internal(bool, coders.BooleanCoder)
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
+ self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [coders.ProtoCoder, coders.FastPrimitivesCoder]
diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
index bb8c5b8..d50b3a8 100644
--- a/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
+++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio.py
@@ -302,7 +302,7 @@
return snapshot_options
-@with_input_types(ReadOperation, typing.Dict[typing.Any, typing.Any])
+@with_input_types(ReadOperation, _SPANNER_TRANSACTION)
@with_output_types(typing.List[typing.Any])
class _NaiveSpannerReadDoFn(DoFn):
def __init__(self, spanner_configuration):
@@ -422,7 +422,7 @@
@with_input_types(int)
-@with_output_types(typing.Dict[typing.Any, typing.Any])
+@with_output_types(_SPANNER_TRANSACTION)
class _CreateTransactionFn(DoFn):
"""
A DoFn to create the transaction of cloud spanner.
diff --git a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
index 4cdf294..e3d1965 100644
--- a/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
+++ b/sdks/python/apache_beam/io/gcp/experimental/spannerio_test.py
@@ -19,6 +19,7 @@
import logging
import random
import string
+import typing
import unittest
import mock
@@ -336,7 +337,10 @@
self, mock_batch_snapshot_class, mock_client_class):
with self.assertRaises(ValueError):
p = TestPipeline()
- transaction = (p | beam.Create([{"invalid": "transaction"}]))
+ transaction = (
+ p | beam.Create([{
+ "invalid": "transaction"
+ }]).with_output_types(typing.Any))
_ = (
p | 'with query' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index d0aaa2f..287ec49 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -300,6 +300,7 @@
original_transform_node.full_label,
original_transform_node.main_inputs)
+ # TODO(BEAM-12854): Merge rather than override.
replacement_transform_node.resource_hints = (
original_transform_node.resource_hints)
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 28024fa..63bbece 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -178,7 +178,8 @@
pcolls_dict = {str(ix): pcolls[ix] for ix in range(num_tags)}
restore_tags = lambda vs: tuple(vs[str(ix)] for ix in range(num_tags))
- result = pcolls_dict | _CoGBKImpl(pipeline=self.pipeline)
+ result = (
+ pcolls_dict | 'CoGroupByKeyImpl' >> _CoGBKImpl(pipeline=self.pipeline))
if restore_tags:
return result | 'RestoreTags' >> MapTuple(
lambda k, vs: (k, restore_tags(vs)))