| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.avro.protobuf; |
| |
| import java.util.List; |
| import java.util.Arrays; |
| import java.util.ArrayList; |
| import java.util.Map; |
| import java.util.IdentityHashMap; |
| import java.util.concurrent.ConcurrentHashMap; |
| |
| import java.io.IOException; |
| import java.io.File; |
| |
| import org.apache.avro.Schema; |
| import org.apache.avro.Schema.Field; |
| import org.apache.avro.generic.GenericData; |
| import org.apache.avro.specific.SpecificData; |
| import org.apache.avro.io.DatumReader; |
| import org.apache.avro.io.DatumWriter; |
| |
| import com.google.protobuf.ByteString; |
| import com.google.protobuf.Message; |
| import com.google.protobuf.Message.Builder; |
| import com.google.protobuf.MessageOrBuilder; |
| import com.google.protobuf.Descriptors.Descriptor; |
| import com.google.protobuf.Descriptors.FieldDescriptor; |
| import com.google.protobuf.Descriptors.EnumDescriptor; |
| import com.google.protobuf.Descriptors.EnumValueDescriptor; |
| import com.google.protobuf.Descriptors.FileDescriptor; |
| import com.google.protobuf.DescriptorProtos.FileOptions; |
| |
| import org.apache.avro.util.ClassUtils; |
| import org.codehaus.jackson.JsonFactory; |
| import org.codehaus.jackson.JsonNode; |
| import org.codehaus.jackson.map.ObjectMapper; |
| import org.codehaus.jackson.node.JsonNodeFactory; |
| |
| /** Utilities for serializing Protobuf data in Avro format. */ |
| public class ProtobufData extends GenericData { |
| private static final String PROTOBUF_TYPE = "protobuf"; |
| |
| private static final ProtobufData INSTANCE = new ProtobufData(); |
| |
| protected ProtobufData() {} |
| |
| /** Return the singleton instance. */ |
| public static ProtobufData get() { return INSTANCE; } |
| |
| @Override |
| public DatumReader createDatumReader(Schema schema) { |
| return new ProtobufDatumReader(schema, schema, this); |
| } |
| |
| @Override |
| public DatumWriter createDatumWriter(Schema schema) { |
| return new ProtobufDatumWriter(schema, this); |
| } |
| |
| @Override |
| public void setField(Object r, String n, int pos, Object o) { |
| setField(r, n, pos, o, getRecordState(r, getSchema(r.getClass()))); |
| } |
| |
| @Override |
| public Object getField(Object r, String name, int pos) { |
| return getField(r, name, pos, getRecordState(r, getSchema(r.getClass()))); |
| } |
| |
| @Override |
| protected void setField(Object r, String n, int pos, Object o, Object state) { |
| Builder b = (Builder)r; |
| FieldDescriptor f = ((FieldDescriptor[])state)[pos]; |
| switch (f.getType()) { |
| case MESSAGE: |
| if (o == null) { |
| b.clearField(f); |
| break; |
| } |
| default: |
| b.setField(f, o); |
| } |
| } |
| |
| @Override |
| protected Object getField(Object record, String name, int pos, Object state) { |
| Message m = (Message)record; |
| FieldDescriptor f = ((FieldDescriptor[])state)[pos]; |
| switch (f.getType()) { |
| case MESSAGE: |
| if (!f.isRepeated() && !m.hasField(f)) |
| return null; |
| default: |
| return m.getField(f); |
| } |
| } |
| |
| private final Map<Descriptor,FieldDescriptor[]> fieldCache = |
| new ConcurrentHashMap<Descriptor,FieldDescriptor[]>(); |
| |
| @Override |
| protected Object getRecordState(Object r, Schema s) { |
| Descriptor d = ((MessageOrBuilder)r).getDescriptorForType(); |
| FieldDescriptor[] fields = fieldCache.get(d); |
| if (fields == null) { // cache miss |
| fields = new FieldDescriptor[s.getFields().size()]; |
| for (Field f : s.getFields()) |
| fields[f.pos()] = d.findFieldByName(f.name()); |
| fieldCache.put(d, fields); // update cache |
| } |
| return fields; |
| } |
| |
| @Override |
| protected boolean isRecord(Object datum) { |
| return datum instanceof Message; |
| } |
| |
| @Override |
| public Object newRecord(Object old, Schema schema) { |
| try { |
| Class c = ClassUtils.forName(SpecificData.getClassName(schema)); |
| if (c == null) |
| return newRecord(old, schema); // punt to generic |
| if (c.isInstance(old)) |
| return old; // reuse instance |
| return c.getMethod("newBuilder").invoke(null); |
| |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| @Override |
| protected boolean isArray(Object datum) { |
| return datum instanceof List; |
| } |
| |
| @Override |
| protected boolean isBytes(Object datum) { |
| return datum instanceof ByteString; |
| } |
| |
| @Override |
| protected Schema getRecordSchema(Object record) { |
| return getSchema(((Message)record).getDescriptorForType()); |
| } |
| |
| private final Map<Class,Schema> schemaCache |
| = new ConcurrentHashMap<Class,Schema>(); |
| |
| /** Return a record schema given a protobuf message class. */ |
| public Schema getSchema(Class c) { |
| Schema schema = schemaCache.get(c); |
| |
| if (schema == null) { // cache miss |
| try { |
| Object descriptor = c.getMethod("getDescriptor").invoke(null); |
| if (c.isEnum()) |
| schema = getSchema((EnumDescriptor)descriptor); |
| else |
| schema = getSchema((Descriptor)descriptor); |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| schemaCache.put(c, schema); // update cache |
| } |
| return schema; |
| } |
| |
| private static final ThreadLocal<Map<Descriptor,Schema>> SEEN |
| = new ThreadLocal<Map<Descriptor,Schema>>() { |
| protected Map<Descriptor,Schema> initialValue() { |
| return new IdentityHashMap<Descriptor,Schema>(); |
| } |
| }; |
| |
| private Schema getSchema(Descriptor descriptor) { |
| Map<Descriptor,Schema> seen = SEEN.get(); |
| if (seen.containsKey(descriptor)) // stop recursion |
| return seen.get(descriptor); |
| boolean first = seen.isEmpty(); |
| try { |
| Schema result = |
| Schema.createRecord(descriptor.getName(), null, |
| getNamespace(descriptor.getFile(), |
| descriptor.getContainingType()), |
| false); |
| |
| seen.put(descriptor, result); |
| |
| List<Field> fields = new ArrayList<Field>(); |
| for (FieldDescriptor f : descriptor.getFields()) |
| fields.add(new Field(f.getName(), getSchema(f), null, getDefault(f))); |
| result.setFields(fields); |
| return result; |
| |
| } finally { |
| if (first) |
| seen.clear(); |
| } |
| } |
| |
| private String getNamespace(FileDescriptor fd, Descriptor containing) { |
| FileOptions o = fd.getOptions(); |
| String p = o.hasJavaPackage() |
| ? o.getJavaPackage() |
| : fd.getPackage(); |
| String outer; |
| if (o.hasJavaOuterClassname()) { |
| outer = o.getJavaOuterClassname(); |
| } else { |
| outer = new File(fd.getName()).getName(); |
| outer = outer.substring(0, outer.lastIndexOf('.')); |
| outer = toCamelCase(outer); |
| } |
| String inner = ""; |
| while (containing != null) { |
| inner = containing.getName() + "$" + inner; |
| containing = containing.getContainingType(); |
| } |
| return p + "." + outer + "$" + inner; |
| } |
| |
| private static String toCamelCase(String s){ |
| String[] parts = s.split("_"); |
| String camelCaseString = ""; |
| for (String part : parts) { |
| camelCaseString = camelCaseString + cap(part); |
| } |
| return camelCaseString; |
| } |
| |
| private static String cap(String s) { |
| return s.substring(0, 1).toUpperCase() + s.substring(1).toLowerCase(); |
| } |
| |
| private static final Schema NULL = Schema.create(Schema.Type.NULL); |
| |
| private Schema getSchema(FieldDescriptor f) { |
| Schema s = getNonRepeatedSchema(f); |
| if (f.isRepeated()) |
| s = Schema.createArray(s); |
| return s; |
| } |
| |
| private Schema getNonRepeatedSchema(FieldDescriptor f) { |
| Schema result; |
| switch (f.getType()) { |
| case BOOL: |
| return Schema.create(Schema.Type.BOOLEAN); |
| case FLOAT: |
| return Schema.create(Schema.Type.FLOAT); |
| case DOUBLE: |
| return Schema.create(Schema.Type.DOUBLE); |
| case STRING: |
| Schema s = Schema.create(Schema.Type.STRING); |
| GenericData.setStringType(s, GenericData.StringType.String); |
| return s; |
| case BYTES: |
| return Schema.create(Schema.Type.BYTES); |
| case INT32: case UINT32: case SINT32: case FIXED32: case SFIXED32: |
| return Schema.create(Schema.Type.INT); |
| case INT64: case UINT64: case SINT64: case FIXED64: case SFIXED64: |
| return Schema.create(Schema.Type.LONG); |
| case ENUM: |
| return getSchema(f.getEnumType()); |
| case MESSAGE: |
| result = getSchema(f.getMessageType()); |
| if (f.isOptional()) |
| // wrap optional record fields in a union with null |
| result = Schema.createUnion(Arrays.asList(new Schema[] {NULL, result})); |
| return result; |
| case GROUP: // groups are deprecated |
| default: |
| throw new RuntimeException("Unexpected type: "+f.getType()); |
| } |
| } |
| |
| private Schema getSchema(EnumDescriptor d) { |
| List<String> symbols = new ArrayList<String>(); |
| for (EnumValueDescriptor e : d.getValues()) { |
| symbols.add(e.getName()); |
| } |
| return Schema.createEnum(d.getName(), null, |
| getNamespace(d.getFile(), d.getContainingType()), |
| symbols); |
| } |
| |
| private static final JsonFactory FACTORY = new JsonFactory(); |
| private static final ObjectMapper MAPPER = new ObjectMapper(FACTORY); |
| private static final JsonNodeFactory NODES = JsonNodeFactory.instance; |
| |
| private JsonNode getDefault(FieldDescriptor f) { |
| if (f.isRequired() || f.isRepeated()) // no default |
| return null; |
| |
| if (f.hasDefaultValue()) { // parse spec'd default value |
| Object value = f.getDefaultValue(); |
| switch (f.getType()) { |
| case ENUM: |
| value = ((EnumValueDescriptor)value).getName(); |
| break; |
| } |
| String json = toString(value); |
| try { |
| return MAPPER.readTree(FACTORY.createJsonParser(json)); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| switch (f.getType()) { // generate default for type |
| case BOOL: |
| return NODES.booleanNode(false); |
| case FLOAT: case DOUBLE: |
| case INT32: case UINT32: case SINT32: case FIXED32: case SFIXED32: |
| case INT64: case UINT64: case SINT64: case FIXED64: case SFIXED64: |
| return NODES.numberNode(0); |
| case STRING: case BYTES: |
| return NODES.textNode(""); |
| case ENUM: |
| return NODES.textNode(f.getEnumType().getValues().get(0).getName()); |
| case MESSAGE: |
| return NODES.nullNode(); |
| case GROUP: // groups are deprecated |
| default: |
| throw new RuntimeException("Unexpected type: "+f.getType()); |
| } |
| |
| } |
| |
| } |