Convert DecryptResponseRunnable to a Callable, this closes apache/incubator-pirk#68
diff --git a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java
index 2231160..97b93fd 100644
--- a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java
+++ b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java
@@ -29,9 +29,12 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import org.apache.pirk.encryption.Paillier;
import org.apache.pirk.querier.wideskies.Querier;
@@ -48,6 +51,8 @@
public class DecryptResponse
{
private static final Logger logger = LoggerFactory.getLogger(DecryptResponse.class);
+
+ private static final BigInteger TWO_BI = BigInteger.valueOf(2);
private final Response response;
@@ -87,25 +92,24 @@
Paillier paillier = querier.getPaillier();
List<String> selectors = querier.getSelectors();
- HashMap<Integer,String> embedSelectorMap = querier.getEmbedSelectorMap();
+ Map<Integer,String> embedSelectorMap = querier.getEmbedSelectorMap();
// Perform decryption on the encrypted columns
- ArrayList<BigInteger> rElements = decryptElements(response.getResponseElements(), paillier);
+ List<BigInteger> rElements = decryptElements(response.getResponseElements(), paillier);
logger.debug("rElements.size() = " + rElements.size());
// Pull the necessary parameters
int dataPartitionBitSize = queryInfo.getDataPartitionBitSize();
// Initialize the result map and masks-- removes initialization checks from code below
- HashMap<String,BigInteger> selectorMaskMap = new HashMap<>();
+ Map<String,BigInteger> selectorMaskMap = new HashMap<>();
int selectorNum = 0;
- BigInteger twoBI = BigInteger.valueOf(2);
for (String selector : selectors)
{
resultMap.put(selector, new ArrayList<>());
// 2^{selectorNum*dataPartitionBitSize}(2^{dataPartitionBitSize} - 1)
- BigInteger mask = twoBI.pow(selectorNum * dataPartitionBitSize).multiply((twoBI.pow(dataPartitionBitSize).subtract(BigInteger.ONE)));
+ BigInteger mask = TWO_BI.pow(selectorNum * dataPartitionBitSize).multiply((TWO_BI.pow(dataPartitionBitSize).subtract(BigInteger.ONE)));
logger.debug("selector = " + selector + " mask = " + mask.toString(2));
selectorMaskMap.put(selector, mask);
@@ -120,7 +124,7 @@
}
int elementsPerThread = selectors.size() / numThreads; // Integral division.
- ArrayList<DecryptResponseRunnable> runnables = new ArrayList<>();
+ List<Future<Map<String,List<QueryResponseJSON>>>> futures = new ArrayList<>();
for (int i = 0; i < numThreads; ++i)
{
// Grab the range of the thread and create the corresponding partition of selectors
@@ -137,33 +141,30 @@
}
// Create the runnable and execute
- // selectorMaskMap and rElements are synchronized, pirWatchlist is copied, selectors is partitioned
- DecryptResponseRunnable runDec = new DecryptResponseRunnable(rElements, selectorsPartition, selectorMaskMap, queryInfo.clone(), embedSelectorMap);
- runnables.add(runDec);
- es.execute(runDec);
- }
-
- // Allow threads to complete
- es.shutdown(); // previously submitted tasks are executed, but no new tasks will be accepted
- boolean finished = es.awaitTermination(1, TimeUnit.DAYS); // waits until all tasks complete or until the specified timeout
-
- if (!finished)
- {
- throw new PIRException("Decryption threads did not finish in the alloted time");
+ DecryptResponseRunnable<Map<String,List<QueryResponseJSON>>> runDec = new DecryptResponseRunnable<>(rElements, selectorsPartition, selectorMaskMap, queryInfo.clone(), embedSelectorMap);
+ futures.add(es.submit(runDec));
}
// Pull all decrypted elements and add to resultMap
- for (DecryptResponseRunnable runner : runnables)
+ try
{
- resultMap.putAll(runner.getResultMap());
+ for (Future<Map<String,List<QueryResponseJSON>>> future : futures)
+ {
+ resultMap.putAll(future.get(1, TimeUnit.DAYS));
+ }
+ } catch (TimeoutException | ExecutionException e)
+ {
+ throw new PIRException("Exception in decryption threads.", e);
}
+
+ es.shutdown();
}
// Method to perform basic decryption of each raw response element - does not
// extract and reconstruct the data elements
- private ArrayList<BigInteger> decryptElements(TreeMap<Integer,BigInteger> elements, Paillier paillier)
+ private List<BigInteger> decryptElements(TreeMap<Integer,BigInteger> elements, Paillier paillier)
{
- ArrayList<BigInteger> decryptedElements = new ArrayList<>();
+ List<BigInteger> decryptedElements = new ArrayList<>();
for (BigInteger encElement : elements.values())
{
diff --git a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java
similarity index 83%
rename from src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java
rename to src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java
index 531ec6a..7b197d8 100644
--- a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java
+++ b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java
@@ -24,12 +24,14 @@
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
+import java.util.concurrent.Callable;
import org.apache.pirk.query.wideskies.QueryInfo;
import org.apache.pirk.query.wideskies.QueryUtils;
import org.apache.pirk.schema.query.QuerySchema;
import org.apache.pirk.schema.query.QuerySchemaRegistry;
import org.apache.pirk.schema.response.QueryResponseJSON;
+import org.apache.pirk.utils.PIRException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -39,11 +41,10 @@
* NOTE: rElements and selectorMaskMap are joint access objects, for now
*
*/
-public class DecryptResponseRunnable implements Runnable
+class DecryptResponseRunnable<V> implements Callable<Map<String,List<QueryResponseJSON>>>
{
private static final Logger logger = LoggerFactory.getLogger(DecryptResponseRunnable.class);
- private final Map<String,List<QueryResponseJSON>> resultMap = new HashMap<>(); // selector -> ArrayList of hits
private final List<BigInteger> rElements;
private final TreeMap<Integer,String> selectors;
private final Map<String,BigInteger> selectorMaskMap;
@@ -61,13 +62,8 @@
embedSelectorMap = embedSelectorMapInput;
}
- public Map<String,List<QueryResponseJSON>> getResultMap()
- {
- return resultMap;
- }
-
@Override
- public void run()
+ public Map<String,List<QueryResponseJSON>> call() throws PIRException
{
// Pull the necessary parameters
int dataPartitionBitSize = queryInfo.getDataPartitionBitSize();
@@ -76,19 +72,18 @@
QuerySchema qSchema = QuerySchemaRegistry.get(queryInfo.getQueryType());
String selectorName = qSchema.getSelectorName();
- // Initialize - removes checks below
+ // Result is a map of (selector -> List of hits).
+ Map<String,List<QueryResponseJSON>> resultMap = new HashMap<>();
for (String selector : selectors.values())
{
resultMap.put(selector, new ArrayList<QueryResponseJSON>());
}
- logger.debug("numResults = " + rElements.size() + " numPartitionsPerDataElement = " + numPartitionsPerDataElement);
-
// Pull the hits for each selector
- int hits = 0;
int maxHitsPerSelector = rElements.size() / numPartitionsPerDataElement; // Max number of data hits in the response elements for a given selector
- logger.debug("numHits = " + maxHitsPerSelector);
- while (hits < maxHitsPerSelector)
+ logger.debug("numResults = " + rElements.size() + " numPartitionsPerDataElement = " + numPartitionsPerDataElement + " maxHits = " + maxHitsPerSelector);
+
+ for (int hits = 0; hits < maxHitsPerSelector; hits++)
{
int selectorIndex = selectors.firstKey();
while (selectorIndex <= selectors.lastKey())
@@ -96,10 +91,9 @@
String selector = selectors.get(selectorIndex);
logger.debug("selector = " + selector);
- ArrayList<BigInteger> parts = new ArrayList<>();
- int partNum = 0;
+ List<BigInteger> parts = new ArrayList<>();
boolean zeroElement = true;
- while (partNum < numPartitionsPerDataElement)
+ for (int partNum = 0; partNum < numPartitionsPerDataElement; partNum++)
{
BigInteger part = (rElements.get(hits * numPartitionsPerDataElement + partNum)).and(selectorMaskMap.get(selector)); // pull off the correct bits
@@ -114,14 +108,7 @@
logger.debug("partNum = " + partNum + " part = " + part.intValue());
- if (zeroElement)
- {
- if (!part.equals(BigInteger.ZERO))
- {
- zeroElement = false;
- }
- }
- ++partNum;
+ zeroElement = zeroElement && part.equals(BigInteger.ZERO);
}
logger.debug("parts.size() = " + parts.size());
@@ -129,15 +116,7 @@
if (!zeroElement)
{
// Convert biHit to the appropriate QueryResponseJSON object, based on the queryType
- QueryResponseJSON qrJOSN = null;
- try
- {
- qrJOSN = QueryUtils.extractQueryResponseJSON(queryInfo, qSchema, parts);
- } catch (Exception e)
- {
- e.printStackTrace();
- throw new RuntimeException(e);
- }
+ QueryResponseJSON qrJOSN = QueryUtils.extractQueryResponseJSON(queryInfo, qSchema, parts);
qrJOSN.setMapping(selectorName, selector);
logger.debug("selector = " + selector + " qrJOSN = " + qrJOSN.getJSONString());
@@ -166,7 +145,8 @@
++selectorIndex;
}
- ++hits;
}
+
+ return resultMap;
}
}