GROOVY-10033, GROOVY-10047: STC: method refs for ctor and static calls
diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index b546226..8119ffb 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -2787,6 +2787,9 @@
         if (nExpressions > 0 && expressions.get(0) instanceof MapExpression && params.length > 0) {
             checkNamedParamsAnnotation(params[0], (MapExpression) expressions.get(0));
         }
+        if (visitClosures) {
+            inferMethodReferenceType(receiver, arguments, selectedMethod);
+        }
     }
 
     private void checkNamedParamsAnnotation(final Parameter param, final MapExpression args) {
@@ -3623,8 +3626,6 @@
                     }
                 }
             }
-
-            inferMethodReferenceType(call, receiver, argumentList);
         } finally {
             typeCheckingContext.popEnclosingMethodCall();
             extension.afterMethodCall(call);
@@ -3642,28 +3643,23 @@
         return Closure.OWNER_FIRST;
     }
 
-    private void inferMethodReferenceType(final MethodCallExpression call, final ClassNode receiver, final ArgumentListExpression argumentList) {
-        if (call == null) return;
+    private void inferMethodReferenceType(final ClassNode receiver, final ArgumentListExpression argumentList, final MethodNode selectedMethod) {
         if (receiver == null) return;
         if (argumentList == null) return;
+        if (selectedMethod == null) return;
 
-        List<Expression> argumentExpressionList = argumentList.getExpressions();
-        if (argumentExpressionList == null) return;
-
-        boolean noMethodReferenceParams = argumentExpressionList.stream().noneMatch(e -> e instanceof MethodReferenceExpression);
-        if (noMethodReferenceParams) {
+        List<Expression> argumentExpressions = argumentList.getExpressions();
+        if (argumentExpressions == null || argumentExpressions.stream()
+                .noneMatch(e -> e instanceof MethodReferenceExpression)) {
             return;
         }
 
-        MethodNode selectedMethod = call.getNodeMetaData(DIRECT_METHOD_CALL_TARGET);
-        if (selectedMethod == null) return;
-
         Parameter[] parameters = selectedMethod.getParameters();
 
         List<Integer> methodReferenceParamIndexList = new LinkedList<>();
         List<Expression> newArgumentExpressionList = new LinkedList<>();
-        for (int i = 0, n = argumentExpressionList.size(); i < n; i++) {
-            Expression argumentExpression = argumentExpressionList.get(i);
+        for (int i = 0, n = argumentExpressions.size(); i < n; i += 1) {
+            Expression argumentExpression = argumentExpressions.get(i);
             if (!(argumentExpression instanceof MethodReferenceExpression)) {
                 newArgumentExpressionList.add(argumentExpression);
                 continue;
@@ -3686,7 +3682,7 @@
         for (Integer methodReferenceParamIndex : methodReferenceParamIndexList) {
             LambdaExpression lambdaExpression = (LambdaExpression) newArgumentExpressionList.get(methodReferenceParamIndex);
             ClassNode[] argumentTypes = lambdaExpression.getNodeMetaData(CLOSURE_ARGUMENTS);
-            argumentExpressionList.get(methodReferenceParamIndex).putNodeMetaData(CLOSURE_ARGUMENTS, argumentTypes);
+            argumentExpressions.get(methodReferenceParamIndex).putNodeMetaData(CLOSURE_ARGUMENTS, argumentTypes);
         }
     }
 
diff --git a/src/test/groovy/transform/stc/MethodReferenceTest.groovy b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
index 0825849..e88c3b3 100644
--- a/src/test/groovy/transform/stc/MethodReferenceTest.groovy
+++ b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
@@ -51,7 +51,38 @@
             @CompileStatic
             void p() {
                 def result = [1, 2, 3].stream().map(Integer::toString).collect(Collectors.toList())
-                assert result ['1', '2', '3']
+                assert result == ['1', '2', '3']
+            }
+
+            p()
+        '''
+    }
+
+    @Test // class::instanceMethod -- GROOVY-10047
+    void testFunctionCI3() {
+        assertScript shell, '''
+            import static java.util.stream.Collectors.toMap
+
+            @CompileStatic
+            void p() {
+                List<String> list = ['a','bc','def']
+                Function<String,String> self = str -> str // help for toMap
+                def map = list.stream().collect(toMap(self, String::length))
+                assert map == [a: 1, bc: 2, 'def': 3]
+            }
+
+            p()
+        '''
+
+        assertScript shell, '''
+            import static java.util.stream.Collectors.toMap
+
+            @CompileStatic
+            void p() {
+                List<String> list = ['a','bc','def']
+                // TODO: inference for T in toMap(Function<? super T,...>, Function<? super T,...>)
+                def map = list.stream().collect(toMap(Function.<String>identity(), String::length))
+                assert map == [a: 1, bc: 2, 'def': 3]
             }
 
             p()
@@ -59,7 +90,7 @@
     }
 
     @Test // class::instanceMethod
-    void testFunctionCI3() {
+    void testFunctionCI4() {
         def err = shouldFail shell, '''
             @CompileStatic
             void p() {
@@ -73,7 +104,7 @@
     }
 
     @Test // class::instanceMethod -- GROOVY-9814
-    void testFunctionCI4() {
+    void testFunctionCI5() {
         assertScript shell, '''
             @CompileStatic
             class One { String id }
@@ -291,6 +322,22 @@
         '''
     }
 
+    @Test // class::new -- GROOVY-10033
+    void testFunctionCN2() {
+        assertScript shell, '''
+            class C {
+                C(Function<String,String> f) {
+                }
+            }
+            @CompileStatic
+            void p() {
+                new C(String::toLowerCase)
+            }
+
+            p()
+        '''
+    }
+
     @Test // class::staticMethod
     void testFunctionCS() {
         assertScript shell, '''
@@ -304,9 +351,25 @@
         '''
     }
 
-    @Test // class::staticMethod -- GROOVY-9799
+    @Test // class::staticMethod
     void testFunctionCS2() {
         assertScript shell, '''
+            import static java.util.stream.Collectors.toMap
+
+            @CompileStatic
+            void p() {
+                List<String> list = ['x','y','z']
+                def map = list.stream().collect(toMap(Function.identity(), Collections::singletonList))
+                assert map == [x: ['x'], y: ['y'], z: ['z']]
+            }
+
+            p()
+        '''
+    }
+
+    @Test // class::staticMethod -- GROOVY-9799
+    void testFunctionCS3() {
+        assertScript shell, '''
             class C {
                 String x
             }