DRILL-8253: Support the limit results in kafka scan (#2580)

* Addressed review comments

Co-authored-by: luoc <luocooong@qq.com>
diff --git a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaGroupScan.java b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaGroupScan.java
index e025e65..45c7589 100644
--- a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaGroupScan.java
+++ b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaGroupScan.java
@@ -30,6 +30,7 @@
 import org.apache.drill.exec.ExecConstants;
 import org.apache.drill.shaded.guava.com.google.common.collect.Sets;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.drill.common.PlanStringBuilder;
 import org.apache.drill.common.exceptions.ExecutionSetupException;
 import org.apache.drill.common.exceptions.UserException;
 import org.apache.drill.common.expression.SchemaPath;
@@ -76,7 +77,8 @@
   private final KafkaStoragePlugin kafkaStoragePlugin;
   private final KafkaScanSpec kafkaScanSpec;
 
-  private List<SchemaPath> columns;
+  private final List<SchemaPath> columns;
+  private final int records;
   private ListMultimap<Integer, PartitionScanWork> assignments;
   private List<EndpointAffinity> affinities;
 
@@ -86,18 +88,21 @@
   public KafkaGroupScan(@JsonProperty("userName") String userName,
                         @JsonProperty("kafkaStoragePluginConfig") KafkaStoragePluginConfig kafkaStoragePluginConfig,
                         @JsonProperty("columns") List<SchemaPath> columns,
+                        @JsonProperty("records") int records,
                         @JsonProperty("kafkaScanSpec") KafkaScanSpec scanSpec,
                         @JacksonInject StoragePluginRegistry pluginRegistry) throws ExecutionSetupException {
     this(userName,
         pluginRegistry.resolve(kafkaStoragePluginConfig, KafkaStoragePlugin.class),
         columns,
+        records,
         scanSpec);
   }
 
-  public KafkaGroupScan(KafkaStoragePlugin kafkaStoragePlugin, KafkaScanSpec kafkaScanSpec, List<SchemaPath> columns) {
+  public KafkaGroupScan(KafkaStoragePlugin kafkaStoragePlugin, KafkaScanSpec kafkaScanSpec, List<SchemaPath> columns, int records) {
     super(StringUtils.EMPTY);
     this.kafkaStoragePlugin = kafkaStoragePlugin;
     this.columns = columns;
+    this.records = records;
     this.kafkaScanSpec = kafkaScanSpec;
     init();
   }
@@ -105,10 +110,12 @@
   public KafkaGroupScan(String userName,
                         KafkaStoragePlugin kafkaStoragePlugin,
                         List<SchemaPath> columns,
+                        int records,
                         KafkaScanSpec kafkaScanSpec) {
     super(userName);
     this.kafkaStoragePlugin = kafkaStoragePlugin;
     this.columns = columns;
+    this.records = records;
     this.kafkaScanSpec = kafkaScanSpec;
     init();
   }
@@ -117,6 +124,27 @@
     super(that);
     this.kafkaStoragePlugin = that.kafkaStoragePlugin;
     this.columns = that.columns;
+    this.records = that.records;
+    this.kafkaScanSpec = that.kafkaScanSpec;
+    this.assignments = that.assignments;
+    this.partitionWorkMap = that.partitionWorkMap;
+  }
+
+  public KafkaGroupScan(KafkaGroupScan that, List<SchemaPath> columns) {
+    super(that);
+    this.kafkaStoragePlugin = that.kafkaStoragePlugin;
+    this.columns = columns;
+    this.records = that.records;
+    this.kafkaScanSpec = that.kafkaScanSpec;
+    this.assignments = that.assignments;
+    this.partitionWorkMap = that.partitionWorkMap;
+  }
+
+  public KafkaGroupScan(KafkaGroupScan that, int records) {
+    super(that);
+    this.kafkaStoragePlugin = that.kafkaStoragePlugin;
+    this.columns = that.columns;
+    this.records = records;
     this.kafkaScanSpec = that.kafkaScanSpec;
     this.assignments = that.assignments;
     this.partitionWorkMap = that.partitionWorkMap;
@@ -263,6 +291,20 @@
   }
 
   @Override
+  public GroupScan applyLimit(int maxRecords) {
+    if (maxRecords > records) { // pass the limit value into sub-scan
+      return new KafkaGroupScan(this, maxRecords);
+    } else { // stop the transform
+      return null;
+    }
+  }
+
+  @Override
+  public boolean supportsLimitPushdown() {
+    return true;
+  }
+
+  @Override
   public KafkaSubScan getSpecificScan(int minorFragmentId) {
     List<PartitionScanWork> workList = assignments.get(minorFragmentId);
 
@@ -270,7 +312,7 @@
       .map(PartitionScanWork::getPartitionScanSpec)
       .collect(Collectors.toList());
 
-    return new KafkaSubScan(getUserName(), kafkaStoragePlugin, columns, scanSpecList);
+    return new KafkaSubScan(getUserName(), kafkaStoragePlugin, columns, records, scanSpecList);
   }
 
   @Override
@@ -314,9 +356,7 @@
 
   @Override
   public GroupScan clone(List<SchemaPath> columns) {
-    KafkaGroupScan clone = new KafkaGroupScan(this);
-    clone.columns = columns;
-    return clone;
+    return new KafkaGroupScan(this, columns);
   }
 
   public GroupScan cloneWithNewSpec(List<KafkaPartitionScanSpec> partitionScanSpecList) {
@@ -348,6 +388,11 @@
   }
 
   @JsonProperty
+  public int getRecords() {
+    return records;
+  }
+
+  @JsonProperty
   public KafkaScanSpec getKafkaScanSpec() {
     return kafkaScanSpec;
   }
@@ -359,7 +404,11 @@
 
   @Override
   public String toString() {
-    return String.format("KafkaGroupScan [KafkaScanSpec=%s, columns=%s]", kafkaScanSpec, columns);
+    return new PlanStringBuilder("")
+        .field("scanSpec", kafkaScanSpec)
+        .field("columns", columns)
+        .field("records", records)
+        .toString();
   }
 
   @JsonIgnore
diff --git a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaRecordReader.java b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaRecordReader.java
index 551b62f..0d38ea1 100644
--- a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaRecordReader.java
+++ b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaRecordReader.java
@@ -62,6 +62,7 @@
       }
     };
     negotiator.setErrorContext(errorContext);
