/* | |
* Licensed to the Apache Software Foundation (ASF) under one | |
* or more contributor license agreements. See the NOTICE file | |
* distributed with this work for additional information | |
* regarding copyright ownership. The ASF licenses this file | |
* to you under the Apache License, Version 2.0 (the | |
* "License"); you may not use this file except in compliance | |
* with the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, | |
* software distributed under the License is distributed on an | |
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
* KIND, either express or implied. See the License for the | |
* specific language governing permissions and limitations | |
* under the License. | |
*/ | |
package org.codehaus.groovy.transform; | |
import groovy.transform.EqualsAndHashCode; | |
import org.codehaus.groovy.ast.ASTNode; | |
import org.codehaus.groovy.ast.AnnotatedNode; | |
import org.codehaus.groovy.ast.AnnotationNode; | |
import org.codehaus.groovy.ast.ClassHelper; | |
import org.codehaus.groovy.ast.ClassNode; | |
import org.codehaus.groovy.ast.FieldNode; | |
import org.codehaus.groovy.ast.MethodNode; | |
import org.codehaus.groovy.ast.Parameter; | |
import org.codehaus.groovy.ast.PropertyNode; | |
import org.codehaus.groovy.ast.expr.BinaryExpression; | |
import org.codehaus.groovy.ast.expr.CastExpression; | |
import org.codehaus.groovy.ast.expr.Expression; | |
import org.codehaus.groovy.ast.expr.VariableExpression; | |
import org.codehaus.groovy.ast.stmt.BlockStatement; | |
import org.codehaus.groovy.ast.stmt.Statement; | |
import org.codehaus.groovy.ast.tools.GenericsUtils; | |
import org.codehaus.groovy.control.CompilePhase; | |
import org.codehaus.groovy.control.SourceUnit; | |
import org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport; | |
import org.codehaus.groovy.util.HashCodeHelper; | |
import java.util.ArrayList; | |
import java.util.List; | |
import static org.codehaus.groovy.ast.ClassHelper.make; | |
import static org.codehaus.groovy.ast.tools.GeneralUtils.*; | |
import static org.codehaus.groovy.ast.tools.GenericsUtils.makeClassSafe; | |
@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) | |
public class EqualsAndHashCodeASTTransformation extends AbstractASTTransformation { | |
static final Class MY_CLASS = EqualsAndHashCode.class; | |
static final ClassNode MY_TYPE = make(MY_CLASS); | |
static final String MY_TYPE_NAME = "@" + MY_TYPE.getNameWithoutPackage(); | |
private static final ClassNode HASHUTIL_TYPE = make(HashCodeHelper.class); | |
private static final ClassNode OBJECT_TYPE = makeClassSafe(Object.class); | |
public void visit(ASTNode[] nodes, SourceUnit source) { | |
init(nodes, source); | |
AnnotatedNode parent = (AnnotatedNode) nodes[1]; | |
AnnotationNode anno = (AnnotationNode) nodes[0]; | |
if (!MY_TYPE.equals(anno.getClassNode())) return; | |
if (parent instanceof ClassNode) { | |
ClassNode cNode = (ClassNode) parent; | |
if (!checkNotInterface(cNode, MY_TYPE_NAME)) return; | |
boolean callSuper = memberHasValue(anno, "callSuper", true); | |
boolean cacheHashCode = memberHasValue(anno, "cache", true); | |
boolean useCanEqual = !memberHasValue(anno, "useCanEqual", false); | |
if (callSuper && cNode.getSuperClass().getName().equals("java.lang.Object")) { | |
addError("Error during " + MY_TYPE_NAME + " processing: callSuper=true but '" + cNode.getName() + "' has no super class.", anno); | |
} | |
boolean includeFields = memberHasValue(anno, "includeFields", true); | |
List<String> excludes = getMemberStringList(anno, "excludes"); | |
List<String> includes = getMemberStringList(anno, "includes"); | |
final boolean allNames = memberHasValue(anno, "allNames", true); | |
if (!checkIncludeExcludeUndefinedAware(anno, excludes, includes, MY_TYPE_NAME)) return; | |
if (!checkPropertyList(cNode, includes, "includes", anno, MY_TYPE_NAME, includeFields)) return; | |
if (!checkPropertyList(cNode, excludes, "excludes", anno, MY_TYPE_NAME, includeFields)) return; | |
createHashCode(cNode, cacheHashCode, includeFields, callSuper, excludes, includes, allNames); | |
createEquals(cNode, includeFields, callSuper, useCanEqual, excludes, includes, allNames); | |
} | |
} | |
public static void createHashCode(ClassNode cNode, boolean cacheResult, boolean includeFields, boolean callSuper, List<String> excludes, List<String> includes) { | |
createHashCode(cNode, cacheResult, includeFields, callSuper, excludes, includes, false); | |
} | |
public static void createHashCode(ClassNode cNode, boolean cacheResult, boolean includeFields, boolean callSuper, List<String> excludes, List<String> includes, boolean allNames) { | |
// make a public method if none exists otherwise try a private method with leading underscore | |
boolean hasExistingHashCode = hasDeclaredMethod(cNode, "hashCode", 0); | |
if (hasExistingHashCode && hasDeclaredMethod(cNode, "_hashCode", 0)) return; | |
final BlockStatement body = new BlockStatement(); | |
// TODO use pList and fList | |
if (cacheResult) { | |
final FieldNode hashField = cNode.addField("$hash$code", ACC_PRIVATE | ACC_SYNTHETIC, ClassHelper.int_TYPE, null); | |
final Expression hash = varX(hashField); | |
body.addStatement(ifS( | |
isZeroX(hash), | |
calculateHashStatements(cNode, hash, includeFields, callSuper, excludes, includes, allNames) | |
)); | |
body.addStatement(returnS(hash)); | |
} else { | |
body.addStatement(calculateHashStatements(cNode, null, includeFields, callSuper, excludes, includes, allNames)); | |
} | |
cNode.addMethod(new MethodNode( | |
hasExistingHashCode ? "_hashCode" : "hashCode", | |
hasExistingHashCode ? ACC_PRIVATE : ACC_PUBLIC, | |
ClassHelper.int_TYPE, | |
Parameter.EMPTY_ARRAY, | |
ClassNode.EMPTY_ARRAY, | |
body)); | |
} | |
private static Statement calculateHashStatements(ClassNode cNode, Expression hash, boolean includeFields, boolean callSuper, List<String> excludes, List<String> includes, boolean allNames) { | |
final List<PropertyNode> pList = getInstanceProperties(cNode); | |
final List<FieldNode> fList = new ArrayList<FieldNode>(); | |
if (includeFields) { | |
fList.addAll(getInstanceNonPropertyFields(cNode)); | |
} | |
final BlockStatement body = new BlockStatement(); | |
// def _result = HashCodeHelper.initHash() | |
final Expression result = varX("_result"); | |
body.addStatement(declS(result, callX(HASHUTIL_TYPE, "initHash"))); | |
for (PropertyNode pNode : pList) { | |
if (shouldSkip(pNode.getName(), excludes, includes, allNames)) continue; | |
// _result = HashCodeHelper.updateHash(_result, getProperty()) // plus self-reference checking | |
Expression getter = getterThisX(cNode, pNode); | |
final Expression current = callX(HASHUTIL_TYPE, "updateHash", args(result, getter)); | |
body.addStatement(ifS( | |
notX(sameX(getter, varX("this"))), | |
assignS(result, current))); | |
} | |
for (FieldNode fNode : fList) { | |
if (shouldSkip(fNode.getName(), excludes, includes, allNames)) continue; | |
// _result = HashCodeHelper.updateHash(_result, field) // plus self-reference checking | |
final Expression fieldExpr = varX(fNode); | |
final Expression current = callX(HASHUTIL_TYPE, "updateHash", args(result, fieldExpr)); | |
body.addStatement(ifS( | |
notX(sameX(fieldExpr, varX("this"))), | |
assignS(result, current))); | |
} | |
if (callSuper) { | |
// _result = HashCodeHelper.updateHash(_result, super.hashCode()) | |
final Expression current = callX(HASHUTIL_TYPE, "updateHash", args(result, callSuperX("hashCode"))); | |
body.addStatement(assignS(result, current)); | |
} | |
// $hash$code = _result | |
if (hash != null) { | |
body.addStatement(assignS(hash, result)); | |
} else { | |
body.addStatement(returnS(result)); | |
} | |
return body; | |
} | |
private static void createCanEqual(ClassNode cNode) { | |
boolean hasExistingCanEqual = hasDeclaredMethod(cNode, "canEqual", 1); | |
if (hasExistingCanEqual && hasDeclaredMethod(cNode, "_canEqual", 1)) return; | |
final BlockStatement body = new BlockStatement(); | |
VariableExpression other = varX("other"); | |
body.addStatement(returnS(isInstanceOfX(other, GenericsUtils.nonGeneric(cNode)))); | |
cNode.addMethod(new MethodNode( | |
hasExistingCanEqual ? "_canEqual" : "canEqual", | |
hasExistingCanEqual ? ACC_PRIVATE : ACC_PUBLIC, | |
ClassHelper.boolean_TYPE, | |
params(param(OBJECT_TYPE, other.getName())), | |
ClassNode.EMPTY_ARRAY, | |
body)); | |
} | |
public static void createEquals(ClassNode cNode, boolean includeFields, boolean callSuper, boolean useCanEqual, List<String> excludes, List<String> includes) { | |
createEquals(cNode, includeFields, callSuper, useCanEqual, excludes, includes, false); | |
} | |
public static void createEquals(ClassNode cNode, boolean includeFields, boolean callSuper, boolean useCanEqual, List<String> excludes, List<String> includes, boolean allNames) { | |
if (useCanEqual) createCanEqual(cNode); | |
// make a public method if none exists otherwise try a private method with leading underscore | |
boolean hasExistingEquals = hasDeclaredMethod(cNode, "equals", 1); | |
if (hasExistingEquals && hasDeclaredMethod(cNode, "_equals", 1)) return; | |
final BlockStatement body = new BlockStatement(); | |
VariableExpression other = varX("other"); | |
// some short circuit cases for efficiency | |
body.addStatement(ifS(equalsNullX(other), returnS(constX(Boolean.FALSE, true)))); | |
body.addStatement(ifS(sameX(varX("this"), other), returnS(constX(Boolean.TRUE, true)))); | |
if (useCanEqual) { | |
body.addStatement(ifS(notX(isInstanceOfX(other, GenericsUtils.nonGeneric(cNode))), returnS(constX(Boolean.FALSE,true)))); | |
} else { | |
body.addStatement(ifS(notX(hasClassX(other, GenericsUtils.nonGeneric(cNode))), returnS(constX(Boolean.FALSE,true)))); | |
} | |
VariableExpression otherTyped = varX("otherTyped", GenericsUtils.nonGeneric(cNode)); | |
CastExpression castExpression = new CastExpression(GenericsUtils.nonGeneric(cNode), other); | |
castExpression.setStrict(true); | |
body.addStatement(declS(otherTyped, castExpression)); | |
if (useCanEqual) { | |
body.addStatement(ifS(notX(callX(otherTyped, "canEqual", varX("this"))), returnS(constX(Boolean.FALSE,true)))); | |
} | |
List<PropertyNode> pList = getInstanceProperties(cNode); | |
for (PropertyNode pNode : pList) { | |
if (shouldSkip(pNode.getName(), excludes, includes, allNames)) continue; | |
boolean canBeSelf = StaticTypeCheckingSupport.implementsInterfaceOrIsSubclassOf( | |
pNode.getOriginType(), cNode | |
); | |
if (!canBeSelf) { | |
body.addStatement(ifS(notX(hasEqualPropertyX(otherTyped.getOriginType(), pNode, otherTyped)), returnS(constX(Boolean.FALSE, true)))); | |
} else { | |
body.addStatement( | |
ifS(notX(hasSamePropertyX(pNode, otherTyped)), | |
ifElseS(differentSelfRecursivePropertyX(pNode, otherTyped), | |
returnS(constX(Boolean.FALSE, true)), | |
ifS(notX(bothSelfRecursivePropertyX(pNode, otherTyped)), | |
ifS(notX(hasEqualPropertyX(otherTyped.getOriginType(), pNode, otherTyped)), returnS(constX(Boolean.FALSE, true)))) | |
) | |
) | |
); | |
} | |
} | |
List<FieldNode> fList = new ArrayList<FieldNode>(); | |
if (includeFields) { | |
fList.addAll(getInstanceNonPropertyFields(cNode)); | |
} | |
for (FieldNode fNode : fList) { | |
if (shouldSkip(fNode.getName(), excludes, includes, allNames)) continue; | |
body.addStatement( | |
ifS(notX(hasSameFieldX(fNode, otherTyped)), | |
ifElseS(differentSelfRecursiveFieldX(fNode, otherTyped), | |
returnS(constX(Boolean.FALSE,true)), | |
ifS(notX(bothSelfRecursiveFieldX(fNode, otherTyped)), | |
ifS(notX(hasEqualFieldX(fNode, otherTyped)), returnS(constX(Boolean.FALSE,true))))) | |
)); | |
} | |
if (callSuper) { | |
body.addStatement(ifS( | |
notX(isTrueX(callSuperX("equals", other))), | |
returnS(constX(Boolean.FALSE,true)) | |
)); | |
} | |
// default | |
body.addStatement(returnS(constX(Boolean.TRUE,true))); | |
cNode.addMethod(new MethodNode( | |
hasExistingEquals ? "_equals" : "equals", | |
hasExistingEquals ? ACC_PRIVATE : ACC_PUBLIC, | |
ClassHelper.boolean_TYPE, | |
params(param(OBJECT_TYPE, other.getName())), | |
ClassNode.EMPTY_ARRAY, | |
body)); | |
} | |
private static BinaryExpression differentSelfRecursivePropertyX(PropertyNode pNode, Expression other) { | |
String getterName = getGetterName(pNode); | |
Expression selfGetter = callThisX(getterName); | |
Expression otherGetter = callX(other, getterName); | |
return orX( | |
andX(sameX(selfGetter, varX("this")), notX(sameX(otherGetter, other))), | |
andX(notX(sameX(selfGetter, varX("this"))), sameX(otherGetter, other)) | |
); | |
} | |
private static BinaryExpression bothSelfRecursivePropertyX(PropertyNode pNode, Expression other) { | |
String getterName = getGetterName(pNode); | |
Expression selfGetter = callThisX(getterName); | |
Expression otherGetter = callX(other, getterName); | |
return andX( | |
sameX(selfGetter, varX("this")), | |
sameX(otherGetter, other) | |
); | |
} | |
private static BinaryExpression differentSelfRecursiveFieldX(FieldNode fNode, Expression other) { | |
final Expression fieldExpr = varX(fNode); | |
final Expression otherExpr = propX(other, fNode.getName()); | |
return orX( | |
andX(sameX(fieldExpr, varX("this")), notX(sameX(otherExpr, other))), | |
andX(notX(sameX(fieldExpr, varX("this"))), sameX(otherExpr, other)) | |
); | |
} | |
private static BinaryExpression bothSelfRecursiveFieldX(FieldNode fNode, Expression other) { | |
final Expression fieldExpr = varX(fNode); | |
final Expression otherExpr = propX(other, fNode.getName()); | |
return andX( | |
sameX(fieldExpr, varX("this")), | |
sameX(otherExpr, other) | |
); | |
} | |
} |