add Auth Context to store request header(Authorization) (#76)

diff --git a/pom.xml b/pom.xml
index 244ce64..e2a8639 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
 
     <groupId>com.baidu.hugegraph</groupId>
     <artifactId>hugegraph-common</artifactId>
-    <version>1.8.8</version>
+    <version>1.8.9</version>
 
     <name>hugegraph-common</name>
     <url>https://github.com/hugegraph/hugegraph-common</url>
@@ -287,7 +287,7 @@
                         <manifestEntries>
                             <!-- Must be on one line, otherwise the automatic
                                  upgrade script cannot replace the version number -->
-                            <Implementation-Version>1.8.8.0</Implementation-Version>
+                            <Implementation-Version>1.8.9.0</Implementation-Version>
                         </manifestEntries>
                     </archive>
                 </configuration>
diff --git a/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java b/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java
index 3507169..1048fc1 100644
--- a/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java
+++ b/src/main/java/com/baidu/hugegraph/rest/AbstractRestClient.java
@@ -46,7 +46,9 @@
 import javax.ws.rs.core.Response;
 import javax.ws.rs.core.Variant;
 
+import org.apache.commons.lang.StringUtils;
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.http.HttpHeaders;
 import org.apache.http.config.Registry;
 import org.apache.http.config.RegistryBuilder;
 import org.apache.http.conn.socket.ConnectionSocketFactory;
@@ -234,15 +236,20 @@
 
     @Override
     public RestResult get(String path) {
-        Response response = this.request(() -> {
-            return this.target.path(path).request().get();
-        });
-        checkStatus(response, Response.Status.OK);
-        return new RestResult(response);
+        return this.get(path, null, ImmutableMap.of());
     }
 
     @Override
     public RestResult get(String path, Map<String, Object> params) {
+        return this.get(path, null, params);
+    }
+
+    @Override
+    public RestResult get(String path, String id) {
+        return this.get(path, id, ImmutableMap.of());
+    }
+
+    private RestResult get(String path, String id, Map<String, Object> params) {
         Ref<WebTarget> target = Refs.of(this.target);
         for (String key : params.keySet()) {
             Object value = params.get(key);
@@ -254,41 +261,44 @@
                 target.set(target.get().queryParam(key, value));
             }
         }
-        Response response = this.request(() -> {
-            return target.get().path(path).request().get();
-        });
-        checkStatus(response, Response.Status.OK);
-        return new RestResult(response);
-    }
 
-    @Override
-    public RestResult get(String path, String id) {
         Response response = this.request(() -> {
-            return this.target.path(path).path(encode(id)).request().get();
+            WebTarget webTarget = target.get();
+            Builder builder = id == null ? webTarget.path(path).request() :
+                              webTarget.path(path).path(encode(id)).request();
+            this.attachAuthToRequest(builder);
+            return builder.get();
         });
+
         checkStatus(response, Response.Status.OK);
         return new RestResult(response);
     }
 
     @Override
     public RestResult delete(String path, Map<String, Object> params) {
-        Ref<WebTarget> target = Refs.of(this.target);
-        for (String key : params.keySet()) {
-            target.set(target.get().queryParam(key, params.get(key)));
-        }
-        Response response = this.request(() -> {
-            return target.get().path(path).request().delete();
-        });
-        checkStatus(response, Response.Status.NO_CONTENT,
-                    Response.Status.ACCEPTED);
-        return new RestResult(response);
+        return this.delete(path, null, params);
     }
 
     @Override
     public RestResult delete(String path, String id) {
+        return this.delete(path, id, ImmutableMap.of());
+    }
+
+    private RestResult delete(String path, String id,
+                              Map<String, Object> params) {
+        Ref<WebTarget> target = Refs.of(this.target);
+        for (String key : params.keySet()) {
+            target.set(target.get().queryParam(key, params.get(key)));
+        }
+
         Response response = this.request(() -> {
-            return this.target.path(path).path(encode(id)).request().delete();
+            WebTarget webTarget = target.get();
+            Builder builder = id == null ? webTarget.path(path).request() :
+                              webTarget.path(path).path(encode(id)).request();
+            this.attachAuthToRequest(builder);
+            return builder.delete();
         });
+
         checkStatus(response, Response.Status.NO_CONTENT,
                     Response.Status.ACCEPTED);
         return new RestResult(response);
@@ -303,6 +313,29 @@
         this.client.close();
     }
 
