Use `VisitorWithPartner` for name-mapping (#1014)
* Use `VisitorWithPartner` for name-mapping
This will correctly handle fields with `.` in the name.
* Fix versions in deprecation
Co-authored-by: Sung Yun <107272191+sungwy@users.noreply.github.com>
* Use full path in error
---------
Co-authored-by: Sung Yun <107272191+sungwy@users.noreply.github.com>
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 719d289..b2cb167 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -130,7 +130,7 @@
visit_with_partner,
)
from pyiceberg.table.metadata import TableMetadata
-from pyiceberg.table.name_mapping import NameMapping
+from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
@@ -818,14 +818,14 @@
) -> Schema:
has_ids = visit_pyarrow(schema, _HasIds())
if has_ids:
- visitor = _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ return visit_pyarrow(schema, _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
elif name_mapping is not None:
- visitor = _ConvertToIceberg(name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ schema_without_ids = _pyarrow_to_schema_without_ids(schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ return apply_name_mapping(schema_without_ids, name_mapping)
else:
raise ValueError(
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
)
- return visit_pyarrow(schema, visitor)
def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema:
@@ -1002,17 +1002,13 @@
"""Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided."""
_field_names: List[str]
- _name_mapping: Optional[NameMapping]
- def __init__(self, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False) -> None:
+ def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None:
self._field_names = []
- self._name_mapping = name_mapping
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
def _field_id(self, field: pa.Field) -> int:
- if self._name_mapping:
- return self._name_mapping.find(*self._field_names).field_id
- elif (field_id := _get_field_id(field)) is not None:
+ if (field_id := _get_field_id(field)) is not None:
return field_id
else:
raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.")
diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py
index cb9f72b..eaf5fc8 100644
--- a/pyiceberg/table/name_mapping.py
+++ b/pyiceberg/table/name_mapping.py
@@ -30,9 +30,10 @@
from pydantic import Field, conlist, field_validator, model_serializer
-from pyiceberg.schema import Schema, SchemaVisitor, visit
+from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, SchemaWithPartnerVisitor, visit, visit_with_partner
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
-from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
+from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType
+from pyiceberg.utils.deprecated import deprecated
class MappedField(IcebergBaseModel):
@@ -74,6 +75,11 @@
def _field_by_name(self) -> Dict[str, MappedField]:
return visit_name_mapping(self, _IndexByName())
+ @deprecated(
+ deprecated_in="0.8.0",
+ removed_in="0.9.0",
+ help_message="Please use `apply_name_mapping` instead",
+ )
def find(self, *names: str) -> MappedField:
name = ".".join(names)
try:
@@ -248,3 +254,127 @@
def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))
+
+
+class NameMappingAccessor(PartnerAccessor[MappedField]):
+ def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]:
+ return partner
+
+ def field_partner(
+ self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str
+ ) -> Optional[MappedField]:
+ if partner_struct is not None:
+ if isinstance(partner_struct, MappedField):
+ partner_struct = partner_struct.fields
+
+ for field in partner_struct:
+ if field_name in field.names:
+ return field
+
+ return None
+
+ def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]:
+ if partner_list is not None:
+ for field in partner_list.fields:
+ if "element" in field.names:
+ return field
+ return None
+
+ def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]:
+ if partner_map is not None:
+ for field in partner_map.fields:
+ if "key" in field.names:
+ return field
+ return None
+
+ def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]:
+ if partner_map is not None:
+ for field in partner_map.fields:
+ if "value" in field.names:
+ return field
+ return None
+
+
+class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, IcebergType]):
+ current_path: List[str]
+
+ def __init__(self) -> None:
+ # For keeping track where we are in case when a field cannot be found
+ self.current_path = []
+
+ def before_field(self, field: NestedField, field_partner: Optional[P]) -> None:
+ self.current_path.append(field.name)
+
+ def after_field(self, field: NestedField, field_partner: Optional[P]) -> None:
+ self.current_path.pop()
+
+ def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
+ self.current_path.append("element")
+
+ def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
+ self.current_path.pop()
+
+ def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
+ self.current_path.append("key")
+
+ def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
+ self.current_path.pop()
+
+ def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
+ self.current_path.append("value")
+
+ def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
+ self.current_path.pop()
+
+ def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType:
+ return Schema(*struct_result.fields, schema_id=schema.schema_id)
+
+ def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType:
+ return StructType(*field_results)
+
+ def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType:
+ if field_partner is None:
+ raise ValueError(f"Field missing from NameMapping: {'.'.join(self.current_path)}")
+
+ return NestedField(
+ field_id=field_partner.field_id,
+ name=field.name,
+ field_type=field_result,
+ required=field.required,
+ doc=field.doc,
+ initial_default=field.initial_default,
+ initial_write=field.write_default,
+ )
+
+ def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType:
+ if list_partner is None:
+ raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")
+
+ element_id = next(field for field in list_partner.fields if "element" in field.names).field_id
+ return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required)
+
+ def map(
+ self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType
+ ) -> IcebergType:
+ if map_partner is None:
+ raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")
+
+ key_id = next(field for field in map_partner.fields if "key" in field.names).field_id
+ value_id = next(field for field in map_partner.fields if "value" in field.names).field_id
+ return MapType(
+ key_id=key_id,
+ key_type=key_result,
+ value_id=value_id,
+ value_type=value_result,
+ value_required=map_type.value_required,
+ )
+
+ def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType:
+ if primitive_partner is None:
+ raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}")
+
+ return primitive
+
+
+def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) -> Schema:
+ return visit_with_partner(schema_without_ids, name_mapping, NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore
diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py
index 3c50a24..647644f 100644
--- a/tests/table/test_name_mapping.py
+++ b/tests/table/test_name_mapping.py
@@ -20,11 +20,12 @@
from pyiceberg.table.name_mapping import (
MappedField,
NameMapping,
+ apply_name_mapping,
create_mapping_from_schema,
parse_mapping_from_json,
update_mapping,
)
-from pyiceberg.types import NestedField, StringType
+from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, MapType, NestedField, StringType, StructType
@pytest.fixture(scope="session")
@@ -321,3 +322,52 @@
MappedField(field_id=18, names=["add_18"]),
])
assert update_mapping(table_name_mapping_nested, updates, adds) == expected
+
+
+def test_mapping_using_by_visitor(table_schema_nested: Schema, table_name_mapping_nested: NameMapping) -> None:
+ schema_without_ids = Schema(
+ NestedField(field_id=0, name="foo", field_type=StringType(), required=False),
+ NestedField(field_id=0, name="bar", field_type=IntegerType(), required=True),
+ NestedField(field_id=0, name="baz", field_type=BooleanType(), required=False),
+ NestedField(
+ field_id=0,
+ name="qux",
+ field_type=ListType(element_id=0, element_type=StringType(), element_required=True),
+ required=True,
+ ),
+ NestedField(
+ field_id=0,
+ name="quux",
+ field_type=MapType(
+ key_id=0,
+ key_type=StringType(),
+ value_id=0,
+ value_type=MapType(key_id=0, key_type=StringType(), value_id=0, value_type=IntegerType(), value_required=True),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=0,
+ name="location",
+ field_type=ListType(
+ element_id=0,
+ element_type=StructType(
+ NestedField(field_id=0, name="latitude", field_type=FloatType(), required=False),
+ NestedField(field_id=0, name="longitude", field_type=FloatType(), required=False),
+ ),
+ element_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=0,
+ name="person",
+ field_type=StructType(
+ NestedField(field_id=0, name="name", field_type=StringType(), required=False),
+ NestedField(field_id=0, name="age", field_type=IntegerType(), required=True),
+ ),
+ required=False,
+ ),
+ )
+ assert apply_name_mapping(schema_without_ids, table_name_mapping_nested).fields == table_schema_nested.fields