calculate trainee type at instantiation; support subclassing with self-referencing BaseTrainer class; allow a backing invoker to be specified for a stub; move thenStub() to WhenObject because returning a stub doesn't make sense for any of the other fluent constructs; raw type warning; add special WhenClass construct to be properly documented in a future commit

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/proxy/branches/version-2.0-work@1515495 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/stub/src/main/java/org/apache/commons/proxy2/stub/BaseTrainer.java b/stub/src/main/java/org/apache/commons/proxy2/stub/BaseTrainer.java
new file mode 100644
index 0000000..e9acc30
--- /dev/null
+++ b/stub/src/main/java/org/apache/commons/proxy2/stub/BaseTrainer.java
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.proxy2.stub;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.Validate;
+import org.apache.commons.lang3.reflect.TypeUtils;
+import org.apache.commons.proxy2.ObjectProvider;
+import org.apache.commons.proxy2.ProxyUtils;
+import org.apache.commons.proxy2.interceptor.InterceptorUtils;
+import org.apache.commons.proxy2.interceptor.matcher.ArgumentMatcher;
+import org.apache.commons.proxy2.interceptor.matcher.argument.ArgumentMatcherUtils;
+
+public abstract class BaseTrainer<S extends BaseTrainer<S, T>, T>
+{
+//----------------------------------------------------------------------------------------------------------------------
+// Fields
+//----------------------------------------------------------------------------------------------------------------------
+    public final Class<T> traineeType;
+
+//----------------------------------------------------------------------------------------------------------------------
+// Constructors
+//----------------------------------------------------------------------------------------------------------------------
+
+    /**
+     * Create a new {@link BaseTrainer} instance. This constructor should only be called
+     * by classes that explicitly assign the T parameter in the class definition.
+     * This should include basically any runtime-usable class.
+     */
+    protected BaseTrainer()
+    {
+        this(null);
+    }
+
+    protected BaseTrainer(Class<T> traineeType)
+    {
+        super();
+        if (traineeType != null)
+        {
+            this.traineeType = traineeType;
+            return;
+        }
+        @SuppressWarnings("unchecked")
+        final Class<T> resolvedVariable =
+            (Class<T>) TypeUtils.getRawType(BaseTrainer.class.getTypeParameters()[1], getClass());
+        Validate.isTrue(resolvedVariable != null, "Trainee type was not specified and could not be calculated for %s",
+            getClass());
+        this.traineeType = resolvedVariable;
+    }
+
+//----------------------------------------------------------------------------------------------------------------------
+// Abstract Methods
+//----------------------------------------------------------------------------------------------------------------------
+
+    protected abstract void train(T trainee);
+
+//----------------------------------------------------------------------------------------------------------------------
+// Other Methods
+//----------------------------------------------------------------------------------------------------------------------
+
+    protected <R> R any(Class<R> type)
+    {
+        record(ArgumentMatcherUtils.any());
+        return null;
+    }
+
+    private void record(ArgumentMatcher<?> matcher)
+    {
+        trainingContext().record(matcher);
+    }
+
+    protected <R> R eq(R value)
+    {
+        record(ArgumentMatcherUtils.eq(value));
+        return value;
+    }
+
+    protected <R> R isInstance(Class<R> type)
+    {
+        record(ArgumentMatcherUtils.isA(type));
+        return ProxyUtils.nullValue(type);
+    }
+
+    protected void thenThrow(Exception e)
+    {
+        trainingContext().then(InterceptorUtils.throwing(e));
+    }
+
+    protected void thenThrow(ObjectProvider<? extends Exception> provider)
+    {
+        trainingContext().then(InterceptorUtils.throwing(provider));
+    }
+
+    protected TrainingContext trainingContext()
+    {
+        return TrainingContext.getCurrent();
+    }
+
+    protected <R> WhenObject<R> when(R expression)
+    {
+        return new WhenObject<R>();
+    }
+
+    protected WhenClass when(Class<?> expression)
+    {
+        return new WhenClass();
+    }
+
+    protected WhenByteArray when(byte[] expression)
+    {
+        return new WhenByteArray();
+    }
+
+    protected WhenBooleanArray when(boolean[] expression)
+    {
+        return new WhenBooleanArray();
+    }
+
+    protected WhenIntArray when(int[] expression)
+    {
+        return new WhenIntArray();
+    }
+
+    protected WhenShortArray when(short[] expresssion)
+    {
+        return new WhenShortArray();
+    }
+
+    protected WhenLongArray when(long[] expression)
+    {
+        return new WhenLongArray();
+    }
+
+    protected WhenFloatArray when(float[] expression)
+    {
+        return new WhenFloatArray();
+    }
+
+    protected WhenDoubleArray when(double[] expression)
+    {
+        return new WhenDoubleArray();
+    }
+
+    protected <R> WhenObjectArray<R> when(R[] expression)
+    {
+        return new WhenObjectArray<R>();
+    }
+
+    protected WhenCharArray when(char[] expression)
+    {
+        return new WhenCharArray();
+    }
+
+    @SuppressWarnings("unchecked")
+    protected S self()
+    {
+        return (S) this;
+    }
+
+//----------------------------------------------------------------------------------------------------------------------
+// Inner Classes
+//----------------------------------------------------------------------------------------------------------------------
+
+    protected abstract class BaseWhen<R>
+    {
+        protected S thenThrow(Exception e)
+        {
+            trainingContext().then(InterceptorUtils.throwing(e));
+            return self();
+        }
+
+        protected S thenThrow(ObjectProvider<? extends Exception> provider)
+        {
+            trainingContext().then(InterceptorUtils.throwing(provider));
+            return self();
+        }
+
+        protected S thenAnswer(ObjectProvider<? extends R> provider)
+        {
+            trainingContext().then(InterceptorUtils.provider(provider));
+            return self();
+        }
+    }
+
+    protected class WhenBooleanArray extends BaseWhen<boolean[]>
+    {
+        protected S thenReturn(boolean... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenByteArray extends BaseWhen<byte[]>
+    {
+        protected S thenReturn(byte... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenCharArray extends BaseWhen<char[]>
+    {
+        protected S thenReturn(char... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenDoubleArray extends BaseWhen<double[]>
+    {
+        protected S thenReturn(double... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenFloatArray extends BaseWhen<float[]>
+    {
+        protected S thenReturn(float... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenIntArray extends BaseWhen<int[]>
+    {
+        protected S thenReturn(int... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenLongArray extends BaseWhen<long[]>
+    {
+        protected S thenReturn(long... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenObject<R> extends BaseWhen<R>
+    {
+        protected S thenReturn(R value)
+        {
+            trainingContext().then(InterceptorUtils.constant(value));
+            return self();
+        }
+
+        protected S thenStub(BaseTrainer<?, R> trainer)
+        {
+            final R trainee = trainingContext().push(trainer.traineeType);
+            trainer.train(trainee);
+            trainingContext().then(InterceptorUtils.constant(trainingContext().pop()));
+            return self();
+        }
+    }
+
+    protected class WhenClass extends BaseWhen<Class<?>>
+    {
+        protected S thenReturn(Class<?> value)
+        {
+            trainingContext().then(InterceptorUtils.constant(value));
+            return self();
+        }
+    }
+
+    protected class WhenObjectArray<R> extends BaseWhen<R[]>
+    {
+        protected S thenReturn(R... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+
+    protected class WhenShortArray extends BaseWhen<short[]>
+    {
+        protected S thenReturn(short... values)
+        {
+            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
+            return self();
+        }
+    }
+}
diff --git a/stub/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java b/stub/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
index 1381a81..339c885 100644
--- a/stub/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
+++ b/stub/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
@@ -18,6 +18,7 @@
 package org.apache.commons.proxy2.stub;
 
 import org.apache.commons.lang3.builder.Builder;
+import org.apache.commons.proxy2.Invoker;
 import org.apache.commons.proxy2.ObjectProvider;
 import org.apache.commons.proxy2.ProxyFactory;
 import org.apache.commons.proxy2.interceptor.SwitchInterceptor;
@@ -30,9 +31,9 @@
 // Fields
 //----------------------------------------------------------------------------------------------------------------------
 
+    protected final Class<T> type;
     private final ProxyFactory proxyFactory;
     private final T target;
-    private final Class<T> type;
     private final SwitchInterceptor switchInterceptor = new SwitchInterceptor();
 
 //----------------------------------------------------------------------------------------------------------------------
@@ -41,11 +42,16 @@
 
     public StubBuilder(ProxyFactory proxyFactory, Class<T> type)
     {
-        this.proxyFactory = proxyFactory;
-        this.type = type;
-        this.target = proxyFactory.createInvokerProxy(NullInvoker.INSTANCE, type);
+        this(proxyFactory, type, NullInvoker.INSTANCE);
     }
 
+    public StubBuilder(ProxyFactory proxyFactory, Class<T> type, Invoker invoker)
+    {
+        this.proxyFactory = proxyFactory;
+        this.type = type;
+        this.target = proxyFactory.createInvokerProxy(invoker, type);
+    }
+    
     public StubBuilder(ProxyFactory proxyFactory, Class<T> type, ObjectProvider<? extends T> provider)
     {
         this.proxyFactory = proxyFactory;
@@ -69,7 +75,7 @@
         return proxyFactory.createInterceptorProxy(target, switchInterceptor, type);
     }
 
-    public StubBuilder<T> train(Trainer<T> trainer)
+    public StubBuilder<T> train(BaseTrainer<?, T> trainer)
     {
         try
         {
diff --git a/stub/src/main/java/org/apache/commons/proxy2/stub/Trainer.java b/stub/src/main/java/org/apache/commons/proxy2/stub/Trainer.java
index 146e745..d0962bc 100644
--- a/stub/src/main/java/org/apache/commons/proxy2/stub/Trainer.java
+++ b/stub/src/main/java/org/apache/commons/proxy2/stub/Trainer.java
@@ -17,240 +17,14 @@
 
 package org.apache.commons.proxy2.stub;
 
-import org.apache.commons.lang3.ArrayUtils;
-import org.apache.commons.lang3.reflect.TypeUtils;
-import org.apache.commons.proxy2.ObjectProvider;
-import org.apache.commons.proxy2.ProxyUtils;
-import org.apache.commons.proxy2.interceptor.InterceptorUtils;
-import org.apache.commons.proxy2.interceptor.matcher.ArgumentMatcher;
-import org.apache.commons.proxy2.interceptor.matcher.argument.ArgumentMatcherUtils;
+public abstract class Trainer<T> extends BaseTrainer<Trainer<T>, T> {
 
-public abstract class Trainer<T>
-{
-//----------------------------------------------------------------------------------------------------------------------
-// Abstract Methods
-//----------------------------------------------------------------------------------------------------------------------
-
-    protected abstract void train(T trainee);
-
-//----------------------------------------------------------------------------------------------------------------------
-// Other Methods
-//----------------------------------------------------------------------------------------------------------------------
-
-    protected <R> R any(Class<R> type)
-    {
-        record(ArgumentMatcherUtils.any());
-        return null;
+    protected Trainer() {
+        super();
     }
 
-    private void record(ArgumentMatcher matcher)
-    {
-        trainingContext().record(matcher);
+    protected Trainer(Class<T> traineeType) {
+        super(traineeType);
     }
 
-    protected <R> R eq(R value)
-    {
-        record(ArgumentMatcherUtils.eq(value));
-        return value;
-    }
-
-    @SuppressWarnings("unchecked")
-    public Class<T> getTraineeType()
-    {
-        return (Class<T>) TypeUtils.getRawType(Trainer.class.getTypeParameters()[0], getClass());
-    }
-
-    protected <R> R isInstance(Class<R> type)
-    {
-        record(ArgumentMatcherUtils.isA(type));
-        return ProxyUtils.nullValue(type);
-    }
-
-    protected void thenThrow(Exception e)
-    {
-        trainingContext().then(InterceptorUtils.throwing(e));
-    }
-
-    protected void thenThrow(ObjectProvider<? extends Exception> provider)
-    {
-        trainingContext().then(InterceptorUtils.throwing(provider));
-    }
-
-    private TrainingContext trainingContext()
-    {
-        return TrainingContext.getCurrent();
-    }
-
-    protected <R> WhenObject<R> when(R expression)
-    {
-        return new WhenObject<R>();
-    }
-
-    protected WhenByteArray when(byte[] expression)
-    {
-        return new WhenByteArray();
-    }
-
-    protected WhenBooleanArray when(boolean[] expression)
-    {
-        return new WhenBooleanArray();
-    }
-
-    protected WhenIntArray when(int[] expression)
-    {
-        return new WhenIntArray();
-    }
-
-    protected WhenShortArray when(short[] expresssion)
-    {
-        return new WhenShortArray();
-    }
-
-    protected WhenLongArray when(long[] expression)
-    {
-        return new WhenLongArray();
-    }
-
-    protected WhenFloatArray when(float[] expression)
-    {
-        return new WhenFloatArray();
-    }
-
-    protected WhenDoubleArray when(double[] expression)
-    {
-        return new WhenDoubleArray();
-    }
-
-    protected <R> WhenObjectArray<R> when(R[] expression)
-    {
-        return new WhenObjectArray<R>();
-    }
-
-    protected WhenCharArray when(char[] expression)
-    {
-        return new WhenCharArray();
-    }
-
-//----------------------------------------------------------------------------------------------------------------------
-// Inner Classes
-//----------------------------------------------------------------------------------------------------------------------
-
-    protected abstract class BaseWhen<R>
-    {
-        protected Trainer<T> thenStub(Trainer<R> trainer)
-        {
-            R trainee = trainingContext().push(trainer.getTraineeType());
-            trainer.train(trainee);
-            trainingContext().then(InterceptorUtils.constant(trainingContext().pop()));
-            return Trainer.this;
-        }
-
-        protected Trainer<T> thenThrow(Exception e)
-        {
-            trainingContext().then(InterceptorUtils.throwing(e));
-            return Trainer.this;
-        }
-
-        protected Trainer<T> thenThrow(ObjectProvider<? extends Exception> provider)
-        {
-            trainingContext().then(InterceptorUtils.throwing(provider));
-            return Trainer.this;
-        }
-
-        protected <R> Trainer<T> thenAnswer(ObjectProvider<? extends R> provider)
-        {
-            trainingContext().then(InterceptorUtils.provider(provider));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenBooleanArray extends BaseWhen<boolean[]>
-    {
-        protected Trainer<T> thenReturn(boolean... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenByteArray extends BaseWhen<byte[]>
-    {
-        protected Trainer<T> thenReturn(byte... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenCharArray extends BaseWhen<char[]>
-    {
-        protected Trainer<T> thenReturn(char... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenDoubleArray extends BaseWhen<double[]>
-    {
-        protected Trainer<T> thenReturn(double... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenFloatArray extends BaseWhen<float[]>
-    {
-        protected Trainer<T> thenReturn(float... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenIntArray extends BaseWhen<int[]>
-    {
-        protected Trainer<T> thenReturn(int... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenLongArray extends BaseWhen<long[]>
-    {
-        protected Trainer<T> thenReturn(long... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenObject<R> extends BaseWhen
-    {
-        protected Trainer<T> thenReturn(R value)
-        {
-            trainingContext().then(InterceptorUtils.constant(value));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenObjectArray<R> extends BaseWhen<R[]>
-    {
-        protected Trainer<T> thenReturn(R... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
-
-    protected class WhenShortArray extends BaseWhen<short[]>
-    {
-        protected Trainer<T> thenReturn(short... values)
-        {
-            trainingContext().then(InterceptorUtils.constant(ArrayUtils.clone(values)));
-            return Trainer.this;
-        }
-    }
 }