ORC-622: Refactor separating BatchReader from TreeReader.

Fixes #503

Signed-off-by: Owen O'Malley <omalley@apache.org>
diff --git a/java/core/src/java/org/apache/orc/impl/ConvertTreeReaderFactory.java b/java/core/src/java/org/apache/orc/impl/ConvertTreeReaderFactory.java
index ead8f65..e6d3863 100644
--- a/java/core/src/java/org/apache/orc/impl/ConvertTreeReaderFactory.java
+++ b/java/core/src/java/org/apache/orc/impl/ConvertTreeReaderFactory.java
@@ -43,12 +43,12 @@
 import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.StringExpr;
-import org.apache.hadoop.hive.ql.util.TimestampUtils;
 import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
 import org.apache.orc.OrcProto;
 import org.apache.orc.TypeDescription;
 import org.apache.orc.TypeDescription.Category;
 import org.apache.orc.impl.reader.StripePlanner;
+import org.apache.orc.impl.reader.tree.TypeReader;
 import org.threeten.extra.chrono.HybridChronology;
 
 /**
@@ -61,9 +61,9 @@
    */
   public static class ConvertTreeReader extends TreeReader {
 
-    TreeReader fromReader;
+    TypeReader fromReader;
 
-    ConvertTreeReader(int columnId, TreeReader fromReader) throws IOException {
+    ConvertTreeReader(int columnId, TypeReader fromReader) throws IOException {
       super(columnId, null);
       this.fromReader = fromReader;
     }
@@ -249,13 +249,13 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       // Pass-thru.
       fromReader.checkEncoding(encoding);
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       // Pass-thru.
       fromReader.startStripe(planner);
     }
@@ -273,7 +273,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       // Pass-thru.
       fromReader.skipRows(items);
     }
@@ -367,7 +367,7 @@
     }
   }
 
