CRUNCH-609: Improved KafkaRecordReader to keep retrying when the range of offsets has not been fully consumed.
diff --git a/crunch-kafka/src/main/java/org/apache/crunch/kafka/KafkaRecordsIterable.java b/crunch-kafka/src/main/java/org/apache/crunch/kafka/KafkaRecordsIterable.java
index 8fec7f8..7525488 100644
--- a/crunch-kafka/src/main/java/org/apache/crunch/kafka/KafkaRecordsIterable.java
+++ b/crunch-kafka/src/main/java/org/apache/crunch/kafka/KafkaRecordsIterable.java
@@ -18,7 +18,6 @@
 package org.apache.crunch.kafka;
 
 import org.apache.crunch.Pair;
-import org.apache.crunch.kafka.inputformat.KafkaRecordReader;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
diff --git a/crunch-kafka/src/main/java/org/apache/crunch/kafka/inputformat/KafkaRecordReader.java b/crunch-kafka/src/main/java/org/apache/crunch/kafka/inputformat/KafkaRecordReader.java
index 1420519..ad73217 100644
--- a/crunch-kafka/src/main/java/org/apache/crunch/kafka/inputformat/KafkaRecordReader.java
+++ b/crunch-kafka/src/main/java/org/apache/crunch/kafka/inputformat/KafkaRecordReader.java
@@ -17,6 +17,7 @@
  */
 package org.apache.crunch.kafka.inputformat;
 
+import org.apache.crunch.CrunchRuntimeException;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.mapreduce.InputSplit;
 import org.apache.hadoop.mapreduce.RecordReader;
@@ -34,11 +35,11 @@
 import java.util.Collections;
 import java.util.Iterator;
 
+import static org.apache.crunch.kafka.KafkaSource.CONSUMER_POLL_TIMEOUT_DEFAULT;
+import static org.apache.crunch.kafka.KafkaSource.CONSUMER_POLL_TIMEOUT_KEY;
 import static org.apache.crunch.kafka.KafkaUtils.KAFKA_RETRY_ATTEMPTS_DEFAULT;
 import static org.apache.crunch.kafka.KafkaUtils.KAFKA_RETRY_ATTEMPTS_KEY;
 import static org.apache.crunch.kafka.KafkaUtils.getKafkaConnectionProperties;
