GH-2394: Fix for query exec when abort is called before exec
diff --git a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java
index 6a52755..e2078a2 100644
--- a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java
+++ b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java
@@ -89,6 +89,7 @@
private long timeout2 = TIMEOUT_UNSET;
private final AlarmClock alarmClock = AlarmClock.get();
private long queryStartTime = -1; // Unset
+ private AtomicBoolean cancelSignal = new AtomicBoolean(false);
protected QueryExecDataset(Query query, String queryString, DatasetGraph datasetGraph, Context cxt,
QueryEngineFactory qeFactory,
@@ -142,6 +143,7 @@
@Override
public void abort() {
+ cancelSignal.set(true);
synchronized (lockTimeout) {
// This is called asynchronously to the execution.
// synchronized is for coordination with other calls of
@@ -166,11 +168,9 @@
}
private RowSet execute() {
- execInit();
startQueryIterator();
- Iterator<Binding> iter = queryIterator;
List<Var> vars = query.getResultVars().stream().map(Var::alloc).collect(Collectors.toList());
- return RowSetStream.create(vars, iter);
+ return RowSetStream.create(vars, queryIterator);
}
// -- Construct
@@ -357,12 +357,6 @@
}
class TimeoutCallback implements Runnable {
- private final AtomicBoolean cancelSignal;
-
- public TimeoutCallback(AtomicBoolean cancelSignal) {
- this.cancelSignal = cancelSignal;
- }
-
@Override
public void run() {
synchronized (lockTimeout) {
@@ -379,11 +373,9 @@
}
private class QueryIteratorTimer2 extends QueryIteratorWrapper {
- private final AtomicBoolean cancelSignal;
- public QueryIteratorTimer2(QueryIterator qIter, AtomicBoolean cancelSignal) {
+ public QueryIteratorTimer2(QueryIterator qIter) {
super(qIter);
- this.cancelSignal = cancelSignal;
}
long yieldCount = 0;
@@ -400,7 +392,7 @@
// So nearly not needed.
synchronized(lockTimeout)
{
- TimeoutCallback callback = new TimeoutCallback(cancelSignal);
+ TimeoutCallback callback = new TimeoutCallback();
expectedCallback.set(callback);
// Lock against calls of .abort() or of timeout1Callback.
@@ -437,10 +429,32 @@
queryStartTime = System.currentTimeMillis();
}
- /** Start the query iterator, setting timeouts as needed. */
+ /** Wrapper for starting the query iterator but also dealing with cancellation */
private void startQueryIterator() {
- if ( queryIterator != null )
+ synchronized (lockTimeout) {
+ if (cancelSignal.get()) {
+ // Fail before starting the iterator if cancelled already
+ throw new QueryCancelledException();
+ }
+
+ startQueryIteratorActual();
+
+ if (cancelSignal.get()) {
+ queryIterator.cancel();
+
+ // Fail now if cancelled already
+ throw new QueryCancelledException();
+ }
+ }
+ }
+
+ /** Start the query iterator, setting timeouts as needed. */
+ private void startQueryIteratorActual() {
+ if ( queryIterator != null ) {
Log.warn(this, "Query iterator has already been started");
+ return;
+ }
+
execInit();
/* Timeouts:
@@ -465,19 +479,15 @@
// This applies to the time to first result because to get the first result, the
// queryIterator must have been built. So it does not apply for the second
// stage of N,-1 or N,M.
+ context.set(ARQConstants.symCancelQuery, cancelSignal);
+ TimeoutCallback callback = new TimeoutCallback() ;
+ expectedCallback.set(callback) ;
if ( !isTimeoutSet(timeout1) && isTimeoutSet(timeout2) ) {
// Case -1,N
- AtomicBoolean cancelSignal = new AtomicBoolean(false);
- context.set(ARQConstants.symCancelQuery, cancelSignal);
- TimeoutCallback callback = new TimeoutCallback(cancelSignal) ;
- expectedCallback.set(callback) ;
timeout2Alarm = alarmClock.add(callback, timeout2) ;
// Start the query.
queryIterator = getPlan().iterator();
- // Timeout when off.
- if ( cancelSignal.get() )
- queryIterator.cancel();
// But don't add resetter.
return ;
}
@@ -489,23 +499,11 @@
// Subcase 2: ! isTimeoutSet(timeout2)
// Add timeout to first row.
- AtomicBoolean cancelSignal = new AtomicBoolean(false);
- context.set(ARQConstants.symCancelQuery, cancelSignal);
-
- TimeoutCallback callback = new TimeoutCallback(cancelSignal) ;
timeout1Alarm = alarmClock.add(callback, timeout1) ;
- expectedCallback.set(callback) ;
queryIterator = getPlan().iterator();
// Add the timeout1->timeout2 resetter wrapper.
- queryIterator = new QueryIteratorTimer2(queryIterator, cancelSignal);
-
- // Minor optimization - timeout has already occurred. The first call of hasNext() or next()
- // will throw QueryCancelledExcetion anyway. This just makes it a bit earlier
- // in the case when the timeout (timeout1) is so short it's gone off already.
-
- if ( cancelSignal.get() )
- queryIterator.cancel();
+ queryIterator = new QueryIteratorTimer2(queryIterator);
}
private Plan getPlan() {
diff --git a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java
index a29f21c..679ce33 100644
--- a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java
+++ b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java
@@ -38,7 +38,7 @@
*/
public class QueryExecutionCompat extends QueryExecutionAdapter {
private final QueryExecMod qExecBuilder;
- private QueryExec qExecHere = null;
+ private volatile QueryExec qExecHere = null;
private final Dataset datasetHere;
private final Query queryHere;
@@ -60,9 +60,14 @@
}
private void execution() {
- // Delay until used so setTimeout and initalBindings work.
- if ( qExecHere == null )
- qExecHere = qExecBuilder.build();
+ if ( qExecHere == null) {
+ // Synchronized because there may be an async call to abort()
+ synchronized (this) {
+ // Delay until used so setTimeout and initalBindings work.
+ if ( qExecHere == null )
+ qExecHere = qExecBuilder.build();
+ }
+ }
}
@Override
diff --git a/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java b/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java
index c52356b..1815e1a 100644
--- a/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java
+++ b/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java
@@ -20,23 +20,45 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.jena.graph.Graph;
+import org.apache.jena.graph.NodeFactory;
import org.apache.jena.query.* ;
import org.apache.jena.rdf.model.Model ;
+import org.apache.jena.rdf.model.ModelFactory;
import org.apache.jena.rdf.model.Property ;
import org.apache.jena.rdf.model.Resource ;
+import org.apache.jena.sparql.core.DatasetGraph;
+import org.apache.jena.sparql.core.DatasetGraphFactory;
+import org.apache.jena.sparql.exec.QueryExec;
import org.apache.jena.sparql.function.FunctionRegistry ;
import org.apache.jena.sparql.function.library.wait ;
import org.apache.jena.sparql.graph.GraphFactory ;
+import org.apache.jena.sparql.sse.SSE;
import org.junit.AfterClass ;
+import org.junit.Assert;
import org.junit.BeforeClass ;
import org.junit.Test ;
public class TestQueryExecutionCancel {
private static final String ns = "http://example/ns#" ;
-
+
static Model m = GraphFactory.makeJenaDefaultModel() ;
static Resource r1 = m.createResource() ;
static Property p1 = m.createProperty(ns+"p1") ;
@@ -47,10 +69,10 @@
m.add(r1, p2, "X2") ; // NB Capital
m.add(r1, p3, "y1") ;
}
-
+
@BeforeClass public static void beforeClass() { FunctionRegistry.get().put(ns + "wait", wait.class) ; }
@AfterClass public static void afterClass() { FunctionRegistry.get().remove(ns + "wait") ; }
-
+
@Test(expected=QueryCancelledException.class)
public void test_Cancel_API_1()
{
@@ -63,7 +85,7 @@
assertFalse("Results not expected after cancel.", rs.hasNext()) ;
}
}
-
+
@Test(expected=QueryCancelledException.class)
public void test_Cancel_API_2()
{
@@ -75,8 +97,8 @@
rs.nextSolution();
assertFalse("Results not expected after cancel.", rs.hasNext()) ;
}
- }
-
+ }
+
@Test public void test_Cancel_API_3() throws InterruptedException
{
// Don't qExec.close on this thread.
@@ -88,9 +110,9 @@
synchronized (qExec) { qExec.notify() ; }
assertEquals (1, thread.getCount()) ;
}
-
+
@Test public void test_Cancel_API_4() throws InterruptedException
- {
+ {
// Don't qExec.close on this thread.
QueryExecution qExec = makeQExec("PREFIX ex: <" + ns + "> SELECT * { ?s ?p ?o } ORDER BY ex:wait(100)") ;
CancelThreadRunner thread = new CancelThreadRunner(qExec);
@@ -101,6 +123,14 @@
assertEquals (1, thread.getCount()) ;
}
+ @Test(expected = QueryCancelledException.class)
+ public void test_Cancel_API_5() {
+ try (QueryExecution qe = QueryExecutionFactory.create("SELECT * { ?s ?p ?o }", m)) {
+ qe.abort();
+ ResultSetFormatter.consume(qe.execSelect());
+ }
+ }
+
private QueryExecution makeQExec(String queryString)
{
Query q = QueryFactory.create(queryString) ;
@@ -108,39 +138,203 @@
return qExec ;
}
- class CancelThreadRunner extends Thread
+ class CancelThreadRunner extends Thread
{
- private QueryExecution qExec = null ;
- private int count = 0 ;
+ private QueryExecution qExec = null ;
+ private int count = 0 ;
- public CancelThreadRunner(QueryExecution qExec)
- {
- this.qExec = qExec ;
- }
-
- @Override
- public void run()
- {
- try
+ public CancelThreadRunner(QueryExecution qExec)
+ {
+ this.qExec = qExec ;
+ }
+
+ @Override
+ public void run()
+ {
+ try
{
ResultSet rs = qExec.execSelect() ;
- while ( rs.hasNext() )
+ while ( rs.hasNext() )
{
rs.nextSolution() ;
count++ ;
synchronized (qExec) { qExec.notify() ; }
synchronized (qExec) { qExec.wait() ; }
}
- }
+ }
catch (QueryCancelledException e) {}
- catch (InterruptedException e) {
+ catch (InterruptedException e) {
e.printStackTrace();
- } finally { qExec.close() ; }
- }
-
- public int getCount()
- {
- return count ;
- }
+ } finally { qExec.close() ; }
+ }
+
+ public int getCount()
+ {
+ return count ;
+ }
+ }
+
+ @Test
+ public void test_cancel_select_1() {
+ cancellationTest("SELECT * {}", QueryExec::select);
+ }
+
+ @Test
+ public void test_cancel_select_2() {
+ cancellationTest("SELECT * {}", QueryExec::select, Iterator::hasNext);
+ }
+
+ @Test
+ public void test_cancel_ask() {
+ cancellationTest("ASK {}", QueryExec::ask);
+ }
+
+ @Test
+ public void test_cancel_construct() {
+ cancellationTest("CONSTRUCT WHERE {}", QueryExec::construct);
+ }
+
+ @Test
+ public void test_cancel_describe() {
+ cancellationTest("DESCRIBE * {}", QueryExec::describe);
+ }
+
+ @Test
+ public void test_cancel_construct_dataset() {
+ cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructDataset);
+ }
+
+ @Test
+ public void test_cancel_construct_triples_1() {
+ cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructTriples, Iterator::hasNext);
+ }
+
+ @Test
+ public void test_cancel_construct_triples_2() {
+ cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructTriples);
+ }
+
+ @Test
+ public void test_cancel_construct_quads_1() {
+ cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructQuads, Iterator::hasNext);
+ }
+
+ @Test
+ public void test_cancel_construct_quads_2() {
+ cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructQuads);
+ }
+
+ @Test
+ public void test_cancel_json() {
+ cancellationTest("JSON {\":a\": \"b\"} WHERE {}", exec->exec.execJson().get(0));
+ }
+
+ static <T> void cancellationTest(String queryString, Function<QueryExec, Iterator<T>> itFactory, Consumer<Iterator<T>> itConsumer) {
+ cancellationTest(queryString, itFactory::apply);
+ cancellationTestForIterator(queryString, itFactory, itConsumer);
+ }
+
+ /** Abort the query exec and expect all execution methods to fail */
+ static void cancellationTest(String queryString, Consumer<QueryExec> execAction) {
+ DatasetGraph dsg = DatasetGraphFactory.createTxnMem();
+ dsg.add(SSE.parseQuad("(_ :s :p :o)"));
+ try(QueryExec aExec = QueryExec.dataset(dsg).query(queryString).build()) {
+ aExec.abort();
+ assertThrows(QueryCancelledException.class, ()-> execAction.accept(aExec));
+ }
+ }
+
+ /** Obtain an iterator and only afterwards abort the query exec.
+ * Operations on the iterator are now expected to fail. */
+ static <T> void cancellationTestForIterator(String queryString, Function<QueryExec, Iterator<T>> itFactory, Consumer<Iterator<T>> itConsumer) {
+ DatasetGraph dsg = DatasetGraphFactory.createTxnMem();
+ dsg.add(SSE.parseQuad("(_ :s :p :o)"));
+ try(QueryExec aExec = QueryExec.dataset(dsg).query(queryString).build()) {
+ Iterator<T> it = itFactory.apply(aExec);
+ aExec.abort();
+ assertThrows(QueryCancelledException.class, ()-> itConsumer.accept(it));
+ }
+ }
+
+ /**
+ * Test that creates iterators over a billion result rows and attempts to cancel them.
+ * If this test hangs then it is likely that something went wrong in the cancellation machinery.
+ */
+ @Test(timeout = 10000)
+ public void test_cancel_concurrent_1() {
+ int maxCancelDelayInMillis = 100;
+
+ int cpuCount = Runtime.getRuntime().availableProcessors();
+ // Spend at most roughly 1 second per cpu (10 tasks a max 100ms)
+ int taskCount = cpuCount * 10;
+
+ // Create a model with 1000 triples
+ Graph graph = GraphFactory.createDefaultGraph();
+ IntStream.range(0, 1000)
+ .mapToObj(i -> NodeFactory.createURI("http://www.example.org/r" + i))
+ .forEach(node -> graph.add(node, node, node));
+ Model model = ModelFactory.createModelForGraph(graph);
+
+ // Create a query that creates 3 cross joins - resulting in one billion result rows
+ Query query = QueryFactory.create("SELECT * { ?a ?b ?c . ?d ?e ?f . ?g ?h ?i . }");
+ Callable<QueryExecution> qeFactory = () -> QueryExecutionFactory.create(query, model);
+
+ runConcurrentAbort(taskCount, maxCancelDelayInMillis, qeFactory, TestQueryExecutionCancel::doCount);
+ }
+
+ private static final int doCount(QueryExecution qe) {
+ try (QueryExecution qe2 = qe) {
+ ResultSet rs = qe2.execSelect();
+ int size = ResultSetFormatter.consume(rs);
+ return size;
+ }
+ }
+
+ /**
+ * Reusable method that creates a parallel stream that starts query executions
+ * and schedules cancel tasks on a separate thread pool.
+ */
+ public static void runConcurrentAbort(int taskCount, int maxCancelDelay, Callable<QueryExecution> qeFactory, Function<QueryExecution, ?> processor) {
+ Random cancelDelayRandom = new Random();
+ ExecutorService executorService = Executors.newCachedThreadPool();
+ try {
+ List<Integer> list = IntStream.range(0, taskCount).boxed().collect(Collectors.toList());
+ list
+ .parallelStream()
+ .forEach(i -> {
+ QueryExecution qe;
+ try {
+ qe = qeFactory.call();
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to build a query execution", e);
+ }
+ Future<?> future = executorService.submit(() -> processor.apply(qe));
+ int delayToAbort = cancelDelayRandom.nextInt(maxCancelDelay);
+ try {
+ Thread.sleep(delayToAbort);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ // System.out.println("Abort: " + qe);
+ qe.abort();
+ try {
+ // System.out.println("Waiting for: " + qe);
+ future.get();
+ } catch (ExecutionException e) {
+ Throwable cause = e.getCause();
+ if (!(cause instanceof QueryCancelledException)) {
+ // Unexpected exception - print out the stack trace
+ e.printStackTrace();
+ }
+ Assert.assertEquals(QueryCancelledException.class, cause.getClass());
+ } catch (InterruptedException e) {
+ // Ignored
+ } finally {
+ // System.out.println("Completed: " + qe);
+ }
+ });
+ } finally {
+ executorService.shutdownNow();
+ }
}
}