SLING-9834 - [Sling Models] Caching bug with reused Servlet requests

* use ServletRequest attributes to cache sling models when the adaptable is a ServletRequest
diff --git a/pom.xml b/pom.xml
index 7f8aab9..6d9ad5e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -209,6 +209,11 @@
             <artifactId>slf4j-simple</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>org.apache.sling</groupId>
+            <artifactId>org.apache.sling.commons.testing</artifactId>
+            <version>2.0.26</version>
+        </dependency>
         <!-- for testing the annotations -->
         <dependency>
             <groupId>org.apache.felix</groupId>
diff --git a/src/main/java/org/apache/sling/models/impl/ModelAdapterFactory.java b/src/main/java/org/apache/sling/models/impl/ModelAdapterFactory.java
index c61e230..86cc2ad 100644
--- a/src/main/java/org/apache/sling/models/impl/ModelAdapterFactory.java
+++ b/src/main/java/org/apache/sling/models/impl/ModelAdapterFactory.java
@@ -121,6 +121,8 @@
 
     private static final Object REQUEST_MARKER_VALUE = new Object();
 
+    private static final String REQUEST_CACHE_ATTRIBUTE = ModelAdapterFactory.class.getName() + ".AdapterCache";
+
     private static class DisposalCallbackRegistryImpl implements DisposalCallbackRegistry, Disposable {
 
         private List<DisposalCallback> callbacks = new ArrayList<>();
@@ -343,7 +345,25 @@
         // throw exception here
         throw new ModelClassException("Could not yet find an adapter factory for the model " + requestedType + " from adaptable " + adaptable.getClass());
     }
