blob: 6b0f9b098db7c227d449237393d7c5e3db06a65c [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Tests common to all coder implementations."""
import decimal
import logging
import unittest
from pyflink.fn_execution.coders import BigIntCoder, TinyIntCoder, BooleanCoder, \
SmallIntCoder, IntCoder, FloatCoder, DoubleCoder, BinaryCoder, CharCoder, DateCoder, \
TimeCoder, TimestampCoder, BasicArrayCoder, MapCoder, DecimalCoder, FlattenRowCoder, RowCoder, \
LocalZonedTimestampCoder, BigDecimalCoder, TupleCoder, PrimitiveArrayCoder
from pyflink.testing.test_case_utils import PyFlinkTestCase
try:
from pyflink.fn_execution import coder_impl_fast # noqa # pylint: disable=unused-import
have_cython = True
except ImportError:
have_cython = False
@unittest.skipIf(have_cython,
"Found cython implementation, we don't need to test non-compiled implementation")
class CodersTest(PyFlinkTestCase):
def check_coder(self, coder, *values):
coder_impl = coder.get_impl()
for v in values:
if isinstance(v, float):
from pyflink.table.tests.test_udf import float_equal
assert float_equal(v, coder_impl.decode(coder_impl.encode(v)), 1e-6)
else:
self.assertEqual(v, coder_impl.decode(coder_impl.encode(v)))
# decide whether two floats are equal
@staticmethod
def float_equal(a, b, rel_tol=1e-09, abs_tol=0.0):
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
def test_bigint_coder(self):
coder = BigIntCoder()
self.check_coder(coder, 1, 100, -100, -1000)
def test_tinyint_coder(self):
coder = TinyIntCoder()
self.check_coder(coder, 1, 10, 127, -128)
def test_boolean_coder(self):
coder = BooleanCoder()
self.check_coder(coder, True, False)
def test_smallint_coder(self):
coder = SmallIntCoder()
self.check_coder(coder, 32767, -32768, 0)
def test_int_coder(self):
coder = IntCoder()
self.check_coder(coder, -2147483648, 2147483647)
def test_float_coder(self):
coder = FloatCoder()
self.check_coder(coder, 1.02, 1.32)
def test_double_coder(self):
coder = DoubleCoder()
self.check_coder(coder, -12.02, 1.98932)
def test_binary_coder(self):
coder = BinaryCoder()
self.check_coder(coder, b'pyflink')
def test_char_coder(self):
coder = CharCoder()
self.check_coder(coder, 'flink', '🐿')
def test_date_coder(self):
import datetime
coder = DateCoder()
self.check_coder(coder, datetime.date(2019, 9, 10))
def test_time_coder(self):
import datetime
coder = TimeCoder()
self.check_coder(coder, datetime.time(hour=11, minute=11, second=11, microsecond=123000))
def test_timestamp_coder(self):
import datetime
coder = TimestampCoder(3)
self.check_coder(coder, datetime.datetime(2019, 9, 10, 18, 30, 20, 123000))
coder = TimestampCoder(6)
self.check_coder(coder, datetime.datetime(2019, 9, 10, 18, 30, 20, 123456))
def test_local_zoned_timestamp_coder(self):
import datetime
import pytz
timezone = pytz.timezone("Asia/Shanghai")
coder = LocalZonedTimestampCoder(3, timezone)
self.check_coder(coder,
timezone.localize(datetime.datetime(2019, 9, 10, 18, 30, 20, 123000)))
coder = LocalZonedTimestampCoder(6, timezone)
self.check_coder(coder,
timezone.localize(datetime.datetime(2019, 9, 10, 18, 30, 20, 123456)))
def test_array_coder(self):
element_coder = BigIntCoder()
coder = BasicArrayCoder(element_coder)
self.check_coder(coder, [1, 2, 3, None])
def test_primitive_array_coder(self):
element_coder = CharCoder()
coder = PrimitiveArrayCoder(element_coder)
self.check_coder(coder, ['hi', 'hello', 'flink'])
def test_map_coder(self):
key_coder = CharCoder()
value_coder = BigIntCoder()
coder = MapCoder(key_coder, value_coder)
self.check_coder(coder, {'flink': 1, 'pyflink': 2, 'coder': None})
def test_decimal_coder(self):
import decimal
coder = DecimalCoder(38, 18)
self.check_coder(coder, decimal.Decimal('0.00001'), decimal.Decimal('1.23E-8'))
coder = DecimalCoder(4, 3)
decimal.getcontext().prec = 2
self.check_coder(coder, decimal.Decimal('1.001'))
self.assertEqual(decimal.getcontext().prec, 2)
def test_flatten_row_coder(self):
field_coder = BigIntCoder()
field_count = 10
coder = FlattenRowCoder([field_coder for _ in range(field_count)]).get_impl()
v = [None if i % 2 == 0 else i for i in range(field_count)]
generator_result = coder.decode(coder.encode(v))
result = []
for item in generator_result:
result.append(item)
self.assertEqual([v], result)
def test_row_coder(self):
from pyflink.common import Row, RowKind
field_coder = BigIntCoder()
field_count = 10
field_names = ['f{}'.format(i) for i in range(field_count)]
coder = RowCoder([field_coder for _ in range(field_count)], field_names)
v = Row(**{field_names[i]: None if i % 2 == 0 else i for i in range(field_count)})
v.set_row_kind(RowKind.INSERT)
self.check_coder(coder, v)
v.set_row_kind(RowKind.UPDATE_BEFORE)
self.check_coder(coder, v)
v.set_row_kind(RowKind.UPDATE_AFTER)
self.check_coder(coder, v)
v.set_row_kind(RowKind.DELETE)
self.check_coder(coder, v)
def test_basic_decimal_coder(self):
basic_dec_coder = BigDecimalCoder()
value = decimal.Decimal(1.200)
self.check_coder(basic_dec_coder, value)
def test_tuple_coder(self):
field_coders = [IntCoder(), CharCoder(), CharCoder()]
tuple_coder = TupleCoder(field_coders=field_coders)
data = (1, "Hello", "Hi")
self.check_coder(tuple_coder, data)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()