IGNITE-14432 Implement connection context managers for clients
This closes #23
diff --git a/pyignite/aio_client.py b/pyignite/aio_client.py
index d882969..d2cc3ff 100644
--- a/pyignite/aio_client.py
+++ b/pyignite/aio_client.py
@@ -33,6 +33,22 @@
__all__ = ['AioClient']
+class _ConnectionContextManager:
+ def __init__(self, client, nodes):
+ self.client = client
+ self.nodes = nodes
+
+ def __await__(self):
+ return (yield from self.__aenter__().__await__())
+
+ async def __aenter__(self):
+ await self.client._connect(self.nodes)
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.client.close()
+
+
class AioClient(BaseClient):
"""
Asynchronous Client implementation.
@@ -57,14 +73,16 @@
super().__init__(compact_footer, partition_aware, **kwargs)
self._registry_mux = asyncio.Lock()
- async def connect(self, *args):
+ def connect(self, *args):
"""
Connect to Ignite cluster node(s).
:param args: (optional) host(s) and port(s) to connect to.
"""
nodes = self._process_connect_args(*args)
+ return _ConnectionContextManager(self, nodes)
+ async def _connect(self, nodes):
for i, node in enumerate(nodes):
host, port = node
conn = AioConnection(self, host, port, **self._connection_args)
diff --git a/pyignite/client.py b/pyignite/client.py
index e4eef6a..05df617 100644
--- a/pyignite/client.py
+++ b/pyignite/client.py
@@ -243,6 +243,19 @@
return self._registry[type_id]
+class _ConnectionContextManager:
+ def __init__(self, client, nodes):
+ self.client = client
+ self.nodes = nodes
+ self.client._connect(self.nodes)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.client.close()
+
+
class Client(BaseClient):
"""
This is a main `pyignite` class, that is build upon the
@@ -280,7 +293,9 @@
:param args: (optional) host(s) and port(s) to connect to.
"""
nodes = self._process_connect_args(*args)
+ return _ConnectionContextManager(self, nodes)
+ def _connect(self, nodes):
# the following code is quite twisted, because the protocol version
# is initially unknown
diff --git a/tests/affinity/conftest.py b/tests/affinity/conftest.py
index 2ec2b1b..e23e0e6 100644
--- a/tests/affinity/conftest.py
+++ b/tests/affinity/conftest.py
@@ -39,20 +39,25 @@
@pytest.fixture
-def client():
+def connection_param():
+ return [('127.0.0.1', 10800 + i) for i in range(1, 4)]
+
+
+@pytest.fixture
+def client(connection_param):
client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT)
try:
- client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
+ client.connect(connection_param)
yield client
finally:
client.close()
@pytest.fixture
-async def async_client():
+async def async_client(connection_param):
client = AioClient(partition_aware=True)
try:
- await client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
+ await client.connect(connection_param)
yield client
finally:
await client.close()
diff --git a/tests/affinity/test_affinity_bad_servers.py b/tests/affinity/test_affinity_bad_servers.py
index b169168..f5eec21 100644
--- a/tests/affinity/test_affinity_bad_servers.py
+++ b/tests/affinity/test_affinity_bad_servers.py
@@ -15,9 +15,10 @@
import pytest
+from pyignite import Client, AioClient
from pyignite.exceptions import ReconnectError, connection_errors
from tests.affinity.conftest import CLIENT_SOCKET_TIMEOUT
-from tests.util import start_ignite, kill_process_tree, get_client, get_client_async
+from tests.util import start_ignite, kill_process_tree
@pytest.fixture(params=['with-partition-awareness', 'without-partition-awareness'])
@@ -27,22 +28,24 @@
def test_client_with_multiple_bad_servers(with_partition_awareness):
with pytest.raises(ReconnectError, match="Can not connect."):
- with get_client(partition_aware=with_partition_awareness) as client:
- client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
+ client = Client(partition_aware=with_partition_awareness)
+ with client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]):
+ pass
@pytest.mark.asyncio
async def test_client_with_multiple_bad_servers_async(with_partition_awareness):
with pytest.raises(ReconnectError, match="Can not connect."):
- async with get_client_async(partition_aware=with_partition_awareness) as client:
- await client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
+ client = AioClient(partition_aware=with_partition_awareness)
+ async with client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]):
+ pass
def test_client_with_failed_server(request, with_partition_awareness):
srv = start_ignite(idx=4)
try:
- with get_client(partition_aware=with_partition_awareness) as client:
- client.connect([("127.0.0.1", 10804)])
+ client = Client(partition_aware=with_partition_awareness)
+ with client.connect([("127.0.0.1", 10804)]):
cache = client.get_or_create_cache(request.node.name)
cache.put(1, 1)
kill_process_tree(srv.pid)
@@ -62,8 +65,8 @@
async def test_client_with_failed_server_async(request, with_partition_awareness):
srv = start_ignite(idx=4)
try:
- async with get_client_async(partition_aware=with_partition_awareness) as client:
- await client.connect([("127.0.0.1", 10804)])
+ client = AioClient(partition_aware=with_partition_awareness)
+ async with client.connect([("127.0.0.1", 10804)]):
cache = await client.get_or_create_cache(request.node.name)
await cache.put(1, 1)
kill_process_tree(srv.pid)
@@ -82,8 +85,8 @@
def test_client_with_recovered_server(request, with_partition_awareness):
srv = start_ignite(idx=4)
try:
- with get_client(partition_aware=with_partition_awareness, timeout=CLIENT_SOCKET_TIMEOUT) as client:
- client.connect([("127.0.0.1", 10804)])
+ client = Client(partition_aware=with_partition_awareness, timeout=CLIENT_SOCKET_TIMEOUT)
+ with client.connect([("127.0.0.1", 10804)]):
cache = client.get_or_create_cache(request.node.name)
cache.put(1, 1)
@@ -108,8 +111,8 @@
async def test_client_with_recovered_server_async(request, with_partition_awareness):
srv = start_ignite(idx=4)
try:
- async with get_client_async(partition_aware=with_partition_awareness) as client:
- await client.connect([("127.0.0.1", 10804)])
+ client = AioClient(partition_aware=with_partition_awareness)
+ async with client.connect([("127.0.0.1", 10804)]):
cache = await client.get_or_create_cache(request.node.name)
await cache.put(1, 1)
diff --git a/tests/affinity/test_connection_context_manager.py b/tests/affinity/test_connection_context_manager.py
new file mode 100644
index 0000000..8056c7d
--- /dev/null
+++ b/tests/affinity/test_connection_context_manager.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from pyignite import Client, AioClient
+
+
+@pytest.fixture
+def connection_param():
+ return [('127.0.0.1', 10800 + i) for i in range(1, 4)]
+
+
+@pytest.mark.parametrize('partition_aware', ['with_partition_aware', 'wo_partition_aware'])
+def test_connection_context(connection_param, partition_aware):
+ is_partition_aware = partition_aware == 'with_partition_aware'
+ client = Client(partition_aware=is_partition_aware)
+
+ # Check context manager
+ with client.connect(connection_param):
+ __check_open(client, is_partition_aware)
+ __check_closed(client)
+
+ # Check standard way
+ try:
+ client.connect(connection_param)
+ __check_open(client, is_partition_aware)
+ finally:
+ client.close()
+ __check_closed(client)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('partition_aware', ['with_partition_aware', 'wo_partition_aware'])
+async def test_connection_context_async(connection_param, partition_aware):
+ is_partition_aware = partition_aware == 'with_partition_aware'
+ client = AioClient(partition_aware=is_partition_aware)
+
+ # Check async context manager.
+ async with client.connect(connection_param):
+ await __check_open(client, is_partition_aware)
+ __check_closed(client)
+
+ # Check standard way.
+ try:
+ await client.connect(connection_param)
+ await __check_open(client, is_partition_aware)
+ finally:
+ await client.close()
+ __check_closed(client)
+
+
+def __check_open(client, is_partition_aware):
+ def inner_sync():
+ if is_partition_aware:
+ assert client.random_node.alive
+ else:
+ all(n.alive for n in client._nodes)
+
+ async def inner_async():
+ if is_partition_aware:
+ random_node = await client.random_node()
+ assert random_node.alive
+ else:
+ all(n.alive for n in client._nodes)
+
+ return inner_sync() if isinstance(client, Client) else inner_async()
+
+
+def __check_closed(client):
+ assert all(not n.alive for n in client._nodes)
diff --git a/tests/security/test_auth.py b/tests/security/test_auth.py
index 4a1c52d..b02f224 100644
--- a/tests/security/test_auth.py
+++ b/tests/security/test_auth.py
@@ -14,8 +14,9 @@
# limitations under the License.
import pytest
+from pyignite import Client, AioClient
from pyignite.exceptions import AuthenticationError
-from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client, get_client_async
+from tests.util import start_ignite_gen, clear_ignite_work_dir
DEFAULT_IGNITE_USERNAME = 'ignite'
DEFAULT_IGNITE_PASSWORD = 'ignite'
@@ -40,21 +41,16 @@
def test_auth_success(with_ssl, ssl_params):
ssl_params['use_ssl'] = with_ssl
-
- with get_client(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params) as client:
- client.connect("127.0.0.1", 10801)
-
+ client = Client(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params)
+ with client.connect("127.0.0.1", 10801):
assert all(node.alive for node in client._nodes)
@pytest.mark.asyncio
async def test_auth_success_async(with_ssl, ssl_params):
ssl_params['use_ssl'] = with_ssl
-
- async with get_client_async(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD,
- **ssl_params) as client:
- await client.connect("127.0.0.1", 10801)
-
+ client = AioClient(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params)
+ async with client.connect("127.0.0.1", 10801):
assert all(node.alive for node in client._nodes)
@@ -73,8 +69,9 @@
ssl_params['use_ssl'] = with_ssl
with pytest.raises(AuthenticationError):
- with get_client(username=username, password=password, **ssl_params) as client:
- client.connect("127.0.0.1", 10801)
+ client = Client(username=username, password=password, **ssl_params)
+ with client.connect("127.0.0.1", 10801):
+ pass
@pytest.mark.parametrize(
@@ -86,5 +83,6 @@
ssl_params['use_ssl'] = with_ssl
with pytest.raises(AuthenticationError):
- async with get_client_async(username=username, password=password, **ssl_params) as client:
- await client.connect("127.0.0.1", 10801)
+ client = AioClient(username=username, password=password, **ssl_params)
+ async with client.connect("127.0.0.1", 10801):
+ pass
diff --git a/tests/security/test_ssl.py b/tests/security/test_ssl.py
index 32db98f..7736864 100644
--- a/tests/security/test_ssl.py
+++ b/tests/security/test_ssl.py
@@ -14,8 +14,9 @@
# limitations under the License.
import pytest
+from pyignite import Client, AioClient
from pyignite.exceptions import ReconnectError
-from tests.util import start_ignite_gen, get_client, get_or_create_cache, get_client_async, get_or_create_cache_async
+from tests.util import start_ignite_gen, get_or_create_cache, get_or_create_cache_async
@pytest.fixture(scope='module', autouse=True)
@@ -45,18 +46,16 @@
kwargs['use_ssl'] = True
def inner():
- with get_client(**kwargs) as client:
- client.connect("127.0.0.1", 10801)
-
+ client = Client(**kwargs)
+ with client.connect("127.0.0.1", 10801):
with get_or_create_cache(client, 'test-cache') as cache:
cache.put(1, 1)
assert cache.get(1) == 1
async def inner_async():
- async with get_client_async(**kwargs) as client:
- await client.connect("127.0.0.1", 10801)
-
+ client = AioClient(**kwargs)
+ async with client.connect("127.0.0.1", 10801):
async with get_or_create_cache_async(client, 'test-cache') as cache:
await cache.put(1, 1)
@@ -75,13 +74,15 @@
@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
def test_connection_error_with_incorrect_config(invalid_ssl_params):
with pytest.raises(ReconnectError):
- with get_client(**invalid_ssl_params) as client:
- client.connect([("127.0.0.1", 10801)])
+ client = Client(**invalid_ssl_params)
+ with client.connect([("127.0.0.1", 10801)]):
+ pass
@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
@pytest.mark.asyncio
async def test_connection_error_with_incorrect_config_async(invalid_ssl_params):
with pytest.raises(ReconnectError):
- async with get_client_async(**invalid_ssl_params) as client:
- await client.connect([("127.0.0.1", 10801)])
+ client = AioClient(**invalid_ssl_params)
+ async with client.connect([("127.0.0.1", 10801)]):
+ pass
diff --git a/tests/util.py b/tests/util.py
index f1243fc..2ca898b 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -35,24 +35,6 @@
@contextlib.contextmanager
-def get_client(**kwargs):
- client = Client(**kwargs)
- try:
- yield client
- finally:
- client.close()
-
-
-@asynccontextmanager
-async def get_client_async(**kwargs):
- client = AioClient(**kwargs)
- try:
- yield client
- finally:
- await client.close()
-
-
-@contextlib.contextmanager
def get_or_create_cache(client, cache_name):
cache = client.get_or_create_cache(cache_name)
try: