blob: 1321313142a569c1cc872faf0b998328c7e30ae3 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utilities to deal with types. This is mostly focused on python3.
"""
import datetime
import decimal
from inspect import getfullargspec, isclass
from typing import ( # noqa: F401
Any,
Callable,
Generic,
List,
Optional,
Tuple,
Union,
)
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype, pandas_dtype
try:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
extension_dtypes_available = True
extension_dtypes = (Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype) # type: Tuple
try:
from pandas import BooleanDtype, StringDtype
extension_object_dtypes_available = True
extension_dtypes += (BooleanDtype, StringDtype)
except ImportError:
extension_object_dtypes_available = False
try:
from pandas import Float32Dtype, Float64Dtype
extension_float_dtypes_available = True
extension_dtypes += (Float32Dtype, Float64Dtype)
except ImportError:
extension_float_dtypes_available = False
except ImportError:
extension_dtypes_available = False
extension_object_dtypes_available = False
extension_float_dtypes_available = False
extension_dtypes = ()
import pyarrow as pa
import pyspark.sql.types as types
from pyspark.sql.pandas.types import to_arrow_type, from_arrow_type
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Dtype, T
from pyspark.pandas.typedef.string_typehints import resolve_string_type_hint
# A column of data, with the data type.
class SeriesType(Generic[T]):
def __init__(self, dtype: Dtype, spark_type: types.DataType):
self.dtype = dtype
self.spark_type = spark_type
def __repr__(self) -> str:
return "SeriesType[{}]".format(self.spark_type)
class DataFrameType(object):
def __init__(
self, dtypes: List[Dtype], spark_types: List[types.DataType], names: List[Optional[str]]
):
from pyspark.pandas.internal import InternalField
from pyspark.pandas.utils import name_like_string
self.fields = [
InternalField(
dtype=dtype,
struct_field=types.StructField(
name=(name_like_string(name) if name is not None else ("c%s" % i)),
dataType=spark_type,
),
)
for i, (name, dtype, spark_type) in enumerate(zip(names, dtypes, spark_types))
]
@property
def dtypes(self) -> List[Dtype]:
return [field.dtype for field in self.fields]
@property
def spark_type(self) -> types.StructType:
return types.StructType([field.struct_field for field in self.fields])
def __repr__(self) -> str:
return "DataFrameType[{}]".format(self.spark_type)
# The type is a scalar type that is furthermore understood by Spark.
class ScalarType(object):
def __init__(self, dtype: Dtype, spark_type: types.DataType):
self.dtype = dtype
self.spark_type = spark_type
def __repr__(self) -> str:
return "ScalarType[{}]".format(self.spark_type)
# The type is left unspecified or we do not know about this type.
class UnknownType(object):
def __init__(self, tpe: Any):
self.tpe = tpe
def __repr__(self) -> str:
return "UnknownType[{}]".format(self.tpe)
class NameTypeHolder(object):
name = None
tpe = None
def as_spark_type(tpe: Union[str, type, Dtype], *, raise_error: bool = True) -> types.DataType:
"""
Given a Python type, returns the equivalent spark type.
Accepts:
- the built-in types in Python
- the built-in types in numpy
- list of pairs of (field_name, type)
- dictionaries of field_name -> type
- Python3's typing system
"""
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
pass
# ArrayType
elif tpe in (np.ndarray,):
return types.ArrayType(types.StringType())
elif hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, list): # type: ignore
element_type = as_spark_type(tpe.__args__[0], raise_error=raise_error) # type: ignore
if element_type is None:
return None
return types.ArrayType(element_type)
# BinaryType
elif tpe in (bytes, np.character, np.bytes_, np.string_):
return types.BinaryType()
# BooleanType
elif tpe in (bool, np.bool_, "bool", "?"):
return types.BooleanType()
# DateType
elif tpe in (datetime.date,):
return types.DateType()
# NumericType
elif tpe in (np.int8, np.byte, "int8", "byte", "b"):
return types.ByteType()
elif tpe in (decimal.Decimal,):
# TODO: considering about the precision & scale for decimal type.
return types.DecimalType(38, 18)
elif tpe in (float, np.float_, np.float64, "float", "float64", "double"):
return types.DoubleType()
elif tpe in (np.float32, "float32", "f"):
return types.FloatType()
elif tpe in (np.int32, "int32", "i"):
return types.IntegerType()
elif tpe in (int, np.int64, "int", "int64", "long"):
return types.LongType()
elif tpe in (np.int16, "int16", "short"):
return types.ShortType()
# StringType
elif tpe in (str, np.unicode_, "str", "U"):
return types.StringType()
# TimestampType
elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"):
return types.TimestampType()
# categorical types
elif isinstance(tpe, CategoricalDtype) or (isinstance(tpe, str) and type == "category"):
return types.LongType()
# extension types
elif extension_dtypes_available:
# IntegralType
if isinstance(tpe, Int8Dtype) or (isinstance(tpe, str) and tpe == "Int8"):
return types.ByteType()
elif isinstance(tpe, Int16Dtype) or (isinstance(tpe, str) and tpe == "Int16"):
return types.ShortType()
elif isinstance(tpe, Int32Dtype) or (isinstance(tpe, str) and tpe == "Int32"):
return types.IntegerType()
elif isinstance(tpe, Int64Dtype) or (isinstance(tpe, str) and tpe == "Int64"):
return types.LongType()
if extension_object_dtypes_available:
# BooleanType
if isinstance(tpe, BooleanDtype) or (isinstance(tpe, str) and tpe == "boolean"):
return types.BooleanType()
# StringType
elif isinstance(tpe, StringDtype) or (isinstance(tpe, str) and tpe == "string"):
return types.StringType()
if extension_float_dtypes_available:
# FractionalType
if isinstance(tpe, Float32Dtype) or (isinstance(tpe, str) and tpe == "Float32"):
return types.FloatType()
elif isinstance(tpe, Float64Dtype) or (isinstance(tpe, str) and tpe == "Float64"):
return types.DoubleType()
if raise_error:
raise TypeError("Type %s was not understood." % tpe)
else:
return None
def spark_type_to_pandas_dtype(
spark_type: types.DataType, *, use_extension_dtypes: bool = False
) -> Dtype:
"""Return the given Spark DataType to pandas dtype."""
if use_extension_dtypes and extension_dtypes_available:
# IntegralType
if isinstance(spark_type, types.ByteType):
return Int8Dtype()
elif isinstance(spark_type, types.ShortType):
return Int16Dtype()
elif isinstance(spark_type, types.IntegerType):
return Int32Dtype()
elif isinstance(spark_type, types.LongType):
return Int64Dtype()
if extension_object_dtypes_available:
# BooleanType
if isinstance(spark_type, types.BooleanType):
return BooleanDtype()
# StringType
elif isinstance(spark_type, types.StringType):
return StringDtype()
# FractionalType
if extension_float_dtypes_available:
if isinstance(spark_type, types.FloatType):
return Float32Dtype()
elif isinstance(spark_type, types.DoubleType):
return Float64Dtype()
if isinstance(
spark_type,
(
types.DateType,
types.NullType,
types.ArrayType,
types.MapType,
types.StructType,
types.UserDefinedType,
),
):
return np.dtype("object")
elif isinstance(spark_type, types.TimestampType):
return np.dtype("datetime64[ns]")
else:
return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())
def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.DataType]:
"""
Convert input into a pandas only dtype object or a numpy dtype object,
and its corresponding Spark DataType.
Parameters
----------
tpe : object to be converted
Returns
-------
tuple of np.dtype or a pandas dtype, and Spark DataType
Raises
------
TypeError if not a dtype
Examples
--------
>>> pandas_on_spark_type(int)
(dtype('int64'), LongType)
>>> pandas_on_spark_type(str)
(dtype('<U'), StringType)
>>> pandas_on_spark_type(datetime.date)
(dtype('O'), DateType)
>>> pandas_on_spark_type(datetime.datetime)
(dtype('<M8[ns]'), TimestampType)
>>> pandas_on_spark_type(List[bool])
(dtype('O'), ArrayType(BooleanType,true))
"""
try:
dtype = pandas_dtype(tpe)
spark_type = as_spark_type(dtype)
except TypeError:
spark_type = as_spark_type(tpe)
dtype = spark_type_to_pandas_dtype(spark_type)
return dtype, spark_type
def infer_pd_series_spark_type(pser: pd.Series, dtype: Dtype) -> types.DataType:
"""Infer Spark DataType from pandas Series dtype.
:param pser: :class:`pandas.Series` to be inferred
:param dtype: the Series' dtype
:return: the inferred Spark data type
"""
if dtype == np.dtype("object"):
if len(pser) == 0 or pser.isnull().all():
return types.NullType()
elif hasattr(pser.iloc[0], "__UDT__"):
return pser.iloc[0].__UDT__
else:
return from_arrow_type(pa.Array.from_pandas(pser).type)
elif isinstance(dtype, CategoricalDtype):
if isinstance(pser.dtype, CategoricalDtype):
return as_spark_type(pser.cat.codes.dtype)
else:
# `pser` must already be converted to codes.
return as_spark_type(pser.dtype)
else:
return as_spark_type(dtype)
def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]:
"""
Infer the return type from the return type annotation of the given function.
The returned type class indicates both dtypes (a pandas only dtype object
or a numpy dtype object) and its corresponding Spark DataType.
>>> def func() -> int:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> ps.Series[int]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> ps.DataFrame[np.float, str]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('<U')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true)))
>>> def func() -> ps.DataFrame[np.float]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true)))
>>> def func() -> 'int':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> 'ps.Series[int]':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> 'ps.DataFrame[np.float, str]':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('<U')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true)))
>>> def func() -> 'ps.DataFrame[np.float]':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true)))
>>> def func() -> ps.DataFrame['a': np.float, 'b': int]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true)))
>>> def func() -> "ps.DataFrame['a': np.float, 'b': int]":
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true)))
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
>>> def func() -> ps.DataFrame[pdf.dtypes]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true)))
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
>>> def func() -> ps.DataFrame[zip(pdf.columns, pdf.dtypes)]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(a,LongType,true),StructField(b,LongType,true)))
>>> pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]})
>>> def func() -> ps.DataFrame[zip(pdf.columns, pdf.dtypes)]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField((x, a),LongType,true),StructField((y, b),LongType,true)))
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical([3, 4, 5])})
>>> def func() -> ps.DataFrame[pdf.dtypes]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
>>> inferred.spark_type
StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true)))
>>> def func() -> ps.DataFrame[zip(pdf.columns, pdf.dtypes)]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
>>> inferred.spark_type
StructType(List(StructField(a,LongType,true),StructField(b,LongType,true)))
>>> def func() -> ps.Series[pdf.b.dtype]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
CategoricalDtype(categories=[3, 4, 5], ordered=False)
>>> inferred.spark_type
LongType
"""
# We should re-import to make sure the class 'SeriesType' is not treated as a class
# within this module locally. See Series.__class_getitem__ which imports this class
# canonically.
from pyspark.pandas.typedef import SeriesType, NameTypeHolder
spec = getfullargspec(f)
tpe = spec.annotations.get("return", None)
if isinstance(tpe, str):
# This type hint can happen when given hints are string to avoid forward reference.
tpe = resolve_string_type_hint(tpe)
if hasattr(tpe, "__origin__") and (
tpe.__origin__ == ps.DataFrame or tpe.__origin__ == ps.Series
):
# When Python version is lower then 3.7. Unwrap it to a Tuple/SeriesType type hints.
tpe = tpe.__args__[0]
if hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, SeriesType):
tpe = tpe.__args__[0]
if issubclass(tpe, NameTypeHolder):
tpe = tpe.tpe
dtype, spark_type = pandas_on_spark_type(tpe)
return SeriesType(dtype, spark_type)
# Note that, DataFrame type hints will create a Tuple.
# Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`.
# Check if the name is Tuple.
name = getattr(tpe, "_name", getattr(tpe, "__name__", None))
if name == "Tuple":
tuple_type = tpe
if hasattr(tuple_type, "__tuple_params__"):
# Python 3.5.0 to 3.5.2 has '__tuple_params__' instead.
# See https://github.com/python/cpython/blob/v3.5.2/Lib/typing.py
parameters = getattr(tuple_type, "__tuple_params__")
else:
parameters = getattr(tuple_type, "__args__")
dtypes, spark_types = zip(
*(
pandas_on_spark_type(p.tpe)
if isclass(p) and issubclass(p, NameTypeHolder)
else pandas_on_spark_type(p)
for p in parameters
)
)
names = [
p.name if isclass(p) and issubclass(p, NameTypeHolder) else None for p in parameters
]
return DataFrameType(list(dtypes), list(spark_types), names)
types = pandas_on_spark_type(tpe)
if types is None:
return UnknownType(tpe)
else:
return ScalarType(*types)
def _test() -> None:
import doctest
import sys
import pyspark.pandas.typedef.typehints
globs = pyspark.pandas.typedef.typehints.__dict__.copy()
(failure_count, test_count) = doctest.testmod(
pyspark.pandas.typedef.typehints,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
)
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()