| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import asyncio |
| from collections import OrderedDict |
| from typing import Union |
| |
| from pyignite.constants import PROTOCOLS, PROTOCOL_BYTE_ORDER |
| from pyignite.exceptions import HandshakeError, SocketError, connection_errors, AuthenticationError |
| from .bitmask_feature import BitmaskFeature |
| from .connection import BaseConnection |
| |
| from .handshake import HandshakeRequest, HandshakeResponse |
| from .protocol_context import ProtocolContext |
| from .ssl import create_ssl_context |
| from ..stream.binary_stream import BinaryStreamBase |
| |
| |
| class BaseProtocol(asyncio.Protocol): |
| def __init__(self, conn, handshake_fut): |
| super().__init__() |
| self._buffer = bytearray() |
| self._conn = conn |
| self._handshake_fut = handshake_fut |
| |
| def connection_lost(self, exc): |
| self.__process_connection_error(exc if exc else SocketError("Connection closed")) |
| |
| def connection_made(self, transport: asyncio.WriteTransport) -> None: |
| try: |
| self.__send_handshake(transport, self._conn) |
| except Exception as e: |
| self._handshake_fut.set_exception(e) |
| |
| def data_received(self, data: bytes) -> None: |
| self._buffer += data |
| while self.__has_full_response(): |
| packet_sz = self.__packet_size(self._buffer) |
| packet = self._buffer[0:packet_sz] |
| if not self._handshake_fut.done(): |
| hs_response = self.__parse_handshake(packet, self._conn.client) |
| self._handshake_fut.set_result(hs_response) |
| else: |
| self._conn.process_message(packet) |
| self._buffer = self._buffer[packet_sz:len(self._buffer)] |
| |
| def __has_full_response(self): |
| if len(self._buffer) > 4: |
| response_len = int.from_bytes(self._buffer[0:4], byteorder=PROTOCOL_BYTE_ORDER, signed=True) |
| return response_len + 4 <= len(self._buffer) |
| |
| @staticmethod |
| def __packet_size(buffer): |
| return int.from_bytes(buffer[0:4], byteorder=PROTOCOL_BYTE_ORDER, signed=True) + 4 |
| |
| def __process_connection_error(self, exc): |
| connected = self._handshake_fut.done() |
| if not connected: |
| self._handshake_fut.set_exception(exc) |
| self._conn.process_connection_lost(exc, connected) |
| |
| @staticmethod |
| def __send_handshake(transport, conn): |
| hs_request = HandshakeRequest(conn.protocol_context, conn.username, conn.password) |
| with BinaryStreamBase(client=conn.client) as stream: |
| hs_request.from_python(stream) |
| transport.write(stream.getvalue()) |
| |
| @staticmethod |
| def __parse_handshake(data, client): |
| with BinaryStreamBase(client, data) as stream: |
| return HandshakeResponse.parse(stream, client.protocol_context) |
| |
| |
| class AioConnection(BaseConnection): |
| """ |
| Asyncio connection to Ignite node. It serves multiple purposes: |
| |
| * wrapper of asyncio streams. See also https://docs.python.org/3/library/asyncio-stream.html |
| * encapsulates handshake and reconnection. |
| """ |
| |
| def __init__(self, client: 'AioClient', host: str, port: int, username: str = None, password: str = None, |
| **ssl_params): |
| """ |
| Initialize connection. |
| |
| For the use of the SSL-related parameters see |
| https://docs.python.org/3/library/ssl.html#ssl-certificates. |
| |
| :param client: Ignite client object, |
| :param host: Ignite server node's host name or IP, |
| :param port: Ignite server node's port number, |
| :param use_ssl: (optional) set to True if Ignite server uses SSL |
| on its binary connector. Defaults to use SSL when username |
| and password has been supplied, not to use SSL otherwise, |
| :param ssl_version: (optional) SSL version constant from standard |
| `ssl` module. Defaults to TLS v1.1, as in Ignite 2.5, |
| :param ssl_ciphers: (optional) ciphers to use. If not provided, |
| `ssl` default ciphers are used, |
| :param ssl_cert_reqs: (optional) determines how the remote side |
| certificate is treated: |
| |
| * `ssl.CERT_NONE` − remote certificate is ignored (default), |
| * `ssl.CERT_OPTIONAL` − remote certificate will be validated, |
| if provided, |
| * `ssl.CERT_REQUIRED` − valid remote certificate is required, |
| |
| :param ssl_keyfile: (optional) a path to SSL key file to identify |
| local (client) party, |
| :param ssl_keyfile_password: (optional) password for SSL key file, |
| can be provided when key file is encrypted to prevent OpenSSL |
| password prompt, |
| :param ssl_certfile: (optional) a path to ssl certificate file |
| to identify local (client) party, |
| :param ssl_ca_certfile: (optional) a path to a trusted certificate |
| or a certificate chain. Required to check the validity of the remote |
| (server-side) certificate, |
| :param username: (optional) user name to authenticate to Ignite |
| cluster, |
| :param password: (optional) password to authenticate to Ignite cluster. |
| """ |
| super().__init__(client, host, port, username, password, **ssl_params) |
| self._pending_reqs = {} |
| self._transport = None |
| self._loop = asyncio.get_event_loop() |
| self._closed = False |
| self._transport_closed_fut = None |
| |
| @property |
| def closed(self) -> bool: |
| """ Tells if socket is closed. """ |
| return self._closed or not self._transport or self._transport.is_closing() |
| |
| async def connect(self): |
| """ |
| Connect to the given server node with protocol version fallback. |
| """ |
| if self.alive: |
| return |
| self._closed = False |
| await self._connect() |
| |
| async def _connect(self): |
| detecting_protocol = False |
| |
| # choose highest version first |
| if self.client.protocol_context is None: |
| detecting_protocol = True |
| self.client.protocol_context = ProtocolContext(max(PROTOCOLS), BitmaskFeature.all_supported()) |
| |
| try: |
| self._on_handshake_start() |
| result = await self._connect_version() |
| except HandshakeError as e: |
| if e.expected_version in PROTOCOLS: |
| self.client.protocol_context.version = e.expected_version |
| result = await self._connect_version() |
| else: |
| self._on_handshake_fail(e) |
| raise e |
| except AuthenticationError as e: |
| self._on_handshake_fail(e) |
| raise e |
| except Exception as e: |
| # restore undefined protocol version |
| if detecting_protocol: |
| self.client.protocol_context = None |
| self._on_handshake_fail(e) |
| raise e |
| |
| self._on_handshake_success(result) |
| |
| def process_connection_lost(self, err, reconnect=False): |
| self.failed = True |
| for _, fut in self._pending_reqs.items(): |
| fut.set_exception(err) |
| self._pending_reqs.clear() |
| |
| if self._transport_closed_fut and not self._transport_closed_fut.done(): |
| self._transport_closed_fut.set_result(None) |
| |
| if reconnect and not self._closed: |
| self._on_connection_lost(err) |
| self._loop.create_task(self._reconnect()) |
| |
| def process_message(self, data): |
| req_id = int.from_bytes(data[4:12], byteorder=PROTOCOL_BYTE_ORDER, signed=True) |
| if req_id in self._pending_reqs: |
| self._pending_reqs[req_id].set_result(data) |
| del self._pending_reqs[req_id] |
| |
| async def _connect_version(self) -> Union[dict, OrderedDict]: |
| """ |
| Connect to the given server node using protocol version |
| defined on client. |
| """ |
| |
| ssl_context = create_ssl_context(self.ssl_params) |
| handshake_fut = self._loop.create_future() |
| self._transport, _ = await self._loop.create_connection(lambda: BaseProtocol(self, handshake_fut), |
| host=self.host, port=self.port, ssl=ssl_context) |
| hs_response = await handshake_fut |
| |
| if hs_response.op_code == 0: |
| await self.close() |
| self._process_handshake_error(hs_response) |
| |
| return hs_response |
| |
| async def reconnect(self): |
| await self._reconnect() |
| |
| async def _reconnect(self): |
| if self.alive: |
| return |
| |
| await self._close_transport() |
| # connect and silence the connection errors |
| try: |
| await self._connect() |
| except connection_errors: |
| pass |
| |
| async def request(self, query_id, data: Union[bytes, bytearray]) -> bytearray: |
| """ |
| Perform request. |
| :param query_id: id of query. |
| :param data: bytes to send. |
| """ |
| if not self.alive: |
| raise SocketError('Attempt to use closed connection.') |
| |
| return await self._send(query_id, data) |
| |
| async def _send(self, query_id, data): |
| fut = self._loop.create_future() |
| self._pending_reqs[query_id] = fut |
| self._transport.write(data) |
| return await fut |
| |
| async def close(self): |
| self._closed = True |
| await self._close_transport() |
| |
| async def _close_transport(self): |
| """ |
| Close connection. |
| """ |
| if self._transport and not self._transport.is_closing(): |
| self._transport_closed_fut = self._loop.create_future() |
| |
| self._transport.close() |
| self._transport = None |
| try: |
| await asyncio.wait_for(self._transport_closed_fut, 1.0) |
| except asyncio.TimeoutError: |
| pass |
| finally: |
| self._on_connection_lost(expected=True) |
| self._transport_closed_fut = None |