blob: cc360a48ede26d818a80355faca06d729d3f4e35 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.sql.resources;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.CountingOutputStream;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;
import org.apache.druid.client.indexing.TaskPayloadResponse;
import org.apache.druid.client.indexing.TaskStatusResponse;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.discovery.NodeRole;
import org.apache.druid.error.DruidException;
import org.apache.druid.error.ErrorResponse;
import org.apache.druid.error.Forbidden;
import org.apache.druid.error.InvalidInput;
import org.apache.druid.error.NotFound;
import org.apache.druid.error.QueryExceptionCompat;
import org.apache.druid.frame.channel.FrameChannelSequence;
import org.apache.druid.guice.annotations.MSQ;
import org.apache.druid.indexer.TaskStatusPlus;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.guice.MultiStageQuery;
import org.apache.druid.msq.indexing.MSQControllerTask;
import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
import org.apache.druid.msq.indexing.destination.MSQDestination;
import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.shuffle.input.DurableStorageInputChannelFactory;
import org.apache.druid.msq.sql.MSQTaskQueryMaker;
import org.apache.druid.msq.sql.MSQTaskSqlEngine;
import org.apache.druid.msq.sql.SqlStatementState;
import org.apache.druid.msq.sql.entity.ColumnNameAndTypes;
import org.apache.druid.msq.sql.entity.PageInformation;
import org.apache.druid.msq.sql.entity.ResultSetInformation;
import org.apache.druid.msq.sql.entity.SqlStatementResult;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.msq.util.SqlStatementResourceHelper;
import org.apache.druid.query.ExecutionMode;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryException;
import org.apache.druid.rpc.HttpResponseException;
import org.apache.druid.rpc.indexing.OverlordClient;
import org.apache.druid.server.QueryResponse;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthorizationUtils;
import org.apache.druid.server.security.AuthorizerMapper;
import org.apache.druid.server.security.ForbiddenException;
import org.apache.druid.server.security.Resource;
import org.apache.druid.server.security.ResourceAction;
import org.apache.druid.sql.DirectStatement;
import org.apache.druid.sql.HttpStatement;
import org.apache.druid.sql.SqlRowTransformer;
import org.apache.druid.sql.SqlStatementFactory;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.sql.http.SqlResource;
import org.apache.druid.storage.NilStorageConnector;
import org.apache.druid.storage.StorageConnector;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
@Path("/druid/v2/sql/statements/")
public class SqlStatementResource
{
public static final String RESULT_FORMAT = "__resultFormat";
private static final Logger log = new Logger(SqlStatementResource.class);
private final SqlStatementFactory msqSqlStatementFactory;
private final ObjectMapper jsonMapper;
private final OverlordClient overlordClient;
private final StorageConnector storageConnector;
private final AuthorizerMapper authorizerMapper;
@Inject
public SqlStatementResource(
final @MSQ SqlStatementFactory msqSqlStatementFactory,
final ObjectMapper jsonMapper,
final OverlordClient overlordClient,
final @MultiStageQuery StorageConnector storageConnector,
final AuthorizerMapper authorizerMapper
)
{
this.msqSqlStatementFactory = msqSqlStatementFactory;
this.jsonMapper = jsonMapper;
this.overlordClient = overlordClient;
this.storageConnector = storageConnector;
this.authorizerMapper = authorizerMapper;
}
/**
* API for clients like web-console to check if this resource is enabled.
*/
@GET
@Path("/enabled")
@Produces(MediaType.APPLICATION_JSON)
public Response isEnabled(@Context final HttpServletRequest request)
{
// All authenticated users are authorized for this API.
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(request);
return Response.ok(ImmutableMap.of("enabled", true)).build();
}
@POST
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response doPost(final SqlQuery sqlQuery, @Context final HttpServletRequest req)
{
SqlQuery modifiedQuery = createModifiedSqlQuery(sqlQuery);
final HttpStatement stmt = msqSqlStatementFactory.httpStatement(modifiedQuery, req);
final String sqlQueryId = stmt.sqlQueryId();
final String currThreadName = Thread.currentThread().getName();
boolean isDebug = false;
try {
QueryContext queryContext = QueryContext.of(modifiedQuery.getContext());
isDebug = queryContext.isDebug();
contextChecks(queryContext);
Thread.currentThread().setName(StringUtils.format("statement_sql[%s]", sqlQueryId));
final DirectStatement.ResultSet plan = stmt.plan();
// in case the engine is async, the query is not run yet. We just return the taskID in case of non explain queries.
final QueryResponse<Object[]> response = plan.run();
final Sequence<Object[]> sequence = response.getResults();
final SqlRowTransformer rowTransformer = plan.createRowTransformer();
final boolean isTaskStruct = MSQTaskSqlEngine.TASK_STRUCT_FIELD_NAMES.equals(rowTransformer.getFieldList());
if (isTaskStruct) {
return buildTaskResponse(sequence, stmt.query().authResult());
} else {
// Used for EXPLAIN
return buildStandardResponse(sequence, modifiedQuery, sqlQueryId, rowTransformer);
}
}
catch (DruidException e) {
stmt.reporter().failed(e);
return buildNonOkResponse(e);
}
catch (QueryException queryException) {
stmt.reporter().failed(queryException);
final DruidException underlyingException = DruidException.fromFailure(new QueryExceptionCompat(queryException));
return buildNonOkResponse(underlyingException);
}
catch (ForbiddenException e) {
log.debug("Got forbidden request for reason [%s]", e.getErrorMessage());
return buildNonOkResponse(Forbidden.exception());
}
// Calcite throws java.lang.AssertionError at various points in planning/validation.
catch (AssertionError | Exception e) {
stmt.reporter().failed(e);
if (isDebug) {
log.warn(e, "Failed to handle query [%s]", sqlQueryId);
} else {
log.noStackTrace().warn(e, "Failed to handle query [%s]", sqlQueryId);
}
return buildNonOkResponse(
DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.UNCATEGORIZED)
.build("%s", e.getMessage())
);
}
finally {
stmt.close();
Thread.currentThread().setName(currThreadName);
}
}
@GET
@Path("/{id}")
@Produces(MediaType.APPLICATION_JSON)
public Response doGetStatus(
@PathParam("id") final String queryId, @Context final HttpServletRequest req
)
{
try {
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req);
final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req);
Optional<SqlStatementResult> sqlStatementResult = getStatementStatus(
queryId,
authenticationResult,
true,
Action.READ
);
if (sqlStatementResult.isPresent()) {
return Response.ok().entity(sqlStatementResult.get()).build();
} else {
throw queryNotFoundException(queryId);
}
}
catch (DruidException e) {
return buildNonOkResponse(e);
}
catch (ForbiddenException e) {
log.debug("Got forbidden request for reason [%s]", e.getErrorMessage());
return buildNonOkResponse(Forbidden.exception());
}
catch (Exception e) {
log.warn(e, "Failed to handle query [%s]", queryId);
return buildNonOkResponse(DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.UNCATEGORIZED)
.build(e, "Failed to handle query [%s]", queryId));
}
}
@GET
@Path("/{id}/results")
@Produces(MediaType.APPLICATION_JSON)
public Response doGetResults(
@PathParam("id") final String queryId,
@QueryParam("page") Long page,
@QueryParam("resultFormat") String resultFormat,
@Context final HttpServletRequest req
)
{
try {
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req);
final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req);
if (page != null && page < 0) {
throw DruidException.forPersona(DruidException.Persona.USER)
.ofCategory(DruidException.Category.INVALID_INPUT)
.build(
"Page cannot be negative. Please pass a positive number."
);
}
TaskStatusResponse taskResponse = contactOverlord(overlordClient.taskStatus(queryId), queryId);
if (taskResponse == null) {
throw queryNotFoundException(queryId);
}
TaskStatusPlus statusPlus = taskResponse.getStatus();
if (statusPlus == null || !MSQControllerTask.TYPE.equals(statusPlus.getType())) {
throw queryNotFoundException(queryId);
}
MSQControllerTask msqControllerTask = getMSQControllerTaskAndCheckPermission(
queryId,
authenticationResult,
Action.READ
);
throwIfQueryIsNotSuccessful(queryId, statusPlus);
Optional<List<ColumnNameAndTypes>> signature = SqlStatementResourceHelper.getSignature(msqControllerTask);
if (!signature.isPresent() || MSQControllerTask.isIngestion(msqControllerTask.getQuerySpec())) {
// Since it's not a select query, nothing to return.
return Response.ok().build();
}
// returning results
final Closer closer = Closer.create();
final Optional<Yielder<Object[]>> results;
results = getResultYielder(queryId, page, msqControllerTask, closer);
if (!results.isPresent()) {
// no results, return empty
return Response.ok().build();
}
ResultFormat preferredFormat = getPreferredResultFormat(resultFormat, msqControllerTask.getQuerySpec());
return Response.ok((StreamingOutput) outputStream -> resultPusher(
queryId,
signature,
closer,
results,
new CountingOutputStream(outputStream),
preferredFormat
)).build();
}
catch (DruidException e) {
return buildNonOkResponse(e);
}
catch (ForbiddenException e) {
log.debug("Got forbidden request for reason [%s]", e.getErrorMessage());
return buildNonOkResponse(Forbidden.exception());
}
catch (Exception e) {
log.warn(e, "Failed to handle query [%s]", queryId);
return buildNonOkResponse(DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.UNCATEGORIZED)
.build(e, "Failed to handle query [%s]", queryId));
}
}
/**
* Queries can be canceled while in any {@link SqlStatementState}. Canceling a query that has already completed will be a no-op.
*
* @param queryId queryId
* @param req httpServletRequest
* @return HTTP 404 if the query ID does not exist,expired or originated by different user. HTTP 202 if the deletion
* request has been accepted.
*/
@DELETE
@Path("/{id}")
@Produces(MediaType.APPLICATION_JSON)
public Response deleteQuery(@PathParam("id") final String queryId, @Context final HttpServletRequest req)
{
try {
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req);
final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req);
Optional<SqlStatementResult> sqlStatementResult = getStatementStatus(
queryId,
authenticationResult,
false,
Action.WRITE
);
if (sqlStatementResult.isPresent()) {
switch (sqlStatementResult.get().getState()) {
case ACCEPTED:
case RUNNING:
overlordClient.cancelTask(queryId);
return Response.status(Response.Status.ACCEPTED).build();
case SUCCESS:
case FAILED:
// we would also want to clean up the results in the future.
return Response.ok().build();
default:
throw new ISE("Illegal State[%s] encountered", sqlStatementResult.get().getState());
}
} else {
throw queryNotFoundException(queryId);
}
}
catch (DruidException e) {
return buildNonOkResponse(e);
}
catch (ForbiddenException e) {
log.debug("Got forbidden request for reason [%s]", e.getErrorMessage());
return buildNonOkResponse(Forbidden.exception());
}
catch (Exception e) {
log.warn(e, "Failed to handle query [%s]", queryId);
return buildNonOkResponse(DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.UNCATEGORIZED)
.build(e, "Failed to handle query [%s]", queryId));
}
}
private Response buildStandardResponse(
Sequence<Object[]> sequence,
SqlQuery sqlQuery,
String sqlQueryId,
SqlRowTransformer rowTransformer
) throws IOException
{
final Yielder<Object[]> yielder0 = Yielders.each(sequence);
try {
final Response.ResponseBuilder responseBuilder = Response.ok((StreamingOutput) outputStream -> {
CountingOutputStream os = new CountingOutputStream(outputStream);
Yielder<Object[]> yielder = yielder0;
try (final ResultFormat.Writer writer = sqlQuery.getResultFormat().createFormatter(os, jsonMapper)) {
writer.writeResponseStart();
if (sqlQuery.includeHeader()) {
writer.writeHeader(
rowTransformer.getRowType(),
sqlQuery.includeTypesHeader(),
sqlQuery.includeSqlTypesHeader()
);
}
while (!yielder.isDone()) {
final Object[] row = yielder.get();
writer.writeRowStart();
for (int i = 0; i < rowTransformer.getFieldList().size(); i++) {
final Object value = rowTransformer.transform(row, i);
writer.writeRowField(rowTransformer.getFieldList().get(i), value);
}
writer.writeRowEnd();
yielder = yielder.next(null);
}
writer.writeResponseEnd();
}
catch (Exception e) {
log.error(e, "Unable to send SQL response [%s]", sqlQueryId);
throw new RuntimeException(e);
}
finally {
yielder.close();
}
});
if (sqlQuery.includeHeader()) {
responseBuilder.header(SqlResource.SQL_HEADER_RESPONSE_HEADER, SqlResource.SQL_HEADER_VALUE);
}
return responseBuilder.build();
}
catch (Throwable e) {
// make sure to close yielder if anything happened before starting to serialize the response.
yielder0.close();
throw e;
}
}
private Response buildTaskResponse(Sequence<Object[]> sequence, AuthenticationResult authenticationResult)
{
List<Object[]> rows = sequence.toList();
int numRows = rows.size();
if (numRows != 1) {
throw new RE("Expected a single row but got [%d] rows. Please check broker logs for more information.", numRows);
}
Object[] firstRow = rows.get(0);
if (firstRow == null || firstRow.length != 1) {
throw new RE(
"Expected a single column but got [%s] columns. Please check broker logs for more information.",
firstRow == null ? 0 : firstRow.length
);
}
String taskId = String.valueOf(firstRow[0]);
Optional<SqlStatementResult> statementResult = getStatementStatus(taskId, authenticationResult, true, Action.READ);
if (statementResult.isPresent()) {
return Response.status(Response.Status.OK).entity(statementResult.get()).build();
} else {
return buildNonOkResponse(
DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.DEFENSIVE).build(
"Unable to find associated task for query id [%s]. Contact cluster admin to check overlord logs for [%s]",
taskId,
taskId
)
);
}
}
private Response buildNonOkResponse(DruidException exception)
{
return Response
.status(exception.getStatusCode())
.entity(new ErrorResponse(exception))
.build();
}
@SuppressWarnings("ReassignedVariable")
private Optional<ResultSetInformation> getResultSetInformation(
String queryId,
String dataSource,
SqlStatementState sqlStatementState,
MSQDestination msqDestination
)
{
if (sqlStatementState == SqlStatementState.SUCCESS) {
MSQTaskReportPayload msqTaskReportPayload =
SqlStatementResourceHelper.getPayload(contactOverlord(
overlordClient.taskReportAsMap(queryId),
queryId
));
Optional<List<PageInformation>> pageList = SqlStatementResourceHelper.populatePageList(
msqTaskReportPayload,
msqDestination
);
// getting the total number of rows, size from page information.
Long rows = null;
Long size = null;
if (pageList.isPresent()) {
rows = 0L;
size = 0L;
for (PageInformation pageInformation : pageList.get()) {
rows += pageInformation.getNumRows() != null ? pageInformation.getNumRows() : 0L;
size += pageInformation.getSizeInBytes() != null ? pageInformation.getSizeInBytes() : 0L;
}
}
boolean isSelectQuery = msqDestination instanceof TaskReportMSQDestination
|| msqDestination instanceof DurableStorageMSQDestination;
List<Object[]> results = null;
if (isSelectQuery) {
results = new ArrayList<>();
Yielder<Object[]> yielder = null;
if (msqTaskReportPayload.getResults() != null) {
yielder = msqTaskReportPayload.getResults().getResultYielder();
}
try {
while (yielder != null && !yielder.isDone()) {
results.add(yielder.get());
yielder = yielder.next(null);
}
}
finally {
if (yielder != null) {
try {
yielder.close();
}
catch (IOException e) {
log.warn(e, StringUtils.format("Unable to close yielder for query[%s]", queryId));
}
}
}
}
return Optional.of(
new ResultSetInformation(
rows,
size,
null,
dataSource,
results,
isSelectQuery ? pageList.orElse(null) : null
)
);
} else {
return Optional.empty();
}
}
private Optional<SqlStatementResult> getStatementStatus(
String queryId,
AuthenticationResult authenticationResult,
boolean withResults,
Action forAction
) throws DruidException
{
TaskStatusResponse taskResponse = contactOverlord(overlordClient.taskStatus(queryId), queryId);
if (taskResponse == null) {
return Optional.empty();
}
TaskStatusPlus statusPlus = taskResponse.getStatus();
if (statusPlus == null || !MSQControllerTask.TYPE.equals(statusPlus.getType())) {
return Optional.empty();
}
// since we need the controller payload for auth checks.
MSQControllerTask msqControllerTask = getMSQControllerTaskAndCheckPermission(queryId, authenticationResult, forAction);
SqlStatementState sqlStatementState = SqlStatementResourceHelper.getSqlStatementState(statusPlus);
if (SqlStatementState.FAILED == sqlStatementState) {
return SqlStatementResourceHelper.getExceptionPayload(
queryId,
taskResponse,
statusPlus,
sqlStatementState,
contactOverlord(overlordClient.taskReportAsMap(queryId), queryId),
jsonMapper
);
} else {
Optional<List<ColumnNameAndTypes>> signature = SqlStatementResourceHelper.getSignature(msqControllerTask);
return Optional.of(new SqlStatementResult(
queryId,
sqlStatementState,
taskResponse.getStatus().getCreatedTime(),
signature.orElse(null),
taskResponse.getStatus().getDuration(),
withResults ? getResultSetInformation(
queryId,
msqControllerTask.getDataSource(),
sqlStatementState,
msqControllerTask.getQuerySpec().getDestination()
).orElse(null) : null,
null
));
}
}
/**
* This method contacts the overlord for the controller task and checks if the requested user has the
* necessary permissions. A user has the necessary permissions if one of the following criteria is satisfied:
* 1. The user is the one who submitted the query
* 2. The user belongs to a role containing the READ or WRITE permissions over the STATE resource. For endpoints like GET,
* the user should have READ permission for the STATE resource, while for endpoints like DELETE, the user should
* have WRITE permission for the STATE resource. (Note: POST API does not need to check the state permissions since
* the currentUser always equal to the queryUser)
*/
private MSQControllerTask getMSQControllerTaskAndCheckPermission(
String queryId,
AuthenticationResult authenticationResult,
Action forAction
) throws ForbiddenException
{
TaskPayloadResponse taskPayloadResponse = contactOverlord(overlordClient.taskPayload(queryId), queryId);
SqlStatementResourceHelper.isMSQPayload(taskPayloadResponse, queryId);
MSQControllerTask msqControllerTask = (MSQControllerTask) taskPayloadResponse.getPayload();
String queryUser = String.valueOf(msqControllerTask.getQuerySpec()
.getQuery()
.getContext()
.get(MSQTaskQueryMaker.USER_KEY));
String currentUser = authenticationResult.getIdentity();
if (currentUser != null && currentUser.equals(queryUser)) {
return msqControllerTask;
}
Access access = AuthorizationUtils.authorizeAllResourceActions(
authenticationResult,
Collections.singletonList(new ResourceAction(Resource.STATE_RESOURCE, forAction)),
authorizerMapper
);
if (access.isAllowed()) {
return msqControllerTask;
}
throw new ForbiddenException(StringUtils.format(
"The current user[%s] cannot view query id[%s] since the query is owned by another user",
currentUser,
queryId
));
}
/**
* Creates a new sqlQuery from the user submitted sqlQuery after performing required modifications.
*/
private SqlQuery createModifiedSqlQuery(SqlQuery sqlQuery)
{
Map<String, Object> context = sqlQuery.getContext();
if (context.containsKey(RESULT_FORMAT)) {
throw InvalidInput.exception("Query context parameter [%s] is not allowed", RESULT_FORMAT);
}
Map<String, Object> modifiedContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put(RESULT_FORMAT, sqlQuery.getResultFormat().toString())
.build();
return new SqlQuery(
sqlQuery.getQuery(),
sqlQuery.getResultFormat(),
sqlQuery.includeHeader(),
sqlQuery.includeTypesHeader(),
sqlQuery.includeSqlTypesHeader(),
modifiedContext,
sqlQuery.getParameters()
);
}
private ResultFormat getPreferredResultFormat(String resultFormatParam, MSQSpec msqSpec)
{
if (resultFormatParam == null) {
return QueryContexts.getAsEnum(
RESULT_FORMAT,
msqSpec.getQuery().context().get(RESULT_FORMAT),
ResultFormat.class,
ResultFormat.DEFAULT_RESULT_FORMAT
);
}
return QueryContexts.getAsEnum(
"resultFormat",
resultFormatParam,
ResultFormat.class
);
}
private Optional<Yielder<Object[]>> getResultYielder(
String queryId,
Long page,
MSQControllerTask msqControllerTask,
Closer closer
)
{
final Optional<Yielder<Object[]>> results;
if (msqControllerTask.getQuerySpec().getDestination() instanceof TaskReportMSQDestination) {
// Results from task report are only present as one page.
if (page != null && page > 0) {
throw InvalidInput.exception(
"Page number [%d] is out of the range of results", page
);
}
MSQTaskReportPayload msqTaskReportPayload = SqlStatementResourceHelper.getPayload(
contactOverlord(overlordClient.taskReportAsMap(queryId), queryId)
);
if (msqTaskReportPayload.getResults().getResultYielder() == null) {
results = Optional.empty();
} else {
results = Optional.of(msqTaskReportPayload.getResults().getResultYielder());
}
} else if (msqControllerTask.getQuerySpec().getDestination() instanceof DurableStorageMSQDestination) {
MSQTaskReportPayload msqTaskReportPayload = SqlStatementResourceHelper.getPayload(
contactOverlord(overlordClient.taskReportAsMap(queryId), queryId)
);
List<PageInformation> pages =
SqlStatementResourceHelper.populatePageList(
msqTaskReportPayload,
msqControllerTask.getQuerySpec().getDestination()
).orElse(null);
if (pages == null || pages.isEmpty()) {
return Optional.empty();
}
final StageDefinition finalStage = Objects.requireNonNull(SqlStatementResourceHelper.getFinalStage(
msqTaskReportPayload)).getStageDefinition();
// get all results
final Long selectedPageId;
if (page != null) {
selectedPageId = getPageInformationForPageId(pages, page).getId();
} else {
selectedPageId = null;
}
checkForDurableStorageConnectorImpl();
final DurableStorageInputChannelFactory standardImplementation = DurableStorageInputChannelFactory.createStandardImplementation(
msqControllerTask.getId(),
storageConnector,
closer,
true
);
results = Optional.of(Yielders.each(
Sequences.concat(pages.stream()
.filter(pageInformation -> selectedPageId == null
|| selectedPageId.equals(pageInformation.getId()))
.map(pageInformation -> {
try {
if (pageInformation.getWorker() == null || pageInformation.getPartition() == null) {
throw DruidException.defensive(
"Worker or partition number is null for page id [%d]",
pageInformation.getId()
);
}
return new FrameChannelSequence(standardImplementation.openChannel(
finalStage.getId(),
pageInformation.getWorker(),
pageInformation.getPartition()
));
}
catch (Exception e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList()))
.flatMap(frame -> SqlStatementResourceHelper.getResultSequence(
msqControllerTask,
finalStage,
frame,
jsonMapper
)
)
.withBaggage(closer)));
} else {
throw DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.UNCATEGORIZED)
.build(
"MSQ select destination[%s] not supported. Please reach out to druid slack community for more help.",
msqControllerTask.getQuerySpec().getDestination().toString()
);
}
return results;
}
private PageInformation getPageInformationForPageId(List<PageInformation> pages, Long pageId)
{
for (PageInformation pageInfo : pages) {
if (pageInfo.getId() == pageId) {
return pageInfo;
}
}
throw InvalidInput.exception("Invalid page id [%d] passed.", pageId);
}
private void resultPusher(
String queryId,
Optional<List<ColumnNameAndTypes>> signature,
Closer closer,
Optional<Yielder<Object[]>> results,
CountingOutputStream os,
ResultFormat resultFormat
) throws IOException
{
try {
try (final ResultFormat.Writer writer = resultFormat.createFormatter(os, jsonMapper)) {
Yielder<Object[]> yielder = results.get();
List<ColumnNameAndTypes> rowSignature = signature.get();
resultPusherInternal(writer, yielder, rowSignature);
}
catch (Exception e) {
log.error(e, "Unable to stream results back for query[%s]", queryId);
throw new ISE(e, "Unable to stream results back for query[%s]", queryId);
}
}
catch (Exception e) {
log.error(e, "Unable to stream results back for query[%s]", queryId);
throw new ISE(e, "Unable to stream results back for query[%s]", queryId);
}
finally {
closer.close();
}
}
@VisibleForTesting
static void resultPusherInternal(
ResultFormat.Writer writer,
Yielder<Object[]> yielder,
List<ColumnNameAndTypes> rowSignature
) throws IOException
{
writer.writeResponseStart();
while (!yielder.isDone()) {
writer.writeRowStart();
Object[] row = yielder.get();
for (int i = 0; i < Math.min(rowSignature.size(), row.length); i++) {
writer.writeRowField(
rowSignature.get(i).getColName(),
row[i]
);
}
writer.writeRowEnd();
yielder = yielder.next(null);
}
writer.writeResponseEnd();
yielder.close();
}
private static void throwIfQueryIsNotSuccessful(String queryId, TaskStatusPlus statusPlus)
{
SqlStatementState sqlStatementState = SqlStatementResourceHelper.getSqlStatementState(statusPlus);
if (sqlStatementState == SqlStatementState.RUNNING || sqlStatementState == SqlStatementState.ACCEPTED) {
throw DruidException.forPersona(DruidException.Persona.USER)
.ofCategory(DruidException.Category.INVALID_INPUT)
.build(
"Query[%s] is currently in [%s] state. Please wait for it to complete.",
queryId,
sqlStatementState
);
} else if (sqlStatementState == SqlStatementState.FAILED) {
throw DruidException.forPersona(DruidException.Persona.USER)
.ofCategory(DruidException.Category.INVALID_INPUT)
.build(
"Query[%s] failed. Check the status api for more details.",
queryId
);
} else {
// do nothing
}
}
private void contextChecks(QueryContext queryContext)
{
ExecutionMode executionMode = queryContext.getEnum(QueryContexts.CTX_EXECUTION_MODE, ExecutionMode.class, null);
if (executionMode == null) {
throw InvalidInput.exception(
"Execution mode is not provided to the sql statement api. "
+ "Please set [%s] to [%s] in the query context",
QueryContexts.CTX_EXECUTION_MODE,
ExecutionMode.ASYNC
);
}
if (!ExecutionMode.ASYNC.equals(executionMode)) {
throw InvalidInput.exception(
"The sql statement api currently does not support the provided execution mode [%s]. "
+ "Please set [%s] to [%s] in the query context",
executionMode,
QueryContexts.CTX_EXECUTION_MODE,
ExecutionMode.ASYNC
);
}
MSQSelectDestination selectDestination = MultiStageQueryContext.getSelectDestination(queryContext);
if (MSQSelectDestination.DURABLESTORAGE.equals(selectDestination)) {
checkForDurableStorageConnectorImpl();
}
}
private void checkForDurableStorageConnectorImpl()
{
if (storageConnector instanceof NilStorageConnector) {
throw DruidException.forPersona(DruidException.Persona.USER)
.ofCategory(DruidException.Category.INVALID_INPUT)
.build(
StringUtils.format(
"The sql statement api cannot read from the select destination [%s] provided "
+ "in the query context [%s] since it is not configured on the %s. It is recommended to configure durable storage "
+ "as it allows the user to fetch large result sets. Please contact your cluster admin to "
+ "configure durable storage.",
MSQSelectDestination.DURABLESTORAGE.getName(),
MultiStageQueryContext.CTX_SELECT_DESTINATION,
NodeRole.BROKER.getJsonName()
)
);
}
}
private <T> T contactOverlord(final ListenableFuture<T> future, String queryId)
{
try {
return FutureUtils.getUnchecked(future, true);
}
catch (RuntimeException e) {
if (e.getCause() instanceof HttpResponseException) {
HttpResponseException httpResponseException = (HttpResponseException) e.getCause();
if (httpResponseException.getResponse() != null && httpResponseException.getResponse().getResponse().getStatus()
.equals(HttpResponseStatus.NOT_FOUND)) {
log.info(httpResponseException, "Query details not found for queryId [%s]", queryId);
// since we get a 404, we mark the request as a NotFound. This code path is generally triggered when user passes a `queryId` which is not found in the overlord.
throw queryNotFoundException(queryId);
}
}
throw DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.UNCATEGORIZED)
.build("Unable to contact overlord " + e.getMessage());
}
}
private static DruidException queryNotFoundException(String queryId)
{
return NotFound.exception("Query [%s] was not found. The query details are no longer present or might not be of the type [%s]. Verify that the id is correct.", queryId, MSQControllerTask.TYPE);
}
}