add the ability to add additional types when stubbing

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/proxy/branches/version-2.0-work@1525272 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java b/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java
index d6976a1..dd73a05 100644
--- a/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java
+++ b/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java
@@ -168,7 +168,7 @@
 
         MapAnnotationTrainer(Map<String, ?> members)
         {
-            super(type);
+            super(annotationType);
             this.members = members;
         }
 
@@ -220,19 +220,24 @@
         return new AnnotationBuilder<A>(type, target);
     }
 
+    private final Class<A> annotationType;
+
     private AnnotationBuilder(Class<A> type, Invoker invoker)
     {
         super(PROXY_FACTORY, type, invoker);
+        this.annotationType = type;
     }
 
     private AnnotationBuilder(Class<A> type, ObjectProvider<? extends A> provider)
     {
         super(PROXY_FACTORY, type, provider);
+        this.annotationType = type;
     }
 
     private AnnotationBuilder(Class<A> type, A target)
     {
         super(PROXY_FACTORY, type, target);
+        this.annotationType = type;
     }
 
     public AnnotationBuilder<A> withMembers(Map<String, ?> members)
@@ -241,19 +246,19 @@
     }
 
     @Override
-    public AnnotationBuilder<A> train(BaseTrainer<?, ? super A> trainer)
+    public <O> AnnotationBuilder<A> train(BaseTrainer<?, O> trainer)
     {
         return (AnnotationBuilder<A>) super.train(trainer);
     }
 
     @Override
     public A build() {
-        train(new AnnotationTrainer<A>(type)
+        train(new AnnotationTrainer<A>(annotationType)
         {
             @Override
             protected void train(A trainee)
             {
-                when(trainee.annotationType()).thenReturn(type);
+                when(trainee.annotationType()).thenReturn(annotationType);
             }
         });
         return super.build();
diff --git a/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java b/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
index 7eba7af..6d7e25d 100644
--- a/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
+++ b/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
@@ -17,6 +17,12 @@
 
 package org.apache.commons.proxy2.stub;
 
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.Validate;
 import org.apache.commons.lang3.builder.Builder;
 import org.apache.commons.proxy2.Invoker;
 import org.apache.commons.proxy2.ObjectProvider;
@@ -31,10 +37,10 @@
 // Fields
 //----------------------------------------------------------------------------------------------------------------------
 
-    protected final Class<T> type;
     private final ProxyFactory proxyFactory;
     private final T target;
     private final SwitchInterceptor switchInterceptor = new SwitchInterceptor();
+    private final Set<Class<?>> proxyTypes = new HashSet<Class<?>>();
 
 //----------------------------------------------------------------------------------------------------------------------
 // Constructors
@@ -48,22 +54,22 @@
     public StubBuilder(ProxyFactory proxyFactory, Class<T> type, Invoker invoker)
     {
         this.proxyFactory = proxyFactory;
-        this.type = type;
         this.target = proxyFactory.createInvokerProxy(invoker, type);
+        this.proxyTypes.add(Validate.notNull(type));
     }
     
     public StubBuilder(ProxyFactory proxyFactory, Class<T> type, ObjectProvider<? extends T> provider)
     {
         this.proxyFactory = proxyFactory;
-        this.type = type;
         this.target = proxyFactory.createDelegatorProxy(provider, type);
+        this.proxyTypes.add(Validate.notNull(type));
     }
 
     public StubBuilder(ProxyFactory proxyFactory, Class<T> type, T target)
     {
         this.proxyFactory = proxyFactory;
-        this.type = type;
         this.target = proxyFactory.createDelegatorProxy(new ConstantProvider<T>(target), type);
+        this.proxyTypes.add(Validate.notNull(type));
     }
 
 //----------------------------------------------------------------------------------------------------------------------
@@ -72,16 +78,18 @@
 
     public T build()
     {
-        return proxyFactory.createInterceptorProxy(target, switchInterceptor, type);
+        return proxyFactory.createInterceptorProxy(target, switchInterceptor,
+                proxyTypes.toArray(ArrayUtils.EMPTY_CLASS_ARRAY));
     }
 
-    public StubBuilder<T> train(BaseTrainer<?, ? super T> trainer)
+    public <O> StubBuilder<T> train(BaseTrainer<?, O> trainer)
     {
         try
         {
             TrainingContext trainingContext = TrainingContext.set(proxyFactory);
-            T trainee = trainingContext.push(type, switchInterceptor);
+            final O trainee = trainingContext.push(trainer.traineeType, switchInterceptor);
             trainer.train(trainee);
+            proxyTypes.add(trainer.traineeType);
         }
         finally
         {
@@ -89,4 +97,10 @@
         }
         return this;
     }
+
+    public StubBuilder<T> addProxyTypes(Class<?>... proxyTypes)
+    {
+        Collections.addAll(this.proxyTypes, Validate.noNullElements(proxyTypes));
+        return this;
+    }
 }
diff --git a/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java b/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java
index a78141f..22c9384 100644
--- a/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java
+++ b/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java
@@ -19,9 +19,15 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.Iterator;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.proxy2.ObjectProvider;
 import org.apache.commons.proxy2.provider.BeanProvider;
+import org.apache.commons.proxy2.provider.ObjectProviderUtils;
 import org.junit.Test;
 
 public class StubBuilderTest extends AbstractStubTestCase
@@ -94,6 +100,33 @@
         assertEquals("Bar", stub.one("Foo"));
     }
 
+    @Test
+    public void testAdditionalInterfaces() {
+        StubBuilder<StubInterface> builder = new StubBuilder<StubInterface>(proxyFactory, StubInterface.class,
+                ObjectProviderUtils.constant(new SimpleStub()));
+        builder.train(new Trainer<Iterable<String>>()
+        {
+
+            @Override
+            protected void train(Iterable<String> trainee)
+            {
+                when(trainee.iterator()).thenAnswer(new ObjectProvider<Iterator<String>>()
+                {
+                    @Override
+                    public Iterator<String> getObject()
+                    {
+                        return Arrays.asList("foo", "bar", "baz").iterator();
+                    }
+                });
+            }
+        });
+        builder.addProxyTypes(Cloneable.class, Marker.class);
+        StubInterface stub = builder.build();
+        assertTrue(stub instanceof Iterable<?>);
+        assertTrue(stub instanceof Cloneable);
+        assertTrue(stub instanceof Marker);
+    }
+
 //----------------------------------------------------------------------------------------------------------------------
 // Inner Classes
 //----------------------------------------------------------------------------------------------------------------------
@@ -190,4 +223,8 @@
             return null;
         }
     }
+
+    public interface Marker
+    {
+    }
 }