blob: 131eafbc6ddd4344b6ef8ea7b901e21d6c55bd7f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
import logging
import unittest
from collections import OrderedDict
import numpy as np
from parameterized import parameterized
from past.builtins import unicode
from apache_beam.testing import datatype_inference
from apache_beam.typehints import typehints
try:
import pyarrow as pa
except ImportError:
pa = None
TEST_DATA = [
{
"name": "empty",
"data": [],
"type_schema": OrderedDict([]),
"pyarrow_schema": pa.schema([]) if pa is not None else None,
"avro_schema": {
"namespace": "example.avro",
"name": "User",
"type": "record",
"fields": [],
},
},
{
"name":
"main",
"data": [
OrderedDict([
("a", 1),
("b", 0.12345),
("c", u"Hello World!!"),
("d", np.array([1, 2, 3])),
("e", b"some bytes"),
]),
OrderedDict([
("a", -5),
("b", 1234.567),
("e", b"more bytes"),
]),
OrderedDict([
("a", 100000),
("c", u"XoXoX"),
("d", np.array([4, 5, 6])),
("e", b""),
]),
],
"type_schema":
OrderedDict([
("a", int),
("b", float),
("c", unicode),
("d", np.ndarray),
("e", bytes),
]),
"pyarrow_schema":
pa.schema([
("a", pa.int64()),
("b", pa.float64()),
("c", pa.string()),
("d", pa.list_(pa.int64())),
("e", pa.binary()),
]) if pa is not None else None,
"avro_schema": {
"namespace":
"example.avro",
"name":
"User",
"type":
"record",
"fields": [
{"name": "a", "type": "int"},
{"name": "b", "type": "double"},
{"name": "c", "type": "string"},
{"name": "d", "type": "bytes"},
{"name": "e", "type": "bytes"},
],
},
},
]
def nullify_data_and_schemas(test_data):
"""Add a row with all columns set to None and adjust the schemas accordingly.
"""
def nullify_avro_schema(schema):
"""Add a 'null' type to every field."""
schema = schema.copy()
new_fields = []
for field in schema["fields"]:
if isinstance(field["type"], str):
new_fields.append(
{"name": field["name"], "type": sorted([field["type"], "null"])})
else:
new_fields.append(
{"name": field["name"], "type": sorted(field["type"] + ["null"])})
schema["fields"] = new_fields
return schema
def get_collumns_in_order(test_data):
"""Get a list of columns while trying to maintain original order.
.. note::
Columns which do not apear until later rows are added to the end,
even if they preceed some columns which have already been added.
"""
_seen = set()
columns = [
c for test_case in test_data for row in test_case["data"]
for c in row
if c not in _seen and not _seen.add(c)
]
return columns
nullified_test_data = []
columns = get_collumns_in_order(test_data)
for test_case in test_data:
if not test_case["data"]:
continue
test_case = test_case.copy()
test_case["name"] = test_case["name"] + "_nullified"
test_case["data"] = test_case["data"] + [
OrderedDict([(c, None) for c in columns])
]
test_case["type_schema"] = OrderedDict([
(k, typehints.Union[v, type(None)])
for k, v in test_case["type_schema"].items()
])
test_case["avro_schema"] = nullify_avro_schema(test_case["avro_schema"])
nullified_test_data.append(test_case)
return nullified_test_data
TEST_DATA += nullify_data_and_schemas(TEST_DATA)
class DatatypeInferenceTest(unittest.TestCase):
@parameterized.expand([
(d["name"], d["data"], d["type_schema"]) for d in TEST_DATA
])
def test_infer_typehints_schema(self, _, data, schema):
typehints_schema = datatype_inference.infer_typehints_schema(data)
self.assertEqual(typehints_schema, schema)
@parameterized.expand([
(d["name"], d["data"], d["pyarrow_schema"]) for d in TEST_DATA
])
@unittest.skipIf(pa is None, "PyArrow is not installed")
def test_infer_pyarrow_schema(self, _, data, schema):
pyarrow_schema = datatype_inference.infer_pyarrow_schema(data)
self.assertEqual(pyarrow_schema, schema)
@parameterized.expand([
(d["name"], d["data"], d["avro_schema"]) for d in TEST_DATA
])
def test_infer_avro_schema(self, _, data, schema):
schema = schema.copy() # Otherwise, it would be mutated by `.pop()`
avro_schema = datatype_inference.infer_avro_schema(data, use_fastavro=False)
avro_schema = avro_schema.to_json()
fields1 = avro_schema.pop("fields")
fields2 = schema.pop("fields")
self.assertDictEqual(avro_schema, schema)
for field1, field2 in zip(fields1, fields2):
self.assertDictEqual(field1, field2)
@parameterized.expand([
(d["name"], d["data"], d["avro_schema"]) for d in TEST_DATA
])
def test_infer_fastavro_schema(self, _, data, schema):
from fastavro import parse_schema
schema = parse_schema(schema)
avro_schema = datatype_inference.infer_avro_schema(data, use_fastavro=True)
fields1 = avro_schema.pop("fields")
fields2 = schema.pop("fields")
self.assertDictEqual(avro_schema, schema)
for field1, field2 in zip(fields1, fields2):
self.assertDictEqual(field1, field2)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()