blob: c5b98c17eb4a1e61ae5058272095489b31116004 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hudi.sync.common.util;
import org.apache.hudi.common.util.ValidationUtils;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;
import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;
/**
* Convert the parquet schema to spark schema' json string.
* This code is refer to org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter
* in spark project.
*/
public class Parquet2SparkSchemaUtils {
public static String convertToSparkSchemaJson(GroupType parquetSchema) {
String fieldsJsonString = parquetSchema.getFields().stream().map(field -> {
switch (field.getRepetition()) {
case OPTIONAL:
return "{\"name\":\"" + field.getName() + "\",\"type\":" + convertFieldType(field)
+ ",\"nullable\":true,\"metadata\":{}}";
case REQUIRED:
return "{\"name\":\"" + field.getName() + "\",\"type\":" + convertFieldType(field)
+ ",\"nullable\":false,\"metadata\":{}}";
case REPEATED:
String arrayType = arrayType(field, false);
return "{\"name\":\"" + field.getName() + "\",\"type\":" + arrayType
+ ",\"nullable\":false,\"metadata\":{}}";
default:
throw new UnsupportedOperationException("Unsupport convert " + field + " to spark sql type");
}
}).reduce((a, b) -> a + "," + b).orElse("");
return "{\"type\":\"struct\",\"fields\":[" + fieldsJsonString + "]}";
}
private static String convertFieldType(Type field) {
if (field instanceof PrimitiveType) {
return "\"" + convertPrimitiveType((PrimitiveType) field) + "\"";
} else {
assert field instanceof GroupType;
return convertGroupField((GroupType) field);
}
}
private static String convertPrimitiveType(PrimitiveType field) {
PrimitiveType.PrimitiveTypeName typeName = field.getPrimitiveTypeName();
OriginalType originalType = field.getOriginalType();
switch (typeName) {
case BOOLEAN: return "boolean";
case FLOAT: return "float";
case DOUBLE: return "double";
case INT32:
if (originalType == null) {
return "integer";
}
switch (originalType) {
case INT_8: return "byte";
case INT_16: return "short";
case INT_32: return "integer";
case DATE: return "date";
case DECIMAL:
return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
+ field.getDecimalMetadata().getScale() + ")";
default: throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
}
case INT64:
if (originalType == null) {
return "long";
}
switch (originalType) {
case INT_64: return "long";
case DECIMAL:
return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
+ field.getDecimalMetadata().getScale() + ")";
case TIMESTAMP_MICROS:
case TIMESTAMP_MILLIS:
return "timestamp";
default:
throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
}
case INT96: return "timestamp";
case BINARY:
if (originalType == null) {
return "binary";
}
switch (originalType) {
case UTF8:
case ENUM:
case JSON:
return "string";
case BSON: return "binary";
case DECIMAL:
return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
+ field.getDecimalMetadata().getScale() + ")";
default:
throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
}
case FIXED_LEN_BYTE_ARRAY:
switch (originalType) {
case DECIMAL:
return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
+ field.getDecimalMetadata().getScale() + ")";
default:
throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
}
default:
throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
}
}
private static String convertGroupField(GroupType field) {
if (field.getOriginalType() == null) {
return convertToSparkSchemaJson(field);
}
switch (field.getOriginalType()) {
case LIST:
ValidationUtils.checkArgument(field.getFieldCount() == 1, "Illegal List type: " + field);
Type repeatedType = field.getType(0);
if (isElementType(repeatedType, field.getName())) {
return arrayType(repeatedType, false);
} else {
Type elementType = repeatedType.asGroupType().getType(0);
boolean optional = elementType.isRepetition(OPTIONAL);
return arrayType(elementType, optional);
}
case MAP:
case MAP_KEY_VALUE:
GroupType keyValueType = field.getType(0).asGroupType();
Type keyType = keyValueType.getType(0);
Type valueType = keyValueType.getType(1);
boolean valueOptional = valueType.isRepetition(OPTIONAL);
return "{\"type\":\"map\", \"keyType\":" + convertFieldType(keyType)
+ ",\"valueType\":" + convertFieldType(valueType)
+ ",\"valueContainsNull\":" + valueOptional + "}";
default:
throw new UnsupportedOperationException("Unsupport convert " + field + " to spark sql type");
}
}
private static String arrayType(Type elementType, boolean containsNull) {
return "{\"type\":\"array\", \"elementType\":" + convertFieldType(elementType) + ",\"containsNull\":" + containsNull + "}";
}
private static boolean isElementType(Type repeatedType, String parentName) {
return repeatedType.isPrimitive() || repeatedType.asGroupType().getFieldCount() > 1
|| repeatedType.getName().equals("array") || repeatedType.getName().equals(parentName + "_tuple");
}
}