blob: f8f3bf48095af10b14046d70061ad042844c2af8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.codehaus.groovy.classgen;
import org.codehaus.groovy.ast.MethodNode;
import org.codehaus.groovy.ast.VariableScope;
import org.codehaus.groovy.ast.expr.Expression;
import org.codehaus.groovy.ast.stmt.BlockStatement;
import org.codehaus.groovy.ast.stmt.BreakStatement;
import org.codehaus.groovy.ast.stmt.CaseStatement;
import org.codehaus.groovy.ast.stmt.CatchStatement;
import org.codehaus.groovy.ast.stmt.EmptyStatement;
import org.codehaus.groovy.ast.stmt.ExpressionStatement;
import org.codehaus.groovy.ast.stmt.IfStatement;
import org.codehaus.groovy.ast.stmt.ReturnStatement;
import org.codehaus.groovy.ast.stmt.Statement;
import org.codehaus.groovy.ast.stmt.SwitchStatement;
import org.codehaus.groovy.ast.stmt.SynchronizedStatement;
import org.codehaus.groovy.ast.stmt.ThrowStatement;
import org.codehaus.groovy.ast.stmt.TryCatchStatement;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static org.codehaus.groovy.ast.tools.GeneralUtils.nullX;
/**
* Utility class to add return statements.
* <p>
* Extracted from Verifier as it can be useful for some AST transformations.
*/
public class ReturnAdder {
@FunctionalInterface
public interface ReturnStatementListener {
/**
* Implement this method in order to be notified whenever a return statement is generated.
*/
void returnStatementAdded(ReturnStatement returnStatement);
}
private static final ReturnStatementListener DEFAULT_LISTENER = returnStatement -> {
};
private final ReturnStatementListener listener;
/**
* If set to 'true', then returns are effectively added. This is useful whenever you just want
* to check what returns are produced without eventually adding them.
*/
private final boolean doAdd;
public ReturnAdder() {
this.listener = DEFAULT_LISTENER;
this.doAdd = true;
}
public ReturnAdder(final ReturnStatementListener listener) {
this.listener = Objects.requireNonNull(listener);
this.doAdd = false;
}
/**
* @deprecated Use {@link #visitMethod(MethodNode)} instead.
*/
@Deprecated
public static void addReturnIfNeeded(final MethodNode node) {
new ReturnAdder().visitMethod(node);
}
/**
* Adds return statements to given method whenever an implicit return is detected.
*/
public void visitMethod(final MethodNode node) {
if (!node.isVoidMethod()) {
Statement code = node.getCode();
if (code != null) { // happens with @interface methods
code = addReturnsIfNeeded(code, node.getVariableScope());
if (doAdd) node.setCode(code);
}
}
}
private Statement addReturnsIfNeeded(final Statement statement, final VariableScope scope) {
if (statement instanceof ReturnStatement || statement instanceof ThrowStatement
|| statement instanceof BytecodeSequence) {
return statement;
}
if (statement instanceof EmptyStatement || statement == null) {
ReturnStatement returnStatement = new ReturnStatement(nullX());
listener.returnStatementAdded(returnStatement);
return returnStatement;
}
if (statement instanceof ExpressionStatement) {
Expression expression = ((ExpressionStatement) statement).getExpression();
ReturnStatement returnStatement = new ReturnStatement(expression);
returnStatement.copyStatementLabels(statement);
returnStatement.setSourcePosition(statement);
listener.returnStatementAdded(returnStatement);
return returnStatement;
}
if (statement instanceof SynchronizedStatement) {
SynchronizedStatement syncStatement = (SynchronizedStatement) statement;
Statement code = addReturnsIfNeeded(syncStatement.getCode(), scope);
if (doAdd) syncStatement.setCode(code);
return syncStatement;
}
if (statement instanceof IfStatement) {
IfStatement ifElseStatement = (IfStatement) statement;
Statement ifBlock = addReturnsIfNeeded(ifElseStatement.getIfBlock(), scope);
Statement elseBlock = addReturnsIfNeeded(ifElseStatement.getElseBlock(), scope);
if (doAdd) {
ifElseStatement.setIfBlock(ifBlock);
ifElseStatement.setElseBlock(elseBlock);
}
return ifElseStatement;
}
if (statement instanceof SwitchStatement) {
SwitchStatement switchStatement = (SwitchStatement) statement;
for (CaseStatement caseStatement : switchStatement.getCaseStatements()) {
Statement code = adjustSwitchCaseCode(caseStatement.getCode(), scope, false);
if (doAdd) caseStatement.setCode(code);
}
Statement defaultStatement = adjustSwitchCaseCode(switchStatement.getDefaultStatement(), scope, true);
if (doAdd) switchStatement.setDefaultStatement(defaultStatement);
return switchStatement;
}
if (statement instanceof TryCatchStatement) {
TryCatchStatement tryCatchFinally = (TryCatchStatement) statement;
boolean[] missesReturn = new boolean[1];
new ReturnAdder(returnStatement -> missesReturn[0] = true)
.addReturnsIfNeeded(tryCatchFinally.getFinallyStatement(), scope);
boolean hasFinally = !(tryCatchFinally.getFinallyStatement() instanceof EmptyStatement);
// if there is no missing return in the finally block and the block exists
// there is nothing to do
if (hasFinally && !missesReturn[0]) return tryCatchFinally;
// add returns to try and catch blocks
Statement tryStatement = addReturnsIfNeeded(tryCatchFinally.getTryStatement(), scope);
if (doAdd) tryCatchFinally.setTryStatement(tryStatement);
for (CatchStatement catchStatement : tryCatchFinally.getCatchStatements()) {
Statement code = addReturnsIfNeeded(catchStatement.getCode(), scope);
if (doAdd) catchStatement.setCode(code);
}
return tryCatchFinally;
}
if (statement instanceof BlockStatement) {
BlockStatement blockStatement = (BlockStatement) statement;
if (blockStatement.isEmpty()) {
ReturnStatement returnStatement = new ReturnStatement(nullX());
returnStatement.copyStatementLabels(blockStatement);
returnStatement.setSourcePosition(blockStatement);
listener.returnStatementAdded(returnStatement);
return returnStatement;
} else {
List<Statement> statements = blockStatement.getStatements();
int lastIndex = statements.size() - 1;
Statement last = addReturnsIfNeeded(statements.get(lastIndex), blockStatement.getVariableScope());
if (doAdd) statements.set(lastIndex, last);
if (!returns(last)) {
ReturnStatement returnStatement = new ReturnStatement(nullX());
listener.returnStatementAdded(returnStatement);
if (doAdd) statements.add(returnStatement);
}
return blockStatement;
}
}
List<Statement> statements = new ArrayList<>(2);
statements.add(statement);
ReturnStatement returnStatement = new ReturnStatement(nullX());
listener.returnStatementAdded(returnStatement);
statements.add(returnStatement);
BlockStatement blockStatement = new BlockStatement(statements, new VariableScope(scope));
blockStatement.setSourcePosition(statement);
return blockStatement;
}
private Statement adjustSwitchCaseCode(final Statement statement, final VariableScope scope, final boolean defaultCase) {
if (statement instanceof BlockStatement) {
List<Statement> statements = ((BlockStatement) statement).getStatements();
if (!statements.isEmpty()) {
int lastIndex = statements.size() - 1;
Statement last = statements.get(lastIndex);
if (last instanceof BreakStatement) {
if (doAdd) {
statements.remove(lastIndex);
return addReturnsIfNeeded(statement, scope);
} else {
BlockStatement newBlock = new BlockStatement();
for (int i = 0; i < lastIndex; i += 1) {
newBlock.addStatement(statements.get(i));
}
return addReturnsIfNeeded(newBlock, scope);
}
} else if (defaultCase) {
return addReturnsIfNeeded(statement, scope);
}
}
}
return statement;
}
private static boolean returns(final Statement statement) {
return statement instanceof ReturnStatement
|| statement instanceof BlockStatement
|| statement instanceof IfStatement
|| statement instanceof ExpressionStatement
|| statement instanceof EmptyStatement
|| statement instanceof TryCatchStatement
|| statement instanceof ThrowStatement
|| statement instanceof SynchronizedStatement
|| statement instanceof BytecodeSequence;
}
}