RANGER-4302: RangerCache updated to support value loader to use callers context in Ranger admin - #2
diff --git a/agents-common/src/main/java/org/apache/ranger/plugin/util/RangerCache.java b/agents-common/src/main/java/org/apache/ranger/plugin/util/RangerCache.java
index da8725b..9769aaa 100644
--- a/agents-common/src/main/java/org/apache/ranger/plugin/util/RangerCache.java
+++ b/agents-common/src/main/java/org/apache/ranger/plugin/util/RangerCache.java
@@ -93,28 +93,7 @@
     public long getValueRefreshLoadTimeoutMs() { return valueRefreshLoadTimeoutMs; }
 
     public V get(K key) {
-        final long        startTime = System.currentTimeMillis();
-        final CachedValue value     = cache.computeIfAbsent(key, f -> new CachedValue(key));
-        final long        timeoutMs = value.isInitialized() ? valueRefreshLoadTimeoutMs : valueInitLoadTimeoutMs;
-        final V           ret;
-
-        if (timeoutMs >= 0) {
-            final long timeTaken = System.currentTimeMillis() - startTime;
-
-            if (timeoutMs <= timeTaken) {
-                ret = value.getCurrentValue();
-
-                if (LOG.isDebugEnabled()) {
-                    LOG.debug("key={}: cache-lookup={}ms took longer than timeout={}ms. Using current value {}", key, timeTaken, timeoutMs, ret);
-                }
-            } else {
-                ret = value.getValue(timeoutMs - timeTaken);
-            }
-        } else {
-            ret = value.getValue();
-        }
-
-        return ret;
+        return get(key, null);
     }
 
     public Set<K> getKeys() {
@@ -147,6 +126,31 @@
         return value != null;
     }
 
