add Auth Context to store request header(Authorization) (#76)
diff --git a/pom.xml b/pom.xml
index 244ce64..e2a8639 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
<groupId>com.baidu.hugegraph</groupId>
<artifactId>hugegraph-common</artifactId>
- <version>1.8.8</version>
+ <version>1.8.9</version>
<name>hugegraph-common</name>
<url>https://github.com/hugegraph/hugegraph-common</url>
@@ -287,7 +287,7 @@
<manifestEntries>
<!-- Must be on one line, otherwise the automatic
upgrade script cannot replace the version number -->
- <Implementation-Version>1.8.8.0</Implementation-Version>
+ <Implementation-Version>1.8.9.0</Implementation-Version>
</manifestEntries>
</archive>
</configuration>
diff --git a/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java b/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java
index 3507169..1048fc1 100644
--- a/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java
+++ b/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java
@@ -46,7 +46,9 @@
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Variant;
+import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.http.HttpHeaders;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
@@ -234,15 +236,20 @@
@Override
public RestResult get(String path) {
- Response response = this.request(() -> {
- return this.target.path(path).request().get();
- });
- checkStatus(response, Response.Status.OK);
- return new RestResult(response);
+ return this.get(path, null, ImmutableMap.of());
}
@Override
public RestResult get(String path, Map<String, Object> params) {
+ return this.get(path, null, params);
+ }
+
+ @Override
+ public RestResult get(String path, String id) {
+ return this.get(path, id, ImmutableMap.of());
+ }
+
+ private RestResult get(String path, String id, Map<String, Object> params) {
Ref<WebTarget> target = Refs.of(this.target);
for (String key : params.keySet()) {
Object value = params.get(key);
@@ -254,41 +261,44 @@
target.set(target.get().queryParam(key, value));
}
}
- Response response = this.request(() -> {
- return target.get().path(path).request().get();
- });
- checkStatus(response, Response.Status.OK);
- return new RestResult(response);
- }
- @Override
- public RestResult get(String path, String id) {
Response response = this.request(() -> {
- return this.target.path(path).path(encode(id)).request().get();
+ WebTarget webTarget = target.get();
+ Builder builder = id == null ? webTarget.path(path).request() :
+ webTarget.path(path).path(encode(id)).request();
+ this.attachAuthToRequest(builder);
+ return builder.get();
});
+
checkStatus(response, Response.Status.OK);
return new RestResult(response);
}
@Override
public RestResult delete(String path, Map<String, Object> params) {
- Ref<WebTarget> target = Refs.of(this.target);
- for (String key : params.keySet()) {
- target.set(target.get().queryParam(key, params.get(key)));
- }
- Response response = this.request(() -> {
- return target.get().path(path).request().delete();
- });
- checkStatus(response, Response.Status.NO_CONTENT,
- Response.Status.ACCEPTED);
- return new RestResult(response);
+ return this.delete(path, null, params);
}
@Override
public RestResult delete(String path, String id) {
+ return this.delete(path, id, ImmutableMap.of());
+ }
+
+ private RestResult delete(String path, String id,
+ Map<String, Object> params) {
+ Ref<WebTarget> target = Refs.of(this.target);
+ for (String key : params.keySet()) {
+ target.set(target.get().queryParam(key, params.get(key)));
+ }
+
Response response = this.request(() -> {
- return this.target.path(path).path(encode(id)).request().delete();
+ WebTarget webTarget = target.get();
+ Builder builder = id == null ? webTarget.path(path).request() :
+ webTarget.path(path).path(encode(id)).request();
+ this.attachAuthToRequest(builder);
+ return builder.delete();
});
+
checkStatus(response, Response.Status.NO_CONTENT,
Response.Status.ACCEPTED);
return new RestResult(response);
@@ -303,6 +313,29 @@
this.client.close();
}
+ private final ThreadLocal<String> authContext =
+ new InheritableThreadLocal<>();
+
+ public void setAuthContext(String auth) {
+ this.authContext.set(auth);
+ }
+
+ public void resetAuthContext() {
+ this.authContext.remove();
+ }
+
+ public String getAuthContext() {
+ return this.authContext.get();
+ }
+
+ private void attachAuthToRequest(Builder builder) {
+ // Add auth header
+ String auth = this.getAuthContext();
+ if (StringUtils.isNotEmpty(auth)) {
+ builder.header(HttpHeaders.AUTHORIZATION, auth);
+ }
+ }
+
private Pair<Builder, Entity<?>> buildRequest(
String path, String id, Object object,
MultivaluedMap<String, Object> headers,
@@ -323,6 +356,8 @@
builder = builder.headers(headers);
encoding = (String) headers.getFirst("Content-Encoding");
}
+ // Add auth header
+ this.attachAuthToRequest(builder);
/*
* We should specify the encoding of the entity object manually,
diff --git a/src/main/java/com/baidu/hugegraph/version/CommonVersion.java b/src/main/java/com/baidu/hugegraph/version/CommonVersion.java
index 679ccff..773eb36 100644
--- a/src/main/java/com/baidu/hugegraph/version/CommonVersion.java
+++ b/src/main/java/com/baidu/hugegraph/version/CommonVersion.java
@@ -27,5 +27,5 @@
// The second parameter of Version.of() is for all-in-one JAR
public static final Version VERSION = Version.of(CommonVersion.class,
- "1.8.8");
+ "1.8.9");
}
diff --git a/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java b/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java
index 536ee94..4821f1d 100644
--- a/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java
+++ b/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java
@@ -122,6 +122,23 @@
}
@Test
+ public void testRowLockWithMultiThreadsLockOneKey() {
+ RowLock<Integer> lock = new RowLock<>();
+ Set<String> names = new HashSet<>(THREADS_NUM);
+
+ Assert.assertEquals(0, names.size());
+
+ Integer key = 1;
+ runWithThreads(THREADS_NUM, () -> {
+ lock.lock(key);
+ names.add(Thread.currentThread().getName());
+ lock.unlock(key);
+ });
+
+ Assert.assertEquals(THREADS_NUM, names.size());
+ }
+
+ @Test
public void testRowLockWithMultiThreadsWithRandomKey() {
RowLock<Integer> lock = new RowLock<>();
Set<String> names = new HashSet<>(THREADS_NUM);
diff --git a/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java b/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java
index ab3c419..a9f4d97 100644
--- a/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java
+++ b/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java
@@ -22,6 +22,7 @@
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Map;
+import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@@ -31,15 +32,20 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
+import javax.ws.rs.client.Entity;
+import javax.ws.rs.client.WebTarget;
+import javax.ws.rs.client.Invocation.Builder;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import org.apache.http.HttpClientConnection;
+import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
import org.apache.http.conn.routing.HttpRoute;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.pool.PoolStats;
+import org.glassfish.jersey.client.ClientRequest;
import org.glassfish.jersey.internal.util.collection.ImmutableMultivaluedMap;
import org.junit.Test;
import org.mockito.Mockito;
@@ -404,4 +410,160 @@
Assert.assertNotNull(cleanExecutor);
Assert.assertTrue(cleanExecutor.isShutdown());
}
+
+ @Test
+ public void testAuthContext() {
+ RestClientImpl client = new RestClientImpl("/test", 1000, 10, 5, 200);
+ Assert.assertNull(client.getAuthContext());
+
+ String token = UUID.randomUUID().toString();
+ client.setAuthContext(token);
+ Assert.assertEquals(token, client.getAuthContext());
+
+ client.resetAuthContext();
+ Assert.assertNull(client.getAuthContext());
+ }
+
+ private static class MockRestClientImpl extends AbstractRestClient {
+
+ public MockRestClientImpl(String url, int timeout) {
+ super(url, timeout);
+ }
+
+ @Override
+ protected void checkStatus(Response response,
+ Response.Status... statuses) {
+ // pass
+ }
+ }
+
+ @Test
+ public void testRequest() {
+ MockRestClientImpl client = new MockRestClientImpl("test", 1000);
+
+ WebTarget target = Mockito.mock(WebTarget.class);
+ Builder builder = Mockito.mock(Builder.class);
+
+ Mockito.when(target.path("test")).thenReturn(target);
+ Mockito.when(target.path("test")
+ .path(AbstractRestClient.encode("id")))
+ .thenReturn(target);
+ Mockito.when(target.path("test").request()).thenReturn(builder);
+ Mockito.when(target.path("test")
+ .path(AbstractRestClient.encode("id"))
+ .request())
+ .thenReturn(builder);
+
+ Response response = Mockito.mock(Response.class);
+ Mockito.when(response.getStatus()).thenReturn(200);
+ Mockito.when(response.getHeaders())
+ .thenReturn(new MultivaluedHashMap<>());
+ Mockito.when(response.readEntity(String.class)).thenReturn("content");
+
+ Mockito.when(builder.delete()).thenReturn(response);
+ Mockito.when(builder.get()).thenReturn(response);
+ Mockito.when(builder.put(Mockito.any())).thenReturn(response);
+ Mockito.when(builder.post(Mockito.any())).thenReturn(response);
+
+ Whitebox.setInternalState(client, "target", target);
+
+ RestResult result;
+
+ // Test delete
+ client.setAuthContext("token1");
+ result = client.delete("test", ImmutableMap.of());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token1");
+ client.resetAuthContext();
+
+ client.setAuthContext("token2");
+ result = client.delete("test", "id");
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token2");
+ client.resetAuthContext();
+
+ // Test get
+ client.setAuthContext("token3");
+ result = client.get("test");
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token3");
+ client.resetAuthContext();
+
+ client.setAuthContext("token4");
+ result = client.get("test", ImmutableMap.of());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token4");
+ client.resetAuthContext();
+
+ client.setAuthContext("token5");
+ result = client.get("test", "id");
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token5");
+ client.resetAuthContext();
+
+ // Test put
+ client.setAuthContext("token6");
+ result = client.post("test", new Object());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token6");
+ client.resetAuthContext();
+
+ client.setAuthContext("token7");
+ result = client.post("test", new Object(), new MultivaluedHashMap<>());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token7");
+ client.resetAuthContext();
+
+ client.setAuthContext("token8");
+ result = client.post("test", new Object(), ImmutableMap.of());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token8");
+ client.resetAuthContext();
+
+ client.setAuthContext("token9");
+ result = client.post("test", new Object(), new MultivaluedHashMap<>(),
+ ImmutableMap.of());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token9");
+ client.resetAuthContext();
+
+ // Test post
+ client.setAuthContext("token10");
+ result = client.post("test", new Object());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token10");
+ client.resetAuthContext();
+
+ client.setAuthContext("token11");
+ result = client.post("test", new Object(), new MultivaluedHashMap<>());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token11");
+ client.resetAuthContext();
+
+ client.setAuthContext("token12");
+ result = client.post("test", new Object(), ImmutableMap.of());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token12");
+ client.resetAuthContext();
+
+ client.setAuthContext("token13");
+ result = client.post("test", new Object(), new MultivaluedHashMap<>(),
+ ImmutableMap.of());
+ Assert.assertEquals(200, result.status());
+ Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+ "token13");
+ client.resetAuthContext();
+ }
}