blob: b1b5084a70c0f3ad7612433afed47a81c1a15383 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.lifecycle.Lifecycle;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.easymock.Capture;
import org.easymock.EasyMock;
import org.easymock.IAnswer;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
public class ChainedExecutionQueryRunnerTest
{
private final Lock neverRelease = new ReentrantLock();
@Before
public void setup()
{
neverRelease.lock();
}
@Test(timeout = 60_000L)
public void testQueryCancellation() throws Exception
{
ExecutorService exec = PrioritizedExecutorService.create(
new Lifecycle(), new DruidProcessingConfig()
{
@Override
public String getFormatString()
{
return "test";
}
@Override
public int getNumThreads()
{
return 2;
}
}
);
final CountDownLatch queriesStarted = new CountDownLatch(2);
final CountDownLatch queriesInterrupted = new CountDownLatch(2);
final CountDownLatch queryIsRegistered = new CountDownLatch(1);
Capture<ListenableFuture> capturedFuture = EasyMock.newCapture();
QueryWatcher watcher = EasyMock.createStrictMock(QueryWatcher.class);
watcher.registerQueryFuture(
EasyMock.anyObject(),
EasyMock.and(EasyMock.anyObject(), EasyMock.capture(capturedFuture))
);
EasyMock.expectLastCall()
.andAnswer(
new IAnswer<Void>()
{
@Override
public Void answer()
{
queryIsRegistered.countDown();
return null;
}
}
)
.once();
EasyMock.replay(watcher);
ArrayBlockingQueue<DyingQueryRunner> interrupted = new ArrayBlockingQueue<>(3);
Set<DyingQueryRunner> runners = Sets.newHashSet(
new DyingQueryRunner(queriesStarted, queriesInterrupted, interrupted),
new DyingQueryRunner(queriesStarted, queriesInterrupted, interrupted),
new DyingQueryRunner(queriesStarted, queriesInterrupted, interrupted)
);
ChainedExecutionQueryRunner chainedRunner = new ChainedExecutionQueryRunner<>(
new ForwardingQueryProcessingPool(exec),
watcher,
Lists.newArrayList(
runners
)
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("test")
.intervals("2014/2015")
.aggregators(Collections.singletonList(new CountAggregatorFactory("count")))
.build();
final Sequence seq = chainedRunner.run(QueryPlus.wrap(query));
Future resultFuture = Executors.newFixedThreadPool(1).submit(
new Runnable()
{
@Override
public void run()
{
seq.toList();
}
}
);
// wait for query to register and start
queryIsRegistered.await();
queriesStarted.await();
// cancel the query
Assert.assertTrue(capturedFuture.hasCaptured());
ListenableFuture future = capturedFuture.getValue();
future.cancel(true);
QueryInterruptedException cause = null;
try {
resultFuture.get();
}
catch (ExecutionException e) {
Assert.assertTrue(e.getCause() instanceof QueryInterruptedException);
cause = (QueryInterruptedException) e.getCause();
}
queriesInterrupted.await();
Assert.assertNotNull(cause);
Assert.assertTrue(future.isCancelled());
DyingQueryRunner interrupted1 = interrupted.poll();
synchronized (interrupted1) {
Assert.assertTrue("runner 1 started", interrupted1.hasStarted);
Assert.assertTrue("runner 1 interrupted", interrupted1.interrupted);
}
DyingQueryRunner interrupted2 = interrupted.poll();
synchronized (interrupted2) {
Assert.assertTrue("runner 2 started", interrupted2.hasStarted);
Assert.assertTrue("runner 2 interrupted", interrupted2.interrupted);
}
runners.remove(interrupted1);
runners.remove(interrupted2);
DyingQueryRunner remainingRunner = runners.iterator().next();
synchronized (remainingRunner) {
Assert.assertTrue("runner 3 should be interrupted or not have started",
!remainingRunner.hasStarted || remainingRunner.interrupted);
}
Assert.assertFalse("runner 1 not completed", interrupted1.hasCompleted);
Assert.assertFalse("runner 2 not completed", interrupted2.hasCompleted);
Assert.assertFalse("runner 3 not completed", remainingRunner.hasCompleted);
EasyMock.verify(watcher);
}
@Test(timeout = 60_000L)
public void testQueryTimeout() throws Exception
{
ExecutorService exec = PrioritizedExecutorService.create(
new Lifecycle(), new DruidProcessingConfig()
{
@Override
public String getFormatString()
{
return "test";
}
@Override
public int getNumThreads()
{
return 2;
}
}
);
final CountDownLatch queriesStarted = new CountDownLatch(2);
final CountDownLatch queriesInterrupted = new CountDownLatch(2);
final CountDownLatch queryIsRegistered = new CountDownLatch(1);
Capture<ListenableFuture> capturedFuture = Capture.newInstance();
QueryWatcher watcher = EasyMock.createStrictMock(QueryWatcher.class);
watcher.registerQueryFuture(
EasyMock.anyObject(),
EasyMock.and(EasyMock.anyObject(), EasyMock.capture(capturedFuture))
);
EasyMock.expectLastCall()
.andAnswer(
new IAnswer<Void>()
{
@Override
public Void answer()
{
queryIsRegistered.countDown();
return null;
}
}
)
.once();
EasyMock.replay(watcher);
ArrayBlockingQueue<DyingQueryRunner> interrupted = new ArrayBlockingQueue<>(3);
Set<DyingQueryRunner> runners = Sets.newHashSet(
new DyingQueryRunner(queriesStarted, queriesInterrupted, interrupted),
new DyingQueryRunner(queriesStarted, queriesInterrupted, interrupted),
new DyingQueryRunner(queriesStarted, queriesInterrupted, interrupted)
);
ChainedExecutionQueryRunner chainedRunner = new ChainedExecutionQueryRunner<>(
new ForwardingQueryProcessingPool(exec),
watcher,
Lists.newArrayList(
runners
)
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("test")
.intervals("2014/2015")
.aggregators(Collections.singletonList(new CountAggregatorFactory("count")))
.context(ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 100, "queryId", "test"))
.build();
final Sequence seq = chainedRunner.run(QueryPlus.wrap(query));
Future resultFuture = Executors.newFixedThreadPool(1).submit(
new Runnable()
{
@Override
public void run()
{
seq.toList();
}
}
);
// wait for query to register and start
queryIsRegistered.await();
queriesStarted.await();
Assert.assertTrue(capturedFuture.hasCaptured());
ListenableFuture future = capturedFuture.getValue();
// wait for query to time out
QueryTimeoutException cause = null;
try {
resultFuture.get();
}
catch (ExecutionException e) {
Assert.assertTrue(e.getCause() instanceof QueryTimeoutException);
Assert.assertEquals("Query timeout", ((QueryTimeoutException) e.getCause()).getErrorCode());
cause = (QueryTimeoutException) e.getCause();
}
queriesInterrupted.await();
Assert.assertNotNull(cause);
Assert.assertTrue(future.isCancelled());
DyingQueryRunner interrupted1 = interrupted.poll();
synchronized (interrupted1) {
Assert.assertTrue("runner 1 started", interrupted1.hasStarted);
Assert.assertTrue("runner 1 interrupted", interrupted1.interrupted);
}
DyingQueryRunner interrupted2 = interrupted.poll();
synchronized (interrupted2) {
Assert.assertTrue("runner 2 started", interrupted2.hasStarted);
Assert.assertTrue("runner 2 interrupted", interrupted2.interrupted);
}
runners.remove(interrupted1);
runners.remove(interrupted2);
DyingQueryRunner remainingRunner = runners.iterator().next();
synchronized (remainingRunner) {
Assert.assertTrue("runner 3 should be interrupted or not have started",
!remainingRunner.hasStarted || remainingRunner.interrupted);
}
Assert.assertFalse("runner 1 not completed", interrupted1.hasCompleted);
Assert.assertFalse("runner 2 not completed", interrupted2.hasCompleted);
Assert.assertFalse("runner 3 not completed", remainingRunner.hasCompleted);
EasyMock.verify(watcher);
}
@Test
public void testSubmittedTaskType()
{
QueryProcessingPool queryProcessingPool = Mockito.mock(QueryProcessingPool.class);
QueryWatcher watcher = EasyMock.createStrictMock(QueryWatcher.class);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("test")
.intervals("2014/2015")
.aggregators(Collections.singletonList(new CountAggregatorFactory("count")))
.context(ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 100, "queryId", "test"))
.build();
List<QueryRunner<Result<TimeseriesResultValue>>> runners = Arrays.asList(
Mockito.mock(QueryRunner.class),
Mockito.mock(QueryRunner.class)
);
ChainedExecutionQueryRunner<Result<TimeseriesResultValue>> chainedRunner = new ChainedExecutionQueryRunner<>(
queryProcessingPool,
watcher,
runners
);
Mockito.when(queryProcessingPool.submitRunnerTask(ArgumentMatchers.any())).thenReturn(Futures.immediateFuture(Collections.singletonList(123)));
chainedRunner.run(QueryPlus.wrap(query)).toList();
ArgumentCaptor<PrioritizedQueryRunnerCallable> captor = ArgumentCaptor.forClass(PrioritizedQueryRunnerCallable.class);
Mockito.verify(queryProcessingPool, Mockito.times(2)).submitRunnerTask(captor.capture());
List<QueryRunner> actual = captor.getAllValues().stream().map(PrioritizedQueryRunnerCallable::getRunner).collect(Collectors.toList());
Assert.assertEquals(runners, actual);
}
private class DyingQueryRunner implements QueryRunner<Integer>
{
private final CountDownLatch start;
private final CountDownLatch stop;
private final Queue<DyingQueryRunner> interruptedRunners;
private volatile boolean hasStarted = false;
private volatile boolean hasCompleted = false;
private volatile boolean interrupted = false;
public DyingQueryRunner(CountDownLatch start, CountDownLatch stop, Queue<DyingQueryRunner> interruptedRunners)
{
this.start = start;
this.stop = stop;
this.interruptedRunners = interruptedRunners;
}
@Override
public Sequence<Integer> run(QueryPlus<Integer> queryPlus, ResponseContext responseContext)
{
// do a lot of work
synchronized (this) {
try {
hasStarted = true;
start.countDown();
neverRelease.lockInterruptibly();
}
catch (InterruptedException e) {
interrupted = true;
interruptedRunners.offer(this);
stop.countDown();
throw new QueryInterruptedException(e);
}
}
hasCompleted = true;
stop.countDown();
return Sequences.simple(Collections.singletonList(123));
}
}
}