IGNITE-14072 Refactor, remove duplicates and optimize Response and SQLResponse
This closes #8
diff --git a/pyignite/api/binary.py b/pyignite/api/binary.py
index 97f9fbd..1d63b49 100644
--- a/pyignite/api/binary.py
+++ b/pyignite/api/binary.py
@@ -20,10 +20,11 @@
body_struct, enum_struct, schema_struct, binary_fields_struct,
)
from pyignite.datatypes import String, Int, Bool
-from pyignite.queries import Query, get_response_class
+from pyignite.queries import Query
from pyignite.queries.op_codes import *
from pyignite.utils import int_overflow, entity_id
from .result import APIResult
+from ..queries.response import Response
def get_binary_type(
@@ -53,9 +54,9 @@
})
connection.send(send_buffer)
- response_head_struct = get_response_class(connection)([
- ('type_exists', Bool),
- ])
+ response_head_struct = Response(protocol_version=connection.get_protocol_version(),
+ following=[('type_exists', Bool)])
+
response_head_type, recv_buffer = response_head_struct.parse(connection)
response_head = response_head_type.from_buffer_copy(recv_buffer)
response_parts = []
diff --git a/pyignite/datatypes/complex.py b/pyignite/datatypes/complex.py
index d9ce36a..ad2a770 100644
--- a/pyignite/datatypes/complex.py
+++ b/pyignite/datatypes/complex.py
@@ -456,7 +456,7 @@
frame = rec[0]
code = frame.f_code
for varname in code.co_varnames:
- suspect = frame.f_locals[varname]
+ suspect = frame.f_locals.get(varname)
if isinstance(suspect, Client):
return suspect
if isinstance(suspect, Connection):
diff --git a/pyignite/queries/__init__.py b/pyignite/queries/__init__.py
index 3029f87..d558125 100644
--- a/pyignite/queries/__init__.py
+++ b/pyignite/queries/__init__.py
@@ -21,4 +21,4 @@
:mod:`pyignite.datatypes` binary parser/generator classes.
"""
-from .query import Query, ConfigQuery, get_response_class
+from .query import Query, ConfigQuery
diff --git a/pyignite/queries/query.py b/pyignite/queries/query.py
index 0e7cfa3..69b6fa2 100644
--- a/pyignite/queries/query.py
+++ b/pyignite/queries/query.py
@@ -13,26 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
import ctypes
from random import randint
-import attr
-
from pyignite.api.result import APIResult
-from pyignite.constants import *
-from pyignite.queries import response
-
-
-def get_response_class(obj: object, sql: bool = False):
- """
- Response class factory.
-
- :param obj: cache, connection or client object,
- :param sql: (optional) return normal (default) or SQL response class,
- :return: response class.
- """
- template = 'SQLResponse{}{}{}' if sql else 'Response{}{}{}'
- return getattr(response, template.format(*obj.get_protocol_version()))
+from pyignite.connection import Connection
+from pyignite.constants import MIN_LONG, MAX_LONG, RHF_TOPOLOGY_CHANGED
+from pyignite.queries.response import Response, SQLResponse
@attr.s
@@ -59,11 +47,7 @@
)
return cls._query_c_type
- def from_python(self, values: dict = None):
- if values is None:
- values = {}
- buffer = b''
-
+ def _build_header(self, buffer: bytearray, values: dict):
header_class = self.build_c_type()
header = header_class()
header.op_code = self.op_code
@@ -74,14 +58,23 @@
buffer += c_type.from_python(values[name])
header.length = (
- len(buffer)
- + ctypes.sizeof(header_class)
- - ctypes.sizeof(ctypes.c_int)
+ len(buffer)
+ + ctypes.sizeof(header_class)
+ - ctypes.sizeof(ctypes.c_int)
)
- return header.query_id, bytes(header) + buffer
+
+ return header
+
+ def from_python(self, values: dict = None):
+ if values is None:
+ values = {}
+ buffer = bytearray()
+ header = self._build_header(buffer, values)
+ buffer[:0] = bytes(header)
+ return header.query_id, bytes(buffer)
def perform(
- self, conn: 'Connection', query_params: dict = None,
+ self, conn: Connection, query_params: dict = None,
response_config: list = None, sql: bool = False, **kwargs,
) -> APIResult:
"""
@@ -98,8 +91,14 @@
"""
_, send_buffer = self.from_python(query_params)
conn.send(send_buffer)
- response_class = get_response_class(conn, sql)
- response_struct = response_class(response_config, **kwargs)
+
+ if sql:
+ response_struct = SQLResponse(protocol_version=conn.get_protocol_version(),
+ following=response_config, **kwargs)
+ else:
+ response_struct = Response(protocol_version=conn.get_protocol_version(),
+ following=response_config)
+
response_ctype, recv_buffer = response_struct.parse(conn)
response = response_ctype.from_buffer_copy(recv_buffer)
@@ -141,24 +140,7 @@
)
return cls._query_c_type
- def from_python(self, values: dict = None):
- if values is None:
- values = {}
- buffer = b''
-
- header_class = self.build_c_type()
- header = header_class()
- header.op_code = self.op_code
- if self.query_id is None:
- header.query_id = randint(MIN_LONG, MAX_LONG)
-
- for name, c_type in self.following:
- buffer += c_type.from_python(values[name])
-
- header.length = (
- len(buffer)
- + ctypes.sizeof(header_class)
- - ctypes.sizeof(ctypes.c_int)
- )
- header.config_length = header.length - ctypes.sizeof(header_class)
- return header.query_id, bytes(header) + buffer
+ def _build_header(self, buffer: bytearray, values: dict):
+ header = super()._build_header(buffer, values)
+ header.config_length = header.length - ctypes.sizeof(type(header))
+ return header
diff --git a/pyignite/queries/response.py b/pyignite/queries/response.py
index 5fb4879..6003959 100644
--- a/pyignite/queries/response.py
+++ b/pyignite/queries/response.py
@@ -13,74 +13,83 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
from collections import OrderedDict
import ctypes
-import attr
-
-from pyignite.constants import *
-from pyignite.datatypes import (
- AnyDataObject, Bool, Int, Long, String, StringArray, Struct,
-)
-from .op_codes import *
+from pyignite.constants import RHF_TOPOLOGY_CHANGED, RHF_ERROR
+from pyignite.connection import Connection
+from pyignite.datatypes import AnyDataObject, Bool, Int, Long, String, StringArray, Struct
+from pyignite.queries.op_codes import OP_SUCCESS
@attr.s
-class Response140:
+class Response:
following = attr.ib(type=list, factory=list)
+ protocol_version = attr.ib(type=tuple, factory=tuple)
_response_header = None
def __attrs_post_init__(self):
# replace None with empty list
self.following = self.following or []
- @classmethod
- def build_header(cls):
- if cls._response_header is None:
- cls._response_header = type(
+ def build_header(self):
+ if self._response_header is None:
+ fields = [
+ ('length', ctypes.c_int),
+ ('query_id', ctypes.c_longlong),
+ ]
+
+ if self.protocol_version and self.protocol_version >= (1, 4, 0):
+ fields.append(('flags', ctypes.c_short))
+ else:
+ fields.append(('status_code', ctypes.c_int),)
+
+ self._response_header = type(
'ResponseHeader',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
- '_fields_': [
- ('length', ctypes.c_int),
- ('query_id', ctypes.c_longlong),
- ('flags', ctypes.c_short),
- ],
+ '_fields_': fields,
},
)
- return cls._response_header
+ return self._response_header
- def parse(self, conn: 'Connection'):
+ def parse(self, conn: Connection):
header_class = self.build_header()
- buffer = conn.recv(ctypes.sizeof(header_class))
+ buffer = bytearray(conn.recv(ctypes.sizeof(header_class)))
header = header_class.from_buffer_copy(buffer)
fields = []
- if header.flags & RHF_TOPOLOGY_CHANGED:
- fields = [
- ('affinity_version', ctypes.c_longlong),
- ('affinity_minor', ctypes.c_int),
- ]
+ has_error = False
+ if self.protocol_version and self.protocol_version >= (1, 4, 0):
+ if header.flags & RHF_TOPOLOGY_CHANGED:
+ fields = [
+ ('affinity_version', ctypes.c_longlong),
+ ('affinity_minor', ctypes.c_int),
+ ]
- if header.flags & RHF_ERROR:
- fields.append(('status_code', ctypes.c_int))
+ if header.flags & RHF_ERROR:
+ fields.append(('status_code', ctypes.c_int))
+ has_error = True
+ else:
+ has_error = header.status_code != OP_SUCCESS
+
+ if fields:
buffer += conn.recv(
- sum([ctypes.sizeof(field[1]) for field in fields])
+ sum([ctypes.sizeof(c_type) for _, c_type in fields])
)
+
+ if has_error:
msg_type, buffer_fragment = String.parse(conn)
buffer += buffer_fragment
fields.append(('error_message', msg_type))
-
else:
- buffer += conn.recv(
- sum([ctypes.sizeof(field[1]) for field in fields])
- )
- for name, ignite_type in self.following:
- c_type, buffer_fragment = ignite_type.parse(conn)
- buffer += buffer_fragment
- fields.append((name, c_type))
+ self._parse_success(conn, buffer, fields)
+ return self._create_parse_result(conn, header_class, fields, buffer)
+
+ def _create_parse_result(self, conn: Connection, header_class, fields: list, buffer: bytearray):
response_class = type(
'Response',
(header_class,),
@@ -89,7 +98,13 @@
'_fields_': fields,
}
)
- return response_class, buffer
+ return response_class, bytes(buffer)
+
+ def _parse_success(self, conn: Connection, buffer: bytearray, fields: list):
+ for name, ignite_type in self.following:
+ c_type, buffer_fragment = ignite_type.parse(conn)
+ buffer += buffer_fragment
+ fields.append((name, c_type))
def to_python(self, ctype_object, *args, **kwargs):
result = OrderedDict()
@@ -104,7 +119,7 @@
@attr.s
-class SQLResponse140(Response140):
+class SQLResponse(Response):
"""
The response class of SQL functions is special in the way the row-column
data is counted in it. Basically, Ignite thin client API is following a
@@ -119,80 +134,55 @@
return 'fields', StringArray
return 'field_count', Int
- def parse(self, conn: 'Connection'):
- header_class = self.build_header()
- buffer = conn.recv(ctypes.sizeof(header_class))
- header = header_class.from_buffer_copy(buffer)
- fields = []
+ def _parse_success(self, conn: Connection, buffer: bytearray, fields: list):
+ following = [
+ self.fields_or_field_count(),
+ ('row_count', Int),
+ ]
+ if self.has_cursor:
+ following.insert(0, ('cursor', Long))
+ body_struct = Struct(following)
+ body_class, body_buffer = body_struct.parse(conn)
+ body = body_class.from_buffer_copy(body_buffer)
+ buffer += body_buffer
- if header.flags & RHF_TOPOLOGY_CHANGED:
- fields = [
- ('affinity_version', ctypes.c_longlong),
- ('affinity_minor', ctypes.c_int),
- ]
-
- if header.flags & RHF_ERROR:
- fields.append(('status_code', ctypes.c_int))
- buffer += conn.recv(
- sum([ctypes.sizeof(field[1]) for field in fields])
- )
- msg_type, buffer_fragment = String.parse(conn)
- buffer += buffer_fragment
- fields.append(('error_message', msg_type))
+ if self.include_field_names:
+ field_count = body.fields.length
else:
- buffer += conn.recv(
- sum([ctypes.sizeof(field[1]) for field in fields])
- )
- following = [
- self.fields_or_field_count(),
- ('row_count', Int),
- ]
- if self.has_cursor:
- following.insert(0, ('cursor', Long))
- body_struct = Struct(following)
- body_class, body_buffer = body_struct.parse(conn)
- body = body_class.from_buffer_copy(body_buffer)
+ field_count = body.field_count
- if self.include_field_names:
- field_count = body.fields.length
- else:
- field_count = body.field_count
+ data_fields = []
+ for i in range(body.row_count):
+ row_fields = []
+ for j in range(field_count):
+ field_class, field_buffer = AnyDataObject.parse(conn)
+ row_fields.append(('column_{}'.format(j), field_class))
+ buffer += field_buffer
- data_fields = []
- data_buffer = b''
- for i in range(body.row_count):
- row_fields = []
- row_buffer = b''
- for j in range(field_count):
- field_class, field_buffer = AnyDataObject.parse(conn)
- row_fields.append(('column_{}'.format(j), field_class))
- row_buffer += field_buffer
-
- row_class = type(
- 'SQLResponseRow',
- (ctypes.LittleEndianStructure,),
- {
- '_pack_': 1,
- '_fields_': row_fields,
- }
- )
- data_fields.append(('row_{}'.format(i), row_class))
- data_buffer += row_buffer
-
- data_class = type(
- 'SQLResponseData',
+ row_class = type(
+ 'SQLResponseRow',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
- '_fields_': data_fields,
+ '_fields_': row_fields,
}
)
- fields += body_class._fields_ + [
- ('data', data_class),
- ('more', ctypes.c_bool),
- ]
- buffer += body_buffer + data_buffer
+ data_fields.append(('row_{}'.format(i), row_class))
+ data_class = type(
+ 'SQLResponseData',
+ (ctypes.LittleEndianStructure,),
+ {
+ '_pack_': 1,
+ '_fields_': data_fields,
+ }
+ )
+ fields += body_class._fields_ + [
+ ('data', data_class),
+ ('more', ctypes.c_bool),
+ ]
+
+ def _create_parse_result(self, conn: Connection, header_class, fields: list, buffer: bytearray):
final_class = type(
'SQLResponse',
(header_class,),
@@ -202,10 +192,10 @@
}
)
buffer += conn.recv(ctypes.sizeof(final_class) - len(buffer))
- return final_class, buffer
+ return final_class, bytes(buffer)
def to_python(self, ctype_object, *args, **kwargs):
- if not hasattr(ctype_object, 'status_code'):
+ if getattr(ctype_object, 'status_code', 0) == 0:
result = {
'more': Bool.to_python(
ctype_object.more, *args, **kwargs
@@ -236,193 +226,3 @@
)
result['data'].append(row)
return result
-
-
-@attr.s
-class Response130:
- following = attr.ib(type=list, factory=list)
- _response_header = None
-
- def __attrs_post_init__(self):
- # replace None with empty list
- self.following = self.following or []
-
- @classmethod
- def build_header(cls):
- if cls._response_header is None:
- cls._response_header = type(
- 'ResponseHeader',
- (ctypes.LittleEndianStructure,),
- {
- '_pack_': 1,
- '_fields_': [
- ('length', ctypes.c_int),
- ('query_id', ctypes.c_longlong),
- ('status_code', ctypes.c_int),
- ],
- },
- )
- return cls._response_header
-
- def parse(self, client: 'Client'):
- header_class = self.build_header()
- buffer = client.recv(ctypes.sizeof(header_class))
- header = header_class.from_buffer_copy(buffer)
- fields = []
-
- if header.status_code == OP_SUCCESS:
- for name, ignite_type in self.following:
- c_type, buffer_fragment = ignite_type.parse(client)
- buffer += buffer_fragment
- fields.append((name, c_type))
- else:
- c_type, buffer_fragment = String.parse(client)
- buffer += buffer_fragment
- fields.append(('error_message', c_type))
-
- response_class = type(
- 'Response',
- (header_class,),
- {
- '_pack_': 1,
- '_fields_': fields,
- }
- )
- return response_class, buffer
-
- def to_python(self, ctype_object, *args, **kwargs):
- result = OrderedDict()
-
- for name, c_type in self.following:
- result[name] = c_type.to_python(
- getattr(ctype_object, name),
- *args, **kwargs
- )
-
- return result if result else None
-
-
-@attr.s
-class SQLResponse130(Response130):
- """
- The response class of SQL functions is special in the way the row-column
- data is counted in it. Basically, Ignite thin client API is following a
- “counter right before the counted objects” rule in most of its parts.
- SQL ops are breaking this rule.
- """
- include_field_names = attr.ib(type=bool, default=False)
- has_cursor = attr.ib(type=bool, default=False)
-
- def fields_or_field_count(self):
- if self.include_field_names:
- return 'fields', StringArray
- return 'field_count', Int
-
- def parse(self, client: 'Client'):
- header_class = self.build_header()
- buffer = client.recv(ctypes.sizeof(header_class))
- header = header_class.from_buffer_copy(buffer)
- fields = []
-
- if header.status_code == OP_SUCCESS:
- following = [
- self.fields_or_field_count(),
- ('row_count', Int),
- ]
- if self.has_cursor:
- following.insert(0, ('cursor', Long))
- body_struct = Struct(following)
- body_class, body_buffer = body_struct.parse(client)
- body = body_class.from_buffer_copy(body_buffer)
-
- if self.include_field_names:
- field_count = body.fields.length
- else:
- field_count = body.field_count
-
- data_fields = []
- data_buffer = b''
- for i in range(body.row_count):
- row_fields = []
- row_buffer = b''
- for j in range(field_count):
- field_class, field_buffer = AnyDataObject.parse(client)
- row_fields.append(('column_{}'.format(j), field_class))
- row_buffer += field_buffer
-
- row_class = type(
- 'SQLResponseRow',
- (ctypes.LittleEndianStructure,),
- {
- '_pack_': 1,
- '_fields_': row_fields,
- }
- )
- data_fields.append(('row_{}'.format(i), row_class))
- data_buffer += row_buffer
-
- data_class = type(
- 'SQLResponseData',
- (ctypes.LittleEndianStructure,),
- {
- '_pack_': 1,
- '_fields_': data_fields,
- }
- )
- fields += body_class._fields_ + [
- ('data', data_class),
- ('more', ctypes.c_bool),
- ]
- buffer += body_buffer + data_buffer
- else:
- c_type, buffer_fragment = String.parse(client)
- buffer += buffer_fragment
- fields.append(('error_message', c_type))
-
- final_class = type(
- 'SQLResponse',
- (header_class,),
- {
- '_pack_': 1,
- '_fields_': fields,
- }
- )
- buffer += client.recv(ctypes.sizeof(final_class) - len(buffer))
- return final_class, buffer
-
- def to_python(self, ctype_object, *args, **kwargs):
- if ctype_object.status_code == 0:
- result = {
- 'more': Bool.to_python(
- ctype_object.more, *args, **kwargs
- ),
- 'data': [],
- }
- if hasattr(ctype_object, 'fields'):
- result['fields'] = StringArray.to_python(
- ctype_object.fields, *args, **kwargs
- )
- else:
- result['field_count'] = Int.to_python(
- ctype_object.field_count, *args, **kwargs
- )
- if hasattr(ctype_object, 'cursor'):
- result['cursor'] = Long.to_python(
- ctype_object.cursor, *args, **kwargs
- )
- for row_item in ctype_object.data._fields_:
- row_name = row_item[0]
- row_object = getattr(ctype_object.data, row_name)
- row = []
- for col_item in row_object._fields_:
- col_name = col_item[0]
- col_object = getattr(row_object, col_name)
- row.append(
- AnyDataObject.to_python(col_object, *args, **kwargs)
- )
- result['data'].append(row)
- return result
-
-
-Response120 = Response130
-SQLResponse120 = SQLResponse130