Implementation of Simple Consumer for Python Client (#588)
* finish simple_consumer
* fix style issues
* delete private info
* convert comments to English
* add state enum & change_invisible_duration
* extract example
* add more tests
* fix style issue
diff --git a/python/examples/simple_consumer_example.py b/python/examples/simple_consumer_example.py
new file mode 100644
index 0000000..07ff20b
--- /dev/null
+++ b/python/examples/simple_consumer_example.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+
+from rocketmq.client_config import ClientConfig
+from rocketmq.filter_expression import FilterExpression
+from rocketmq.log import logger
+from rocketmq.protocol.definition_pb2 import Resource
+from rocketmq.rpc_client import Endpoints
+from rocketmq.session_credentials import (SessionCredentials,
+ SessionCredentialsProvider)
+from rocketmq.simple_consumer import SimpleConsumer
+
+
+async def test():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+ endpoints=Endpoints("endpoint"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "normal_topic"
+
+ consumer_group = "yourConsumerGroup"
+ subscription = {topic.name: FilterExpression("*")}
+ simple_consumer = (await SimpleConsumer.Builder()
+ .set_client_config(client_config)
+ .set_consumer_group(consumer_group)
+ .set_await_duration(15)
+ .set_subscription_expression(subscription)
+ .build())
+ logger.info(simple_consumer)
+ # while True:
+ message_views = await simple_consumer.receive(16, 15)
+ logger.info(message_views)
+ for message in message_views:
+ logger.info(message.body)
+ logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
+ await simple_consumer.ack(message)
+ logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
+
+if __name__ == "__main__":
+ asyncio.run(test())
diff --git a/python/rocketmq/client.py b/python/rocketmq/client.py
index 509d991..fd6a7f7 100644
--- a/python/rocketmq/client.py
+++ b/python/rocketmq/client.py
@@ -15,7 +15,6 @@
import asyncio
import threading
-from typing import Set
from protocol import definition_pb2, service_pb2
from protocol.definition_pb2 import Code as ProtoCode
@@ -60,7 +59,7 @@
"""
Main client class which handles interaction with the server.
"""
- def __init__(self, client_config: ClientConfig, topics: Set[str]):
+ def __init__(self, client_config: ClientConfig):
"""
Initialization method for the Client class.
@@ -70,7 +69,6 @@
self.client_config = client_config
self.client_id = ClientIdEncoder.generate()
self.endpoints = client_config.endpoints
- self.topics = topics
#: A cache to store topic routes.
self.topic_route_cache = {}
@@ -83,13 +81,16 @@
#: A dictionary to store isolated items.
self.isolated = dict()
+ def get_topics(self):
+ raise NotImplementedError("This method should be implemented by the subclass.")
+
async def start(self):
"""
Start method which initiates fetching of topic routes and schedules heartbeats.
"""
# get topic route
logger.debug(f"Begin to start the rocketmq client, client_id={self.client_id}")
- for topic in self.topics:
+ for topic in self.get_topics():
self.topic_route_cache[topic] = await self.fetch_topic_route(topic)
scheduler = ScheduleWithFixedDelay(self.heartbeat, 3, 12)
scheduler_sync_settings = ScheduleWithFixedDelay(self.sync_settings, 3, 12)
@@ -489,6 +490,22 @@
request, metadata, timeout_seconds
)
+ async def receive_message(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.ReceiveMessageRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
+
+ response = await rpc_client.receive_message(
+ request, metadata, timeout_seconds
+ )
+ return response
+
def telemetry(
self,
endpoints: Endpoints,
diff --git a/python/rocketmq/consumer.py b/python/rocketmq/consumer.py
new file mode 100644
index 0000000..d81d897
--- /dev/null
+++ b/python/rocketmq/consumer.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import List
+
+from filter_expression import ExpressionType
+from google.protobuf.duration_pb2 import Duration
+from message import MessageView
+from rocketmq.client import Client
+from rocketmq.protocol.definition_pb2 import \
+ FilterExpression as ProtoFilterExpression
+from rocketmq.protocol.definition_pb2 import FilterType
+from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
+from rocketmq.protocol.service_pb2 import \
+ ReceiveMessageRequest as ProtoReceiveMessageRequest
+
+
+class ReceiveMessageResult:
+ def __init__(self, endpoints, messages: List['MessageView']):
+ self.endpoints = endpoints
+ self.messages = messages
+
+
+class Consumer(Client):
+ CONSUMER_GROUP_REGEX = re.compile(r"^[%a-zA-Z0-9_-]+$")
+
+ def __init__(self, client_config, consumer_group):
+ super().__init__(client_config)
+ self.consumer_group = consumer_group
+
+ async def receive_message(self, request, mq, await_duration):
+ tolerance = self.client_config.request_timeout
+ timeout = tolerance + await_duration
+ results = await self.client_manager.receive_message(mq.broker.endpoints, request, timeout)
+
+ messages = [MessageView.from_protobuf(message, mq) for message in results]
+ return ReceiveMessageResult(mq.broker.endpoints, messages)
+
+ @staticmethod
+ def _wrap_filter_expression(filter_expression):
+ filter_type = FilterType.TAG
+ if filter_expression.type == ExpressionType.Sql92:
+ filter_type = FilterType.SQL
+ return ProtoFilterExpression(
+ type=filter_type,
+ expression=filter_expression.expression
+ )
+
+ def wrap_receive_message_request(self, batch_size, mq, filter_expression, await_duration, invisible_duration):
+ group = ProtoResource()
+ group.name = self.consumer_group
+ return ProtoReceiveMessageRequest(
+ group=group,
+ message_queue=mq.to_protobuf(),
+ filter_expression=self._wrap_filter_expression(filter_expression),
+ long_polling_timeout=Duration(seconds=await_duration),
+ batch_size=batch_size,
+ auto_renew=False,
+ invisible_duration=Duration(seconds=invisible_duration)
+ )
diff --git a/python/rocketmq/definition.py b/python/rocketmq/definition.py
index 3d63748..498fc6d 100644
--- a/python/rocketmq/definition.py
+++ b/python/rocketmq/definition.py
@@ -62,7 +62,7 @@
:return: The protobuf representation of the broker.
"""
return ProtoBroker(
- Name=self.name, Id=self.id, Endpoints=self.endpoints.to_protobuf()
+ name=self.name, id=self.id, endpoints=self.endpoints.to_protobuf()
)
@@ -76,8 +76,8 @@
:param resource: The resource object.
"""
if resource is not None:
- self.namespace = resource.ResourceNamespace
- self.name = resource.Name
+ self.namespace = resource.resource_namespace
+ self.name = resource.name
else:
self.namespace = ""
self.name = name
@@ -87,7 +87,10 @@
:return: The protobuf representation of the resource.
"""
- return ProtoResource(ResourceNamespace=self.namespace, Name=self.name)
+ resource = ProtoResource()
+ resource.name = self.name
+ resource.resource_namespace = self.namespace
+ return resource
def __str__(self):
return f"{self.namespace}.{self.name}" if self.namespace else self.name
@@ -219,7 +222,7 @@
:param message_queue: The initial message queue to be encapsulated.
"""
- self._topic_resource = Resource(message_queue.topic)
+ self._topic_resource = Resource(message_queue.topic.name, message_queue.topic)
self.queue_id = message_queue.id
self.permission = PermissionHelper.from_protobuf(message_queue.permission)
self.accept_message_types = [
diff --git a/python/rocketmq/filter_expression.py b/python/rocketmq/filter_expression.py
new file mode 100644
index 0000000..9e3e511
--- /dev/null
+++ b/python/rocketmq/filter_expression.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+
+
+class ExpressionType(Enum):
+ Tag = 1
+ Sql92 = 2
+
+
+class FilterExpression:
+ def __init__(self, expression, expression_type=ExpressionType.Tag):
+ self._expression = expression
+ self._type = expression_type
+
+ @property
+ def type(self):
+ return self._type
+
+ @property
+ def expression(self):
+ return self._expression
diff --git a/python/rocketmq/message.py b/python/rocketmq/message.py
index f20e865..5076da9 100644
--- a/python/rocketmq/message.py
+++ b/python/rocketmq/message.py
@@ -14,7 +14,14 @@
# limitations under the License.
-from rocketmq.message_id import MessageId
+import binascii
+import gzip
+import hashlib
+from typing import Dict, List
+
+from rocketmq.definition import MessageQueue
+from rocketmq.protocol.definition_pb2 import DigestType as ProtoDigestType
+from rocketmq.protocol.definition_pb2 import Encoding as ProtoEncoding
class Message:
@@ -70,16 +77,21 @@
class MessageView:
def __init__(
self,
- message_id: MessageId,
+ message_id: str,
topic: str,
body: bytes,
- properties: map,
tag: str,
- keys: str,
message_group: str,
delivery_timestamp: int,
+ keys: List[str],
+ properties: Dict[str, str],
born_host: str,
+ born_time: int,
delivery_attempt: int,
+ message_queue: MessageQueue,
+ receipt_handle: str,
+ offset: int,
+ corrupted: bool
):
self.__message_id = message_id
self.__topic = topic
@@ -91,12 +103,29 @@
self.__delivery_timestamp = delivery_timestamp
self.__born_host = born_host
self.__delivery_attempt = delivery_attempt
+ self.__receipt_handle = receipt_handle
+ self.__born_time = born_time
+ self.__message_queue = message_queue
+ self.__offset = offset
+ self.__corrupted = corrupted
+
+ @property
+ def message_queue(self):
+ return self.__message_queue
+
+ @property
+ def receipt_handle(self):
+ return self.__receipt_handle
@property
def topic(self):
return self.__topic
@property
+ def body(self):
+ return self.__body
+
+ @property
def message_id(self):
return self.__message_id
@@ -123,3 +152,55 @@
@property
def delivery_timestamp(self):
return self.__delivery_timestamp
+
+ @classmethod
+ def from_protobuf(cls, message, message_queue=None):
+ topic = message.topic.name
+ system_properties = message.system_properties
+ message_id = system_properties.message_id
+ body_digest = system_properties.body_digest
+ check_sum = body_digest.checksum
+ raw = message.body
+ corrupted = False
+ digest_type = body_digest.type
+
+ # Digest Type check
+ if digest_type == ProtoDigestType.CRC32:
+ expected_check_sum = format(binascii.crc32(raw) & 0xFFFFFFFF, '08X')
+ if not expected_check_sum == check_sum:
+ corrupted = True
+ elif digest_type == ProtoDigestType.MD5:
+ expected_check_sum = hashlib.md5(raw).hexdigest()
+ if not expected_check_sum == check_sum:
+ corrupted = True
+ elif digest_type == ProtoDigestType.SHA1:
+ expected_check_sum = hashlib.sha1(raw).hexdigest()
+ if not expected_check_sum == check_sum:
+ corrupted = True
+ elif digest_type in [ProtoDigestType.unspecified, None]:
+ print(f"Unsupported message body digest algorithm, digestType={digest_type}, topic={topic}, messageId={message_id}")
+
+ # Body Encoding check
+ body_encoding = system_properties.body_encoding
+ body = raw
+ if body_encoding == ProtoEncoding.GZIP:
+ body = gzip.decompress(message.body)
+ elif body_encoding in [ProtoEncoding.IDENTITY, None]:
+ pass
+ else:
+ print(f"Unsupported message encoding algorithm, topic={topic}, messageId={message_id}, bodyEncoding={body_encoding}")
+
+ tag = system_properties.tag
+ message_group = system_properties.message_group
+ delivery_time = system_properties.delivery_timestamp
+ keys = list(system_properties.keys)
+
+ born_host = system_properties.born_host
+ born_time = system_properties.born_timestamp
+ delivery_attempt = system_properties.delivery_attempt
+ queue_offset = system_properties.queue_offset
+ properties = {key: value for key, value in message.user_properties.items()}
+ receipt_handle = system_properties.receipt_handle
+
+ return cls(message_id, topic, body, tag, message_group, delivery_time, keys, properties, born_host,
+ born_time, delivery_attempt, message_queue, receipt_handle, queue_offset, corrupted)
diff --git a/python/rocketmq/producer.py b/python/rocketmq/producer.py
index 9e10a3d..378bf8e 100644
--- a/python/rocketmq/producer.py
+++ b/python/rocketmq/producer.py
@@ -179,7 +179,8 @@
:param client_config: The configuration for the client.
:param topics: The set of topics to which the producer can send messages.
"""
- super().__init__(client_config, topics)
+ super().__init__(client_config)
+ self.publish_topics = topics
retry_policy = ExponentialBackoffRetryPolicy.immediately_retry_policy(10)
#: Set up the publishing settings with the given parameters.
self.publish_settings = PublishingSettings(
@@ -196,6 +197,9 @@
"""Provide an asynchronous context manager for the producer."""
await self.shutdown()
+ def get_topics(self):
+ return self.publish_topics
+
async def start(self):
"""Start the RocketMQ producer and log the operation."""
logger.info(f"Begin to start the rocketmq producer, client_id={self.client_id}")
@@ -364,7 +368,7 @@
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
- endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
@@ -375,6 +379,7 @@
msg.body = b"My Normal Message Body"
sysperf = SystemProperties()
sysperf.message_id = MessageIdCodec.next_message_id()
+ sysperf.message_group = "yourConsumerGroup"
msg.system_properties.CopyFrom(sysperf)
producer = Producer(client_config, topics={"normal_topic"})
message = Message(topic.name, msg.body)
@@ -388,7 +393,7 @@
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
- endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
@@ -417,7 +422,7 @@
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
- endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
@@ -431,7 +436,7 @@
msg.system_properties.CopyFrom(sysperf)
logger.debug(f"{msg}")
producer = Producer(client_config, topics={"fifo_topic"})
- message = Message(topic.name, msg.body, message_group="yourMessageGroup")
+ message = Message(topic.name, msg.body, message_group="yourConsumerGroup")
await producer.start()
await asyncio.sleep(10)
send_receipt = await producer.send(message)
@@ -442,7 +447,7 @@
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
- endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
@@ -469,7 +474,7 @@
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
- endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
diff --git a/python/rocketmq/rpc_client.py b/python/rocketmq/rpc_client.py
index 6c1107a..d907632 100644
--- a/python/rocketmq/rpc_client.py
+++ b/python/rocketmq/rpc_client.py
@@ -23,9 +23,8 @@
import certifi
from grpc import aio, ssl_channel_credentials
-from protocol import service_pb2
from rocketmq.log import logger
-from rocketmq.protocol import service_pb2_grpc
+from rocketmq.protocol import service_pb2, service_pb2_grpc
from rocketmq.protocol.definition_pb2 import Address as ProtoAddress
from rocketmq.protocol.definition_pb2 import \
AddressScheme as ProtoAddressScheme
diff --git a/python/rocketmq/simple_consumer.py b/python/rocketmq/simple_consumer.py
new file mode 100644
index 0000000..a85eb80
--- /dev/null
+++ b/python/rocketmq/simple_consumer.py
@@ -0,0 +1,423 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import random
+import re
+import threading
+from datetime import timedelta
+from threading import Lock
+from typing import Dict
+
+import rocketmq
+from google.protobuf.duration_pb2 import Duration
+from rocketmq.client_config import ClientConfig
+from rocketmq.consumer import Consumer
+from rocketmq.definition import PermissionHelper
+from rocketmq.filter_expression import FilterExpression
+from rocketmq.log import logger
+from rocketmq.message import MessageView
+from rocketmq.protocol.definition_pb2 import Resource
+from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
+from rocketmq.protocol.service_pb2 import \
+ AckMessageEntry as ProtoAckMessageEntry
+from rocketmq.protocol.service_pb2 import \
+ AckMessageRequest as ProtoAckMessageRequest
+from rocketmq.protocol.service_pb2 import \
+ ChangeInvisibleDurationRequest as ProtoChangeInvisibleDurationRequest
+from rocketmq.rpc_client import Endpoints
+from rocketmq.session_credentials import (SessionCredentials,
+ SessionCredentialsProvider)
+from rocketmq.simple_subscription_settings import SimpleSubscriptionSettings
+from rocketmq.state import State
+from utils import get_positive_mod
+
+
+class SubscriptionLoadBalancer:
+ """This class serves as a load balancer for message subscription.
+ It keeps track of a rotating index to help distribute the load evenly.
+ """
+
+ def __init__(self, topic_route_data):
+ #: current index for message queue selection
+ self._index = random.randint(0, 10000) # assuming a range of 0-10000
+ #: thread lock to ensure atomic update to the index
+ self._index_lock = threading.Lock()
+
+ #: filter the message queues which are readable and from the master broker
+ self._message_queues = [
+ mq for mq in topic_route_data.message_queues
+ if PermissionHelper().is_readable(mq.permission)
+ and mq.broker.id == rocketmq.utils.master_broker_id
+ ]
+
+ def update(self, topic_route_data):
+ """Updates the message queues based on the new topic route data."""
+ self._index += 1
+ self._message_queues = [
+ mq for mq in topic_route_data.message_queues
+ if PermissionHelper().is_readable(mq.permission)
+ and mq.broker.id == rocketmq.utils.master_broker_id
+ ]
+ return self
+
+ def take_message_queue(self):
+ """Fetches the next message queue based on the current index."""
+ with self._index_lock:
+ index = get_positive_mod(self._index, len(self._message_queues))
+ self._index += 1
+ return self._message_queues[index]
+
+
+class SimpleConsumer(Consumer):
+ """The SimpleConsumer class extends the Client class and is used to consume
+ messages from specific topics in RocketMQ.
+ """
+
+ def __init__(self, client_config: ClientConfig, consumer_group: str, await_duration: int, subscription_expressions: Dict[str, FilterExpression]):
+ """Create a new SimpleConsumer.
+
+ :param client_config: The configuration for the client.
+ :param consumer_group: The consumer group.
+ :param await_duration: The await duration.
+ :param subscription_expressions: The subscription expressions.
+ """
+ super().__init__(client_config, consumer_group)
+
+ self._consumer_group = consumer_group
+ self._await_duration = await_duration
+ self._subscription_expressions = subscription_expressions
+
+ self._simple_subscription_settings = SimpleSubscriptionSettings(self.client_id, self.endpoints, self._consumer_group, timedelta(seconds=10), 10, self._subscription_expressions)
+ self._subscription_route_data_cache = {}
+ self._topic_round_robin_index = 0
+ self._state_lock = Lock()
+ self._state = State.New
+ self._subscription_load_balancer = {} # A dictionary to keep subscription load balancers
+
+ def get_topics(self):
+ return set(self._subscription_expressions.keys())
+
+ def get_settings(self):
+ return self._simple_subscription_settings
+
+ async def subscribe(self, topic: str, filter_expression: FilterExpression):
+ if self._state != State.Running:
+ raise Exception("Simple consumer is not running")
+
+ await self.get_subscription_load_balancer(topic)
+ self._subscription_expressions[topic] = filter_expression
+
+ def unsubscribe(self, topic: str):
+ if self._state != State.Running:
+ raise Exception("Simple consumer is not running")
+ try:
+ self._subscription_expressions.pop(topic)
+ except KeyError:
+ pass
+
+ async def start(self):
+ """Start the RocketMQ consumer and log the operation."""
+ logger.info(f"Begin to start the rocketmq consumer, client_id={self.client_id}")
+ with self._state_lock:
+ if self._state != State.New:
+ raise Exception("Consumer already started")
+ await super().start()
+ # Start all necessary operations
+ self._state = State.Running
+ logger.info(f"The rocketmq consumer starts successfully, client_id={self.client_id}")
+
+ async def shutdown(self):
+ """Shutdown the RocketMQ consumer and log the operation."""
+ logger.info(f"Begin to shutdown the rocketmq consumer, client_id={self.client_id}")
+ with self._state_lock:
+ if self._state != State.Running:
+ raise Exception("Consumer is not running")
+ # Shutdown all necessary operations
+ self._state = State.Terminated
+ await super().shutdown()
+ logger.info(f"Shutdown the rocketmq consumer successfully, client_id={self.client_id}")
+
+ def update_subscription_load_balancer(self, topic, topic_route_data):
+ # if a load balancer for this topic already exists in the subscription routing data cache, update it
+ subscription_load_balancer = self._subscription_route_data_cache.get(topic)
+ if subscription_load_balancer:
+ subscription_load_balancer.update(topic_route_data)
+ # otherwise, create a new subscription load balancer
+ else:
+ subscription_load_balancer = SubscriptionLoadBalancer(topic_route_data)
+
+ # store new or updated subscription load balancers in the subscription routing data cache
+ self._subscription_route_data_cache[topic] = subscription_load_balancer
+ return subscription_load_balancer
+
+ async def get_subscription_load_balancer(self, topic):
+ # if a load balancer for this topic already exists in the subscription routing data cache, return it
+ subscription_load_balancer = self._subscription_route_data_cache.get(topic)
+ if subscription_load_balancer:
+ return subscription_load_balancer
+
+ # otherwise, obtain the routing data for the topic
+ topic_route_data = await self.get_route_data(topic)
+ # update subscription load balancer
+ return self.update_subscription_load_balancer(topic, topic_route_data)
+
+ async def receive(self, max_message_num, invisible_duration):
+ if self._state != State.Running:
+ raise Exception("Simple consumer is not running")
+ if max_message_num <= 0:
+ raise Exception("maxMessageNum must be greater than 0")
+ copy = dict(self._subscription_expressions)
+ topics = list(copy.keys())
+ if len(topics) == 0:
+ raise ValueError("There is no topic to receive message")
+
+ index = (self._topic_round_robin_index + 1) % len(topics)
+ self._topic_round_robin_index = index
+ topic = topics[index]
+ filter_expression = self._subscription_expressions[topic]
+ subscription_load_balancer = await self.get_subscription_load_balancer(topic)
+ mq = subscription_load_balancer.take_message_queue()
+ request = self.wrap_receive_message_request(max_message_num, mq, filter_expression, self._await_duration, invisible_duration)
+ result = await self.receive_message(request, mq, self._await_duration)
+ return result.messages
+
+ def wrap_change_invisible_duration(self, message_view: MessageView, invisible_duration):
+ topic_resource = ProtoResource()
+ topic_resource.name = message_view.topic
+
+ request = ProtoChangeInvisibleDurationRequest()
+ request.topic.CopyFrom(topic_resource)
+ group = ProtoResource()
+ group.name = message_view.message_group
+ logger.debug(message_view.message_group)
+ request.group.CopyFrom(group)
+ request.receipt_handle = message_view.receipt_handle
+ request.invisible_duration.CopyFrom(Duration(seconds=invisible_duration))
+ request.message_id = message_view.message_id
+
+ return request
+
+ async def change_invisible_duration(self, message_view: MessageView, invisible_duration):
+ if self._state != State.Running:
+ raise Exception("Simple consumer is not running")
+
+ request = self.wrap_change_invisible_duration(message_view, invisible_duration)
+ result = await self.client_manager.change_invisible_duration(
+ message_view.message_queue.broker.endpoints,
+ request,
+ self.client_config.request_timeout
+ )
+ logger.debug(result)
+
+ async def ack(self, message_view: MessageView):
+ if self._state != State.Running:
+ raise Exception("Simple consumer is not running")
+ request = self.wrap_ack_message_request(message_view)
+ result = await self.client_manager.ack_message(message_view.message_queue.broker.endpoints, request=request, timeout_seconds=self.client_config.request_timeout)
+ logger.info(result)
+
+ def get_protobuf_group(self):
+ return ProtoResource(name=self.consumer_group)
+
+ def wrap_ack_message_request(self, message_view: MessageView):
+ topic_resource = ProtoResource()
+ topic_resource.name = message_view.topic
+ entry = ProtoAckMessageEntry()
+ entry.message_id = message_view.message_id
+ entry.receipt_handle = message_view.receipt_handle
+
+ request = ProtoAckMessageRequest(group=self.get_protobuf_group(), topic=topic_resource, entries=[entry])
+ return request
+
+ class Builder:
+ def __init__(self):
+ self._consumer_group_regex = re.compile(r"^[%a-zA-Z0-9_-]+$")
+ self._clientConfig = None
+ self._consumerGroup = None
+ self._awaitDuration = None
+ self._subscriptionExpressions = {}
+
+ def set_client_config(self, client_config: ClientConfig):
+ if client_config is None:
+ raise ValueError("clientConfig should not be null")
+ self._clientConfig = client_config
+ return self
+
+ def set_consumer_group(self, consumer_group: str):
+ if consumer_group is None:
+ raise ValueError("consumerGroup should not be null")
+ # Assuming CONSUMER_GROUP_REGEX is defined in the outer scope
+ if not re.match(self._consumer_group_regex, consumer_group):
+ raise ValueError(f"topic does not match the regex {self._consumer_group_regex}")
+ self._consumerGroup = consumer_group
+ return self
+
+ def set_await_duration(self, await_duration: int):
+ self._awaitDuration = await_duration
+ return self
+
+ def set_subscription_expression(self, subscription_expressions: Dict[str, FilterExpression]):
+ if subscription_expressions is None:
+ raise ValueError("subscriptionExpressions should not be null")
+ if len(subscription_expressions) == 0:
+ raise ValueError("subscriptionExpressions should not be empty")
+ self._subscriptionExpressions = subscription_expressions
+ return self
+
+ async def build(self):
+ if self._clientConfig is None:
+ raise ValueError("clientConfig has not been set yet")
+ if self._consumerGroup is None:
+ raise ValueError("consumerGroup has not been set yet")
+ if len(self._subscriptionExpressions) == 0:
+ raise ValueError("subscriptionExpressions has not been set yet")
+
+ simple_consumer = SimpleConsumer(self._clientConfig, self._consumerGroup, self._awaitDuration, self._subscriptionExpressions)
+ await simple_consumer.start()
+ return simple_consumer
+
+
+async def test():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+ endpoints=Endpoints("endpoint"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "normal_topic"
+
+ consumer_group = "yourConsumerGroup"
+ subscription = {topic.name: FilterExpression("*")}
+ simple_consumer = (await SimpleConsumer.Builder()
+ .set_client_config(client_config)
+ .set_consumer_group(consumer_group)
+ .set_await_duration(15)
+ .set_subscription_expression(subscription)
+ .build())
+ logger.info(simple_consumer)
+ # while True:
+ message_views = await simple_consumer.receive(16, 15)
+ logger.info(message_views)
+ for message in message_views:
+ logger.info(message.body)
+ logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
+ await simple_consumer.ack(message)
+ logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
+
+
+async def test_fifo_message():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+ endpoints=Endpoints("endpoint"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "fifo_topic"
+
+ consumer_group = "yourConsumerGroup"
+ subscription = {topic.name: FilterExpression("*")}
+ simple_consumer = (await SimpleConsumer.Builder()
+ .set_client_config(client_config)
+ .set_consumer_group(consumer_group)
+ .set_await_duration(15)
+ .set_subscription_expression(subscription)
+ .build())
+ logger.info(simple_consumer)
+ # while True:
+ message_views = await simple_consumer.receive(16, 15)
+ # logger.info(message_views)
+ for message in message_views:
+ logger.info(message.body)
+ logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
+ await simple_consumer.ack(message)
+ logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
+
+
+async def test_change_invisible_duration():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+ endpoints=Endpoints("endpoint"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "fifo_topic"
+
+ consumer_group = "yourConsumerGroup"
+ subscription = {topic.name: FilterExpression("*")}
+ simple_consumer = (await SimpleConsumer.Builder()
+ .set_client_config(client_config)
+ .set_consumer_group(consumer_group)
+ .set_await_duration(15)
+ .set_subscription_expression(subscription)
+ .build())
+ logger.info(simple_consumer)
+ # while True:
+ message_views = await simple_consumer.receive(16, 15)
+ # logger.info(message_views)
+ for message in message_views:
+ await simple_consumer.change_invisible_duration(message_view=message, invisible_duration=3)
+ logger.info(message.body)
+ logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
+ await simple_consumer.ack(message)
+ logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
+
+
+async def test_subscribe_unsubscribe():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+ endpoints=Endpoints("endpoint"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "normal_topic"
+
+ consumer_group = "yourConsumerGroup"
+ subscription = {topic.name: FilterExpression("*")}
+ simple_consumer = (await SimpleConsumer.Builder()
+ .set_client_config(client_config)
+ .set_consumer_group(consumer_group)
+ .set_await_duration(15)
+ .set_subscription_expression(subscription)
+ .build())
+ logger.info(simple_consumer)
+ # while True:
+ message_views = await simple_consumer.receive(16, 15)
+ logger.info(message_views)
+ for message in message_views:
+ logger.info(message.body)
+ logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
+ await simple_consumer.ack(message)
+ logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
+ simple_consumer.unsubscribe('normal_topic')
+ await simple_consumer.subscribe('fifo_topic', FilterExpression("*"))
+ message_views = await simple_consumer.receive(16, 15)
+ logger.info(message_views)
+ for message in message_views:
+ logger.info(message.body)
+ logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
+ await simple_consumer.ack(message)
+ logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
+
+if __name__ == "__main__":
+ asyncio.run(test_subscribe_unsubscribe())
diff --git a/python/rocketmq/simple_subscription_settings.py b/python/rocketmq/simple_subscription_settings.py
new file mode 100644
index 0000000..6d19300
--- /dev/null
+++ b/python/rocketmq/simple_subscription_settings.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+from google.protobuf.duration_pb2 import Duration
+from rocketmq.filter_expression import ExpressionType
+from rocketmq.log import logger
+from rocketmq.protocol.definition_pb2 import \
+ FilterExpression as ProtoFilterExpression
+from rocketmq.protocol.definition_pb2 import FilterType as ProtoFilterType
+from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
+from rocketmq.protocol.definition_pb2 import Settings as ProtoSettings
+from rocketmq.protocol.definition_pb2 import Subscription as ProtoSubscription
+from rocketmq.protocol.definition_pb2 import \
+ SubscriptionEntry as ProtoSubscriptionEntry
+
+from .settings import ClientType, ClientTypeHelper, Settings
+
+
+# Assuming a simple representation of FilterExpression for the purpose of this example
+class FilterExpression:
+ def __init__(self, type, expression):
+ self.Type = type
+ self.Expression = expression
+
+
+class SimpleSubscriptionSettings(Settings):
+
+ def __init__(self, clientId, endpoints, consumerGroup, requestTimeout, longPollingTimeout,
+ subscriptionExpressions: Dict[str, FilterExpression]):
+ super().__init__(clientId, ClientType.SimpleConsumer, endpoints, None, requestTimeout)
+ self._group = consumerGroup # Simplified as string for now
+ self._longPollingTimeout = longPollingTimeout
+ self._subscriptionExpressions = subscriptionExpressions
+
+ def Sync(self, settings: ProtoSettings):
+ if not isinstance(settings, ProtoSettings):
+ logger.error(f"[Bug] Issued settings doesn't match with the client type, clientId={self.ClientId}, clientType={self.ClientType}")
+
+ def to_protobuf(self):
+ subscriptionEntries = []
+
+ for key, value in self._subscriptionExpressions.items():
+ topic = ProtoResource()
+ topic.name = key
+
+ subscriptionEntry = ProtoSubscriptionEntry()
+ filterExpression = ProtoFilterExpression()
+
+ if value.type == ExpressionType.Tag:
+ filterExpression.type = ProtoFilterType.TAG
+ elif value.type == ExpressionType.Sql92:
+ filterExpression.type = ProtoFilterType.SQL
+ else:
+ logger.warn(f"[Bug] Unrecognized filter type={value.Type} for simple consumer")
+
+ filterExpression.expression = value.expression
+ subscriptionEntry.topic.CopyFrom(topic)
+ subscriptionEntries.append(subscriptionEntry)
+
+ subscription = ProtoSubscription()
+ group = ProtoResource()
+ group.name = self._group
+ subscription.group.CopyFrom(group)
+ subscription.subscriptions.extend(subscriptionEntries)
+ duration_longPollingTimeout = Duration(seconds=self._longPollingTimeout)
+ subscription.long_polling_timeout.CopyFrom(duration_longPollingTimeout)
+
+ settings = super().to_protobuf()
+ settings.access_point.CopyFrom(self.Endpoints.to_protobuf()) # Assuming Endpoints has a to_protobuf method
+ settings.client_type = ClientTypeHelper.to_protobuf(self.ClientType)
+
+ settings.request_timeout.CopyFrom(Duration(seconds=int(self.RequestTimeout.total_seconds())))
+ settings.subscription.CopyFrom(subscription)
+
+ return settings
diff --git a/python/rocketmq/state.py b/python/rocketmq/state.py
new file mode 100644
index 0000000..e8f2d01
--- /dev/null
+++ b/python/rocketmq/state.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+
+
+class State(Enum):
+ New = 1
+ Starting = 2
+ Running = 3
+ Stopping = 4
+ Terminated = 5
+ Failed = 6