IMPALA-14083: Connected user and session user mismatch when cookie based
authentication is used with SPNEGO
IMPALA-11298 allowed comparing short user name for connected user and
session user to support proxy clients like Hue which could potentially
use different physical hosts for queries/requests from the same session.
When cookie based authentication is used, the 'kerberos_user_short' is
not set on the ConnectionContext and as a result 'connected_user_short'
is not set in SessionState. This can cause a mismatch when comparing
short user names from ConnectionContext and SessionState. This happens
because the original connection authenticated using SPNEGO will have
'kerberos_user_short' in the ConnectionContext, while the other
connections authenticated using cookies won't have 'kerberos_user_short'
set in the ConnectionContext.
This patch addresses this issue by setting 'kerberos_user_short' in
ConnectionContext, when using auth cookies generated post SPNEGO. This
information is retrieved from 'impala.auth' cookie itself, which now
also stores the 'a=<AUTH_MECHANISM>' in the cookie's value.
Testing:
- Added a SpnegoAuthTest which simulates 'knox' like proxy client and
uses SPNEGO to connect to Impala and also uses authentication cookies.
The test runs concurrent sql clients similar to real world scenarios.
Without the fix the test fails with error:
The user authorized on the connection '<username>' does not match the
session username ''
Change-Id: Id7223e449c32484bfd2295f7a9e728b7c02637e9
Reviewed-on: http://gerrit.cloudera.org:8080/22986
Tested-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
Reviewed-by: Jason Fehr <jfehr@cloudera.com>
diff --git a/be/src/rpc/authentication-util.cc b/be/src/rpc/authentication-util.cc
index 05ddc09..8538289 100644
--- a/be/src/rpc/authentication-util.cc
+++ b/be/src/rpc/authentication-util.cc
@@ -51,6 +51,7 @@
static const string USERNAME_KEY = "u=";
static const string TIMESTAMP_KEY = "t=";
static const string RAND_KEY = "r=";
+static const string AUTH_MECH_KEY = "a=";
// Cookies generated and processed by the HTTP server will be of the form:
// COOKIE_NAME=<cookie>
@@ -68,7 +69,7 @@
Status AuthenticateCookie(
const AuthenticationHash& hash, const string& cookie_header,
- string* username, string* rand) {
+ string* username, string* authMech, string* rand) {
// The 'Cookie' header allows sending multiple name/value pairs separated by ';'.
vector<string> cookies = strings::Split(cookie_header, ";");
if (cookies.size() > MAX_COOKIES_TO_CHECK) {
@@ -107,9 +108,9 @@
return Status("The signature is incorrect.");
}
- // Split the cookie value into username, timestamp, and random number.
+ // Split the cookie value into username, timestamp, random number and auth mechanism.
vector<string> cookie_value_split = Split(cookie_value, COOKIE_SEPARATOR);
- if (cookie_value_split.size() != 3) {
+ if (cookie_value_split.size() != 4) {
return Status("The cookie value has an invalid format.");
}
string timestamp;
@@ -133,6 +134,9 @@
return Status("The cookie rand value has an invalid format.");
}
}
+ if (!TryStripPrefixString(cookie_value_split[3], AUTH_MECH_KEY, authMech)) {
+ return Status("The cookie authMech value has an invalid format.");
+ }
// We've successfully authenticated.
return Status::OK();
} else {
@@ -144,7 +148,7 @@
}
string GenerateCookie(const string& username, const AuthenticationHash& hash,
- std::string* srand) {
+ const std::string& authMech, std::string* srand) {
// Its okay to use rand() here even though its a weak RNG because being able to guess
// the random numbers generated won't help an attacker. The important thing is that
// we're using a strong RNG to create the key and a strong HMAC function.
@@ -154,7 +158,8 @@
*srand = cookie_rand_s;
}
string cookie_value = StrCat(USERNAME_KEY, username, COOKIE_SEPARATOR, TIMESTAMP_KEY,
- MonotonicMillis(), COOKIE_SEPARATOR, RAND_KEY, cookie_rand_s);
+ MonotonicMillis(), COOKIE_SEPARATOR, RAND_KEY, cookie_rand_s, COOKIE_SEPARATOR,
+ AUTH_MECH_KEY, authMech);
uint8_t signature[AuthenticationHash::HashLen()];
Status compute_status =
hash.Compute(reinterpret_cast<const uint8_t*>(cookie_value.data()),
diff --git a/be/src/rpc/authentication-util.h b/be/src/rpc/authentication-util.h
index eec4976..37a62b3 100644
--- a/be/src/rpc/authentication-util.h
+++ b/be/src/rpc/authentication-util.h
@@ -23,17 +23,30 @@
class AuthenticationHash;
+// HTTP Auth Mechanisms used for generating Auth Cookies
+inline const std::string HTTP_AUTH_MECH_HTPASSWD = "HTPASSWD";
+inline const std::string HTTP_AUTH_MECH_LDAP = "LDAP";
+inline const std::string HTTP_AUTH_MECH_TRUSTED_DOMAIN = "TRUSTED_DOMAIN";
+inline const std::string HTTP_AUTH_MECH_TRUSTED_HEADER = "TRUSTED_HEADER";
+inline const std::string HTTP_AUTH_MECH_SPNEGO = "SPNEGO";
+inline const std::string HTTP_AUTH_MECH_SAML = "SAML";
+inline const std::string HTTP_AUTH_MECH_JWT = "JWT";
+inline const std::string HTTP_AUTH_MECH_OAUTH = "OAUTH";
+
// Takes a single 'key=value' pair from a 'Cookie' header and attempts to verify its
// signature with 'hash'. If verification is successful and the cookie is still valid,
-// sets 'username' and 'rand' (if specified) to the corresponding values and returns OK.
+// sets 'username', 'authMech' and 'rand' (if specified) to the corresponding values
+// and returns OK.
Status AuthenticateCookie(
const AuthenticationHash& hash, const std::string& cookie_header,
- std::string* username, std::string* rand = nullptr);
+ std::string* username, std::string* authMech, std::string* rand = nullptr);
-// Generates and returns a cookie containing the username set on 'connection_context' and
-// a signature generated with 'hash'. If specified, sets 'rand' to the 'r=' cookie value.
+// Generates and returns a cookie containing the username set on 'connection_context',
+// a signature generated with 'hash' and the authentication mechanism used for
+// authenticating the given user with username. If specified, sets 'rand' to the 'r='
+// cookie value.
std::string GenerateCookie(const std::string& username, const AuthenticationHash& hash,
- std::string* rand = nullptr);
+ const std::string& authMech, std::string* rand = nullptr);
// Returns a empty cookie. Returned in a 'Set-Cookie' when cookie auth fails to indicate
// to the client that the cookie should be deleted.
diff --git a/be/src/rpc/authentication.cc b/be/src/rpc/authentication.cc
index 7f17393..8a879d5 100644
--- a/be/src/rpc/authentication.cc
+++ b/be/src/rpc/authentication.cc
@@ -662,9 +662,18 @@
bool CookieAuth(ThriftServer::ConnectionContext* connection_context,
const AuthenticationHash& hash, const std::string& cookie_header) {
string username;
- Status cookie_status = AuthenticateCookie(hash, cookie_header, &username);
+ string authMech;
+ Status cookie_status = AuthenticateCookie(hash, cookie_header, &username, &authMech);
if (cookie_status.ok()) {
connection_context->username = username;
+ if (authMech == HTTP_AUTH_MECH_SPNEGO) {
+ connection_context->kerberos_user_principal = username;
+ connection_context->kerberos_user_short =
+ GetShortUsernameFromKerberosPrincipal(username);
+ VLOG(2) << "Connection authenticated with "
+ << "short username \"" << connection_context->kerberos_user_short << "\" "
+ << "parsed from principal \"" << username << "\" ";
+ }
return true;
}
@@ -723,7 +732,9 @@
}
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(connection_context->username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(connection_context->username, hash,
+ HTTP_AUTH_MECH_TRUSTED_DOMAIN)));
return true;
}
@@ -734,7 +745,9 @@
}
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(connection_context->username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(connection_context->username, hash,
+ HTTP_AUTH_MECH_TRUSTED_HEADER)));
return true;
}
@@ -755,7 +768,8 @@
connection_context->username = username;
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(username, hash, HTTP_AUTH_MECH_LDAP)));
if (!FLAGS_test_cookie.empty()) {
connection_context->return_headers.push_back(
Substitute("Set-Cookie: $0", FLAGS_test_cookie));
@@ -803,7 +817,7 @@
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0", GenerateCookie(username, hash, HTTP_AUTH_MECH_JWT)));
return true;
}
@@ -845,7 +859,7 @@
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0", GenerateCookie(username, hash, HTTP_AUTH_MECH_OAUTH)));
return true;
}
@@ -913,7 +927,8 @@
connection_context->kerberos_user_short = short_user;
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(username, hash, HTTP_AUTH_MECH_SPNEGO)));
}
}
} else {
@@ -1035,7 +1050,7 @@
connection_context->username = username;
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0", GenerateCookie(username, hash, HTTP_AUTH_MECH_SAML)));
return true;
}
diff --git a/be/src/util/webserver-test.cc b/be/src/util/webserver-test.cc
index 7123dba..ad8d44c 100644
--- a/be/src/util/webserver-test.cc
+++ b/be/src/util/webserver-test.cc
@@ -460,20 +460,37 @@
const filesystem::path& path() { return path_; }
string token() {
const char* rand_key = "&r=";
+ const char* auth_mech_key = "&a=";
string rand, line;
ifstream cookie_file(path_.string());
while (cookie_file) {
getline(cookie_file, line);
size_t rand_idx = line.rfind(rand_key);
- if (rand_idx != string::npos) {
- // Relies on the random value being the last element in the cookie.
- rand = line.substr(rand_idx + strlen(rand_key));
+ size_t auth_mech_idx = line.rfind(auth_mech_key);
+ if (rand_idx != string::npos && auth_mech_idx != string::npos) {
+ // Relies on the random value being followed by auth mech in the cookie.
+ size_t rand_val_idx = rand_idx + strlen(rand_key);
+ rand = line.substr(rand_val_idx, auth_mech_idx - rand_val_idx);
break;
}
}
return rand;
}
-
+ string auth_mech() {
+ const char* auth_mech_key = "&a=";
+ string authmech, line;
+ ifstream cookie_file(path_.string());
+ while (cookie_file) {
+ getline(cookie_file, line);
+ size_t auth_mech_idx = line.rfind(auth_mech_key);
+ if (auth_mech_idx != string::npos) {
+ // Relies on auth mech being the last element in the cookie.
+ authmech = line.substr(auth_mech_idx + strlen(auth_mech_key));
+ break;
+ }
+ }
+ return authmech;
+ }
private:
const filesystem::path dir_, path_;
};
@@ -531,6 +548,8 @@
// curl does not do the initial attempt without authentication, so there is no
// additional failed auth attempt.
CheckAuthMetrics(&metrics, 1, (curl_7_64_or_above ? 1 : 2), 1, 0);
+ // Validate authentication mechanism stored in the cookie
+ ASSERT_EQ(cookie.auth_mech(), "SPNEGO");
webserver.Stop();
MetricGroup metrics2("webserver-test");
@@ -578,6 +597,8 @@
CookieJar cookie;
// GET with SPNEGO succeeds and returns a cookie.
ASSERT_EQ(system(curl("--negotiate -u : -c " + cookie.path().string()).c_str()), 0);
+ // Validate authentication mechanism stored in the cookie.
+ ASSERT_EQ(cookie.auth_mech(), "SPNEGO");
// Verify we got a cookie and can read the random token.
string token = cookie.token();
ASSERT_FALSE(token.empty());
@@ -629,13 +650,14 @@
// GET with user and password succeeds and returns a cookie.
ASSERT_EQ(system(curl(Substitute("--digest -u test:test -c $0",
cookie.path().string())).c_str()), 0);
+ // Validate authentication mechanism stored in the cookie
+ ASSERT_EQ(cookie.auth_mech(), "HTPASSWD");
// Verify we got a cookie and can read the random token.
string token = cookie.token();
ASSERT_FALSE(token.empty());
// Post with the cookie fails due to CSRF protection.
ASSERT_EQ(curl_status_code(Substitute("--digest -u test:test -b $0 -d ''",
cookie.path().string()).c_str()), "403");
-
// Include the cookie's random token as csrf_token and request should succeed.
ASSERT_EQ(system(curl(Substitute("--digest -u test:test -b $0 -d 'csrf_token=$1'",
cookie.path().string(), token)).c_str()), 0);
diff --git a/be/src/util/webserver.cc b/be/src/util/webserver.cc
index 13923a1..70b190d 100644
--- a/be/src/util/webserver.cc
+++ b/be/src/util/webserver.cc
@@ -738,9 +738,11 @@
if (!authenticated && use_cookies_) {
const char* cookie_header = sq_get_header(connection, "Cookie");
string username;
+ string auth_mech;
if (cookie_header != nullptr) {
Status cookie_status =
- AuthenticateCookie(hash_, cookie_header, &username, &cookie_rand_value);
+ AuthenticateCookie(hash_, cookie_header, &username, &auth_mech,
+ &cookie_rand_value);
if (cookie_status.ok()) {
authenticated = true;
cookie_authenticated = true;
@@ -760,7 +762,8 @@
// as browsers automatically include HTPASSWD credentials in requests, so add and use
// cookies to avoid requiring the custom header.
authenticated = true;
- AddCookie(request_info->remote_user, &response_headers, &cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers, HTTP_AUTH_MECH_HTPASSWD,
+ &cookie_rand_value);
}
// Connections originating from trusted domains should not require authentication.
@@ -788,7 +791,8 @@
if (TrustedDomainCheck(origin, connection, request_info)) {
total_trusted_domain_check_success_->Increment(1);
authenticated = true;
- AddCookie(request_info->remote_user, &response_headers, &cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_TRUSTED_DOMAIN, &cookie_rand_value);
}
}
}
@@ -801,7 +805,8 @@
if (GetUsernameFromAuthHeader(connection, request_info, err_msg)) {
total_trusted_auth_header_check_success_->Increment(1);
authenticated = true;
- AddCookie(request_info->remote_user, &response_headers, &cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_TRUSTED_HEADER, &cookie_rand_value);
} else {
LOG(ERROR) << "Found trusted auth header but " << err_msg;
}
@@ -814,7 +819,8 @@
HandleSpnego(connection, request_info, &response_headers);
if (spnego_result == SQ_CONTINUE_HANDLING) {
// Spnego negotiation was successful.
- AddCookie(request_info->remote_user, &response_headers, &cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_SPNEGO, &cookie_rand_value);
} else {
// Spnego negotiation is incomplete or failed, stop processing the request.
return spnego_result;
@@ -825,7 +831,8 @@
if (basic_status.ok()) {
// Basic auth was successful.
total_basic_auth_success_->Increment(1);
- AddCookie(request_info->remote_user, &response_headers, &cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_LDAP, &cookie_rand_value);
} else {
total_basic_auth_failure_->Increment(1);
if (!sq_get_header(connection, "Authorization")) {
@@ -1158,7 +1165,7 @@
}
void Webserver::AddCookie(const char* user, vector<string>* response_headers,
- string* cookie_rand_value) {
+ const string& authMech, string* cookie_rand_value) {
if (use_cookies_) {
// If cookie auth failed and we generated a 'delete cookie' header, remove it.
auto eq = [](const string& header) { return header.rfind("Set-Cookie", 0) == 0; };
@@ -1168,7 +1175,7 @@
}
// Generate a cookie to return.
response_headers->push_back(Substitute("Set-Cookie: $0",
- GenerateCookie(user, hash_, cookie_rand_value)));
+ GenerateCookie(user, hash_, authMech, cookie_rand_value)));
}
}
diff --git a/be/src/util/webserver.h b/be/src/util/webserver.h
index bfa71f1..eb34c31 100644
--- a/be/src/util/webserver.h
+++ b/be/src/util/webserver.h
@@ -221,7 +221,8 @@
// Adds a 'Set-Cookie' header to 'response_headers', if cookie support is enabled.
// Returns the random value portion of the cookie in 'rand' for use in CSRF prevention.
- void AddCookie(const char* user, vector<string>* response_headers, string* rand);
+ void AddCookie(const char* user, vector<string>* response_headers,
+ const string& authMech, string* rand);
// Get username from Authorization header.
bool GetUsernameFromAuthHeader(struct sq_connection* connection,
diff --git a/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java b/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
index e43d6ed..d622028 100644
--- a/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
+++ b/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
@@ -53,6 +53,10 @@
this.testFolder = testFolder;
}
+ public String getTestFolderPath() throws IOException {
+ return testFolder.getRoot().getCanonicalPath();
+ }
+
@Override
protected void before() throws Throwable {
testFolder.create();
diff --git a/fe/src/test/java/org/apache/impala/customcluster/SpnegoAuthTest.java b/fe/src/test/java/org/apache/impala/customcluster/SpnegoAuthTest.java
new file mode 100644
index 0000000..09aec13
--- /dev/null
+++ b/fe/src/test/java/org/apache/impala/customcluster/SpnegoAuthTest.java
@@ -0,0 +1,415 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.impala.customcluster;
+
+import static org.apache.impala.testutil.LdapUtil.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.directory.server.core.annotations.CreateDS;
+import org.apache.directory.server.core.annotations.CreatePartition;
+import org.apache.directory.server.annotations.CreateLdapServer;
+import org.apache.directory.server.annotations.CreateTransport;
+import org.apache.directory.server.core.annotations.ApplyLdifFiles;
+import org.apache.directory.server.core.integ.CreateLdapServerRule;
+import org.apache.hive.service.rpc.thrift.*;
+import org.apache.impala.testutil.WebClient;
+import org.apache.thrift.transport.THttpClient;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.ietf.jgss.*;
+import org.junit.ClassRule;
+import org.junit.rules.TemporaryFolder;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@CreateDS(name = "myDS",
+ partitions = { @CreatePartition(name = "test", suffix = "dc=myorg,dc=com") })
+@CreateLdapServer(
+ transports = { @CreateTransport(protocol = "LDAP", address = "localhost") })
+@ApplyLdifFiles({"users.ldif"})
+/**
+ * Tests that hiveserver2 operations over the http interface work as expected when
+ * SPNEGO authentication is being used.
+ */
+public class SpnegoAuthTest {
+ private static final Logger LOG = LoggerFactory.getLogger(SpnegoAuthTest.class);
+
+ @ClassRule
+ public static CreateLdapServerRule serverRule = new CreateLdapServerRule();
+ @ClassRule
+ public static KerberosKdcEnvironment kerberosKdcEnvironment =
+ new KerberosKdcEnvironment(new TemporaryFolder());
+
+ WebClient client_ = new WebClient();
+
+ protected Map<String, String> getLdapFlags() {
+ String ldapUri = String.format("ldap://localhost:%s",
+ serverRule.getLdapServer().getPort());
+ String passwordCommand = String.format("'echo -n %s'", TEST_PASSWORD_1);
+ return ImmutableMap.<String, String>builder()
+ .put("enable_ldap_auth", "true")
+ .put("ldap_uri", ldapUri)
+ .put("ldap_bind_pattern", "cn=#UID,ou=Users,dc=myorg,dc=com")
+ .put("ldap_passwords_in_clear_ok", "true")
+ .put("ldap_bind_dn", TEST_USER_DN_1)
+ .put("ldap_bind_password_cmd", passwordCommand)
+ .build();
+ }
+
+ protected int startImpalaCluster(String args) throws IOException, InterruptedException {
+ return kerberosKdcEnvironment.startImpalaClusterWithArgs(args);
+ }
+
+ public static String flagsToArgs(Map<String, String> flags) {
+ return flags.entrySet().stream()
+ .map(entry -> "--" + entry.getKey() + "=" + entry.getValue() + " ")
+ .collect(Collectors.joining());
+ }
+
+ @SafeVarargs
+ public static Map<String, String> mergeFlags(Map<String, String>... flags) {
+ return Arrays.stream(flags)
+ .filter(Objects::nonNull)
+ .flatMap(map -> map.entrySet().stream())
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ }
+
+ static void verifySuccess(TStatus status) throws Exception {
+ if (status.getStatusCode() == TStatusCode.SUCCESS_STATUS
+ || status.getStatusCode() == TStatusCode.SUCCESS_WITH_INFO_STATUS) {
+ return;
+ }
+ throw new Exception(status.toString());
+ }
+
+ /**
+ * Executes 'query', fetches the results and closes the 'query'. Expects there to be
+ * exactly one string returned, which be be equal to 'expectedResult'.
+ */
+ static void execAndFetch(TCLIService.Iface client,
+ TSessionHandle sessionHandle, String query, String expectedResult)
+ throws Exception {
+ TOperationHandle handle = null;
+ try {
+ TExecuteStatementReq execReq = new TExecuteStatementReq(sessionHandle, query);
+ TExecuteStatementResp execResp = client.ExecuteStatement(execReq);
+ verifySuccess(execResp.getStatus());
+ handle = execResp.getOperationHandle();
+
+ TFetchResultsReq fetchReq = new TFetchResultsReq(
+ handle, TFetchOrientation.FETCH_NEXT, 1000);
+ TFetchResultsResp fetchResp = client.FetchResults(fetchReq);
+ verifySuccess(fetchResp.getStatus());
+ List<TColumn> columns = fetchResp.getResults().getColumns();
+ assertEquals(1, columns.size());
+ if (expectedResult != null) {
+ assertEquals(expectedResult, columns.get(0).getStringVal().getValues().get(0));
+ }
+ } finally {
+ if (handle != null) {
+ TCloseOperationReq closeReq = new TCloseOperationReq(handle);
+ TCloseOperationResp closeResp = client.CloseOperation(closeReq);
+ verifySuccess(closeResp.getStatus());
+ }
+ }
+ }
+
+ private void verifyNegotiateAuthMetrics(
+ long expectedBasicAuthSuccess, long expectedBasicAuthFailure) throws Exception {
+ long actualBasicAuthSuccess = (long) client_.getMetric(
+ "impala.thrift-server.hiveserver2-http-frontend.total-negotiate-auth-success");
+ assertEquals(expectedBasicAuthSuccess, actualBasicAuthSuccess);
+ long actualBasicAuthFailure = (long) client_.getMetric(
+ "impala.thrift-server.hiveserver2-http-frontend.total-negotiate-auth-failure");
+ assertEquals(expectedBasicAuthFailure, actualBasicAuthFailure);
+ }
+
+ private void verifyCookieAuthMetrics(
+ long expectedCookieAuthSuccess, long expectedCookieAuthFailure) throws Exception {
+ long actualCookieAuthSuccess = (long) client_.getMetric(
+ "impala.thrift-server.hiveserver2-http-frontend.total-cookie-auth-success");
+ assertEquals(expectedCookieAuthSuccess, actualCookieAuthSuccess);
+ long actualCookieAuthFailure = (long) client_.getMetric(
+ "impala.thrift-server.hiveserver2-http-frontend.total-cookie-auth-failure");
+ assertEquals(expectedCookieAuthFailure, actualCookieAuthFailure);
+ }
+
+ @Test
+ /**
+ * Tests Authentication flow using a proxy client such as Knox, which uses SPNEGO Auth
+ * to connect to Impala and impersonates other users. Initial Authentication is done
+ * through SPNEGO and follow on requests are authenticated using Auth cookies. The test
+ * uses multiple clients sharing the same Auth cookie similar to what a proxy client
+ * would do and as a result adds coverage for interesting scenarios where OpenSession
+ * RPC could also use Auth Cookies.
+ */
+ public void testImpersonation() throws Exception, Throwable {
+ Map<String, String> flags = mergeFlags(
+ // enable Kerberos authentication
+ kerberosKdcEnvironment.getKerberosAuthFlags(),
+ getLdapFlags(),
+ // custom LDAP filters
+ ImmutableMap.of(
+ "ldap_group_filter", String.format("%s,another-group", TEST_USER_GROUP),
+ "ldap_user_filter", String.format("%s,%s,another-user",
+ TEST_USER_1, TEST_USER_3),
+ "ldap_group_dn_pattern", GROUP_DN_PATTERN,
+ "ldap_group_membership_key", "uniqueMember",
+ "ldap_group_class_key", "groupOfUniqueNames",
+ "allow_custom_ldap_filters_with_kerberos_auth", "true",
+ // set proxy user: allow TEST_USER_4 to act as a proxy user for others
+ "authorized_proxy_user_config", String.format("%s=*", TEST_USER_4)
+ )
+ );
+ // Start Impala with configured flags.
+ int ret = startImpalaCluster(flagsToArgs(flags));
+ assertEquals(0, ret); // cluster should start up
+
+ // Open a session and authenticate using SPNEGO.
+ THttpClientWithHeaders transport =
+ new THttpClientWithHeaders("http://localhost:28000");
+ Map<String, String> headers = new HashMap<String, String>();
+ // Authenticate as the proxy user 'Test4Ldap'
+ headers.put("Authorization", "Negotiate " + getSpnegoToken(TEST_USER_4));
+ transport.setCustomHeaders(headers);
+ transport.open();
+ TCLIService.Iface client = new TCLIService.Client(new TBinaryProtocol(transport));
+
+ // Open a session without specifying a 'doas', should fail as the proxy user won't
+ // pass the filters.
+ TOpenSessionReq openReq = new TOpenSessionReq();
+ TOpenSessionResp openResp = client.OpenSession(openReq);
+ assertEquals(TStatusCode.ERROR_STATUS, openResp.getStatus().getStatusCode());
+ int negotiateAuthFailureCount = 0;
+ int negotiateAuthSuccessCount = 1;
+ verifyNegotiateAuthMetrics(negotiateAuthSuccessCount, negotiateAuthFailureCount);
+ int cookieAuthFailureCount = 0;
+ int cookieAuthSuccessCount = 0;
+ verifyCookieAuthMetrics(cookieAuthSuccessCount, cookieAuthFailureCount);
+
+ // SPNEGO doesn't like replay tokens, so use new tokens.
+ headers.remove("Authorization");
+ headers.put("Authorization", "Negotiate " + getSpnegoToken(TEST_USER_4));
+ // Open a session with a 'doas' that will pass both filters, should succeed.
+ Map<String, String> config = new HashMap<String, String>();
+ config.put("impala.doas.user", TEST_USER_1);
+ openReq.setConfiguration(config);
+ openResp = client.OpenSession(openReq);
+ assertEquals(TStatusCode.SUCCESS_STATUS, openResp.getStatus().getStatusCode());
+ negotiateAuthSuccessCount++;
+ verifyNegotiateAuthMetrics(negotiateAuthSuccessCount, negotiateAuthFailureCount);
+ verifyCookieAuthMetrics(cookieAuthSuccessCount, cookieAuthFailureCount);
+
+ // Use Auth Cookie for the remaining sessions and connections.
+ Map<String, List<String>> responseHeaders = transport.getResponseHeaders();
+ List<String> cookies = responseHeaders.get("Set-Cookie");
+ if (cookies != null) {
+ for (String cookie : cookies) {
+ String authMech = extractCookieAuthMech(cookie);
+ assertNotNull(authMech);
+ assertEquals("SPNEGO", authMech);
+ headers.put("Cookie", cookie);
+ }
+ } else {
+ fail("'Set-Cookie' cookie not returned from Impala");
+ }
+
+ // Simulate 4 concurrent clients, with each running 100 exec and fetch RPCs.
+ final int numClients = 4;
+ final int numQueries = 100;
+ ExecutorService executor = Executors.newFixedThreadPool(numClients);
+ List<Future<Void>> futures = new ArrayList<>();
+ for (int i = 0; i < numClients; i++) {
+ final int clientId = i;
+ Future<Void> future = executor.submit(() -> {
+ simulateClient(headers, config, clientId, numQueries);
+ return null;
+ });
+ futures.add(future);
+ }
+
+ executor.shutdown();
+ executor.awaitTermination(5, TimeUnit.MINUTES);
+
+ // Check for exceptions from client threads
+ for (int i = 0; i < futures.size(); i++) {
+ try {
+ futures.get(i).get();
+ } catch (ExecutionException e) {
+ Throwable cause = e.getCause();
+ System.err.println("Client " + i + " failed: " + cause.getMessage());
+ cause.printStackTrace();
+ fail("Client " + i + " failed: " + cause.getMessage());
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ System.err.println("Main thread interrupted.");
+ }
+ }
+ // Each client uses one OpenSession RPC using cookie based authentication.
+ // Each query runs one Exec, one Fetch and one Close RPC, using 3 cookie based
+ // authentications per query.
+ cookieAuthSuccessCount += numClients * (1 + numQueries * 3);
+ verifyCookieAuthMetrics(cookieAuthSuccessCount, cookieAuthFailureCount);
+ verifyNegotiateAuthMetrics(negotiateAuthSuccessCount, negotiateAuthFailureCount);
+ }
+
+ /**
+ * Generates and returns Base64 encoded SPNEGO token for the input user.
+ */
+ private static String getSpnegoToken(String user) throws Exception {
+ // Create a test user principal and generate Kerberos credentials cache (ccache)
+ String ccacheFilePath =
+ kerberosKdcEnvironment.createUserPrincipalAndCredentialsCache(user);
+ File spngeoTokenFile =
+ new File(kerberosKdcEnvironment.getTestFolderPath() + "/spngeoToken.bin");
+ // Using ProcessBuilder to generate SPNEGO token becasue apparently some of the Java
+ // security classes are initialized much earlier and cannot read required kerberos
+ // config setup by the test.
+ ProcessBuilder pb = new ProcessBuilder(
+ "java", "-cp", System.getProperty("java.class.path"),
+ "-Djava.security.krb5.conf=" + kerberosKdcEnvironment.getKrb5ConfigPath(),
+ "-Dsun.security.krb5.debug=true",
+ "-Djava.security.debug=gssloginconfig,configfile,configparser,logincontext,JGSS",
+ "-Djavax.security.auth.useSubjectCredsOnly=false",
+ "org.apache.impala.customcluster.SpnegoTokenGenerator",
+ spngeoTokenFile.getCanonicalPath());
+
+ Map<String, String> env = pb.environment();
+ env.put("KRB5CCNAME", "FILE:" + ccacheFilePath);
+
+ pb.inheritIO();
+ Process process = pb.start();
+ int exitCode = process.waitFor();
+ // Non zero exit code indicates token generation failed.
+ assertEquals(0, exitCode);
+
+ byte[] token = readTokenFromFile(spngeoTokenFile.getCanonicalPath());
+ String base64Token = Base64.getEncoder().encodeToString(token);
+ return base64Token;
+ }
+
+ /**
+ * Helper function to read token from token file generated by SpnegoTokenGenerator.
+ */
+ private static byte[] readTokenFromFile(String path) throws IOException {
+ ByteArrayOutputStream buffer = new ByteArrayOutputStream();
+ InputStream is = new FileInputStream(path);
+ byte[] temp = new byte[4096];
+ int bytesRead;
+ while ((bytesRead = is.read(temp)) != -1) {
+ buffer.write(temp, 0, bytesRead);
+ }
+ is.close();
+ return buffer.toByteArray();
+ }
+
+ /**
+ * Simulates a client opening session and running a number of queries within a session.
+ */
+ private static void simulateClient(Map<String, String> headers,
+ Map<String, String> config, int clientId, int numQueries) throws Exception {
+ // Create and open the transport
+ THttpClientWithHeaders transport =
+ new THttpClientWithHeaders("http://localhost:28000");
+ transport.setCustomHeaders(headers);
+ transport.open();
+
+ // Create client stub
+ TCLIService.Iface client = new TCLIService.Client(new TBinaryProtocol(transport));
+
+ // Open a session
+ TOpenSessionReq openReq = new TOpenSessionReq();
+ openReq.setConfiguration(config);
+ TOpenSessionResp openResp = client.OpenSession(openReq);
+
+ if (openResp.getStatus().getStatusCode() != TStatusCode.SUCCESS_STATUS) {
+ throw new RuntimeException("Failed to open session for client " + clientId);
+ }
+
+ System.out.println("Client " + clientId + " opened session successfully.");
+
+ // Execute queries
+ for (int i = 0; i < numQueries; i++) {
+ execAndFetch(client, openResp.getSessionHandle(),
+ "select logged_in_user()", "Test1Ldap");
+ int sleepMillis = ThreadLocalRandom.current().nextInt(10, 100);
+ Thread.sleep(sleepMillis);
+ }
+
+ // Close transport
+ transport.close();
+ System.out.println("Client " + clientId + " finished.");
+ }
+
+ /**
+ * Extracts auth mechanism from cookie's value.
+ */
+ private static String extractCookieAuthMech(String cookie) throws Exception {
+ if (cookie == null || cookie.isEmpty()) {
+ return null;
+ }
+ // Expect cookie:
+ // impala.auth=<base64signature>&<cookie_value>;HttpOnly;Max-Age=86400;Secure.
+ String[] cookieFields = cookie.split(";");
+ if (cookieFields.length == 0) {
+ return null;
+ }
+
+ // We've impala.auth=<base64signature>&<cookie_value> as first token with
+ // cookie_value like u=Test4Ldap@myorg.com&t=549158755&r=1800557187&a=SPNEGO.
+ String[] cookieValueFields = cookieFields[0].trim().split("&");
+ assertEquals(5, cookieValueFields.length);
+ String[] authMech = cookieValueFields[4].trim().split("=");
+ assertEquals(2, authMech.length);
+ assertEquals("a", authMech[0]);
+ return authMech[1];
+ }
+}
diff --git a/fe/src/test/java/org/apache/impala/customcluster/SpnegoTokenGenerator.java b/fe/src/test/java/org/apache/impala/customcluster/SpnegoTokenGenerator.java
new file mode 100644
index 0000000..25526d5
--- /dev/null
+++ b/fe/src/test/java/org/apache/impala/customcluster/SpnegoTokenGenerator.java
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.impala.customcluster;
+
+import org.ietf.jgss.*;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+public class SpnegoTokenGenerator {
+ public static void main(String[] args) {
+ try {
+ if (args.length < 1) {
+ System.err.println("Missing argument: <output-token-file>");
+ System.exit(1);
+ }
+ String outputPath = args[0];
+ // OID for Kerberos V5
+ Oid krb5Oid = new Oid("1.2.840.113554.1.2.2");
+
+ // Create GSSManager
+ GSSManager manager = GSSManager.getInstance();
+
+ // This OID corresponds to NT_KRB5_PRINCIPAL
+ Oid krb5PrincipalOid = new Oid("1.2.840.113554.1.2.2.1");
+
+ // Full service principal with realm
+ String servicePrincipal = "impala/localhost@myorg.com";
+
+ // Create GSSName with full principal name
+ GSSName serverName = manager.createName(servicePrincipal, krb5PrincipalOid);
+
+ // Create security context
+ GSSContext context = manager.createContext(
+ serverName,
+ krb5Oid,
+ null, // use default credentials from ccache
+ GSSContext.DEFAULT_LIFETIME
+ );
+
+ // Initiate the context, which triggers ticket acquisition
+ context.requestMutualAuth(true);
+ context.requestCredDeleg(false);
+
+ byte[] token = context.initSecContext(new byte[0], 0, 0);
+ if (token != null) {
+ try {
+ FileOutputStream fos = new FileOutputStream(outputPath);
+ fos.write(token);
+ System.out.println("Token written to " + outputPath);
+ } catch (IOException e) {
+ System.err.println("Failed to write token to file: " + e.getMessage());
+ e.printStackTrace();
+ System.exit(1);
+ }
+ } else {
+ System.err.println("Failed to obtain SPNEGO token.");
+ System.exit(1);
+ }
+ context.dispose();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ System.exit(1);
+ }
+ }
+}
diff --git a/fe/src/test/java/org/apache/impala/customcluster/THttpClientWithHeaders.java b/fe/src/test/java/org/apache/impala/customcluster/THttpClientWithHeaders.java
new file mode 100644
index 0000000..2666b18
--- /dev/null
+++ b/fe/src/test/java/org/apache/impala/customcluster/THttpClientWithHeaders.java
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+// This code is copied from apache/thrift and modified to return
+// HTTP response headers from transport.
+// Original Source:
+// https://github.com/apache/thrift/blob/v0.16.0/lib/java/src/org/apache/thrift/transport/
+// THttpClient.java
+
+package org.apache.impala.customcluster;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
+import java.io.IOException;
+
+import java.net.URL;
+import java.net.HttpURLConnection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.http.HttpEntity;
+import org.apache.http.HttpHost;
+import org.apache.http.HttpResponse;
+import org.apache.http.HttpStatus;
+import org.apache.http.client.HttpClient;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.apache.http.params.CoreConnectionPNames;
+import org.apache.thrift.TConfiguration;
+import org.apache.thrift.transport.TEndpointTransport;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.TTransportFactory;
+
+/**
+ * HTTP implementation of the TTransport interface. Used for working with a
+ * Thrift web services implementation (using for example TServlet).
+ *
+ * This class offers two implementations of the HTTP transport.
+ * One uses HttpURLConnection instances, the other HttpClient from Apache
+ * Http Components.
+ * The chosen implementation depends on the constructor used to
+ * create the THttpClient instance.
+ * Using the THttpClient(String url) constructor or passing null as the
+ * HttpClient to THttpClient(String url, HttpClient client) will create an
+ * instance which will use HttpURLConnection.
+ *
+ * When using HttpClient, the following configuration leads to 5-15%
+ * better performance than the HttpURLConnection implementation:
+ *
+ * http.protocol.version=HttpVersion.HTTP_1_1
+ * http.protocol.content-charset=UTF-8
+ * http.protocol.expect-continue=false
+ * http.connection.stalecheck=false
+ *
+ * Also note that under high load, the HttpURLConnection implementation
+ * may exhaust the open file descriptor limit.
+ *
+ * @see <a href="https://issues.apache.org/jira/browse/THRIFT-970">THRIFT-970</a>
+ */
+
+public class THttpClientWithHeaders extends TEndpointTransport {
+
+ private URL url_ = null;
+
+ private final ByteArrayOutputStream requestBuffer_ = new ByteArrayOutputStream();
+
+ private InputStream inputStream_ = null;
+
+ private int connectTimeout_ = 0;
+
+ private int readTimeout_ = 0;
+
+ private Map<String,String> customHeaders_ = null;
+
+ // Used for storing response headers. This is not in the
+ // THttpClient.java class in the apache/thrift repository.
+ private Map<String,List<String>> responseHeaders_ = null;
+
+ private final HttpHost host;
+
+ private final HttpClient client;
+
+ /* Not compatible with thrift 0.11.0 which is used when USE_APACHE_COMPONENTS=true.
+ public static class Factory extends TTransportFactory {
+
+ private final String url;
+ private final HttpClient client;
+
+ public Factory(String url) {
+ this.url = url;
+ this.client = null;
+ }
+
+ public Factory(String url, HttpClient client) {
+ this.url = url;
+ this.client = client;
+ }
+
+ @Override
+ public TTransport getTransport(TTransport trans) {
+ try {
+ if (null != client) {
+ return new THttpClientWithHeaders(trans.getConfiguration(), url, client);
+ } else {
+ return new THttpClientWithHeaders(trans.getConfiguration(), url);
+ }
+ } catch (TTransportException tte) {
+ return null;
+ }
+ }
+ }*/
+
+ public THttpClientWithHeaders(TConfiguration config, String url)
+ throws TTransportException {
+ super(config);
+ try {
+ url_ = new URL(url);
+ this.client = null;
+ this.host = null;
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public THttpClientWithHeaders(String url) throws TTransportException {
+ super(new TConfiguration());
+ try {
+ url_ = new URL(url);
+ this.client = null;
+ this.host = null;
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public THttpClientWithHeaders(TConfiguration config, String url, HttpClient client)
+ throws TTransportException {
+ super(config);
+ try {
+ url_ = new URL(url);
+ this.client = client;
+ this.host = new HttpHost(url_.getHost(), -1 == url_.getPort()
+ ? url_.getDefaultPort()
+ : url_.getPort(), url_.getProtocol());
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public THttpClientWithHeaders(String url, HttpClient client)
+ throws TTransportException {
+ super(new TConfiguration());
+ try {
+ url_ = new URL(url);
+ this.client = client;
+ this.host = new HttpHost(url_.getHost(), -1 == url_.getPort()
+ ? url_.getDefaultPort()
+ : url_.getPort(), url_.getProtocol());
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public void setConnectTimeout(int timeout) {
+ connectTimeout_ = timeout;
+ if (null != this.client) {
+ // WARNING, this modifies the HttpClient params, this might have an impact elsewhere
+ // if the same HttpClient is used for something else.
+ client.getParams().setParameter(
+ CoreConnectionPNames.CONNECTION_TIMEOUT, connectTimeout_);
+ }
+ }
+
+ public void setReadTimeout(int timeout) {
+ readTimeout_ = timeout;
+ if (null != this.client) {
+ // WARNING, this modifies the HttpClient params, this might have an impact elsewhere
+ // if the same HttpClient is used for something else.
+ client.getParams().setParameter(CoreConnectionPNames.SO_TIMEOUT, readTimeout_);
+ }
+ }
+
+ public void setCustomHeaders(Map<String,String> headers) {
+ customHeaders_ = headers;
+ }
+
+ public void setCustomHeader(String key, String value) {
+ if (customHeaders_ == null) {
+ customHeaders_ = new HashMap<String, String>();
+ }
+ customHeaders_.put(key, value);
+ }
+
+ public void open() {}
+
+ public void close() {
+ if (null != inputStream_) {
+ try {
+ inputStream_.close();
+ } catch (IOException ioe) {
+ }
+ inputStream_ = null;
+ }
+ }
+
+ public boolean isOpen() {
+ return true;
+ }
+
+ public int read(byte[] buf, int off, int len) throws TTransportException {
+ if (inputStream_ == null) {
+ throw new TTransportException("Response buffer is empty, no request.");
+ }
+
+ checkReadBytesAvailable(len);
+
+ try {
+ int ret = inputStream_.read(buf, off, len);
+ if (ret == -1) {
+ throw new TTransportException("No more data available.");
+ }
+ countConsumedMessageBytes(ret);
+
+ return ret;
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public void write(byte[] buf, int off, int len) {
+ requestBuffer_.write(buf, off, len);
+ }
+
+ /**
+ * copy from org.apache.http.util.EntityUtils#consume. Android has it's own httpcore
+ * that doesn't have a consume.
+ */
+ private static void consume(final HttpEntity entity) throws IOException {
+ if (entity == null) {
+ return;
+ }
+ if (entity.isStreaming()) {
+ InputStream instream = entity.getContent();
+ if (instream != null) {
+ instream.close();
+ }
+ }
+ }
+
+ private void flushUsingHttpClient() throws TTransportException {
+
+ if (null == this.client) {
+ throw new TTransportException("Null HttpClient, aborting.");
+ }
+
+ // Extract request and reset buffer
+ byte[] data = requestBuffer_.toByteArray();
+ requestBuffer_.reset();
+
+ HttpPost post = null;
+
+ InputStream is = null;
+
+ try {
+ // Set request to path + query string
+ post = new HttpPost(this.url_.getFile());
+
+ //
+ // Headers are added to the HttpPost instance, not
+ // to HttpClient.
+ //
+
+ post.setHeader("Content-Type", "application/x-thrift");
+ post.setHeader("Accept", "application/x-thrift");
+ post.setHeader("User-Agent", "Java/THttpClient/HC");
+
+ if (null != customHeaders_) {
+ for (Map.Entry<String, String> header : customHeaders_.entrySet()) {
+ post.setHeader(header.getKey(), header.getValue());
+ }
+ }
+
+ post.setEntity(new ByteArrayEntity(data));
+
+ HttpResponse response = this.client.execute(this.host, post);
+ int responseCode = response.getStatusLine().getStatusCode();
+
+ //
+ // Retrieve the inputstream BEFORE checking the status code so
+ // resources get freed in the finally clause.
+ //
+
+ is = response.getEntity().getContent();
+
+ if (responseCode != HttpStatus.SC_OK) {
+ throw new TTransportException("HTTP Response code: " + responseCode);
+ }
+
+ // Read the responses into a byte array so we can release the connection
+ // early. This implies that the whole content will have to be read in
+ // memory, and that momentarily we might use up twice the memory (while the
+ // thrift struct is being read up the chain).
+ // Proceeding differently might lead to exhaustion of connections and thus
+ // to app failure.
+
+ byte[] buf = new byte[1024];
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+
+ int len = 0;
+ do {
+ len = is.read(buf);
+ if (len > 0) {
+ baos.write(buf, 0, len);
+ }
+ } while (-1 != len);
+
+ try {
+ // Indicate we're done with the content.
+ consume(response.getEntity());
+ } catch (IOException ioe) {
+ // We ignore this exception, it might only mean the server has no
+ // keep-alive capability.
+ }
+
+ inputStream_ = new ByteArrayInputStream(baos.toByteArray());
+ } catch (IOException ioe) {
+ // Abort method so the connection gets released back to the connection manager
+ if (null != post) {
+ post.abort();
+ }
+ throw new TTransportException(ioe);
+ } finally {
+ resetConsumedMessageSize(-1);
+ if (null != is) {
+ // Close the entity's input stream, this will release the underlying connection
+ try {
+ is.close();
+ } catch (IOException ioe) {
+ throw new TTransportException(ioe);
+ }
+ }
+ if (post != null) {
+ post.releaseConnection();
+ }
+ }
+ }
+
+ public void flush() throws TTransportException {
+
+ if (null != this.client) {
+ flushUsingHttpClient();
+ return;
+ }
+
+ // Extract request and reset buffer
+ byte[] data = requestBuffer_.toByteArray();
+ requestBuffer_.reset();
+
+ try {
+ // Create connection object
+ HttpURLConnection connection = (HttpURLConnection)url_.openConnection();
+
+ // Timeouts, only if explicitly set
+ if (connectTimeout_ > 0) {
+ connection.setConnectTimeout(connectTimeout_);
+ }
+ if (readTimeout_ > 0) {
+ connection.setReadTimeout(readTimeout_);
+ }
+
+ // Make the request
+ connection.setRequestMethod("POST");
+ connection.setRequestProperty("Content-Type", "application/x-thrift");
+ connection.setRequestProperty("Accept", "application/x-thrift");
+ connection.setRequestProperty("User-Agent", "Java/THttpClient");
+ if (customHeaders_ != null) {
+ for (Map.Entry<String, String> header : customHeaders_.entrySet()) {
+ connection.setRequestProperty(header.getKey(), header.getValue());
+ }
+ }
+ connection.setDoOutput(true);
+ connection.connect();
+ connection.getOutputStream().write(data);
+
+ int responseCode = connection.getResponseCode();
+ if (responseCode != HttpURLConnection.HTTP_OK) {
+ throw new TTransportException("HTTP Response code: " + responseCode);
+ }
+
+ // Read the responses
+ inputStream_ = connection.getInputStream();
+ // Capture the response headers.
+ // This is not in the THttpClient.java class in the apache/thrift repository.
+ responseHeaders_ = connection.getHeaderFields();
+
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ } finally {
+ resetConsumedMessageSize(-1);
+ }
+ }
+
+ // Getter function for HTTP response headers. This is not in the
+ // THttpClient.java class in the apache/thrift repository.
+ public Map<String, List<String>> getResponseHeaders() {
+ return responseHeaders_;
+ }
+}