IGNITE-15479 Fix incorrect partial read from socket in sync client - Fixes #50.
diff --git a/pyignite/connection/connection.py b/pyignite/connection/connection.py
index 98ba7e0..3d86f01 100644
--- a/pyignite/connection/connection.py
+++ b/pyignite/connection/connection.py
@@ -156,6 +156,9 @@
return self.client._event_listeners
+DEFAULT_INITIAL_BUF_SIZE = 1024
+
+
class Connection(BaseConnection):
"""
This is a `pyignite` class, that represents a connection to Ignite
@@ -348,15 +351,15 @@
if flags is not None:
kwargs['flags'] = flags
- data = bytearray(1024)
+ data = bytearray(DEFAULT_INITIAL_BUF_SIZE)
buffer = memoryview(data)
- bytes_total_received, bytes_to_receive = 0, 0
+ total_rcvd, packet_len = 0, 0
while True:
try:
- bytes_received = self._socket.recv_into(buffer, len(buffer), **kwargs)
- if bytes_received == 0:
+ bytes_rcvd = self._socket.recv_into(buffer, len(buffer), **kwargs)
+ if bytes_rcvd == 0:
raise SocketError('Connection broken.')
- bytes_total_received += bytes_received
+ total_rcvd += bytes_rcvd
except connection_errors as e:
self.failed = True
if reconnect:
@@ -364,23 +367,19 @@
self.reconnect()
raise e
- if bytes_total_received < 4:
- continue
- elif bytes_to_receive == 0:
- response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER)
- bytes_to_receive = response_len
-
- if response_len + 4 > len(data):
+ if packet_len == 0 and total_rcvd > 4:
+ packet_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER, signed=True) + 4
+ if packet_len > len(data):
buffer.release()
- data.extend(bytearray(response_len + 4 - len(data)))
- buffer = memoryview(data)[bytes_total_received:]
+ data.extend(bytearray(packet_len - len(data)))
+ buffer = memoryview(data)[total_rcvd:]
continue
- if bytes_total_received >= bytes_to_receive:
+ if 0 < packet_len <= total_rcvd:
buffer.release()
break
- buffer = buffer[bytes_received:]
+ buffer = buffer[bytes_rcvd:]
return data
diff --git a/tests/common/test_sync_socket.py b/tests/common/test_sync_socket.py
new file mode 100644
index 0000000..cd41809
--- /dev/null
+++ b/tests/common/test_sync_socket.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import secrets
+import socket
+import unittest.mock as mock
+
+import pytest
+
+from pyignite import Client
+from tests.util import get_or_create_cache
+
+old_recv_into = socket.socket.recv_into
+
+
+def patched_recv_into_factory(buf_len):
+ def patched_recv_into(self, buffer, nbytes, **kwargs):
+ return old_recv_into(self, buffer, min(nbytes, buf_len) if buf_len else nbytes, **kwargs)
+ return patched_recv_into
+
+
+@pytest.mark.parametrize('buf_len', [0, 1, 4, 16, 32, 64, 128, 256, 512, 1024])
+def test_get_large_value(buf_len):
+ with mock.patch.object(socket.socket, 'recv_into', new=patched_recv_into_factory(buf_len)):
+ c = Client()
+ with c.connect("127.0.0.1", 10801):
+ with get_or_create_cache(c, 'test') as cache:
+ value = secrets.token_hex((1 << 16) + 1)
+ cache.put(1, value)
+ assert value == cache.get(1)