blob: 25b3b4f429bd38ef80b0187ca6723e4187b15db6 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import dataclasses
import enum
import functools
import logging
import re
import sys
from importlib import import_module
from types import ModuleType
from typing import Any, TypeVar, Union, cast
import attr
import airflow.serialization.serializers
from airflow.configuration import conf
from airflow.stats import Stats
from airflow.utils.module_loading import import_string, iter_namespace, qualname
log = logging.getLogger(__name__)
MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
CLASSNAME = "__classname__"
VERSION = "__version__"
DATA = "__data__"
SCHEMA_ID = "__id__"
CACHE = "__cache__"
OLD_TYPE = "__type"
OLD_SOURCE = "__source"
OLD_DATA = "__var"
DEFAULT_VERSION = 0
T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
U = Union[bool, float, int, dict, list, str, tuple, set]
S = Union[list, tuple, set]
_serializers: dict[str, ModuleType] = {}
_deserializers: dict[str, ModuleType] = {}
_stringifiers: dict[str, ModuleType] = {}
_extra_allowed: set[str] = set()
_primitives = (int, bool, float, str)
_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
"""Encodes o so it can be understood by the deserializer."""
return {CLASSNAME: cls, VERSION: version, DATA: data}
def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
classname = d[CLASSNAME]
version = d[VERSION]
if not isinstance(classname, str) or not isinstance(version, int):
raise ValueError(f"cannot decode {d!r}")
data = d.get(DATA)
return classname, version, data
def serialize(o: object, depth: int = 0) -> U | None:
"""Serialize an object into a representation consisting only built-in types.
Primitives (int, float, bool, str) are returned as-is. Built-in collections
are iterated over, where it is assumed that keys in a dict can be represented
as str.
Values that are not of a built-in type are serialized if a serializer is
found for them. The order in which serializers are used is
1. A ``serialize`` function provided by the object.
2. A registered serializer in the namespace of ``airflow.serialization.serializers``
3. Annotations from attr or dataclass.
Limitations: attr and dataclass objects can lose type information for nested objects
as they do not store this when calling ``asdict``. This means that at deserialization values
will be deserialized as a dict as opposed to reinstating the object. Provide
your own serializer to work around this.
:param o: The object to serialize.
:param depth: Private tracker for nested serialization.
:raise TypeError: A serializer cannot be found.
:raise RecursionError: The object is too nested for the function to handle.
:return: A representation of ``o`` that consists of only built-in types.
"""
if depth == MAX_RECURSION_DEPTH:
raise RecursionError("maximum recursion depth reached for serialization")
# None remains None
if o is None:
return o
# primitive types are returned as is
if isinstance(o, _primitives):
if isinstance(o, enum.Enum):
return o.value
return o
if isinstance(o, list):
return [serialize(d, depth + 1) for d in o]
if isinstance(o, dict):
if CLASSNAME in o or SCHEMA_ID in o:
raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
return {str(k): serialize(v, depth + 1) for k, v in o.items()}
cls = type(o)
qn = qualname(o)
# custom serializers
dct = {
CLASSNAME: qn,
VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
}
# if there is a builtin serializer available use that
if qn in _serializers:
data, classname, version, is_serialized = _serializers[qn].serialize(o)
if is_serialized:
return encode(classname, version, serialize(data, depth + 1))
# object / class brings their own
if hasattr(o, "serialize"):
data = getattr(o, "serialize")()
# if we end up with a structure, ensure its values are serialized
if isinstance(data, dict):
data = serialize(data, depth + 1)
dct[DATA] = data
return dct
# dataclasses
if dataclasses.is_dataclass(cls):
# fixme: unfortunately using asdict with nested dataclasses it looses information
data = dataclasses.asdict(o) # type: ignore[call-overload]
dct[DATA] = serialize(data, depth + 1)
return dct
# attr annotated
if attr.has(cls):
# Only include attributes which we can pass back to the classes constructor
data = attr.asdict(cast(attr.AttrsInstance, o), recurse=True, filter=lambda a, v: a.init)
dct[DATA] = serialize(data, depth + 1)
return dct
raise TypeError(f"cannot serialize object of type {cls}")
def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
"""
Deserializes an object of primitive type T into an object. Uses an allow
list to determine if a class can be loaded.
:param o: primitive to deserialize into an arbitrary object.
:param full: if False it will return a stringified representation
of an object and will not load any classes
:param type_hint: if set it will be used to help determine what
object to deserialize in. It does not override if another
specification is found
:return: object
"""
if o is None:
return o
if isinstance(o, _primitives):
return o
# tuples, sets are included here for backwards compatibility
if isinstance(o, _builtin_collections):
col = [deserialize(d) for d in o]
if isinstance(o, tuple):
return tuple(col)
if isinstance(o, set):
return set(col)
return col
if not isinstance(o, dict):
# if o is not a dict, then it's already deserialized
# in this case we should return it as is
return o
o = _convert(o)
# plain dict and no type hint
if CLASSNAME not in o and not type_hint or VERSION not in o:
return {str(k): deserialize(v, full) for k, v in o.items()}
# custom deserialization starts here
cls: Any
version = 0
value: Any = None
classname = ""
if type_hint:
cls = type_hint
classname = qualname(cls)
version = 0 # type hinting always sets version to 0
value = o
if CLASSNAME in o and VERSION in o:
classname, version, value = decode(o)
if not classname:
raise TypeError("classname cannot be empty")
# only return string representation
if not full:
return _stringify(classname, version, value)
if not _match(classname) and classname not in _extra_allowed:
raise ImportError(
f"{classname} was not found in allow list for deserialization imports. "
f"To allow it, add it to allowed_deserialization_classes in the configuration"
)
cls = import_string(classname)
# registered deserializer
if classname in _deserializers:
return _deserializers[classname].deserialize(classname, version, deserialize(value))
# class has deserialization function
if hasattr(cls, "deserialize"):
return getattr(cls, "deserialize")(deserialize(value), version)
# attr or dataclass
if attr.has(cls) or dataclasses.is_dataclass(cls):
class_version = getattr(cls, "__version__", 0)
if int(version) > class_version:
raise TypeError(
"serialized version of %s is newer than module version (%s > %s)",
classname,
version,
class_version,
)
return cls(**deserialize(value))
# no deserializer available
raise TypeError(f"No deserializer found for {classname}")
def _convert(old: dict) -> dict:
"""Converts an old style serialization to new style."""
if OLD_TYPE in old and OLD_DATA in old:
return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]}
return old
def _match(classname: str) -> bool:
return any(p.match(classname) is not None for p in _get_patterns())
def _stringify(classname: str, version: int, value: T | None) -> str:
"""Convert a previously serialized object in a somewhat human-readable format.
This function is not designed to be exact, and will not extensively traverse
the whole tree of an object.
"""
if classname in _stringifiers:
return _stringifiers[classname].stringify(classname, version, value)
s = f"{classname}@version={version}("
if isinstance(value, _primitives):
s += f"{value})"
elif isinstance(value, _builtin_collections):
# deserialized values can be != str
s += ",".join(str(deserialize(value, full=False)))
elif isinstance(value, dict):
for k, v in value.items():
s += f"{k}={deserialize(v, full=False)},"
s = s[:-1] + ")"
return s
def _register():
"""Register builtin serializers and deserializers for types that don't have any themselves."""
_serializers.clear()
_deserializers.clear()
_stringifiers.clear()
with Stats.timer("serde.load_serializers") as timer:
for _, name, _ in iter_namespace(airflow.serialization.serializers):
name = import_module(name)
for s in getattr(name, "serializers", ()):
if not isinstance(s, str):
s = qualname(s)
if s in _serializers and _serializers[s] != name:
raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
log.debug("registering %s for serialization", s)
_serializers[s] = name
for d in getattr(name, "deserializers", ()):
if not isinstance(d, str):
d = qualname(d)
if d in _deserializers and _deserializers[d] != name:
raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
log.debug("registering %s for deserialization", d)
_deserializers[d] = name
_extra_allowed.add(d)
for c in getattr(name, "stringifiers", ()):
if not isinstance(c, str):
c = qualname(c)
if c in _deserializers and _deserializers[c] != name:
raise AttributeError(f"duplicate {c} for stringifiers in {name} and {_stringifiers[c]}")
log.debug("registering %s for stringifying", c)
_stringifiers[c] = name
log.debug("loading serializers took %.3f seconds", timer.duration)
@functools.lru_cache(maxsize=None)
def _get_patterns() -> list[re.Pattern]:
patterns = conf.get("core", "allowed_deserialization_classes").split()
return [re.compile(re.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]
_register()