GROOVY-10033, GROOVY-10047: STC: method refs for ctor and static calls
diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index b546226..8119ffb 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -2787,6 +2787,9 @@
if (nExpressions > 0 && expressions.get(0) instanceof MapExpression && params.length > 0) {
checkNamedParamsAnnotation(params[0], (MapExpression) expressions.get(0));
}
+ if (visitClosures) {
+ inferMethodReferenceType(receiver, arguments, selectedMethod);
+ }
}
private void checkNamedParamsAnnotation(final Parameter param, final MapExpression args) {
@@ -3623,8 +3626,6 @@
}
}
}
-
- inferMethodReferenceType(call, receiver, argumentList);
} finally {
typeCheckingContext.popEnclosingMethodCall();
extension.afterMethodCall(call);
@@ -3642,28 +3643,23 @@
return Closure.OWNER_FIRST;
}
- private void inferMethodReferenceType(final MethodCallExpression call, final ClassNode receiver, final ArgumentListExpression argumentList) {
- if (call == null) return;
+ private void inferMethodReferenceType(final ClassNode receiver, final ArgumentListExpression argumentList, final MethodNode selectedMethod) {
if (receiver == null) return;
if (argumentList == null) return;
+ if (selectedMethod == null) return;
- List<Expression> argumentExpressionList = argumentList.getExpressions();
- if (argumentExpressionList == null) return;
-
- boolean noMethodReferenceParams = argumentExpressionList.stream().noneMatch(e -> e instanceof MethodReferenceExpression);
- if (noMethodReferenceParams) {
+ List<Expression> argumentExpressions = argumentList.getExpressions();
+ if (argumentExpressions == null || argumentExpressions.stream()
+ .noneMatch(e -> e instanceof MethodReferenceExpression)) {
return;
}
- MethodNode selectedMethod = call.getNodeMetaData(DIRECT_METHOD_CALL_TARGET);
- if (selectedMethod == null) return;
-
Parameter[] parameters = selectedMethod.getParameters();
List<Integer> methodReferenceParamIndexList = new LinkedList<>();
List<Expression> newArgumentExpressionList = new LinkedList<>();
- for (int i = 0, n = argumentExpressionList.size(); i < n; i++) {
- Expression argumentExpression = argumentExpressionList.get(i);
+ for (int i = 0, n = argumentExpressions.size(); i < n; i += 1) {
+ Expression argumentExpression = argumentExpressions.get(i);
if (!(argumentExpression instanceof MethodReferenceExpression)) {
newArgumentExpressionList.add(argumentExpression);
continue;
@@ -3686,7 +3682,7 @@
for (Integer methodReferenceParamIndex : methodReferenceParamIndexList) {
LambdaExpression lambdaExpression = (LambdaExpression) newArgumentExpressionList.get(methodReferenceParamIndex);
ClassNode[] argumentTypes = lambdaExpression.getNodeMetaData(CLOSURE_ARGUMENTS);
- argumentExpressionList.get(methodReferenceParamIndex).putNodeMetaData(CLOSURE_ARGUMENTS, argumentTypes);
+ argumentExpressions.get(methodReferenceParamIndex).putNodeMetaData(CLOSURE_ARGUMENTS, argumentTypes);
}
}
diff --git a/src/test/groovy/transform/stc/MethodReferenceTest.groovy b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
index 0825849..e88c3b3 100644
--- a/src/test/groovy/transform/stc/MethodReferenceTest.groovy
+++ b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
@@ -51,7 +51,38 @@
@CompileStatic
void p() {
def result = [1, 2, 3].stream().map(Integer::toString).collect(Collectors.toList())
- assert result ['1', '2', '3']
+ assert result == ['1', '2', '3']
+ }
+
+ p()
+ '''
+ }
+
+ @Test // class::instanceMethod -- GROOVY-10047
+ void testFunctionCI3() {
+ assertScript shell, '''
+ import static java.util.stream.Collectors.toMap
+
+ @CompileStatic
+ void p() {
+ List<String> list = ['a','bc','def']
+ Function<String,String> self = str -> str // help for toMap
+ def map = list.stream().collect(toMap(self, String::length))
+ assert map == [a: 1, bc: 2, 'def': 3]
+ }
+
+ p()
+ '''
+
+ assertScript shell, '''
+ import static java.util.stream.Collectors.toMap
+
+ @CompileStatic
+ void p() {
+ List<String> list = ['a','bc','def']
+ // TODO: inference for T in toMap(Function<? super T,...>, Function<? super T,...>)
+ def map = list.stream().collect(toMap(Function.<String>identity(), String::length))
+ assert map == [a: 1, bc: 2, 'def': 3]
}
p()
@@ -59,7 +90,7 @@
}
@Test // class::instanceMethod
- void testFunctionCI3() {
+ void testFunctionCI4() {
def err = shouldFail shell, '''
@CompileStatic
void p() {
@@ -73,7 +104,7 @@
}
@Test // class::instanceMethod -- GROOVY-9814
- void testFunctionCI4() {
+ void testFunctionCI5() {
assertScript shell, '''
@CompileStatic
class One { String id }
@@ -291,6 +322,22 @@
'''
}
+ @Test // class::new -- GROOVY-10033
+ void testFunctionCN2() {
+ assertScript shell, '''
+ class C {
+ C(Function<String,String> f) {
+ }
+ }
+ @CompileStatic
+ void p() {
+ new C(String::toLowerCase)
+ }
+
+ p()
+ '''
+ }
+
@Test // class::staticMethod
void testFunctionCS() {
assertScript shell, '''
@@ -304,9 +351,25 @@
'''
}
- @Test // class::staticMethod -- GROOVY-9799
+ @Test // class::staticMethod
void testFunctionCS2() {
assertScript shell, '''
+ import static java.util.stream.Collectors.toMap
+
+ @CompileStatic
+ void p() {
+ List<String> list = ['x','y','z']
+ def map = list.stream().collect(toMap(Function.identity(), Collections::singletonList))
+ assert map == [x: ['x'], y: ['y'], z: ['z']]
+ }
+
+ p()
+ '''
+ }
+
+ @Test // class::staticMethod -- GROOVY-9799
+ void testFunctionCS3() {
+ assertScript shell, '''
class C {
String x
}