-  private static TreeReader createFromInteger(int columnId,
+  private static TypeReader createFromInteger(int columnId,
                                               TypeDescription fileType,
                                               Context context) throws IOException {
     switch (fileType.getCategory()) {
@@ -1744,7 +1744,7 @@
     }
   }
 
-  private static TreeReader createAnyIntegerConvertTreeReader(int columnId,
+  private static TypeReader createAnyIntegerConvertTreeReader(int columnId,
                                                               TypeDescription fileType,
                                                               TypeDescription readerType,
                                                               Context context) throws IOException {
@@ -1798,7 +1798,7 @@
     }
   }
 
-  private static TreeReader createDoubleConvertTreeReader(int columnId,
+  private static TypeReader createDoubleConvertTreeReader(int columnId,
                                                           TypeDescription fileType,
                                                           TypeDescription readerType,
                                                           Context context) throws IOException {
@@ -1845,7 +1845,7 @@
     }
   }
 
-  private static TreeReader createDecimalConvertTreeReader(int columnId,
+  private static TypeReader createDecimalConvertTreeReader(int columnId,
                                                            TypeDescription fileType,
                                                            TypeDescription readerType,
                                                            Context context) throws IOException {
@@ -1891,7 +1891,7 @@
     }
   }
 
-  private static TreeReader createStringConvertTreeReader(int columnId,
+  private static TypeReader createStringConvertTreeReader(int columnId,
                                                           TypeDescription fileType,
                                                           TypeDescription readerType,
                                                           Context context) throws IOException {
@@ -1946,7 +1946,7 @@
     }
   }
 
-  private static TreeReader createTimestampConvertTreeReader(int columnId,
+  private static TypeReader createTimestampConvertTreeReader(int columnId,
                                                              TypeDescription fileType,
                                                              TypeDescription readerType,
                                                              Context context) throws IOException {
@@ -1995,7 +1995,7 @@
     }
   }
 
-  private static TreeReader createDateConvertTreeReader(int columnId,
+  private static TypeReader createDateConvertTreeReader(int columnId,
                                                         TypeDescription readerType,
                                                         Context context) throws IOException {
 
@@ -2036,7 +2036,7 @@
     }
   }
 
-  private static TreeReader createBinaryConvertTreeReader(int columnId,
+  private static TypeReader createBinaryConvertTreeReader(int columnId,
                                                           TypeDescription readerType,
                                                           Context context) throws IOException {
 
@@ -2199,7 +2199,7 @@
    * @return
    * @throws IOException
    */
-  public static TreeReader createConvertTreeReader(TypeDescription readerType,
+  public static TypeReader createConvertTreeReader(TypeDescription readerType,
                                                    Context context) throws IOException {
     final SchemaEvolution evolution = context.getSchemaEvolution();
 
diff --git a/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java b/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java
index f5a65ba..947889e 100644
--- a/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java
+++ b/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java
@@ -48,6 +48,7 @@
 import org.apache.orc.TypeDescription;
 import org.apache.orc.impl.reader.ReaderEncryption;
 import org.apache.orc.impl.reader.StripePlanner;
+import org.apache.orc.impl.reader.tree.BatchReader;
 import org.apache.orc.util.BloomFilter;
 import org.apache.orc.util.BloomFilterIO;
 import org.slf4j.Logger;
@@ -78,7 +79,7 @@
   private int currentStripe = -1;
   private long rowBaseInStripe = 0;
   private long rowCountInStripe = 0;
-  private final TreeReaderFactory.TreeReader reader;
+  private final BatchReader reader;
   private final OrcIndex indexes;
   private final SargApplier sargApp;
   // an array about which row groups aren't skipped
@@ -227,8 +228,7 @@
           .setProlepticGregorian(fileReader.writerUsedProlepticGregorian(),
               fileReader.options.getConvertToProlepticGregorian())
           .setEncryption(encryption);
-    reader = TreeReaderFactory.createTreeReader(evolution.getReaderSchema(),
-        readerContext);
+    reader = TreeReaderFactory.createRootReader(evolution.getReaderSchema(), readerContext);
 
     int columns = evolution.getFileSchema().getMaximumId() + 1;
     indexes = new OrcIndex(new OrcProto.RowIndex[columns],
@@ -1109,7 +1109,7 @@
    * @throws IOException
    */
   private boolean advanceToNextRow(
-      TreeReaderFactory.TreeReader reader, long nextRow, boolean canAdvanceStripe)
+    BatchReader reader, long nextRow, boolean canAdvanceStripe)
       throws IOException {
     long nextRowInStripe = nextRow - rowBaseInStripe;
     // check for row skipping
@@ -1166,8 +1166,6 @@
       rowInStripe += batchSize;
       reader.setVectorColumnCount(batch.getDataColumnCount());
       reader.nextBatch(batch, batchSize);
-      batch.selectedInUse = false;
-      batch.size = batchSize;
       advanceToNextRow(reader, rowInStripe + rowBaseInStripe, true);
       return batch.size  != 0;
     } catch (IOException e) {
@@ -1261,7 +1259,7 @@
     }
   }
 
-  private void seekToRowEntry(TreeReaderFactory.TreeReader reader, int rowEntry)
+  private void seekToRowEntry(BatchReader reader, int rowEntry)
       throws IOException {
     OrcProto.RowIndex[] rowIndices = indexes.getRowGroupIndex();
     PositionProvider[] index = new PositionProvider[rowIndices.length];
diff --git a/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java b/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java
index cac4067..d17de68 100644
--- a/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java
+++ b/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java
@@ -48,6 +48,10 @@
 import org.apache.orc.OrcProto;
 import org.apache.orc.impl.reader.ReaderEncryption;
 import org.apache.orc.impl.reader.StripePlanner;
+import org.apache.orc.impl.reader.tree.BatchReader;
+import org.apache.orc.impl.reader.tree.PrimitiveBatchReader;
+import org.apache.orc.impl.reader.tree.StructBatchReader;
+import org.apache.orc.impl.reader.tree.TypeReader;
 import org.apache.orc.impl.writer.TimestampTreeWriter;
 
 /**
@@ -160,10 +164,9 @@
     }
   }
 
-  public abstract static class TreeReader {
+  public abstract static class TreeReader implements TypeReader {
     protected final int columnId;
     protected BitFieldReader present = null;
-    protected int vectorColumnCount;
     protected final Context context;
 
     static final long[] powerOfTenTable = {
@@ -200,14 +203,9 @@
       } else {
         present = new BitFieldReader(in);
       }
-      vectorColumnCount = -1;
     }
 
-    void setVectorColumnCount(int vectorColumnCount) {
-      this.vectorColumnCount = vectorColumnCount;
-    }
-
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
             columnId);
@@ -230,7 +228,7 @@
       }
     }
 
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       checkEncoding(planner.getEncoding(columnId));
       InStream in = planner.getStream(new StreamName(columnId,
           OrcProto.Stream.Kind.PRESENT));
@@ -271,21 +269,6 @@
       }
     }
 
-    abstract void skipRows(long rows) throws IOException;
-
-    /**
-     * Called at the top level to read into the given batch.
-     * @param batch the batch to read into
-     * @param batchSize the number of rows to read
-     * @throws IOException
-     */
-    public void nextBatch(VectorizedRowBatch batch,
-                          int batchSize) throws IOException {
-      batch.cols[0].reset();
-      batch.cols[0].ensureSize(batchSize, false);
-      nextVector(batch.cols[0], null, batchSize);
-    }
-
     /**
      * Populates the isNull vector array in the previousVector object based on
      * the present stream values. This function is called from all the child
@@ -334,6 +317,7 @@
       return present;
     }
 
+    @Override
     public int getColumnId() {
       return columnId;
     }
@@ -351,7 +335,7 @@
     }
 
     @Override
-    void skipRows(long rows) {
+    public void skipRows(long rows) {
       // PASS
     }
 
@@ -388,7 +372,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       reader = new BitFieldReader(planner.getStream(new StreamName(columnId,
           OrcProto.Stream.Kind.DATA)));
@@ -406,7 +390,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
 
@@ -437,7 +421,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       reader = new RunLengthByteReader(planner.getStream(new StreamName(columnId,
           OrcProto.Stream.Kind.DATA)));
@@ -468,7 +452,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
   }
@@ -491,7 +475,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -500,7 +484,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -533,7 +517,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
   }
@@ -556,7 +540,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -565,7 +549,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -598,7 +582,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
   }
@@ -622,7 +606,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -631,7 +615,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -664,7 +648,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
   }
@@ -684,7 +668,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -754,7 +738,7 @@
     }
 
     @Override
-    protected void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       for (int i = 0; i < items; ++i) {
         utils.readFloat(stream);
@@ -777,7 +761,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name =
           new StreamName(columnId,
@@ -847,7 +831,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       long len = items * 8;
       while (len > 0) {
@@ -877,7 +861,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -886,7 +870,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -921,7 +905,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       long lengthToSkip = 0;
       for (int i = 0; i < items; ++i) {
@@ -991,7 +975,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -1000,7 +984,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       OrcProto.ColumnEncoding.Kind kind = planner.getEncoding(columnId).getKind();
       data = createIntegerReader(kind,
@@ -1103,7 +1087,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       data.skip(items);
       nanos.skip(items);
@@ -1134,7 +1118,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -1143,7 +1127,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -1187,7 +1171,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
   }
@@ -1228,7 +1212,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -1237,7 +1221,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       valueStream = planner.getStream(new StreamName(columnId,
           OrcProto.Stream.Kind.DATA));
@@ -1339,7 +1323,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       HiveDecimalWritable scratchDecWritable = new HiveDecimalWritable();
       for (int i = 0; i < items; i++) {
@@ -1378,7 +1362,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
             columnId);
@@ -1386,7 +1370,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       InStream stream = planner.getStream(new StreamName(columnId,
           OrcProto.Stream.Kind.DATA));
@@ -1442,7 +1426,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       valueReader.skip(items);
     }
@@ -1454,7 +1438,7 @@
    * dictionary encoding was used.
    */
   public static class StringTreeReader extends TreeReader {
-    protected TreeReader reader;
+    protected TypeReader reader;
 
     StringTreeReader(int columnId, Context context) throws IOException {
       super(columnId, context);
@@ -1483,12 +1467,12 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       reader.checkEncoding(encoding);
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       // For each stripe, checks the encoding and initializes the appropriate
       // reader
       switch (planner.getEncoding(columnId).getKind()) {
@@ -1525,7 +1509,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skipRows(items);
     }
   }
@@ -1633,7 +1617,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT &&
           encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -1642,7 +1626,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       StreamName name = new StreamName(columnId,
           OrcProto.Stream.Kind.DATA);
@@ -1680,7 +1664,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       long lengthToSkip = 0;
       for (int i = 0; i < items; ++i) {
@@ -1737,7 +1721,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DICTIONARY &&
           encoding.getKind() != OrcProto.ColumnEncoding.Kind.DICTIONARY_V2) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -1746,7 +1730,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
 
       // read the dictionary blob
@@ -1890,7 +1874,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       reader.skip(countNonNulls(items));
     }
 
@@ -2011,7 +1995,7 @@
   }
 
   public static class StructTreeReader extends TreeReader {
-    protected final TreeReader[] fields;
+    public final TypeReader[] fields;
 
     protected StructTreeReader(int columnId,
                                TypeDescription readerSchema,
@@ -2026,14 +2010,14 @@
       }
     }
 
-    public TreeReader[] getChildReaders() {
+    public TypeReader[] getChildReaders() {
       return fields;
     }
 
     protected StructTreeReader(int columnId, InStream present,
                                Context context,
                                OrcProto.ColumnEncoding encoding,
-                               TreeReader[] childReaders) throws IOException {
+                               TypeReader[] childReaders) throws IOException {
       super(columnId, present, context);
       if (encoding != null) {
         checkEncoding(encoding);
@@ -2044,7 +2028,7 @@
     @Override
     public void seek(PositionProvider[] index) throws IOException {
       super.seek(index);
-      for (TreeReader kid : fields) {
+      for (TypeReader kid : fields) {
         if (kid != null) {
           kid.seek(index);
         }
@@ -2052,20 +2036,6 @@
     }
 
     @Override
-    public void nextBatch(VectorizedRowBatch batch,
-                          int batchSize) throws IOException {
-      for(int i=0; i < fields.length &&
-          (vectorColumnCount == -1 || i < vectorColumnCount); ++i) {
-        ColumnVector colVector = batch.cols[i];
-        if (colVector != null) {
-          colVector.reset();
-          colVector.ensureSize((int) batchSize, false);
-          fields[i].nextVector(colVector, null, batchSize);
-        }
-      }
-    }
-
-    @Override
     public void nextVector(ColumnVector previousVector,
                            boolean[] isNull,
                            final int batchSize) throws IOException {
@@ -2085,9 +2055,9 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
-      for (TreeReader field : fields) {
+      for (TypeReader field : fields) {
         if (field != null) {
           field.startStripe(planner);
         }
@@ -2095,9 +2065,9 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
-      for (TreeReader field : fields) {
+      for (TypeReader field : fields) {
         if (field != null) {
           field.skipRows(items);
         }
@@ -2106,7 +2076,7 @@
   }
 
   public static class UnionTreeReader extends TreeReader {
-    protected final TreeReader[] fields;
+    protected final TypeReader[] fields;
     protected RunLengthByteReader tags;
 
     protected UnionTreeReader(int fileColumn,
@@ -2125,7 +2095,7 @@
     protected UnionTreeReader(int columnId, InStream present,
                               Context context,
                               OrcProto.ColumnEncoding encoding,
-                              TreeReader[] childReaders) throws IOException {
+                              TypeReader[] childReaders) throws IOException {
       super(columnId, present, context);
       if (encoding != null) {
         checkEncoding(encoding);
@@ -2137,7 +2107,7 @@
     public void seek(PositionProvider[] index) throws IOException {
       super.seek(index);
       tags.seek(index[columnId]);
-      for (TreeReader kid : fields) {
+      for (TypeReader kid : fields) {
         kid.seek(index);
       }
     }
@@ -2165,11 +2135,11 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       tags = new RunLengthByteReader(planner.getStream(new StreamName(columnId,
           OrcProto.Stream.Kind.DATA)));
-      for (TreeReader field : fields) {
+      for (TypeReader field : fields) {
         if (field != null) {
           field.startStripe(planner);
         }
@@ -2177,7 +2147,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       long[] counts = new long[fields.length];
       for (int i = 0; i < items; ++i) {
@@ -2190,7 +2160,7 @@
   }
 
   public static class ListTreeReader extends TreeReader {
-    protected final TreeReader elementReader;
+    protected final TypeReader elementReader;
     protected IntegerReader lengths = null;
 
     protected ListTreeReader(int fileColumn,
@@ -2206,7 +2176,7 @@
                              Context context,
                              InStream data,
                              OrcProto.ColumnEncoding encoding,
-                             TreeReader elementReader) throws IOException {
+                             TypeReader elementReader) throws IOException {
       super(columnId, present, context);
       if (data != null && encoding != null) {
         checkEncoding(encoding);
@@ -2247,7 +2217,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -2256,7 +2226,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       lengths = createIntegerReader(planner.getEncoding(columnId).getKind(),
           planner.getStream(new StreamName(columnId,
@@ -2267,7 +2237,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       long childSkip = 0;
       for (long i = 0; i < items; ++i) {
@@ -2278,8 +2248,8 @@
   }
 
   public static class MapTreeReader extends TreeReader {
-    protected final TreeReader keyReader;
-    protected final TreeReader valueReader;
+    protected final TypeReader keyReader;
+    protected final TypeReader valueReader;
     protected IntegerReader lengths = null;
 
     protected MapTreeReader(int fileColumn,
@@ -2297,8 +2267,8 @@
                             Context context,
                             InStream data,
                             OrcProto.ColumnEncoding encoding,
-                            TreeReader keyReader,
-                            TreeReader valueReader) throws IOException {
+                            TypeReader keyReader,
+                            TypeReader valueReader) throws IOException {
       super(columnId, present, context);
       if (data != null && encoding != null) {
         checkEncoding(encoding);
@@ -2342,7 +2312,7 @@
     }
 
     @Override
-    void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
+    public void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException {
       if ((encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT) &&
           (encoding.getKind() != OrcProto.ColumnEncoding.Kind.DIRECT_V2)) {
         throw new IOException("Unknown encoding " + encoding + " in column " +
@@ -2351,7 +2321,7 @@
     }
 
     @Override
-    void startStripe(StripePlanner planner) throws IOException {
+    public void startStripe(StripePlanner planner) throws IOException {
       super.startStripe(planner);
       lengths = createIntegerReader(planner.getEncoding(columnId).getKind(),
           planner.getStream(new StreamName(columnId,
@@ -2365,7 +2335,7 @@
     }
 
     @Override
-    void skipRows(long items) throws IOException {
+    public void skipRows(long items) throws IOException {
       items = countNonNulls(items);
       long childSkip = 0;
       for (long i = 0; i < items; ++i) {
@@ -2376,7 +2346,7 @@
     }
   }
 
-  public static TreeReader createTreeReader(TypeDescription readerType,
+  public static TypeReader createTreeReader(TypeDescription readerType,
                                             Context context
                                             ) throws IOException {
     OrcFile.Version version = context.getFileFormat();
@@ -2446,4 +2416,14 @@
             readerTypeCategory);
     }
   }
+
+  public static BatchReader createRootReader(TypeDescription readerType, Context context)
+          throws IOException {
+    TypeReader reader = createTreeReader(readerType, context);
+    if (reader instanceof StructTreeReader) {
+      return new StructBatchReader((StructTreeReader) reader);
+    } else {
+      return new PrimitiveBatchReader(reader);
+    }
+  }
 }
diff --git a/java/core/src/java/org/apache/orc/impl/reader/StripePlanner.java b/java/core/src/java/org/apache/orc/impl/reader/StripePlanner.java
index 463a70c..d8d08fb 100644
--- a/java/core/src/java/org/apache/orc/impl/reader/StripePlanner.java
+++ b/java/core/src/java/org/apache/orc/impl/reader/StripePlanner.java
@@ -494,15 +494,14 @@
     return result;
   }
 
-  private static class StreamInformation {
-    final OrcProto.Stream.Kind kind;
-    final int column;
-    final long offset;
-    final long length;
-    BufferChunk firstChunk;
+  public static class StreamInformation {
+    public final OrcProto.Stream.Kind kind;
+    public final int column;
+    public final long offset;
+    public final long length;
+    public BufferChunk firstChunk;
 
-    StreamInformation(OrcProto.Stream.Kind kind,
-                      int column, long offset, long length) {
+    public StreamInformation(OrcProto.Stream.Kind kind, int column, long offset, long length) {
       this.kind = kind;
       this.column = column;
       this.offset = offset;
diff --git a/java/core/src/java/org/apache/orc/impl/reader/tree/BatchReader.java b/java/core/src/java/org/apache/orc/impl/reader/tree/BatchReader.java
new file mode 100644
index 0000000..710d2cd
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/tree/BatchReader.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.orc.impl.reader.tree;
+
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.orc.impl.PositionProvider;
+import org.apache.orc.impl.reader.StripePlanner;
+
+import java.io.IOException;
+
+
+/**
+ * The top level interface that the reader uses to read the columns from the
+ * ORC file.
+ */
+public abstract class BatchReader {
+  // The row type reader
+  public final TypeReader rootType;
+
+  protected int vectorColumnCount = -1;
+
+  public BatchReader(TypeReader rootType) {
+    this.rootType = rootType;
+  }
+
+  public void startStripe(StripePlanner planner) throws IOException {
+    rootType.startStripe(planner);
+  }
+
+  public void setVectorColumnCount(int vectorColumnCount) {
+    this.vectorColumnCount = vectorColumnCount;
+  }
+
+  /**
+   * Read the next batch of data from the file.
+   * @param batch     the batch to read into
+   * @param batchSize the number of rows to read
+   * @throws IOException errors reading the file
+   */
+  public abstract void nextBatch(VectorizedRowBatch batch,
+                                 int batchSize) throws IOException;
+
+  protected void resetBatch(VectorizedRowBatch batch, int batchSize) {
+    batch.selectedInUse = false;
+    batch.size = batchSize;
+  }
+
+  public void skipRows(long rows) throws IOException {
+    rootType.skipRows(rows);
+  }
+
+  public void seek(PositionProvider[] index) throws IOException {
+    rootType.seek(index);
+  }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/tree/PrimitiveBatchReader.java b/java/core/src/java/org/apache/orc/impl/reader/tree/PrimitiveBatchReader.java
new file mode 100644
index 0000000..669b6a5
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/tree/PrimitiveBatchReader.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.orc.impl.reader.tree;
+
+import java.io.IOException;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+
+public class PrimitiveBatchReader extends BatchReader {
+
+  public PrimitiveBatchReader(TypeReader rowReader) {
+    super(rowReader);
+  }
+
+  @Override
+  public void nextBatch(VectorizedRowBatch batch,
+                        int batchSize) throws IOException {
+  batch.cols[0].reset();
+  batch.cols[0].ensureSize(batchSize, false);
+  rootType.nextVector(batch.cols[0], null, batchSize);
+  resetBatch(batch, batchSize);
+  }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/tree/StructBatchReader.java b/java/core/src/java/org/apache/orc/impl/reader/tree/StructBatchReader.java
new file mode 100644
index 0000000..9503436
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/tree/StructBatchReader.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.orc.impl.reader.tree;
+
+import org.apache.hadoop.hive.ql.exec.vector.ColumnVector;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.orc.impl.TreeReaderFactory;
+
+import java.io.IOException;
+
+public class StructBatchReader extends BatchReader {
+
+  public StructBatchReader(TreeReaderFactory.StructTreeReader rowReader) {
+    super(rowReader);
+  }
+
+  @Override
+  public void nextBatch(VectorizedRowBatch batch, int batchSize) throws IOException {
+    TypeReader[] children = ((TreeReaderFactory.StructTreeReader) rootType).fields;
+    for (int i = 0; i < children.length &&
+                    (vectorColumnCount == -1 || i < vectorColumnCount); ++i) {
+      ColumnVector colVector = batch.cols[i];
+      if (colVector != null) {
+        colVector.reset();
+        colVector.ensureSize(batchSize, false);
+        children[i].nextVector(colVector, null, batchSize);
+      }
+    }
+    resetBatch(batch, batchSize);
+  }
+}
diff --git a/java/core/src/java/org/apache/orc/impl/reader/tree/TypeReader.java b/java/core/src/java/org/apache/orc/impl/reader/tree/TypeReader.java
new file mode 100644
index 0000000..5932bee
--- /dev/null
+++ b/java/core/src/java/org/apache/orc/impl/reader/tree/TypeReader.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.orc.impl.reader.tree;
+
+import org.apache.hadoop.hive.ql.exec.vector.ColumnVector;
+import org.apache.orc.OrcProto;
+import org.apache.orc.impl.PositionProvider;
+import org.apache.orc.impl.reader.StripePlanner;
+
+import java.io.IOException;
+
+public interface TypeReader {
+  void checkEncoding(OrcProto.ColumnEncoding encoding) throws IOException;
+
+  void startStripe(StripePlanner planner) throws IOException;
+
+  void seek(PositionProvider[] index) throws IOException;
+
+  void seek(PositionProvider index) throws IOException;
+
+  void skipRows(long rows) throws IOException;
+
+  void nextVector(ColumnVector previous,
+                  boolean[] isNull,
+                  int batchSize) throws IOException;
+
+  int getColumnId();
+}
diff --git a/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java b/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java
index a0981a6..be7b616 100644
--- a/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java
+++ b/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java
@@ -53,6 +53,9 @@
 import org.apache.orc.Writer;
 import org.apache.orc.impl.reader.ReaderEncryption;
 import org.apache.orc.impl.reader.StripePlanner;
+import org.apache.orc.impl.reader.tree.BatchReader;
+import org.apache.orc.impl.reader.tree.StructBatchReader;
+import org.apache.orc.impl.reader.tree.TypeReader;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -1641,13 +1644,14 @@
 
     TreeReaderFactory.Context treeContext =
         new TreeReaderFactory.ReaderContext().setSchemaEvolution(evo);
-    TreeReaderFactory.TreeReader reader =
-        TreeReaderFactory.createTreeReader(readType, treeContext);
+    BatchReader reader =
+        TreeReaderFactory.createRootReader(readType, treeContext);
 
     // check to make sure the tree reader is built right
-    assertEquals(TreeReaderFactory.StructTreeReader.class, reader.getClass());
-    TreeReaderFactory.TreeReader[] children =
-        ((TreeReaderFactory.StructTreeReader) reader).getChildReaders();
+    assertEquals(StructBatchReader.class, reader.getClass());
+    assertEquals(TreeReaderFactory.StructTreeReader.class, reader.rootType.getClass());
+    TypeReader[] children =
+        ((TreeReaderFactory.StructTreeReader) reader.rootType).getChildReaders();
     assertEquals(3, children.length);
     assertEquals(TreeReaderFactory.NullTreeReader.class, children[0].getClass());
     assertEquals(TreeReaderFactory.StringTreeReader.class, children[1].getClass());
@@ -1694,12 +1698,12 @@
 
     TreeReaderFactory.Context treeContext =
         new TreeReaderFactory.ReaderContext().setSchemaEvolution(evo);
-    TreeReaderFactory.TreeReader reader =
+    TypeReader reader =
         TreeReaderFactory.createTreeReader(readType, treeContext);
 
     // check to make sure the tree reader is built right
     assertEquals(TreeReaderFactory.StructTreeReader.class, reader.getClass());
-    TreeReaderFactory.TreeReader[] children =
+    TypeReader[] children =
         ((TreeReaderFactory.StructTreeReader) reader).getChildReaders();
     assertEquals(2, children.length);
     assertEquals(TreeReaderFactory.IntTreeReader.class, children[0].getClass());