Python support for directly using Java transforms using constructor and builder methods.
diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py
index 4b8838f..1af6972 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -29,6 +29,7 @@
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardRequirements
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardResourceHints
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardSideInputTypes
+from apache_beam.portability.api.external_transforms_pb2 import ExpansionMethods
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfo
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfoSpecs
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfoTypeUrns
@@ -66,3 +67,5 @@
requirements = StandardRequirements.Enum
displayData = StandardDisplayData.DisplayData
+
+java_class_lookup = ExpansionMethods.JAVA_CLASS_LOOKUP
\ No newline at end of file
diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py
index 305cb97..d1e8ab7 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -25,6 +25,7 @@
import copy
import functools
import threading
+from collections import OrderedDict
from typing import Dict
import grpc
@@ -36,7 +37,9 @@
from apache_beam.portability.api import beam_expansion_api_pb2
from apache_beam.portability.api import beam_expansion_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.portability.api.external_transforms_pb2 import BuilderMethod
from apache_beam.portability.api.external_transforms_pb2 import ExternalConfigurationPayload
+from apache_beam.portability.api.external_transforms_pb2 import JavaClassLookupPayload
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import artifact_service
from apache_beam.transforms import ptransform
@@ -144,6 +147,95 @@
return self._tuple_instance
+class JavaClassLookupPayloadBuilder(PayloadBuilder):
+ """
+ Builds a payload for directly instantiating a Java transform using a
+ constructor and builder methods.
+ """
+
+ def __init__(self, class_name):
+ """
+ :param class_name: fully qualified name of the transform class.
+ """
+ if not class_name:
+ raise ValueError('Class name must not be empty')
+
+ self._class_name = class_name
+ self._constructor_method = None
+ self._constructor_params = None
+ self._builder_methods_and_params = OrderedDict()
+
+ def _get_schema_proto_and_payload(self, **kwargs):
+ named_fields = []
+ for key, value in kwargs.items():
+ if not key:
+ raise ValueError('Parameter name cannot be empty')
+ if value is None:
+ raise ValueError('Received value None for key %s. None values are currently not supported' % key)
+ named_fields.append((key, convert_to_typing_type(instance_to_type(value))))
+
+ schema_proto = named_fields_to_schema(named_fields)
+ row = named_tuple_from_schema(schema_proto)(**kwargs)
+ schema = named_tuple_to_schema(type(row))
+
+ payload = RowCoder(schema).encode(row)
+ return (schema_proto, payload)
+
+ def build(self):
+ constructor_params = self._constructor_params or {}
+ constructor_schema, constructor_payload = self._get_schema_proto_and_payload(**constructor_params)
+ payload = JavaClassLookupPayload(
+ class_name=self._class_name,
+ constructor_schema=constructor_schema,
+ constructor_payload=constructor_payload)
+ if self._constructor_method:
+ payload.constructor_method = self._constructor_method
+
+ for builder_method_name, params in self._builder_methods_and_params.items():
+ builder_method_schema, builder_method_payload = self._get_schema_proto_and_payload(**params)
+ builder_method = BuilderMethod(
+ name=builder_method_name,
+ schema=builder_method_schema,
+ payload=builder_method_payload)
+ builder_method.name = builder_method_name
+ payload.builder_methods.append(builder_method)
+ return payload
+
+ def add_constructor(self, **kwargs):
+ """
+ Specifies the Java constructor to use.
+
+ :param kwargs: parameter names and values of the constructor.
+ """
+ if self._constructor_method or self._constructor_params:
+ raise ValueError('Constructor or constructor method can only be specified once')
+
+ self._constructor_params = kwargs
+
+ def add_constructor_method(self, method_name, **kwargs):
+ """
+ Specifies the Java constructor method to use.
+
+ :param method_name: name of the constructor method.
+ :param kwargs: parameter names and values of the constructor method.
+ """
+ if self._constructor_method or self._constructor_params:
+ raise ValueError('Constructor or constructor method can only be specified once')
+
+ self._constructor_method = method_name
+ self._constructor_params = kwargs
+
+ def add_builder_method(self, method_name, **kwargs):
+ """
+ Specifies a Java builder method to be invoked after instantiating the Java
+ transform class. Specified builder method will be applied in order.
+
+ :param method_name: name of the builder method.
+ :param kwargs: parameter names and values of the builder method.
+ """
+ self._builder_methods_and_params[method_name] = kwargs
+
+
class AnnotationBasedPayloadBuilder(SchemaBasedPayloadBuilder):
"""
Build a payload based on an external transform's type annotations.
@@ -212,6 +304,8 @@
or an address (as a str) to a grpc server that provides this method.
"""
expansion_service = expansion_service or DEFAULT_EXPANSION_SERVICE
+ if not urn and isinstance(payload, JavaClassLookupPayloadBuilder):
+ urn = common_urns.java_class_lookup
self._urn = urn
self._payload = (
payload.payload() if isinstance(payload, PayloadBuilder) else payload)
diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py
index c357f7a..69da239 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -28,17 +28,21 @@
from apache_beam import Pipeline
from apache_beam.coders import RowCoder
from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.portability.api.external_transforms_pb2 import BuilderMethod
from apache_beam.portability.api.external_transforms_pb2 import ExternalConfigurationPayload
+from apache_beam.portability.api.external_transforms_pb2 import JavaClassLookupPayload
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import expansion_service
from apache_beam.runners.portability.expansion_service_test import FibTransform
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.external import AnnotationBasedPayloadBuilder
+from apache_beam.transforms.external import JavaClassLookupPayloadBuilder
from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
from apache_beam.transforms.external import NamedTupleBasedPayloadBuilder
from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import convert_to_beam_type
+from apache_beam.utils import proto_utils
# Protect against environments where apitools library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
@@ -405,6 +409,71 @@
return get_payload(DataclassTransform(**values))
+class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
+
+ def _verify_row(self, schema, row_payload, expected_values):
+ row = RowCoder(schema).decode(row_payload)
+
+ for attr_name, expected_value in expected_values.items():
+ self.assertTrue(hasattr(row,attr_name))
+ value = getattr(row, attr_name)
+ self.assertEqual(expected_value, value)
+
+ def test_build_payload_with_constructor(self):
+ payload_builder = JavaClassLookupPayloadBuilder('dummy_class_name')
+
+ payload_builder.add_constructor(str_field='abc', int_field=123)
+ payload_bytes = payload_builder.payload()
+ payload_from_bytes = proto_utils.parse_Bytes(payload_bytes, JavaClassLookupPayload)
+ self.assertTrue(isinstance(payload_from_bytes, JavaClassLookupPayload))
+ self._verify_row(
+ payload_from_bytes.constructor_schema,
+ payload_from_bytes.constructor_payload,
+ {'str_field': 'abc', 'int_field': 123})
+
+ def test_build_payload_with_constructor_method(self):
+ payload_builder = JavaClassLookupPayloadBuilder('dummy_class_name')
+ payload_builder.add_constructor_method('dummy_constructor_method', str_field='abc', int_field=123)
+ payload_bytes = payload_builder.payload()
+ payload_from_bytes = proto_utils.parse_Bytes(payload_bytes, JavaClassLookupPayload)
+ self.assertTrue(isinstance(payload_from_bytes, JavaClassLookupPayload))
+ self.assertEqual('dummy_constructor_method',
+ payload_from_bytes.constructor_method)
+ self._verify_row(
+ payload_from_bytes.constructor_schema,
+ payload_from_bytes.constructor_payload,
+ {'str_field': 'abc', 'int_field': 123})
+
+ def test_build_payload_with_builder_methods(self):
+ payload_builder = JavaClassLookupPayloadBuilder('dummy_class_name')
+ payload_builder.add_constructor(str_field='abc', int_field=123)
+ payload_builder.add_builder_method('builder_method1', str_field1='abc1', int_field1=1234)
+ payload_builder.add_builder_method('builder_method2', str_field2='abc2', int_field2=5678)
+ payload_bytes = payload_builder.payload()
+ payload_from_bytes = proto_utils.parse_Bytes(payload_bytes, JavaClassLookupPayload)
+ self.assertTrue(isinstance(payload_from_bytes, JavaClassLookupPayload))
+ self._verify_row(
+ payload_from_bytes.constructor_schema,
+ payload_from_bytes.constructor_payload,
+ {'str_field': 'abc', 'int_field': 123})
+ self.assertEqual(2, len(payload_from_bytes.builder_methods))
+ builder_method = payload_from_bytes.builder_methods[0]
+ self.assertTrue(isinstance(builder_method, BuilderMethod))
+ self.assertEqual('builder_method1', builder_method.name)
+
+ self._verify_row(
+ builder_method.schema,
+ builder_method.payload,
+ {'str_field1': 'abc1', 'int_field1': 1234})
+
+ builder_method = payload_from_bytes.builder_methods[1]
+ self.assertTrue(isinstance(builder_method, BuilderMethod))
+ self.assertEqual('builder_method2', builder_method.name)
+ self._verify_row(
+ builder_method.schema,
+ builder_method.payload,
+ {'str_field2': 'abc2', 'int_field2': 5678})
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)