| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Tests common to all coder implementations.""" |
| # pytype: skip-file |
| |
| import base64 |
| import collections |
| import enum |
| import logging |
| import math |
| import os |
| import pickle |
| import subprocess |
| import sys |
| import textwrap |
| import unittest |
| from decimal import Decimal |
| from typing import Any |
| from typing import List |
| from typing import NamedTuple |
| |
| import pytest |
| from parameterized import param |
| from parameterized import parameterized |
| |
| from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message |
| from apache_beam.coders import coders |
| from apache_beam.coders import typecoders |
| from apache_beam.internal import pickler |
| from apache_beam.runners import pipeline_context |
| from apache_beam.transforms import userstate |
| from apache_beam.transforms import window |
| from apache_beam.transforms.window import GlobalWindow |
| from apache_beam.typehints import sharded_key_type |
| from apache_beam.typehints import typehints |
| from apache_beam.utils import timestamp |
| from apache_beam.utils import windowed_value |
| from apache_beam.utils.sharded_key import ShardedKey |
| from apache_beam.utils.timestamp import MIN_TIMESTAMP |
| |
| from . import observable |
| |
| try: |
| import dataclasses |
| except ImportError: |
| dataclasses = None # type: ignore |
| |
| try: |
| import dill |
| except ImportError: |
| dill = None |
| |
| MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type: ignore[name-match] |
| AnotherNamedTuple = collections.namedtuple('AnotherNamedTuple', ['x', 'y']) |
| MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) |
| |
| |
| class MyEnum(enum.Enum): |
| E1 = 5 |
| E2 = enum.auto() |
| E3 = 'abc' |
| |
| |
| MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3') |
| MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3') |
| MyFlag = enum.Flag('MyFlag', 'F1 F2 F3') # pylint: disable=too-many-function-args |
| |
| |
| class DefinesGetState: |
| def __init__(self, value): |
| self.value = value |
| |
| def __getstate__(self): |
| return self.value |
| |
| def __eq__(self, other): |
| return type(other) is type(self) and other.value == self.value |
| |
| |
| class DefinesGetAndSetState(DefinesGetState): |
| def __setstate__(self, value): |
| self.value = value |
| |
| |
| # Defined out of line for picklability. |
| class CustomCoder(coders.Coder): |
| def encode(self, x): |
| return str(x + 1).encode('utf-8') |
| |
| def decode(self, encoded): |
| return int(encoded) - 1 |
| |
| |
| if dataclasses is not None: |
| |
| @dataclasses.dataclass(frozen=True) |
| class FrozenDataClass: |
| a: Any |
| b: int |
| |
| @dataclasses.dataclass |
| class UnFrozenDataClass: |
| x: int |
| y: int |
| |
| |
| # These tests need to all be run in the same process due to the asserts |
| # in tearDownClass. |
| @pytest.mark.no_xdist |
| @pytest.mark.uses_dill |
| class CodersTest(unittest.TestCase): |
| |
| # These class methods ensure that we test each defined coder in both |
| # nested and unnested context. |
| |
| # Common test values representing Python's built-in types. |
| test_values_deterministic: List[Any] = [ |
| None, |
| 1, |
| -1, |
| 1.5, |
| b'str\0str', |
| 'unicode\0\u0101', |
| (), |
| (1, 2, 3), |
| [], |
| [1, 2, 3], |
| True, |
| False, |
| ] |
| test_values = test_values_deterministic + [ |
| {}, |
| { |
| 'a': 'b' |
| }, |
| { |
| 0: {}, 1: len |
| }, |
| set(), |
| {'a', 'b'}, |
| len, |
| ] |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.seen = set() |
| cls.seen_nested = set() |
| |
| @classmethod |
| def tearDownClass(cls): |
| standard = set( |
| c for c in coders.__dict__.values() if isinstance(c, type) and |
| issubclass(c, coders.Coder) and 'Base' not in c.__name__) |
| standard -= set([ |
| coders.Coder, |
| coders.AvroGenericCoder, |
| coders.DeterministicProtoCoder, |
| coders.FastCoder, |
| coders.ListLikeCoder, |
| coders.ProtoCoder, |
| coders.ProtoPlusCoder, |
| coders.BigEndianShortCoder, |
| coders.SinglePrecisionFloatCoder, |
| coders.ToBytesCoder, |
| coders.BigIntegerCoder, # tested in DecimalCoder |
| coders.TimestampPrefixingOpaqueWindowCoder, |
| ]) |
| if not dill: |
| standard -= set( |
| [coders.DillCoder, coders.DeterministicFastPrimitivesCoder]) |
| cls.seen_nested -= set( |
| [coders.ProtoCoder, coders.ProtoPlusCoder, CustomCoder]) |
| assert not standard - cls.seen, str(standard - cls.seen) |
| assert not cls.seen_nested - standard, str(cls.seen_nested - standard) |
| |
| def tearDown(self): |
| typecoders.registry.update_compatibility_version = None |
| |
| @classmethod |
| def _observe(cls, coder): |
| cls.seen.add(type(coder)) |
| cls._observe_nested(coder) |
| |
| @classmethod |
| def _observe_nested(cls, coder): |
| if isinstance(coder, coders.TupleCoder): |
| for c in coder.coders(): |
| cls.seen_nested.add(type(c)) |
| cls._observe_nested(c) |
| |
| def check_coder(self, coder, *values, **kwargs): |
| context = kwargs.pop('context', pipeline_context.PipelineContext()) |
| test_size_estimation = kwargs.pop('test_size_estimation', True) |
| assert not kwargs |
| self._observe(coder) |
| for v in values: |
| self.assertEqual(v, coder.decode(coder.encode(v))) |
| if test_size_estimation: |
| self.assertEqual(coder.estimate_size(v), len(coder.encode(v))) |
| self.assertEqual( |
| coder.estimate_size(v), coder.get_impl().estimate_size(v)) |
| self.assertEqual( |
| coder.get_impl().get_estimated_size_and_observables(v), |
| (coder.get_impl().estimate_size(v), [])) |
| copy1 = pickler.loads(pickler.dumps(coder)) |
| copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context) |
| for v in values: |
| self.assertEqual(v, copy1.decode(copy2.encode(v))) |
| if coder.is_deterministic(): |
| self.assertEqual(copy1.encode(v), copy2.encode(v)) |
| |
| def test_custom_coder(self): |
| |
| self.check_coder(CustomCoder(), 1, -10, 5) |
| self.check_coder( |
| coders.TupleCoder((CustomCoder(), coders.BytesCoder())), (1, b'a'), |
| (-10, b'b'), (5, b'c')) |
| |
| def test_pickle_coder(self): |
| coder = coders.PickleCoder() |
| self.check_coder(coder, *self.test_values) |
| |
| def test_cloudpickle_pickle_coder(self): |
| cell_value = (lambda x: lambda: x)(0).__closure__[0] |
| self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value) |
| self.check_coder( |
| coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())), |
| (1, cell_value)) |
| |
| def test_memoizing_pickle_coder(self): |
| coder = coders._MemoizingPickleCoder() |
| self.check_coder(coder, *self.test_values) |
| |
| @parameterized.expand([ |
| param(compat_version=None), |
| param(compat_version="2.67.0"), |
| param(compat_version="2.68.0"), |
| ]) |
| def test_deterministic_coder(self, compat_version): |
| """ Test in process determinism for all special deterministic types |
| |
| - In SDK version <= 2.67.0 dill is used to encode "special types" |
| - In SDK version 2.68.0 cloudpickle is used to encode "special types" with |
| absolute filepaths in code objects and dynamic functions. |
| - In SDK version >=2.69.0 cloudpickle is used to encode "special types" |
| with relative filepaths in code objects and dynamic functions. |
| """ |
| |
| typecoders.registry.update_compatibility_version = compat_version |
| coder = coders.FastPrimitivesCoder() |
| if not dill and compat_version == "2.67.0": |
| with self.assertRaises(RuntimeError): |
| coder.as_deterministic_coder(step_label="step") |
| self.skipTest('Dill not installed') |
| deterministic_coder = coder.as_deterministic_coder(step_label="step") |
| |
| self.check_coder(deterministic_coder, *self.test_values_deterministic) |
| for v in self.test_values_deterministic: |
| self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, )) |
| self.check_coder( |
| coders.TupleCoder( |
| (deterministic_coder, ) * len(self.test_values_deterministic)), |
| tuple(self.test_values_deterministic)) |
| |
| self.check_coder(deterministic_coder, {}) |
| self.check_coder(deterministic_coder, {2: 'x', 1: 'y'}) |
| with self.assertRaises(TypeError): |
| self.check_coder(deterministic_coder, {1: 'x', 'y': 2}) |
| self.check_coder(deterministic_coder, [1, {}]) |
| with self.assertRaises(TypeError): |
| self.check_coder(deterministic_coder, [1, {1: 'x', 'y': 2}]) |
| |
| self.check_coder( |
| coders.TupleCoder((deterministic_coder, coder)), (1, {}), ('a', [{}])) |
| |
| self.check_coder(deterministic_coder, test_message.MessageA(field1='value')) |
| |
| # Skip this test during cloudpickle. Dill monkey patches the __reduce__ |
| # method for anonymous named tuples (MyNamedTuple) which is not pickleable. |
| # Since the test is parameterized the type gets colbbered. |
| if compat_version == "2.67.0": |
| self.check_coder( |
| deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) |
| |
| self.check_coder( |
| deterministic_coder, |
| [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) |
| |
| if dataclasses is not None: |
| self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) |
| |
| with self.assertRaises(TypeError): |
| self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) |
| with self.assertRaises(TypeError): |
| self.check_coder( |
| deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) |
| with self.assertRaises(TypeError): |
| self.check_coder( |
| deterministic_coder, |
| AnotherNamedTuple(UnFrozenDataClass(1, 2), 3)) |
| |
| self.check_coder(deterministic_coder, list(MyEnum)) |
| self.check_coder(deterministic_coder, list(MyIntEnum)) |
| self.check_coder(deterministic_coder, list(MyIntFlag)) |
| self.check_coder(deterministic_coder, list(MyFlag)) |
| |
| self.check_coder( |
| deterministic_coder, |
| [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]) |
| |
| with self.assertRaises(TypeError): |
| self.check_coder(deterministic_coder, DefinesGetState(1)) |
| with self.assertRaises(TypeError): |
| self.check_coder( |
| deterministic_coder, DefinesGetAndSetState({ |
| 1: 'x', 'y': 2 |
| })) |
| |
| @parameterized.expand([ |
| param(compat_version=None), |
| param(compat_version="2.67.0"), |
| param(compat_version="2.68.0"), |
| ]) |
| def test_deterministic_map_coder_is_update_compatible(self, compat_version): |
| """ Test in process determinism for map coder including when a component |
| coder uses DeterministicFastPrimitivesCoder for "special types". |
| |
| - In SDK version <= 2.67.0 dill is used to encode "special types" |
| - In SDK version 2.68.0 cloudpickle is used to encode "special types" with |
| absolute filepaths in code objects and dynamic functions. |
| - In SDK version >=2.69.0 cloudpickle is used to encode "special types" |
| with relative file. |
| """ |
| typecoders.registry.update_compatibility_version = compat_version |
| values = [{ |
| MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i) |
| for i in range(10) |
| }] |
| |
| coder = coders.MapCoder( |
| coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()) |
| |
| if not dill and compat_version == "2.67.0": |
| with self.assertRaises(RuntimeError): |
| coder.as_deterministic_coder(step_label="step") |
| self.skipTest('Dill not installed') |
| |
| deterministic_coder = coder.as_deterministic_coder(step_label="step") |
| |
| assert isinstance( |
| deterministic_coder._key_coder, |
| coders.DeterministicFastPrimitivesCoderV2 if compat_version |
| in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder) |
| |
| self.check_coder(deterministic_coder, *values) |
| |
| def test_dill_coder(self): |
| if not dill: |
| with self.assertRaises(RuntimeError): |
| coders.DillCoder() |
| self.skipTest('Dill not installed') |
| |
| cell_value = (lambda x: lambda: x)(0).__closure__[0] |
| self.check_coder(coders.DillCoder(), 'a', 1, cell_value) |
| self.check_coder( |
| coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())), |
| (1, cell_value)) |
| |
| def test_fast_primitives_coder(self): |
| coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len)) |
| self.check_coder(coder, *self.test_values) |
| for v in self.test_values: |
| self.check_coder(coders.TupleCoder((coder, )), (v, )) |
| |
| def test_fast_primitives_coder_large_int(self): |
| coder = coders.FastPrimitivesCoder() |
| self.check_coder(coder, 10**100) |
| |
| def test_fake_deterministic_fast_primitives_coder(self): |
| coder = coders.FakeDeterministicFastPrimitivesCoder(coders.PickleCoder()) |
| self.check_coder(coder, *self.test_values) |
| for v in self.test_values: |
| self.check_coder(coders.TupleCoder((coder, )), (v, )) |
| |
| def test_bytes_coder(self): |
| self.check_coder(coders.BytesCoder(), b'a', b'\0', b'z' * 1000) |
| |
| def test_bool_coder(self): |
| self.check_coder(coders.BooleanCoder(), True, False) |
| |
| def test_varint_coder(self): |
| # Small ints. |
| self.check_coder(coders.VarIntCoder(), *range(-10, 10)) |
| # Multi-byte encoding starts at 128 |
| self.check_coder(coders.VarIntCoder(), *range(120, 140)) |
| # Large values |
| MAX_64_BIT_INT = 0x7fffffffffffffff |
| self.check_coder( |
| coders.VarIntCoder(), |
| *[ |
| int(math.pow(-1, k) * math.exp(k)) |
| for k in range(0, int(math.log(MAX_64_BIT_INT))) |
| ]) |
| |
| def test_varint32_coder(self): |
| # Small ints. |
| self.check_coder(coders.VarInt32Coder(), *range(-10, 10)) |
| # Multi-byte encoding starts at 128 |
| self.check_coder(coders.VarInt32Coder(), *range(120, 140)) |
| # Large values |
| MAX_32_BIT_INT = 0x7fffffff |
| self.check_coder( |
| coders.VarIntCoder(), |
| *[ |
| int(math.pow(-1, k) * math.exp(k)) |
| for k in range(0, int(math.log(MAX_32_BIT_INT))) |
| ]) |
| |
| def test_float_coder(self): |
| self.check_coder( |
| coders.FloatCoder(), *[float(0.1 * x) for x in range(-100, 100)]) |
| self.check_coder( |
| coders.FloatCoder(), *[float(2**(0.1 * x)) for x in range(-100, 100)]) |
| self.check_coder(coders.FloatCoder(), float('-Inf'), float('Inf')) |
| self.check_coder( |
| coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())), (0, 1), |
| (-100, 100), (0.5, 0.25)) |
| |
| def test_singleton_coder(self): |
| a = 'anything' |
| b = 'something else' |
| self.check_coder(coders.SingletonCoder(a), a) |
| self.check_coder(coders.SingletonCoder(b), b) |
| self.check_coder( |
| coders.TupleCoder((coders.SingletonCoder(a), coders.SingletonCoder(b))), |
| (a, b)) |
| |
| def test_interval_window_coder(self): |
| self.check_coder( |
| coders.IntervalWindowCoder(), |
| *[ |
| window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] |
| for y in range(-100, 100) |
| ]) |
| self.check_coder( |
| coders.TupleCoder((coders.IntervalWindowCoder(), )), |
| (window.IntervalWindow(0, 10), )) |
| |
| def test_paneinfo_window_coder(self): |
| self.check_coder( |
| coders.PaneInfoCoder(), |
| *[ |
| windowed_value.PaneInfo( |
| is_first=y == 0, |
| is_last=y == 9, |
| timing=windowed_value.PaneInfoTiming.EARLY, |
| index=y, |
| nonspeculative_index=-1) for y in range(0, 10) |
| ]) |
| |
| def test_timestamp_coder(self): |
| self.check_coder( |
| coders.TimestampCoder(), |
| *[timestamp.Timestamp(micros=x) for x in (-1000, 0, 1000)]) |
| self.check_coder( |
| coders.TimestampCoder(), |
| timestamp.Timestamp(micros=-1234567000), |
| timestamp.Timestamp(micros=1234567000)) |
| self.check_coder( |
| coders.TimestampCoder(), |
| timestamp.Timestamp(micros=-1234567890123456000), |
| timestamp.Timestamp(micros=1234567890123456000)) |
| self.check_coder( |
| coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())), |
| (timestamp.Timestamp.of(27), b'abc')) |
| |
| def test_timer_coder(self): |
| self.check_coder( |
| coders._TimerCoder(coders.StrUtf8Coder(), coders.GlobalWindowCoder()), |
| *[ |
| userstate.Timer( |
| user_key="key", |
| dynamic_timer_tag="tag", |
| windows=(GlobalWindow(), ), |
| clear_bit=True, |
| fire_timestamp=None, |
| hold_timestamp=None, |
| paneinfo=None), |
| userstate.Timer( |
| user_key="key", |
| dynamic_timer_tag="tag", |
| windows=(GlobalWindow(), ), |
| clear_bit=False, |
| fire_timestamp=timestamp.Timestamp.of(123), |
| hold_timestamp=timestamp.Timestamp.of(456), |
| paneinfo=windowed_value.PANE_INFO_UNKNOWN) |
| ]) |
| |
| def test_tuple_coder(self): |
| kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder())) |
| # Test binary representation |
| self.assertEqual(b'\x04abc', kv_coder.encode((4, b'abc'))) |
| # Test unnested |
| self.check_coder(kv_coder, (1, b'a'), (-2, b'a' * 100), (300, b'abc\0' * 5)) |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder(( |
| coders.TupleCoder((coders.PickleCoder(), coders.VarIntCoder())), |
| coders.StrUtf8Coder(), |
| coders.BooleanCoder())), ((1, 2), 'a', True), |
| ((-2, 5), 'a\u0101' * 100, False), ((300, 1), 'abc\0' * 5, True)) |
| |
| def test_tuple_sequence_coder(self): |
| int_tuple_coder = coders.TupleSequenceCoder(coders.VarIntCoder()) |
| self.check_coder(int_tuple_coder, (1, -1, 0), (), tuple(range(1000))) |
| self.check_coder( |
| coders.TupleCoder((coders.VarIntCoder(), int_tuple_coder)), |
| (1, (1, 2, 3))) |
| |
| def test_base64_pickle_coder(self): |
| self.check_coder(coders.Base64PickleCoder(), 'a', 1, 1.5, (1, 2, 3)) |
| |
| def test_utf8_coder(self): |
| self.check_coder(coders.StrUtf8Coder(), 'a', 'ab\u00FF', '\u0101\0') |
| |
| def test_iterable_coder(self): |
| iterable_coder = coders.IterableCoder(coders.VarIntCoder()) |
| # Test unnested |
| self.check_coder(iterable_coder, [1], [-1, 0, 100]) |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder( |
| (coders.VarIntCoder(), coders.IterableCoder(coders.VarIntCoder()))), |
| (1, [1, 2, 3])) |
| |
| def test_iterable_coder_unknown_length(self): |
| # Empty |
| self._test_iterable_coder_of_unknown_length(0) |
| # Single element |
| self._test_iterable_coder_of_unknown_length(1) |
| # Multiple elements |
| self._test_iterable_coder_of_unknown_length(100) |
| # Multiple elements with underlying stream buffer overflow. |
| self._test_iterable_coder_of_unknown_length(80000) |
| |
| def _test_iterable_coder_of_unknown_length(self, count): |
| def iter_generator(count): |
| for i in range(count): |
| yield i |
| |
| iterable_coder = coders.IterableCoder(coders.VarIntCoder()) |
| self.assertCountEqual( |
| list(iter_generator(count)), |
| iterable_coder.decode(iterable_coder.encode(iter_generator(count)))) |
| |
| def test_list_coder(self): |
| list_coder = coders.ListCoder(coders.VarIntCoder()) |
| # Test unnested |
| self.check_coder(list_coder, [1], [-1, 0, 100]) |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder((coders.VarIntCoder(), list_coder)), (1, [1, 2, 3])) |
| |
| def test_windowedvalue_coder_paneinfo(self): |
| coder = coders.WindowedValueCoder( |
| coders.VarIntCoder(), coders.GlobalWindowCoder()) |
| test_paneinfo_values = [ |
| windowed_value.PANE_INFO_UNKNOWN, |
| windowed_value.PaneInfo( |
| True, True, windowed_value.PaneInfoTiming.EARLY, 0, -1), |
| windowed_value.PaneInfo( |
| True, False, windowed_value.PaneInfoTiming.ON_TIME, 0, 0), |
| windowed_value.PaneInfo( |
| True, False, windowed_value.PaneInfoTiming.ON_TIME, 10, 0), |
| windowed_value.PaneInfo( |
| False, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 23), |
| windowed_value.PaneInfo( |
| False, True, windowed_value.PaneInfoTiming.ON_TIME, 12, 23), |
| windowed_value.PaneInfo( |
| False, False, windowed_value.PaneInfoTiming.LATE, 0, 123), |
| ] |
| |
| test_values = [ |
| windowed_value.WindowedValue(123, 234, (GlobalWindow(), ), p) |
| for p in test_paneinfo_values |
| ] |
| |
| # Test unnested. |
| self.check_coder( |
| coder, |
| windowed_value.WindowedValue( |
| 123, 234, (GlobalWindow(), ), windowed_value.PANE_INFO_UNKNOWN)) |
| for value in test_values: |
| self.check_coder(coder, value) |
| |
| # Test nested. |
| for value1 in test_values: |
| for value2 in test_values: |
| self.check_coder(coders.TupleCoder((coder, coder)), (value1, value2)) |
| |
| def test_windowed_value_coder(self): |
| coder = coders.WindowedValueCoder( |
| coders.VarIntCoder(), coders.GlobalWindowCoder()) |
| # Test binary representation |
| self.assertEqual( |
| b'\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01', |
| coder.encode(window.GlobalWindows.windowed_value(1))) |
| |
| # Test decoding large timestamp |
| self.assertEqual( |
| coder.decode(b'\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'), |
| windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(), ))) |
| |
| # Test unnested |
| self.check_coder( |
| coders.WindowedValueCoder(coders.VarIntCoder()), |
| windowed_value.WindowedValue(3, -100, ()), |
| windowed_value.WindowedValue(-1, 100, (1, 2, 3))) |
| |
| # Test Global Window |
| self.check_coder( |
| coders.WindowedValueCoder( |
| coders.VarIntCoder(), coders.GlobalWindowCoder()), |
| window.GlobalWindows.windowed_value(1)) |
| |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder(( |
| coders.WindowedValueCoder(coders.FloatCoder()), |
| coders.WindowedValueCoder(coders.StrUtf8Coder()))), |
| ( |
| windowed_value.WindowedValue(1.5, 0, ()), |
| windowed_value.WindowedValue("abc", 10, ('window', )))) |
| |
| def test_param_windowed_value_coder(self): |
| from apache_beam.transforms.window import IntervalWindow |
| from apache_beam.utils.windowed_value import PaneInfo |
| # pylint: disable=too-many-function-args |
| wv = windowed_value.create( |
| b'', |
| # Milliseconds to microseconds |
| 1000 * 1000, |
| (IntervalWindow(11, 21), ), |
| PaneInfo(True, False, 1, 2, 3)) |
| windowed_value_coder = coders.WindowedValueCoder( |
| coders.BytesCoder(), coders.IntervalWindowCoder()) |
| payload = windowed_value_coder.encode(wv) |
| coder = coders.ParamWindowedValueCoder( |
| payload, [coders.VarIntCoder(), coders.IntervalWindowCoder()]) |
| |
| # Test binary representation |
| self.assertEqual( |
| b'\x01', coder.encode(window.GlobalWindows.windowed_value(1))) |
| |
| # Test unnested |
| self.check_coder( |
| coders.ParamWindowedValueCoder( |
| payload, [coders.VarIntCoder(), coders.IntervalWindowCoder()]), |
| windowed_value.WindowedValue( |
| 3, |
| 1, (window.IntervalWindow(11, 21), ), |
| PaneInfo(True, False, 1, 2, 3)), |
| windowed_value.WindowedValue( |
| 1, |
| 1, (window.IntervalWindow(11, 21), ), |
| PaneInfo(True, False, 1, 2, 3))) |
| |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder(( |
| coders.ParamWindowedValueCoder( |
| payload, [coders.FloatCoder(), coders.IntervalWindowCoder()]), |
| coders.ParamWindowedValueCoder( |
| payload, |
| [coders.StrUtf8Coder(), coders.IntervalWindowCoder()]))), |
| ( |
| windowed_value.WindowedValue( |
| 1.5, |
| 1, (window.IntervalWindow(11, 21), ), |
| PaneInfo(True, False, 1, 2, 3)), |
| windowed_value.WindowedValue( |
| "abc", |
| 1, (window.IntervalWindow(11, 21), ), |
| PaneInfo(True, False, 1, 2, 3)))) |
| |
| @parameterized.expand([ |
| param(compat_version=None), |
| param(compat_version="2.67.0"), |
| param(compat_version="2.68.0"), |
| ]) |
| def test_cross_process_encoding_of_special_types_is_deterministic( |
| self, compat_version): |
| """Test cross-process determinism for all special deterministic types |
| |
| - In SDK version <= 2.67.0 dill is used to encode "special types" |
| - In SDK version 2.68.0 cloudpickle is used to encode "special types" with |
| absolute filepaths in code objects and dynamic functions. |
| - In SDK version 2.69.0 cloudpickle is used to encode "special types" with |
| relative filepaths in code objects and dynamic functions. |
| """ |
| is_using_dill = compat_version == "2.67.0" |
| if is_using_dill: |
| pytest.importorskip("dill") |
| |
| if sys.executable is None: |
| self.skipTest('No Python interpreter found') |
| typecoders.registry.update_compatibility_version = compat_version |
| |
| # pylint: disable=line-too-long |
| script = textwrap.dedent( |
| f'''\ |
| import pickle |
| import sys |
| import collections |
| import enum |
| import logging |
| |
| from apache_beam.coders import coders |
| from apache_beam.coders import typecoders |
| from apache_beam.coders.coders_test_common import MyNamedTuple |
| from apache_beam.coders.coders_test_common import MyTypedNamedTuple |
| from apache_beam.coders.coders_test_common import MyEnum |
| from apache_beam.coders.coders_test_common import MyIntEnum |
| from apache_beam.coders.coders_test_common import MyIntFlag |
| from apache_beam.coders.coders_test_common import MyFlag |
| from apache_beam.coders.coders_test_common import DefinesGetState |
| from apache_beam.coders.coders_test_common import DefinesGetAndSetState |
| from apache_beam.coders.coders_test_common import FrozenDataClass |
| |
| |
| from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| stream=sys.stderr, |
| force=True |
| ) |
| |
| # Test cases for all special deterministic types |
| # NOTE: When this script run in a subprocess the module is considered |
| # __main__. Dill cannot pickle enums in __main__ because it |
| # needs to define a way to create the type if it does not exist |
| # in the session, and reaches recursion depth limits. |
| test_cases = [ |
| ("proto_message", test_message.MessageA(field1='value')), |
| ("named_tuple_simple", MyNamedTuple(1, 2)), |
| ("typed_named_tuple", MyTypedNamedTuple(1, 'a')), |
| ("named_tuple_list", [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]), |
| ("enum_single", MyEnum.E1), |
| ("enum_list", list(MyEnum)), |
| ("int_enum_list", list(MyIntEnum)), |
| ("int_flag_list", list(MyIntFlag)), |
| ("flag_list", list(MyFlag)), |
| ("getstate_setstate_simple", DefinesGetAndSetState(1)), |
| ("getstate_setstate_complex", DefinesGetAndSetState((1, 2, 3))), |
| ("getstate_setstate_list", [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]), |
| ] |
| |
| |
| test_cases.extend([ |
| ("frozen_dataclass", FrozenDataClass(1, 2)), |
| ("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]), |
| ]) |
| |
| compat_version = {'"'+ compat_version +'"' if compat_version else None} |
| typecoders.registry.update_compatibility_version = compat_version |
| coder = coders.FastPrimitivesCoder() |
| deterministic_coder = coder.as_deterministic_coder("step") |
| |
| results = dict() |
| for test_name, value in test_cases: |
| try: |
| encoded = deterministic_coder.encode(value) |
| results[test_name] = encoded |
| except Exception as e: |
| logging.warning("Encoding failed with %s", e) |
| sys.exit(1) |
| |
| sys.stdout.buffer.write(pickle.dumps(results)) |
| |
| |
| ''') |
| |
| def run_subprocess(): |
| result = subprocess.run([sys.executable, '-c', script], |
| capture_output=True, |
| timeout=30, |
| check=False) |
| |
| self.assertEqual( |
| 0, result.returncode, f"Subprocess failed: {result.stderr}") |
| return pickle.loads(result.stdout) |
| |
| results1 = run_subprocess() |
| results2 = run_subprocess() |
| |
| coder = coders.FastPrimitivesCoder() |
| deterministic_coder = coder.as_deterministic_coder("step") |
| |
| for test_name in results1: |
| |
| data1 = results1[test_name] |
| data2 = results2[test_name] |
| |
| self.assertEqual( |
| data1, data2, f"Cross-process encoding differs for {test_name}") |
| self.assertGreater(len(data1), 1) |
| |
| try: |
| decoded1 = deterministic_coder.decode(data1) |
| decoded2 = deterministic_coder.decode(data2) |
| except Exception as e: |
| logging.warning("Could not decode %s data due to %s", test_name, e) |
| continue |
| |
| if test_name == "named_tuple_simple" and not is_using_dill: |
| # The absense of a compat_version means we are using the most recent |
| # implementation of the coder, which uses relative paths. |
| should_have_relative_path = not compat_version |
| named_tuple_type = type(decoded1) |
| self.assertEqual( |
| os.path.isabs(named_tuple_type._make.__code__.co_filename), |
| not should_have_relative_path) |
| self.assertEqual( |
| os.path.isabs( |
| named_tuple_type.__getnewargs__.__globals__['__file__']), |
| not should_have_relative_path) |
| |
| self.assertEqual( |
| decoded1, decoded2, f"Cross-process decoding differs for {test_name}") |
| self.assertIsInstance( |
| decoded1, |
| type(decoded2), |
| f"Cross-process decoding differs for {test_name}") |
| |
| def test_proto_coder(self): |
| # For instructions on how these test proto message were generated, |
| # see coders_test.py |
| ma = test_message.MessageA() |
| mab = ma.field2.add() |
| mab.field1 = True |
| ma.field1 = 'hello world' |
| |
| mb = test_message.MessageA() |
| mb.field1 = 'beam' |
| |
| proto_coder = coders.ProtoCoder(ma.__class__) |
| self.check_coder(proto_coder, ma) |
| self.check_coder( |
| coders.TupleCoder((proto_coder, coders.BytesCoder())), (ma, b'a'), |
| (mb, b'b')) |
| |
| def test_global_window_coder(self): |
| coder = coders.GlobalWindowCoder() |
| value = window.GlobalWindow() |
| # Test binary representation |
| self.assertEqual(b'', coder.encode(value)) |
| self.assertEqual(value, coder.decode(b'')) |
| # Test unnested |
| self.check_coder(coder, value) |
| # Test nested |
| self.check_coder(coders.TupleCoder((coder, coder)), (value, value)) |
| |
| def test_length_prefix_coder(self): |
| coder = coders.LengthPrefixCoder(coders.BytesCoder()) |
| # Test binary representation |
| self.assertEqual(b'\x00', coder.encode(b'')) |
| self.assertEqual(b'\x01a', coder.encode(b'a')) |
| self.assertEqual(b'\x02bc', coder.encode(b'bc')) |
| self.assertEqual(b'\xff\x7f' + b'z' * 16383, coder.encode(b'z' * 16383)) |
| # Test unnested |
| self.check_coder(coder, b'', b'a', b'bc', b'def') |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder((coder, coder)), (b'', b'a'), (b'bc', b'def')) |
| |
| def test_nested_observables(self): |
| class FakeObservableIterator(observable.ObservableMixin): |
| def __iter__(self): |
| return iter([1, 2, 3]) |
| |
| # Coder for elements from the observable iterator. |
| elem_coder = coders.VarIntCoder() |
| iter_coder = coders.TupleSequenceCoder(elem_coder) |
| |
| # Test nested WindowedValue observable. |
| coder = coders.WindowedValueCoder(iter_coder) |
| observ = FakeObservableIterator() |
| value = windowed_value.WindowedValue(observ, 0, ()) |
| self.assertEqual( |
| coder.get_impl().get_estimated_size_and_observables(value)[1], |
| [(observ, elem_coder.get_impl())]) |
| |
| # Test nested tuple observable. |
| coder = coders.TupleCoder((coders.StrUtf8Coder(), iter_coder)) |
| value = ('123', observ) |
| self.assertEqual( |
| coder.get_impl().get_estimated_size_and_observables(value)[1], |
| [(observ, elem_coder.get_impl())]) |
| |
| def test_state_backed_iterable_coder(self): |
| # pylint: disable=global-variable-undefined |
| # required for pickling by reference |
| global state |
| state = {} |
| |
| def iterable_state_write(values, element_coder_impl): |
| token = b'state_token_%d' % len(state) |
| state[token] = [element_coder_impl.encode(e) for e in values] |
| return token |
| |
| def iterable_state_read(token, element_coder_impl): |
| return [element_coder_impl.decode(s) for s in state[token]] |
| |
| coder = coders.StateBackedIterableCoder( |
| coders.VarIntCoder(), |
| read_state=iterable_state_read, |
| write_state=iterable_state_write, |
| write_state_threshold=1) |
| # Note: do not use check_coder |
| # see https://github.com/cloudpipe/cloudpickle/issues/452 |
| self._observe(coder) |
| self.assertEqual([1, 2, 3], coder.decode(coder.encode([1, 2, 3]))) |
| # Ensure that state was actually used. |
| self.assertNotEqual(state, {}) |
| tupleCoder = coders.TupleCoder((coder, coder)) |
| self._observe(tupleCoder) |
| self.assertEqual(([1], [2, 3]), |
| tupleCoder.decode(tupleCoder.encode(([1], [2, 3])))) |
| |
| def test_nullable_coder(self): |
| self.check_coder(coders.NullableCoder(coders.VarIntCoder()), None, 2 * 64) |
| |
| def test_map_coder(self): |
| values = [ |
| { |
| 1: "one", 300: "three hundred" |
| }, # force yapf to be nice |
| {}, |
| { |
| i: str(i) |
| for i in range(5000) |
| }, |
| ] |
| map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder()) |
| self.check_coder(map_coder, *values) |
| self.check_coder(map_coder.as_deterministic_coder("label"), *values) |
| |
| def test_sharded_key_coder(self): |
| key_and_coders = [(b'', b'\x00', coders.BytesCoder()), |
| (b'key', b'\x03key', coders.BytesCoder()), |
| ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()), |
| (('k', 1), |
| b'\x01\x6b\x01', |
| coders.TupleCoder( |
| (coders.StrUtf8Coder(), coders.VarIntCoder())))] |
| |
| for key, bytes_repr, key_coder in key_and_coders: |
| coder = coders.ShardedKeyCoder(key_coder) |
| |
| # Test str repr |
| self.assertEqual('%s' % coder, 'ShardedKeyCoder[%s]' % key_coder) |
| |
| self.assertEqual(b'\x00' + bytes_repr, coder.encode(ShardedKey(key, b''))) |
| self.assertEqual( |
| b'\x03123' + bytes_repr, coder.encode(ShardedKey(key, b'123'))) |
| |
| # Test unnested |
| self.check_coder(coder, ShardedKey(key, b'')) |
| self.check_coder(coder, ShardedKey(key, b'123')) |
| |
| # Test type hints |
| self.assertTrue( |
| isinstance( |
| coder.to_type_hint(), sharded_key_type.ShardedKeyTypeConstraint)) |
| key_type = coder.to_type_hint().key_type |
| if isinstance(key_type, typehints.TupleConstraint): |
| self.assertEqual(key_type.tuple_types, (type(key[0]), type(key[1]))) |
| else: |
| self.assertEqual(key_type, type(key)) |
| self.assertEqual( |
| coders.ShardedKeyCoder.from_type_hint( |
| coder.to_type_hint(), typecoders.CoderRegistry()), |
| coder) |
| |
| for other_key, _, other_key_coder in key_and_coders: |
| other_coder = coders.ShardedKeyCoder(other_key_coder) |
| # Test nested |
| self.check_coder( |
| coders.TupleCoder((coder, other_coder)), |
| (ShardedKey(key, b''), ShardedKey(other_key, b''))) |
| self.check_coder( |
| coders.TupleCoder((coder, other_coder)), |
| (ShardedKey(key, b'123'), ShardedKey(other_key, b''))) |
| |
| def test_timestamp_prefixing_window_coder(self): |
| self.check_coder( |
| coders.TimestampPrefixingWindowCoder(coders.IntervalWindowCoder()), |
| *[ |
| window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] |
| for y in range(-100, 100) |
| ]) |
| self.check_coder( |
| coders.TupleCoder(( |
| coders.TimestampPrefixingWindowCoder( |
| coders.IntervalWindowCoder()), )), |
| (window.IntervalWindow(0, 10), )) |
| |
| def test_timestamp_prefixing_opaque_window_coder(self): |
| sdk_coder = coders.TimestampPrefixingWindowCoder( |
| coders.LengthPrefixCoder(coders.PickleCoder())) |
| safe_coder = coders.TimestampPrefixingOpaqueWindowCoder() |
| for w in [window.IntervalWindow(1, 123), window.GlobalWindow()]: |
| round_trip = sdk_coder.decode( |
| safe_coder.encode(safe_coder.decode(sdk_coder.encode(w)))) |
| self.assertEqual(w, round_trip) |
| |
| def test_decimal_coder(self): |
| test_coder = coders.DecimalCoder() |
| |
| test_values = [ |
| Decimal("-10.5"), |
| Decimal("-1"), |
| Decimal(), |
| Decimal("1"), |
| Decimal("13.258"), |
| ] |
| |
| test_encodings = ("AZc", "AP8", "AAA", "AAE", "AzPK") |
| |
| self.check_coder(test_coder, *test_values) |
| |
| for idx, value in enumerate(test_values): |
| self.assertEqual( |
| test_encodings[idx], |
| base64.b64encode(test_coder.encode(value)).decode().rstrip("=")) |
| |
| def test_OrderedUnionCoder(self): |
| test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()), |
| (int, coders.VarIntCoder()), |
| fallback_coder=coders.FloatCoder()) |
| self.check_coder(test_coder, 's') |
| self.check_coder(test_coder, 123) |
| self.check_coder(test_coder, 1.5) |
| |
| |
| if __name__ == '__main__': |
| logging.getLogger().setLevel(logging.INFO) |
| unittest.main() |