[BEAM-8157] Add integration test for non standard key coders
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 24c6b87..992aa08 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -45,7 +45,10 @@
from apache_beam.runners.portability.portable_runner import PortableRunner
from apache_beam.runners.worker import worker_pool_main
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
from apache_beam.transforms import environments
+from apache_beam.transforms import userstate
class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
@@ -188,6 +191,46 @@
def test_metrics(self):
self.skipTest('Metrics not supported.')
+ def test_pardo_state_with_custom_key_coder(self):
+ """Tests that state requests work correctly when the key coder is an
+ SDK-specific coder, i.e. non standard coder. This is additionally enforced
+ by Java's ProcessBundleDescriptorsTest and by Flink's
+ ExecutableStageDoFnOperator which detects invalid encoding by checking for
+ the correct key group of the encoded key."""
+ index_state_spec = userstate.CombiningValueStateSpec('index', sum)
+
+ # Test params
+ # Ensure decent amount of elements to serve all partitions
+ n = 200
+ duplicates = 1
+
+ split = n // (duplicates + 1)
+ inputs = [(i % split, str(i % split)) for i in range(0, n)]
+
+ # Use a DoFn which has to use FastPrimitivesCoder because the type cannot
+ # be inferred
+ class Input(beam.DoFn):
+ def process(self, impulse):
+ for i in inputs:
+ yield i
+
+ class AddIndex(beam.DoFn):
+ def process(self, kv,
+ index=beam.DoFn.StateParam(index_state_spec)):
+ k, v = kv
+ index.add(1)
+ yield k, v, index.read()
+
+ expected = [(i % split, str(i % split), i // split + 1)
+ for i in range(0, n)]
+
+ with self.create_pipeline() as p:
+ assert_that(p
+ | beam.Impulse()
+ | beam.ParDo(Input())
+ | beam.ParDo(AddIndex()),
+ equal_to(expected))
+
# Inherits all other tests from fn_api_runner_test.FnApiRunnerTest