blob: fa8239ba4ab03a7e5d68b280a0447f4e7ab12d9f [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.avro.protobuf;
import java.util.List;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Map;
import java.util.IdentityHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.io.IOException;
import java.io.File;
import org.apache.avro.Schema;
import org.apache.avro.Schema.Field;
import org.apache.avro.generic.GenericData;
import org.apache.avro.specific.SpecificData;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.DatumWriter;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.DescriptorProtos.FileOptions;
import org.apache.avro.util.ClassUtils;
import org.codehaus.jackson.JsonFactory;
import org.codehaus.jackson.JsonNode;
import org.codehaus.jackson.map.ObjectMapper;
import org.codehaus.jackson.node.JsonNodeFactory;
/** Utilities for serializing Protobuf data in Avro format. */
public class ProtobufData extends GenericData {
private static final String PROTOBUF_TYPE = "protobuf";
private static final ProtobufData INSTANCE = new ProtobufData();
protected ProtobufData() {}
/** Return the singleton instance. */
public static ProtobufData get() { return INSTANCE; }
@Override
public DatumReader createDatumReader(Schema schema) {
return new ProtobufDatumReader(schema, schema, this);
}
@Override
public DatumWriter createDatumWriter(Schema schema) {
return new ProtobufDatumWriter(schema, this);
}
@Override
public void setField(Object r, String n, int pos, Object o) {
setField(r, n, pos, o, getRecordState(r, getSchema(r.getClass())));
}
@Override
public Object getField(Object r, String name, int pos) {
return getField(r, name, pos, getRecordState(r, getSchema(r.getClass())));
}
@Override
protected void setField(Object r, String n, int pos, Object o, Object state) {
Builder b = (Builder)r;
FieldDescriptor f = ((FieldDescriptor[])state)[pos];
switch (f.getType()) {
case MESSAGE:
if (o == null) {
b.clearField(f);
break;
}
default:
b.setField(f, o);
}
}
@Override
protected Object getField(Object record, String name, int pos, Object state) {
Message m = (Message)record;
FieldDescriptor f = ((FieldDescriptor[])state)[pos];
switch (f.getType()) {
case MESSAGE:
if (!f.isRepeated() && !m.hasField(f))
return null;
default:
return m.getField(f);
}
}
private final Map<Descriptor,FieldDescriptor[]> fieldCache =
new ConcurrentHashMap<Descriptor,FieldDescriptor[]>();
@Override
protected Object getRecordState(Object r, Schema s) {
Descriptor d = ((MessageOrBuilder)r).getDescriptorForType();
FieldDescriptor[] fields = fieldCache.get(d);
if (fields == null) { // cache miss
fields = new FieldDescriptor[s.getFields().size()];
for (Field f : s.getFields())
fields[f.pos()] = d.findFieldByName(f.name());
fieldCache.put(d, fields); // update cache
}
return fields;
}
@Override
protected boolean isRecord(Object datum) {
return datum instanceof Message;
}
@Override
public Object newRecord(Object old, Schema schema) {
try {
Class c = ClassUtils.forName(SpecificData.getClassName(schema));
if (c == null)
return newRecord(old, schema); // punt to generic
if (c.isInstance(old))
return old; // reuse instance
return c.getMethod("newBuilder").invoke(null);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
protected boolean isArray(Object datum) {
return datum instanceof List;
}
@Override
protected boolean isBytes(Object datum) {
return datum instanceof ByteString;
}
@Override
protected Schema getRecordSchema(Object record) {
return getSchema(((Message)record).getDescriptorForType());
}
private final Map<Class,Schema> schemaCache
= new ConcurrentHashMap<Class,Schema>();
/** Return a record schema given a protobuf message class. */
public Schema getSchema(Class c) {
Schema schema = schemaCache.get(c);
if (schema == null) { // cache miss
try {
Object descriptor = c.getMethod("getDescriptor").invoke(null);
if (c.isEnum())
schema = getSchema((EnumDescriptor)descriptor);
else
schema = getSchema((Descriptor)descriptor);
} catch (Exception e) {
throw new RuntimeException(e);
}
schemaCache.put(c, schema); // update cache
}
return schema;
}
private static final ThreadLocal<Map<Descriptor,Schema>> SEEN
= new ThreadLocal<Map<Descriptor,Schema>>() {
protected Map<Descriptor,Schema> initialValue() {
return new IdentityHashMap<Descriptor,Schema>();
}
};
private Schema getSchema(Descriptor descriptor) {
Map<Descriptor,Schema> seen = SEEN.get();
if (seen.containsKey(descriptor)) // stop recursion
return seen.get(descriptor);
boolean first = seen.isEmpty();
try {
Schema result =
Schema.createRecord(descriptor.getName(), null,
getNamespace(descriptor.getFile(),
descriptor.getContainingType()),
false);
seen.put(descriptor, result);
List<Field> fields = new ArrayList<Field>();
for (FieldDescriptor f : descriptor.getFields())
fields.add(new Field(f.getName(), getSchema(f), null, getDefault(f)));
result.setFields(fields);
return result;
} finally {
if (first)
seen.clear();
}
}
private String getNamespace(FileDescriptor fd, Descriptor containing) {
FileOptions o = fd.getOptions();
String p = o.hasJavaPackage()
? o.getJavaPackage()
: fd.getPackage();
String outer;
if (o.hasJavaOuterClassname()) {
outer = o.getJavaOuterClassname();
} else {
outer = new File(fd.getName()).getName();
outer = outer.substring(0, outer.lastIndexOf('.'));
outer = toCamelCase(outer);
}
String inner = "";
while (containing != null) {
inner = containing.getName() + "$" + inner;
containing = containing.getContainingType();
}
return p + "." + outer + "$" + inner;
}
private static String toCamelCase(String s){
String[] parts = s.split("_");
String camelCaseString = "";
for (String part : parts) {
camelCaseString = camelCaseString + cap(part);
}
return camelCaseString;
}
private static String cap(String s) {
return s.substring(0, 1).toUpperCase() + s.substring(1).toLowerCase();
}
private static final Schema NULL = Schema.create(Schema.Type.NULL);
private Schema getSchema(FieldDescriptor f) {
Schema s = getNonRepeatedSchema(f);
if (f.isRepeated())
s = Schema.createArray(s);
return s;
}
private Schema getNonRepeatedSchema(FieldDescriptor f) {
Schema result;
switch (f.getType()) {
case BOOL:
return Schema.create(Schema.Type.BOOLEAN);
case FLOAT:
return Schema.create(Schema.Type.FLOAT);
case DOUBLE:
return Schema.create(Schema.Type.DOUBLE);
case STRING:
Schema s = Schema.create(Schema.Type.STRING);
GenericData.setStringType(s, GenericData.StringType.String);
return s;
case BYTES:
return Schema.create(Schema.Type.BYTES);
case INT32: case UINT32: case SINT32: case FIXED32: case SFIXED32:
return Schema.create(Schema.Type.INT);
case INT64: case UINT64: case SINT64: case FIXED64: case SFIXED64:
return Schema.create(Schema.Type.LONG);
case ENUM:
return getSchema(f.getEnumType());
case MESSAGE:
result = getSchema(f.getMessageType());
if (f.isOptional())
// wrap optional record fields in a union with null
result = Schema.createUnion(Arrays.asList(new Schema[] {NULL, result}));
return result;
case GROUP: // groups are deprecated
default:
throw new RuntimeException("Unexpected type: "+f.getType());
}
}
private Schema getSchema(EnumDescriptor d) {
List<String> symbols = new ArrayList<String>();
for (EnumValueDescriptor e : d.getValues()) {
symbols.add(e.getName());
}
return Schema.createEnum(d.getName(), null,
getNamespace(d.getFile(), d.getContainingType()),
symbols);
}
private static final JsonFactory FACTORY = new JsonFactory();
private static final ObjectMapper MAPPER = new ObjectMapper(FACTORY);
private static final JsonNodeFactory NODES = JsonNodeFactory.instance;
private JsonNode getDefault(FieldDescriptor f) {
if (f.isRequired() || f.isRepeated()) // no default
return null;
if (f.hasDefaultValue()) { // parse spec'd default value
Object value = f.getDefaultValue();
switch (f.getType()) {
case ENUM:
value = ((EnumValueDescriptor)value).getName();
break;
}
String json = toString(value);
try {
return MAPPER.readTree(FACTORY.createJsonParser(json));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
switch (f.getType()) { // generate default for type
case BOOL:
return NODES.booleanNode(false);
case FLOAT: case DOUBLE:
case INT32: case UINT32: case SINT32: case FIXED32: case SFIXED32:
case INT64: case UINT64: case SINT64: case FIXED64: case SFIXED64:
return NODES.numberNode(0);
case STRING: case BYTES:
return NODES.textNode("");
case ENUM:
return NODES.textNode(f.getEnumType().getValues().get(0).getName());
case MESSAGE:
return NODES.nullNode();
case GROUP: // groups are deprecated
default:
throw new RuntimeException("Unexpected type: "+f.getType());
}
}
}