| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from abc import ABCMeta, abstractmethod |
| from typing import List |
| |
| try: |
| import importlib.metadata as importlib_metadata |
| except ImportError: |
| import importlib_metadata |
| |
| import pyarrow as pa |
| |
| from ._internal import ( |
| AggregateUDF, |
| Config, |
| DataFrame, |
| SessionContext, |
| SessionConfig, |
| RuntimeConfig, |
| ScalarUDF, |
| ) |
| |
| from .common import ( |
| DFField, |
| DFSchema, |
| ) |
| |
| from .expr import ( |
| Alias, |
| Analyze, |
| Expr, |
| Filter, |
| Limit, |
| Like, |
| ILike, |
| Projection, |
| SimilarTo, |
| ScalarVariable, |
| Sort, |
| TableScan, |
| GetIndexedField, |
| Not, |
| IsNotNull, |
| IsTrue, |
| IsFalse, |
| IsUnknown, |
| IsNotTrue, |
| IsNotFalse, |
| IsNotUnknown, |
| Negative, |
| ScalarFunction, |
| BuiltinScalarFunction, |
| InList, |
| Exists, |
| Subquery, |
| InSubquery, |
| ScalarSubquery, |
| GroupingSet, |
| Placeholder, |
| Case, |
| Cast, |
| TryCast, |
| Between, |
| Explain, |
| CreateMemoryTable, |
| SubqueryAlias, |
| Extension, |
| CreateView, |
| Distinct, |
| DropTable, |
| Repartition, |
| Partitioning, |
| ) |
| |
| __version__ = importlib_metadata.version(__name__) |
| |
| __all__ = [ |
| "Config", |
| "DataFrame", |
| "SessionContext", |
| "SessionConfig", |
| "RuntimeConfig", |
| "Expr", |
| "AggregateUDF", |
| "ScalarUDF", |
| "column", |
| "literal", |
| "TableScan", |
| "Projection", |
| "DFSchema", |
| "DFField", |
| "Analyze", |
| "Sort", |
| "Limit", |
| "Filter", |
| "Like", |
| "ILike", |
| "SimilarTo", |
| "ScalarVariable", |
| "Alias", |
| "GetIndexedField", |
| "Not", |
| "IsNotNull", |
| "IsTrue", |
| "IsFalse", |
| "IsUnknown", |
| "IsNotTrue", |
| "IsNotFalse", |
| "IsNotUnknown", |
| "Negative", |
| "ScalarFunction", |
| "BuiltinScalarFunction", |
| "InList", |
| "Exists", |
| "Subquery", |
| "InSubquery", |
| "ScalarSubquery", |
| "GroupingSet", |
| "Placeholder", |
| "Case", |
| "Cast", |
| "TryCast", |
| "Between", |
| "Explain", |
| "SubqueryAlias", |
| "Extension", |
| "CreateMemoryTable", |
| "CreateView", |
| "Distinct", |
| "DropTable", |
| "Repartition", |
| "Partitioning", |
| ] |
| |
| |
| class Accumulator(metaclass=ABCMeta): |
| @abstractmethod |
| def state(self) -> List[pa.Scalar]: |
| pass |
| |
| @abstractmethod |
| def update(self, values: pa.Array) -> None: |
| pass |
| |
| @abstractmethod |
| def merge(self, states: pa.Array) -> None: |
| pass |
| |
| @abstractmethod |
| def evaluate(self) -> pa.Scalar: |
| pass |
| |
| |
| def column(value): |
| return Expr.column(value) |
| |
| |
| col = column |
| |
| |
| def literal(value): |
| if not isinstance(value, pa.Scalar): |
| value = pa.scalar(value) |
| return Expr.literal(value) |
| |
| |
| lit = literal |
| |
| |
| def udf(func, input_types, return_type, volatility, name=None): |
| """ |
| Create a new User Defined Function |
| """ |
| if not callable(func): |
| raise TypeError("`func` argument must be callable") |
| if name is None: |
| name = func.__qualname__.lower() |
| return ScalarUDF( |
| name=name, |
| func=func, |
| input_types=input_types, |
| return_type=return_type, |
| volatility=volatility, |
| ) |
| |
| |
| def udaf(accum, input_type, return_type, state_type, volatility, name=None): |
| """ |
| Create a new User Defined Aggregate Function |
| """ |
| if not issubclass(accum, Accumulator): |
| raise TypeError( |
| "`accum` must implement the abstract base class Accumulator" |
| ) |
| if name is None: |
| name = accum.__qualname__.lower() |
| return AggregateUDF( |
| name=name, |
| accumulator=accum, |
| input_type=input_type, |
| return_type=return_type, |
| state_type=state_type, |
| volatility=volatility, |
| ) |