blob: 4ccda2b9f4330c6770e34a14c85d4fa65382af94 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import Set
from protocol import service_pb2
from protocol.service_pb2 import QueryRouteRequest
from rocketmq.client_config import ClientConfig
from rocketmq.client_id_encoder import ClientIdEncoder
from rocketmq.definition import TopicRouteData
from rocketmq.rpc_client import Endpoints, RpcClient
from rocketmq.session import Session
from rocketmq.signature import Signature
class Client:
def __init__(self, client_config: ClientConfig, topics: Set[str]):
self.client_config = client_config
self.client_id = ClientIdEncoder.generate()
self.endpoints = client_config.endpoints
self.topics = topics
self.topic_route_cache = {}
self.sessions_table = {}
self.sessionsLock = threading.Lock()
self.client_manager = ClientManager(self)
async def start_up(self):
# get topic route
for topic in self.topics:
self.topic_route_cache[topic] = await self.fetch_topic_route(topic)
def GetTotalRouteEndpoints(self):
endpoints = set()
for item in self.topic_route_cache.items():
for endpoint in [mq.broker.endpoints for mq in item[1].message_queues]:
endpoints.add(endpoint)
return endpoints
def get_client_config(self):
return self.client_config
async def OnTopicRouteDataFetched(self, topic, topicRouteData):
route_endpoints = set()
for mq in topicRouteData.message_queues:
route_endpoints.add(mq.broker.endpoints)
existed_route_endpoints = self.GetTotalRouteEndpoints()
new_endpoints = route_endpoints.difference(existed_route_endpoints)
for endpoints in new_endpoints:
created, session = await self.GetSession(endpoints)
if not created:
continue
await session.sync_settings(True)
self.topic_route_cache[topic] = topicRouteData
# self.OnTopicRouteDataUpdated0(topic, topicRouteData)
async def fetch_topic_route0(self, topic):
req = QueryRouteRequest()
req.topic.name = topic
address = req.endpoints.addresses.add()
address.host = self.endpoints.Addresses[0].host
address.port = self.endpoints.Addresses[0].port
req.endpoints.scheme = self.endpoints.scheme.to_protobuf(self.endpoints.scheme)
response = await self.client_manager.query_route(self.endpoints, req, 10)
message_queues = response.message_queues
return TopicRouteData(message_queues)
# return topic data
async def fetch_topic_route(self, topic):
topic_route_data = await self.fetch_topic_route0(topic)
await self.OnTopicRouteDataFetched(topic, topic_route_data)
return topic_route_data
async def GetSession(self, endpoints):
self.sessionsLock.acquire()
try:
# Session exists, return in advance.
if endpoints in self.sessions_table:
return (False, self.sessions_table[endpoints])
finally:
self.sessionsLock.release()
self.sessionsLock.acquire()
try:
# Session exists, return in advance.
if endpoints in self.sessions_table:
return (False, self.sessions_table[endpoints])
stream = self.client_manager.telemetry(endpoints, 10)
created = Session(endpoints, stream, self)
self.sessions_table[endpoints] = created
return (True, created)
finally:
self.sessionsLock.release()
class ClientManager:
def __init__(self, client: Client):
self.__client = client
self.__rpc_clients = {}
self.__rpc_clients_lock = threading.Lock()
def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool):
with self.__rpc_clients_lock:
rpc_client = self.__rpc_clients.get(endpoints)
if rpc_client:
return rpc_client
rpc_client = RpcClient(endpoints.grpc_target(True), ssl_enabled)
self.__rpc_clients[endpoints] = rpc_client
return rpc_client
async def query_route(
self,
endpoints: Endpoints,
request: service_pb2.QueryRouteRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.query_route(request, metadata, timeout_seconds)
async def heartbeat(
self,
endpoints: Endpoints,
request: service_pb2.HeartbeatRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.heartbeat(request, metadata, timeout_seconds)
async def send_message(
self,
endpoints: Endpoints,
request: service_pb2.SendMessageRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.send_message(request, metadata, timeout_seconds)
async def query_assignment(
self,
endpoints: Endpoints,
request: service_pb2.QueryAssignmentRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.query_assignment(request, metadata, timeout_seconds)
async def ack_message(
self,
endpoints: Endpoints,
request: service_pb2.AckMessageRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.ack_message(request, metadata, timeout_seconds)
async def forward_message_to_dead_letter_queue(
self,
endpoints: Endpoints,
request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.forward_message_to_dead_letter_queue(
request, metadata, timeout_seconds
)
async def end_transaction(
self,
endpoints: Endpoints,
request: service_pb2.EndTransactionRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.end_transaction(request, metadata, timeout_seconds)
async def notify_client_termination(
self,
endpoints: Endpoints,
request: service_pb2.NotifyClientTerminationRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.notify_client_termination(
request, metadata, timeout_seconds
)
async def change_invisible_duration(
self,
endpoints: Endpoints,
request: service_pb2.ChangeInvisibleDurationRequest,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.change_invisible_duration(
request, metadata, timeout_seconds
)
def telemetry(
self,
endpoints: Endpoints,
timeout_seconds: int,
):
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return rpc_client.telemetry(metadata, timeout_seconds)