+    private final ThreadLocal<String> authContext =
+                                      new InheritableThreadLocal<>();
+
+    public void setAuthContext(String auth) {
+        this.authContext.set(auth);
+    }
+
+    public void resetAuthContext() {
+        this.authContext.remove();
+    }
+
+    public String getAuthContext() {
+        return this.authContext.get();
+    }
+
+    private void attachAuthToRequest(Builder builder) {
+        // Add auth header
+        String auth = this.getAuthContext();
+        if (StringUtils.isNotEmpty(auth)) {
+            builder.header(HttpHeaders.AUTHORIZATION, auth);
+        }
+    }
+
     private Pair<Builder, Entity<?>> buildRequest(
                                      String path, String id, Object object,
                                      MultivaluedMap<String, Object> headers,
@@ -323,6 +356,8 @@
             builder = builder.headers(headers);
             encoding = (String) headers.getFirst("Content-Encoding");
         }
+        // Add auth header
+        this.attachAuthToRequest(builder);
 
         /*
          * We should specify the encoding of the entity object manually,
diff --git a/src/main/java/com/baidu/hugegraph/version/CommonVersion.java b/src/main/java/com/baidu/hugegraph/version/CommonVersion.java
index 679ccff..773eb36 100644
--- a/src/main/java/com/baidu/hugegraph/version/CommonVersion.java
+++ b/src/main/java/com/baidu/hugegraph/version/CommonVersion.java
@@ -27,5 +27,5 @@
 
     // The second parameter of Version.of() is for all-in-one JAR
     public static final Version VERSION = Version.of(CommonVersion.class,
-                                                     "1.8.8");
+                                                     "1.8.9");
 }
diff --git a/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java b/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java
index 536ee94..4821f1d 100644
--- a/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java
+++ b/src/test/java/com/baidu/hugegraph/unit/concurrent/RowLockTest.java
@@ -122,6 +122,23 @@
     }
 
     @Test
+    public void testRowLockWithMultiThreadsLockOneKey() {
+        RowLock<Integer> lock = new RowLock<>();
+        Set<String> names = new HashSet<>(THREADS_NUM);
+
+        Assert.assertEquals(0, names.size());
+
+        Integer key = 1;
+        runWithThreads(THREADS_NUM, () -> {
+            lock.lock(key);
+            names.add(Thread.currentThread().getName());
+            lock.unlock(key);
+        });
+
+        Assert.assertEquals(THREADS_NUM, names.size());
+    }
+
+    @Test
     public void testRowLockWithMultiThreadsWithRandomKey() {
         RowLock<Integer> lock = new RowLock<>();
         Set<String> names = new HashSet<>(THREADS_NUM);
diff --git a/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java b/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java
index ab3c419..a9f4d97 100644
--- a/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java
+++ b/src/test/java/com/baidu/hugegraph/unit/rest/RestClientTest.java
@@ -22,6 +22,7 @@
 import java.security.NoSuchAlgorithmException;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.UUID;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
@@ -31,15 +32,20 @@
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLSession;
 import javax.net.ssl.SSLSessionContext;
+import javax.ws.rs.client.Entity;
+import javax.ws.rs.client.WebTarget;
+import javax.ws.rs.client.Invocation.Builder;
 import javax.ws.rs.core.MultivaluedHashMap;
 import javax.ws.rs.core.MultivaluedMap;
 import javax.ws.rs.core.Response;
 
 import org.apache.http.HttpClientConnection;
+import org.apache.http.HttpHeaders;
 import org.apache.http.HttpHost;
 import org.apache.http.conn.routing.HttpRoute;
 import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
 import org.apache.http.pool.PoolStats;
+import org.glassfish.jersey.client.ClientRequest;
 import org.glassfish.jersey.internal.util.collection.ImmutableMultivaluedMap;
 import org.junit.Test;
 import org.mockito.Mockito;
@@ -404,4 +410,160 @@
         Assert.assertNotNull(cleanExecutor);
         Assert.assertTrue(cleanExecutor.isShutdown());
     }
+
+    @Test
+    public void testAuthContext() {
+        RestClientImpl client = new RestClientImpl("/test", 1000, 10, 5, 200);
+        Assert.assertNull(client.getAuthContext());
+
+        String token = UUID.randomUUID().toString();
+        client.setAuthContext(token);
+        Assert.assertEquals(token, client.getAuthContext());
+
+        client.resetAuthContext();
+        Assert.assertNull(client.getAuthContext());
+    }
+
+    private static class MockRestClientImpl extends AbstractRestClient {
+
+        public MockRestClientImpl(String url, int timeout) {
+            super(url, timeout);
+        }
+
+        @Override
+        protected void checkStatus(Response response,
+                                   Response.Status... statuses) {
+            // pass
+        }
+    }
+
+    @Test
+    public void testRequest() {
+        MockRestClientImpl client = new MockRestClientImpl("test", 1000);
+
+        WebTarget target = Mockito.mock(WebTarget.class);
+        Builder builder = Mockito.mock(Builder.class);
+
+        Mockito.when(target.path("test")).thenReturn(target);
+        Mockito.when(target.path("test")
+                           .path(AbstractRestClient.encode("id")))
+               .thenReturn(target);
+        Mockito.when(target.path("test").request()).thenReturn(builder);
+        Mockito.when(target.path("test")
+                           .path(AbstractRestClient.encode("id"))
+                           .request())
+               .thenReturn(builder);
+
+        Response response = Mockito.mock(Response.class);
+        Mockito.when(response.getStatus()).thenReturn(200);
+        Mockito.when(response.getHeaders())
+               .thenReturn(new MultivaluedHashMap<>());
+        Mockito.when(response.readEntity(String.class)).thenReturn("content");
+
+        Mockito.when(builder.delete()).thenReturn(response);
+        Mockito.when(builder.get()).thenReturn(response);
+        Mockito.when(builder.put(Mockito.any())).thenReturn(response);
+        Mockito.when(builder.post(Mockito.any())).thenReturn(response);
+
+        Whitebox.setInternalState(client, "target", target);
+
+        RestResult result;
+
+        // Test delete
+        client.setAuthContext("token1");
+        result = client.delete("test", ImmutableMap.of());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token1");
+        client.resetAuthContext();
+
+        client.setAuthContext("token2");
+        result = client.delete("test", "id");
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token2");
+        client.resetAuthContext();
+
+        // Test get
+        client.setAuthContext("token3");
+        result = client.get("test");
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token3");
+        client.resetAuthContext();
+
+        client.setAuthContext("token4");
+        result = client.get("test", ImmutableMap.of());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token4");
+        client.resetAuthContext();
+
+        client.setAuthContext("token5");
+        result = client.get("test", "id");
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token5");
+        client.resetAuthContext();
+
+        // Test put
+        client.setAuthContext("token6");
+        result = client.post("test", new Object());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token6");
+        client.resetAuthContext();
+
+        client.setAuthContext("token7");
+        result = client.post("test", new Object(), new MultivaluedHashMap<>());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token7");
+        client.resetAuthContext();
+
+        client.setAuthContext("token8");
+        result = client.post("test", new Object(), ImmutableMap.of());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token8");
+        client.resetAuthContext();
+
+        client.setAuthContext("token9");
+        result = client.post("test", new Object(), new MultivaluedHashMap<>(),
+                             ImmutableMap.of());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token9");
+        client.resetAuthContext();
+
+        // Test post
+        client.setAuthContext("token10");
+        result = client.post("test", new Object());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token10");
+        client.resetAuthContext();
+
+        client.setAuthContext("token11");
+        result = client.post("test", new Object(), new MultivaluedHashMap<>());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token11");
+        client.resetAuthContext();
+
+        client.setAuthContext("token12");
+        result = client.post("test", new Object(), ImmutableMap.of());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token12");
+        client.resetAuthContext();
+
+        client.setAuthContext("token13");
+        result = client.post("test", new Object(), new MultivaluedHashMap<>(),
+                             ImmutableMap.of());
+        Assert.assertEquals(200, result.status());
+        Mockito.verify(builder).header(HttpHeaders.AUTHORIZATION,
+                                       "token13");
+        client.resetAuthContext();
+    }
 }