blob: 7fb631e2a86a0904b548a47612ceecda17f81fc4 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import pyarrow as pa
import re
from pypaimon.api import Schema
from pypaimon.pynative.common.data_field import DataType, DataField
def convert_pa_schema_to_data_fields(pa_schema: pa.Schema) -> list[DataField]:
fields = []
for i, field in enumerate(pa_schema):
field: pa.Field
type_name = str(field.type)
if type_name.startswith('int'):
type_name = 'INT'
elif type_name.startswith('float'):
type_name = 'FLOAT'
elif type_name.startswith('double'):
type_name = 'DOUBLE'
elif type_name.startswith('bool'):
type_name = 'BOOLEAN'
elif type_name.startswith('string'):
type_name = 'STRING'
elif type_name.startswith('binary'):
type_name = 'BINARY'
elif type_name.startswith('date'):
type_name = 'DATE'
elif type_name.startswith('timestamp'):
type_name = 'TIMESTAMP'
elif type_name.startswith('decimal'):
match = re.match(r'decimal\((\d+),\s*(\d+)\)', type_name)
if match:
precision, scale = map(int, match.groups())
type_name = f'DECIMAL({precision},{scale})'
else:
type_name = 'DECIMAL(38,18)'
elif type_name.startswith('list'):
type_name = 'ARRAY'
elif type_name.startswith('struct'):
type_name = 'STRUCT'
elif type_name.startswith('map'):
type_name = 'MAP'
data_type = DataType(type_name, field.nullable)
data_field = DataField(
id=i,
name=field.name,
type=data_type,
description=field.metadata.get(b'description', b'').decode
('utf-8') if field.metadata and b'description' in field.metadata else None
)
fields.append(data_field)
return fields
def convert_data_fields_to_pa_schema(fields: list[DataField]) -> pa.Schema:
"""Convert a list of DataField to PyArrow Schema."""
pa_fields = []
for field in fields:
type_name = field.type.type_name.upper()
if type_name == 'INT':
type_name = pa.int32()
elif type_name == 'BIGINT':
type_name = pa.int64()
elif type_name == 'FLOAT':
type_name = pa.float32()
elif type_name == 'DOUBLE':
type_name = pa.float64()
elif type_name == 'BOOLEAN':
type_name = pa.bool_()
elif type_name == 'STRING':
type_name = pa.string()
elif type_name == 'BINARY':
type_name = pa.binary()
elif type_name == 'DATE':
type_name = pa.date32()
elif type_name == 'TIMESTAMP':
type_name = pa.timestamp('ms')
elif type_name.startswith('DECIMAL'):
match = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', type_name)
if match:
precision, scale = map(int, match.groups())
type_name = pa.decimal128(precision, scale)
else:
type_name = pa.decimal128(38, 18)
elif type_name == 'ARRAY':
# TODO: support arra / struct / map element type
type_name = pa.list_(pa.string())
elif type_name == 'STRUCT':
type_name = pa.struct([])
elif type_name == 'MAP':
type_name = pa.map_(pa.string(), pa.string())
else:
raise ValueError(f"Unsupported data type: {type_name}")
metadata = {}
if field.description:
metadata[b'description'] = field.description.encode('utf-8')
pa_fields.append(pa.field(field.name, type_name, nullable=field.type.nullable, metadata=metadata))
return pa.schema(pa_fields)
def get_highest_field_id(fields: list) -> int:
return max(field.id for field in fields)
def check_schema_for_external_table(exists_schema: Schema, new_schema: Schema):
"""Check if the new schema is compatible with the existing schema for external table."""
if ((not new_schema.pa_schema or new_schema.pa_schema.equals(exists_schema.pa_schema))
and (not new_schema.partition_keys or new_schema.partition_keys == exists_schema.partition_keys)
and (not new_schema.primary_keys or new_schema.primary_keys == exists_schema.primary_keys)):
exists_options = exists_schema.options
new_options = new_schema.options
for key, value in new_options.items():
if (key != 'owner' and key != 'path'
and (key not in exists_options or exists_options[key] != value)):
raise ValueError(
f"New schema's options are not equal to the exists schema's, "
f"new schema: {new_options}, exists schema: {exists_options}")
else:
raise ValueError(
f"New schema is not equal to the exists schema, "
f"new schema: {new_schema}, exists schema: {exists_schema}")