Merge pull request #276 from Jargon9/develop_consume_first

[ISSUE #272]Support subscribe topic from first
diff --git a/core/src/main/java/org/apache/rocketmq/streams/core/common/Constant.java b/core/src/main/java/org/apache/rocketmq/streams/core/common/Constant.java
index 9b64c00..5b011a0 100644
--- a/core/src/main/java/org/apache/rocketmq/streams/core/common/Constant.java
+++ b/core/src/main/java/org/apache/rocketmq/streams/core/common/Constant.java
@@ -18,8 +18,6 @@
 
 package org.apache.rocketmq.streams.core.common;
 
-import java.nio.charset.StandardCharsets;
-
 public class Constant {
 
     public static final String SHUFFLE_KEY_CLASS_NAME = "shuffle.key.class.name";
@@ -51,4 +49,7 @@
     public static final String STATIC_TOPIC_BROKER_NAME = "__syslo__global__";
 
     public static final String WATERMARK_KEY = "watermark_key";
+
+    public static final Long DEFAULT_CONSUME_OFFSET = 0L;
+
 }
diff --git a/core/src/main/java/org/apache/rocketmq/streams/core/metadata/StreamConfig.java b/core/src/main/java/org/apache/rocketmq/streams/core/metadata/StreamConfig.java
index bd09568..0d318f3 100644
--- a/core/src/main/java/org/apache/rocketmq/streams/core/metadata/StreamConfig.java
+++ b/core/src/main/java/org/apache/rocketmq/streams/core/metadata/StreamConfig.java
@@ -24,6 +24,8 @@
     public static final String ROCKETMQ_STREAMS_STATE_CONSUMER_GROUP = "__state_group";
     public static final String COMMIT_STATE_INTERNAL_MS = "commitStateIntervalMillisecond";
 
+    public static final String ROCKETMQ_STREAMS_CONSUMER_FORM_WHERE = "consume_from_where";
+
     public static Integer STREAMS_PARALLEL_THREAD_NUM = 1;
     public static Integer SHUFFLE_TOPIC_QUEUE_NUM = 8;
     public static Integer SCHEDULED_THREAD_NUM = 2;
diff --git a/core/src/main/java/org/apache/rocketmq/streams/core/running/MessageQueueListenerWrapper.java b/core/src/main/java/org/apache/rocketmq/streams/core/running/MessageQueueListenerWrapper.java
index fbac1fe..7e6f087 100644
--- a/core/src/main/java/org/apache/rocketmq/streams/core/running/MessageQueueListenerWrapper.java
+++ b/core/src/main/java/org/apache/rocketmq/streams/core/running/MessageQueueListenerWrapper.java
@@ -24,23 +24,25 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.Collections;
 import java.util.HashSet;
-import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.BiFunction;
+import java.util.function.Function;
+
+import static org.apache.rocketmq.streams.core.common.Constant.STATE_TOPIC_SUFFIX;
 
 class MessageQueueListenerWrapper implements MessageQueueListener {
     private static final Logger logger = LoggerFactory.getLogger(MessageQueueListenerWrapper.class.getName());
     private final MessageQueueListener originListener;
     private final TopologyBuilder topologyBuilder;
-
     private final ConcurrentHashMap<String, Set<MessageQueue>> ownedMapping = new ConcurrentHashMap<>();
     private final ConcurrentHashMap<String, Processor<?>> mq2Processor = new ConcurrentHashMap<>();
 
     private BiFunction<Set<MessageQueue>, Set<MessageQueue>, Throwable> recoverHandler;
 
+    private Function<Set<MessageQueue>, Throwable> resetOffsetHandler;
+
     MessageQueueListenerWrapper(MessageQueueListener originListener, TopologyBuilder topologyBuilder) {
         this.originListener = originListener;
         this.topologyBuilder = topologyBuilder;
@@ -101,4 +103,8 @@
     public void setRecoverHandler(BiFunction<Set<MessageQueue>, Set<MessageQueue>, Throwable> handler) {
         this.recoverHandler = handler;
     }
+
+    public void setResetOffsetHandler(Function<Set<MessageQueue>, Throwable> handler) {
+        this.resetOffsetHandler = handler;
+    }
 }
diff --git a/core/src/main/java/org/apache/rocketmq/streams/core/running/RocketMQClient.java b/core/src/main/java/org/apache/rocketmq/streams/core/running/RocketMQClient.java
index 313c27e..d6b71b6 100644
--- a/core/src/main/java/org/apache/rocketmq/streams/core/running/RocketMQClient.java
+++ b/core/src/main/java/org/apache/rocketmq/streams/core/running/RocketMQClient.java
@@ -20,14 +20,22 @@
 import org.apache.rocketmq.client.exception.MQClientException;
 import org.apache.rocketmq.client.producer.DefaultMQProducer;
 import org.apache.rocketmq.common.consumer.ConsumeFromWhere;
+import org.apache.rocketmq.common.message.MessageQueue;
+import org.apache.rocketmq.common.protocol.route.QueueData;
+import org.apache.rocketmq.common.protocol.route.TopicRouteData;
+import org.apache.rocketmq.remoting.exception.RemotingException;
 import org.apache.rocketmq.tools.admin.DefaultMQAdminExt;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
 import java.util.Set;
 import java.util.UUID;
 
 import static org.apache.rocketmq.common.protocol.heartbeat.SubscriptionData.SUB_ALL;
+import static org.apache.rocketmq.streams.core.common.Constant.*;
 
 public class RocketMQClient {
     private static final Logger logger = LoggerFactory.getLogger(RocketMQClient.class);
@@ -37,14 +45,13 @@
         this.nameSrvAddr = nameSrvAddr;
     }
 
-    public DefaultLitePullConsumer pullConsumer(String groupName, Set<String> topics) throws MQClientException {
+    public DefaultLitePullConsumer pullConsumer(String groupName,
+                                                Set<String> topics) throws MQClientException {
         DefaultLitePullConsumer pullConsumer = new DefaultLitePullConsumer(groupName);
         pullConsumer.setNamesrvAddr(nameSrvAddr);
-        pullConsumer.setConsumeFromWhere(ConsumeFromWhere.CONSUME_FROM_LAST_OFFSET);
         pullConsumer.setAutoCommit(false);
         pullConsumer.setPullBatchSize(1000);
 
-
         for (String topic : topics) {
             pullConsumer.subscribe(topic, SUB_ALL);
             logger.debug("subscribe topic:{}, groupName:{}", topic, groupName);
diff --git a/core/src/main/java/org/apache/rocketmq/streams/core/running/WorkerThread.java b/core/src/main/java/org/apache/rocketmq/streams/core/running/WorkerThread.java
index f8c5318..61d6af3 100644
--- a/core/src/main/java/org/apache/rocketmq/streams/core/running/WorkerThread.java
+++ b/core/src/main/java/org/apache/rocketmq/streams/core/running/WorkerThread.java
@@ -22,8 +22,13 @@
 import org.apache.rocketmq.client.exception.MQClientException;
 import org.apache.rocketmq.client.producer.DefaultMQProducer;
 import org.apache.rocketmq.common.MixAll;
+import org.apache.rocketmq.common.admin.ConsumeStats;
+import org.apache.rocketmq.common.admin.OffsetWrapper;
+import org.apache.rocketmq.common.consumer.ConsumeFromWhere;
 import org.apache.rocketmq.common.message.MessageExt;
 import org.apache.rocketmq.common.message.MessageQueue;
+import org.apache.rocketmq.common.protocol.body.ClusterInfo;
+import org.apache.rocketmq.common.protocol.route.BrokerData;
 import org.apache.rocketmq.streams.core.common.Constant;
 import org.apache.rocketmq.streams.core.exception.DataProcessThrowable;
 import org.apache.rocketmq.streams.core.exception.RStreamsException;
@@ -46,11 +51,14 @@
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 
+import static org.apache.rocketmq.streams.core.common.Constant.*;
+import static org.apache.rocketmq.streams.core.metadata.StreamConfig.ROCKETMQ_STREAMS_CONSUMER_FORM_WHERE;
 import static org.apache.rocketmq.streams.core.metadata.StreamConfig.ROCKETMQ_STREAMS_CONSUMER_GROUP;
 
 public class WorkerThread extends Thread {
@@ -79,7 +87,6 @@
 
         Set<String> topicNames = topologyBuilder.getSourceTopic();
 
-
         DefaultLitePullConsumer unionConsumer = rocketMQClient.pullConsumer(groupName, topicNames);
 
         MessageQueueListener originListener = unionConsumer.getMessageQueueListener();
@@ -92,7 +99,7 @@
         RocksDBStore rocksDBStore = new RocksDBStore(threadName);
         RocketMQStore store = new RocketMQStore(producer, rocksDBStore, mqAdmin, this.properties);
 
-        this.planetaryEngine = new PlanetaryEngine<>(unionConsumer, producer, store, mqAdmin, wrapper);
+        this.planetaryEngine = new PlanetaryEngine<>(unionConsumer, producer, store, mqAdmin, wrapper, topicNames);
     }
 
     @Override
@@ -101,6 +108,7 @@
             this.planetaryEngine.start();
             logger.info("worker thread=[{}], start task success, jobId:{}", this.getName(), jobId);
 
+            this.planetaryEngine.maybeResetOffsetToFirst();
             this.planetaryEngine.runInLoop();
         } catch (Throwable e) {
             logger.error("worker thread=[{}], error:{}.", this.getName(), e);
@@ -125,11 +133,13 @@
         private final IdleWindowScaner idleWindowScaner;
         private volatile boolean stop = false;
 
+        private Set<String> sourceTopicSet;
+
         private final HashSet<MessageQueue> mq2Commit = new HashSet<>();
 
 
         public PlanetaryEngine(DefaultLitePullConsumer unionConsumer, DefaultMQProducer producer, StateStore stateStore,
-                               DefaultMQAdminExt mqAdmin, MessageQueueListenerWrapper wrapper) {
+                               DefaultMQAdminExt mqAdmin, MessageQueueListenerWrapper wrapper, Set<String> sourceTopicSet) {
             this.unionConsumer = unionConsumer;
             this.producer = producer;
             this.mqAdmin = mqAdmin;
@@ -144,6 +154,8 @@
                     return e;
                 }
             });
+            this.sourceTopicSet = sourceTopicSet;
+
             Integer idleTime = (Integer) WorkerThread.this.properties.getOrDefault(StreamConfig.IDLE_TIME_TO_FIRE_WINDOW, 2000);
             int commitInterval = (Integer) WorkerThread.this.properties.getOrDefault(StreamConfig.COMMIT_STATE_INTERNAL_MS, 2 * 1000);
             this.idleWindowScaner = new IdleWindowScaner(idleTime, executor);
@@ -239,6 +251,54 @@
             }
         }
 
+        void maybeResetOffsetToFirst() throws Exception {
+            ConsumeFromWhere consumeFromWhere = (ConsumeFromWhere) properties.getOrDefault(ROCKETMQ_STREAMS_CONSUMER_FORM_WHERE, ConsumeFromWhere.CONSUME_FROM_LAST_OFFSET);
+
+            if (!consumeFromWhere.equals(ConsumeFromWhere.CONSUME_FROM_FIRST_OFFSET)) {
+                return;
+            }
+
+            for (String topic : sourceTopicSet) {
+                // 内部 topic 不能重置位点
+                if (topic.endsWith(Constant.SHUFFLE_TOPIC_SUFFIX) || topic.endsWith(STATE_TOPIC_SUFFIX)) {
+                    continue;
+                }
+                ConsumeStats consumeStats = mqAdmin.examineConsumeStats(unionConsumer.getConsumerGroup(), topic);
+                Map<MessageQueue, OffsetWrapper> offsetTable = consumeStats.getOffsetTable();
+                Set<MessageQueue> messageQueues = offsetTable.keySet();
+                for (MessageQueue messageQueue : messageQueues) {
+                    try {
+                        // 如果有消费进度,说明已经开始消费,跳过重置其消费进度
+                        if (offsetTable.containsKey(messageQueue) &&
+                                offsetTable.get(messageQueue).getConsumerOffset() != DEFAULT_CONSUME_OFFSET) {
+                            break;
+                        }
+
+                        Long minOffset = mqAdmin.minOffset(messageQueue);
+                        String brokerName = messageQueue.getBrokerName();
+                        ClusterInfo clusterInfo = mqAdmin.examineBrokerClusterInfo();
+                        BrokerData brokerData = clusterInfo.getBrokerAddrTable().get(brokerName);
+                        if (brokerData == null) {
+                            String msg = String.format("get broker error, have no broker info (name:%s)", brokerName);
+                            logger.error(msg);
+                            throw new RStreamsException(msg);
+                        }
+                        for (String brokerAddress : brokerData.getBrokerAddrs().values()) {
+                            mqAdmin.resetOffsetByQueueId(brokerAddress,
+                                    unionConsumer.getConsumerGroup(),
+                                    messageQueue.getTopic(),
+                                    messageQueue.getQueueId(),
+                                    minOffset);
+                        }
+                    } catch (Exception e) {
+                        logger.error("reset messageQueue:{} consumer offset to first failed.", messageQueue, e);
+                        throw e;
+                    }
+                }
+            }
+        }
+
+
         long prepareTime(MessageExt messageExt, SourceSupplier.SourceProcessor<K, V> processor) {
             TimeType type = (TimeType) properties.get(StreamConfig.TIME_TYPE);
 
diff --git a/example/src/main/java/org/apache/rocketmq/streams/examples/WordCountFromFirstOffset.java b/example/src/main/java/org/apache/rocketmq/streams/examples/WordCountFromFirstOffset.java
new file mode 100644
index 0000000..e34f5ef
--- /dev/null
+++ b/example/src/main/java/org/apache/rocketmq/streams/examples/WordCountFromFirstOffset.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.rocketmq.streams.examples;
+
+import org.apache.rocketmq.common.MixAll;
+import org.apache.rocketmq.streams.core.RocketMQStream;
+import org.apache.rocketmq.streams.core.function.ValueMapperAction;
+import org.apache.rocketmq.streams.core.metadata.StreamConfig;
+import org.apache.rocketmq.streams.core.rstream.StreamBuilder;
+import org.apache.rocketmq.streams.core.topology.TopologyBuilder;
+import org.apache.rocketmq.streams.core.util.Pair;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Properties;
+
+import static org.apache.rocketmq.common.consumer.ConsumeFromWhere.CONSUME_FROM_FIRST_OFFSET;
+
+public class WordCountFromFirstOffset {
+    public static void main(String[] args) {
+        StreamBuilder builder = new StreamBuilder("wordCount");
+
+        builder.source("sourceTopic", total -> {
+            String value = new String(total, StandardCharsets.UTF_8);
+            return new Pair<>(null, value);
+        })
+                .flatMap((ValueMapperAction<String, List<String>>) value -> {
+                    String[] splits = value.toLowerCase().split("\\W+");
+                    return Arrays.asList(splits);
+                })
+                .keyBy(value -> value)
+                .count()
+                .toRStream()
+                .print();
+
+        TopologyBuilder topologyBuilder = builder.build();
+
+        Properties properties = new Properties();
+        properties.put(MixAll.NAMESRV_ADDR_PROPERTY, "127.0.0.1:9876");
+        properties.put(StreamConfig.ROCKETMQ_STREAMS_CONSUMER_FORM_WHERE, CONSUME_FROM_FIRST_OFFSET);
+
+        RocketMQStream rocketMQStream = new RocketMQStream(topologyBuilder, properties);
+
+
+        Runtime.getRuntime().addShutdownHook(new Thread("wordcount-shutdown-hook") {
+            @Override
+            public void run() {
+                rocketMQStream.stop();
+            }
+        });
+
+        rocketMQStream.start();
+    }
+}