Parallelize TransactionImpl.readUnread() (#1080)

When a transaction only writes to a row+col and then has a collision
Fluo will read the row+col after the collision to look for orphaned
locks. This commit parallelizes this behavior by reading all row+cols
at once.

This commit accomplishes this by using a ParallelSnapshotScanner
instead of a SnapshotScanner.

Fixes #948
diff --git a/modules/core/src/main/java/org/apache/fluo/core/impl/ParallelSnapshotScanner.java b/modules/core/src/main/java/org/apache/fluo/core/impl/ParallelSnapshotScanner.java
index 9847580..4d2b37c 100644
--- a/modules/core/src/main/java/org/apache/fluo/core/impl/ParallelSnapshotScanner.java
+++ b/modules/core/src/main/java/org/apache/fluo/core/impl/ParallelSnapshotScanner.java
@@ -24,6 +24,7 @@
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
 import org.apache.accumulo.core.client.BatchScanner;
@@ -53,9 +54,11 @@
   private Function<ByteSequence, Bytes> rowConverter;
   private Function<Key, Column> columnConverter;
   private Map<Bytes, Set<Column>> readLocksSeen;
+  private Consumer<Entry<Key, Value>> writeLocksSeen;
 
   ParallelSnapshotScanner(Collection<Bytes> rows, Set<Column> columns, Environment env,
-      long startTs, TxStats stats, Map<Bytes, Set<Column>> readLocksSeen) {
+      long startTs, TxStats stats, Map<Bytes, Set<Column>> readLocksSeen,
+      Consumer<Entry<Key, Value>> writeLocksSeen) {
     this.rows = rows;
     this.columns = columns;
     this.env = env;
@@ -64,10 +67,11 @@
     this.rowConverter = new CachedBytesConverter(rows);
     this.columnConverter = new CachedColumnConverter(columns);
     this.readLocksSeen = readLocksSeen;
+    this.writeLocksSeen = writeLocksSeen;
   }
 
   ParallelSnapshotScanner(Collection<RowColumn> cells, Environment env, long startTs, TxStats stats,
-      Map<Bytes, Set<Column>> readLocksSeen) {
+      Map<Bytes, Set<Column>> readLocksSeen, Consumer<Entry<Key, Value>> writeLocksSeen) {
     for (RowColumn rc : cells) {
       byte[] r = rc.getRow().toArray();
       byte[] cf = rc.getColumn().getFamily().toArray();
@@ -87,6 +91,7 @@
     this.rowConverter = ByteUtil::toBytes;
     this.columnConverter = ColumnUtil::convert;
     this.readLocksSeen = readLocksSeen;
+    this.writeLocksSeen = writeLocksSeen;
   }
 
   private BatchScanner setupBatchScanner() {
@@ -180,6 +185,7 @@
         switch (colType) {
           case LOCK:
             locks.add(entry);
+            writeLocksSeen.accept(entry);
             break;
           case DATA:
             ret.computeIfAbsent(row, k -> new HashMap<>()).put(col,
diff --git a/modules/core/src/main/java/org/apache/fluo/core/impl/SnapshotScanner.java b/modules/core/src/main/java/org/apache/fluo/core/impl/SnapshotScanner.java
index f8f83f7..8397c15 100644
--- a/modules/core/src/main/java/org/apache/fluo/core/impl/SnapshotScanner.java
+++ b/modules/core/src/main/java/org/apache/fluo/core/impl/SnapshotScanner.java
@@ -75,7 +75,6 @@
   private final Environment env;
   private final TxStats stats;
   private final Opts config;
-  private Consumer<Entry<Key, Value>> locksSeen;
 
   static final long INITIAL_WAIT_TIME = 50;
   // TODO make configurable
@@ -155,8 +154,6 @@
 
       // read ahead a little bit looking for other locks to resolve
 
-      locksSeen.accept(lockEntry);
-
       long startTime = System.currentTimeMillis();
       long waitTime = INITIAL_WAIT_TIME;
 
@@ -174,7 +171,6 @@
 
           if (ColumnType.from(entry.getKey()) == ColumnType.LOCK) {
             locks.add(entry);
-            locksSeen.accept(lockEntry);
           }
 
           amountRead += entry.getKey().getSize() + entry.getValue().getSize();
@@ -241,13 +237,11 @@
     }
   }
 
-  SnapshotScanner(Environment env, Opts config, long startTs, TxStats stats,
-      Consumer<Entry<Key, Value>> locksSeen) {
+  SnapshotScanner(Environment env, Opts config, long startTs, TxStats stats) {
     this.env = env;
     this.config = config;
     this.startTs = startTs;
     this.stats = stats;
-    this.locksSeen = locksSeen;
   }
 
   @Override
diff --git a/modules/core/src/main/java/org/apache/fluo/core/impl/TransactionImpl.java b/modules/core/src/main/java/org/apache/fluo/core/impl/TransactionImpl.java
index b898492..8201a12 100644
--- a/modules/core/src/main/java/org/apache/fluo/core/impl/TransactionImpl.java
+++ b/modules/core/src/main/java/org/apache/fluo/core/impl/TransactionImpl.java
@@ -187,8 +187,7 @@
   @Override
   public Map<Column, Bytes> get(Bytes row, Set<Column> columns) {
     checkIfOpen();
-    return getImpl(row, columns, kve -> {
-    });
+    return getImpl(row, columns);
   }
 
   @Override
@@ -202,7 +201,8 @@
     env.getSharedResources().getVisCache().validate(columns);
 
     ParallelSnapshotScanner pss =
-        new ParallelSnapshotScanner(rows, columns, env, startTs, stats, readLocksSeen);
+        new ParallelSnapshotScanner(rows, columns, env, startTs, stats, readLocksSeen, kve -> {
+        });
 
     Map<Bytes, Map<Column, Bytes>> ret = pss.scan();
 
@@ -216,29 +216,11 @@
   @Override
   public Map<RowColumn, Bytes> get(Collection<RowColumn> rowColumns) {
     checkIfOpen();
-
-    if (rowColumns.isEmpty()) {
-      return Collections.emptyMap();
-    }
-
-    ParallelSnapshotScanner pss =
-        new ParallelSnapshotScanner(rowColumns, env, startTs, stats, readLocksSeen);
-
-    Map<Bytes, Map<Column, Bytes>> scan = pss.scan();
-    Map<RowColumn, Bytes> ret = new HashMap<>();
-
-    for (Entry<Bytes, Map<Column, Bytes>> entry : scan.entrySet()) {
-      updateColumnsRead(entry.getKey(), entry.getValue().keySet());
-      for (Entry<Column, Bytes> colVal : entry.getValue().entrySet()) {
-        ret.put(new RowColumn(entry.getKey(), colVal.getKey()), colVal.getValue());
-      }
-    }
-
-    return ret;
+    return getImpl(rowColumns, kve -> {
+    });
   }
 
-  private Map<Column, Bytes> getImpl(Bytes row, Set<Column> columns,
-      Consumer<Entry<Key, Value>> locksSeen) {
+  private Map<Column, Bytes> getImpl(Bytes row, Set<Column> columns) {
 
     // TODO push visibility filtering to server side?
 
@@ -270,7 +252,7 @@
     Map<Column, Bytes> ret = new HashMap<>();
     Set<Column> readLockCols = null;
 
-    for (Entry<Key, Value> kve : new SnapshotScanner(env, opts, startTs, stats, locksSeen)) {
+    for (Entry<Key, Value> kve : new SnapshotScanner(env, opts, startTs, stats)) {
 
       Column col = ColumnUtil.convert(kve.getKey());
       if (shouldCopy && !columns.contains(col)) {
@@ -293,6 +275,28 @@
     return ret;
   }
 
+  private Map<RowColumn, Bytes> getImpl(Collection<RowColumn> rowColumns,
+      Consumer<Entry<Key, Value>> writeLocksSeen) {
+    if (rowColumns.isEmpty()) {
+      return Collections.emptyMap();
+    }
+
+    ParallelSnapshotScanner pss =
+        new ParallelSnapshotScanner(rowColumns, env, startTs, stats, readLocksSeen, writeLocksSeen);
+
+    Map<Bytes, Map<Column, Bytes>> scan = pss.scan();
+    Map<RowColumn, Bytes> ret = new HashMap<>();
+
+    for (Entry<Bytes, Map<Column, Bytes>> entry : scan.entrySet()) {
+      updateColumnsRead(entry.getKey(), entry.getValue().keySet());
+      for (Entry<Column, Bytes> colVal : entry.getValue().entrySet()) {
+        ret.put(new RowColumn(entry.getKey(), colVal.getKey()), colVal.getValue());
+      }
+    }
+
+    return ret;
+  }
+
   @Override
   public ScannerBuilder scanner() {
     checkIfOpen();
@@ -542,7 +546,7 @@
    * This function helps handle the following case
    *
    * <OL>
-   * <LI>TX1 locls r1 col1
+   * <LI>TX1 locks r1 col1
    * <LI>TX1 fails before unlocking
    * <LI>TX2 attempts to write r1:col1 w/o reading it
    * </OL>
@@ -554,28 +558,29 @@
    *
    * @param cd Commit data
    */
-  private void readUnread(CommitData cd, Consumer<Entry<Key, Value>> locksSeen) {
-    // TODO make async
+  private void readUnread(CommitData cd, Consumer<Entry<Key, Value>> writeLocksSeen) {
     // TODO need to keep track of ranges read (not ranges passed in, but actual data read... user
     // may not iterate over entire range
-    Map<Bytes, Set<Column>> columnsToRead = new HashMap<>();
+    Collection<RowColumn> rowColumnsToRead = new ArrayList<>();
 
     for (Entry<Bytes, Set<Column>> entry : cd.getRejected().entrySet()) {
       Set<Column> rowColsRead = columnsRead.get(entry.getKey());
       if (rowColsRead == null) {
-        columnsToRead.put(entry.getKey(), entry.getValue());
+        for (Column column : entry.getValue()) {
+          rowColumnsToRead.add(new RowColumn(entry.getKey(), column));
+        }
       } else {
         HashSet<Column> colsToRead = new HashSet<>(entry.getValue());
         colsToRead.removeAll(rowColsRead);
         if (!colsToRead.isEmpty()) {
-          columnsToRead.put(entry.getKey(), colsToRead);
+          for (Column column : colsToRead) {
+            rowColumnsToRead.add(new RowColumn(entry.getKey(), column));
+          }
         }
       }
     }
 
-    for (Entry<Bytes, Set<Column>> entry : columnsToRead.entrySet()) {
-      getImpl(entry.getKey(), entry.getValue(), locksSeen);
-    }
+    getImpl(rowColumnsToRead, writeLocksSeen);
   }
 
   private void checkForOrphanedReadLocks(CommitData cd, Map<Bytes, Set<Column>> locksResolved)
@@ -642,15 +647,15 @@
 
   private void checkForOrphanedLocks(CommitData cd) throws Exception {
 
-    Map<Bytes, Set<Column>> locksSeen = new HashMap<>();
+    Map<Bytes, Set<Column>> writeLocksSeen = new HashMap<>();
 
     readUnread(cd, kve -> {
       Bytes row = ByteUtil.toBytes(kve.getKey().getRowData());
       Column col = ColumnUtil.convert(kve.getKey());
-      locksSeen.computeIfAbsent(row, k -> new HashSet<>()).add(col);
+      writeLocksSeen.computeIfAbsent(row, k -> new HashSet<>()).add(col);
     });
 
-    checkForOrphanedReadLocks(cd, locksSeen);
+    checkForOrphanedReadLocks(cd, writeLocksSeen);
   }
 
   private boolean checkForAckCollision(ConditionalMutation cm) {
@@ -1532,8 +1537,6 @@
   }
 
   public SnapshotScanner newSnapshotScanner(Span span, Collection<Column> columns) {
-    return new SnapshotScanner(env, new SnapshotScanner.Opts(span, columns, false), startTs, stats,
-        kve -> {
-        });
+    return new SnapshotScanner(env, new SnapshotScanner.Opts(span, columns, false), startTs, stats);
   }
 }