QPID-8219: [Broker-J] Cache authentication results for the same remote hosts and credentials
diff --git a/broker-core/src/main/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacher.java b/broker-core/src/main/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacher.java
index c5cb157..b034916 100644
--- a/broker-core/src/main/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacher.java
+++ b/broker-core/src/main/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacher.java
@@ -19,6 +19,8 @@
package org.apache.qpid.server.security.auth.manager;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
@@ -117,7 +119,20 @@
if (connectionPrincipals != null && !connectionPrincipals.isEmpty())
{
SocketConnectionPrincipal connectionPrincipal = connectionPrincipals.iterator().next();
- md.update(connectionPrincipal.getRemoteAddress().toString().getBytes(UTF8));
+ SocketAddress remoteAddress = connectionPrincipal.getRemoteAddress();
+ String address;
+ if (remoteAddress instanceof InetSocketAddress)
+ {
+ address = ((InetSocketAddress) remoteAddress).getHostString();
+ }
+ else
+ {
+ address = remoteAddress.toString();
+ }
+ if (address != null)
+ {
+ md.update(address.getBytes(UTF8));
+ }
}
for (String part : content)
diff --git a/broker-core/src/test/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacherTest.java b/broker-core/src/test/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacherTest.java
index 27921a8..659fc91 100644
--- a/broker-core/src/test/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacherTest.java
+++ b/broker-core/src/test/java/org/apache/qpid/server/security/auth/manager/AuthenticationResultCacherTest.java
@@ -115,33 +115,36 @@
@Test
- public void testCacheMissDifferentAddress() throws Exception
+ public void testCacheMissDifferentRemoteAddressHosts() throws Exception
{
- Subject.doAs(_subject, new PrivilegedAction<Void>()
- {
- @Override
- public Void run()
- {
- AuthenticationResult result;
- result = _authenticationResultCacher.getOrLoad(new String[]{"credentials"}, _loader);
- assertEquals("Unexpected AuthenticationResult", _successfulAuthenticationResult, result);
- assertEquals("Unexpected number of loads before cache hit", (long) 1, (long) _loadCallCount);
- return null;
- }
- });
+ final String credentials = "credentials";
+ assertGetOrLoad(credentials, _successfulAuthenticationResult, 1);
+ when(_connection.getRemoteSocketAddress()).thenReturn(new InetSocketAddress("example2.com", 8888));
+ assertGetOrLoad(credentials, _successfulAuthenticationResult, 2);
+ }
+ @Test
+ public void testCacheHitDifferentRemoteAddressPorts() throws Exception
+ {
+ final int expectedHitCount = 1;
+ final AuthenticationResult expectedResult = _successfulAuthenticationResult;
+ final String credentials = "credentials";
+
+ assertGetOrLoad(credentials, expectedResult, expectedHitCount);
when(_connection.getRemoteSocketAddress()).thenReturn(new InetSocketAddress("example.com", 8888));
- Subject.doAs(_subject, new PrivilegedAction<Void>()
- {
- @Override
- public Void run()
- {
- AuthenticationResult result;
- result = _authenticationResultCacher.getOrLoad(new String[]{"credentials"}, _loader);
- assertEquals("Unexpected AuthenticationResult", _successfulAuthenticationResult, result);
- assertEquals("Unexpected number of loads before cache hit", (long) 2, (long) _loadCallCount);
- return null;
- }
+ assertGetOrLoad(credentials, expectedResult, expectedHitCount);
+ }
+
+ private void assertGetOrLoad(final String credentials,
+ final AuthenticationResult expectedResult,
+ final int expectedHitCount)
+ {
+ Subject.doAs(_subject, (PrivilegedAction<Void>) () -> {
+ AuthenticationResult result;
+ result = _authenticationResultCacher.getOrLoad(new String[]{credentials}, _loader);
+ assertEquals("Unexpected AuthenticationResult", expectedResult, result);
+ assertEquals("Unexpected number of loads before cache hit", (long)expectedHitCount, (long) _loadCallCount);
+ return null;
});
}
}