+    negotiator.limit(maxRecords);
 
     messageReader = MessageReaderFactory.getMessageReader(readOptions.getMessageReader());
     messageReader.init(negotiator, readOptions, plugin);
@@ -86,10 +87,6 @@
   }
 
   private boolean nextLine(RowSetLoader rowWriter) {
-    if (rowWriter.limitReached(maxRecords)) {
-      return false;
-    }
-
     if (currentOffset >= subScanSpec.getEndOffset() || !msgItr.hasNext()) {
       return false;
     }
diff --git a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaScanBatchCreator.java b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaScanBatchCreator.java
index 7d2ebcc..4adfdca 100644
--- a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaScanBatchCreator.java
+++ b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaScanBatchCreator.java
@@ -61,7 +61,7 @@
     builder.setUserName(subScan.getUserName());
 
     List<ManagedReader<SchemaNegotiator>> readers = subScan.getPartitionSubScanSpecList().stream()
-        .map(scanSpec -> new KafkaRecordReader(scanSpec, options, subScan.getKafkaStoragePlugin(), -1))
+        .map(scanSpec -> new KafkaRecordReader(scanSpec, options, subScan.getKafkaStoragePlugin(), subScan.getRecords()))
         .collect(Collectors.toList());
     ManagedScanFramework.ReaderFactory readerFactory = new BasicScanFactory(readers.iterator());
     builder.setReaderFactory(readerFactory);
diff --git a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaStoragePlugin.java b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaStoragePlugin.java
index 257c0bf..d80b959 100644
--- a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaStoragePlugin.java
+++ b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaStoragePlugin.java
@@ -76,7 +76,7 @@
     KafkaScanSpec kafkaScanSpec = selection.getListWith(new ObjectMapper(),
         new TypeReference<KafkaScanSpec>() {
         });
-    return new KafkaGroupScan(this, kafkaScanSpec, null);
+    return new KafkaGroupScan(this, kafkaScanSpec, null, -1);
   }
 
   public void registerToClose(AutoCloseable autoCloseable) {
diff --git a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaSubScan.java b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaSubScan.java
index c55f1e7..c05ddd0 100644
--- a/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaSubScan.java
+++ b/contrib/storage-kafka/src/main/java/org/apache/drill/exec/store/kafka/KafkaSubScan.java
@@ -44,6 +44,7 @@
 
   private final KafkaStoragePlugin kafkaStoragePlugin;
   private final List<SchemaPath> columns;
+  private final int records;
   private final List<KafkaPartitionScanSpec> partitionSubScanSpecList;
 
   @JsonCreator
@@ -51,21 +52,25 @@
                       @JsonProperty("userName") String userName,
                       @JsonProperty("kafkaStoragePluginConfig") KafkaStoragePluginConfig kafkaStoragePluginConfig,
                       @JsonProperty("columns") List<SchemaPath> columns,
+                      @JsonProperty("records") int records,
                       @JsonProperty("partitionSubScanSpecList") LinkedList<KafkaPartitionScanSpec> partitionSubScanSpecList)
       throws ExecutionSetupException {
     this(userName,
         registry.resolve(kafkaStoragePluginConfig, KafkaStoragePlugin.class),
         columns,
+        records,
         partitionSubScanSpecList);
   }
 
   public KafkaSubScan(String userName,
                       KafkaStoragePlugin kafkaStoragePlugin,
                       List<SchemaPath> columns,
+                      int records,
                       List<KafkaPartitionScanSpec> partitionSubScanSpecList) {
     super(userName);
     this.kafkaStoragePlugin = kafkaStoragePlugin;
     this.columns = columns;
+    this.records = records;
     this.partitionSubScanSpecList = partitionSubScanSpecList;
   }
 
