blob: c4c0d9ce4c4d99b7247f395ad953ee93d244bce3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.avro.protobuf;
import java.util.List;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Map;
import java.util.IdentityHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.io.IOException;
import java.io.File;
import org.apache.avro.Conversion;
import org.apache.avro.Schema;
import org.apache.avro.Schema.Field;
import org.apache.avro.generic.GenericData;
import org.apache.avro.specific.SpecificData;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.DatumWriter;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.DescriptorProtos.FileOptions;
import org.apache.avro.util.ClassUtils;
import org.apache.avro.util.internal.Accessor;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
/** Utilities for serializing Protobuf data in Avro format. */
public class ProtobufData extends GenericData {
private static final ProtobufData INSTANCE = new ProtobufData();
protected ProtobufData() {
}
/** Return the singleton instance. */
public static ProtobufData get() {
return INSTANCE;
}
@Override
public DatumReader createDatumReader(Schema schema) {
return new ProtobufDatumReader(schema, schema, this);
}
@Override
public DatumWriter createDatumWriter(Schema schema) {
return new ProtobufDatumWriter(schema, this);
}
@Override
public void setField(Object r, String n, int pos, Object value) {
setField(r, n, pos, value, getRecordState(r, getSchema(r.getClass())));
}
@Override
public Object getField(Object r, String name, int pos) {
return getField(r, name, pos, getRecordState(r, getSchema(r.getClass())));
}
@Override
protected void setField(Object record, String name, int position, Object value, Object state) {
Builder b = (Builder) record;
FieldDescriptor f = ((FieldDescriptor[]) state)[position];
switch (f.getType()) {
case MESSAGE:
if (value == null) {
b.clearField(f);
break;
}
default:
b.setField(f, value);
}
}
@Override
protected Object getField(Object record, String name, int pos, Object state) {
Message m = (Message) record;
FieldDescriptor f = ((FieldDescriptor[]) state)[pos];
switch (f.getType()) {
case MESSAGE:
if (!f.isRepeated() && !m.hasField(f))
return null;
default:
return m.getField(f);
}
}
private final Map<Descriptor, FieldDescriptor[]> fieldCache = new ConcurrentHashMap<>();
@Override
protected Object getRecordState(Object r, Schema s) {
Descriptor d = ((MessageOrBuilder) r).getDescriptorForType();
FieldDescriptor[] fields = fieldCache.get(d);
if (fields == null) { // cache miss
fields = new FieldDescriptor[s.getFields().size()];
for (Field f : s.getFields())
fields[f.pos()] = d.findFieldByName(f.name());
fieldCache.put(d, fields); // update cache
}
return fields;
}
@Override
protected boolean isRecord(Object datum) {
return datum instanceof Message;
}
@Override
public Object newRecord(Object old, Schema schema) {
try {
Class c = SpecificData.get().getClass(schema);
if (c == null)
return newRecord(old, schema); // punt to generic
if (c.isInstance(old))
return old; // reuse instance
return c.getMethod("newBuilder").invoke(null);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
protected boolean isArray(Object datum) {
return datum instanceof List;
}
@Override
protected boolean isBytes(Object datum) {
return datum instanceof ByteString;
}
@Override
protected Schema getRecordSchema(Object record) {
Descriptor descriptor = ((Message) record).getDescriptorForType();
Schema schema = schemaCache.get(descriptor);
if (schema == null) {
schema = getSchema(descriptor);
schemaCache.put(descriptor, schema);
}
return schema;
}
private final Map<Object, Schema> schemaCache = new ConcurrentHashMap<>();
/** Return a record schema given a protobuf message class. */
public Schema getSchema(Class c) {
Schema schema = schemaCache.get(c);
if (schema == null) { // cache miss
try {
Object descriptor = c.getMethod("getDescriptor").invoke(null);
if (c.isEnum())
schema = getSchema((EnumDescriptor) descriptor);
else
schema = getSchema((Descriptor) descriptor);
} catch (Exception e) {
throw new RuntimeException(e);
}
schemaCache.put(c, schema); // update cache
}
return schema;
}
private static final ThreadLocal<Map<Descriptor, Schema>> SEEN = ThreadLocal.withInitial(IdentityHashMap::new);
public Schema getSchema(Descriptor descriptor) {
Map<Descriptor, Schema> seen = SEEN.get();
if (seen.containsKey(descriptor)) // stop recursion
return seen.get(descriptor);
boolean first = seen.isEmpty();
Conversion conversion = getConversionByDescriptor(descriptor);
if (conversion != null) {
Schema converted = conversion.getRecommendedSchema();
seen.put(descriptor, converted);
return converted;
}
try {
Schema result = Schema.createRecord(descriptor.getName(), null,
getNamespace(descriptor.getFile(), descriptor.getContainingType()), false);
seen.put(descriptor, result);
List<Field> fields = new ArrayList<>(descriptor.getFields().size());
for (FieldDescriptor f : descriptor.getFields())
fields.add(Accessor.createField(f.getName(), getSchema(f), null, getDefault(f)));
result.setFields(fields);
return result;
} finally {
if (first)
seen.clear();
}
}
public String getNamespace(FileDescriptor fd, Descriptor containing) {
FileOptions o = fd.getOptions();
String p = o.hasJavaPackage() ? o.getJavaPackage() : fd.getPackage();
String outer = "";
if (!o.getJavaMultipleFiles()) {
if (o.hasJavaOuterClassname()) {
outer = o.getJavaOuterClassname();
} else {
outer = new File(fd.getName()).getName();
outer = outer.substring(0, outer.lastIndexOf('.'));
outer = toCamelCase(outer);
}
}
StringBuilder inner = new StringBuilder();
while (containing != null) {
if (inner.length() == 0) {
inner.insert(0, containing.getName());
} else {
inner.insert(0, containing.getName() + "$");
}
containing = containing.getContainingType();
}
String d1 = (!outer.isEmpty() || inner.length() != 0 ? "." : "");
String d2 = (!outer.isEmpty() && inner.length() != 0 ? "$" : "");
return p + d1 + outer + d2 + inner;
}
private static String toCamelCase(String s) {
String[] parts = s.split("_");
StringBuilder camelCaseString = new StringBuilder(s.length());
for (String part : parts) {
camelCaseString.append(cap(part));
}
return camelCaseString.toString();
}
private static String cap(String s) {
return s.substring(0, 1).toUpperCase() + s.substring(1).toLowerCase();
}
private static final Schema NULL = Schema.create(Schema.Type.NULL);
public Schema getSchema(FieldDescriptor f) {
Schema s = getNonRepeatedSchema(f);
if (f.isRepeated())
s = Schema.createArray(s);
return s;
}
private Schema getNonRepeatedSchema(FieldDescriptor f) {
Schema result;
switch (f.getType()) {
case BOOL:
return Schema.create(Schema.Type.BOOLEAN);
case FLOAT:
return Schema.create(Schema.Type.FLOAT);
case DOUBLE:
return Schema.create(Schema.Type.DOUBLE);
case STRING:
Schema s = Schema.create(Schema.Type.STRING);
GenericData.setStringType(s, GenericData.StringType.String);
return s;
case BYTES:
return Schema.create(Schema.Type.BYTES);
case INT32:
case UINT32:
case SINT32:
case FIXED32:
case SFIXED32:
return Schema.create(Schema.Type.INT);
case INT64:
case UINT64:
case SINT64:
case FIXED64:
case SFIXED64:
return Schema.create(Schema.Type.LONG);
case ENUM:
return getSchema(f.getEnumType());
case MESSAGE:
result = getSchema(f.getMessageType());
if (f.isOptional())
// wrap optional record fields in a union with null
result = Schema.createUnion(Arrays.asList(NULL, result));
return result;
case GROUP: // groups are deprecated
default:
throw new RuntimeException("Unexpected type: " + f.getType());
}
}
public Schema getSchema(EnumDescriptor d) {
List<String> symbols = new ArrayList<>(d.getValues().size());
for (EnumValueDescriptor e : d.getValues()) {
symbols.add(e.getName());
}
return Schema.createEnum(d.getName(), null, getNamespace(d.getFile(), d.getContainingType()), symbols);
}
private static final JsonFactory FACTORY = new JsonFactory();
private static final ObjectMapper MAPPER = new ObjectMapper(FACTORY);
private static final JsonNodeFactory NODES = JsonNodeFactory.instance;
private JsonNode getDefault(FieldDescriptor f) {
if (f.isRequired()) // no default
return null;
if (f.isRepeated()) // empty array as repeated fields' default value
return NODES.arrayNode();
if (f.hasDefaultValue()) { // parse spec'd default value
Object value = f.getDefaultValue();
switch (f.getType()) {
case ENUM:
value = ((EnumValueDescriptor) value).getName();
break;
}
String json = toString(value);
try {
return MAPPER.readTree(FACTORY.createParser(json));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
switch (f.getType()) { // generate default for type
case BOOL:
return NODES.booleanNode(false);
case FLOAT:
case DOUBLE:
case INT32:
case UINT32:
case SINT32:
case FIXED32:
case SFIXED32:
case INT64:
case UINT64:
case SINT64:
case FIXED64:
case SFIXED64:
return NODES.numberNode(0);
case STRING:
case BYTES:
return NODES.textNode("");
case ENUM:
return NODES.textNode(f.getEnumType().getValues().get(0).getName());
case MESSAGE:
return NODES.nullNode();
case GROUP: // groups are deprecated
default:
throw new RuntimeException("Unexpected type: " + f.getType());
}
}
/**
* Get Conversion from protobuf descriptor via protobuf classname.
*
* @param descriptor protobuf descriptor
* @return Conversion | null
*/
private Conversion getConversionByDescriptor(Descriptor descriptor) {
String namespace = getNamespace(descriptor.getFile(), descriptor.getContainingType());
String name = descriptor.getName();
String dot = namespace.endsWith("$") ? "" : "."; // back-compatibly handle $
try {
Class clazz = ClassUtils.forName(getClassLoader(), namespace + dot + name);
return getConversionByClass(clazz);
} catch (ClassNotFoundException e) {
return null;
}
}
}