+    protected V get(K key, Object context) {
+        final long        startTime = System.currentTimeMillis();
+        final CachedValue value     = cache.computeIfAbsent(key, f -> new CachedValue(key));
+        final long        timeoutMs = value.isInitialized() ? valueRefreshLoadTimeoutMs : valueInitLoadTimeoutMs;
+        final V           ret;
+
+        if (timeoutMs >= 0) {
+            final long timeTaken = System.currentTimeMillis() - startTime;
+
+            if (timeoutMs <= timeTaken) {
+                ret = value.getCurrentValue();
+
+                if (LOG.isDebugEnabled()) {
+                    LOG.debug("key={}: cache-lookup={}ms took longer than timeout={}ms. Using current value {}", key, timeTaken, timeoutMs, ret);
+                }
+            } else {
+                ret = value.getValue(timeoutMs - timeTaken);
+            }
+        } else {
+            ret = value.getValue(context);
+        }
+
+        return ret;
+    }
+
     public static class RefreshableValue<V> {
         private final V    value;
         private       long nextRefreshTimeMs = -1;
@@ -165,7 +169,7 @@
     }
 
     public static abstract class ValueLoader<K, V> {
-        public abstract RefreshableValue<V> load(K key, RefreshableValue<V> currentValue) throws Exception;
+        public abstract RefreshableValue<V> load(K key, RefreshableValue<V> currentValue, Object context) throws Exception;
     }
 
     private class CachedValue {
@@ -185,17 +189,17 @@
 
         public K getKey() { return key; }
 
-        public V getValue() {
-            refreshIfNeeded();
+        public V getValue(Object context) {
+            refreshIfNeeded(context);
 
             return getCurrentValue();
         }
 
-        public V getValue(long timeoutMs) {
+        public V getValue(long timeoutMs, Object context) {
             if (timeoutMs < 0) {
-                refreshIfNeeded();
+                refreshIfNeeded(context);
             } else {
-                refreshIfNeeded(timeoutMs);
+                refreshIfNeeded(timeoutMs, context);
             }
 
             return getCurrentValue();
@@ -217,7 +221,7 @@
             return value != null;
         }
 
-        private void refreshIfNeeded() {
+        private void refreshIfNeeded(Object context) {
             if (needsRefresh()) {
                 try (AutoClosableLock ignored = new AutoClosableLock(lock)) {
                     if (needsRefresh()) {
@@ -228,7 +232,7 @@
                                 LOG.debug("refreshIfNeeded(key={}): using caller thread", key);
                             }
 
-                            refreshValue();
+                            refreshValue(context);
                         } else { // wait for the refresher to complete
                             try {
                                 future.get();
@@ -243,7 +247,7 @@
             }
         }
 
-        private void refreshIfNeeded(long timeoutMs) {
+        private void refreshIfNeeded(long timeoutMs, Object context) {
             if (needsRefresh()) {
                 long startTime = System.currentTimeMillis();
 
@@ -253,7 +257,7 @@
                             Future<?> future = this.refresher;
 
                             if (future == null) {
-                                future = this.refresher = loaderThreadPool.submit(this::refreshValue);
+                                future = this.refresher = loaderThreadPool.submit(new RefreshWithContext(context));
 
                                 if (LOG.isDebugEnabled()) {
                                     LOG.debug("refresher scheduled for key {}", key);
@@ -287,7 +291,7 @@
             }
         }
 
-        private Boolean refreshValue() {
+        private Boolean refreshValue(Object context) {
             long                startTime = System.currentTimeMillis();
             boolean             isSuccess = false;
             RefreshableValue<V> newValue  = null;
@@ -296,7 +300,7 @@
                 ValueLoader<K, V> loader = RangerCache.this.loader;
 
                 if (loader != null) {
-                    newValue  = loader.load(key, value);
+                    newValue  = loader.load(key, value, context);
                     isSuccess = true;
                 }
             } catch (KeyNotFoundException excp) {
@@ -319,7 +323,7 @@
                      if (!isRemoved) {
                          ScheduledExecutorService scheduledExecutor = ((ScheduledExecutorService) loaderThreadPool);
 
-                         scheduledExecutor.schedule(this::refreshValue, valueValidityPeriodMs, TimeUnit.MILLISECONDS);
+                         scheduledExecutor.schedule(new RefreshWithContext(context), valueValidityPeriodMs, TimeUnit.MILLISECONDS);
                      } else {
                          if (LOG.isDebugEnabled()) {
                              LOG.debug("key {} was removed. Not scheduling next refresh ", key);
@@ -338,6 +342,19 @@
                 this.value.setNextRefreshTimeMs(System.currentTimeMillis() + valueValidityPeriodMs);
             }
         }
+
+        private class RefreshWithContext implements Callable<Boolean> {
+            private final Object context;
+
+            public RefreshWithContext(Object context) {
+                this.context = context;
+            }
+
+            @Override
+            public Boolean call() {
+                return refreshValue(context);
+            }
+        }
     }
 
     private ThreadFactory createThreadFactory() {
diff --git a/agents-common/src/test/java/org/apache/ranger/plugin/util/RangerCacheTest.java b/agents-common/src/test/java/org/apache/ranger/plugin/util/RangerCacheTest.java
index 8b89496..f9bc426 100644
--- a/agents-common/src/test/java/org/apache/ranger/plugin/util/RangerCacheTest.java
+++ b/agents-common/src/test/java/org/apache/ranger/plugin/util/RangerCacheTest.java
@@ -412,7 +412,7 @@
             }
 
             @Override
-            public RefreshableValue<List<String>> load(String userName, RefreshableValue<List<String>> currVal) throws Exception {
+            public RefreshableValue<List<String>> load(String userName, RefreshableValue<List<String>> currVal, Object context) throws Exception {
                 long startTimeMs = System.currentTimeMillis();
 
                 UserStats userStats = stats.get(userName);
diff --git a/security-admin/src/main/java/org/apache/ranger/biz/GdsDBStore.java b/security-admin/src/main/java/org/apache/ranger/biz/GdsDBStore.java
index 0112c34..d9c056a 100755
--- a/security-admin/src/main/java/org/apache/ranger/biz/GdsDBStore.java
+++ b/security-admin/src/main/java/org/apache/ranger/biz/GdsDBStore.java
@@ -26,7 +26,6 @@
 import org.apache.ranger.common.db.RangerTransactionSynchronizationAdapter;
 import org.apache.ranger.db.*;
 import org.apache.ranger.entity.*;
-import org.apache.ranger.plugin.model.RangerGds;
 import org.apache.ranger.plugin.model.RangerGds.*;
 import org.apache.ranger.plugin.model.RangerPolicy;
 import org.apache.ranger.plugin.model.RangerPolicy.RangerPolicyItem;
diff --git a/security-admin/src/main/java/org/apache/ranger/util/RangerAdminCache.java b/security-admin/src/main/java/org/apache/ranger/util/RangerAdminCache.java
index 569c113..2d5da7d 100644
--- a/security-admin/src/main/java/org/apache/ranger/util/RangerAdminCache.java
+++ b/security-admin/src/main/java/org/apache/ranger/util/RangerAdminCache.java
@@ -21,6 +21,8 @@
 
 import org.apache.ranger.authorization.hadoop.config.RangerAdminConfig;
 import org.apache.ranger.plugin.util.RangerCache;
+import org.apache.ranger.security.context.RangerContextHolder;
+import org.apache.ranger.security.context.RangerSecurityContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.transaction.PlatformTransactionManager;
@@ -48,6 +50,11 @@
         super(name, loader, loaderThreadsCount, refreshMode, valueValidityPeriodMs, valueInitLoadTimeoutMs, valueRefreshLoadTimeoutMs);
     }
 
+    @Override
+    public V get(K key)  {
+        return super.get(key, RangerContextHolder.getSecurityContext());
+    }
+
     private static int getLoaderThreadPoolSize(String cacheName) {
         return RangerAdminConfig.getInstance().getInt(PROP_PREFIX + cacheName + PROP_LOADER_THREAD_POOL_SIZE, DEFAULT_ADMIN_CACHE_LOADER_THREADS_COUNT);
     }
@@ -70,16 +77,28 @@
         }
 
         @Override
-        final public RefreshableValue<V> load(K key, RefreshableValue<V> currentValue) throws Exception {
+        final public RefreshableValue<V> load(K key, RefreshableValue<V> currentValue, Object context) throws Exception {
             Exception[] ex = new Exception[1];
 
             RefreshableValue<V> ret = txTemplate.execute(status -> {
+                RangerSecurityContext currentContext = null;
+
                 try {
+                    if (context instanceof RangerSecurityContext) {
+                        currentContext = RangerContextHolder.getSecurityContext();
+
+                        RangerContextHolder.setSecurityContext((RangerSecurityContext) context);
+                    }
+
                     return dbLoad(key, currentValue);
                 } catch (Exception excp) {
                     LOG.error("RangerDBLoaderCache.load(): failed to load for key={}", key, excp);
 
                     ex[0] = excp;
+                } finally {
+                    if (context instanceof RangerSecurityContext) {
+                        RangerContextHolder.setSecurityContext(currentContext);
+                    }
                 }
 
                 return null;