blob: 759e480eba90ce4ff641c330be6556d51817b4b6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.store.http;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.apache.drill.common.logical.OAuthConfig;
import org.apache.drill.common.logical.StoragePluginConfig.AuthMode;
import org.apache.drill.common.logical.security.CredentialsProvider;
import org.apache.drill.common.logical.security.PlainCredentialsProvider;
import org.apache.drill.common.types.TypeProtos.DataMode;
import org.apache.drill.common.types.TypeProtos.MinorType;
import org.apache.drill.common.util.DrillFileUtils;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.oauth.PersistentTokenTable;
import org.apache.drill.exec.physical.rowSet.DirectRowSet;
import org.apache.drill.exec.physical.rowSet.RowSet;
import org.apache.drill.exec.physical.rowSet.RowSetBuilder;
import org.apache.drill.exec.record.metadata.SchemaBuilder;
import org.apache.drill.exec.record.metadata.TupleMetadata;
import org.apache.drill.exec.store.security.oauth.OAuthTokenCredentials;
import com.google.common.io.Files;
import org.apache.drill.test.ClusterFixtureBuilder;
import org.apache.drill.test.ClusterTest;
import org.apache.drill.test.rowSet.RowSetUtilities;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class TestOAuthProcess extends ClusterTest {
private static final Logger logger = LoggerFactory.getLogger(TestOAuthProcess.class);
private static final int MOCK_SERVER_PORT = 47779;
private static final int TIMEOUT = 30;
private static final String CONNECTION_NAME = "localOauth";
private final OkHttpClient httpClient = new OkHttpClient.Builder()
.connectTimeout(TIMEOUT, TimeUnit.SECONDS)
.writeTimeout(TIMEOUT, TimeUnit.SECONDS)
.readTimeout(TIMEOUT, TimeUnit.SECONDS).build();
private static String ACCESS_TOKEN_RESPONSE;
private static String REFRESH_TOKEN_RESPONSE;
private static String TEST_JSON_RESPONSE_WITH_DATATYPES;
private static String hostname;
@BeforeClass
public static void setup() throws Exception {
ACCESS_TOKEN_RESPONSE = Files.asCharSource(DrillFileUtils.getResourceAsFile("/data/oauth_access_token_response.json"),
StandardCharsets.UTF_8).read();
REFRESH_TOKEN_RESPONSE = Files.asCharSource(DrillFileUtils.getResourceAsFile("/data/token_refresh.json"),
StandardCharsets.UTF_8).read();
TEST_JSON_RESPONSE_WITH_DATATYPES = Files.asCharSource(DrillFileUtils.getResourceAsFile("/data/response2.json"),
StandardCharsets.UTF_8).read();
ClusterFixtureBuilder builder = new ClusterFixtureBuilder(dirTestWatcher)
.configProperty(ExecConstants.HTTP_ENABLE, true)
.configProperty(ExecConstants.HTTP_PORT_HUNT, true);
startCluster(builder);
int portNumber = cluster.drillbit().getWebServerPort();
hostname = "http://localhost:" + portNumber + "/storage/" + CONNECTION_NAME;
Map<String, String> creds = new HashMap<>();
creds.put("clientID", "12345");
creds.put("clientSecret", "54321");
creds.put("accessToken", null);
creds.put("refreshToken", null);
creds.put(OAuthTokenCredentials.TOKEN_URI, "http://localhost:" + MOCK_SERVER_PORT + "/get_access_token");
CredentialsProvider credentialsProvider = new PlainCredentialsProvider(creds);
HttpApiConfig connectionConfig = HttpApiConfig.builder()
.url("http://localhost:" + MOCK_SERVER_PORT + "/getdata")
.method("get")
.requireTail(false)
.inputType("json")
.build();
OAuthConfig oAuthConfig = OAuthConfig.builder()
.callbackURL(hostname + "/update_oauth2_authtoken")
.build();
Map<String, HttpApiConfig> configs = new HashMap<>();
configs.put("test", connectionConfig);
// Add storage plugin for test OAuth
HttpStoragePluginConfig mockStorageConfigWithWorkspace =
new HttpStoragePluginConfig(false, false, configs, TIMEOUT, 1000, null, null, "", 80, "", "", "",
oAuthConfig, credentialsProvider, AuthMode.SHARED_USER.name());
mockStorageConfigWithWorkspace.setEnabled(true);
cluster.defineStoragePlugin("localOauth", mockStorageConfigWithWorkspace);
}
@Test
public void testAccessToken() {
String url = hostname + "/update_oauth2_authtoken?code=ABCDEF";
Request request = new Request.Builder().url(url).build();
try (MockWebServer server = startServer()) {
server.enqueue(new MockResponse().setResponseCode(200).setBody(ACCESS_TOKEN_RESPONSE));
Response response = httpClient.newCall(request).execute();
// Verify that the request succeeded w/o error
assertEquals(200, response.code());
// Verify that the access and refresh tokens were saved
PersistentTokenTable tokenTable = ((HttpStoragePlugin) cluster
.storageRegistry()
.getPlugin("localOauth"))
.getTokenTable();
assertEquals("you_have_access", tokenTable.getAccessToken());
assertEquals("refresh_me", tokenTable.getRefreshToken());
assertEquals("3600", tokenTable.getExpiresIn());
} catch (Exception e) {
logger.error(e.getMessage());
fail();
}
}
@Test
public void testGetDataWithAuthentication() {
String url = hostname + "/update_oauth2_authtoken?code=ABCDEF";
Request request = new Request.Builder().url(url).build();
try (MockWebServer server = startServer()) {
server.enqueue(new MockResponse().setResponseCode(200).setBody(ACCESS_TOKEN_RESPONSE));
Response response = httpClient.newCall(request).execute();
// Verify that the request succeeded w/o error
assertEquals(200, response.code());
// Verify that the access and refresh tokens were saved
PersistentTokenTable tokenTable = ((HttpStoragePlugin) cluster.storageRegistry()
.getPlugin("localOauth"))
.getTokenRegistry()
.getTokenTable("localOauth");
assertEquals("you_have_access", tokenTable.getAccessToken());
assertEquals("refresh_me", tokenTable.getRefreshToken());
assertEquals("3600", tokenTable.getExpiresIn());
// Now execute a query and get query results.
server.enqueue(new MockResponse()
.setResponseCode(200)
.setBody(TEST_JSON_RESPONSE_WITH_DATATYPES));
String sql = "SELECT * FROM localOauth.test";
DirectRowSet results = queryBuilder().sql(sql).rowSet();
TupleMetadata expectedSchema = new SchemaBuilder()
.add("col_1", MinorType.FLOAT8, DataMode.OPTIONAL)
.add("col_2", MinorType.BIGINT, DataMode.OPTIONAL)
.add("col_3", MinorType.VARCHAR, DataMode.OPTIONAL)
.build();
RowSet expected = new RowSetBuilder(client.allocator(), expectedSchema)
.addRow(1.0, 2, "3.0")
.addRow(4.0, 5, "6.0")
.build();
RowSetUtilities.verify(expected, results);
} catch (Exception e) {
logger.error(e.getMessage());
fail();
}
}
@Test
public void testGetDataWithTokenRefresh() {
String url = hostname + "/update_oauth2_authtoken?code=ABCDEF";
Request request = new Request.Builder().url(url).build();
try (MockWebServer server = startServer()) {
server.enqueue(new MockResponse().setResponseCode(200).setBody(ACCESS_TOKEN_RESPONSE));
Response response = httpClient.newCall(request).execute();
// Verify that the request succeeded w/o error
assertEquals(200, response.code());
// Verify that the access and refresh tokens were saved
PersistentTokenTable tokenTable = ((HttpStoragePlugin) cluster.storageRegistry().getPlugin("localOauth")).getTokenRegistry().getTokenTable("localOauth");
assertEquals("you_have_access", tokenTable.getAccessToken());
assertEquals("refresh_me", tokenTable.getRefreshToken());
assertEquals("3600", tokenTable.getExpiresIn());
// Now execute a query and get a refresh token
// The API should return a 401 error. This should trigger Drill to automatically
// fire off a second call with the refresh token and then a third request with the
// new access token to obtain the actual data.
server.enqueue(new MockResponse().setResponseCode(401).setBody("Access Denied"));
server.enqueue(new MockResponse().setResponseCode(200).setBody(REFRESH_TOKEN_RESPONSE));
server.enqueue(new MockResponse()
.setResponseCode(200)
.setBody(TEST_JSON_RESPONSE_WITH_DATATYPES));
String sql = "SELECT * FROM localOauth.test";
DirectRowSet results = queryBuilder().sql(sql).rowSet();
// Verify that the access and refresh tokens were saved
assertEquals("token 2.0", tokenTable.getAccessToken());
assertEquals("refresh 2.0", tokenTable.getRefreshToken());
assertEquals("3800", tokenTable.getExpiresIn());
TupleMetadata expectedSchema = new SchemaBuilder()
.add("col_1", MinorType.FLOAT8, DataMode.OPTIONAL)
.add("col_2", MinorType.BIGINT, DataMode.OPTIONAL)
.add("col_3", MinorType.VARCHAR, DataMode.OPTIONAL)
.build();
RowSet expected = new RowSetBuilder(client.allocator(), expectedSchema)
.addRow(1.0, 2, "3.0")
.addRow(4.0, 5, "6.0")
.build();
RowSetUtilities.verify(expected, results);
} catch (Exception e) {
logger.debug(e.getMessage());
fail();
}
}
/**
* Helper function to start the MockHTTPServer
* @return Started Mock server
* @throws IOException If the server cannot start, throws IOException
*/
public static MockWebServer startServer () throws IOException {
MockWebServer server = new MockWebServer();
server.start(MOCK_SERVER_PORT);
return server;
}
}