blob: ef01af1854bd4abe9ce3f7484f4d2dcd48502619 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
from asyncio import TimeoutError, InvalidStateError
import pytest
from pyignite import AioClient
from pyignite.aio_cache import AioCache
from pyignite.datatypes.key_value import PeekModes
from tests.util import start_ignite_gen
@pytest.fixture(scope='module', autouse=True)
def server1():
yield from start_ignite_gen(idx=1)
@pytest.fixture(autouse=True)
async def proxy(event_loop, server1, cache):
proxy = ProxyServer(("127.0.0.1", 10802), ("127.0.0.1", 10801))
try:
await proxy.start()
yield proxy
finally:
await proxy.close()
@pytest.fixture(scope='module', autouse=True)
async def cache(server1):
c = AioClient(partition_aware=False)
async with c.connect("127.0.0.1", 10801):
try:
cache = await c.get_or_create_cache("test")
yield cache
finally:
await cache.destroy()
@pytest.fixture(autouse=True)
def invalid_states_errors():
errors = []
def trace(_, event, arg):
if event == 'exception':
etype, _, _ = arg
if etype is InvalidStateError:
errors.append(arg)
return trace
try:
sys.settrace(trace)
yield errors
finally:
sys.settrace(None)
cache_method_params = (
"method,targs",
[
(AioCache.get, (1,)),
(AioCache.put, (1, 1)),
(AioCache.get_all, ([1, 1],)),
(AioCache.put_all, ({1: 1},)),
(AioCache.replace, (1, 1)),
(AioCache.clear, ()),
(AioCache.clear_key, (1,)),
(AioCache.clear_keys, ([1, 1],)),
(AioCache.contains_key, (1,)),
(AioCache.contains_keys, ([1, 1],)),
(AioCache.get_and_put, (1, 1)),
(AioCache.get_and_put_if_absent, (1, 1)),
(AioCache.put_if_absent, (1, 1)),
(AioCache.get_and_remove, (1,)),
(AioCache.get_and_replace, (1, 1)),
(AioCache.remove_key, (1,)),
(AioCache.remove_keys, ([1, 1],)),
(AioCache.remove_all, ()),
(AioCache.remove_if_equals, (1, 1)),
(AioCache.replace_if_equals, (1, 1, 1)),
(AioCache.get_size, ([PeekModes.PRIMARY, PeekModes.BACKUP],)),
(AioCache.get_size, ()),
(AioCache.settings, ())
]
)
@pytest.mark.parametrize(*cache_method_params)
@pytest.mark.asyncio
async def test_cancellation_on_slow_response(event_loop, proxy, invalid_states_errors,
method, targs):
c = AioClient(partition_aware=False)
async with c.connect("127.0.0.1", 10802):
cache = await c.get_cache("test")
proxy.slow_response = True
with pytest.raises(TimeoutError):
await method(cache, *targs, timeout=0.1)
proxy.slow_response = False
assert len(invalid_states_errors) == 0
@pytest.mark.parametrize(*cache_method_params)
@pytest.mark.asyncio
async def test_cancellation_on_disconnect(event_loop, proxy, invalid_states_errors,
method, targs):
c = AioClient(partition_aware=False)
async with c.connect("127.0.0.1", 10802):
cache = await c.get_cache("test")
proxy.discard_response = True
result = asyncio.ensure_future(method(cache, *targs, timeout=0.1))
await asyncio.sleep(0.2)
await proxy.disconnect_peers()
with pytest.raises(TimeoutError):
await result
assert len(invalid_states_errors) == 0
class ProxyServer:
"""
Proxy for simulating discarding response or slow response from ignite server
Set `discard_response` or `slow_response` to `True` to simulate this condition.
Set `slow_response_timeout` to change sleep time interval for slow responses.
Call `disconnect_peers()` in order to simulate lost connection to Ignite server.
"""
def __init__(self, local_host, remote_host):
self.local_host = local_host
self.remote_host = remote_host
self.peers = {}
self.discard_response, self.slow_response = False, False
self.slow_response_timeout = 0.1
self.server = None
async def start(self):
loop = asyncio.get_event_loop()
host, port = self.local_host
self.server = await loop.create_server(
lambda: ProxyTcpProtocol(self), host=host, port=port)
async def disconnect_peers(self):
peers = dict(self.peers)
for k, v in peers.items():
if not v:
return
local, remote = v
if local:
await remote.close()
if remote:
await local.close()
async def close(self):
try:
await self.disconnect_peers()
except TimeoutError:
pass
self.server.close()
class ProxyTcpProtocol(asyncio.Protocol):
def __init__(self, proxy):
self.addr, self.port = proxy.remote_host
self.proxy = proxy
self.transport, self.remote_protocol = None, None
self.conn_info, self.close_fut = None, None
super().__init__()
def connection_made(self, transport):
self.transport = transport
self.conn_info = transport.get_extra_info("peername")
def data_received(self, data):
if self.remote_protocol and self.remote_protocol.transport:
self.remote_protocol.transport.write(data)
return
loop = asyncio.get_event_loop()
self.remote_protocol = RemoteTcpProtocol(self.proxy, self, data)
coro = loop.create_connection(lambda: self.remote_protocol, host=self.addr, port=self.port)
asyncio.ensure_future(coro)
self.proxy.peers[self.conn_info] = (self, self.remote_protocol)
async def close(self):
if not self.transport:
return
self.close_fut = asyncio.get_event_loop().create_future()
self.transport.close()
try:
await asyncio.wait_for(self.close_fut, 0.1)
except TimeoutError:
pass
def connection_lost(self, exc):
if self.close_fut:
self.close_fut.done()
class RemoteTcpProtocol(asyncio.Protocol):
def __init__(self, proxy, proxy_protocol, data):
self.proxy = proxy
self.proxy_protocol = proxy_protocol
self.data = data
self.transport, self.close_fut = None, None
self.queue = []
self.loop_task = asyncio.ensure_future(self.send_loop())
self.stop_flag = False
super().__init__()
def connection_made(self, transport):
self.transport = transport
self.transport.write(self.data)
async def close(self):
if not self.transport:
return
self.stop_flag = True
self.close_fut = asyncio.get_event_loop().create_future()
self.transport.close()
try:
await asyncio.wait_for(self.loop_task, 0.1)
except TimeoutError:
pass
try:
await asyncio.wait_for(self.close_fut, 0.1)
except TimeoutError:
pass
def connection_lost(self, exc):
if self.close_fut:
self.close_fut.done()
async def send_loop(self):
while not self.stop_flag:
if not self.queue:
await asyncio.sleep(0.01)
continue
packet = self.queue.pop()
if packet:
if self.proxy.slow_response:
await asyncio.sleep(self.proxy.slow_response_timeout)
self.proxy_protocol.transport.write(packet)
def data_received(self, data):
if self.proxy.discard_response:
return
self.queue.append(data)