blob: 881e57f5a01556761dbd0f4473afb6d8710b051e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for coders that must be consistent across all Beam SDKs.
"""
# pytype: skip-file
from __future__ import absolute_import
from __future__ import print_function
import json
import logging
import math
import os.path
import sys
import unittest
from builtins import map
from typing import Dict
from typing import Tuple
import yaml
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import schema_pb2
from apache_beam.runners import pipeline_context
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.transforms.window import IntervalWindow
from apache_beam.typehints import schemas
from apache_beam.utils import windowed_value
from apache_beam.utils.timestamp import Timestamp
from apache_beam.utils.windowed_value import PaneInfo
from apache_beam.utils.windowed_value import PaneInfoTiming
STANDARD_CODERS_YAML = os.path.normpath(
os.path.join(
os.path.dirname(__file__), '../portability/api/standard_coders.yaml'))
def _load_test_cases(test_yaml):
"""Load test data from yaml file and return an iterable of test cases.
See ``standard_coders.yaml`` for more details.
"""
if not os.path.exists(test_yaml):
raise ValueError('Could not find the test spec: %s' % test_yaml)
with open(test_yaml, 'rb') as coder_spec:
for ix, spec in enumerate(
yaml.load_all(coder_spec, Loader=yaml.SafeLoader)):
spec['index'] = ix
name = spec.get('name', spec['coder']['urn'].split(':')[-2])
yield [name, spec]
def parse_float(s):
x = float(s)
if math.isnan(x):
# In Windows, float('NaN') has opposite sign from other platforms.
# For the purpose of this test, we just need consistency.
x = abs(x)
return x
def value_parser_from_schema(schema):
def attribute_parser_from_type(type_):
# TODO: This should be exhaustive
type_info = type_.WhichOneof("type_info")
if type_info == "atomic_type":
return schemas.ATOMIC_TYPE_TO_PRIMITIVE[type_.atomic_type]
elif type_info == "array_type":
element_parser = attribute_parser_from_type(type_.array_type.element_type)
return lambda x: list(map(element_parser, x))
elif type_info == "map_type":
key_parser = attribute_parser_from_type(type_.array_type.key_type)
value_parser = attribute_parser_from_type(type_.array_type.value_type)
return lambda x: dict(
(key_parser(k), value_parser(v)) for k, v in x.items())
parsers = [(field.name, attribute_parser_from_type(field.type))
for field in schema.fields]
constructor = schemas.named_tuple_from_schema(schema)
def value_parser(x):
result = []
for name, parser in parsers:
value = x.pop(name)
result.append(None if value is None else parser(value))
if len(x):
raise ValueError(
"Test data contains attributes that don't exist in the schema: {}".
format(', '.join(x.keys())))
return constructor(*result)
return value_parser
class StandardCodersTest(unittest.TestCase):
_urn_to_json_value_parser = {
'beam:coder:bytes:v1': lambda x: x.encode('utf-8'),
'beam:coder:bool:v1': lambda x: x,
'beam:coder:string_utf8:v1': lambda x: x,
'beam:coder:varint:v1': lambda x: x,
'beam:coder:kv:v1': lambda x,
key_parser,
value_parser: (key_parser(x['key']), value_parser(x['value'])),
'beam:coder:interval_window:v1': lambda x: IntervalWindow(
start=Timestamp(micros=(x['end'] - x['span']) * 1000),
end=Timestamp(micros=x['end'] * 1000)),
'beam:coder:iterable:v1': lambda x,
parser: list(map(parser, x)),
'beam:coder:global_window:v1': lambda x: window.GlobalWindow(),
'beam:coder:windowed_value:v1': lambda x,
value_parser,
window_parser: windowed_value.create(
value_parser(x['value']),
x['timestamp'] * 1000,
tuple([window_parser(w) for w in x['windows']])),
'beam:coder:param_windowed_value:v1': lambda x,
value_parser,
window_parser: windowed_value.create(
value_parser(x['value']),
x['timestamp'] * 1000,
tuple([window_parser(w) for w in x['windows']]),
PaneInfo(
x['pane']['is_first'],
x['pane']['is_last'],
PaneInfoTiming.from_string(x['pane']['timing']),
x['pane']['index'],
x['pane']['on_time_index'])),
'beam:coder:timer:v1': lambda x,
value_parser,
window_parser: userstate.Timer(
user_key=value_parser(x['userKey']),
dynamic_timer_tag=x['dynamicTimerTag'],
clear_bit=x['clearBit'],
windows=tuple([window_parser(w) for w in x['windows']]),
fire_timestamp=None,
hold_timestamp=None,
paneinfo=None) if x['clearBit'] else userstate.Timer(
user_key=value_parser(x['userKey']),
dynamic_timer_tag=x['dynamicTimerTag'],
clear_bit=x['clearBit'],
fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000),
hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000),
windows=tuple([window_parser(w) for w in x['windows']]),
paneinfo=PaneInfo(
x['pane']['is_first'],
x['pane']['is_last'],
PaneInfoTiming.from_string(x['pane']['timing']),
x['pane']['index'],
x['pane']['on_time_index'])),
'beam:coder:double:v1': parse_float,
}
def test_standard_coders(self):
for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
logging.info('Executing %s test.', name)
self._run_standard_coder(name, spec)
def _run_standard_coder(self, name, spec):
def assert_equal(actual, expected):
"""Handle nan values which self.assertEqual fails on."""
if (isinstance(actual, float) and isinstance(expected, float) and
math.isnan(actual) and math.isnan(expected)):
return
self.assertEqual(actual, expected)
coder = self.parse_coder(spec['coder'])
parse_value = self.json_value_parser(spec['coder'])
nested_list = [spec['nested']] if 'nested' in spec else [True, False]
for nested in nested_list:
for expected_encoded, json_value in spec['examples'].items():
value = parse_value(json_value)
expected_encoded = expected_encoded.encode('latin1')
if not spec['coder'].get('non_deterministic', False):
actual_encoded = encode_nested(coder, value, nested)
if self.fix and actual_encoded != expected_encoded:
self.to_fix[spec['index'], expected_encoded] = actual_encoded
else:
self.assertEqual(expected_encoded, actual_encoded)
decoded = decode_nested(coder, expected_encoded, nested)
assert_equal(decoded, value)
else:
# Only verify decoding for a non-deterministic coder
self.assertEqual(
decode_nested(coder, expected_encoded, nested), value)
def parse_coder(self, spec):
context = pipeline_context.PipelineContext()
coder_id = str(hash(str(spec)))
component_ids = [
context.coders.get_id(self.parse_coder(c))
for c in spec.get('components', ())
]
context.coders.put_proto(
coder_id,
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=spec['urn'],
payload=spec.get('payload', '').encode('latin1')),
component_coder_ids=component_ids))
return context.coders.get_by_id(coder_id)
def json_value_parser(self, coder_spec):
# TODO: integrate this with the logic for the other parsers
if coder_spec['urn'] == 'beam:coder:row:v1':
schema = schema_pb2.Schema.FromString(
coder_spec['payload'].encode('latin1'))
return value_parser_from_schema(schema)
component_parsers = [
self.json_value_parser(c) for c in coder_spec.get('components', ())
]
return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
x, *component_parsers)
# Used when --fix is passed.
fix = False
to_fix = {} # type: Dict[Tuple[int, bytes], bytes]
@classmethod
def tearDownClass(cls):
if cls.fix and cls.to_fix:
print("FIXING", len(cls.to_fix), "TESTS")
doc_sep = '\n---\n'
docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)
def quote(s):
return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')
for (doc_ix, expected_encoded), actual_encoded in cls.to_fix.items():
print(quote(expected_encoded), "->", quote(actual_encoded))
docs[doc_ix] = docs[doc_ix].replace(
quote(expected_encoded) + ':', quote(actual_encoded) + ':')
open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
def encode_nested(coder, value, nested=True):
out = coder_impl.create_OutputStream()
coder.get_impl().encode_to_stream(value, out, nested)
return out.get()
def decode_nested(coder, encoded, nested=True):
return coder.get_impl().decode_from_stream(
coder_impl.create_InputStream(encoded), nested)
if __name__ == '__main__':
if '--fix' in sys.argv:
StandardCodersTest.fix = True
sys.argv.remove('--fix')
unittest.main()