blob: fd6a7f79f22676fe1b3836f78990a2a87ff3f566 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import threading
from protocol import definition_pb2, service_pb2
from protocol.definition_pb2 import Code as ProtoCode
from protocol.service_pb2 import HeartbeatRequest, QueryRouteRequest
from rocketmq.client_config import ClientConfig
from rocketmq.client_id_encoder import ClientIdEncoder
from rocketmq.definition import Resource, TopicRouteData
from rocketmq.log import logger
from rocketmq.rpc_client import Endpoints, RpcClient
from rocketmq.session import Session
from rocketmq.signature import Signature
class ScheduleWithFixedDelay:
def __init__(self, action, delay, period):
self.action = action
self.delay = delay
self.period = period
self.task = None
async def start(self):
await asyncio.sleep(self.delay)
while True:
try:
await self.action()
except Exception as e:
logger.error(e, "Failed to execute scheduled task")
finally:
await asyncio.sleep(self.period)
def schedule(self):
loop1 = asyncio.new_event_loop()
asyncio.set_event_loop(loop1)
self.task = asyncio.create_task(self.start())
def cancel(self):
if self.task:
self.task.cancel()
class Client:
"""
Main client class which handles interaction with the server.
"""
def __init__(self, client_config: ClientConfig):
"""
Initialization method for the Client class.
:param client_config: Client configuration.
:param topics: Set of topics that the client is subscribed to.
"""
self.client_config = client_config
self.client_id = ClientIdEncoder.generate()
self.endpoints = client_config.endpoints
#: A cache to store topic routes.
self.topic_route_cache = {}
#: A table to store session information.
self.sessions_table = {}
self.sessionsLock = threading.Lock()
self.client_manager = ClientManager(self)
#: A dictionary to store isolated items.
self.isolated = dict()
def get_topics(self):
raise NotImplementedError("This method should be implemented by the subclass.")
async def start(self):
"""
Start method which initiates fetching of topic routes and schedules heartbeats.
"""
# get topic route
logger.debug(f"Begin to start the rocketmq client, client_id={self.client_id}")
for topic in self.get_topics():
self.topic_route_cache[topic] = await self.fetch_topic_route(topic)
scheduler = ScheduleWithFixedDelay(self.heartbeat, 3, 12)
scheduler_sync_settings = ScheduleWithFixedDelay(self.sync_settings, 3, 12)
scheduler.schedule()
scheduler_sync_settings.schedule()
logger.debug(f"Start the rocketmq client successfully, client_id={self.client_id}")
async def shutdown(self):
logger.debug(f"Begin to shutdown rocketmq client, client_id={self.client_id}")
logger.debug(f"Shutdown the rocketmq client successfully, client_id={self.client_id}")
async def heartbeat(self):
"""
Asynchronous method that sends a heartbeat to the server.
"""
try:
endpoints = self.get_total_route_endpoints()
request = HeartbeatRequest()
request.client_type = definition_pb2.PRODUCER
topic = Resource()
topic.name = "normal_topic"
# Collect task into a map.
for item in endpoints:
try:
task = await self.client_manager.heartbeat(item, request, self.client_config.request_timeout)
code = task.status.code
if code == ProtoCode.OK:
logger.info(f"Send heartbeat successfully, endpoints={item}, client_id={self.client_id}")
if item in self.isolated:
self.isolated.pop(item)
logger.info(f"Rejoin endpoints which was isolated before, endpoints={item}, "
+ f"client_id={self.client_id}")
return
status_message = task.status.message
logger.info(f"Failed to send heartbeat, endpoints={item}, code={code}, "
+ f"status_message={status_message}, client_id={self.client_id}")
except Exception:
logger.error(f"Failed to send heartbeat, endpoints={item}")
except Exception as e:
logger.error(f"[Bug] unexpected exception raised during heartbeat, client_id={self.client_id}, Exception: {str(e)}")
def get_total_route_endpoints(self):
"""
Method that returns all route endpoints.
"""
endpoints = set()
for item in self.topic_route_cache.items():
for endpoint in [mq.broker.endpoints for mq in item[1].message_queues]:
endpoints.add(endpoint)
return endpoints
async def get_route_data(self, topic):
"""
Asynchronous method that fetches route data for a given topic.
:param topic: The topic to fetch route data for.
"""
if topic in self.topic_route_cache:
return self.topic_route_cache[topic]
topic_route_data = await self.fetch_topic_route(topic=topic)
return topic_route_data
def get_client_config(self):
"""
Method to return client configuration.
"""
return self.client_config
async def sync_settings(self):
total_route_endpoints = self.get_total_route_endpoints()
for endpoints in total_route_endpoints:
created, session = await self.get_session(endpoints)
await session.sync_settings(True)
logger.info(f"Sync settings to remote, endpoints={endpoints}")
def stats(self):
# TODO: stats implement
pass
async def notify_client_termination(self):
pass
async def on_recover_orphaned_transaction_command(self, endpoints, command):
pass
async def on_verify_message_command(self, endpoints, command):
logger.warn(f"Ignore verify message command from remote, which is not expected, clientId={self.client_id}, "
+ f"endpoints={endpoints}, command={command}")
pass
async def on_print_thread_stack_trace_command(self, endpoints, command):
pass
async def on_settings_command(self, endpoints, settings):
pass
async def on_topic_route_data_fetched(self, topic, topic_route_data):
"""
Asynchronous method that handles the process once the topic route data is fetched.
:param topic: The topic for which the route data is fetched.
:param topic_route_data: The fetched topic route data.
"""
route_endpoints = set()
for mq in topic_route_data.message_queues:
route_endpoints.add(mq.broker.endpoints)
existed_route_endpoints = self.get_total_route_endpoints()
new_endpoints = route_endpoints.difference(existed_route_endpoints)
for endpoints in new_endpoints:
created, session = await self.get_session(endpoints)
if not created:
continue
logger.info(f"Begin to establish session for endpoints={endpoints}, client_id={self.client_id}")
await session.sync_settings(True)
logger.info(f"Establish session for endpoints={endpoints} successfully, client_id={self.client_id}")
self.topic_route_cache[topic] = topic_route_data
async def fetch_topic_route0(self, topic):
"""
Asynchronous method that fetches the topic route.
:param topic: The topic to fetch the route for.
"""
try:
req = QueryRouteRequest()
req.topic.name = topic
address = req.endpoints.addresses.add()
address.host = self.endpoints.Addresses[0].host
address.port = self.endpoints.Addresses[0].port
req.endpoints.scheme = self.endpoints.scheme.to_protobuf(self.endpoints.scheme)
response = await self.client_manager.query_route(self.endpoints, req, 10)
code = response.status.code
if code != ProtoCode.OK:
logger.error(f"Failed to fetch topic route, client_id={self.client_id}, topic={topic}, code={code}, "
+ f"statusMessage={response.status.message}")
message_queues = response.message_queues
return TopicRouteData(message_queues)
except Exception as e:
logger.error(e, f"Failed to fetch topic route, client_id={self.client_id}, topic={topic}")
raise
async def fetch_topic_route(self, topic):
"""
Asynchronous method that fetches the topic route and updates the data.
:param topic: The topic to fetch the route for.
"""
topic_route_data = await self.fetch_topic_route0(topic)
await self.on_topic_route_data_fetched(topic, topic_route_data)
logger.info(f"Fetch topic route successfully, client_id={self.client_id}, topic={topic}, topicRouteData={topic_route_data}")
return topic_route_data
async def get_session(self, endpoints):
"""
Asynchronous method that gets the session for a given endpoint.
:param endpoints: The endpoints to get the session for.
"""
self.sessionsLock.acquire()
try:
# Session exists, return in advance.
if endpoints in self.sessions_table:
return (False, self.sessions_table[endpoints])
finally:
self.sessionsLock.release()
self.sessionsLock.acquire()
try:
# Session exists, return in advance.
if endpoints in self.sessions_table:
return (False, self.sessions_table[endpoints])
stream = self.client_manager.telemetry(endpoints, 10000000)
created = Session(endpoints, stream, self)
self.sessions_table[endpoints] = created
return (True, created)
finally:
self.sessionsLock.release()
def get_client_id(self):
return self.client_id
class ClientManager:
"""Manager class for RPC Clients in a thread-safe manner.
Each instance is created by a specific client and can manage
multiple RPC clients.
"""
def __init__(self, client: Client):
#: The client that instantiated this manager.
self.__client = client
#: A dictionary that maps endpoints to the corresponding RPC clients.
self.__rpc_clients = {}
#: A lock used to ensure thread safety when accessing __rpc_clients.
self.__rpc_clients_lock = threading.Lock()
def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool):
"""Retrieve the RPC client corresponding to the given endpoints.
If not present, a new RPC client is created and stored in __rpc_clients.
:param endpoints: The endpoints associated with the RPC client.
:param ssl_enabled: A flag indicating whether SSL is enabled.
:return: The RPC client associated with the given endpoints.
"""
with self.__rpc_clients_lock:
rpc_client = self.__rpc_clients.get(endpoints)
if rpc_client:
return rpc_client
rpc_client = RpcClient(endpoints.grpc_target(True), ssl_enabled)
self.__rpc_clients[endpoints] = rpc_client
return rpc_client
async def query_route(
self,
endpoints: Endpoints,
request: service_pb2.QueryRouteRequest,
timeout_seconds: int,
):
"""Query the routing information.
:param endpoints: The endpoints to query.
:param request: The request containing the details of the query.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the query.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.query_route(request, metadata, timeout_seconds)
async def heartbeat(
self,
endpoints: Endpoints,
request: service_pb2.HeartbeatRequest,
timeout_seconds: int,
):
"""Send a heartbeat to the server to indicate that the client is still alive.
:param endpoints: The endpoints to send the heartbeat to.
:param request: The request containing the details of the heartbeat.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the heartbeat.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.heartbeat(request, metadata, timeout_seconds)
async def send_message(
self,
endpoints: Endpoints,
request: service_pb2.SendMessageRequest,
timeout_seconds: int,
):
"""Send a message to the server.
:param endpoints: The endpoints to send the message to.
:param request: The request containing the details of the message.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the message sending operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.send_message(request, metadata, timeout_seconds)
async def query_assignment(
self,
endpoints: Endpoints,
request: service_pb2.QueryAssignmentRequest,
timeout_seconds: int,
):
"""Query the assignment information.
:param endpoints: The endpoints to query.
:param request: The request containing the details of the query.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the query.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.query_assignment(request, metadata, timeout_seconds)
async def ack_message(
self,
endpoints: Endpoints,
request: service_pb2.AckMessageRequest,
timeout_seconds: int,
):
"""Send an acknowledgment for a message to the server.
:param endpoints: The endpoints to send the acknowledgment to.
:param request: The request containing the details of the acknowledgment.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the acknowledgment.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.ack_message(request, metadata, timeout_seconds)
async def forward_message_to_dead_letter_queue(
self,
endpoints: Endpoints,
request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
timeout_seconds: int,
):
"""Forward a message to the dead letter queue.
:param endpoints: The endpoints to send the request to.
:param request: The request containing the details of the message to forward.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the forward operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.forward_message_to_dead_letter_queue(
request, metadata, timeout_seconds
)
async def end_transaction(
self,
endpoints: Endpoints,
request: service_pb2.EndTransactionRequest,
timeout_seconds: int,
):
"""Ends a transaction.
:param endpoints: The endpoints to send the request to.
:param request: The request to end the transaction.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the end transaction operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.end_transaction(request, metadata, timeout_seconds)
async def notify_client_termination(
self,
endpoints: Endpoints,
request: service_pb2.NotifyClientTerminationRequest,
timeout_seconds: int,
):
"""Notify server about client termination.
:param endpoints: The endpoints to send the notification to.
:param request: The request containing the details of the termination.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the notification operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.notify_client_termination(
request, metadata, timeout_seconds
)
async def change_invisible_duration(
self,
endpoints: Endpoints,
request: service_pb2.ChangeInvisibleDurationRequest,
timeout_seconds: int,
):
"""Change the invisible duration of a message.
:param endpoints: The endpoints to send the request to.
:param request: The request containing the new invisible duration.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the change operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.change_invisible_duration(
request, metadata, timeout_seconds
)
async def receive_message(
self,
endpoints: Endpoints,
request: service_pb2.ReceiveMessageRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
response = await rpc_client.receive_message(
request, metadata, timeout_seconds
)
return response
def telemetry(
self,
endpoints: Endpoints,
timeout_seconds: int,
):
"""Fetch telemetry information.
:param endpoints: The endpoints to send the request to.
:param timeout_seconds: The maximum time to wait for a response.
:return: The telemetry information.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return rpc_client.telemetry(metadata, timeout_seconds)