-
+    
+    @SuppressWarnings("unchecked")
+    private Map<Class<?>, SoftReference<Object>> getOrCreateCache(final Object adaptable) {
+        Map<Class<?>, SoftReference<Object>> adaptableCache;
+        if (adaptable instanceof ServletRequest) {
+            ServletRequest request = (ServletRequest) adaptable;
+            adaptableCache = (Map<Class<?>, SoftReference<Object>>) request.getAttribute(REQUEST_CACHE_ATTRIBUTE);
+            if (adaptableCache == null) {
+                adaptableCache = Collections.synchronizedMap(new WeakHashMap<Class<?>, SoftReference<Object>>());
+                request.setAttribute(REQUEST_CACHE_ATTRIBUTE, adaptableCache);
+            }
+        } else {
+            adaptableCache = adapterCache.computeIfAbsent(adaptable, k -> {
+                return Collections.synchronizedMap(new WeakHashMap<Class<?>, SoftReference<Object>>());
+            });
+        }
+        return adaptableCache;
+    }
+    
     @SuppressWarnings("unchecked")
     private <ModelType> Result<ModelType> internalCreateModel(final Object adaptable, final Class<ModelType> requestedType) {
         Result<ModelType> result;
@@ -365,21 +385,15 @@
             boolean isAdaptable = false;
 
             Model modelAnnotation = modelClass.getModelAnnotation();
-            Object cacheKey = adaptable;
+            Map<Class<?>, SoftReference<Object>> adaptableCache = null;
 
             if (modelAnnotation.cache()) {
-                if (adaptable instanceof ServletRequestWrapper) {
-                    cacheKey = unwrapRequest((ServletRequest) adaptable);
-                }
-
-                Map<Class<?>, SoftReference<Object>> adaptableCache = adapterCache.get(cacheKey);
-                if (adaptableCache != null) {
-                    SoftReference<Object> softReference = adaptableCache.get(requestedType);
-                    if (softReference != null) {
-                        ModelType cachedObject = (ModelType) softReference.get();
-                        if (cachedObject != null) {
-                            return new Result<>(cachedObject);
-                        }
+                adaptableCache = getOrCreateCache(adaptable);
+                SoftReference<Object> softReference = adaptableCache.get(requestedType);
+                if (softReference != null) {
+                    ModelType cachedObject = (ModelType) softReference.get();
+                    if (cachedObject != null) {
+                        return new Result<>(cachedObject);
                     }
                 }
             }
@@ -403,12 +417,7 @@
                     if (handlerResult.wasSuccessful()) {
                         ModelType model = (ModelType) Proxy.newProxyInstance(modelClass.getType().getClassLoader(), new Class<?>[] { modelClass.getType() }, handlerResult.getValue());
 
-                        if (modelAnnotation.cache()) {
-                            Map<Class<?>, SoftReference<Object>> adaptableCache = adapterCache.get(cacheKey);
-                            if (adaptableCache == null) {
-                                adaptableCache = Collections.synchronizedMap(new WeakHashMap<Class<?>, SoftReference<Object>>());
-                                adapterCache.put(cacheKey, adaptableCache);
-                            }
+                        if (modelAnnotation.cache() && adaptableCache != null) {
                             adaptableCache.put(requestedType, new SoftReference<Object>(model));
                         }
 
@@ -420,12 +429,7 @@
                     try {
                         result = createObject(adaptable, modelClass);
 
-                        if (result.wasSuccessful() && modelAnnotation.cache()) {
-                            Map<Class<?>, SoftReference<Object>> adaptableCache = adapterCache.get(cacheKey);
-                            if (adaptableCache == null) {
-                                adaptableCache = Collections.synchronizedMap(new WeakHashMap<Class<?>, SoftReference<Object>>());
-                                adapterCache.put(cacheKey, adaptableCache);
-                            }
+                        if (result.wasSuccessful() && modelAnnotation.cache() && adaptableCache != null) {
                             adaptableCache.put(requestedType, new SoftReference<Object>(result.getValue()));
                         }
                     } catch (Exception e) {
diff --git a/src/test/java/org/apache/sling/models/impl/CachingTest.java b/src/test/java/org/apache/sling/models/impl/CachingTest.java
index 2af2fa5..13fef17 100644
--- a/src/test/java/org/apache/sling/models/impl/CachingTest.java
+++ b/src/test/java/org/apache/sling/models/impl/CachingTest.java
@@ -16,44 +16,58 @@
  */
 package org.apache.sling.models.impl;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
-import javax.servlet.ServletRequestWrapper;
-
-import org.apache.sling.api.SlingHttpServletRequest;
+import org.apache.sling.api.resource.Resource;
+import org.apache.sling.api.resource.ValueMap;
+import org.apache.sling.api.wrappers.SlingHttpServletRequestWrapper;
+import org.apache.sling.api.wrappers.ValueMapDecorator;
+import org.apache.sling.commons.testing.sling.MockSlingHttpServletRequest;
 import org.apache.sling.models.impl.injectors.RequestAttributeInjector;
+import org.apache.sling.models.impl.injectors.ValueMapInjector;
 import org.apache.sling.models.testmodels.classes.CachedModel;
 import org.apache.sling.models.testmodels.classes.UncachedModel;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
+import org.mockito.Spy;
 import org.mockito.runners.MockitoJUnitRunner;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
 @RunWith(MockitoJUnitRunner.class)
 public class CachingTest {
 
-    @Mock
-    private SlingHttpServletRequest request;
+    @Spy
+    private MockRequest request;
+
+    private SlingHttpServletRequestWrapper requestWrapper;
 
     @Mock
-    private ServletRequestWrapper requestWrapper;
-
+    private Resource resource;
+    
     private ModelAdapterFactory factory;
 
     @Before
     public void setup() {
         factory = AdapterFactoryTest.createModelAdapterFactory();
         factory.bindInjector(new RequestAttributeInjector(), new ServicePropertiesMap(0, 0));
+        factory.bindInjector(new ValueMapInjector(), new ServicePropertiesMap(1, 1));
         factory.adapterImplementations.addClassesAsAdapterAndImplementation(CachedModel.class, UncachedModel.class,
                 org.apache.sling.models.testmodels.interfaces.CachedModel.class, org.apache.sling.models.testmodels.interfaces.UncachedModel.class);
 
         when(request.getAttribute("testValue")).thenReturn("test");
-        when(requestWrapper.getRequest()).thenReturn(request);
+        requestWrapper = new SlingHttpServletRequestWrapper(request);
+        
+        ValueMap vm = new ValueMapDecorator(Collections.singletonMap("testValue", "test"));
+        when(resource.adaptTo(ValueMap.class)).thenReturn(vm);
     }
 
     @Test
@@ -67,6 +81,18 @@
 
         verify(request, times(1)).getAttribute("testValue");
     }
+    
+    @Test
+    public void testCachedClassWithResource() {
+        CachedModel cached1 = factory.getAdapter(resource, CachedModel.class);
+        CachedModel cached2 = factory.getAdapter(resource, CachedModel.class);
+
+        assertTrue(cached1 == cached2);
+        assertEquals("test", cached1.getTestValue());
+        assertEquals("test", cached2.getTestValue());
+
+        verify(resource, times(1)).adaptTo(ValueMap.class);
+    }
 
     @Test
     public void testNoCachedClass() {
@@ -79,6 +105,18 @@
 
         verify(request, times(2)).getAttribute("testValue");
     }
+    
+    @Test
+    public void testNoCachedClassWithResource() {
+        UncachedModel uncached1 = factory.getAdapter(resource, UncachedModel.class);
+        UncachedModel uncached2 = factory.getAdapter(resource, UncachedModel.class);
+
+        assertTrue(uncached1 != uncached2);
+        assertEquals("test", uncached1.getTestValue());
+        assertEquals("test", uncached2.getTestValue());
+
+        verify(resource, times(2)).adaptTo(ValueMap.class);
+    }
 
     @Test
     public void testCachedInterface() {
@@ -114,6 +152,11 @@
         assertEquals("test", cached2.getTestValue());
 
         verify(request, times(1)).getAttribute("testValue");
+        
+        // If we clear the request attributes, the sling model is no longer cached
+        request.clearAttributes();
+        CachedModel cached3 = factory.getAdapter(request, CachedModel.class);
+        assertTrue(cached1 != cached3);
     }
 
     @Test
@@ -127,4 +170,29 @@
 
         verify(request, times(1)).getAttribute("testValue");
     }
+    
+    // MockSlingHttpServletRequest doesn't implement set and get attributes
+    private static class MockRequest extends MockSlingHttpServletRequest {
+
+        private Map<String, Object> attributes = new HashMap<>();
+        
+        MockRequest() {
+            super(null, null, null, null, null);
+        }
+        
+        @Override
+        public void setAttribute(String name, Object o) {
+            attributes.put(name, o);
+        }
+        
+        @Override
+        public Object getAttribute(String name) {
+            return attributes.get(name);
+        }
+        
+        public void clearAttributes() {
+            attributes.clear();
+        }
+    }
 }
+
diff --git a/src/test/java/org/apache/sling/models/testmodels/classes/CachedModel.java b/src/test/java/org/apache/sling/models/testmodels/classes/CachedModel.java
index 2e02526..2a10933 100644
--- a/src/test/java/org/apache/sling/models/testmodels/classes/CachedModel.java
+++ b/src/test/java/org/apache/sling/models/testmodels/classes/CachedModel.java
@@ -16,12 +16,13 @@
  */
 package org.apache.sling.models.testmodels.classes;
 
-import org.apache.sling.api.SlingHttpServletRequest;
-import org.apache.sling.models.annotations.Model;
-
 import javax.inject.Inject;
 
-@Model(adaptables = SlingHttpServletRequest.class, cache = true)
+import org.apache.sling.api.SlingHttpServletRequest;
+import org.apache.sling.api.resource.Resource;
+import org.apache.sling.models.annotations.Model;
+
+@Model(adaptables = {SlingHttpServletRequest.class, Resource.class}, cache = true)
 public class CachedModel {
 
     @Inject
diff --git a/src/test/java/org/apache/sling/models/testmodels/classes/UncachedModel.java b/src/test/java/org/apache/sling/models/testmodels/classes/UncachedModel.java
index bd39352..06ae166 100644
--- a/src/test/java/org/apache/sling/models/testmodels/classes/UncachedModel.java
+++ b/src/test/java/org/apache/sling/models/testmodels/classes/UncachedModel.java
@@ -17,12 +17,13 @@
 package org.apache.sling.models.testmodels.classes;
 
 
-import org.apache.sling.api.SlingHttpServletRequest;
-import org.apache.sling.models.annotations.Model;
-
 import javax.inject.Inject;
 
-@Model(adaptables = SlingHttpServletRequest.class)
+import org.apache.sling.api.SlingHttpServletRequest;
+import org.apache.sling.api.resource.Resource;
+import org.apache.sling.models.annotations.Model;
+
+@Model(adaptables = {SlingHttpServletRequest.class, Resource.class})
 public class UncachedModel {
 
     @Inject