[ISSUE #35] Use `LitePullConsumer` model instead of default pull consumer(#46)

diff --git a/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java b/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java
index 22903c5..50a0883 100644
--- a/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java
+++ b/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java
@@ -64,6 +64,9 @@
     public static final ConfigOption<Long> OPTIONAL_PARTITION_DISCOVERY_INTERVAL_MS =
             ConfigOptions.key("partitionDiscoveryIntervalMs").longType().defaultValue(30000L);
 
+    public static final ConfigOption<Long> OPTIONAL_CONSUMER_POLL_MS =
+            ConfigOptions.key("consumer.timeout").longType().defaultValue(3000L);
+
     public static final ConfigOption<Boolean> OPTIONAL_USE_NEW_API =
             ConfigOptions.key("useNewApi").booleanType().defaultValue(true);
 
diff --git a/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQConfig.java b/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQConfig.java
index 936beb8..ecf7a9e 100644
--- a/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQConfig.java
+++ b/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQConfig.java
@@ -20,7 +20,7 @@
 import org.apache.rocketmq.acl.common.SessionCredentials;
 import org.apache.rocketmq.client.AccessChannel;
 import org.apache.rocketmq.client.ClientConfig;
-import org.apache.rocketmq.client.consumer.DefaultMQPullConsumer;
+import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer;
 import org.apache.rocketmq.client.producer.DefaultMQProducer;
 import org.apache.rocketmq.common.protocol.heartbeat.MessageModel;
 
@@ -59,8 +59,12 @@
     public static final int DEFAULT_PRODUCER_RETRY_TIMES = 3;
 
     public static final String PRODUCER_TIMEOUT = "producer.timeout";
+
+    public static final String CONSUMER_TIMEOUT = "consumer.timeout";
     public static final int DEFAULT_PRODUCER_TIMEOUT = 3000; // 3 seconds
 
+    public static final int DEFAULT_CONSUMER_TIMEOUT = 3000; // 3 seconds
+
     // Consumer related config
     public static final String CONSUMER_GROUP = "consumer.group"; // Required
     public static final String CONSUMER_TOPIC = "consumer.topic"; // Required
@@ -142,9 +146,9 @@
      * Build Consumer Configs.
      *
      * @param props Properties
-     * @param consumer DefaultMQPullConsumer
+     * @param consumer DefaultLitePullConsumer
      */
-    public static void buildConsumerConfigs(Properties props, DefaultMQPullConsumer consumer) {
+    public static void buildConsumerConfigs(Properties props, DefaultLitePullConsumer consumer) {
         buildCommonConfigs(props, consumer);
         consumer.setMessageModel(MessageModel.CLUSTERING);
         consumer.setPersistConsumerOffsetInterval(
diff --git a/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQSourceFunction.java b/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQSourceFunction.java
index b078056..29272d8 100644
--- a/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQSourceFunction.java
+++ b/src/main/java/org/apache/rocketmq/flink/legacy/RocketMQSourceFunction.java
@@ -17,9 +17,8 @@
 
 package org.apache.rocketmq.flink.legacy;
 
-import org.apache.rocketmq.client.consumer.DefaultMQPullConsumer;
+import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer;
 import org.apache.rocketmq.client.consumer.MessageSelector;
-import org.apache.rocketmq.client.consumer.PullResult;
 import org.apache.rocketmq.client.exception.MQClientException;
 import org.apache.rocketmq.common.message.MessageExt;
 import org.apache.rocketmq.common.message.MessageQueue;
@@ -55,6 +54,7 @@
 import org.apache.flink.shaded.curator5.com.google.common.collect.Lists;
 import org.apache.flink.shaded.curator5.com.google.common.util.concurrent.ThreadFactoryBuilder;
 
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.collections.map.LinkedMap;
 import org.apache.commons.lang.Validate;
 import org.apache.commons.lang3.StringUtils;
@@ -78,7 +78,9 @@
 import java.util.concurrent.locks.ReentrantLock;
 
 import static org.apache.rocketmq.flink.legacy.RocketMQConfig.CONSUMER_BATCH_SIZE;
+import static org.apache.rocketmq.flink.legacy.RocketMQConfig.CONSUMER_TIMEOUT;
 import static org.apache.rocketmq.flink.legacy.RocketMQConfig.DEFAULT_CONSUMER_BATCH_SIZE;
+import static org.apache.rocketmq.flink.legacy.RocketMQConfig.DEFAULT_CONSUMER_TIMEOUT;
 import static org.apache.rocketmq.flink.legacy.common.util.RocketMQUtils.getInteger;
 
 /**
@@ -94,7 +96,9 @@
     private static final Logger log = LoggerFactory.getLogger(RocketMQSourceFunction.class);
     private static final String OFFSETS_STATE_NAME = "topic-partition-offset-states";
     private RunningChecker runningChecker;
-    private transient DefaultMQPullConsumer consumer;
+
+    private transient DefaultLitePullConsumer consumer;
+
     private KeyValueDeserializationSchema<OUT> schema;
     private transient ListState<Tuple2<MessageQueue, Long>> unionOffsetStates;
     private Map<MessageQueue, Long> offsetTable;
@@ -203,7 +207,7 @@
         executor = Executors.newCachedThreadPool(threadFactory);
 
         int indexOfThisSubTask = getRuntimeContext().getIndexOfThisSubtask();
-        consumer = new DefaultMQPullConsumer(group, RocketMQConfig.buildAclRPCHook(props));
+        consumer = new DefaultLitePullConsumer(group, RocketMQConfig.buildAclRPCHook(props));
         RocketMQConfig.buildConsumerConfigs(props, consumer);
 
         // set unique instance name, avoid exception:
@@ -241,7 +245,7 @@
         int taskNumber = ctx.getNumberOfParallelSubtasks();
         int taskIndex = ctx.getIndexOfThisSubtask();
         log.info("Source run, NumberOfTotalTask={}, IndexOfThisSubTask={}", taskNumber, taskIndex);
-        Collection<MessageQueue> totalQueues = consumer.fetchSubscribeMessageQueues(topic);
+        Collection<MessageQueue> totalQueues = consumer.fetchMessageQueues(topic);
         messageQueues =
                 RocketMQUtils.allocate(totalQueues, taskNumber, ctx.getIndexOfThisSubtask());
         // If the job recovers from the state, the state has already contained the offsets of last
@@ -265,6 +269,12 @@
                 5,
                 5,
                 TimeUnit.SECONDS);
+        if (StringUtils.isEmpty(sql)) {
+            consumer.subscribe(topic, tag);
+        } else {
+            // pull with sql do not support block pull.
+            consumer.subscribe(topic, MessageSelector.bySql(sql));
+        }
         for (MessageQueue mq : messageQueues) {
             this.executor.execute(
                     () ->
@@ -272,103 +282,64 @@
                                     () -> {
                                         while (runningChecker.isRunning()) {
                                             try {
-                                                long offset = offsetTable.get(mq);
-                                                PullResult pullResult;
-                                                if (StringUtils.isEmpty(sql)) {
-                                                    pullResult =
-                                                            consumer.pullBlockIfNotFound(
-                                                                    mq, tag, offset, pullBatchSize);
-                                                } else {
-                                                    // pull with sql do not support block pull.
-                                                    pullResult =
-                                                            consumer.pull(
-                                                                    mq,
-                                                                    MessageSelector.bySql(sql),
-                                                                    offset,
-                                                                    pullBatchSize);
-                                                }
-
+                                                Long offset = offsetTable.get(mq);
+                                                consumer.setPullBatchSize(pullBatchSize);
+                                                consumer.seek(mq, offset);
                                                 boolean found = false;
-                                                switch (pullResult.getPullStatus()) {
-                                                    case FOUND:
-                                                        List<MessageExt> messages =
-                                                                pullResult.getMsgFoundList();
-                                                        long fetchTime = System.currentTimeMillis();
-                                                        for (MessageExt msg : messages) {
-                                                            byte[] key =
-                                                                    msg.getKeys() != null
-                                                                            ? msg.getKeys()
-                                                                                    .getBytes(
-                                                                                            StandardCharsets
-                                                                                                    .UTF_8)
-                                                                            : null;
-                                                            byte[] value = msg.getBody();
-                                                            OUT data =
-                                                                    schema.deserializeKeyAndValue(
-                                                                            key, value);
+                                                List<MessageExt> messages =
+                                                        consumer.poll(
+                                                                getInteger(
+                                                                        props,
+                                                                        CONSUMER_TIMEOUT,
+                                                                        DEFAULT_CONSUMER_TIMEOUT));
+                                                if (CollectionUtils.isNotEmpty(messages)) {
+                                                    long fetchTime = System.currentTimeMillis();
+                                                    for (MessageExt msg : messages) {
+                                                        byte[] key =
+                                                                msg.getKeys() != null
+                                                                        ? msg.getKeys()
+                                                                                .getBytes(
+                                                                                        StandardCharsets
+                                                                                                .UTF_8)
+                                                                        : null;
+                                                        byte[] value = msg.getBody();
+                                                        OUT data =
+                                                                schema.deserializeKeyAndValue(
+                                                                        key, value);
 
-                                                            // output and state update are atomic
-                                                            synchronized (checkPointLock) {
-                                                                log.debug(
-                                                                        msg.getMsgId()
-                                                                                + "_"
-                                                                                + msg
-                                                                                        .getBrokerName()
-                                                                                + " "
-                                                                                + msg.getQueueId()
-                                                                                + " "
-                                                                                + msg
-                                                                                        .getQueueOffset());
-                                                                context.collectWithTimestamp(
-                                                                        data,
-                                                                        msg.getBornTimestamp());
-                                                                long emitTime =
-                                                                        System.currentTimeMillis();
-
-                                                                // update max eventTime per queue
-                                                                // waterMarkPerQueue.extractTimestamp(mq, msg.getBornTimestamp());
-                                                                waterMarkForAll.extractTimestamp(
-                                                                        msg.getBornTimestamp());
-                                                                tpsMetric.markEvent();
-                                                                long eventTime =
-                                                                        msg.getStoreTimestamp();
-                                                                fetchDelay.report(
-                                                                        Math.abs(
-                                                                                fetchTime
-                                                                                        - eventTime));
-                                                                emitDelay.report(
-                                                                        Math.abs(
-                                                                                emitTime
-                                                                                        - eventTime));
-                                                            }
+                                                        // output and state update are atomic
+                                                        synchronized (checkPointLock) {
+                                                            log.debug(
+                                                                    msg.getMsgId()
+                                                                            + "_"
+                                                                            + msg.getBrokerName()
+                                                                            + " "
+                                                                            + msg.getQueueId()
+                                                                            + " "
+                                                                            + msg.getQueueOffset());
+                                                            context.collectWithTimestamp(
+                                                                    data, msg.getBornTimestamp());
+                                                            long emitTime =
+                                                                    System.currentTimeMillis();
+                                                            // update max eventTime per queue
+                                                            // waterMarkPerQueue.extractTimestamp(mq, msg.getBornTimestamp());
+                                                            waterMarkForAll.extractTimestamp(
+                                                                    msg.getBornTimestamp());
+                                                            tpsMetric.markEvent();
+                                                            long eventTime =
+                                                                    msg.getStoreTimestamp();
+                                                            fetchDelay.report(
+                                                                    Math.abs(
+                                                                            fetchTime - eventTime));
+                                                            emitDelay.report(
+                                                                    Math.abs(emitTime - eventTime));
                                                         }
-                                                        found = true;
-                                                        break;
-                                                    case NO_MATCHED_MSG:
-                                                        log.debug(
-                                                                "No matched message after offset {} for queue {}",
-                                                                offset,
-                                                                mq);
-                                                        break;
-                                                    case NO_NEW_MSG:
-                                                        log.debug(
-                                                                "No new message after offset {} for queue {}",
-                                                                offset,
-                                                                mq);
-                                                        break;
-                                                    case OFFSET_ILLEGAL:
-                                                        log.warn(
-                                                                "Offset {} is illegal for queue {}",
-                                                                offset,
-                                                                mq);
-                                                        break;
-                                                    default:
-                                                        break;
+                                                    }
+                                                    found = true;
                                                 }
-
                                                 synchronized (checkPointLock) {
                                                     updateMessageQueueOffset(
-                                                            mq, pullResult.getNextBeginOffset());
+                                                            mq, consumer.committed(mq));
                                                 }
 
                                                 if (!found) {
@@ -405,13 +376,15 @@
             long offset;
             switch (startMode) {
                 case LATEST:
-                    offset = consumer.maxOffset(mq);
+                    consumer.seekToEnd(mq);
+                    offset = consumer.committed(mq);
                     break;
                 case EARLIEST:
-                    offset = consumer.minOffset(mq);
+                    consumer.seekToBegin(mq);
+                    offset = consumer.committed(mq);
                     break;
                 case GROUP_OFFSETS:
-                    offset = consumer.fetchConsumeOffset(mq, false);
+                    offset = consumer.committed(mq);
                     // the min offset return if consumer group first join,return a negative number
                     // if
                     // catch exception when fetch from broker.
@@ -419,7 +392,8 @@
                     if (offset <= 0) {
                         switch (offsetResetStrategy) {
                             case LATEST:
-                                offset = consumer.maxOffset(mq);
+                                consumer.seekToEnd(mq);
+                                offset = consumer.committed(mq);
                                 log.info(
                                         "current consumer thread:{} has no committed offset,use Strategy:{} instead",
                                         mq,
@@ -430,7 +404,8 @@
                                         "current consumer thread:{} has no committed offset,use Strategy:{} instead",
                                         mq,
                                         offsetResetStrategy);
-                                offset = consumer.minOffset(mq);
+                                consumer.seekToBegin(mq);
+                                offset = consumer.committed(mq);
                                 break;
                             default:
                                 break;
@@ -438,7 +413,7 @@
                     }
                     break;
                 case TIMESTAMP:
-                    offset = consumer.searchOffset(mq, specificTimeStamp);
+                    offset = consumer.offsetForTimestamp(mq, specificTimeStamp);
                     break;
                 case SPECIFIC_OFFSETS:
                     if (specificStartupOffsets == null) {
@@ -449,7 +424,7 @@
                     if (specificOffset != null) {
                         offset = specificOffset;
                     } else {
-                        offset = consumer.fetchConsumeOffset(mq, false);
+                        offset = consumer.committed(mq);
                     }
                     break;
                 default:
@@ -514,8 +489,7 @@
     private void updateMessageQueueOffset(MessageQueue mq, long offset) throws MQClientException {
         offsetTable.put(mq, offset);
         if (!enableCheckpoint) {
-            consumer.updateConsumeOffset(mq, offset);
-            consumer.getOffsetStore().persist(consumer.queueWithNamespace(mq));
+            consumer.getOffsetStore().updateOffset(mq, offset, false);
         }
     }
 
@@ -589,8 +563,7 @@
             // Discovers topic route change when snapshot
             RetryUtil.call(
                     () -> {
-                        Collection<MessageQueue> totalQueues =
-                                consumer.fetchSubscribeMessageQueues(topic);
+                        Collection<MessageQueue> totalQueues = consumer.fetchMessageQueues(topic);
                         int taskNumber = getRuntimeContext().getNumberOfParallelSubtasks();
                         int taskIndex = getRuntimeContext().getIndexOfThisSubtask();
                         List<MessageQueue> newQueues =
@@ -700,7 +673,7 @@
         }
 
         for (Map.Entry<MessageQueue, Long> entry : offsets.entrySet()) {
-            consumer.updateConsumeOffset(entry.getKey(), entry.getValue());
+            consumer.getOffsetStore().updateOffset(entry.getKey(), entry.getValue(), false);
             consumer.getOffsetStore().persist(consumer.queueWithNamespace(entry.getKey()));
         }
     }
diff --git a/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java b/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java
index 27c69f1..8d98d2e 100644
--- a/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java
+++ b/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java
@@ -60,6 +60,7 @@
 
     private final String consumerOffsetMode;
     private final long consumerOffsetTimestamp;
+    private final long pollTime;
     private final String topic;
     private final String consumerGroup;
     private final String nameServerAddress;
@@ -79,6 +80,7 @@
     private final RocketMQDeserializationSchema<OUT> deserializationSchema;
 
     public RocketMQSource(
+            long pollTime,
             String topic,
             String consumerGroup,
             String nameServerAddress,
@@ -97,6 +99,7 @@
         Validate.isTrue(
                 !(StringUtils.isNotEmpty(tag) && StringUtils.isNotEmpty(sql)),
                 "Consumer tag and sql can not set value at the same time");
+        this.pollTime = pollTime;
         this.topic = topic;
         this.consumerGroup = consumerGroup;
         this.nameServerAddress = nameServerAddress;
@@ -140,6 +143,7 @@
         Supplier<SplitReader<Tuple3<OUT, Long, Long>, RocketMQPartitionSplit>> splitReaderSupplier =
                 () ->
                         new RocketMQPartitionSplitReader<>(
+                                pollTime,
                                 topic,
                                 consumerGroup,
                                 nameServerAddress,
diff --git a/src/main/java/org/apache/rocketmq/flink/source/enumerator/RocketMQSourceEnumerator.java b/src/main/java/org/apache/rocketmq/flink/source/enumerator/RocketMQSourceEnumerator.java
index 38aa132..bf489bb 100644
--- a/src/main/java/org/apache/rocketmq/flink/source/enumerator/RocketMQSourceEnumerator.java
+++ b/src/main/java/org/apache/rocketmq/flink/source/enumerator/RocketMQSourceEnumerator.java
@@ -20,7 +20,7 @@
 
 import org.apache.rocketmq.acl.common.AclClientRPCHook;
 import org.apache.rocketmq.acl.common.SessionCredentials;
-import org.apache.rocketmq.client.consumer.DefaultMQPullConsumer;
+import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer;
 import org.apache.rocketmq.client.exception.MQClientException;
 import org.apache.rocketmq.common.message.MessageQueue;
 import org.apache.rocketmq.flink.source.split.RocketMQPartitionSplit;
@@ -100,7 +100,8 @@
     private final Map<Integer, Set<RocketMQPartitionSplit>> pendingPartitionSplitAssignment;
 
     // Lazily instantiated or mutable fields.
-    private DefaultMQPullConsumer consumer;
+    private DefaultLitePullConsumer consumer;
+
     private boolean noMoreNewPartitionSplits = false;
 
     public RocketMQSourceEnumerator(
@@ -233,7 +234,8 @@
         Set<Tuple3<String, String, Integer>> newPartitions = new HashSet<>();
         Set<Tuple3<String, String, Integer>> removedPartitions =
                 new HashSet<>(Collections.unmodifiableSet(discoveredPartitions));
-        Set<MessageQueue> messageQueues = consumer.fetchSubscribeMessageQueues(topic);
+
+        Collection<MessageQueue> messageQueues = consumer.fetchMessageQueues(topic);
         Set<RocketMQPartitionSplit> result = new HashSet<>();
         for (MessageQueue messageQueue : messageQueues) {
             Tuple3<String, String, Integer> topicPartition =
@@ -337,16 +339,16 @@
             } else {
                 switch (consumerOffsetMode) {
                     case CONSUMER_OFFSET_EARLIEST:
-                        offset = consumer.minOffset(mq);
-                        break;
+                        consumer.seekToBegin(mq);
+                        return -1;
                     case CONSUMER_OFFSET_LATEST:
-                        offset = consumer.maxOffset(mq);
-                        break;
+                        consumer.seekToEnd(mq);
+                        return -1;
                     case CONSUMER_OFFSET_TIMESTAMP:
-                        offset = consumer.searchOffset(mq, consumerOffsetTimestamp);
+                        offset = consumer.offsetForTimestamp(mq, consumerOffsetTimestamp);
                         break;
                     default:
-                        offset = consumer.fetchConsumeOffset(mq, false);
+                        offset = consumer.committed(mq);
                         if (offset < 0) {
                             throw new IllegalArgumentException(
                                     "Unknown value for CONSUMER_OFFSET_RESET_TO.");
@@ -364,11 +366,10 @@
                     && !StringUtils.isNullOrWhitespaceOnly(secretKey)) {
                 AclClientRPCHook aclClientRPCHook =
                         new AclClientRPCHook(new SessionCredentials(accessKey, secretKey));
-                consumer = new DefaultMQPullConsumer(consumerGroup, aclClientRPCHook);
+                consumer = new DefaultLitePullConsumer(consumerGroup, aclClientRPCHook);
             } else {
-                consumer = new DefaultMQPullConsumer(consumerGroup);
+                consumer = new DefaultLitePullConsumer(consumerGroup);
             }
-
             consumer.setNamesrvAddr(nameServerAddress);
             consumer.setInstanceName(
                     String.join(
diff --git a/src/main/java/org/apache/rocketmq/flink/source/reader/RocketMQPartitionSplitReader.java b/src/main/java/org/apache/rocketmq/flink/source/reader/RocketMQPartitionSplitReader.java
index ca9c3f1..72fd96e 100644
--- a/src/main/java/org/apache/rocketmq/flink/source/reader/RocketMQPartitionSplitReader.java
+++ b/src/main/java/org/apache/rocketmq/flink/source/reader/RocketMQPartitionSplitReader.java
@@ -20,16 +20,13 @@
 
 import org.apache.rocketmq.acl.common.AclClientRPCHook;
 import org.apache.rocketmq.acl.common.SessionCredentials;
-import org.apache.rocketmq.client.consumer.DefaultMQPullConsumer;
+import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer;
 import org.apache.rocketmq.client.consumer.MessageSelector;
-import org.apache.rocketmq.client.consumer.PullResult;
-import org.apache.rocketmq.client.exception.MQBrokerException;
 import org.apache.rocketmq.client.exception.MQClientException;
 import org.apache.rocketmq.common.message.MessageExt;
 import org.apache.rocketmq.common.message.MessageQueue;
 import org.apache.rocketmq.flink.source.reader.deserializer.RocketMQDeserializationSchema;
 import org.apache.rocketmq.flink.source.split.RocketMQPartitionSplit;
-import org.apache.rocketmq.remoting.exception.RemotingException;
 
 import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds;
@@ -57,8 +54,6 @@
 import java.util.Map;
 import java.util.Set;
 
-import static org.apache.rocketmq.client.consumer.PullStatus.FOUND;
-
 /**
  * A {@link SplitReader} implementation that reads records from RocketMQ partitions.
  *
@@ -75,6 +70,8 @@
     private final long startTime;
     private final long startOffset;
 
+    private final long pollTime;
+
     private final String accessKey;
     private final String secretKey;
 
@@ -83,13 +80,14 @@
     private final Map<Tuple3<String, String, Integer>, Long> stoppingTimestamps;
     private final SimpleCollector<T> collector;
 
-    private DefaultMQPullConsumer consumer;
+    private DefaultLitePullConsumer consumer;
 
     private volatile boolean wakeup = false;
 
     private static final int MAX_MESSAGE_NUMBER_PER_BLOCK = 64;
 
     public RocketMQPartitionSplitReader(
+            long pollTime,
             String topic,
             String consumerGroup,
             String nameServerAddress,
@@ -101,6 +99,7 @@
             long startTime,
             long startOffset,
             RocketMQDeserializationSchema<T> deserializationSchema) {
+        this.pollTime = pollTime;
         this.topic = topic;
         this.tag = tag;
         this.sql = sql;
@@ -120,9 +119,9 @@
     public RecordsWithSplitIds<Tuple3<T, Long, Long>> fetch() throws IOException {
         RocketMQPartitionSplitRecords<Tuple3<T, Long, Long>> recordsBySplits =
                 new RocketMQPartitionSplitRecords<>();
-        Set<MessageQueue> messageQueues;
+        Collection<MessageQueue> messageQueues;
         try {
-            messageQueues = consumer.fetchSubscribeMessageQueues(topic);
+            messageQueues = consumer.fetchMessageQueues(topic);
         } catch (MQClientException e) {
             LOG.error(
                     String.format(
@@ -144,7 +143,7 @@
                     try {
                         messageOffset =
                                 startTime > 0
-                                        ? consumer.searchOffset(messageQueue, startTime)
+                                        ? consumer.offsetForTimestamp(messageQueue, startTime)
                                         : startOffset;
                     } catch (MQClientException e) {
                         LOG.warn(
@@ -157,7 +156,7 @@
                     }
                     messageOffset = messageOffset > -1 ? messageOffset : 0;
                 }
-                PullResult pullResult = null;
+                List<MessageExt> messageExts = null;
                 try {
                     if (wakeup) {
                         LOG.info(
@@ -173,25 +172,11 @@
                         recordsBySplits.prepareForRead();
                         return recordsBySplits;
                     }
-                    if (StringUtils.isNotEmpty(sql)) {
-                        pullResult =
-                                consumer.pull(
-                                        messageQueue,
-                                        MessageSelector.bySql(sql),
-                                        messageOffset,
-                                        MAX_MESSAGE_NUMBER_PER_BLOCK);
-                    } else {
-                        pullResult =
-                                consumer.pull(
-                                        messageQueue,
-                                        tag,
-                                        messageOffset,
-                                        MAX_MESSAGE_NUMBER_PER_BLOCK);
-                    }
-                } catch (MQClientException
-                        | RemotingException
-                        | MQBrokerException
-                        | InterruptedException e) {
+
+                    consumer.setPullBatchSize(MAX_MESSAGE_NUMBER_PER_BLOCK);
+                    consumer.seek(messageQueue, messageOffset);
+                    messageExts = consumer.poll(pollTime);
+                } catch (MQClientException e) {
                     LOG.warn(
                             String.format(
                                     "Pull RocketMQ messages of topic[%s] broker[%s] queue[%d] tag[%s] sql[%s] from offset[%d] exception.",
@@ -203,10 +188,23 @@
                                     messageOffset),
                             e);
                 }
-                startingOffsets.put(
-                        topicPartition,
-                        pullResult == null ? messageOffset : pullResult.getNextBeginOffset());
-                if (pullResult != null && pullResult.getPullStatus() == FOUND) {
+                try {
+                    startingOffsets.put(
+                            topicPartition,
+                            messageExts == null ? messageOffset : consumer.committed(messageQueue));
+                } catch (MQClientException e) {
+                    LOG.warn(
+                            String.format(
+                                    "Pull RocketMQ messages of topic[%s] broker[%s] queue[%d] tag[%s] sql[%s] from offset[%d] exception.",
+                                    messageQueue.getTopic(),
+                                    messageQueue.getBrokerName(),
+                                    messageQueue.getQueueId(),
+                                    tag,
+                                    sql,
+                                    messageOffset),
+                            e);
+                }
+                if (messageExts != null) {
                     Collection<Tuple3<T, Long, Long>> recordsForSplit =
                             recordsBySplits.recordsForSplit(
                                     messageQueue.getTopic()
@@ -214,7 +212,7 @@
                                             + messageQueue.getBrokerName()
                                             + "-"
                                             + messageQueue.getQueueId());
-                    for (MessageExt messageExt : pullResult.getMsgFoundList()) {
+                    for (MessageExt messageExt : messageExts) {
                         long stoppingTimestamp = getStoppingTimestamp(topicPartition);
                         long storeTimestamp = messageExt.getStoreTimestamp();
                         if (storeTimestamp > stoppingTimestamp) {
@@ -320,9 +318,9 @@
             if (StringUtils.isNotBlank(accessKey) && StringUtils.isNotBlank(secretKey)) {
                 AclClientRPCHook aclClientRPCHook =
                         new AclClientRPCHook(new SessionCredentials(accessKey, secretKey));
-                consumer = new DefaultMQPullConsumer(consumerGroup, aclClientRPCHook);
+                consumer = new DefaultLitePullConsumer(consumerGroup, aclClientRPCHook);
             } else {
-                consumer = new DefaultMQPullConsumer(consumerGroup);
+                consumer = new DefaultLitePullConsumer(consumerGroup);
             }
             consumer.setNamesrvAddr(nameServerAddress);
             consumer.setInstanceName(
@@ -333,6 +331,11 @@
                             consumerGroup,
                             "" + System.nanoTime()));
             consumer.start();
+            if (StringUtils.isNotEmpty(sql)) {
+                consumer.subscribe(topic, MessageSelector.bySql(sql));
+            } else {
+                consumer.subscribe(topic, tag);
+            }
         } catch (MQClientException e) {
             LOG.error("Failed to initial RocketMQ consumer.", e);
             consumer.shutdown();
diff --git a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java
index 6db5075..8b4fd52 100644
--- a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java
+++ b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java
@@ -44,6 +44,7 @@
 import static org.apache.rocketmq.flink.common.RocketMQOptions.NAME_SERVER_ADDRESS;
 import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_ACCESS_KEY;
 import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_COLUMN_ERROR_DEBUG;
+import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_CONSUMER_POLL_MS;
 import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_ENCODING;
 import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_END_TIME;
 import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_FIELD_DELIMITER;
@@ -104,6 +105,7 @@
         optionalOptions.add(OPTIONAL_ACCESS_KEY);
         optionalOptions.add(OPTIONAL_SECRET_KEY);
         optionalOptions.add(OPTIONAL_SCAN_STARTUP_MODE);
+        optionalOptions.add(OPTIONAL_CONSUMER_POLL_MS);
         return optionalOptions;
     }
 
@@ -182,6 +184,7 @@
                 configuration.getLong(
                         RocketMQOptions.OPTIONAL_OFFSET_FROM_TIMESTAMP, System.currentTimeMillis());
         return new RocketMQScanTableSource(
+                configuration.getLong(OPTIONAL_CONSUMER_POLL_MS),
                 descriptorProperties,
                 physicalSchema,
                 topic,
diff --git a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java
index dc92a47..3eb68df 100644
--- a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java
+++ b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java
@@ -73,10 +73,12 @@
     private final long startMessageOffset;
     private final long startTime;
     private final boolean useNewApi;
+    private final long pollTime;
 
     private List<String> metadataKeys;
 
     public RocketMQScanTableSource(
+            long pollTime,
             DescriptorProperties properties,
             TableSchema schema,
             String topic,
@@ -93,6 +95,7 @@
             String consumerOffsetMode,
             long consumerOffsetTimestamp,
             boolean useNewApi) {
+        this.pollTime = pollTime;
         this.properties = properties;
         this.schema = schema;
         this.topic = topic;
@@ -122,6 +125,7 @@
         if (useNewApi) {
             return SourceProvider.of(
                     new RocketMQSource<>(
+                            pollTime,
                             topic,
                             consumerGroup,
                             nameServerAddress,
@@ -162,6 +166,7 @@
     public DynamicTableSource copy() {
         RocketMQScanTableSource tableSource =
                 new RocketMQScanTableSource(
+                        pollTime,
                         properties,
                         schema,
                         topic,
diff --git a/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java b/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java
index 7ce124d..9c5042c 100644
--- a/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java
+++ b/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java
@@ -18,8 +18,7 @@
 
 package org.apache.rocketmq.flink.legacy;
 
-import org.apache.rocketmq.client.consumer.DefaultMQPullConsumer;
-import org.apache.rocketmq.client.consumer.MQPullConsumerScheduleService;
+import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer;
 import org.apache.rocketmq.client.consumer.PullResult;
 import org.apache.rocketmq.client.consumer.PullStatus;
 import org.apache.rocketmq.common.message.MessageExt;
@@ -42,10 +41,7 @@
 
 import static org.apache.rocketmq.flink.legacy.common.util.TestUtils.setFieldValue;
 import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyBoolean;
-import static org.mockito.Matchers.anyInt;
 import static org.mockito.Matchers.anyLong;
-import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -55,8 +51,7 @@
 public class RocketMQSourceTest {
 
     private RocketMQSourceFunction rocketMQSource;
-    private MQPullConsumerScheduleService pullConsumerScheduleService;
-    private DefaultMQPullConsumer consumer;
+    private DefaultLitePullConsumer consumer;
     private KeyValueDeserializationSchema deserializationSchema;
     private String topic = "tpc";
 
@@ -71,12 +66,8 @@
         setFieldValue(rocketMQSource, "offsetTable", new ConcurrentHashMap<>());
         setFieldValue(rocketMQSource, "restoredOffsets", new ConcurrentHashMap<>());
 
-        pullConsumerScheduleService = new MQPullConsumerScheduleService("g");
-
-        consumer = mock(DefaultMQPullConsumer.class);
-        pullConsumerScheduleService.setDefaultMQPullConsumer(consumer);
+        consumer = mock(DefaultLitePullConsumer.class);
         setFieldValue(rocketMQSource, "consumer", consumer);
-        setFieldValue(rocketMQSource, "pullConsumerScheduleService", pullConsumerScheduleService);
     }
 
     @Test
@@ -89,9 +80,8 @@
         msgFoundList.add(messageExt);
         PullResult pullResult = new PullResult(PullStatus.FOUND, 3, 1, 5, msgFoundList);
 
-        when(consumer.fetchConsumeOffset(any(MessageQueue.class), anyBoolean())).thenReturn(2L);
-        when(consumer.pull(any(MessageQueue.class), anyString(), anyLong(), anyInt()))
-                .thenReturn(pullResult);
+        when(consumer.committed(any(MessageQueue.class))).thenReturn(2L);
+        when(consumer.poll(anyLong())).thenReturn(pullResult.getMsgFoundList());
 
         SourceContext context = mock(SourceContext.class);
         when(context.getCheckpointLock()).thenReturn(new Object());
@@ -101,7 +91,6 @@
         // schedule the pull task
         Set<MessageQueue> set = new HashSet();
         set.add(new MessageQueue(topic, "brk", 1));
-        pullConsumerScheduleService.putTask(topic, set);
 
         MessageExt msg = pullResult.getMsgFoundList().get(0);