blob: f198f16c5120b940b3b3dc908eee79d92a8d3ce0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.tinkerpop.gremlin.process.traversal.dsl;
import com.squareup.javapoet.ArrayTypeName;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import org.apache.tinkerpop.gremlin.process.remote.RemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.Traversal;
import org.apache.tinkerpop.gremlin.process.traversal.TraversalStrategies;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.AddEdgeStartStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.AddVertexStartStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.map.GraphStep;
import org.apache.tinkerpop.gremlin.process.traversal.step.sideEffect.InjectStep;
import org.apache.tinkerpop.gremlin.process.traversal.util.DefaultTraversal;
import org.apache.tinkerpop.gremlin.structure.Edge;
import org.apache.tinkerpop.gremlin.structure.Graph;
import org.apache.tinkerpop.gremlin.structure.Vertex;
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
import javax.annotation.processing.ProcessingEnvironment;
import javax.annotation.processing.RoundEnvironment;
import javax.annotation.processing.SupportedAnnotationTypes;
import javax.annotation.processing.SupportedSourceVersion;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.type.TypeVariable;
import javax.lang.model.util.Elements;
import javax.lang.model.util.Types;
import javax.tools.Diagnostic;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* A custom Java annotation processor for the {@link GremlinDsl} annotation that helps to generate DSLs classes.
*
* @author Stephen Mallette (http://stephen.genoprime.com)
*/
@SupportedAnnotationTypes("org.apache.tinkerpop.gremlin.process.traversal.dsl.GremlinDsl")
@SupportedSourceVersion(SourceVersion.RELEASE_8)
public class GremlinDslProcessor extends AbstractProcessor {
private Messager messager;
private Elements elementUtils;
private Filer filer;
private Types typeUtils;
@Override
public synchronized void init(final ProcessingEnvironment processingEnv) {
super.init(processingEnv);
messager = processingEnv.getMessager();
elementUtils = processingEnv.getElementUtils();
filer = processingEnv.getFiler();
typeUtils = processingEnv.getTypeUtils();
}
@Override
public boolean process(final Set<? extends TypeElement> annotations, final RoundEnvironment roundEnv) {
try {
for (Element dslElement : roundEnv.getElementsAnnotatedWith(GremlinDsl.class)) {
validateDSL(dslElement);
final Context ctx = new Context((TypeElement) dslElement);
// creates the "Traversal" interface using an extension of the GraphTraversal class that has the
// GremlinDsl annotation on it
generateTraversalInterface(ctx);
// create the "DefaultTraversal" class which implements the above generated "Traversal" and can then
// be used by the "TraversalSource" generated below to spawn new traversal instances.
generateDefaultTraversal(ctx);
// create the "TraversalSource" class which is used to spawn traversals from a Graph instance. It will
// spawn instances of the "DefaultTraversal" generated above.
generateTraversalSource(ctx);
// create anonymous traversal for DSL
generateAnonymousTraversal(ctx);
}
} catch (Exception ex) {
messager.printMessage(Diagnostic.Kind.ERROR, ex.getMessage());
}
return true;
}
private void generateAnonymousTraversal(final Context ctx) throws IOException {
final TypeSpec.Builder anonymousClass = TypeSpec.classBuilder("__")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL);
// this class is just static methods - it should not be instantiated
anonymousClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PRIVATE)
.build());
// add start() method
anonymousClass.addMethod(MethodSpec.methodBuilder("start")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.addTypeVariable(TypeVariableName.get("A"))
.addStatement("return new $N<>()", ctx.defaultTraversalClazz)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("A"), TypeVariableName.get("A")))
.build());
// process the methods of the GremlinDsl annotated class
for (ExecutableElement templateMethod : findMethodsOfElement(ctx.annotatedDslType, null)) {
final Optional<GremlinDsl.AnonymousMethod> methodAnnotation = Optional.ofNullable(templateMethod.getAnnotation(GremlinDsl.AnonymousMethod.class));
final String methodName = templateMethod.getSimpleName().toString();
// either use the direct return type of the DSL specification or override it with specification from
// GremlinDsl.AnonymousMethod
final TypeName returnType = methodAnnotation.isPresent() && methodAnnotation.get().returnTypeParameters().length > 0 ?
getOverridenReturnTypeDefinition(ctx.traversalClassName, methodAnnotation.get().returnTypeParameters()) :
getReturnTypeDefinition(ctx.traversalClassName, templateMethod);
final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(methodName)
.addModifiers(Modifier.STATIC, Modifier.PUBLIC)
.addExceptions(templateMethod.getThrownTypes().stream().map(TypeName::get).collect(Collectors.toList()))
.returns(returnType);
// either use the method type parameter specified from the GremlinDsl.AnonymousMethod or just infer them
// from the DSL specification. "inferring" relies on convention and sometimes doesn't work for all cases.
final String startGeneric = methodAnnotation.isPresent() && methodAnnotation.get().methodTypeParameters().length > 0 ?
methodAnnotation.get().methodTypeParameters()[0] : "S";
if (methodAnnotation.isPresent() && methodAnnotation.get().methodTypeParameters().length > 0)
Stream.of(methodAnnotation.get().methodTypeParameters()).map(TypeVariableName::get).forEach(methodToAdd::addTypeVariable);
else {
templateMethod.getTypeParameters().forEach(tp -> methodToAdd.addTypeVariable(TypeVariableName.get(tp)));
// might have to deal with an "S" (in __ it's usually an "A") - how to make this less bound to that convention?
final List<? extends TypeMirror> returnTypeArguments = getTypeArguments(templateMethod);
returnTypeArguments.stream().filter(rtm -> rtm instanceof TypeVariable).forEach(rtm -> {
if (((TypeVariable) rtm).asElement().getSimpleName().contentEquals("S"))
methodToAdd.addTypeVariable(TypeVariableName.get(((TypeVariable) rtm).asElement().getSimpleName().toString()));
});
}
addMethodBody(methodToAdd, templateMethod, "return __.<" + startGeneric + ">start().$L(", ")", methodName);
anonymousClass.addMethod(methodToAdd.build());
}
// use methods from __ to template them into the DSL __
final Element anonymousTraversal = elementUtils.getTypeElement(__.class.getCanonicalName());
final Predicate<ExecutableElement> ignore = ee -> ee.getSimpleName().contentEquals("start");
for (ExecutableElement templateMethod : findMethodsOfElement(anonymousTraversal, ignore)) {
final String methodName = templateMethod.getSimpleName().toString();
final TypeName returnType = getReturnTypeDefinition(ctx.traversalClassName, templateMethod);
final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(methodName)
.addModifiers(Modifier.STATIC, Modifier.PUBLIC)
.addExceptions(templateMethod.getThrownTypes().stream().map(TypeName::get).collect(Collectors.toList()))
.returns(returnType);
templateMethod.getTypeParameters().forEach(tp -> methodToAdd.addTypeVariable(TypeVariableName.get(tp)));
if (methodName.equals("__")) {
for (VariableElement param : templateMethod.getParameters()) {
methodToAdd.addParameter(ParameterSpec.get(param));
}
methodToAdd.varargs(true);
methodToAdd.addStatement("return inject(starts)", methodName);
} else {
if (templateMethod.getTypeParameters().isEmpty()) {
final List<? extends TypeMirror> types = getTypeArguments(templateMethod);
addMethodBody(methodToAdd, templateMethod, "return __.<$T>start().$L(", ")", types.get(0), methodName);
} else {
addMethodBody(methodToAdd, templateMethod, "return __.<A>start().$L(", ")", methodName);
}
}
anonymousClass.addMethod(methodToAdd.build());
}
final JavaFile traversalSourceJavaFile = JavaFile.builder(ctx.packageName, anonymousClass.build()).build();
traversalSourceJavaFile.writeTo(filer);
}
private void generateTraversalSource(final Context ctx) throws IOException {
final TypeElement graphTraversalSourceElement = ctx.traversalSourceDslType;
final TypeSpec.Builder traversalSourceClass = TypeSpec.classBuilder(ctx.traversalSourceClazz)
.addModifiers(Modifier.PUBLIC)
.superclass(TypeName.get(graphTraversalSourceElement.asType()));
// add the required constructors for instantiation
traversalSourceClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(Graph.class, "graph")
.addStatement("super($N)", "graph")
.build());
traversalSourceClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(Graph.class, "graph")
.addParameter(TraversalStrategies.class, "strategies")
.addStatement("super($N, $N)", "graph", "strategies")
.build());
traversalSourceClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(RemoteConnection.class, "connection")
.addStatement("super($N)", "connection")
.build());
// override methods to return a the DSL TraversalSource. find GraphTraversalSource class somewhere in the hierarchy
final Element tinkerPopsGraphTraversalSource = findClassAsElement(graphTraversalSourceElement, GraphTraversalSource.class);
final Predicate<ExecutableElement> ignore = e -> !(e.getReturnType().getKind() == TypeKind.DECLARED && ((DeclaredType) e.getReturnType()).asElement().getSimpleName().contentEquals(GraphTraversalSource.class.getSimpleName()));
for (ExecutableElement elementOfGraphTraversalSource : findMethodsOfElement(tinkerPopsGraphTraversalSource, ignore)) {
// first copy/override methods that return a GraphTraversalSource so that we can instead return
// the DSL TraversalSource class.
traversalSourceClass.addMethod(constructMethod(elementOfGraphTraversalSource, ctx.traversalSourceClassName, "",Modifier.PUBLIC));
}
// override methods that return GraphTraversal that come from the user defined extension of GraphTraversal
if (!graphTraversalSourceElement.getSimpleName().contentEquals(GraphTraversalSource.class.getSimpleName())) {
for (ExecutableElement templateMethod : findMethodsOfElement(graphTraversalSourceElement, null)) {
final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(templateMethod.getSimpleName().toString())
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class);
methodToAdd.addStatement("$T clone = this.clone()", ctx.traversalSourceClassName);
addMethodBody(methodToAdd, templateMethod, "return new $T (clone, super.$L(", ").asAdmin())",
ctx.defaultTraversalClassName, templateMethod.getSimpleName());
methodToAdd.returns(getReturnTypeDefinition(ctx.traversalClassName, templateMethod));
traversalSourceClass.addMethod(methodToAdd.build());
}
}
if (ctx.generateDefaultMethods) {
// override methods that return GraphTraversal
traversalSourceClass.addMethod(MethodSpec.methodBuilder("addV")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.addV)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, (String) null))", ctx.traversalClassName, AddVertexStartStep.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("addV")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(String.class, "label")
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.addV, label)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, label))", ctx.traversalClassName, AddVertexStartStep.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("addV")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(Traversal.class, "vertexLabelTraversal")
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.addV, vertexLabelTraversal)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, vertexLabelTraversal))", ctx.traversalClassName, AddVertexStartStep.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("addE")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(String.class, "label")
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.addE, label)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, label))", ctx.traversalClassName, AddEdgeStartStep.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Edge.class), ClassName.get(Edge.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("addE")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(Traversal.class, "edgeLabelTraversal")
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.addE, edgeLabelTraversal)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, edgeLabelTraversal))", ctx.traversalClassName, AddEdgeStartStep.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Edge.class), ClassName.get(Edge.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("V")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(Object[].class, "vertexIds")
.varargs(true)
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.V, vertexIds)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, $T.class, true, vertexIds))", ctx.traversalClassName, GraphStep.class, Vertex.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Vertex.class), ClassName.get(Vertex.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("E")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(Object[].class, "edgeIds")
.varargs(true)
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.E, edgeIds)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, $T.class, true, edgeIds))", ctx.traversalClassName, GraphStep.class, Edge.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, ClassName.get(Edge.class), ClassName.get(Edge.class)))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("inject")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(ArrayTypeName.of(TypeVariableName.get("S")), "starts")
.varargs(true)
.addTypeVariable(TypeVariableName.get("S"))
.addStatement("$N clone = this.clone()", ctx.traversalSourceClazz)
.addStatement("clone.getBytecode().addStep($T.inject, starts)", GraphTraversal.Symbols.class)
.addStatement("$N traversal = new $N(clone)", ctx.defaultTraversalClazz, ctx.defaultTraversalClazz)
.addStatement("return ($T) traversal.asAdmin().addStep(new $T(traversal, starts))", ctx.traversalClassName, InjectStep.class)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("S")))
.build());
traversalSourceClass.addMethod(MethodSpec.methodBuilder("getAnonymousTraversalClass")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addStatement("return Optional.of(__.class)")
.returns(ParameterizedTypeName.get(Optional.class, Class.class))
.build());
}
final JavaFile traversalSourceJavaFile = JavaFile.builder(ctx.packageName, traversalSourceClass.build()).build();
traversalSourceJavaFile.writeTo(filer);
}
private Element findClassAsElement(final Element element, final Class<?> clazz) {
if (element.getSimpleName().contentEquals(clazz.getSimpleName())) {
return element;
}
final List<? extends TypeMirror> supertypes = typeUtils.directSupertypes(element.asType());
return findClassAsElement(typeUtils.asElement(supertypes.get(0)), clazz);
}
private void generateDefaultTraversal(final Context ctx) throws IOException {
final TypeSpec.Builder defaultTraversalClass = TypeSpec.classBuilder(ctx.defaultTraversalClazz)
.addModifiers(Modifier.PUBLIC)
.addTypeVariables(Arrays.asList(TypeVariableName.get("S"), TypeVariableName.get("E")))
.superclass(TypeName.get(elementUtils.getTypeElement(DefaultTraversal.class.getCanonicalName()).asType()))
.addSuperinterface(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")));
// add the required constructors for instantiation
defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addStatement("super()")
.build());
defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(Graph.class, "graph")
.addStatement("super($N)", "graph")
.build());
defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(ctx.traversalSourceClassName, "traversalSource")
.addStatement("super($N)", "traversalSource")
.build());
defaultTraversalClass.addMethod(MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(ctx.traversalSourceClassName, "traversalSource")
.addParameter(ctx.graphTraversalAdminClassName, "traversal")
.addStatement("super($N, $N.asAdmin())", "traversalSource", "traversal")
.build());
// add the override
defaultTraversalClass.addMethod(MethodSpec.methodBuilder("iterate")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addStatement("return ($T) super.iterate()", ctx.traversalClassName)
.returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
.build());
defaultTraversalClass.addMethod(MethodSpec.methodBuilder("asAdmin")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addStatement("return ($T) super.asAdmin()", GraphTraversal.Admin.class)
.returns(ParameterizedTypeName.get(ctx.graphTraversalAdminClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
.build());
defaultTraversalClass.addMethod(MethodSpec.methodBuilder("clone")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addStatement("return ($T) super.clone()", ctx.defaultTraversalClassName)
.returns(ParameterizedTypeName.get(ctx.defaultTraversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
.build());
final JavaFile defaultTraversalJavaFile = JavaFile.builder(ctx.packageName, defaultTraversalClass.build()).build();
defaultTraversalJavaFile.writeTo(filer);
}
private void generateTraversalInterface(final Context ctx) throws IOException {
final TypeSpec.Builder traversalInterface = TypeSpec.interfaceBuilder(ctx.traversalClazz)
.addModifiers(Modifier.PUBLIC)
.addTypeVariables(Arrays.asList(TypeVariableName.get("S"), TypeVariableName.get("E")))
.addSuperinterface(TypeName.get(ctx.annotatedDslType.asType()));
// process the methods of the GremlinDsl annotated class
for (ExecutableElement templateMethod : findMethodsOfElement(ctx.annotatedDslType, null)) {
traversalInterface.addMethod(constructMethod(templateMethod, ctx.traversalClassName, ctx.dslName,
Modifier.PUBLIC, Modifier.DEFAULT));
}
// process the methods of GraphTraversal
final TypeElement graphTraversalElement = elementUtils.getTypeElement(GraphTraversal.class.getCanonicalName());
final Predicate<ExecutableElement> ignore = e -> e.getSimpleName().contentEquals("asAdmin") || e.getSimpleName().contentEquals("iterate");
for (ExecutableElement templateMethod : findMethodsOfElement(graphTraversalElement, ignore)) {
traversalInterface.addMethod(constructMethod(templateMethod, ctx.traversalClassName, ctx.dslName,
Modifier.PUBLIC, Modifier.DEFAULT));
}
// there are weird things with generics that require this method to be implemented if it isn't already present
// in the GremlinDsl annotated class extending from GraphTraversal
traversalInterface.addMethod(MethodSpec.methodBuilder("iterate")
.addModifiers(Modifier.PUBLIC, Modifier.DEFAULT)
.addAnnotation(Override.class)
.addStatement("$T.super.iterate()", ClassName.get(ctx.annotatedDslType))
.addStatement("return this")
.returns(ParameterizedTypeName.get(ctx.traversalClassName, TypeVariableName.get("S"), TypeVariableName.get("E")))
.build());
final JavaFile traversalJavaFile = JavaFile.builder(ctx.packageName, traversalInterface.build()).build();
traversalJavaFile.writeTo(filer);
}
private MethodSpec constructMethod(final Element element, final ClassName returnClazz, final String parent,
final Modifier... modifiers) {
final ExecutableElement templateMethod = (ExecutableElement) element;
final String methodName = templateMethod.getSimpleName().toString();
final TypeName returnType = getReturnTypeDefinition(returnClazz, templateMethod);
final MethodSpec.Builder methodToAdd = MethodSpec.methodBuilder(methodName)
.addModifiers(modifiers)
.addAnnotation(Override.class)
.addExceptions(templateMethod.getThrownTypes().stream().map(TypeName::get).collect(Collectors.toList()))
.returns(returnType);
templateMethod.getTypeParameters().forEach(tp -> methodToAdd.addTypeVariable(TypeVariableName.get(tp)));
final String parentCall = parent.isEmpty() ? "" : parent + ".";
final String body = "return ($T) " + parentCall + "super.$L(";
addMethodBody(methodToAdd, templateMethod, body, ")", returnClazz, methodName);
return methodToAdd.build();
}
private void addMethodBody(final MethodSpec.Builder methodToAdd, final ExecutableElement templateMethod,
final String startBody, final String endBody, final Object... statementArgs) {
final List<? extends VariableElement> parameters = templateMethod.getParameters();
final StringBuilder body = new StringBuilder(startBody);
final int numberOfParams = parameters.size();
for (int ix = 0; ix < numberOfParams; ix++) {
final VariableElement param = parameters.get(ix);
methodToAdd.addParameter(ParameterSpec.get(param));
body.append(param.getSimpleName());
if (ix < numberOfParams - 1) {
body.append(",");
}
}
body.append(endBody);
// treat a final array as a varargs param
if (!parameters.isEmpty() && parameters.get(parameters.size() - 1).asType().getKind() == TypeKind.ARRAY)
methodToAdd.varargs(true);
methodToAdd.addStatement(body.toString(), statementArgs);
}
private TypeName getOverridenReturnTypeDefinition(final ClassName returnClazz, final String[] typeValues) {
return ParameterizedTypeName.get(returnClazz, Stream.of(typeValues).map(tv -> {
try {
return ClassName.get(Class.forName(tv));
} catch (ClassNotFoundException cnfe) {
if (tv.contains("extends")) {
final String[] sides = tv.toString().split(" extends ");
final TypeVariableName name = TypeVariableName.get(sides[0]);
try {
name.withBounds(ClassName.get(Class.forName(sides[1])));
} catch (Exception ex) {
name.withBounds(TypeVariableName.get(sides[1]));
}
return name;
} else {
return TypeVariableName.get(tv);
}
}
}).collect(Collectors.toList()).toArray(new TypeName[typeValues.length]));
}
private TypeName getReturnTypeDefinition(final ClassName returnClazz, final ExecutableElement templateMethod) {
final List<? extends TypeMirror> returnTypeArguments = getTypeArguments(templateMethod);
// build a return type with appropriate generic declarations (if such declarations are present)
return returnTypeArguments.isEmpty() ?
returnClazz :
ParameterizedTypeName.get(returnClazz, returnTypeArguments.stream().map(TypeName::get).collect(Collectors.toList()).toArray(new TypeName[returnTypeArguments.size()]));
}
private void validateDSL(final Element dslElement) throws ProcessorException {
if (dslElement.getKind() != ElementKind.INTERFACE)
throw new ProcessorException(dslElement, "Only interfaces can be annotated with @%s", GremlinDsl.class.getSimpleName());
final TypeElement typeElement = (TypeElement) dslElement;
if (!typeElement.getModifiers().contains(Modifier.PUBLIC))
throw new ProcessorException(dslElement, "The interface %s is not public.", typeElement.getQualifiedName());
}
private List<ExecutableElement> findMethodsOfElement(final Element element, final Predicate<ExecutableElement> ignore) {
final Predicate<ExecutableElement> test = null == ignore ? ee -> false : ignore;
return element.getEnclosedElements().stream()
.filter(ee -> ee.getKind() == ElementKind.METHOD)
.map(ee -> (ExecutableElement) ee)
.filter(ee -> !test.test(ee)).collect(Collectors.toList());
}
private List<? extends TypeMirror> getTypeArguments(final ExecutableElement templateMethod) {
final DeclaredType returnTypeMirror = (DeclaredType) templateMethod.getReturnType();
return returnTypeMirror.getTypeArguments();
}
private class Context {
private final TypeElement annotatedDslType;
private final String packageName;
private final String dslName;
private final String traversalClazz;
private final ClassName traversalClassName;
private final String traversalSourceClazz;
private final ClassName traversalSourceClassName;
private final String defaultTraversalClazz;
private final ClassName defaultTraversalClassName;
private final ClassName graphTraversalAdminClassName;
private final TypeElement traversalSourceDslType;
private final boolean generateDefaultMethods;
public Context(final TypeElement dslElement) {
annotatedDslType = dslElement;
// gets the annotation on the dsl class/interface
GremlinDsl gremlinDslAnnotation = dslElement.getAnnotation(GremlinDsl.class);
generateDefaultMethods = gremlinDslAnnotation.generateDefaultMethods();
traversalSourceDslType = elementUtils.getTypeElement(gremlinDslAnnotation.traversalSource());
packageName = getPackageName(dslElement, gremlinDslAnnotation);
// create the Traversal implementation interface
dslName = dslElement.getSimpleName().toString();
final String dslPrefix = dslName.substring(0, dslName.length() - "TraversalDSL".length()); // chop off "TraversalDSL"
traversalClazz = dslPrefix + "Traversal";
traversalClassName = ClassName.get(packageName, traversalClazz);
traversalSourceClazz = dslPrefix + "TraversalSource";
traversalSourceClassName = ClassName.get(packageName, traversalSourceClazz);
defaultTraversalClazz = "Default" + traversalClazz;
defaultTraversalClassName = ClassName.get(packageName, defaultTraversalClazz);
graphTraversalAdminClassName = ClassName.get(GraphTraversal.Admin.class);
}
private String getPackageName(final Element dslElement, final GremlinDsl gremlinDslAnnotation) {
return gremlinDslAnnotation.packageName().isEmpty() ?
elementUtils.getPackageOf(dslElement).getQualifiedName().toString() :
gremlinDslAnnotation.packageName();
}
}
}