blob: 99687b97ebf271c8c8a1f08378a83be1c8c20de2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import typing
from typing import Optional
from pyfory.type_util import TypeVisitor, infer_field
from pyfory.format._format import (
Schema,
DataType,
TypeId,
boolean,
int8,
int16,
int32,
int64,
float32,
float64,
utf8,
binary,
date32,
timestamp,
list_,
map_,
struct,
field,
schema,
)
__type_map__ = {}
__schemas__ = {} # ensure `id(schema)` doesn't get duplicate.
def get_cls_by_schema(schema):
id_ = id(schema)
if id_ not in __type_map__:
# For Fory Schema, we don't have metadata support yet
# Try to get class name from schema if available
cls_name = ""
if hasattr(schema, "metadata") and schema.metadata:
cls_name = schema.metadata.get(b"cls", b"").decode()
if cls_name:
import importlib
module_name, class_name = cls_name.rsplit(".", 1)
mod = importlib.import_module(module_name)
cls_ = getattr(mod, class_name)
else:
from pyfory.type_util import record_class_factory
cls_ = record_class_factory("Record" + str(id(schema)), [schema.field(i).name for i in range(schema.num_fields)])
__type_map__[id_] = cls_
__schemas__[id_] = schema
return __type_map__[id_]
def remove_schema(schema):
__schemas__.pop(id(schema))
def reset():
__type_map__.clear()
__schemas__.clear()
_supported_types = {
bool,
int,
float,
str,
bytes,
typing.List,
typing.Dict,
}
_supported_types_str = [f"{t.__module__}.{getattr(t, '__name__', t)}" for t in _supported_types]
_supported_types_mapping = {
bool: boolean,
int: int64,
float: float64,
str: utf8,
bytes: binary,
list: list_,
dict: map_,
typing.List: list_,
typing.Dict: map_,
datetime.date: date32,
datetime.datetime: timestamp,
}
# Add pyfory type annotations support
from pyfory.types import (
int8 as int8_type,
int16 as int16_type,
int32 as int32_type,
int64 as int64_type,
float32 as float32_type,
float64 as float64_type,
)
_supported_types_mapping.update(
{
int8_type: int8,
int16_type: int16,
int32_type: int32,
int64_type: int64,
float32_type: float32,
float64_type: float64,
}
)
# Add numpy types if available
try:
import numpy as np
_supported_types_mapping.update(
{
np.int8: int8,
np.int16: int16,
np.int32: int32,
np.int64: int64,
np.float32: float32,
np.float64: float64,
}
)
except ImportError:
pass
def infer_schema(clz, types_path=None) -> Schema:
types_path = list(types_path or [])
from pyfory.type_util import get_type_hints
type_hints = get_type_hints(clz)
keys = sorted(type_hints.keys())
fields = [
infer_field(
field_name,
type_hints[field_name],
ForyTypeVisitor(),
types_path=types_path,
)
for field_name in keys
]
# TODO: Add metadata support to Fory Schema
return schema(fields)
class ForyTypeVisitor(TypeVisitor):
def visit_list(self, field_name, elem_type, types_path=None):
# Infer type recursively for type such as List[Dict[str, str]]
elem_field = infer_field("item", elem_type, self, types_path=types_path)
return field(field_name, list_(elem_field.type))
def visit_set(self, field_name, elem_type, types_path=None):
# Infer type recursively for type such as Set[Dict[str, str]]
elem_field = infer_field("item", elem_type, self, types_path=types_path)
return field(field_name, list_(elem_field.type))
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_field = infer_field("key", key_type, self, types_path=types_path)
value_field = infer_field("value", value_type, self, types_path=types_path)
return field(field_name, map_(key_field.type, value_field.type))
def visit_customized(self, field_name, type_, types_path=None):
# type_ is a pojo
pojo_schema = infer_schema(type_)
fields = [pojo_schema.field(i) for i in range(pojo_schema.num_fields)]
# TODO: Add metadata support
return field(field_name, struct(fields))
def visit_other(self, field_name, type_, types_path=None):
if type_ not in _supported_types_mapping:
raise TypeError(
f"Type {type_} not supported, currently only compositions of {_supported_types_str} are supported. types_path is {types_path}"
)
fory_type_func = _supported_types_mapping.get(type_)
return field(field_name, fory_type_func())
def infer_data_type(clz) -> Optional[DataType]:
try:
return infer_field("", clz, ForyTypeVisitor()).type
except TypeError:
return None
def get_type_id(clz) -> Optional[int]:
type_ = infer_data_type(clz)
if type_:
return type_.id
else:
return None
def compute_schema_hash(schema: Schema):
hash_ = 17
for i in range(schema.num_fields):
hash_ = _compute_hash(hash_, schema.field(i).type)
return hash_
def _compute_hash(hash_: int, type_: DataType):
while True:
h = hash_ * 31 + int(type_.id)
if h > 2**63 - 1:
hash_ = hash_ >> 2
else:
hash_ = h
break
types = []
type_id = type_.id
if type_id == TypeId.LIST:
list_type = type_
types.append(list_type.value_type)
elif type_id == TypeId.MAP:
map_type = type_
types.append(map_type.key_type)
types.append(map_type.item_type)
elif type_id == TypeId.STRUCT:
for i in range(type_.num_fields):
types.append(type_.field(i).type)
else:
assert type_.num_fields == 0, f"field type should not be nested, but got type {type_}."
for t in types:
hash_ = _compute_hash(hash_, t)
return hash_
def from_arrow_schema(arrow_schema) -> Schema:
"""Convert an Arrow Schema to a Fory Schema.
This is for compatibility with code that uses PyArrow schemas.
Args:
arrow_schema: A PyArrow Schema object.
Returns:
A Fory Schema object with the same structure.
Raises:
ImportError: If pyarrow is not available.
"""
try:
from pyarrow import types as pa_types
except ImportError:
raise ImportError("pyarrow is required for Arrow schema conversion")
def convert_type(arrow_type) -> DataType:
if pa_types.is_boolean(arrow_type):
return boolean()
elif pa_types.is_int8(arrow_type):
from pyfory.format._format import int8
return int8()
elif pa_types.is_int16(arrow_type):
from pyfory.format._format import int16
return int16()
elif pa_types.is_int32(arrow_type):
from pyfory.format._format import int32
return int32()
elif pa_types.is_int64(arrow_type):
return int64()
elif pa_types.is_float32(arrow_type):
from pyfory.format._format import float32
return float32()
elif pa_types.is_float64(arrow_type):
return float64()
elif pa_types.is_string(arrow_type) or pa_types.is_large_string(arrow_type):
return utf8()
elif pa_types.is_binary(arrow_type) or pa_types.is_large_binary(arrow_type):
return binary()
elif pa_types.is_date32(arrow_type):
return date32()
elif pa_types.is_timestamp(arrow_type):
return timestamp()
elif pa_types.is_list(arrow_type) or pa_types.is_large_list(arrow_type):
return list_(convert_type(arrow_type.value_type))
elif pa_types.is_map(arrow_type):
return map_(convert_type(arrow_type.key_type), convert_type(arrow_type.item_type))
elif pa_types.is_struct(arrow_type):
fields = []
for i in range(arrow_type.num_fields):
f = arrow_type.field(i)
fields.append(field(f.name, convert_type(f.type), nullable=f.nullable))
return struct(fields)
else:
raise TypeError(f"Unsupported Arrow type for Fory conversion: {arrow_type}")
fory_fields = []
for i in range(len(arrow_schema)):
f = arrow_schema.field(i)
fory_fields.append(field(f.name, convert_type(f.type), nullable=f.nullable))
return schema(fory_fields)
def to_arrow_schema(fory_schema: Schema):
"""Convert a Fory Schema to an Arrow Schema.
This is for compatibility with ArrowWriter which requires Arrow schemas.
Args:
fory_schema: A Fory Schema object.
Returns:
An Arrow Schema object with the same structure.
Raises:
ImportError: If pyarrow is not available.
"""
try:
import pyarrow as pa
except ImportError:
raise ImportError("pyarrow is required for Arrow schema conversion")
def convert_type(fory_type: DataType):
type_id = fory_type.id
if type_id == TypeId.BOOL:
return pa.bool_()
elif type_id == TypeId.INT8:
return pa.int8()
elif type_id == TypeId.INT16:
return pa.int16()
elif type_id == TypeId.INT32:
return pa.int32()
elif type_id == TypeId.INT64:
return pa.int64()
elif type_id == TypeId.FLOAT32:
return pa.float32()
elif type_id == TypeId.FLOAT64:
return pa.float64()
elif type_id == TypeId.STRING:
return pa.string()
elif type_id == TypeId.BINARY:
return pa.binary()
elif type_id == TypeId.DATE:
return pa.date32()
elif type_id == TypeId.TIMESTAMP:
return pa.timestamp("us")
elif type_id == TypeId.LIST:
return pa.list_(convert_type(fory_type.value_type))
elif type_id == TypeId.MAP:
return pa.map_(convert_type(fory_type.key_type), convert_type(fory_type.item_type))
elif type_id == TypeId.STRUCT:
fields = []
for i in range(fory_type.num_fields):
f = fory_type.field(i)
fields.append(pa.field(f.name, convert_type(f.type), nullable=f.nullable))
return pa.struct(fields)
else:
raise TypeError(f"Unsupported Fory type for Arrow conversion: {fory_type}")
arrow_fields = []
for i in range(fory_schema.num_fields):
f = fory_schema.field(i)
arrow_fields.append(pa.field(f.name, convert_type(f.type), nullable=f.nullable))
return pa.schema(arrow_fields)