blob: 47a17c7849931bf786533f708e93df3c5c6f3c45 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gora.compiler;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.avro.Protocol;
import org.apache.avro.Schema;
import org.apache.avro.Protocol.Message;
import org.apache.avro.Schema.Field;
import org.apache.avro.specific.SpecificData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Generate specific Java interfaces and classes for protocols and schemas. */
public class GoraCompiler {
private File dest;
private Writer out;
private Set<Schema> queue = new HashSet<Schema>();
private static final Logger log = LoggerFactory.getLogger(GoraCompiler.class);
private GoraCompiler(File dest) {
this.dest = dest; // root directory for output
}
/** Generates Java interface and classes for a protocol.
* @param src the source Avro protocol file
* @param dest the directory to place generated files in
*/
public static void compileProtocol(File src, File dest) throws IOException {
GoraCompiler compiler = new GoraCompiler(dest);
Protocol protocol = Protocol.parse(src);
for (Schema s : protocol.getTypes()) // enqueue types
compiler.enqueue(s);
compiler.compileInterface(protocol); // generate interface
compiler.compile(); // generate classes for types
}
/** Generates Java classes for a schema. */
public static void compileSchema(File src, File dest) throws IOException {
log.info("Compiling " + src + " to " + dest );
GoraCompiler compiler = new GoraCompiler(dest);
compiler.enqueue(Schema.parse(src)); // enqueue types
compiler.compile(); // generate classes for types
}
private static String camelCasify(String s) {
return s.substring(0, 1).toUpperCase() + s.substring(1);
}
/** Recognizes camel case */
private static String toUpperCase(String s) {
StringBuilder builder = new StringBuilder();
for(int i=0; i<s.length(); i++) {
if(i > 0) {
if(Character.isUpperCase(s.charAt(i))
&& Character.isLowerCase(s.charAt(i-1))
&& Character.isLetter(s.charAt(i))) {
builder.append("_");
}
}
builder.append(Character.toUpperCase(s.charAt(i)));
}
return builder.toString();
}
/** Recursively enqueue schemas that need a class generated. */
private void enqueue(Schema schema) throws IOException {
if (queue.contains(schema)) return;
switch (schema.getType()) {
case RECORD:
queue.add(schema);
for (Field field : schema.getFields())
enqueue(field.schema());
break;
case MAP:
enqueue(schema.getValueType());
break;
case ARRAY:
enqueue(schema.getElementType());
break;
case UNION:
for (Schema s : schema.getTypes())
enqueue(s);
break;
case ENUM:
case FIXED:
queue.add(schema);
break;
case STRING: case BYTES:
case INT: case LONG:
case FLOAT: case DOUBLE:
case BOOLEAN: case NULL:
break;
default: throw new RuntimeException("Unknown type: "+schema);
}
}
/** Generate java classes for enqueued schemas. */
private void compile() throws IOException {
for (Schema schema : queue)
compile(schema);
}
private void compileInterface(Protocol protocol) throws IOException {
startFile(protocol.getName(), protocol.getNamespace());
try {
line(0, "public interface "+protocol.getName()+" {");
out.append("\n");
for (Map.Entry<String,Message> e : protocol.getMessages().entrySet()) {
String name = e.getKey();
Message message = e.getValue();
Schema request = message.getRequest();
Schema response = message.getResponse();
line(1, unbox(response)+" "+name+"("+params(request)+")");
line(2,"throws AvroRemoteException"+errors(message.getErrors())+";");
}
line(0, "}");
} finally {
out.close();
}
}
private void startFile(String name, String space) throws IOException {
File dir = new File(dest, space.replace('.', File.separatorChar));
if (!dir.exists())
if (!dir.mkdirs())
throw new IOException("Unable to create " + dir);
name = cap(name) + ".java";
out = new OutputStreamWriter(new FileOutputStream(new File(dir, name)));
header(space);
}
private void header(String namespace) throws IOException {
if(namespace != null) {
line(0, "package "+namespace+";\n");
}
line(0, "import java.nio.ByteBuffer;");
line(0, "import java.util.Map;");
line(0, "import java.util.HashMap;");
line(0, "import org.apache.avro.Protocol;");
line(0, "import org.apache.avro.Schema;");
line(0, "import org.apache.avro.AvroRuntimeException;");
line(0, "import org.apache.avro.Protocol;");
line(0, "import org.apache.avro.util.Utf8;");
line(0, "import org.apache.avro.ipc.AvroRemoteException;");
line(0, "import org.apache.avro.generic.GenericArray;");
line(0, "import org.apache.avro.specific.FixedSize;");
line(0, "import org.apache.avro.specific.SpecificExceptionBase;");
line(0, "import org.apache.avro.specific.SpecificRecordBase;");
line(0, "import org.apache.avro.specific.SpecificRecord;");
line(0, "import org.apache.avro.specific.SpecificFixed;");
line(0, "import org.apache.gora.persistency.StateManager;");
line(0, "import org.apache.gora.persistency.impl.PersistentBase;");
line(0, "import org.apache.gora.persistency.impl.StateManagerImpl;");
line(0, "import org.apache.gora.persistency.StatefulHashMap;");
line(0, "import org.apache.gora.persistency.ListGenericArray;");
for (Schema s : queue)
if (namespace == null
? (s.getNamespace() != null)
: !namespace.equals(s.getNamespace()))
line(0, "import "+SpecificData.get().getClassName(s)+";");
line(0, "");
line(0, "@SuppressWarnings(\"all\")");
}
private String params(Schema request) throws IOException {
StringBuilder b = new StringBuilder();
int count = 0;
for (Field field : request.getFields()) {
b.append(unbox(field.schema()));
b.append(" ");
b.append(field.name());
if (++count < request.getFields().size())
b.append(", ");
}
return b.toString();
}
private String errors(Schema errs) throws IOException {
StringBuilder b = new StringBuilder();
for (Schema error : errs.getTypes().subList(1, errs.getTypes().size())) {
b.append(", ");
b.append(error.getName());
}
return b.toString();
}
private void compile(Schema schema) throws IOException {
startFile(schema.getName(), schema.getNamespace());
try {
switch (schema.getType()) {
case RECORD:
String type = type(schema);
line(0, "public class "+ type
+" extends PersistentBase {");
// schema definition
line(1, "public static final Schema _SCHEMA = Schema.parse(\""
+esc(schema)+"\");");
//field information
line(1, "public static enum Field {");
int i=0;
for (Field field : schema.getFields()) {
line(2,toUpperCase(field.name())+"("+(i++)+ ",\"" + field.name() + "\"),");
}
line(2, ";");
line(2, "private int index;");
line(2, "private String name;");
line(2, "Field(int index, String name) {this.index=index;this.name=name;}");
line(2, "public int getIndex() {return index;}");
line(2, "public String getName() {return name;}");
line(2, "public String toString() {return name;}");
line(1, "};");
StringBuilder builder = new StringBuilder(
"public static final String[] _ALL_FIELDS = {");
for (Field field : schema.getFields()) {
builder.append("\"").append(field.name()).append("\",");
}
builder.append("};");
line(1, builder.toString());
line(1, "static {");
line(2, "PersistentBase.registerFields("+type+".class, _ALL_FIELDS);");
line(1, "}");
// field declations
for (Field field : schema.getFields()) {
line(1,"private "+unbox(field.schema())+" "+field.name()+";");
}
//constructors
line(1, "public " + type + "() {");
line(2, "this(new StateManagerImpl());");
line(1, "}");
line(1, "public " + type + "(StateManager stateManager) {");
line(2, "super(stateManager);");
for (Field field : schema.getFields()) {
Schema fieldSchema = field.schema();
switch (fieldSchema.getType()) {
case ARRAY:
String valueType = type(fieldSchema.getElementType());
line(2, field.name()+" = new ListGenericArray<"+valueType+">(getSchema()" +
".getField(\""+field.name()+"\").schema());");
break;
case MAP:
valueType = type(fieldSchema.getValueType());
line(2, field.name()+" = new StatefulHashMap<Utf8,"+valueType+">();");
}
}
line(1, "}");
//newInstance(StateManager)
line(1, "public " + type + " newInstance(StateManager stateManager) {");
line(2, "return new " + type + "(stateManager);" );
line(1, "}");
// schema method
line(1, "public Schema getSchema() { return _SCHEMA; }");
// get method
line(1, "public Object get(int _field) {");
line(2, "switch (_field) {");
i = 0;
for (Field field : schema.getFields()) {
line(2, "case "+(i++)+": return "+field.name()+";");
}
line(2, "default: throw new AvroRuntimeException(\"Bad index\");");
line(2, "}");
line(1, "}");
// put method
line(1, "@SuppressWarnings(value=\"unchecked\")");
line(1, "public void put(int _field, Object _value) {");
line(2, "if(isFieldEqual(_field, _value)) return;");
line(2, "getStateManager().setDirty(this, _field);");
line(2, "switch (_field) {");
i = 0;
for (Field field : schema.getFields()) {
line(2, "case "+i+":"+field.name()+" = ("+
type(field.schema())+")_value; break;");
i++;
}
line(2, "default: throw new AvroRuntimeException(\"Bad index\");");
line(2, "}");
line(1, "}");
// java bean style getters and setters
i = 0;
for (Field field : schema.getFields()) {
String camelKey = camelCasify(field.name());
Schema fieldSchema = field.schema();
switch (fieldSchema.getType()) {
case INT:case LONG:case FLOAT:case DOUBLE:
case BOOLEAN:case BYTES:case STRING: case ENUM: case RECORD:
case FIXED:
String unboxed = unbox(fieldSchema);
String fieldType = type(fieldSchema);
line(1, "public "+unboxed+" get" +camelKey+"() {");
line(2, "return ("+fieldType+") get("+i+");");
line(1, "}");
line(1, "public void set"+camelKey+"("+unboxed+" value) {");
line(2, "put("+i+", value);");
line(1, "}");
break;
case ARRAY:
unboxed = unbox(fieldSchema.getElementType());
fieldType = type(fieldSchema.getElementType());
line(1, "public GenericArray<"+fieldType+"> get"+camelKey+"() {");
line(2, "return (GenericArray<"+fieldType+">) get("+i+");");
line(1, "}");
line(1, "public void addTo"+camelKey+"("+unboxed+" element) {");
line(2, "getStateManager().setDirty(this, "+i+");");
line(2, field.name()+".add(element);");
line(1, "}");
break;
case MAP:
unboxed = unbox(fieldSchema.getValueType());
fieldType = type(fieldSchema.getValueType());
line(1, "public Map<Utf8, "+fieldType+"> get"+camelKey+"() {");
line(2, "return (Map<Utf8, "+fieldType+">) get("+i+");");
line(1, "}");
line(1, "public "+fieldType+" getFrom"+camelKey+"(Utf8 key) {");
line(2, "if ("+field.name()+" == null) { return null; }");
line(2, "return "+field.name()+".get(key);");
line(1, "}");
line(1, "public void putTo"+camelKey+"(Utf8 key, "+unboxed+" value) {");
line(2, "getStateManager().setDirty(this, "+i+");");
line(2, field.name()+".put(key, value);");
line(1, "}");
line(1, "public "+fieldType+" removeFrom"+camelKey+"(Utf8 key) {");
line(2, "if ("+field.name()+" == null) { return null; }");
line(2, "getStateManager().setDirty(this, "+i+");");
line(2, "return "+field.name()+".remove(key);");
line(1, "}");
}
i++;
}
line(0, "}");
break;
case ENUM:
line(0, "public enum "+type(schema)+" { ");
StringBuilder b = new StringBuilder();
int count = 0;
for (String symbol : schema.getEnumSymbols()) {
b.append(symbol);
if (++count < schema.getEnumSymbols().size())
b.append(", ");
}
line(1, b.toString());
line(0, "}");
break;
case FIXED:
line(0, "@FixedSize("+schema.getFixedSize()+")");
line(0, "public class "+type(schema)+" extends SpecificFixed {}");
break;
case MAP: case ARRAY: case UNION: case STRING: case BYTES:
case INT: case LONG: case FLOAT: case DOUBLE: case BOOLEAN: case NULL:
break;
default: throw new RuntimeException("Unknown type: "+schema);
}
} finally {
out.close();
}
}
private static final Schema NULL_SCHEMA = Schema.create(Schema.Type.NULL);
public static String type(Schema schema) {
switch (schema.getType()) {
case RECORD:
case ENUM:
case FIXED:
return schema.getName();
case ARRAY:
return "GenericArray<"+type(schema.getElementType())+">";
case MAP:
return "Map<Utf8,"+type(schema.getValueType())+">";
case UNION:
List<Schema> types = schema.getTypes(); // elide unions with null
if ((types.size() == 2) && types.contains(NULL_SCHEMA))
return type(types.get(types.get(0).equals(NULL_SCHEMA) ? 1 : 0));
return "Object";
case STRING: return "Utf8";
case BYTES: return "ByteBuffer";
case INT: return "Integer";
case LONG: return "Long";
case FLOAT: return "Float";
case DOUBLE: return "Double";
case BOOLEAN: return "Boolean";
case NULL: return "Void";
default: throw new RuntimeException("Unknown type: "+schema);
}
}
public static String unbox(Schema schema) {
switch (schema.getType()) {
case INT: return "int";
case LONG: return "long";
case FLOAT: return "float";
case DOUBLE: return "double";
case BOOLEAN: return "boolean";
default: return type(schema);
}
}
private void line(int indent, String text) throws IOException {
for (int i = 0; i < indent; i ++) {
out.append(" ");
}
out.append(text);
out.append("\n");
}
static String cap(String name) {
return name.substring(0,1).toUpperCase()+name.substring(1,name.length());
}
private static String esc(Object o) {
return o.toString().replace("\"", "\\\"");
}
public static void main(String[] args) throws Exception {
if (args.length < 2) {
System.err.println("Usage: Compiler <schema file> <output dir>");
System.exit(1);
}
compileSchema(new File(args[0]), new File(args[1]));
}
}