blob: 8efe404aa053dc57462a7e9cf7f4f8fb4f73b024 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import array
import ctypes
import datetime
import pickle
import sys
import tempfile
import unittest
from pyflink.pyflink_gateway_server import on_windows
from pyflink.serializers import BatchedSerializer, PickleSerializer
from pyflink.java_gateway import get_gateway
from pyflink.table.types import (_infer_schema_from_data, _infer_type,
_array_signed_int_typecode_ctype_mappings,
_array_unsigned_int_typecode_ctype_mappings,
_array_type_mappings, _merge_type,
_create_type_verifier, UserDefinedType, DataTypes, Row, RowField,
RowType, ArrayType, BigIntType, VarCharType, MapType, DataType,
_to_java_type, _from_java_type, ZonedTimestampType,
LocalZonedTimestampType)
from pyflink.testing.test_case_utils import PyFlinkTestCase
class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""
@classmethod
def sql_type(cls):
return DataTypes.ARRAY(DataTypes.DOUBLE(False))
@classmethod
def module(cls):
return 'pyflink.table.tests.test_types'
@classmethod
def java_udt(cls):
return 'org.apache.flink.table.types.python.ExamplePointUserDefinedType'
def serialize(self, obj):
return [obj.x, obj.y]
def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])
class ExamplePoint:
"""
An example class to demonstrate UDT in Java, and Python.
"""
__UDT__ = ExamplePointUDT()
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "ExamplePoint(%s,%s)" % (self.x, self.y)
def __str__(self):
return "(%s,%s)" % (self.x, self.y)
def __eq__(self, other):
return isinstance(other, self.__class__) and other.x == self.x and other.y == self.y
class PythonOnlyUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""
@classmethod
def sql_type(cls):
return DataTypes.ARRAY(DataTypes.DOUBLE(False))
@classmethod
def module(cls):
return '__main__'
def serialize(self, obj):
return [obj.x, obj.y]
def deserialize(self, datum):
return PythonOnlyPoint(datum[0], datum[1])
class PythonOnlyPoint(ExamplePoint):
"""
An example class to demonstrate UDT in only Python
"""
__UDT__ = PythonOnlyUDT()
class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
"""
def __init__(self, offset=0):
self.OFFSET = datetime.timedelta(hours=offset)
def utcoffset(self, dt):
return self.OFFSET
def dst(self, dt):
return self.OFFSET
class TypesTests(PyFlinkTestCase):
def test_infer_schema(self):
from decimal import Decimal
class A(object):
def __init__(self):
self.a = 1
from collections import namedtuple
Point = namedtuple('Point', 'x y')
data = [
True,
1,
"a",
u"a",
datetime.date(1970, 1, 1),
datetime.time(0, 0, 0),
datetime.datetime(1970, 1, 1, 0, 0),
1.0,
array.array("d", [1]),
[1],
(1,),
Point(1.0, 5.0),
{"a": 1},
bytearray(1),
Decimal(1),
Row(a=1),
Row("a")(1),
A(),
]
expected = [
'BooleanType(true)',
'BigIntType(true)',
'VarCharType(2147483647, true)',
'VarCharType(2147483647, true)',
'DateType(true)',
'TimeType(0, true)',
'LocalZonedTimestampType(6, true)',
'DoubleType(true)',
"ArrayType(DoubleType(false), true)",
"ArrayType(BigIntType(true), true)",
'RowType(RowField(_1, BigIntType(true), ...))',
'RowType(RowField(x, DoubleType(true), ...),RowField(y, DoubleType(true), ...))',
'MapType(VarCharType(2147483647, false), BigIntType(true), true)',
'VarBinaryType(2147483647, true)',
'DecimalType(38, 18, true)',
'RowType(RowField(a, BigIntType(true), ...))',
'RowType(RowField(a, BigIntType(true), ...))',
'RowType(RowField(a, BigIntType(true), ...))',
]
schema = _infer_schema_from_data([data])
self.assertEqual(expected, [repr(f.data_type) for f in schema.fields])
def test_infer_schema_nulltype(self):
elements = [Row(c1=[], c2={}, c3=None),
Row(c1=[Row(a=1, b='s')], c2={"key": Row(c=1.0, d="2")}, c3="")]
schema = _infer_schema_from_data(elements)
self.assertTrue(isinstance(schema, RowType))
self.assertEqual(3, len(schema.fields))
# first column is array
self.assertTrue(isinstance(schema.fields[0].data_type, ArrayType))
# element type of first column is struct
self.assertTrue(isinstance(schema.fields[0].data_type.element_type, RowType))
self.assertTrue(isinstance(schema.fields[0].data_type.element_type.fields[0].data_type,
BigIntType))
self.assertTrue(isinstance(schema.fields[0].data_type.element_type.fields[1].data_type,
VarCharType))
# second column is map
self.assertTrue(isinstance(schema.fields[1].data_type, MapType))
self.assertTrue(isinstance(schema.fields[1].data_type.key_type, VarCharType))
self.assertTrue(isinstance(schema.fields[1].data_type.value_type, RowType))
# third column is varchar
self.assertTrue(isinstance(schema.fields[2].data_type, VarCharType))
def test_infer_schema_not_enough_names(self):
schema = _infer_schema_from_data([["a", "b"]], ["col1"])
self.assertTrue(schema.names, ['col1', '_2'])
def test_infer_schema_fails(self):
with self.assertRaises(TypeError):
_infer_schema_from_data([[1, 1], ["x", 1]], names=["a", "b"])
def test_infer_nested_schema(self):
NestedRow = Row("f1", "f2")
data1 = [NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]
schema1 = _infer_schema_from_data(data1)
expected1 = [
'ArrayType(BigIntType(true), true)',
'MapType(VarCharType(2147483647, false), DoubleType(true), true)'
]
self.assertEqual(expected1, [repr(f.data_type) for f in schema1.fields])
data2 = [NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]
schema2 = _infer_schema_from_data(data2)
expected2 = [
'ArrayType(ArrayType(BigIntType(true), true), true)',
'ArrayType(BigIntType(true), true)'
]
self.assertEqual(expected2, [repr(f.data_type) for f in schema2.fields])
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.as_dict()['l'][0].a)
self.assertEqual(1.0, row.as_dict()['d']['key'].c)
def test_udt(self):
p = ExamplePoint(1.0, 2.0)
self.assertEqual(_infer_type(p), ExamplePointUDT())
_create_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
self.assertRaises(ValueError, lambda: _create_type_verifier(ExamplePointUDT())([1.0, 2.0]))
p = PythonOnlyPoint(1.0, 2.0)
self.assertEqual(_infer_type(p), PythonOnlyUDT())
_create_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
self.assertRaises(ValueError, lambda: _create_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
def test_nested_udt_in_df(self):
expected_schema = DataTypes.ROW() \
.add("_1", DataTypes.BIGINT()).add("_2", DataTypes.ARRAY(PythonOnlyUDT()))
data = (1, [PythonOnlyPoint(float(1), float(2))])
self.assertEqual(expected_schema, _infer_type(data))
expected_schema = DataTypes.ROW().add("_1", DataTypes.BIGINT()).add(
"_2", DataTypes.MAP(DataTypes.BIGINT(False), PythonOnlyUDT()))
p = (1, {1: PythonOnlyPoint(1, float(2))})
self.assertEqual(expected_schema, _infer_type(p))
def test_struct_type(self):
row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
.add("f2", DataTypes.STRING(nullable=True))
row2 = DataTypes.ROW([DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)),
DataTypes.FIELD("f2", DataTypes.STRING(nullable=True), None)])
self.assertEqual(row1.field_names(), row2.names)
self.assertEqual(row1, row2)
row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
.add("f2", DataTypes.STRING(nullable=True))
row2 = DataTypes.ROW([DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))])
self.assertNotEqual(row1.field_names(), row2.names)
self.assertNotEqual(row1, row2)
row1 = (DataTypes.ROW().add(DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)))
.add("f2", DataTypes.STRING(nullable=True)))
row2 = DataTypes.ROW([DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)),
DataTypes.FIELD("f2", DataTypes.STRING(nullable=True))])
self.assertEqual(row1.field_names(), row2.names)
self.assertEqual(row1, row2)
row1 = (DataTypes.ROW().add(DataTypes.FIELD("f1", DataTypes.STRING(nullable=True)))
.add("f2", DataTypes.STRING(nullable=True)))
row2 = DataTypes.ROW([DataTypes.FIELD("f1", DataTypes.STRING(nullable=True))])
self.assertNotEqual(row1.field_names(), row2.names)
self.assertNotEqual(row1, row2)
# Catch exception raised during improper construction
self.assertRaises(ValueError, lambda: DataTypes.ROW().add("name"))
row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
.add("f2", DataTypes.STRING(nullable=True))
for field in row1:
self.assertIsInstance(field, RowField)
row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
.add("f2", DataTypes.STRING(nullable=True))
self.assertEqual(len(row1), 2)
row1 = DataTypes.ROW().add("f1", DataTypes.STRING(nullable=True)) \
.add("f2", DataTypes.STRING(nullable=True))
self.assertIs(row1["f1"], row1.fields[0])
self.assertIs(row1[0], row1.fields[0])
self.assertEqual(row1[0:1], DataTypes.ROW(row1.fields[0:1]))
self.assertRaises(KeyError, lambda: row1["f9"])
self.assertRaises(IndexError, lambda: row1[9])
self.assertRaises(TypeError, lambda: row1[9.9])
def test_infer_bigint_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
schema = _infer_schema_from_data(longrow)
self.assertEqual(DataTypes.BIGINT(), schema.fields[1].data_type)
self.assertEqual(DataTypes.BIGINT(), _infer_type(1))
self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 10))
self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 20))
self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 31 - 1))
self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 31))
self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 61))
self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 71))
def test_merge_type(self):
self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.NULL()), DataTypes.BIGINT())
self.assertEqual(_merge_type(DataTypes.NULL(), DataTypes.BIGINT()), DataTypes.BIGINT())
self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.BIGINT()), DataTypes.BIGINT())
self.assertEqual(_merge_type(
DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.ARRAY(DataTypes.BIGINT())
), DataTypes.ARRAY(DataTypes.BIGINT()))
with self.assertRaises(TypeError):
_merge_type(DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.ARRAY(DataTypes.DOUBLE()))
self.assertEqual(_merge_type(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())
), DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT()))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT()),
DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()))
self.assertEqual(_merge_type(
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
DataTypes.FIELD('f2', DataTypes.STRING())]),
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
DataTypes.FIELD('f2', DataTypes.STRING())])
), DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
DataTypes.FIELD('f2', DataTypes.STRING())]))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
DataTypes.FIELD('f2', DataTypes.STRING())]),
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.DOUBLE()),
DataTypes.FIELD('f2', DataTypes.STRING())]))
self.assertEqual(_merge_type(
DataTypes.ROW([DataTypes.FIELD(
'f1', DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())]))]),
DataTypes.ROW([DataTypes.FIELD(
'f1', DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())]))])
), DataTypes.ROW([DataTypes.FIELD(
'f1', DataTypes.ROW([DataTypes.FIELD('f2', DataTypes.BIGINT())]))]))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ROW(
[DataTypes.FIELD('f2', DataTypes.BIGINT())]))]),
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ROW(
[DataTypes.FIELD('f2', DataTypes.STRING())]))]))
self.assertEqual(_merge_type(
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())]),
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())])
), DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())]))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())]),
DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.DOUBLE())),
DataTypes.FIELD('f2', DataTypes.STRING())]))
self.assertEqual(_merge_type(
DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())]),
DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())])
), DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())]))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
DataTypes.FIELD('f2', DataTypes.STRING())]),
DataTypes.ROW([
DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())),
DataTypes.FIELD('f2', DataTypes.STRING())]))
self.assertEqual(_merge_type(
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())))]),
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())))])
), DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())))]))
with self.assertRaises(TypeError):
_merge_type(
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())))]),
DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT())))])
)
def test_array_types(self):
# This test need to make sure that the Scala type selected is at least
# as large as the python's types. This is necessary because python's
# array types depend on C implementation on the machine. Therefore there
# is no machine independent correspondence between python's array types
# and Scala types.
# See: https://docs.python.org/2/library/array.html
def assert_collect_success(typecode, value, element_type):
self.assertEqual(element_type,
str(_infer_type(array.array(typecode, [value])).element_type))
# supported string types
#
# String types in python's array are "u" for Py_UNICODE and "c" for char.
# "u" will be removed in python 4, and "c" is not supported in python 3.
supported_string_types = []
if sys.version_info[0] < 4:
supported_string_types += ['u']
# test unicode
assert_collect_success('u', u'a', 'CHAR')
# supported float and double
#
# Test max, min, and precision for float and double, assuming IEEE 754
# floating-point format.
supported_fractional_types = ['f', 'd']
assert_collect_success('f', ctypes.c_float(1e+38).value, 'FLOAT')
assert_collect_success('f', ctypes.c_float(1e-38).value, 'FLOAT')
assert_collect_success('f', ctypes.c_float(1.123456).value, 'FLOAT')
assert_collect_success('d', sys.float_info.max, 'DOUBLE')
assert_collect_success('d', sys.float_info.min, 'DOUBLE')
assert_collect_success('d', sys.float_info.epsilon, 'DOUBLE')
def get_int_data_type(size):
if size <= 8:
return "TINYINT"
if size <= 16:
return "SMALLINT"
if size <= 32:
return "INT"
if size <= 64:
return "BIGINT"
# supported signed int types
#
# The size of C types changes with implementation, we need to make sure
# that there is no overflow error on the platform running this test.
supported_signed_int_types = list(
set(_array_signed_int_typecode_ctype_mappings.keys()).intersection(
set(_array_type_mappings.keys())))
for t in supported_signed_int_types:
ctype = _array_signed_int_typecode_ctype_mappings[t]
max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
assert_collect_success(t, max_val - 1, get_int_data_type(ctypes.sizeof(ctype) * 8))
assert_collect_success(t, -max_val, get_int_data_type(ctypes.sizeof(ctype) * 8))
# supported unsigned int types
#
# JVM does not have unsigned types. We need to be very careful to make
# sure that there is no overflow error.
supported_unsigned_int_types = list(
set(_array_unsigned_int_typecode_ctype_mappings.keys()).intersection(
set(_array_type_mappings.keys())))
for t in supported_unsigned_int_types:
ctype = _array_unsigned_int_typecode_ctype_mappings[t]
max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
assert_collect_success(t, max_val, get_int_data_type(ctypes.sizeof(ctype) * 8 + 1))
# all supported types
#
# Make sure the types tested above:
# 1. are all supported types
# 2. cover all supported types
supported_types = (supported_string_types +
supported_fractional_types +
supported_signed_int_types +
supported_unsigned_int_types)
self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
# all unsupported types
#
# Keys in _array_type_mappings is a complete list of all supported types,
# and types not in _array_type_mappings are considered unsupported.
all_types = set(array.typecodes)
unsupported_types = all_types - set(supported_types)
# test unsupported types
for t in unsupported_types:
with self.assertRaises(TypeError):
_infer_schema_from_data([Row(myarray=array.array(t))])
def test_data_type_eq(self):
lt = DataTypes.BIGINT()
lt2 = pickle.loads(pickle.dumps(DataTypes.BIGINT()))
self.assertEqual(lt, lt2)
def test_decimal_type(self):
t1 = DataTypes.DECIMAL(10, 0)
t2 = DataTypes.DECIMAL(10, 2)
self.assertTrue(t2 is not t1)
self.assertNotEqual(t1, t2)
def test_datetype_equal_zero(self):
dt = DataTypes.DATE()
self.assertEqual(dt.from_sql_type(0), datetime.date(1970, 1, 1))
@unittest.skipIf(on_windows(), "Windows x64 system only support the datetime not larger "
"than time.ctime(32536799999), so this test can't run "
"under Windows platform")
def test_timestamp_microsecond(self):
tst = DataTypes.TIMESTAMP()
self.assertEqual(tst.to_sql_type(datetime.datetime.max) % 1000000, 999999)
@unittest.skipIf(on_windows(), "Windows x64 system only support the datetime not larger "
"than time.ctime(32536799999), so this test can't run "
"under Windows platform")
def test_local_zoned_timestamp_type(self):
lztst = DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE()
ts = datetime.datetime(1970, 1, 1, 0, 0, 0, 0000)
self.assertEqual(0, lztst.to_sql_type(ts))
import pytz
# suppose the timezone of the data is +9:00
timezone = pytz.timezone("Asia/Tokyo")
orig_epoch = LocalZonedTimestampType.EPOCH_ORDINAL
try:
# suppose the local timezone is +8:00
LocalZonedTimestampType.EPOCH_ORDINAL = 28800000000
ts_tokyo = timezone.localize(ts)
self.assertEqual(-3600000000, lztst.to_sql_type(ts_tokyo))
finally:
LocalZonedTimestampType.EPOCH_ORDINAL = orig_epoch
if sys.version_info >= (3, 6):
ts2 = lztst.from_sql_type(0)
self.assertEqual(ts.astimezone(), ts2.astimezone())
def test_zoned_timestamp_type(self):
ztst = ZonedTimestampType()
ts = datetime.datetime(1970, 1, 1, 0, 0, 0, 0000, tzinfo=UTCOffsetTimezone(1))
self.assertEqual((0, 3600), ztst.to_sql_type(ts))
ts2 = ztst.from_sql_type((0, 3600))
self.assertEqual(ts, ts2)
def test_day_time_inteval_type(self):
ymt = DataTypes.INTERVAL(DataTypes.DAY(), DataTypes.SECOND())
td = datetime.timedelta(days=1, seconds=10)
self.assertEqual(86410000000, ymt.to_sql_type(td))
td2 = ymt.from_sql_type(86410000000)
self.assertEqual(td, td2)
def test_empty_row(self):
row = Row()
self.assertEqual(len(row), 0)
def test_invalid_create_row(self):
row_class = Row("c1", "c2")
self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
def test_nullable(self):
t = DataType(nullable=False)
self.assertEqual(t._nullable, False)
t_nullable = t.nullable()
self.assertEqual(t_nullable._nullable, True)
def test_not_null(self):
t = DataType(nullable=True)
self.assertEqual(t._nullable, True)
t_notnull = t.not_null()
self.assertEqual(t_notnull._nullable, False)
class DataTypeVerificationTests(PyFlinkTestCase):
def test_verify_type_exception_msg(self):
self.assertRaises(
ValueError,
lambda: _create_type_verifier(
DataTypes.STRING(nullable=False), name="test_name")(None))
schema = DataTypes.ROW(
[DataTypes.FIELD('a', DataTypes.ROW([DataTypes.FIELD('b', DataTypes.INT())]))])
self.assertRaises(
TypeError,
lambda: _create_type_verifier(schema)([["data"]]))
def test_verify_type_ok_nullable(self):
obj = None
types = [DataTypes.INT(), DataTypes.FLOAT(), DataTypes.STRING(), DataTypes.ROW([])]
for data_type in types:
try:
_create_type_verifier(data_type)(obj)
except (TypeError, ValueError):
self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))
def test_verify_type_not_nullable(self):
import array
import datetime
import decimal
schema = DataTypes.ROW([
DataTypes.FIELD('s', DataTypes.STRING(nullable=False)),
DataTypes.FIELD('i', DataTypes.INT(True))])
class MyObj:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
# obj, data_type
success_spec = [
# String
("", DataTypes.STRING()),
(u"", DataTypes.STRING()),
# UDT
(ExamplePoint(1.0, 2.0), ExamplePointUDT()),
# Boolean
(True, DataTypes.BOOLEAN()),
# TinyInt
(-(2 ** 7), DataTypes.TINYINT()),
(2 ** 7 - 1, DataTypes.TINYINT()),
# SmallInt
(-(2 ** 15), DataTypes.SMALLINT()),
(2 ** 15 - 1, DataTypes.SMALLINT()),
# Int
(-(2 ** 31), DataTypes.INT()),
(2 ** 31 - 1, DataTypes.INT()),
# BigInt
(2 ** 64, DataTypes.BIGINT()),
# Float & Double
(1.0, DataTypes.FLOAT()),
(1.0, DataTypes.DOUBLE()),
# Decimal
(decimal.Decimal("1.0"), DataTypes.DECIMAL(10, 0)),
# Binary
(bytearray([1]), DataTypes.BINARY(1)),
# Date/Time/Timestamp
(datetime.date(2000, 1, 2), DataTypes.DATE()),
(datetime.datetime(2000, 1, 2, 3, 4), DataTypes.DATE()),
(datetime.time(1, 1, 2), DataTypes.TIME()),
(datetime.datetime(2000, 1, 2, 3, 4), DataTypes.TIMESTAMP()),
# Array
([], DataTypes.ARRAY(DataTypes.INT())),
(["1", None], DataTypes.ARRAY(DataTypes.STRING(nullable=True))),
([1, 2], DataTypes.ARRAY(DataTypes.INT())),
((1, 2), DataTypes.ARRAY(DataTypes.INT())),
(array.array('h', [1, 2]), DataTypes.ARRAY(DataTypes.INT())),
# Map
({}, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())),
({"a": 1}, DataTypes.MAP(DataTypes.STRING(), DataTypes.INT())),
({"a": None}, DataTypes.MAP(DataTypes.STRING(nullable=False), DataTypes.INT(True))),
# Struct
({"s": "a", "i": 1}, schema),
({"s": "a", "i": None}, schema),
({"s": "a"}, schema),
({"s": "a", "f": 1.0}, schema),
(Row(s="a", i=1), schema),
(Row(s="a", i=None), schema),
(Row(s="a", i=1, f=1.0), schema),
(["a", 1], schema),
(["a", None], schema),
(("a", 1), schema),
(MyObj(s="a", i=1), schema),
(MyObj(s="a", i=None), schema),
(MyObj(s="a"), schema),
]
# obj, data_type, exception class
failure_spec = [
# Char/VarChar (match anything but None)
(None, DataTypes.VARCHAR(1), ValueError),
(None, DataTypes.CHAR(1), ValueError),
# VarChar (length exceeds maximum length)
("abc", DataTypes.VARCHAR(1), ValueError),
# Char (length exceeds length)
("abc", DataTypes.CHAR(1), ValueError),
# UDT
(ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
# Boolean
(1, DataTypes.BOOLEAN(), TypeError),
("True", DataTypes.BOOLEAN(), TypeError),
([1], DataTypes.BOOLEAN(), TypeError),
# TinyInt
(-(2 ** 7) - 1, DataTypes.TINYINT(), ValueError),
(2 ** 7, DataTypes.TINYINT(), ValueError),
("1", DataTypes.TINYINT(), TypeError),
(1.0, DataTypes.TINYINT(), TypeError),
# SmallInt
(-(2 ** 15) - 1, DataTypes.SMALLINT(), ValueError),
(2 ** 15, DataTypes.SMALLINT(), ValueError),
# Int
(-(2 ** 31) - 1, DataTypes.INT(), ValueError),
(2 ** 31, DataTypes.INT(), ValueError),
# Float & Double
(1, DataTypes.FLOAT(), TypeError),
(1, DataTypes.DOUBLE(), TypeError),
# Decimal
(1.0, DataTypes.DECIMAL(10, 0), TypeError),
(1, DataTypes.DECIMAL(10, 0), TypeError),
("1.0", DataTypes.DECIMAL(10, 0), TypeError),
# Binary
(1, DataTypes.BINARY(1), TypeError),
# VarBinary (length exceeds maximum length)
(bytearray([1, 2]), DataTypes.VARBINARY(1), ValueError),
# Char (length exceeds length)
(bytearray([1, 2]), DataTypes.BINARY(1), ValueError),
# Date/Time/Timestamp
("2000-01-02", DataTypes.DATE(), TypeError),
("10:01:02", DataTypes.TIME(), TypeError),
(946811040, DataTypes.TIMESTAMP(), TypeError),
# Array
(["1", None], DataTypes.ARRAY(DataTypes.VARCHAR(1, nullable=False)), ValueError),
([1, "2"], DataTypes.ARRAY(DataTypes.INT()), TypeError),
# Map
({"a": 1}, DataTypes.MAP(DataTypes.INT(), DataTypes.INT()), TypeError),
({"a": "1"}, DataTypes.MAP(DataTypes.VARCHAR(1), DataTypes.INT()), TypeError),
({"a": None}, DataTypes.MAP(DataTypes.VARCHAR(1), DataTypes.INT(False)), ValueError),
# Struct
({"s": "a", "i": "1"}, schema, TypeError),
(Row(s="a"), schema, ValueError), # Row can't have missing field
(Row(s="a", i="1"), schema, TypeError),
(["a"], schema, ValueError),
(["a", "1"], schema, TypeError),
(MyObj(s="a", i="1"), schema, TypeError),
(MyObj(s=None, i="1"), schema, ValueError),
]
# Check success cases
for obj, data_type in success_spec:
try:
_create_type_verifier(data_type.not_null())(obj)
except (TypeError, ValueError):
self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))
# Check failure cases
for obj, data_type, exp in failure_spec:
msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
with self.assertRaises(exp, msg=msg):
_create_type_verifier(data_type.not_null())(obj)
class DataTypeConvertTests(PyFlinkTestCase):
def test_basic_type(self):
test_types = [DataTypes.STRING(),
DataTypes.BOOLEAN(),
DataTypes.BYTES(),
DataTypes.TINYINT(),
DataTypes.SMALLINT(),
DataTypes.INT(),
DataTypes.BIGINT(),
DataTypes.FLOAT(),
DataTypes.DOUBLE(),
DataTypes.DATE(),
DataTypes.TIME(),
DataTypes.TIMESTAMP(3)]
java_types = [_to_java_type(item) for item in test_types]
converted_python_types = [_from_java_type(item) for item in java_types]
self.assertEqual(test_types, converted_python_types)
def test_atomic_type_with_data_type_with_parameters(self):
gateway = get_gateway()
JDataTypes = gateway.jvm.DataTypes
java_types = [JDataTypes.TIME(3).notNull(),
JDataTypes.TIMESTAMP(3).notNull(),
JDataTypes.VARBINARY(100).notNull(),
JDataTypes.BINARY(2).notNull(),
JDataTypes.VARCHAR(30).notNull(),
JDataTypes.CHAR(50).notNull(),
JDataTypes.DECIMAL(20, 10).notNull()]
converted_python_types = [_from_java_type(item) for item in java_types]
expected = [DataTypes.TIME(3, False),
DataTypes.TIMESTAMP(3).not_null(),
DataTypes.VARBINARY(100, False),
DataTypes.BINARY(2, False),
DataTypes.VARCHAR(30, False),
DataTypes.CHAR(50, False),
DataTypes.DECIMAL(20, 10, False)]
self.assertEqual(converted_python_types, expected)
# Legacy type tests
Types = gateway.jvm.org.apache.flink.table.api.Types
BlinkBigDecimalTypeInfo = \
gateway.jvm.org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
java_types = [Types.STRING(),
Types.DECIMAL(),
BlinkBigDecimalTypeInfo(12, 5)]
converted_python_types = [_from_java_type(item) for item in java_types]
expected = [DataTypes.VARCHAR(2147483647),
DataTypes.DECIMAL(38, 18),
DataTypes.DECIMAL(12, 5)]
self.assertEqual(converted_python_types, expected)
def test_array_type(self):
# nullable/not_null flag will be lost during the conversion.
test_types = [DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.ARRAY(DataTypes.STRING()),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.BIGINT())),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))]
java_types = [_to_java_type(item) for item in test_types]
converted_python_types = [_from_java_type(item) for item in java_types]
self.assertEqual(test_types, converted_python_types)
def test_multiset_type(self):
test_types = [DataTypes.MULTISET(DataTypes.BIGINT()),
DataTypes.MULTISET(DataTypes.STRING()),
DataTypes.MULTISET(DataTypes.MULTISET(DataTypes.BIGINT())),
DataTypes.MULTISET(DataTypes.MULTISET(DataTypes.STRING()))]
java_types = [_to_java_type(item) for item in test_types]
converted_python_types = [_from_java_type(item) for item in java_types]
self.assertEqual(test_types, converted_python_types)
def test_map_type(self):
test_types = [DataTypes.MAP(DataTypes.BIGINT(), DataTypes.BIGINT()),
DataTypes.MAP(DataTypes.STRING(), DataTypes.STRING()),
DataTypes.MAP(DataTypes.STRING(),
DataTypes.MAP(DataTypes.STRING(), DataTypes.BIGINT())),
DataTypes.MAP(DataTypes.STRING(),
DataTypes.MAP(DataTypes.STRING(), DataTypes.STRING()))]
java_types = [_to_java_type(item) for item in test_types]
converted_python_types = [_from_java_type(item) for item in java_types]
self.assertEqual(test_types, converted_python_types)
def test_row_type(self):
test_types = [DataTypes.ROW([DataTypes.FIELD("a", DataTypes.INT()),
DataTypes.FIELD("b",
DataTypes.ROW(
[DataTypes.FIELD("c",
DataTypes.STRING())]))])]
java_types = [_to_java_type(item) for item in test_types]
converted_python_types = [_from_java_type(item) for item in java_types]
self.assertEqual(test_types, converted_python_types)
def test_list_view_type(self):
test_types = [DataTypes.LIST_VIEW(DataTypes.BIGINT()),
DataTypes.LIST_VIEW(DataTypes.STRING())]
java_types = [_to_java_type(item) for item in test_types]
converted_python_types = [_from_java_type(item) for item in java_types]
self.assertEqual(test_types, converted_python_types)
class DataSerializerTests(PyFlinkTestCase):
def test_java_pickle_deserializer(self):
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
serializer = PickleSerializer()
data = [(1, 2), (3, 4), (5, 6), (7, 8)]
try:
serializer.dump_to_stream(data, temp_file)
finally:
temp_file.close()
gateway = get_gateway()
result = [tuple(int_pair) for int_pair in
list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, False))]
self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])
def test_java_batch_deserializer(self):
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp())
serializer = BatchedSerializer(PickleSerializer(), 2)
data = [(1, 2), (3, 4), (5, 6), (7, 8)]
try:
serializer.dump_to_stream(data, temp_file)
finally:
temp_file.close()
gateway = get_gateway()
result = [tuple(int_pair) for int_pair in
list(gateway.jvm.PythonBridgeUtils.readPythonObjects(temp_file.name, True))]
self.assertEqual(result, [(1, 2), (3, 4), (5, 6), (7, 8)])
if __name__ == "__main__":
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)