blob: a85eb809287ff80b792a4eb5c33e05800b4856d9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
import re
import threading
from datetime import timedelta
from threading import Lock
from typing import Dict
import rocketmq
from google.protobuf.duration_pb2 import Duration
from rocketmq.client_config import ClientConfig
from rocketmq.consumer import Consumer
from rocketmq.definition import PermissionHelper
from rocketmq.filter_expression import FilterExpression
from rocketmq.log import logger
from rocketmq.message import MessageView
from rocketmq.protocol.definition_pb2 import Resource
from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
from rocketmq.protocol.service_pb2 import \
AckMessageEntry as ProtoAckMessageEntry
from rocketmq.protocol.service_pb2 import \
AckMessageRequest as ProtoAckMessageRequest
from rocketmq.protocol.service_pb2 import \
ChangeInvisibleDurationRequest as ProtoChangeInvisibleDurationRequest
from rocketmq.rpc_client import Endpoints
from rocketmq.session_credentials import (SessionCredentials,
SessionCredentialsProvider)
from rocketmq.simple_subscription_settings import SimpleSubscriptionSettings
from rocketmq.state import State
from utils import get_positive_mod
class SubscriptionLoadBalancer:
"""This class serves as a load balancer for message subscription.
It keeps track of a rotating index to help distribute the load evenly.
"""
def __init__(self, topic_route_data):
#: current index for message queue selection
self._index = random.randint(0, 10000) # assuming a range of 0-10000
#: thread lock to ensure atomic update to the index
self._index_lock = threading.Lock()
#: filter the message queues which are readable and from the master broker
self._message_queues = [
mq for mq in topic_route_data.message_queues
if PermissionHelper().is_readable(mq.permission)
and mq.broker.id == rocketmq.utils.master_broker_id
]
def update(self, topic_route_data):
"""Updates the message queues based on the new topic route data."""
self._index += 1
self._message_queues = [
mq for mq in topic_route_data.message_queues
if PermissionHelper().is_readable(mq.permission)
and mq.broker.id == rocketmq.utils.master_broker_id
]
return self
def take_message_queue(self):
"""Fetches the next message queue based on the current index."""
with self._index_lock:
index = get_positive_mod(self._index, len(self._message_queues))
self._index += 1
return self._message_queues[index]
class SimpleConsumer(Consumer):
"""The SimpleConsumer class extends the Client class and is used to consume
messages from specific topics in RocketMQ.
"""
def __init__(self, client_config: ClientConfig, consumer_group: str, await_duration: int, subscription_expressions: Dict[str, FilterExpression]):
"""Create a new SimpleConsumer.
:param client_config: The configuration for the client.
:param consumer_group: The consumer group.
:param await_duration: The await duration.
:param subscription_expressions: The subscription expressions.
"""
super().__init__(client_config, consumer_group)
self._consumer_group = consumer_group
self._await_duration = await_duration
self._subscription_expressions = subscription_expressions
self._simple_subscription_settings = SimpleSubscriptionSettings(self.client_id, self.endpoints, self._consumer_group, timedelta(seconds=10), 10, self._subscription_expressions)
self._subscription_route_data_cache = {}
self._topic_round_robin_index = 0
self._state_lock = Lock()
self._state = State.New
self._subscription_load_balancer = {} # A dictionary to keep subscription load balancers
def get_topics(self):
return set(self._subscription_expressions.keys())
def get_settings(self):
return self._simple_subscription_settings
async def subscribe(self, topic: str, filter_expression: FilterExpression):
if self._state != State.Running:
raise Exception("Simple consumer is not running")
await self.get_subscription_load_balancer(topic)
self._subscription_expressions[topic] = filter_expression
def unsubscribe(self, topic: str):
if self._state != State.Running:
raise Exception("Simple consumer is not running")
try:
self._subscription_expressions.pop(topic)
except KeyError:
pass
async def start(self):
"""Start the RocketMQ consumer and log the operation."""
logger.info(f"Begin to start the rocketmq consumer, client_id={self.client_id}")
with self._state_lock:
if self._state != State.New:
raise Exception("Consumer already started")
await super().start()
# Start all necessary operations
self._state = State.Running
logger.info(f"The rocketmq consumer starts successfully, client_id={self.client_id}")
async def shutdown(self):
"""Shutdown the RocketMQ consumer and log the operation."""
logger.info(f"Begin to shutdown the rocketmq consumer, client_id={self.client_id}")
with self._state_lock:
if self._state != State.Running:
raise Exception("Consumer is not running")
# Shutdown all necessary operations
self._state = State.Terminated
await super().shutdown()
logger.info(f"Shutdown the rocketmq consumer successfully, client_id={self.client_id}")
def update_subscription_load_balancer(self, topic, topic_route_data):
# if a load balancer for this topic already exists in the subscription routing data cache, update it
subscription_load_balancer = self._subscription_route_data_cache.get(topic)
if subscription_load_balancer:
subscription_load_balancer.update(topic_route_data)
# otherwise, create a new subscription load balancer
else:
subscription_load_balancer = SubscriptionLoadBalancer(topic_route_data)
# store new or updated subscription load balancers in the subscription routing data cache
self._subscription_route_data_cache[topic] = subscription_load_balancer
return subscription_load_balancer
async def get_subscription_load_balancer(self, topic):
# if a load balancer for this topic already exists in the subscription routing data cache, return it
subscription_load_balancer = self._subscription_route_data_cache.get(topic)
if subscription_load_balancer:
return subscription_load_balancer
# otherwise, obtain the routing data for the topic
topic_route_data = await self.get_route_data(topic)
# update subscription load balancer
return self.update_subscription_load_balancer(topic, topic_route_data)
async def receive(self, max_message_num, invisible_duration):
if self._state != State.Running:
raise Exception("Simple consumer is not running")
if max_message_num <= 0:
raise Exception("maxMessageNum must be greater than 0")
copy = dict(self._subscription_expressions)
topics = list(copy.keys())
if len(topics) == 0:
raise ValueError("There is no topic to receive message")
index = (self._topic_round_robin_index + 1) % len(topics)
self._topic_round_robin_index = index
topic = topics[index]
filter_expression = self._subscription_expressions[topic]
subscription_load_balancer = await self.get_subscription_load_balancer(topic)
mq = subscription_load_balancer.take_message_queue()
request = self.wrap_receive_message_request(max_message_num, mq, filter_expression, self._await_duration, invisible_duration)
result = await self.receive_message(request, mq, self._await_duration)
return result.messages
def wrap_change_invisible_duration(self, message_view: MessageView, invisible_duration):
topic_resource = ProtoResource()
topic_resource.name = message_view.topic
request = ProtoChangeInvisibleDurationRequest()
request.topic.CopyFrom(topic_resource)
group = ProtoResource()
group.name = message_view.message_group
logger.debug(message_view.message_group)
request.group.CopyFrom(group)
request.receipt_handle = message_view.receipt_handle
request.invisible_duration.CopyFrom(Duration(seconds=invisible_duration))
request.message_id = message_view.message_id
return request
async def change_invisible_duration(self, message_view: MessageView, invisible_duration):
if self._state != State.Running:
raise Exception("Simple consumer is not running")
request = self.wrap_change_invisible_duration(message_view, invisible_duration)
result = await self.client_manager.change_invisible_duration(
message_view.message_queue.broker.endpoints,
request,
self.client_config.request_timeout
)
logger.debug(result)
async def ack(self, message_view: MessageView):
if self._state != State.Running:
raise Exception("Simple consumer is not running")
request = self.wrap_ack_message_request(message_view)
result = await self.client_manager.ack_message(message_view.message_queue.broker.endpoints, request=request, timeout_seconds=self.client_config.request_timeout)
logger.info(result)
def get_protobuf_group(self):
return ProtoResource(name=self.consumer_group)
def wrap_ack_message_request(self, message_view: MessageView):
topic_resource = ProtoResource()
topic_resource.name = message_view.topic
entry = ProtoAckMessageEntry()
entry.message_id = message_view.message_id
entry.receipt_handle = message_view.receipt_handle
request = ProtoAckMessageRequest(group=self.get_protobuf_group(), topic=topic_resource, entries=[entry])
return request
class Builder:
def __init__(self):
self._consumer_group_regex = re.compile(r"^[%a-zA-Z0-9_-]+$")
self._clientConfig = None
self._consumerGroup = None
self._awaitDuration = None
self._subscriptionExpressions = {}
def set_client_config(self, client_config: ClientConfig):
if client_config is None:
raise ValueError("clientConfig should not be null")
self._clientConfig = client_config
return self
def set_consumer_group(self, consumer_group: str):
if consumer_group is None:
raise ValueError("consumerGroup should not be null")
# Assuming CONSUMER_GROUP_REGEX is defined in the outer scope
if not re.match(self._consumer_group_regex, consumer_group):
raise ValueError(f"topic does not match the regex {self._consumer_group_regex}")
self._consumerGroup = consumer_group
return self
def set_await_duration(self, await_duration: int):
self._awaitDuration = await_duration
return self
def set_subscription_expression(self, subscription_expressions: Dict[str, FilterExpression]):
if subscription_expressions is None:
raise ValueError("subscriptionExpressions should not be null")
if len(subscription_expressions) == 0:
raise ValueError("subscriptionExpressions should not be empty")
self._subscriptionExpressions = subscription_expressions
return self
async def build(self):
if self._clientConfig is None:
raise ValueError("clientConfig has not been set yet")
if self._consumerGroup is None:
raise ValueError("consumerGroup has not been set yet")
if len(self._subscriptionExpressions) == 0:
raise ValueError("subscriptionExpressions has not been set yet")
simple_consumer = SimpleConsumer(self._clientConfig, self._consumerGroup, self._awaitDuration, self._subscriptionExpressions)
await simple_consumer.start()
return simple_consumer
async def test():
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
topic = Resource()
topic.name = "normal_topic"
consumer_group = "yourConsumerGroup"
subscription = {topic.name: FilterExpression("*")}
simple_consumer = (await SimpleConsumer.Builder()
.set_client_config(client_config)
.set_consumer_group(consumer_group)
.set_await_duration(15)
.set_subscription_expression(subscription)
.build())
logger.info(simple_consumer)
# while True:
message_views = await simple_consumer.receive(16, 15)
logger.info(message_views)
for message in message_views:
logger.info(message.body)
logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
await simple_consumer.ack(message)
logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
async def test_fifo_message():
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
topic = Resource()
topic.name = "fifo_topic"
consumer_group = "yourConsumerGroup"
subscription = {topic.name: FilterExpression("*")}
simple_consumer = (await SimpleConsumer.Builder()
.set_client_config(client_config)
.set_consumer_group(consumer_group)
.set_await_duration(15)
.set_subscription_expression(subscription)
.build())
logger.info(simple_consumer)
# while True:
message_views = await simple_consumer.receive(16, 15)
# logger.info(message_views)
for message in message_views:
logger.info(message.body)
logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
await simple_consumer.ack(message)
logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
async def test_change_invisible_duration():
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
topic = Resource()
topic.name = "fifo_topic"
consumer_group = "yourConsumerGroup"
subscription = {topic.name: FilterExpression("*")}
simple_consumer = (await SimpleConsumer.Builder()
.set_client_config(client_config)
.set_consumer_group(consumer_group)
.set_await_duration(15)
.set_subscription_expression(subscription)
.build())
logger.info(simple_consumer)
# while True:
message_views = await simple_consumer.receive(16, 15)
# logger.info(message_views)
for message in message_views:
await simple_consumer.change_invisible_duration(message_view=message, invisible_duration=3)
logger.info(message.body)
logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
await simple_consumer.ack(message)
logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
async def test_subscribe_unsubscribe():
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
endpoints=Endpoints("endpoint"),
session_credentials_provider=credentials_provider,
ssl_enabled=True,
)
topic = Resource()
topic.name = "normal_topic"
consumer_group = "yourConsumerGroup"
subscription = {topic.name: FilterExpression("*")}
simple_consumer = (await SimpleConsumer.Builder()
.set_client_config(client_config)
.set_consumer_group(consumer_group)
.set_await_duration(15)
.set_subscription_expression(subscription)
.build())
logger.info(simple_consumer)
# while True:
message_views = await simple_consumer.receive(16, 15)
logger.info(message_views)
for message in message_views:
logger.info(message.body)
logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
await simple_consumer.ack(message)
logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
simple_consumer.unsubscribe('normal_topic')
await simple_consumer.subscribe('fifo_topic', FilterExpression("*"))
message_views = await simple_consumer.receive(16, 15)
logger.info(message_views)
for message in message_views:
logger.info(message.body)
logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}")
await simple_consumer.ack(message)
logger.info(f"Message is acknowledged successfully, message-id={message.message_id}")
if __name__ == "__main__":
asyncio.run(test_subscribe_unsubscribe())