blob: 9990d8362037a4644cb5a2d7bab1fbe2d21cef35 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Contains everything around the name mapping.
More information can be found on here:
https://iceberg.apache.org/spec/#name-mapping-serialization
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import ChainMap
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, Iterator, List, TypeVar, Union
from pydantic import Field, conlist, field_validator, model_serializer
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
class MappedField(IcebergBaseModel):
field_id: int = Field(alias="field-id")
names: List[str] = conlist(str, min_length=1)
fields: List[MappedField] = Field(default_factory=list)
@field_validator('fields', mode='before')
@classmethod
def convert_null_to_empty_List(cls, v: Any) -> Any:
return v or []
@model_serializer
def ser_model(self) -> Dict[str, Any]:
"""Set custom serializer to leave out the field when it is empty."""
fields = {'fields': self.fields} if len(self.fields) > 0 else {}
return {
'field-id': self.field_id,
'names': self.names,
**fields,
}
def __len__(self) -> int:
"""Return the number of fields."""
return len(self.fields)
def __str__(self) -> str:
"""Convert the mapped-field into a nicely formatted string."""
# Otherwise the UTs fail because the order of the set can change
fields_str = ", ".join([str(e) for e in self.fields]) or ""
fields_str = " " + fields_str if fields_str else ""
return "([" + ", ".join(self.names) + "] -> " + (str(self.field_id) or "?") + fields_str + ")"
class NameMapping(IcebergRootModel[List[MappedField]]):
root: List[MappedField]
@cached_property
def _field_by_name(self) -> Dict[str, MappedField]:
return visit_name_mapping(self, _IndexByName())
def find(self, *names: str) -> MappedField:
name = '.'.join(names)
try:
return self._field_by_name[name]
except KeyError as e:
raise ValueError(f"Could not find field with name: {name}") from e
def __len__(self) -> int:
"""Return the number of mappings."""
return len(self.root)
def __iter__(self) -> Iterator[MappedField]:
"""Iterate over the mapped fields."""
return iter(self.root)
def __str__(self) -> str:
"""Convert the name-mapping into a nicely formatted string."""
if len(self.root) == 0:
return "[]"
else:
return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]"
T = TypeVar("T")
class NameMappingVisitor(Generic[T], ABC):
@abstractmethod
def mapping(self, nm: NameMapping, field_results: T) -> T:
"""Visit a NameMapping."""
@abstractmethod
def fields(self, struct: List[MappedField], field_results: List[T]) -> T:
"""Visit a List[MappedField]."""
@abstractmethod
def field(self, field: MappedField, field_result: T) -> T:
"""Visit a MappedField."""
class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]):
def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]:
return field_results
def fields(self, struct: List[MappedField], field_results: List[Dict[str, MappedField]]) -> Dict[str, MappedField]:
return dict(ChainMap(*field_results))
def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dict[str, MappedField]:
result: Dict[str, MappedField] = {
f"{field_name}.{key}": result_field for key, result_field in field_result.items() for field_name in field.names
}
for name in field.names:
result[name] = field
return result
@singledispatch
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T:
"""Traverse the name mapping in post-order traversal."""
raise NotImplementedError(f"Cannot visit non-type: {obj}")
@visit_name_mapping.register(NameMapping)
def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T:
return visitor.mapping(obj, visit_name_mapping(obj.root, visitor))
@visit_name_mapping.register(list)
def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T:
results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields]
return visitor.fields(fields, results)
def parse_mapping_from_json(mapping: str) -> NameMapping:
return NameMapping.model_validate_json(mapping)
class _CreateMapping(SchemaVisitor[List[MappedField]]):
def schema(self, schema: Schema, struct_result: List[MappedField]) -> List[MappedField]:
return struct_result
def struct(self, struct: StructType, field_results: List[List[MappedField]]) -> List[MappedField]:
return [
MappedField(field_id=field.field_id, names=[field.name], fields=result)
for field, result in zip(struct.fields, field_results)
]
def field(self, field: NestedField, field_result: List[MappedField]) -> List[MappedField]:
return field_result
def list(self, list_type: ListType, element_result: List[MappedField]) -> List[MappedField]:
return [MappedField(field_id=list_type.element_id, names=["element"], fields=element_result)]
def map(self, map_type: MapType, key_result: List[MappedField], value_result: List[MappedField]) -> List[MappedField]:
return [
MappedField(field_id=map_type.key_id, names=["key"], fields=key_result),
MappedField(field_id=map_type.value_id, names=["value"], fields=value_result),
]
def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
return []
def create_mapping_from_schema(schema: Schema) -> NameMapping:
return NameMapping(visit(schema, _CreateMapping()))