GH-2394: Fix for query exec when abort is called before exec
diff --git a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java
index 6a52755..e2078a2 100644
--- a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java
+++ b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecDataset.java
@@ -89,6 +89,7 @@
     private long                     timeout2         = TIMEOUT_UNSET;
     private final AlarmClock         alarmClock       = AlarmClock.get();
     private long                     queryStartTime   = -1; // Unset
+    private AtomicBoolean            cancelSignal     = new AtomicBoolean(false);
 
     protected QueryExecDataset(Query query, String queryString, DatasetGraph datasetGraph, Context cxt,
                                QueryEngineFactory qeFactory,
@@ -142,6 +143,7 @@
 
     @Override
     public void abort() {
+        cancelSignal.set(true);
         synchronized (lockTimeout) {
             // This is called asynchronously to the execution.
             // synchronized is for coordination with other calls of
@@ -166,11 +168,9 @@
     }
 
     private RowSet execute() {
-        execInit();
         startQueryIterator();
-        Iterator<Binding> iter = queryIterator;
         List<Var> vars = query.getResultVars().stream().map(Var::alloc).collect(Collectors.toList());
-        return RowSetStream.create(vars, iter);
+        return RowSetStream.create(vars, queryIterator);
     }
 
     // -- Construct
@@ -357,12 +357,6 @@
     }
 
     class TimeoutCallback implements Runnable {
-        private final AtomicBoolean cancelSignal;
-
-        public TimeoutCallback(AtomicBoolean cancelSignal) {
-            this.cancelSignal = cancelSignal;
-        }
-
         @Override
         public void run() {
             synchronized (lockTimeout) {
@@ -379,11 +373,9 @@
     }
 
     private class QueryIteratorTimer2 extends QueryIteratorWrapper {
-        private final AtomicBoolean cancelSignal;
 
-        public QueryIteratorTimer2(QueryIterator qIter, AtomicBoolean cancelSignal) {
+        public QueryIteratorTimer2(QueryIterator qIter) {
             super(qIter);
-            this.cancelSignal = cancelSignal;
         }
 
         long yieldCount = 0;
@@ -400,7 +392,7 @@
                 // So nearly not needed.
                 synchronized(lockTimeout)
                 {
-                    TimeoutCallback callback = new TimeoutCallback(cancelSignal);
+                    TimeoutCallback callback = new TimeoutCallback();
                     expectedCallback.set(callback);
                     // Lock against calls of .abort() or of timeout1Callback.
 
@@ -437,10 +429,32 @@
             queryStartTime = System.currentTimeMillis();
     }
 
-    /** Start the query iterator, setting timeouts as needed. */
+    /** Wrapper for starting the query iterator but also dealing with cancellation */
     private void startQueryIterator() {
-        if ( queryIterator != null )
+        synchronized (lockTimeout) {
+            if (cancelSignal.get()) {
+                // Fail before starting the iterator if cancelled already
+                throw new QueryCancelledException();
+            }
+
+            startQueryIteratorActual();
+
+            if (cancelSignal.get()) {
+                queryIterator.cancel();
+
+                // Fail now if cancelled already
+                throw new QueryCancelledException();
+            }
+        }
+    }
+
+    /** Start the query iterator, setting timeouts as needed. */
+    private void startQueryIteratorActual() {
+        if ( queryIterator != null ) {
             Log.warn(this, "Query iterator has already been started");
+            return;
+        }
+
         execInit();
 
         /* Timeouts:
@@ -465,19 +479,15 @@
         // This applies to the time to first result because to get the first result, the
         // queryIterator must have been built. So it does not apply for the second
         // stage of N,-1 or N,M.
+        context.set(ARQConstants.symCancelQuery, cancelSignal);
+        TimeoutCallback callback = new TimeoutCallback() ;
+        expectedCallback.set(callback) ;
 
         if ( !isTimeoutSet(timeout1) && isTimeoutSet(timeout2) ) {
             // Case -1,N
-            AtomicBoolean cancelSignal = new AtomicBoolean(false);
-            context.set(ARQConstants.symCancelQuery, cancelSignal);
-            TimeoutCallback callback = new TimeoutCallback(cancelSignal) ;
-            expectedCallback.set(callback) ;
             timeout2Alarm = alarmClock.add(callback, timeout2) ;
             // Start the query.
             queryIterator = getPlan().iterator();
-            // Timeout when off.
-            if ( cancelSignal.get() )
-                queryIterator.cancel();
             // But don't add resetter.
             return ;
         }
@@ -489,23 +499,11 @@
         //   Subcase 2: ! isTimeoutSet(timeout2)
         // Add timeout to first row.
 
-        AtomicBoolean cancelSignal = new AtomicBoolean(false);
-        context.set(ARQConstants.symCancelQuery, cancelSignal);
-
-        TimeoutCallback callback = new TimeoutCallback(cancelSignal) ;
         timeout1Alarm = alarmClock.add(callback, timeout1) ;
-        expectedCallback.set(callback) ;
 
         queryIterator = getPlan().iterator();
         // Add the timeout1->timeout2 resetter wrapper.
-        queryIterator = new QueryIteratorTimer2(queryIterator, cancelSignal);
-
-        // Minor optimization - timeout has already occurred. The first call of hasNext() or next()
-        // will throw QueryCancelledExcetion anyway. This just makes it a bit earlier
-        // in the case when the timeout (timeout1) is so short it's gone off already.
-
-        if ( cancelSignal.get() )
-            queryIterator.cancel();
+        queryIterator = new QueryIteratorTimer2(queryIterator);
     }
 
     private Plan getPlan() {
diff --git a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java
index a29f21c..679ce33 100644
--- a/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java
+++ b/jena-arq/src/main/java/org/apache/jena/sparql/exec/QueryExecutionCompat.java
@@ -38,7 +38,7 @@
  */
 public class QueryExecutionCompat extends QueryExecutionAdapter {
     private final QueryExecMod qExecBuilder;
-    private QueryExec qExecHere = null;
+    private volatile QueryExec qExecHere = null;
     private final Dataset datasetHere;
     private final Query queryHere;
 
@@ -60,9 +60,14 @@
     }
 
     private void execution() {
-        // Delay until used so setTimeout and initalBindings work.
-        if ( qExecHere == null )
-            qExecHere = qExecBuilder.build();
+        if ( qExecHere == null) {
+            // Synchronized because there may be an async call to abort()
+            synchronized (this) {
+                // Delay until used so setTimeout and initalBindings work.
+                if ( qExecHere == null )
+                    qExecHere = qExecBuilder.build();
+            }
+        }
     }
 
     @Override
diff --git a/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java b/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java
index c52356b..1815e1a 100644
--- a/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java
+++ b/jena-arq/src/test/java/org/apache/jena/sparql/api/TestQueryExecutionCancel.java
@@ -20,23 +20,45 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.jena.graph.Graph;
+import org.apache.jena.graph.NodeFactory;
 import org.apache.jena.query.* ;
 import org.apache.jena.rdf.model.Model ;
+import org.apache.jena.rdf.model.ModelFactory;
 import org.apache.jena.rdf.model.Property ;
 import org.apache.jena.rdf.model.Resource ;
+import org.apache.jena.sparql.core.DatasetGraph;
+import org.apache.jena.sparql.core.DatasetGraphFactory;
+import org.apache.jena.sparql.exec.QueryExec;
 import org.apache.jena.sparql.function.FunctionRegistry ;
 import org.apache.jena.sparql.function.library.wait ;
 import org.apache.jena.sparql.graph.GraphFactory ;
+import org.apache.jena.sparql.sse.SSE;
 import org.junit.AfterClass ;
+import org.junit.Assert;
 import org.junit.BeforeClass ;
 import org.junit.Test ;
 
 public class TestQueryExecutionCancel {
 
     private static final String ns = "http://example/ns#" ;
-    
+
     static Model m = GraphFactory.makeJenaDefaultModel() ;
     static Resource r1 = m.createResource() ;
     static Property p1 = m.createProperty(ns+"p1") ;
@@ -47,10 +69,10 @@
         m.add(r1, p2, "X2") ; // NB Capital
         m.add(r1, p3, "y1") ;
     }
-    
+
     @BeforeClass public static void beforeClass() { FunctionRegistry.get().put(ns + "wait", wait.class) ; }
     @AfterClass  public static void afterClass() { FunctionRegistry.get().remove(ns + "wait") ; }
-    
+
     @Test(expected=QueryCancelledException.class)
     public void test_Cancel_API_1()
     {
@@ -63,7 +85,7 @@
             assertFalse("Results not expected after cancel.", rs.hasNext()) ;
         }
     }
-    
+
     @Test(expected=QueryCancelledException.class)
     public void test_Cancel_API_2()
     {
@@ -75,8 +97,8 @@
             rs.nextSolution();
             assertFalse("Results not expected after cancel.", rs.hasNext()) ;
         }
-    }    
-    
+    }
+
     @Test public void test_Cancel_API_3() throws InterruptedException
     {
         // Don't qExec.close on this thread.
@@ -88,9 +110,9 @@
         synchronized (qExec) { qExec.notify() ; }
         assertEquals (1, thread.getCount()) ;
     }
-    
+
     @Test public void test_Cancel_API_4() throws InterruptedException
-    { 
+    {
         // Don't qExec.close on this thread.
         QueryExecution qExec = makeQExec("PREFIX ex: <" + ns + "> SELECT * { ?s ?p ?o } ORDER BY ex:wait(100)") ;
         CancelThreadRunner thread = new CancelThreadRunner(qExec);
@@ -101,6 +123,14 @@
         assertEquals (1, thread.getCount()) ;
     }
 
+    @Test(expected = QueryCancelledException.class)
+    public void test_Cancel_API_5() {
+        try (QueryExecution qe = QueryExecutionFactory.create("SELECT * { ?s ?p ?o }", m)) {
+            qe.abort();
+            ResultSetFormatter.consume(qe.execSelect());
+        }
+    }
+
     private QueryExecution makeQExec(String queryString)
     {
         Query q = QueryFactory.create(queryString) ;
@@ -108,39 +138,203 @@
         return qExec ;
     }
 
-    class CancelThreadRunner extends Thread 
+    class CancelThreadRunner extends Thread
     {
-    	private QueryExecution qExec = null ;
-    	private int count = 0 ;
+        private QueryExecution qExec = null ;
+        private int count = 0 ;
 
-    	public CancelThreadRunner(QueryExecution qExec) 
-    	{
-    		this.qExec = qExec ;
-    	}
-    	
-    	@Override
-    	public void run() 
-    	{
-            try 
+        public CancelThreadRunner(QueryExecution qExec)
+        {
+            this.qExec = qExec ;
+        }
+
+        @Override
+        public void run()
+        {
+            try
             {
                 ResultSet rs = qExec.execSelect() ;
-                while ( rs.hasNext() ) 
+                while ( rs.hasNext() )
                 {
                     rs.nextSolution() ;
                     count++ ;
                     synchronized (qExec) { qExec.notify() ; }
                     synchronized (qExec) { qExec.wait() ; }
                 }
-    		} 
+            }
             catch (QueryCancelledException e) {}
-            catch (InterruptedException e) { 
+            catch (InterruptedException e) {
                 e.printStackTrace();
-    		} finally { qExec.close() ; }
-    	}
-    	
-    	public int getCount() 
-    	{
-    		return count ;
-    	}
+            } finally { qExec.close() ; }
+        }
+
+        public int getCount()
+        {
+            return count ;
+        }
+    }
+
+    @Test
+    public void test_cancel_select_1() {
+        cancellationTest("SELECT * {}", QueryExec::select);
+    }
+
+    @Test
+    public void test_cancel_select_2() {
+        cancellationTest("SELECT * {}", QueryExec::select, Iterator::hasNext);
+    }
+
+    @Test
+    public void test_cancel_ask() {
+        cancellationTest("ASK {}", QueryExec::ask);
+    }
+
+    @Test
+    public void test_cancel_construct() {
+        cancellationTest("CONSTRUCT WHERE {}", QueryExec::construct);
+    }
+
+    @Test
+    public void test_cancel_describe() {
+        cancellationTest("DESCRIBE * {}", QueryExec::describe);
+    }
+
+    @Test
+    public void test_cancel_construct_dataset() {
+        cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructDataset);
+    }
+
+    @Test
+    public void test_cancel_construct_triples_1() {
+        cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructTriples, Iterator::hasNext);
+    }
+
+    @Test
+    public void test_cancel_construct_triples_2() {
+        cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructTriples);
+    }
+
+    @Test
+    public void test_cancel_construct_quads_1() {
+        cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructQuads, Iterator::hasNext);
+    }
+
+    @Test
+    public void test_cancel_construct_quads_2() {
+        cancellationTest("CONSTRUCT{} WHERE{}", QueryExec::constructQuads);
+    }
+
+    @Test
+    public void test_cancel_json() {
+        cancellationTest("JSON {\":a\": \"b\"} WHERE {}", exec->exec.execJson().get(0));
+    }
+
+    static <T> void cancellationTest(String queryString, Function<QueryExec, Iterator<T>> itFactory, Consumer<Iterator<T>> itConsumer) {
+        cancellationTest(queryString, itFactory::apply);
+        cancellationTestForIterator(queryString, itFactory, itConsumer);
+    }
+
+    /** Abort the query exec and expect all execution methods to fail */
+    static void cancellationTest(String queryString, Consumer<QueryExec> execAction) {
+        DatasetGraph dsg = DatasetGraphFactory.createTxnMem();
+        dsg.add(SSE.parseQuad("(_ :s :p :o)"));
+        try(QueryExec aExec = QueryExec.dataset(dsg).query(queryString).build()) {
+            aExec.abort();
+            assertThrows(QueryCancelledException.class, ()-> execAction.accept(aExec));
+        }
+    }
+
+    /** Obtain an iterator and only afterwards abort the query exec.
+     * Operations on the iterator are now expected to fail. */
+    static <T> void cancellationTestForIterator(String queryString, Function<QueryExec, Iterator<T>> itFactory, Consumer<Iterator<T>> itConsumer) {
+        DatasetGraph dsg = DatasetGraphFactory.createTxnMem();
+        dsg.add(SSE.parseQuad("(_ :s :p :o)"));
+        try(QueryExec aExec = QueryExec.dataset(dsg).query(queryString).build()) {
+            Iterator<T> it = itFactory.apply(aExec);
+            aExec.abort();
+            assertThrows(QueryCancelledException.class, ()-> itConsumer.accept(it));
+        }
+    }
+
+    /**
+     * Test that creates iterators over a billion result rows and attempts to cancel them.
+     * If this test hangs then it is likely that something went wrong in the cancellation machinery.
+     */
+    @Test(timeout = 10000)
+    public void test_cancel_concurrent_1() {
+        int maxCancelDelayInMillis = 100;
+
+        int cpuCount = Runtime.getRuntime().availableProcessors();
+        // Spend at most roughly 1 second per cpu (10 tasks a max 100ms)
+        int taskCount = cpuCount * 10;
+
+        // Create a model with 1000 triples
+        Graph graph = GraphFactory.createDefaultGraph();
+        IntStream.range(0, 1000)
+            .mapToObj(i -> NodeFactory.createURI("http://www.example.org/r" + i))
+            .forEach(node -> graph.add(node, node, node));
+        Model model = ModelFactory.createModelForGraph(graph);
+
+        // Create a query that creates 3 cross joins - resulting in one billion result rows
+        Query query = QueryFactory.create("SELECT * { ?a ?b ?c . ?d ?e ?f . ?g ?h ?i . }");
+        Callable<QueryExecution> qeFactory = () -> QueryExecutionFactory.create(query, model);
+
+        runConcurrentAbort(taskCount, maxCancelDelayInMillis, qeFactory, TestQueryExecutionCancel::doCount);
+    }
+
+    private static final int doCount(QueryExecution qe) {
+        try (QueryExecution qe2 = qe) {
+            ResultSet rs = qe2.execSelect();
+            int size = ResultSetFormatter.consume(rs);
+            return size;
+        }
+    }
+
+    /**
+     * Reusable method that creates a parallel stream that starts query executions
+     * and schedules cancel tasks on a separate thread pool.
+     */
+    public static void runConcurrentAbort(int taskCount, int maxCancelDelay, Callable<QueryExecution> qeFactory, Function<QueryExecution, ?> processor) {
+        Random cancelDelayRandom = new Random();
+        ExecutorService executorService = Executors.newCachedThreadPool();
+        try {
+            List<Integer> list = IntStream.range(0, taskCount).boxed().collect(Collectors.toList());
+            list
+                .parallelStream()
+                .forEach(i -> {
+                    QueryExecution qe;
+                    try {
+                        qe = qeFactory.call();
+                    } catch (Exception e) {
+                        throw new RuntimeException("Failed to build a query execution", e);
+                    }
+                    Future<?> future = executorService.submit(() -> processor.apply(qe));
+                    int delayToAbort = cancelDelayRandom.nextInt(maxCancelDelay);
+                    try {
+                        Thread.sleep(delayToAbort);
+                    } catch (InterruptedException e) {
+                        throw new RuntimeException(e);
+                    }
+                    // System.out.println("Abort: " + qe);
+                    qe.abort();
+                    try {
+                        // System.out.println("Waiting for: " + qe);
+                        future.get();
+                    } catch (ExecutionException e) {
+                        Throwable cause = e.getCause();
+                        if (!(cause instanceof QueryCancelledException)) {
+                            // Unexpected exception - print out the stack trace
+                            e.printStackTrace();
+                        }
+                        Assert.assertEquals(QueryCancelledException.class, cause.getClass());
+                    } catch (InterruptedException e) {
+                        // Ignored
+                    } finally {
+                        // System.out.println("Completed: " + qe);
+                    }
+                });
+        } finally {
+            executorService.shutdownNow();
+        }
     }
 }