add the ability to add additional types when stubbing
git-svn-id: https://svn.apache.org/repos/asf/commons/proper/proxy/branches/version-2.0-work@1525272 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java b/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java
index d6976a1..dd73a05 100644
--- a/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java
+++ b/core/src/main/java/org/apache/commons/proxy2/stub/AnnotationBuilder.java
@@ -168,7 +168,7 @@
MapAnnotationTrainer(Map<String, ?> members)
{
- super(type);
+ super(annotationType);
this.members = members;
}
@@ -220,19 +220,24 @@
return new AnnotationBuilder<A>(type, target);
}
+ private final Class<A> annotationType;
+
private AnnotationBuilder(Class<A> type, Invoker invoker)
{
super(PROXY_FACTORY, type, invoker);
+ this.annotationType = type;
}
private AnnotationBuilder(Class<A> type, ObjectProvider<? extends A> provider)
{
super(PROXY_FACTORY, type, provider);
+ this.annotationType = type;
}
private AnnotationBuilder(Class<A> type, A target)
{
super(PROXY_FACTORY, type, target);
+ this.annotationType = type;
}
public AnnotationBuilder<A> withMembers(Map<String, ?> members)
@@ -241,19 +246,19 @@
}
@Override
- public AnnotationBuilder<A> train(BaseTrainer<?, ? super A> trainer)
+ public <O> AnnotationBuilder<A> train(BaseTrainer<?, O> trainer)
{
return (AnnotationBuilder<A>) super.train(trainer);
}
@Override
public A build() {
- train(new AnnotationTrainer<A>(type)
+ train(new AnnotationTrainer<A>(annotationType)
{
@Override
protected void train(A trainee)
{
- when(trainee.annotationType()).thenReturn(type);
+ when(trainee.annotationType()).thenReturn(annotationType);
}
});
return super.build();
diff --git a/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java b/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
index 7eba7af..6d7e25d 100644
--- a/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
+++ b/core/src/main/java/org/apache/commons/proxy2/stub/StubBuilder.java
@@ -17,6 +17,12 @@
package org.apache.commons.proxy2.stub;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.builder.Builder;
import org.apache.commons.proxy2.Invoker;
import org.apache.commons.proxy2.ObjectProvider;
@@ -31,10 +37,10 @@
// Fields
//----------------------------------------------------------------------------------------------------------------------
- protected final Class<T> type;
private final ProxyFactory proxyFactory;
private final T target;
private final SwitchInterceptor switchInterceptor = new SwitchInterceptor();
+ private final Set<Class<?>> proxyTypes = new HashSet<Class<?>>();
//----------------------------------------------------------------------------------------------------------------------
// Constructors
@@ -48,22 +54,22 @@
public StubBuilder(ProxyFactory proxyFactory, Class<T> type, Invoker invoker)
{
this.proxyFactory = proxyFactory;
- this.type = type;
this.target = proxyFactory.createInvokerProxy(invoker, type);
+ this.proxyTypes.add(Validate.notNull(type));
}
public StubBuilder(ProxyFactory proxyFactory, Class<T> type, ObjectProvider<? extends T> provider)
{
this.proxyFactory = proxyFactory;
- this.type = type;
this.target = proxyFactory.createDelegatorProxy(provider, type);
+ this.proxyTypes.add(Validate.notNull(type));
}
public StubBuilder(ProxyFactory proxyFactory, Class<T> type, T target)
{
this.proxyFactory = proxyFactory;
- this.type = type;
this.target = proxyFactory.createDelegatorProxy(new ConstantProvider<T>(target), type);
+ this.proxyTypes.add(Validate.notNull(type));
}
//----------------------------------------------------------------------------------------------------------------------
@@ -72,16 +78,18 @@
public T build()
{
- return proxyFactory.createInterceptorProxy(target, switchInterceptor, type);
+ return proxyFactory.createInterceptorProxy(target, switchInterceptor,
+ proxyTypes.toArray(ArrayUtils.EMPTY_CLASS_ARRAY));
}
- public StubBuilder<T> train(BaseTrainer<?, ? super T> trainer)
+ public <O> StubBuilder<T> train(BaseTrainer<?, O> trainer)
{
try
{
TrainingContext trainingContext = TrainingContext.set(proxyFactory);
- T trainee = trainingContext.push(type, switchInterceptor);
+ final O trainee = trainingContext.push(trainer.traineeType, switchInterceptor);
trainer.train(trainee);
+ proxyTypes.add(trainer.traineeType);
}
finally
{
@@ -89,4 +97,10 @@
}
return this;
}
+
+ public StubBuilder<T> addProxyTypes(Class<?>... proxyTypes)
+ {
+ Collections.addAll(this.proxyTypes, Validate.noNullElements(proxyTypes));
+ return this;
+ }
}
diff --git a/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java b/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java
index a78141f..22c9384 100644
--- a/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java
+++ b/test/src/test/java/org/apache/commons/proxy2/stub/StubBuilderTest.java
@@ -19,9 +19,15 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.Iterator;
import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.proxy2.ObjectProvider;
import org.apache.commons.proxy2.provider.BeanProvider;
+import org.apache.commons.proxy2.provider.ObjectProviderUtils;
import org.junit.Test;
public class StubBuilderTest extends AbstractStubTestCase
@@ -94,6 +100,33 @@
assertEquals("Bar", stub.one("Foo"));
}
+ @Test
+ public void testAdditionalInterfaces() {
+ StubBuilder<StubInterface> builder = new StubBuilder<StubInterface>(proxyFactory, StubInterface.class,
+ ObjectProviderUtils.constant(new SimpleStub()));
+ builder.train(new Trainer<Iterable<String>>()
+ {
+
+ @Override
+ protected void train(Iterable<String> trainee)
+ {
+ when(trainee.iterator()).thenAnswer(new ObjectProvider<Iterator<String>>()
+ {
+ @Override
+ public Iterator<String> getObject()
+ {
+ return Arrays.asList("foo", "bar", "baz").iterator();
+ }
+ });
+ }
+ });
+ builder.addProxyTypes(Cloneable.class, Marker.class);
+ StubInterface stub = builder.build();
+ assertTrue(stub instanceof Iterable<?>);
+ assertTrue(stub instanceof Cloneable);
+ assertTrue(stub instanceof Marker);
+ }
+
//----------------------------------------------------------------------------------------------------------------------
// Inner Classes
//----------------------------------------------------------------------------------------------------------------------
@@ -190,4 +223,8 @@
return null;
}
}
+
+ public interface Marker
+ {
+ }
}