blob: 6ac982835cb336fc7c5d7d4ea07cbfb98d2832ac [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import logging
import typing
import unittest
from itertools import chain
import numpy as np
from numpy.testing import assert_array_equal
import apache_beam as beam
from apache_beam.coders import RowCoder
from apache_beam.coders import coder_impl
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.internal import pickler
from apache_beam.portability.api import schema_pb2
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import typing_to_runner_api
from apache_beam.utils.timestamp import Timestamp
Person = typing.NamedTuple(
"Person",
[
("name", str),
("age", np.int32),
("address", typing.Optional[str]),
("aliases", typing.List[str]),
("knows_javascript", bool),
("payload", typing.Optional[bytes]),
("custom_metadata", typing.Mapping[str, int]),
("favorite_time", Timestamp),
])
NullablePerson = typing.NamedTuple(
"NullablePerson",
[("name", typing.Optional[str]), ("age", np.int32),
("address", typing.Optional[str]), ("aliases", typing.List[str]),
("knows_javascript", bool), ("payload", typing.Optional[bytes]),
("custom_metadata", typing.Mapping[str, int]),
("favorite_time", typing.Optional[Timestamp]),
("one_more_field", typing.Optional[str])])
class People(typing.NamedTuple):
primary: Person
partner: typing.Optional[Person]
coders_registry.register_coder(Person, RowCoder)
coders_registry.register_coder(People, RowCoder)
class RowCoderTest(unittest.TestCase):
JON_SNOW = Person(
name="Jon Snow",
age=np.int32(23),
address=None,
aliases=["crow", "wildling"],
knows_javascript=False,
payload=None,
custom_metadata={},
favorite_time=Timestamp.from_rfc3339('2016-03-18T23:22:59.123456Z'),
)
PEOPLE = [
JON_SNOW,
Person(
"Daenerys Targaryen",
np.int32(25),
"Westeros",
["Mother of Dragons"],
False,
None,
{"dragons": 3},
Timestamp.from_rfc3339('1970-04-26T17:46:40Z'),
),
Person(
"Michael Bluth",
np.int32(30),
None, [],
True,
b"I've made a huge mistake", {},
Timestamp.from_rfc3339('2020-08-12T15:51:00.032Z'))
]
def test_row_accepts_trailing_zeros_truncated(self):
expected_coder = RowCoder(
typing_to_runner_api(NullablePerson).row_type.schema)
person = NullablePerson(
None,
np.int32(25),
"Westeros", ["Mother of Dragons"],
False,
None, {"dragons": 3},
None,
"NotNull")
out = expected_coder.encode(person)
# 9 fields, 1 null byte, field 0, 5, 7 are null
new_payload = bytes([9, 1, 1 | 1 << 5 | 1 << 7]) + out[4:]
new_value = expected_coder.decode(new_payload)
self.assertEqual(person, new_value)
def test_create_row_coder_from_named_tuple(self):
expected_coder = RowCoder(typing_to_runner_api(Person).row_type.schema)
real_coder = coders_registry.get_coder(Person)
for test_case in self.PEOPLE:
self.assertEqual(
expected_coder.encode(test_case), real_coder.encode(test_case))
self.assertEqual(
test_case, real_coder.decode(real_coder.encode(test_case)))
def test_create_row_coder_from_nested_named_tuple(self):
expected_coder = RowCoder(typing_to_runner_api(People).row_type.schema)
real_coder = coders_registry.get_coder(People)
for primary in self.PEOPLE:
for other in self.PEOPLE + [None]:
test_case = People(primary=primary, partner=other)
self.assertEqual(
expected_coder.encode(test_case), real_coder.encode(test_case))
self.assertEqual(
test_case, real_coder.decode(real_coder.encode(test_case)))
def test_create_row_coder_from_schema(self):
schema = schema_pb2.Schema(
id="person",
fields=[
schema_pb2.Field(
name="name",
type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)),
schema_pb2.Field(
name="age",
type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)),
schema_pb2.Field(
name="address",
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING, nullable=True)),
schema_pb2.Field(
name="aliases",
type=schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(
element_type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING)))),
schema_pb2.Field(
name="knows_javascript",
type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)),
schema_pb2.Field(
name="payload",
type=schema_pb2.FieldType(
atomic_type=schema_pb2.BYTES, nullable=True)),
schema_pb2.Field(
name="custom_metadata",
type=schema_pb2.FieldType(
map_type=schema_pb2.MapType(
key_type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING),
value_type=schema_pb2.FieldType(
atomic_type=schema_pb2.INT64),
))),
schema_pb2.Field(
name="favorite_time",
type=schema_pb2.FieldType(
logical_type=schema_pb2.LogicalType(
urn="beam:logical_type:micros_instant:v1",
representation=schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
id="micros_instant",
fields=[
schema_pb2.Field(
name="seconds",
type=schema_pb2.FieldType(
atomic_type=schema_pb2.INT64)),
schema_pb2.Field(
name="micros",
type=schema_pb2.FieldType(
atomic_type=schema_pb2.INT64)),
])))))),
])
coder = RowCoder(schema)
for test_case in self.PEOPLE:
self.assertEqual(test_case, coder.decode(coder.encode(test_case)))
@unittest.skip(
"https://github.com/apache/beam/issues/19696 - Overflow behavior in "
"VarIntCoder is currently inconsistent")
def test_overflows(self):
IntTester = typing.NamedTuple(
'IntTester',
[
# TODO(https://github.com/apache/beam/issues/19815): Test int8 and
# int16 here as well when those types are supported
# ('i8', typing.Optional[np.int8]),
# ('i16', typing.Optional[np.int16]),
('i32', typing.Optional[np.int32]),
('i64', typing.Optional[np.int64]),
])
c = RowCoder.from_type_hint(IntTester, None)
no_overflow = chain(
(IntTester(i32=i, i64=None) for i in (-2**31, 2**31 - 1)),
(IntTester(i32=None, i64=i) for i in (-2**63, 2**63 - 1)),
)
# Encode max/min ints to make sure they don't throw any error
for case in no_overflow:
c.encode(case)
overflow = chain(
(IntTester(i32=i, i64=None) for i in (-2**31 - 1, 2**31)),
(IntTester(i32=None, i64=i) for i in (-2**63 - 1, 2**63)),
)
# Encode max+1/min-1 ints to make sure they DO throw an error
# pylint: disable=cell-var-from-loop
for case in overflow:
self.assertRaises(OverflowError, lambda: c.encode(case))
def test_none_in_non_nullable_field_throws(self):
Test = typing.NamedTuple('Test', [('foo', str)])
c = RowCoder.from_type_hint(Test, None)
self.assertRaises(ValueError, lambda: c.encode(Test(foo=None)))
def test_schema_remove_column(self):
fields = [("field1", str), ("field2", str)]
# new schema is missing one field that was in the old schema
Old = typing.NamedTuple('Old', fields)
New = typing.NamedTuple('New', fields[:-1])
old_coder = RowCoder.from_type_hint(Old, None)
new_coder = RowCoder.from_type_hint(New, None)
self.assertEqual(
New("foo"), new_coder.decode(old_coder.encode(Old("foo", "bar"))))
def test_schema_add_column(self):
fields = [("field1", str), ("field2", typing.Optional[str])]
# new schema has one (optional) field that didn't exist in the old schema
Old = typing.NamedTuple('Old', fields[:-1])
New = typing.NamedTuple('New', fields)
old_coder = RowCoder.from_type_hint(Old, None)
new_coder = RowCoder.from_type_hint(New, None)
self.assertEqual(
New("bar", None), new_coder.decode(old_coder.encode(Old("bar"))))
def test_schema_add_column_with_null_value(self):
fields = [("field1", typing.Optional[str]), ("field2", str),
("field3", typing.Optional[str])]
# new schema has one (optional) field that didn't exist in the old schema
Old = typing.NamedTuple('Old', fields[:-1])
New = typing.NamedTuple('New', fields)
old_coder = RowCoder.from_type_hint(Old, None)
new_coder = RowCoder.from_type_hint(New, None)
self.assertEqual(
New(None, "baz", None),
new_coder.decode(old_coder.encode(Old(None, "baz"))))
def test_row_coder_picklable(self):
# occasionally coders can get pickled, RowCoder should be able to handle it
coder = coders_registry.get_coder(Person)
roundtripped = pickler.loads(pickler.dumps(coder))
self.assertEqual(roundtripped, coder)
def test_row_coder_in_pipeine(self):
with TestPipeline() as p:
res = (
p
| beam.Create(self.PEOPLE)
| beam.Filter(lambda person: person.name == "Jon Snow"))
assert_that(res, equal_to([self.JON_SNOW]))
def test_row_coder_nested_struct(self):
Pair = typing.NamedTuple('Pair', [('left', Person), ('right', Person)])
value = Pair(self.PEOPLE[0], self.PEOPLE[1])
coder = RowCoder(typing_to_runner_api(Pair).row_type.schema)
self.assertEqual(value, coder.decode(coder.encode(value)))
def test_encoding_position_reorder_fields(self):
schema1 = schema_pb2.Schema(
id="reorder_test_schema1",
fields=[
schema_pb2.Field(
name="f_int32",
type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
),
schema_pb2.Field(
name="f_str",
type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
),
])
schema2 = schema_pb2.Schema(
id="reorder_test_schema2",
encoding_positions_set=True,
fields=[
schema_pb2.Field(
name="f_str",
type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
encoding_position=1,
),
schema_pb2.Field(
name="f_int32",
type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
encoding_position=0,
),
])
RowSchema1 = named_tuple_from_schema(schema1)
RowSchema2 = named_tuple_from_schema(schema2)
roundtripped = RowCoder(schema2).decode(
RowCoder(schema1).encode(RowSchema1(42, "Hello World!")))
self.assertEqual(RowSchema2(f_int32=42, f_str="Hello World!"), roundtripped)
def test_encoding_position_add_fields_and_reorder(self):
old_schema = schema_pb2.Schema(
id="add_test_old",
fields=[
schema_pb2.Field(
name="f_int32",
type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
),
schema_pb2.Field(
name="f_str",
type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
),
])
new_schema = schema_pb2.Schema(
encoding_positions_set=True,
id="add_test_new",
fields=[
schema_pb2.Field(
name="f_new_str",
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING, nullable=True),
encoding_position=2,
),
schema_pb2.Field(
name="f_int32",
type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
encoding_position=0,
),
schema_pb2.Field(
name="f_str",
type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
encoding_position=1,
),
])
Old = named_tuple_from_schema(old_schema)
New = named_tuple_from_schema(new_schema)
roundtripped = RowCoder(new_schema).decode(
RowCoder(old_schema).encode(Old(42, "Hello World!")))
self.assertEqual(
New(f_new_str=None, f_int32=42, f_str="Hello World!"), roundtripped)
def test_row_coder_fail_early_bad_schema(self):
schema_proto = schema_pb2.Schema(
fields=[
schema_pb2.Field(
name="type_with_no_typeinfo", type=schema_pb2.FieldType())
],
id='bad-schema')
# Should raise an exception referencing the problem field
self.assertRaisesRegex(
ValueError, "type_with_no_typeinfo", lambda: RowCoder(schema_proto))
def test_batch_encode_decode(self):
coder = RowCoder(typing_to_runner_api(Person).row_type.schema).get_impl()
seq_out = coder_impl.create_OutputStream()
for person in self.PEOPLE:
coder.encode_to_stream(person, seq_out, False)
batch_out = coder_impl.create_OutputStream()
columnar = {
field: np.array([getattr(person, field) for person in self.PEOPLE],
ndmin=1,
dtype=object)
for field in Person._fields
}
coder.encode_batch_to_stream(columnar, batch_out)
if seq_out.get() != batch_out.get():
a, b = seq_out.get(), batch_out.get()
N = 25
for k in range(0, max(len(a), len(b)), N):
print(k, a[k:k + N] == b[k:k + N])
print(a[k:k + N])
print(b[k:k + N])
self.assertEqual(seq_out.get(), batch_out.get())
for size in [len(self.PEOPLE) - 1, len(self.PEOPLE), len(self.PEOPLE) + 1]:
dest = {
field: np.ndarray((size, ), dtype=a.dtype)
for field,
a in columnar.items()
}
n = min(size, len(self.PEOPLE))
self.assertEqual(
n,
coder.decode_batch_from_stream(
dest, coder_impl.create_InputStream(seq_out.get())))
for field, a in columnar.items():
assert_array_equal(a[:n], dest[field][:n])
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
unittest.main()