@@ -77,7 +82,7 @@
   @Override
   public PhysicalOperator getNewWithChildren(List<PhysicalOperator> children) {
     Preconditions.checkArgument(children.isEmpty());
-    return new KafkaSubScan(getUserName(), kafkaStoragePlugin, columns, partitionSubScanSpecList);
+    return new KafkaSubScan(getUserName(), kafkaStoragePlugin, columns, records, partitionSubScanSpecList);
   }
 
   @Override
@@ -96,6 +101,11 @@
   }
 
   @JsonProperty
+  public int getRecords() {
+    return records;
+  }
+
+  @JsonProperty
   public List<KafkaPartitionScanSpec> getPartitionSubScanSpecList() {
     return partitionSubScanSpecList;
   }
diff --git a/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/KafkaQueriesTest.java b/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/KafkaQueriesTest.java
index e86423e..e913036 100644
--- a/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/KafkaQueriesTest.java
+++ b/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/KafkaQueriesTest.java
@@ -64,6 +64,16 @@
   }
 
   @Test
+  public void testResultLimit() throws Exception {
+    String queryString = String.format(TestQueryConstants.MSG_LIMIT_QUERY, TestQueryConstants.JSON_TOPIC);
+    queryBuilder()
+      .sql(queryString)
+      .planMatcher()
+      .include("Scan", "records=3")
+      .match();
+  }
+
+  @Test
   public void testResultCount() {
     String queryString = String.format(TestQueryConstants.MSG_SELECT_QUERY, TestQueryConstants.JSON_TOPIC);
     runKafkaSQLVerifyCount(queryString, TestKafkaSuit.NUM_JSON_MSG);
diff --git a/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/TestQueryConstants.java b/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/TestQueryConstants.java
index b3163ad..b370b9b 100644
--- a/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/TestQueryConstants.java
+++ b/contrib/storage-kafka/src/test/java/org/apache/drill/exec/store/kafka/TestQueryConstants.java
@@ -40,6 +40,7 @@
   // Queries
   String MSG_COUNT_QUERY = "select count(*) from kafka.`%s`";
   String MSG_SELECT_QUERY = "select * from kafka.`%s`";
+  String MSG_LIMIT_QUERY = "select * from kafka.`%s` limit 3";
   String MIN_OFFSET_QUERY = "select MIN(kafkaMsgOffset) as minOffset from kafka.`%s`";
   String MAX_OFFSET_QUERY = "select MAX(kafkaMsgOffset) as maxOffset from kafka.`%s`";