blob: d907632a296eb3ed65f325d453b9c2d3bbdf7de5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import operator
import socket
import time
from datetime import timedelta
from enum import Enum
from functools import reduce
import certifi
from grpc import aio, ssl_channel_credentials
from rocketmq.log import logger
from rocketmq.protocol import service_pb2, service_pb2_grpc
from rocketmq.protocol.definition_pb2 import Address as ProtoAddress
from rocketmq.protocol.definition_pb2 import \
AddressScheme as ProtoAddressScheme
from rocketmq.protocol.definition_pb2 import Endpoints as ProtoEndpoints
class AddressScheme(Enum):
Unspecified = 0
Ipv4 = 1
Ipv6 = 2
DomainName = 3
@staticmethod
def to_protobuf(scheme):
if scheme == AddressScheme.Ipv4:
return ProtoAddressScheme.IPV4
elif scheme == AddressScheme.Ipv6:
return ProtoAddressScheme.IPV6
elif scheme == AddressScheme.DomainName:
return ProtoAddressScheme.DOMAIN_NAME
else: # Unspecified or other cases
return ProtoAddressScheme.ADDRESS_SCHEME_UNSPECIFIED
class Address:
def __init__(self, host, port):
self.host = host
self.port = port
def to_protobuf(self):
proto_address = ProtoAddress()
proto_address.host = self.host
proto_address.port = self.port
return proto_address
class Endpoints:
HttpPrefix = "http://"
HttpsPrefix = "https://"
DefaultPort = 80
EndpointSeparator = ":"
def __init__(self, endpoints):
self.Addresses = []
self.scheme = AddressScheme.Unspecified
self._hash = None
if type(endpoints) == str:
if endpoints.startswith(self.HttpPrefix):
endpoints = endpoints[len(self.HttpPrefix):]
if endpoints.startswith(self.HttpsPrefix):
endpoints = endpoints[len(self.HttpsPrefix):]
index = endpoints.find(self.EndpointSeparator)
port = int(endpoints[index + 1:]) if index > 0 else 80
host = endpoints[:index] if index > 0 else endpoints
address = Address(host, port)
self.Addresses.append(address)
try:
socket.inet_pton(socket.AF_INET, host)
self.scheme = AddressScheme.IPv4
except socket.error:
try:
socket.inet_pton(socket.AF_INET6, host)
self.scheme = AddressScheme.IPv6
except socket.error:
self.scheme = AddressScheme.DomainName
self.Addresses.append(address)
# Assuming AddressListEqualityComparer exists
self._hash = 17
self._hash = (self._hash * 31) + reduce(
operator.xor, (hash(address) for address in self.Addresses)
)
self._hash = (self._hash * 31) + hash(self.scheme)
else:
self.Addresses = [
Address(addr.host, addr.port) for addr in endpoints.addresses
]
if not self.Addresses:
raise Exception("No available address")
if endpoints.scheme == "Ipv4":
self.scheme = AddressScheme.Ipv4
elif endpoints.scheme == "Ipv6":
self.scheme = AddressScheme.Ipv6
else:
self.scheme = AddressScheme.DomainName
if len(self.Addresses) > 1:
raise Exception(
"Multiple addresses are\
not allowed in domain scheme"
)
self._hash = self._calculate_hash()
def _calculate_hash(self):
hash_value = 17
for address in self.Addresses:
hash_value = (hash_value * 31) + hash(address)
hash_value = (hash_value * 31) + hash(self.scheme)
return hash_value
def __str__(self):
for address in self.Addresses:
return str(address.host) + str(address.port)
def grpc_target(self, sslEnabled):
for address in self.Addresses:
return address.host + ":" + str(address.port)
raise ValueError("No available address")
def __eq__(self, other):
if other is None:
return False
if self is other:
return True
res = self.Addresses == other.Addresses and self.Scheme == other.Scheme
return res
def __hash__(self):
return self._hash
def to_protobuf(self):
proto_endpoints = ProtoEndpoints()
proto_endpoints.scheme = self.scheme.to_protobuf(self.scheme)
proto_endpoints.addresses.extend([i.to_protobuf() for i in self.Addresses])
return proto_endpoints
class RpcClient:
channel_options = [
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.connect_timeout_ms", 3000),
]
def __init__(self, endpoints: str, ssl_enabled: bool = True):
self.__endpoints = endpoints
self.__cert = certifi.contents().encode(encoding="utf-8")
if ssl_enabled:
self.__channel = aio.secure_channel(
endpoints,
ssl_channel_credentials(root_certificates=self.__cert),
options=RpcClient.channel_options,
)
else:
self.__channel = aio.insecure_channel(
endpoints, options=RpcClient.channel_options
)
self.__stub = service_pb2_grpc.MessagingServiceStub(self.__channel)
self.activity_nano_time = time.monotonic_ns()
def idle_duration(self):
return timedelta(
microseconds=(time.monotonic_ns() - self.activity_nano_time) / 1000
)
async def query_route(
self, request: service_pb2.QueryRouteRequest, metadata, timeout_seconds: int
):
# metadata = [('x-mq-client-id', 'value1')]
return await self.__stub.QueryRoute(
request, timeout=timeout_seconds, metadata=metadata
)
async def heartbeat(
self, request: service_pb2.HeartbeatRequest, metadata, timeout_seconds: int
):
return await self.__stub.Heartbeat(
request, metadata=metadata, timeout=timeout_seconds
)
async def send_message(
self, request: service_pb2.SendMessageRequest, metadata, timeout_seconds: int
):
return await self.__stub.SendMessage(
request, metadata=metadata, timeout=timeout_seconds
)
async def receive_message(
self, request: service_pb2.ReceiveMessageRequest, metadata, timeout_seconds: int
):
results = self.__stub.ReceiveMessage(
request, metadata=metadata, timeout=timeout_seconds
)
response = []
try:
async for result in results:
if result.HasField("message"):
response.append(result.message)
except Exception as e:
logger.info("An error occurred: %s", e)
# Handle error as appropriate for your use case
return response
async def query_assignment(
self,
request: service_pb2.QueryAssignmentRequest,
metadata,
timeout_seconds: int,
):
return await self.__stub.QueryAssignment(
request, metadata=metadata, timeout=timeout_seconds
)
async def ack_message(
self, request: service_pb2.AckMessageRequest, metadata, timeout_seconds: int
):
return await self.__stub.AckMessage(
request, metadata=metadata, timeout=timeout_seconds
)
async def forward_message_to_dead_letter_queue(
self,
request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
metadata,
timeout_seconds: int,
):
return await self.__stub.ForwardMessageToDeadLetterQueue(
request, metadata=metadata, timeout=timeout_seconds
)
async def end_transaction(
self, request: service_pb2.EndTransactionRequest, metadata, timeout_seconds: int
):
return await self.__stub.EndTransaction(
request, metadata=metadata, timeout=timeout_seconds
)
async def notify_client_termination(
self,
request: service_pb2.NotifyClientTerminationRequest,
metadata,
timeout_seconds: int,
):
return await self.__stub.NotifyClientTermination(
request, metadata=metadata, timeout=timeout_seconds
)
async def change_invisible_duration(
self,
request: service_pb2.ChangeInvisibleDurationRequest,
metadata,
timeout_seconds: int,
):
return await self.__stub.ChangeInvisibleDuration(
request, metadata=metadata, timeout=timeout_seconds
)
async def send_requests(self, requests, stream):
for request in requests:
await stream.send_message(request)
def telemetry(self, metadata, timeout_seconds: int):
stream = self.__stub.Telemetry(metadata=metadata, timeout=timeout_seconds)
return stream
async def test():
client = RpcClient("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8081")
request = service_pb2.SendMessageRequest()
response = await client.send_message(request, 3)
logger.info(response)
if __name__ == "__main__":
asyncio.run(test())