-import static org.apache.crunch.kafka.KafkaSource.CONSUMER_POLL_TIMEOUT_DEFAULT;
-import static org.apache.crunch.kafka.KafkaSource.CONSUMER_POLL_TIMEOUT_KEY;
 
 /**
  * A {@link RecordReader} for pulling data from Kafka.
@@ -56,11 +57,15 @@
   private long consumerPollTimeout;
   private long maxNumberOfRecords;
   private long startingOffset;
+  private long currentOffset;
   private int maxNumberAttempts;
 
   @Override
   public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException {
     consumer = new KafkaConsumer<>(getKafkaConnectionProperties(taskAttemptContext.getConfiguration()));
+    if(!(inputSplit instanceof KafkaInputSplit)){
+      throw new CrunchRuntimeException("InputSplit for RecordReader is not valid split type.");
+    }
     KafkaInputSplit split = (KafkaInputSplit) inputSplit;
     TopicPartition topicPartition = split.getTopicPartition();
     consumer.assign(Collections.singletonList(topicPartition));
@@ -70,9 +75,10 @@
     startingOffset = split.getStartingOffset();
     consumer.seek(topicPartition,startingOffset);
 
+    currentOffset = startingOffset - 1;
     endingOffset = split.getEndingOffset();
 
-    maxNumberOfRecords = endingOffset - split.getStartingOffset();
+    maxNumberOfRecords = endingOffset - startingOffset;
     if(LOG.isInfoEnabled()) {
       LOG.info("Reading data from {} between {} and {}", new Object[]{topicPartition, startingOffset, endingOffset});
     }
@@ -84,16 +90,25 @@
 
   @Override
   public boolean nextKeyValue() throws IOException, InterruptedException {
-    recordIterator = getRecords();
-    record = recordIterator.hasNext() ? recordIterator.next() : null;
-    if(LOG.isDebugEnabled()){
-      if(record != null) {
+    if(hasPendingData()) {
+      recordIterator = getRecords();
+      record = recordIterator.hasNext() ? recordIterator.next() : null;
+      if (record != null) {
         LOG.debug("nextKeyValue: Retrieved record with offset {}", record.offset());
-      }else{
-        LOG.debug("nextKeyValue: Retrieved null record");
+        long oldOffset = currentOffset;
+        currentOffset = record.offset();
+        LOG.debug("Current offset will be updated to be [{}]", currentOffset);
+        if (LOG.isWarnEnabled() && (currentOffset - oldOffset > 1)) {
+          LOG.warn("Offset increment was larger than expected value of one, old {} new {}", oldOffset, currentOffset);
+        }
+        return true;
+      } else {
+        LOG.warn("nextKeyValue: Retrieved null record last offset was {} and ending offset is {}", currentOffset,
+                endingOffset);
       }
     }
-    return record != null && record.offset() < endingOffset;
+    record = null;
+    return false;
   }
 
   @Override
@@ -109,39 +124,53 @@
   @Override
   public float getProgress() throws IOException, InterruptedException {
     //not most accurate but gives reasonable estimate
-    return record == null ? 0.0f : ((float) (record.offset()- startingOffset)) / maxNumberOfRecords;
+    return ((float) (currentOffset - startingOffset +1)) / maxNumberOfRecords;
+  }
+
+  private boolean hasPendingData(){
+    //offset range is exclusive at the end which means the ending offset is one higher
+    // than the actual physical last offset
+    return currentOffset < endingOffset-1;
   }
 
   private Iterator<ConsumerRecord<K, V>> getRecords() {
     if (recordIterator == null || !recordIterator.hasNext()) {
       ConsumerRecords<K, V> records = null;
       int numTries = 0;
-      boolean notSuccess = false;
-      while(!notSuccess && numTries < maxNumberAttempts) {
+      boolean success = false;
+      while(!success && numTries < maxNumberAttempts) {
         try {
-          records = consumer.poll(consumerPollTimeout);
-          notSuccess = true;
+          records = getConsumer().poll(consumerPollTimeout);
         } catch (RetriableException re) {
           numTries++;
           if (numTries < maxNumberAttempts) {
-            LOG.warn("Error pulling messages from Kafka. Retrying with attempt {}", numTries, re);
+            LOG.warn("Error pulling messages from Kafka. Retrying with attempt {}", numTries+1, re);
           } else {
             LOG.error("Error pulling messages from Kafka. Exceeded maximum number of attempts {}", maxNumberAttempts, re);
             throw re;
           }
         }
+        if((records == null || records.isEmpty()) && hasPendingData()){
+          LOG.warn("No records retrieved but pending offsets to consume therefore polling again.");
+        }else{
+          success = true;
+        }
       }
 
-      if(LOG.isDebugEnabled() && records != null){
-        LOG.debug("No records retrieved from Kafka therefore nothing to iterate over.");
+      if(records == null || records.isEmpty()){
+        LOG.info("No records retrieved from Kafka therefore nothing to iterate over.");
       }else{
-        LOG.debug("Retrieved records from Kafka to iterate over.");
+        LOG.info("Retrieved records from Kafka to iterate over.");
       }
       return records != null ? records.iterator() : ConsumerRecords.<K, V>empty().iterator();
     }
     return recordIterator;
   }
 
+  protected Consumer<K,V> getConsumer(){
+    return consumer;
+  }
+
   @Override
   public void close() throws IOException {
     LOG.debug("Closing the record reader.");
diff --git a/crunch-kafka/src/test/java/org/apache/crunch/kafka/inputformat/KafkaRecordReaderIT.java b/crunch-kafka/src/test/java/org/apache/crunch/kafka/inputformat/KafkaRecordReaderIT.java
index ba5b65b..15970c1 100644
--- a/crunch-kafka/src/test/java/org/apache/crunch/kafka/inputformat/KafkaRecordReaderIT.java
+++ b/crunch-kafka/src/test/java/org/apache/crunch/kafka/inputformat/KafkaRecordReaderIT.java
@@ -22,14 +22,18 @@
 import org.apache.crunch.kafka.ClusterTest;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.mapred.TaskAttemptContext;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.common.TopicPartition;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.BeforeClass;
+import org.junit.Ignore;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TestName;
 import org.junit.runner.RunWith;
+import org.mockito.Matchers;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
 
@@ -54,6 +58,9 @@
   @Mock
   private TaskAttemptContext context;
 
+  @Mock
+  private Consumer<String, String> consumer;
+
   @Rule
   public TestName testName = new TestName();
   private Properties consumerProps;
@@ -77,6 +84,7 @@
     consumerProps = ClusterTest.getConsumerProperties();
     config = ClusterTest.getConsumerConfig();
     when(context.getConfiguration()).thenReturn(config);
+    when(consumer.poll(Matchers.anyLong())).thenReturn(null);
   }
 
   @Test
@@ -119,4 +127,219 @@
     //validate the same number of unique keys was read as were written.
     assertThat(keysRead.size(), is(keys.size()));
   }
+
+  @Test
+  public void pollReturnsNullAtStart() throws IOException, InterruptedException {
+    List<String> keys = ClusterTest.writeData(ClusterTest.getProducerProperties(), topic, "batch", 10, 10);
+
+    Map<TopicPartition, Long> startOffsets = getBrokerOffsets(consumerProps, OffsetRequest.EarliestTime(), topic);
+    Map<TopicPartition, Long> endOffsets = getBrokerOffsets(consumerProps, OffsetRequest.LatestTime(), topic);
+
+    Map<TopicPartition, Pair<Long, Long>> offsets = new HashMap<>();
+    for (Map.Entry<TopicPartition, Long> entry : startOffsets.entrySet()) {
+      Long endingOffset = endOffsets.get(entry.getKey());
+      offsets.put(entry.getKey(), Pair.of(entry.getValue(), endingOffset));
+    }
+
+    KafkaInputFormat.writeOffsetsToConfiguration(offsets, config);
+
+    Set<String> keysRead = new HashSet<>();
+    //read all data from all splits
+    for (Map.Entry<TopicPartition, Pair<Long, Long>> partitionInfo : offsets.entrySet()) {
+      KafkaInputSplit split = new KafkaInputSplit(partitionInfo.getKey().topic(), partitionInfo.getKey().partition(),
+              partitionInfo.getValue().first(), partitionInfo.getValue().second());
+
+      KafkaRecordReader<String, String> recordReader = new NullAtStartKafkaRecordReader<>(consumer, 3);
+      recordReader.initialize(split, context);
+
+      int numRecordsFound = 0;
+      while (recordReader.nextKeyValue()) {
+        keysRead.add(recordReader.getCurrentKey());
+        assertThat(keys, hasItem(recordReader.getCurrentKey()));
+        assertThat(recordReader.getCurrentValue(), is(notNullValue()));
+        numRecordsFound++;
+      }
+      recordReader.close();
+
+      //assert that it encountered a partitions worth of data
+      assertThat(((long) numRecordsFound), is(partitionInfo.getValue().second() - partitionInfo.getValue().first()));
+    }
+
+    //validate the same number of unique keys was read as were written.
+    assertThat(keysRead.size(), is(keys.size()));
+  }
+
+  @Test
+  public void pollReturnsEmptyAtStart() throws IOException, InterruptedException {
+    List<String> keys = ClusterTest.writeData(ClusterTest.getProducerProperties(), topic, "batch", 10, 10);
+
+    Map<TopicPartition, Long> startOffsets = getBrokerOffsets(consumerProps, OffsetRequest.EarliestTime(), topic);
+    Map<TopicPartition, Long> endOffsets = getBrokerOffsets(consumerProps, OffsetRequest.LatestTime(), topic);
+
+    Map<TopicPartition, Pair<Long, Long>> offsets = new HashMap<>();
+    for (Map.Entry<TopicPartition, Long> entry : startOffsets.entrySet()) {
+      Long endingOffset = endOffsets.get(entry.getKey());
+      offsets.put(entry.getKey(), Pair.of(entry.getValue(), endingOffset));
+    }
+
+    KafkaInputFormat.writeOffsetsToConfiguration(offsets, config);
+
+    Set<String> keysRead = new HashSet<>();
+    //read all data from all splits
+    for (Map.Entry<TopicPartition, Pair<Long, Long>> partitionInfo : offsets.entrySet()) {
+      KafkaInputSplit split = new KafkaInputSplit(partitionInfo.getKey().topic(), partitionInfo.getKey().partition(),
+              partitionInfo.getValue().first(), partitionInfo.getValue().second());
+
+      when(consumer.poll(Matchers.anyLong())).thenReturn(ConsumerRecords.<String, String>empty());
+      KafkaRecordReader<String, String> recordReader = new NullAtStartKafkaRecordReader<>(consumer, 3);
+      recordReader.initialize(split, context);
+
+      int numRecordsFound = 0;
+      while (recordReader.nextKeyValue()) {
+        keysRead.add(recordReader.getCurrentKey());
+        assertThat(keys, hasItem(recordReader.getCurrentKey()));
+        assertThat(recordReader.getCurrentValue(), is(notNullValue()));
+        numRecordsFound++;
+      }
+      recordReader.close();
+
+      //assert that it encountered a partitions worth of data
+      assertThat(((long) numRecordsFound), is(partitionInfo.getValue().second() - partitionInfo.getValue().first()));
+    }
+
+    //validate the same number of unique keys was read as were written.
+    assertThat(keysRead.size(), is(keys.size()));
+  }
+
+  @Test
+  public void pollReturnsNullInMiddle() throws IOException, InterruptedException {
+    List<String> keys = ClusterTest.writeData(ClusterTest.getProducerProperties(), topic, "batch", 10, 10);
+
+    Map<TopicPartition, Long> startOffsets = getBrokerOffsets(consumerProps, OffsetRequest.EarliestTime(), topic);
+    Map<TopicPartition, Long> endOffsets = getBrokerOffsets(consumerProps, OffsetRequest.LatestTime(), topic);
+
+    Map<TopicPartition, Pair<Long, Long>> offsets = new HashMap<>();
+    for (Map.Entry<TopicPartition, Long> entry : startOffsets.entrySet()) {
+      Long endingOffset = endOffsets.get(entry.getKey());
+      offsets.put(entry.getKey(), Pair.of(entry.getValue(), endingOffset));
+    }
+
+    KafkaInputFormat.writeOffsetsToConfiguration(offsets, config);
+
+    Set<String> keysRead = new HashSet<>();
+    //read all data from all splits
+    for (Map.Entry<TopicPartition, Pair<Long, Long>> partitionInfo : offsets.entrySet()) {
+      KafkaInputSplit split = new KafkaInputSplit(partitionInfo.getKey().topic(), partitionInfo.getKey().partition(),
+              partitionInfo.getValue().first(), partitionInfo.getValue().second());
+
+      KafkaRecordReader<String, String> recordReader = new InjectableKafkaRecordReader<>(consumer, 1);
+      recordReader.initialize(split, context);
+
+      int numRecordsFound = 0;
+      while (recordReader.nextKeyValue()) {
+        keysRead.add(recordReader.getCurrentKey());
+        assertThat(keys, hasItem(recordReader.getCurrentKey()));
+        assertThat(recordReader.getCurrentValue(), is(notNullValue()));
+        numRecordsFound++;
+      }
+      recordReader.close();
+
+      //assert that it encountered a partitions worth of data
+      assertThat(((long) numRecordsFound), is(partitionInfo.getValue().second() - partitionInfo.getValue().first()));
+    }
+
+    //validate the same number of unique keys was read as were written.
+    assertThat(keysRead.size(), is(keys.size()));
+  }
+
+  @Test
+  public void pollReturnsEmptyInMiddle() throws IOException, InterruptedException {
+    List<String> keys = ClusterTest.writeData(ClusterTest.getProducerProperties(), topic, "batch", 10, 10);
+
+    Map<TopicPartition, Long> startOffsets = getBrokerOffsets(consumerProps, OffsetRequest.EarliestTime(), topic);
+    Map<TopicPartition, Long> endOffsets = getBrokerOffsets(consumerProps, OffsetRequest.LatestTime(), topic);
+
+    Map<TopicPartition, Pair<Long, Long>> offsets = new HashMap<>();
+    for (Map.Entry<TopicPartition, Long> entry : startOffsets.entrySet()) {
+      Long endingOffset = endOffsets.get(entry.getKey());
+      offsets.put(entry.getKey(), Pair.of(entry.getValue(), endingOffset));
+    }
+
+    KafkaInputFormat.writeOffsetsToConfiguration(offsets, config);
+
+    Set<String> keysRead = new HashSet<>();
+    //read all data from all splits
+    for (Map.Entry<TopicPartition, Pair<Long, Long>> partitionInfo : offsets.entrySet()) {
+      KafkaInputSplit split = new KafkaInputSplit(partitionInfo.getKey().topic(), partitionInfo.getKey().partition(),
+              partitionInfo.getValue().first(), partitionInfo.getValue().second());
+
+      when(consumer.poll(Matchers.anyLong())).thenReturn(ConsumerRecords.<String, String>empty());
+      KafkaRecordReader<String, String> recordReader = new InjectableKafkaRecordReader<>(consumer, 1);
+      recordReader.initialize(split, context);
+
+      int numRecordsFound = 0;
+      while (recordReader.nextKeyValue()) {
+        keysRead.add(recordReader.getCurrentKey());
+        assertThat(keys, hasItem(recordReader.getCurrentKey()));
+        assertThat(recordReader.getCurrentValue(), is(notNullValue()));
+        numRecordsFound++;
+      }
+      recordReader.close();
+
+      //assert that it encountered a partitions worth of data
+      assertThat(((long) numRecordsFound), is(partitionInfo.getValue().second() - partitionInfo.getValue().first()));
+    }
+
+    //validate the same number of unique keys was read as were written.
+    assertThat(keysRead.size(), is(keys.size()));
+  }
+
+
+  private static class NullAtStartKafkaRecordReader<K, V> extends KafkaRecordReader<K, V>{
+
+    private final Consumer consumer;
+    private final int callAttempts;
+
+    private int attempts;
+
+    public NullAtStartKafkaRecordReader(Consumer consumer, int callAttempts){
+      this.consumer = consumer;
+      this.callAttempts = callAttempts;
+      attempts = 0;
+    }
+
+    @Override
+    protected Consumer<K, V> getConsumer() {
+      if(attempts > callAttempts){
+        return super.getConsumer();
+      }
+      attempts++;
+      return consumer;
+    }
+  }
+
+  private static class InjectableKafkaRecordReader<K, V> extends KafkaRecordReader<K, V>{
+
+    private final Consumer consumer;
+    private final int failAttempt;
+
+    private int attempts;
+
+    public InjectableKafkaRecordReader(Consumer consumer, int failAttempt){
+      this.consumer = consumer;
+      this.failAttempt = failAttempt;
+      attempts = 0;
+    }
+
+    @Override
+    protected Consumer<K, V> getConsumer() {
+      if(attempts == failAttempt){
+        attempts++;
+        return consumer;
+      }
+      attempts++;
+      return super.getConsumer();
+    }
+  }
+
 }