| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| import threading |
| from typing import Optional |
| |
| from dubbo.bootstrap import Dubbo |
| from dubbo.classes import MethodDescriptor |
| from dubbo.configs import ReferenceConfig |
| from dubbo.constants import common_constants |
| from dubbo.extension import extensionLoader |
| from dubbo.protocol import Invoker, Protocol |
| from dubbo.proxy import RpcCallable, RpcCallableFactory |
| from dubbo.proxy.callables import DefaultRpcCallableFactory |
| from dubbo.registry.protocol import RegistryProtocol |
| from dubbo.types import ( |
| DeserializingFunction, |
| SerializingFunction, |
| RpcTypes, |
| ) |
| from dubbo.url import URL |
| |
| __all__ = ["Client"] |
| |
| |
| class Client: |
| def __init__(self, reference: ReferenceConfig, dubbo: Optional[Dubbo] = None): |
| self._initialized = False |
| self._global_lock = threading.RLock() |
| |
| self._dubbo = dubbo or Dubbo() |
| self._reference = reference |
| |
| self._url: Optional[URL] = None |
| self._protocol: Optional[Protocol] = None |
| self._invoker: Optional[Invoker] = None |
| |
| self._callable_factory: RpcCallableFactory = DefaultRpcCallableFactory() |
| |
| # initialize the invoker |
| self._initialize() |
| |
| def _initialize(self): |
| """ |
| Initialize the invoker. |
| """ |
| with self._global_lock: |
| if self._initialized: |
| return |
| |
| # get the protocol |
| protocol = extensionLoader.get_extension( |
| Protocol, self._reference.protocol |
| )() |
| |
| registry_config = self._dubbo.registry_config |
| |
| self._protocol = ( |
| RegistryProtocol(registry_config, protocol) |
| if self._dubbo.registry_config |
| else protocol |
| ) |
| |
| # build url |
| reference_url = self._reference.to_url() |
| if registry_config: |
| self._url = registry_config.to_url().copy() |
| self._url.path = reference_url.path |
| for k, v in reference_url.parameters.items(): |
| self._url.parameters[k] = v |
| else: |
| self._url = reference_url |
| |
| # create invoker |
| self._invoker = self._protocol.refer(self._url) |
| |
| self._initialized = True |
| |
| def unary( |
| self, |
| method_name: str, |
| request_serializer: Optional[SerializingFunction] = None, |
| response_deserializer: Optional[DeserializingFunction] = None, |
| ) -> RpcCallable: |
| return self._callable( |
| MethodDescriptor( |
| method_name=method_name, |
| arg_serialization=(request_serializer, None), |
| return_serialization=(None, response_deserializer), |
| rpc_type=RpcTypes.UNARY.value, |
| ) |
| ) |
| |
| def client_stream( |
| self, |
| method_name: str, |
| request_serializer: Optional[SerializingFunction] = None, |
| response_deserializer: Optional[DeserializingFunction] = None, |
| ) -> RpcCallable: |
| return self._callable( |
| MethodDescriptor( |
| method_name=method_name, |
| arg_serialization=(request_serializer, None), |
| return_serialization=(None, response_deserializer), |
| rpc_type=RpcTypes.CLIENT_STREAM.value, |
| ) |
| ) |
| |
| def server_stream( |
| self, |
| method_name: str, |
| request_serializer: Optional[SerializingFunction] = None, |
| response_deserializer: Optional[DeserializingFunction] = None, |
| ) -> RpcCallable: |
| return self._callable( |
| MethodDescriptor( |
| method_name=method_name, |
| arg_serialization=(request_serializer, None), |
| return_serialization=(None, response_deserializer), |
| rpc_type=RpcTypes.SERVER_STREAM.value, |
| ) |
| ) |
| |
| def bi_stream( |
| self, |
| method_name: str, |
| request_serializer: Optional[SerializingFunction] = None, |
| response_deserializer: Optional[DeserializingFunction] = None, |
| ) -> RpcCallable: |
| # create method descriptor |
| return self._callable( |
| MethodDescriptor( |
| method_name=method_name, |
| arg_serialization=(request_serializer, None), |
| return_serialization=(None, response_deserializer), |
| rpc_type=RpcTypes.BI_STREAM.value, |
| ) |
| ) |
| |
| def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: |
| """ |
| Generate a proxy for the given method |
| :param method_descriptor: The method descriptor. |
| :return: The proxy. |
| :rtype: RpcCallable |
| """ |
| # get invoker |
| url = self._invoker.get_url() |
| |
| # clone url |
| url = url.copy() |
| url.parameters[common_constants.METHOD_KEY] = ( |
| method_descriptor.get_method_name() |
| ) |
| # set method descriptor |
| url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor |
| |
| # create proxy |
| return self._callable_factory.get_callable(self._invoker, url) |