| # Licensed to the Apache Software Foundation (ASF) under one or more | 
 | # contributor license agreements.  See the NOTICE file distributed with | 
 | # this work for additional information regarding copyright ownership. | 
 | # The ASF licenses this file to You under the Apache License, Version 2.0 | 
 | # (the "License"); you may not use this file except in compliance with | 
 | # the License.  You may obtain a copy of the License at | 
 | # | 
 | #     http://www.apache.org/licenses/LICENSE-2.0 | 
 | # | 
 | # Unless required by applicable law or agreed to in writing, software | 
 | # distributed under the License is distributed on an "AS IS" BASIS, | 
 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | # See the License for the specific language governing permissions and | 
 | # limitations under the License. | 
 |  | 
 | import asyncio | 
 | import threading | 
 | import time | 
 | from concurrent.futures import Future | 
 |  | 
 | from grpc import ChannelConnectivity | 
 | from rocketmq.grpc_protocol import (AckMessageRequest, | 
 |                                     ChangeInvisibleDurationRequest, | 
 |                                     EndTransactionRequest, HeartbeatRequest, | 
 |                                     NotifyClientTerminationRequest, | 
 |                                     QueryRouteRequest, ReceiveMessageRequest, | 
 |                                     SendMessageRequest, TelemetryCommand) | 
 | from rocketmq.v5.client.connection import RpcChannel, RpcEndpoints | 
 | from rocketmq.v5.log import logger | 
 | from rocketmq.v5.util import ConcurrentMap | 
 |  | 
 |  | 
 | class RpcClient: | 
 |     _instance_lock = threading.Lock() | 
 |     _channel_lock = threading.Lock() | 
 |     _io_loop = None  # event loop for all async io | 
 |     _io_loop_thread = None  # thread for io loop | 
 |     RPC_CLIENT_MAX_IDLE_SECONDS = 60 * 30 | 
 |  | 
 |     def __init__(self, tls_enable=False): | 
 |         with RpcClient._instance_lock: | 
 |             # start an event loop for async io | 
 |             if RpcClient._io_loop is None: | 
 |                 initialized_event = threading.Event() | 
 |                 RpcClient._io_loop_thread = threading.Thread( | 
 |                     target=RpcClient.__init_io_loop, | 
 |                     args=(initialized_event,), | 
 |                     name="channel_io_loop_thread", | 
 |                 ) | 
 |                 RpcClient._io_loop_thread.daemon = True | 
 |                 RpcClient._io_loop_thread.start() | 
 |                 # waiting for thread start success | 
 |                 initialized_event.wait() | 
 |         self.channels = ConcurrentMap() | 
 |         self.__clean_idle_channel_scheduler = None | 
 |         self.__clean_idle_channel_scheduler_threading_event = None | 
 |         self.__enable_retrieve_channel = True | 
 |         self.__tls_enable = tls_enable | 
 |  | 
 |     def retrieve_or_create_channel(self, endpoints: RpcEndpoints): | 
 |         if self.__enable_retrieve_channel is False: | 
 |             raise Exception("RpcClient is not running.") | 
 |         try: | 
 |             # get or create a new grpc channel | 
 |             channel = self.__get_channel(endpoints) | 
 |             if channel is not None: | 
 |                 channel.update_time = int(time.time()) | 
 |             else: | 
 |                 with RpcClient._channel_lock: | 
 |                     channel = RpcChannel(endpoints, self.__tls_enable) | 
 |                     channel.create_channel(RpcClient.get_channel_io_loop()) | 
 |                     self.__put_channel(endpoints, channel) | 
 |             return channel | 
 |         except Exception as e: | 
 |             logger.error(f"retrieve or create channel exception: {e}") | 
 |             raise e | 
 |  | 
 |     def clear_idle_rpc_channels(self): | 
 |         items = self.channels.items() | 
 |         now = int(time.time()) | 
 |         idle_endpoints = list() | 
 |         for endpoints, channel in items: | 
 |             if now - channel.update_time > RpcClient.RPC_CLIENT_MAX_IDLE_SECONDS: | 
 |                 idle_endpoints.append(endpoints) | 
 |         if len(idle_endpoints) > 0: | 
 |             with RpcClient._channel_lock: | 
 |                 for endpoints in idle_endpoints: | 
 |                     logger.info(f"remove idle channel {endpoints.__str__()}") | 
 |                     self.__close_rpc_channel(endpoints) | 
 |                     self.channels.remove(endpoints) | 
 |  | 
 |     def stop(self): | 
 |         with RpcClient._channel_lock: | 
 |             self.__enable_retrieve_channel = False | 
 |             all_endpoints = self.channels.keys() | 
 |             for endpoints in all_endpoints: | 
 |                 self.__close_rpc_channel(endpoints) | 
 |  | 
 |     @staticmethod | 
 |     def get_channel_io_loop(): | 
 |         return RpcClient._io_loop | 
 |  | 
 |     """ grpc MessageService """ | 
 |  | 
 |     def query_topic_route_async( | 
 |         self, endpoints: RpcEndpoints, req: QueryRouteRequest, metadata, timeout=3 | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__query_route_async_0( | 
 |                 endpoints, req, metadata=metadata, timeout=timeout | 
 |             ) | 
 |         ) | 
 |  | 
 |     def send_message_async( | 
 |         self, endpoints: RpcEndpoints, req: SendMessageRequest, metadata, timeout=3 | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__send_message_0(endpoints, req, metadata=metadata, timeout=timeout) | 
 |         ) | 
 |  | 
 |     def receive_message_async( | 
 |         self, endpoints: RpcEndpoints, req: ReceiveMessageRequest, metadata, timeout=3 | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__receive_message_0(endpoints, req, metadata=metadata, timeout=timeout) | 
 |         ) | 
 |  | 
 |     def ack_message_async( | 
 |         self, endpoints: RpcEndpoints, req: AckMessageRequest, metadata, timeout=3 | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__ack_message_0(endpoints, req, metadata=metadata, timeout=timeout) | 
 |         ) | 
 |  | 
 |     def change_invisible_duration_async( | 
 |         self, | 
 |         endpoints: RpcEndpoints, | 
 |         req: ChangeInvisibleDurationRequest, | 
 |         metadata, | 
 |         timeout=3, | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__change_invisible_duration_0( | 
 |                 endpoints, req, metadata=metadata, timeout=timeout | 
 |             ) | 
 |         ) | 
 |  | 
 |     def heartbeat_async( | 
 |         self, endpoints: RpcEndpoints, req: HeartbeatRequest, metadata, timeout=3 | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__heartbeat_async_0(endpoints, req, metadata=metadata, timeout=timeout) | 
 |         ) | 
 |  | 
 |     def telemetry_write_async(self, endpoints: RpcEndpoints, req: TelemetryCommand): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.retrieve_or_create_channel( | 
 |                 endpoints | 
 |             ).telemetry_stream_stream_call.stream_write(req) | 
 |         ) | 
 |  | 
 |     def end_transaction_async( | 
 |         self, endpoints: RpcEndpoints, req: EndTransactionRequest, metadata, timeout=3 | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__end_transaction_0(endpoints, req, metadata=metadata, timeout=timeout) | 
 |         ) | 
 |  | 
 |     def notify_client_termination( | 
 |         self, | 
 |         endpoints: RpcEndpoints, | 
 |         req: NotifyClientTerminationRequest, | 
 |         metadata, | 
 |         timeout=3, | 
 |     ): | 
 |         return RpcClient.__run_message_service_async( | 
 |             self.__notify_client_termination_0( | 
 |                 endpoints, req, metadata=metadata, timeout=timeout | 
 |             ) | 
 |         ) | 
 |  | 
 |     def end_transaction_for_server_check( | 
 |         self, endpoints: RpcEndpoints, req: EndTransactionRequest, metadata, timeout=3 | 
 |     ): | 
 |         # assert asyncio.get_running_loop() == RpcClient._io_loop | 
 |         try: | 
 |             return self.__end_transaction_0( | 
 |                 endpoints, req, metadata=metadata, timeout=timeout | 
 |             ) | 
 |         except Exception as e: | 
 |             logger.error( | 
 |                 f"end transaction exception, topic:{req.topic.name}, message_id:{req.message_id}, transaction_id:{req.transaction_id}: {e}" | 
 |             ) | 
 |             raise e | 
 |  | 
 |     """ build stream_stream_call """ | 
 |  | 
 |     def telemetry_stream( | 
 |         self, endpoints: RpcEndpoints, client, metadata, rebuild, timeout=3000 | 
 |     ): | 
 |         # assert asyncio.get_running_loop() == RpcClient._io_loop | 
 |         try: | 
 |             channel = self.retrieve_or_create_channel(endpoints) | 
 |             stream = channel.async_stub.Telemetry( | 
 |                 metadata=metadata, timeout=timeout, wait_for_ready=True | 
 |             ) | 
 |             channel.register_telemetry_stream_stream_call(stream, client) | 
 |             asyncio.run_coroutine_threadsafe( | 
 |                 channel.telemetry_stream_stream_call.start_stream_read(), | 
 |                 RpcClient.get_channel_io_loop(), | 
 |             ) | 
 |             logger.info( | 
 |                 f"{client.__str__()} rebuild stream_steam_call to {endpoints.__str__()}." | 
 |                 if rebuild | 
 |                 else f"{client.__str__()} create stream_steam_call to {endpoints.__str__()}." | 
 |             ) | 
 |             return channel | 
 |         except Exception as e: | 
 |             raise e | 
 |  | 
 |     """ MessageService.stub impl """ | 
 |  | 
 |     async def __query_route_async_0( | 
 |         self, endpoints: RpcEndpoints, req: QueryRouteRequest, metadata, timeout=3 | 
 |     ): | 
 |         return await self.retrieve_or_create_channel(endpoints).async_stub.QueryRoute( | 
 |             req, metadata=metadata, timeout=timeout | 
 |         ) | 
 |  | 
 |     async def __send_message_0( | 
 |         self, endpoints: RpcEndpoints, req: SendMessageRequest, metadata, timeout=3 | 
 |     ): | 
 |         return await self.retrieve_or_create_channel(endpoints).async_stub.SendMessage( | 
 |             req, metadata=metadata, timeout=timeout | 
 |         ) | 
 |  | 
 |     async def __receive_message_0( | 
 |         self, endpoints: RpcEndpoints, req: ReceiveMessageRequest, metadata, timeout=3 | 
 |     ): | 
 |         return self.retrieve_or_create_channel(endpoints).async_stub.ReceiveMessage( | 
 |             req, metadata=metadata, timeout=timeout | 
 |         ) | 
 |  | 
 |     async def __ack_message_0( | 
 |         self, endpoints: RpcEndpoints, req: AckMessageRequest, metadata, timeout=3 | 
 |     ): | 
 |         return await self.retrieve_or_create_channel(endpoints).async_stub.AckMessage( | 
 |             req, metadata=metadata, timeout=timeout | 
 |         ) | 
 |  | 
 |     async def __heartbeat_async_0( | 
 |         self, endpoints: RpcEndpoints, req: HeartbeatRequest, metadata, timeout=3 | 
 |     ): | 
 |         return await self.retrieve_or_create_channel(endpoints).async_stub.Heartbeat( | 
 |             req, metadata=metadata, timeout=timeout | 
 |         ) | 
 |  | 
 |     async def __change_invisible_duration_0( | 
 |         self, | 
 |         endpoints: RpcEndpoints, | 
 |         req: ChangeInvisibleDurationRequest, | 
 |         metadata, | 
 |         timeout=3, | 
 |     ): | 
 |         return await self.retrieve_or_create_channel( | 
 |             endpoints | 
 |         ).async_stub.ChangeInvisibleDuration(req, metadata=metadata, timeout=timeout) | 
 |  | 
 |     async def __end_transaction_0( | 
 |         self, endpoints: RpcEndpoints, req: EndTransactionRequest, metadata, timeout=3 | 
 |     ): | 
 |         return await self.retrieve_or_create_channel( | 
 |             endpoints | 
 |         ).async_stub.EndTransaction(req, metadata=metadata, timeout=timeout) | 
 |  | 
 |     async def __notify_client_termination_0( | 
 |         self, | 
 |         endpoints: RpcEndpoints, | 
 |         req: NotifyClientTerminationRequest, | 
 |         metadata, | 
 |         timeout=3, | 
 |     ): | 
 |         return await self.retrieve_or_create_channel( | 
 |             endpoints | 
 |         ).async_stub.NotifyClientTermination(req, metadata=metadata, timeout=timeout) | 
 |  | 
 |     async def __create_channel_async(self, endpoints: RpcEndpoints): | 
 |         return self.retrieve_or_create_channel(endpoints) | 
 |  | 
 |     """ private """ | 
 |  | 
 |     def __get_channel(self, endpoints: RpcEndpoints) -> RpcChannel: | 
 |         return self.channels.get(endpoints) | 
 |  | 
 |     def __put_channel(self, endpoints: RpcEndpoints, channel): | 
 |         self.channels.put(endpoints, channel) | 
 |  | 
 |     def __close_rpc_channel(self, endpoints: RpcEndpoints): | 
 |         channel = self.__get_channel(endpoints) | 
 |         if ( | 
 |             channel is not None | 
 |             and channel.channel_state() is not ChannelConnectivity.SHUTDOWN | 
 |         ): | 
 |             try: | 
 |                 channel.close_channel(RpcClient.get_channel_io_loop()) | 
 |                 self.channels.remove(endpoints) | 
 |             except Exception as e: | 
 |                 logger.error(f"close channel {endpoints.__str__()} error: {e}") | 
 |                 raise e | 
 |  | 
 |     @staticmethod | 
 |     def __init_io_loop(initialized_event): | 
 |         # start a thread, set an event loop to the thread. all clients use the same event loop for io operation | 
 |         # loop only init once, running forever until the process ends | 
 |         # RpcClient use RpcClient._io_loop to execute grpc call | 
 |         try: | 
 |             loop = asyncio.new_event_loop() | 
 |             asyncio.set_event_loop(loop) | 
 |             RpcClient._io_loop = loop | 
 |             initialized_event.set() | 
 |             logger.info("start io loop thread success.") | 
 |             loop.run_forever() | 
 |         except Exception as e: | 
 |             logger.error(f"start io loop thread exception: {e}") | 
 |  | 
 |     @staticmethod | 
 |     def __run_message_service_async(func): | 
 |         try: | 
 |             # execute grpc call in RpcClient._io_loop | 
 |             return asyncio.run_coroutine_threadsafe( | 
 |                 func, RpcClient.get_channel_io_loop() | 
 |             ) | 
 |         except Exception as e: | 
 |             future = Future() | 
 |             future.set_exception(e) | 
 |             return future |