blob: 9dd1bc22ad1c85db4c663cb8bdeebdbb8c73f874 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for schemas."""
from __future__ import absolute_import
import itertools
import sys
import unittest
from typing import ByteString
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
import numpy as np
from past.builtins import unicode
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints.schemas import typing_from_runner_api
from apache_beam.typehints.schemas import typing_to_runner_api
IS_PYTHON_3 = sys.version_info.major > 2
class SchemaTest(unittest.TestCase):
""" Tests for Runner API Schema proto to/from typing conversions
There are two main tests: test_typing_survives_proto_roundtrip, and
test_proto_survives_typing_roundtrip. These are both necessary because Schemas
are cached by ID, so performing just one of them wouldn't necessarily exercise
all code paths.
"""
def test_typing_survives_proto_roundtrip(self):
all_nonoptional_primitives = [
np.int8,
np.int16,
np.int32,
np.int64,
np.float32,
np.float64,
unicode,
bool,
]
# The bytes type cannot survive a roundtrip to/from proto in Python 2.
# In order to use BYTES a user type has to use typing.ByteString (because
# bytes == str, and we map str to STRING).
if IS_PYTHON_3:
all_nonoptional_primitives.extend([bytes])
all_optional_primitives = [
Optional[typ] for typ in all_nonoptional_primitives
]
all_primitives = all_nonoptional_primitives + all_optional_primitives
basic_array_types = [Sequence[typ] for typ in all_primitives]
basic_map_types = [
Mapping[key_type,
value_type] for key_type, value_type in itertools.product(
all_primitives, all_primitives)
]
selected_schemas = [
NamedTuple(
'AllPrimitives',
[('field%d' % i, typ) for i, typ in enumerate(all_primitives)]),
NamedTuple('ComplexSchema', [
('id', np.int64),
('name', unicode),
('optional_map', Optional[Mapping[unicode,
Optional[np.float64]]]),
('optional_array', Optional[Sequence[np.float32]]),
('array_optional', Sequence[Optional[bool]]),
])
]
test_cases = all_primitives + \
basic_array_types + \
basic_map_types + \
selected_schemas
for test_case in test_cases:
self.assertEqual(test_case,
typing_from_runner_api(typing_to_runner_api(test_case)))
def test_proto_survives_typing_roundtrip(self):
all_nonoptional_primitives = [
schema_pb2.FieldType(atomic_type=typ)
for typ in schema_pb2.AtomicType.values()
if typ is not schema_pb2.UNSPECIFIED
]
# The bytes type cannot survive a roundtrip to/from proto in Python 2.
# In order to use BYTES a user type has to use typing.ByteString (because
# bytes == str, and we map str to STRING).
if not IS_PYTHON_3:
all_nonoptional_primitives.remove(
schema_pb2.FieldType(atomic_type=schema_pb2.BYTES))
all_optional_primitives = [
schema_pb2.FieldType(nullable=True, atomic_type=typ)
for typ in schema_pb2.AtomicType.values()
if typ is not schema_pb2.UNSPECIFIED
]
all_primitives = all_nonoptional_primitives + all_optional_primitives
basic_array_types = [
schema_pb2.FieldType(array_type=schema_pb2.ArrayType(element_type=typ))
for typ in all_primitives
]
basic_map_types = [
schema_pb2.FieldType(
map_type=schema_pb2.MapType(
key_type=key_type, value_type=value_type)) for key_type,
value_type in itertools.product(all_primitives, all_primitives)
]
selected_schemas = [
schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
id='32497414-85e8-46b7-9c90-9a9cc62fe390',
fields=[
schema_pb2.Field(name='field%d' % i, type=typ)
for i, typ in enumerate(all_primitives)
]))),
schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
id='dead1637-3204-4bcb-acf8-99675f338600',
fields=[
schema_pb2.Field(
name='id',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.INT64)),
schema_pb2.Field(
name='name',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING)),
schema_pb2.Field(
name='optional_map',
type=schema_pb2.FieldType(
nullable=True,
map_type=schema_pb2.MapType(
key_type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING
),
value_type=schema_pb2.FieldType(
atomic_type=schema_pb2.DOUBLE
)))),
schema_pb2.Field(
name='optional_array',
type=schema_pb2.FieldType(
nullable=True,
array_type=schema_pb2.ArrayType(
element_type=schema_pb2.FieldType(
atomic_type=schema_pb2.FLOAT)
))),
schema_pb2.Field(
name='array_optional',
type=schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(
element_type=schema_pb2.FieldType(
nullable=True,
atomic_type=schema_pb2.BYTES)
))),
]))),
]
test_cases = all_primitives + \
basic_array_types + \
basic_map_types + \
selected_schemas
for test_case in test_cases:
self.assertEqual(test_case,
typing_to_runner_api(typing_from_runner_api(test_case)))
def test_unknown_primitive_raise_valueerror(self):
self.assertRaises(ValueError, lambda: typing_to_runner_api(np.uint32))
def test_unknown_atomic_raise_valueerror(self):
self.assertRaises(
ValueError, lambda: typing_from_runner_api(
schema_pb2.FieldType(atomic_type=schema_pb2.UNSPECIFIED))
)
@unittest.skipIf(IS_PYTHON_3, 'str is acceptable in python 3')
def test_str_raises_error_py2(self):
self.assertRaises(lambda: typing_to_runner_api(str))
self.assertRaises(lambda: typing_to_runner_api(
NamedTuple('Test', [('int', int), ('str', str)])))
def test_int_maps_to_int64(self):
self.assertEqual(
schema_pb2.FieldType(atomic_type=schema_pb2.INT64),
typing_to_runner_api(int))
def test_float_maps_to_float64(self):
self.assertEqual(
schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE),
typing_to_runner_api(float))
def test_trivial_example(self):
MyCuteClass = NamedTuple('MyCuteClass', [
('name', unicode),
('age', Optional[int]),
('interests', List[unicode]),
('height', float),
('blob', ByteString),
])
expected = schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(fields=[
schema_pb2.Field(
name='name',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING),
),
schema_pb2.Field(
name='age',
type=schema_pb2.FieldType(
nullable=True,
atomic_type=schema_pb2.INT64)),
schema_pb2.Field(
name='interests',
type=schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(
element_type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING)))),
schema_pb2.Field(
name='height',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.DOUBLE)),
schema_pb2.Field(
name='blob',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.BYTES)),
])))
# Only test that the fields are equal. If we attempt to test the entire type
# or the entire schema, the generated id will break equality.
self.assertEqual(expected.row_type.schema.fields,
typing_to_runner_api(MyCuteClass).row_type.schema.fields)
if __name__ == '__main__':
unittest.main()