IGNITE-13911 Asyncio version of client
This closes #21
diff --git a/.gitignore b/.gitignore
index 699c26d..14ec495 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@
*.so
build
distr
+docs/generated
tests/config/*.xml
junit*.xml
pyignite.egg-info
diff --git a/.travis.yml b/.travis.yml
index 7e726be..74909b8 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -43,6 +43,9 @@
- python: '3.8'
arch: amd64
env: TOXENV=py38
+ - python: '3.8'
+ arch: amd64
+ env: TOXENV=codestyle
- python: '3.9'
arch: amd64
env: TOXENV=py39
diff --git a/examples/create_binary.py b/examples/create_binary.py
index c963796..b199527 100644
--- a/examples/create_binary.py
+++ b/examples/create_binary.py
@@ -23,44 +23,44 @@
client.connect('127.0.0.1', 10800)
student_cache = client.create_cache({
- PROP_NAME: 'SQL_PUBLIC_STUDENT',
- PROP_SQL_SCHEMA: 'PUBLIC',
- PROP_QUERY_ENTITIES: [
- {
- 'table_name': 'Student'.upper(),
- 'key_field_name': 'SID',
- 'key_type_name': 'java.lang.Integer',
- 'field_name_aliases': [],
- 'query_fields': [
- {
- 'name': 'SID',
- 'type_name': 'java.lang.Integer',
- 'is_key_field': True,
- 'is_notnull_constraint_field': True,
- },
- {
- 'name': 'NAME',
- 'type_name': 'java.lang.String',
- },
- {
- 'name': 'LOGIN',
- 'type_name': 'java.lang.String',
- },
- {
- 'name': 'AGE',
- 'type_name': 'java.lang.Integer',
- },
- {
- 'name': 'GPA',
- 'type_name': 'java.math.Double',
- },
- ],
- 'query_indexes': [],
- 'value_type_name': 'SQL_PUBLIC_STUDENT_TYPE',
- 'value_field_name': None,
- },
- ],
- })
+ PROP_NAME: 'SQL_PUBLIC_STUDENT',
+ PROP_SQL_SCHEMA: 'PUBLIC',
+ PROP_QUERY_ENTITIES: [
+ {
+ 'table_name': 'Student'.upper(),
+ 'key_field_name': 'SID',
+ 'key_type_name': 'java.lang.Integer',
+ 'field_name_aliases': [],
+ 'query_fields': [
+ {
+ 'name': 'SID',
+ 'type_name': 'java.lang.Integer',
+ 'is_key_field': True,
+ 'is_notnull_constraint_field': True,
+ },
+ {
+ 'name': 'NAME',
+ 'type_name': 'java.lang.String',
+ },
+ {
+ 'name': 'LOGIN',
+ 'type_name': 'java.lang.String',
+ },
+ {
+ 'name': 'AGE',
+ 'type_name': 'java.lang.Integer',
+ },
+ {
+ 'name': 'GPA',
+ 'type_name': 'java.math.Double',
+ },
+ ],
+ 'query_indexes': [],
+ 'value_type_name': 'SQL_PUBLIC_STUDENT_TYPE',
+ 'value_field_name': None,
+ },
+ ],
+})
class Student(
diff --git a/examples/sql.py b/examples/sql.py
index 8f0ee7c..0e8c729 100644
--- a/examples/sql.py
+++ b/examples/sql.py
@@ -280,7 +280,7 @@
field_data = list(*result)
print('City info:')
-for field_name, field_value in zip(field_names*len(field_data), field_data):
+for field_name, field_value in zip(field_names * len(field_data), field_data):
print('{}: {}'.format(field_name, field_value))
# City info:
# ID: 3802
diff --git a/pyignite/__init__.py b/pyignite/__init__.py
index 0ac346f..c26c59a 100644
--- a/pyignite/__init__.py
+++ b/pyignite/__init__.py
@@ -14,4 +14,7 @@
# limitations under the License.
from pyignite.client import Client
+from pyignite.aio_client import AioClient
from pyignite.binary import GenericObjectMeta
+
+__version__ = '0.4.0-dev'
diff --git a/pyignite/aio_cache.py b/pyignite/aio_cache.py
new file mode 100644
index 0000000..b92a14c
--- /dev/null
+++ b/pyignite/aio_cache.py
@@ -0,0 +1,600 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+from typing import Any, Dict, Iterable, Optional, Union
+
+from .constants import AFFINITY_RETRIES, AFFINITY_DELAY
+from .connection import AioConnection
+from .datatypes import prop_codes
+from .datatypes.base import IgniteDataType
+from .datatypes.internal import AnyDataObject
+from .exceptions import CacheCreationError, CacheError, ParameterError, connection_errors
+from .utils import cache_id, status_to_exception
+from .api.cache_config import (
+ cache_create_async, cache_get_or_create_async, cache_destroy_async, cache_get_configuration_async,
+ cache_create_with_config_async, cache_get_or_create_with_config_async
+)
+from .api.key_value import (
+ cache_get_async, cache_contains_key_async, cache_clear_key_async, cache_clear_keys_async, cache_clear_async,
+ cache_replace_async, cache_put_all_async, cache_get_all_async, cache_put_async, cache_contains_keys_async,
+ cache_get_and_put_async, cache_get_and_put_if_absent_async, cache_put_if_absent_async, cache_get_and_remove_async,
+ cache_get_and_replace_async, cache_remove_key_async, cache_remove_keys_async, cache_remove_all_async,
+ cache_remove_if_equals_async, cache_replace_if_equals_async, cache_get_size_async,
+)
+from .cursors import AioScanCursor
+from .api.affinity import cache_get_node_partitions_async
+from .cache import __parse_settings, BaseCacheMixin
+
+
+async def get_cache(client: 'AioClient', settings: Union[str, dict]) -> 'AioCache':
+ name, settings = __parse_settings(settings)
+ if settings:
+ raise ParameterError('Only cache name allowed as a parameter')
+
+ return AioCache(client, name)
+
+
+async def create_cache(client: 'AioClient', settings: Union[str, dict]) -> 'AioCache':
+ name, settings = __parse_settings(settings)
+
+ conn = await client.random_node()
+ if settings:
+ result = await cache_create_with_config_async(conn, settings)
+ else:
+ result = await cache_create_async(conn, name)
+
+ if result.status != 0:
+ raise CacheCreationError(result.message)
+
+ return AioCache(client, name)
+
+
+async def get_or_create_cache(client: 'AioClient', settings: Union[str, dict]) -> 'AioCache':
+ name, settings = __parse_settings(settings)
+
+ conn = await client.random_node()
+ if settings:
+ result = await cache_get_or_create_with_config_async(conn, settings)
+ else:
+ result = await cache_get_or_create_async(conn, name)
+
+ if result.status != 0:
+ raise CacheCreationError(result.message)
+
+ return AioCache(client, name)
+
+
+class AioCache(BaseCacheMixin):
+ """
+ Ignite cache abstraction. Users should never use this class directly,
+ but construct its instances with
+ :py:meth:`~pyignite.client.Client.create_cache`,
+ :py:meth:`~pyignite.client.Client.get_or_create_cache` or
+ :py:meth:`~pyignite.client.Client.get_cache` methods instead. See
+ :ref:`this example <create_cache>` on how to do it.
+ """
+ def __init__(self, client: 'AioClient', name: str):
+ """
+ Initialize async cache object. For internal use.
+
+ :param client: Async Ignite client,
+ :param name: Cache name.
+ """
+ self._client = client
+ self._name = name
+ self._cache_id = cache_id(self._name)
+ self._settings = None
+ self._affinity_query_mux = asyncio.Lock()
+ self.affinity = {'version': (0, 0)}
+
+ async def settings(self) -> Optional[dict]:
+ """
+ Lazy Cache settings. See the :ref:`example <sql_cache_read>`
+ of reading this property.
+
+ All cache properties are documented here: :ref:`cache_props`.
+
+ :return: dict of cache properties and their values.
+ """
+ if self._settings is None:
+ conn = await self.get_best_node()
+ config_result = await cache_get_configuration_async(conn, self._cache_id)
+
+ if config_result.status == 0:
+ self._settings = config_result.value
+ else:
+ raise CacheError(config_result.message)
+
+ return self._settings
+
+ async def name(self) -> str:
+ """
+ Lazy cache name.
+
+ :return: cache name string.
+ """
+ if self._name is None:
+ settings = await self.settings()
+ self._name = settings[prop_codes.PROP_NAME]
+
+ return self._name
+
+ @property
+ def client(self) -> 'AioClient':
+ """
+ Ignite :class:`~pyignite.aio_client.AioClient` object.
+
+ :return: Async client object, through which the cache is accessed.
+ """
+ return self._client
+
+ @property
+ def cache_id(self) -> int:
+ """
+ Cache ID.
+
+ :return: integer value of the cache ID.
+ """
+ return self._cache_id
+
+ @status_to_exception(CacheError)
+ async def destroy(self):
+ """
+ Destroys cache with a given name.
+ """
+ conn = await self.get_best_node()
+ return await cache_destroy_async(conn, self._cache_id)
+
+ @status_to_exception(CacheError)
+ async def _get_affinity(self, conn: 'AioConnection') -> Dict:
+ """
+ Queries server for affinity mappings. Retries in case
+ of an intermittent error (most probably “Getting affinity for topology
+ version earlier than affinity is calculated”).
+
+ :param conn: connection to Igneite server,
+ :return: OP_CACHE_PARTITIONS operation result value.
+ """
+ for _ in range(AFFINITY_RETRIES or 1):
+ result = await cache_get_node_partitions_async(conn, self._cache_id)
+ if result.status == 0 and result.value['partition_mapping']:
+ break
+ await asyncio.sleep(AFFINITY_DELAY)
+
+ return result
+
+ async def get_best_node(self, key: Any = None, key_hint: 'IgniteDataType' = None) -> 'AioConnection':
+ """
+ Returns the node from the list of the nodes, opened by client, that
+ most probably contains the needed key-value pair. See IEP-23.
+
+ This method is not a part of the public API. Unless you wish to
+ extend the `pyignite` capabilities (with additional testing, logging,
+ examining connections, et c.) you probably should not use it.
+
+ :param key: (optional) pythonic key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :return: Ignite connection object.
+ """
+ conn = await self._client.random_node()
+
+ if self.client.partition_aware and key is not None:
+ if self.__should_update_mapping():
+ async with self._affinity_query_mux:
+ while self.__should_update_mapping():
+ try:
+ full_affinity = await self._get_affinity(conn)
+ self._update_affinity(full_affinity)
+
+ asyncio.ensure_future(
+ asyncio.gather(
+ *[conn.reconnect() for conn in self.client._nodes if not conn.alive],
+ return_exceptions=True
+ )
+ )
+
+ break
+ except connection_errors:
+ # retry if connection failed
+ conn = await self._client.random_node()
+ pass
+ except CacheError:
+ # server did not create mapping in time
+ return conn
+
+ parts = self.affinity.get('number_of_partitions')
+
+ if not parts:
+ return conn
+
+ key, key_hint = self._get_affinity_key(key, key_hint)
+
+ hashcode = await key_hint.hashcode_async(key, self._client)
+
+ best_node = self._get_node_by_hashcode(hashcode, parts)
+ if best_node:
+ return best_node
+
+ return conn
+
+ def __should_update_mapping(self):
+ return self.affinity['version'] < self._client.affinity_version
+
+ @status_to_exception(CacheError)
+ async def get(self, key, key_hint: object = None) -> Any:
+ """
+ Retrieves a value from cache by key.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :return: value retrieved.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_get_async(conn, self._cache_id, key, key_hint=key_hint)
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def put(self, key, value, key_hint: object = None, value_hint: object = None):
+ """
+ Puts a value with a given key to cache (overwriting existing value
+ if any).
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param value: value for the key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ return await cache_put_async(conn, self._cache_id, key, value, key_hint=key_hint, value_hint=value_hint)
+
+ @status_to_exception(CacheError)
+ async def get_all(self, keys: list) -> list:
+ """
+ Retrieves multiple key-value pairs from cache.
+
+ :param keys: list of keys or tuples of (key, key_hint),
+ :return: a dict of key-value pairs.
+ """
+ conn = await self.get_best_node()
+ result = await cache_get_all_async(conn, self._cache_id, keys)
+ if result.value:
+ keys = list(result.value.keys())
+ values = await asyncio.gather(*[self.client.unwrap_binary(value) for value in result.value.values()])
+
+ for i, key in enumerate(keys):
+ result.value[key] = values[i]
+ return result
+
+ @status_to_exception(CacheError)
+ async def put_all(self, pairs: dict):
+ """
+ Puts multiple key-value pairs to cache (overwriting existing
+ associations if any).
+
+ :param pairs: dictionary type parameters, contains key-value pairs
+ to save. Each key or value can be an item of representable
+ Python type or a tuple of (item, hint),
+ """
+ conn = await self.get_best_node()
+ return await cache_put_all_async(conn, self._cache_id, pairs)
+
+ @status_to_exception(CacheError)
+ async def replace(self, key, value, key_hint: object = None, value_hint: object = None):
+ """
+ Puts a value with a given key to cache only if the key already exist.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param value: value for the key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_replace_async(conn, self._cache_id, key, value, key_hint=key_hint, value_hint=value_hint)
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def clear(self, keys: Optional[list] = None):
+ """
+ Clears the cache without notifying listeners or cache writers.
+
+ :param keys: (optional) list of cache keys or (key, key type
+ hint) tuples to clear (default: clear all).
+ """
+ conn = await self.get_best_node()
+ if keys:
+ return await cache_clear_keys_async(conn, self._cache_id, keys)
+ else:
+ return await cache_clear_async(conn, self._cache_id)
+
+ @status_to_exception(CacheError)
+ async def clear_key(self, key, key_hint: object = None):
+ """
+ Clears the cache key without notifying listeners or cache writers.
+
+ :param key: key for the cache entry,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ return await cache_clear_key_async(conn, self._cache_id, key, key_hint=key_hint)
+
+ @status_to_exception(CacheError)
+ async def clear_keys(self, keys: Iterable):
+ """
+ Clears the cache key without notifying listeners or cache writers.
+
+ :param keys: a list of keys or (key, type hint) tuples
+ """
+ conn = await self.get_best_node()
+ return await cache_clear_keys_async(conn, self._cache_id, keys)
+
+ @status_to_exception(CacheError)
+ async def contains_key(self, key, key_hint=None) -> bool:
+ """
+ Returns a value indicating whether given key is present in cache.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :return: boolean `True` when key is present, `False` otherwise.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ return await cache_contains_key_async(conn, self._cache_id, key, key_hint=key_hint)
+
+ @status_to_exception(CacheError)
+ async def contains_keys(self, keys: Iterable) -> bool:
+ """
+ Returns a value indicating whether all given keys are present in cache.
+
+ :param keys: a list of keys or (key, type hint) tuples,
+ :return: boolean `True` when all keys are present, `False` otherwise.
+ """
+ conn = await self.get_best_node()
+ return await cache_contains_keys_async(conn, self._cache_id, keys)
+
+ @status_to_exception(CacheError)
+ async def get_and_put(self, key, value, key_hint=None, value_hint=None) -> Any:
+ """
+ Puts a value with a given key to cache, and returns the previous value
+ for that key, or null value if there was not such key.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param value: value for the key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted.
+ :return: old value or None.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_get_and_put_async(conn, self._cache_id, key, value, key_hint, value_hint)
+
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def get_and_put_if_absent(self, key, value, key_hint=None, value_hint=None):
+ """
+ Puts a value with a given key to cache only if the key does not
+ already exist.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param value: value for the key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted,
+ :return: old value or None.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_get_and_put_if_absent_async(conn, self._cache_id, key, value, key_hint, value_hint)
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def put_if_absent(self, key, value, key_hint=None, value_hint=None):
+ """
+ Puts a value with a given key to cache only if the key does not
+ already exist.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param value: value for the key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ return await cache_put_if_absent_async(conn, self._cache_id, key, value, key_hint, value_hint)
+
+ @status_to_exception(CacheError)
+ async def get_and_remove(self, key, key_hint=None) -> Any:
+ """
+ Removes the cache entry with specified key, returning the value.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :return: old value or None.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_get_and_remove_async(conn, self._cache_id, key, key_hint)
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def get_and_replace(self, key, value, key_hint=None, value_hint=None) -> Any:
+ """
+ Puts a value with a given key to cache, returning previous value
+ for that key, if and only if there is a value currently mapped
+ for that key.
+
+ :param key: key for the cache entry. Can be of any supported type,
+ :param value: value for the key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted.
+ :return: old value or None.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_get_and_replace_async(conn, self._cache_id, key, value, key_hint, value_hint)
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def remove_key(self, key, key_hint=None):
+ """
+ Clears the cache key without notifying listeners or cache writers.
+
+ :param key: key for the cache entry,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ return await cache_remove_key_async(conn, self._cache_id, key, key_hint)
+
+ @status_to_exception(CacheError)
+ async def remove_keys(self, keys: list):
+ """
+ Removes cache entries by given list of keys, notifying listeners
+ and cache writers.
+
+ :param keys: list of keys or tuples of (key, key_hint) to remove.
+ """
+ conn = await self.get_best_node()
+ return await cache_remove_keys_async(conn, self._cache_id, keys)
+
+ @status_to_exception(CacheError)
+ async def remove_all(self):
+ """
+ Removes all cache entries, notifying listeners and cache writers.
+ """
+ conn = await self.get_best_node()
+ return await cache_remove_all_async(conn, self._cache_id)
+
+ @status_to_exception(CacheError)
+ async def remove_if_equals(self, key, sample, key_hint=None, sample_hint=None):
+ """
+ Removes an entry with a given key if provided value is equal to
+ actual value, notifying listeners and cache writers.
+
+ :param key: key for the cache entry,
+ :param sample: a sample to compare the stored value with,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param sample_hint: (optional) Ignite data type, for whic
+ the given sample should be converted.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ return await cache_remove_if_equals_async(conn, self._cache_id, key, sample, key_hint, sample_hint)
+
+ @status_to_exception(CacheError)
+ async def replace_if_equals(self, key, sample, value, key_hint=None, sample_hint=None, value_hint=None) -> Any:
+ """
+ Puts a value with a given key to cache only if the key already exists
+ and value equals provided sample.
+
+ :param key: key for the cache entry,
+ :param sample: a sample to compare the stored value with,
+ :param value: new value for the given key,
+ :param key_hint: (optional) Ignite data type, for which the given key
+ should be converted,
+ :param sample_hint: (optional) Ignite data type, for whic
+ the given sample should be converted
+ :param value_hint: (optional) Ignite data type, for which the given
+ value should be converted,
+ :return: boolean `True` when key is present, `False` otherwise.
+ """
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ conn = await self.get_best_node(key, key_hint)
+ result = await cache_replace_if_equals_async(conn, self._cache_id, key, sample, value, key_hint, sample_hint,
+ value_hint)
+ result.value = await self.client.unwrap_binary(result.value)
+ return result
+
+ @status_to_exception(CacheError)
+ async def get_size(self, peek_modes=0):
+ """
+ Gets the number of entries in cache.
+
+ :param peek_modes: (optional) limit count to near cache partition
+ (PeekModes.NEAR), primary cache (PeekModes.PRIMARY), or backup cache
+ (PeekModes.BACKUP). Defaults to all cache partitions (PeekModes.ALL),
+ :return: integer number of cache entries.
+ """
+ conn = await self.get_best_node()
+ return await cache_get_size_async(conn, self._cache_id, peek_modes)
+
+ def scan(self, page_size: int = 1, partitions: int = -1, local: bool = False):
+ """
+ Returns all key-value pairs from the cache, similar to `get_all`, but
+ with internal pagination, which is slower, but safer.
+
+ :param page_size: (optional) page size. Default size is 1 (slowest
+ and safest),
+ :param partitions: (optional) number of partitions to query
+ (negative to query entire cache),
+ :param local: (optional) pass True if this query should be executed
+ on local node only. Defaults to False,
+ :return: async scan query cursor
+ """
+ return AioScanCursor(self.client, self._cache_id, page_size, partitions, local)
diff --git a/pyignite/aio_client.py b/pyignite/aio_client.py
new file mode 100644
index 0000000..d882969
--- /dev/null
+++ b/pyignite/aio_client.py
@@ -0,0 +1,358 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import random
+from itertools import chain
+from typing import Iterable, Type, Union, Any
+
+from .api.binary import get_binary_type_async, put_binary_type_async
+from .api.cache_config import cache_get_names_async
+from .client import BaseClient
+from .cursors import AioSqlFieldsCursor
+from .aio_cache import AioCache, get_cache, create_cache, get_or_create_cache
+from .connection import AioConnection
+from .constants import IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT
+from .datatypes import BinaryObject
+from .exceptions import BinaryTypeError, CacheError, ReconnectError, connection_errors
+from .stream import AioBinaryStream, READ_BACKWARD
+from .utils import cache_id, entity_id, status_to_exception, is_iterable, is_wrapped
+
+
+__all__ = ['AioClient']
+
+
+class AioClient(BaseClient):
+ """
+ Asynchronous Client implementation.
+ """
+
+ def __init__(self, compact_footer: bool = None, partition_aware: bool = False, **kwargs):
+ """
+ Initialize client.
+
+ :param compact_footer: (optional) use compact (True, recommended) or
+ full (False) schema approach when serializing Complex objects.
+ Default is to use the same approach the server is using (None).
+ Apache Ignite binary protocol documentation on this topic:
+ https://apacheignite.readme.io/docs/binary-client-protocol-data-format#section-schema
+ :param partition_aware: (optional) try to calculate the exact data
+ placement from the key before to issue the key operation to the
+ server node:
+ https://cwiki.apache.org/confluence/display/IGNITE/IEP-23%3A+Best+Effort+Affinity+for+thin+clients
+ The feature is in experimental status, so the parameter is `False`
+ by default. This will be changed later.
+ """
+ super().__init__(compact_footer, partition_aware, **kwargs)
+ self._registry_mux = asyncio.Lock()
+
+ async def connect(self, *args):
+ """
+ Connect to Ignite cluster node(s).
+
+ :param args: (optional) host(s) and port(s) to connect to.
+ """
+ nodes = self._process_connect_args(*args)
+
+ for i, node in enumerate(nodes):
+ host, port = node
+ conn = AioConnection(self, host, port, **self._connection_args)
+
+ if not self.partition_aware:
+ try:
+ if self.protocol_version is None:
+ # open connection before adding to the pool
+ await conn.connect()
+
+ # do not try to open more nodes
+ self._current_node = i
+
+ except connection_errors:
+ conn.failed = True
+
+ self._nodes.append(conn)
+
+ if self.partition_aware:
+ connect_results = await asyncio.gather(
+ *[conn.connect() for conn in self._nodes],
+ return_exceptions=True
+ )
+
+ reconnect_coro = []
+ for i, res in enumerate(connect_results):
+ if isinstance(res, Exception):
+ if isinstance(res, connection_errors):
+ reconnect_coro.append(self._nodes[i].reconnect())
+ else:
+ raise res
+
+ await asyncio.gather(*reconnect_coro, return_exceptions=True)
+
+ if self.protocol_version is None:
+ raise ReconnectError('Can not connect.')
+
+ async def close(self):
+ await asyncio.gather(*[conn.close() for conn in self._nodes], return_exceptions=True)
+ self._nodes.clear()
+
+ async def random_node(self) -> AioConnection:
+ """
+ Returns random usable node.
+
+ This method is not a part of the public API. Unless you wish to
+ extend the `pyignite` capabilities (with additional testing, logging,
+ examining connections, et c.) you probably should not use it.
+ """
+ if self.partition_aware:
+ # if partition awareness is used just pick a random connected node
+ return await self._get_random_node()
+ else:
+ # if partition awareness is not used then just return the current
+ # node if it's alive or the next usable node if connection with the
+ # current is broken
+ node = self._nodes[self._current_node]
+ if node.alive:
+ return node
+
+ # close current (supposedly failed) node
+ await self._nodes[self._current_node].close()
+
+ # advance the node index
+ self._current_node += 1
+ if self._current_node >= len(self._nodes):
+ self._current_node = 0
+
+ # prepare the list of node indexes to try to connect to
+ for i in chain(range(self._current_node, len(self._nodes)), range(self._current_node)):
+ node = self._nodes[i]
+ try:
+ await node.connect()
+ except connection_errors:
+ pass
+ else:
+ return node
+
+ # no nodes left
+ raise ReconnectError('Can not reconnect: out of nodes.')
+
+ async def _get_random_node(self, reconnect=True):
+ alive_nodes = [n for n in self._nodes if n.alive]
+ if alive_nodes:
+ return random.choice(alive_nodes)
+ elif reconnect:
+ await asyncio.gather(*[n.reconnect() for n in self._nodes], return_exceptions=True)
+ return await self._get_random_node(reconnect=False)
+ else:
+ # cannot choose from an empty sequence
+ raise ReconnectError('Can not reconnect: out of nodes.') from None
+
+ @status_to_exception(BinaryTypeError)
+ async def get_binary_type(self, binary_type: Union[str, int]) -> dict:
+ """
+ Gets the binary type information from the Ignite server. This is quite
+ a low-level implementation of Ignite thin client protocol's
+ `OP_GET_BINARY_TYPE` operation. You would probably want to use
+ :py:meth:`~pyignite.client.Client.query_binary_type` instead.
+
+ :param binary_type: binary type name or ID,
+ :return: binary type description − a dict with the following fields:
+
+ - `type_exists`: True if the type is registered, False otherwise. In
+ the latter case all the following fields are omitted,
+ - `type_id`: Complex object type ID,
+ - `type_name`: Complex object type name,
+ - `affinity_key_field`: string value or None,
+ - `is_enum`: False in case of Complex object registration,
+ - `schemas`: a list, containing the Complex object schemas in format:
+ OrderedDict[field name: field type hint]. A schema can be empty.
+ """
+ conn = await self.random_node()
+ result = await get_binary_type_async(conn, binary_type)
+ return self._process_get_binary_type_result(result)
+
+ @status_to_exception(BinaryTypeError)
+ async def put_binary_type(self, type_name: str, affinity_key_field: str = None, is_enum=False, schema: dict = None):
+ """
+ Registers binary type information in cluster. Do not update binary
+ registry. This is a literal implementation of Ignite thin client
+ protocol's `OP_PUT_BINARY_TYPE` operation. You would probably want
+ to use :py:meth:`~pyignite.client.Client.register_binary_type` instead.
+
+ :param type_name: name of the data type being registered,
+ :param affinity_key_field: (optional) name of the affinity key field,
+ :param is_enum: (optional) register enum if True, binary object
+ otherwise. Defaults to False,
+ :param schema: (optional) when register enum, pass a dict
+ of enumerated parameter names as keys and an integers as values.
+ When register binary type, pass a dict of field names: field types.
+ Binary type with no fields is OK.
+ """
+ conn = await self.random_node()
+ return await put_binary_type_async(conn, type_name, affinity_key_field, is_enum, schema)
+
+ async def register_binary_type(self, data_class: Type, affinity_key_field: str = None):
+ """
+ Register the given class as a representation of a certain Complex
+ object type. Discards autogenerated or previously registered class.
+
+ :param data_class: Complex object class,
+ :param affinity_key_field: (optional) affinity parameter.
+ """
+ if not await self.query_binary_type(data_class.type_id, data_class.schema_id):
+ await self.put_binary_type(data_class.type_name, affinity_key_field, schema=data_class.schema)
+
+ self._registry[data_class.type_id][data_class.schema_id] = data_class
+
+ async def query_binary_type(self, binary_type: Union[int, str], schema: Union[int, dict] = None):
+ """
+ Queries the registry of Complex object classes.
+
+ :param binary_type: Complex object type name or ID,
+ :param schema: (optional) Complex object schema or schema ID,
+ :return: found dataclass or None, if `schema` parameter is provided,
+ a dict of {schema ID: dataclass} format otherwise.
+ """
+ type_id = entity_id(binary_type)
+
+ result = self._get_from_registry(type_id, schema)
+
+ if not result:
+ async with self._registry_mux:
+ result = self._get_from_registry(type_id, schema)
+
+ if not result:
+ type_info = await self.get_binary_type(type_id)
+ self._sync_binary_registry(type_id, type_info)
+ return self._get_from_registry(type_id, schema)
+
+ return result
+
+ async def unwrap_binary(self, value: Any) -> Any:
+ """
+ Detects and recursively unwraps Binary Object.
+
+ :param value: anything that could be a Binary Object,
+ :return: the result of the Binary Object unwrapping with all other data
+ left intact.
+ """
+ if is_wrapped(value):
+ blob, offset = value
+ with AioBinaryStream(self, blob) as stream:
+ data_class = await BinaryObject.parse_async(stream)
+ return await BinaryObject.to_python_async(stream.read_ctype(data_class, direction=READ_BACKWARD), self)
+ return value
+
+ async def create_cache(self, settings: Union[str, dict]) -> 'AioCache':
+ """
+ Creates Ignite cache by name. Raises `CacheError` if such a cache is
+ already exists.
+
+ :param settings: cache name or dict of cache properties' codes
+ and values. All cache properties are documented here:
+ :ref:`cache_props`. See also the
+ :ref:`cache creation example <sql_cache_create>`,
+ :return: :class:`~pyignite.cache.Cache` object.
+ """
+ return await create_cache(self, settings)
+
+ async def get_or_create_cache(self, settings: Union[str, dict]) -> 'AioCache':
+ """
+ Creates Ignite cache, if not exist.
+
+ :param settings: cache name or dict of cache properties' codes
+ and values. All cache properties are documented here:
+ :ref:`cache_props`. See also the
+ :ref:`cache creation example <sql_cache_create>`,
+ :return: :class:`~pyignite.cache.Cache` object.
+ """
+ return await get_or_create_cache(self, settings)
+
+ async def get_cache(self, settings: Union[str, dict]) -> 'AioCache':
+ """
+ Creates Cache object with a given cache name without checking it up
+ on server. If such a cache does not exist, some kind of exception
+ (most probably `CacheError`) may be raised later.
+
+ :param settings: cache name or cache properties (but only `PROP_NAME`
+ property is allowed),
+ :return: :class:`~pyignite.cache.Cache` object.
+ """
+ return await get_cache(self, settings)
+
+ @status_to_exception(CacheError)
+ async def get_cache_names(self) -> list:
+ """
+ Gets existing cache names.
+
+ :return: list of cache names.
+ """
+ conn = await self.random_node()
+ return await cache_get_names_async(conn)
+
+ def sql(
+ self, query_str: str, page_size: int = 1024,
+ query_args: Iterable = None, schema: str = 'PUBLIC',
+ statement_type: int = 0, distributed_joins: bool = False,
+ local: bool = False, replicated_only: bool = False,
+ enforce_join_order: bool = False, collocated: bool = False,
+ lazy: bool = False, include_field_names: bool = False,
+ max_rows: int = -1, timeout: int = 0,
+ cache: Union[int, str, 'AioCache'] = None
+ ):
+ """
+ Runs an SQL query and returns its result.
+
+ :param query_str: SQL query string,
+ :param page_size: (optional) cursor page size. Default is 1024, which
+ means that client makes one server call per 1024 rows,
+ :param query_args: (optional) query arguments. List of values or
+ (value, type hint) tuples,
+ :param schema: (optional) schema for the query. Defaults to `PUBLIC`,
+ :param statement_type: (optional) statement type. Can be:
+
+ * StatementType.ALL − any type (default),
+ * StatementType.SELECT − select,
+ * StatementType.UPDATE − update.
+
+ :param distributed_joins: (optional) distributed joins. Defaults
+ to False,
+ :param local: (optional) pass True if this query should be executed
+ on local node only. Defaults to False,
+ :param replicated_only: (optional) whether query contains only
+ replicated tables or not. Defaults to False,
+ :param enforce_join_order: (optional) enforce join order. Defaults
+ to False,
+ :param collocated: (optional) whether your data is co-located or not.
+ Defaults to False,
+ :param lazy: (optional) lazy query execution. Defaults to False,
+ :param include_field_names: (optional) include field names in result.
+ Defaults to False,
+ :param max_rows: (optional) query-wide maximum of rows. Defaults to -1
+ (all rows),
+ :param timeout: (optional) non-negative timeout value in ms.
+ Zero disables timeout (default),
+ :param cache (optional) Name or ID of the cache to use to infer schema.
+ If set, 'schema' argument is ignored,
+ :return: generator with result rows as a lists. If
+ `include_field_names` was set, the first row will hold field names.
+ """
+
+ c_id = cache.cache_id if isinstance(cache, AioCache) else cache_id(cache)
+
+ if c_id != 0:
+ schema = None
+
+ return AioSqlFieldsCursor(self, c_id, query_str, page_size, query_args, schema, statement_type,
+ distributed_joins, local, replicated_only, enforce_join_order, collocated,
+ lazy, include_field_names, max_rows, timeout)
diff --git a/pyignite/api/__init__.py b/pyignite/api/__init__.py
index 7dbef0a..7deed8c 100644
--- a/pyignite/api/__init__.py
+++ b/pyignite/api/__init__.py
@@ -23,53 +23,55 @@
stable end user API see :mod:`pyignite.client` module.
"""
+# flake8: noqa
+
from .affinity import (
- cache_get_node_partitions,
+ cache_get_node_partitions, cache_get_node_partitions_async,
)
from .cache_config import (
- cache_create,
- cache_get_names,
- cache_get_or_create,
- cache_destroy,
- cache_get_configuration,
- cache_create_with_config,
- cache_get_or_create_with_config,
+ cache_create, cache_create_async,
+ cache_get_names, cache_get_names_async,
+ cache_get_or_create, cache_get_or_create_async,
+ cache_destroy, cache_destroy_async,
+ cache_get_configuration, cache_get_configuration_async,
+ cache_create_with_config, cache_create_with_config_async,
+ cache_get_or_create_with_config, cache_get_or_create_with_config_async,
)
from .key_value import (
- cache_get,
- cache_put,
- cache_get_all,
- cache_put_all,
- cache_contains_key,
- cache_contains_keys,
- cache_get_and_put,
- cache_get_and_replace,
- cache_get_and_remove,
- cache_put_if_absent,
- cache_get_and_put_if_absent,
- cache_replace,
- cache_replace_if_equals,
- cache_clear,
- cache_clear_key,
- cache_clear_keys,
- cache_remove_key,
- cache_remove_if_equals,
- cache_remove_keys,
- cache_remove_all,
- cache_get_size,
- cache_local_peek,
+ cache_get, cache_get_async,
+ cache_put, cache_put_async,
+ cache_get_all, cache_get_all_async,
+ cache_put_all, cache_put_all_async,
+ cache_contains_key, cache_contains_key_async,
+ cache_contains_keys, cache_contains_keys_async,
+ cache_get_and_put, cache_get_and_put_async,
+ cache_get_and_replace, cache_get_and_replace_async,
+ cache_get_and_remove, cache_get_and_remove_async,
+ cache_put_if_absent, cache_put_if_absent_async,
+ cache_get_and_put_if_absent, cache_get_and_put_if_absent_async,
+ cache_replace, cache_replace_async,
+ cache_replace_if_equals, cache_replace_if_equals_async,
+ cache_clear, cache_clear_async,
+ cache_clear_key, cache_clear_key_async,
+ cache_clear_keys, cache_clear_keys_async,
+ cache_remove_key, cache_remove_key_async,
+ cache_remove_if_equals, cache_remove_if_equals_async,
+ cache_remove_keys, cache_remove_keys_async,
+ cache_remove_all, cache_remove_all_async,
+ cache_get_size, cache_get_size_async,
+ cache_local_peek, cache_local_peek_async,
)
from .sql import (
- scan,
- scan_cursor_get_page,
+ scan, scan_async,
+ scan_cursor_get_page, scan_cursor_get_page_async,
sql,
sql_cursor_get_page,
- sql_fields,
- sql_fields_cursor_get_page,
- resource_close,
+ sql_fields, sql_fields_async,
+ sql_fields_cursor_get_page, sql_fields_cursor_get_page_async,
+ resource_close, resource_close_async
)
from .binary import (
- get_binary_type,
- put_binary_type,
+ get_binary_type, get_binary_type_async,
+ put_binary_type, put_binary_type_async
)
from .result import APIResult
diff --git a/pyignite/api/affinity.py b/pyignite/api/affinity.py
index 7d09517..ddf1e7a 100644
--- a/pyignite/api/affinity.py
+++ b/pyignite/api/affinity.py
@@ -15,9 +15,10 @@
from typing import Iterable, Union
+from pyignite.connection import AioConnection, Connection
from pyignite.datatypes import Bool, Int, Long, UUIDObject
from pyignite.datatypes.internal import StructArray, Conditional, Struct
-from pyignite.queries import Query
+from pyignite.queries import Query, query_perform
from pyignite.queries.op_codes import OP_CACHE_PARTITIONS
from pyignite.utils import is_iterable
from .result import APIResult
@@ -67,10 +68,7 @@
])
-def cache_get_node_partitions(
- conn: 'Connection', caches: Union[int, Iterable[int]],
- query_id: int = None,
-) -> APIResult:
+def cache_get_node_partitions(conn: 'Connection', caches: Union[int, Iterable[int]], query_id: int = None) -> APIResult:
"""
Gets partition mapping for an Ignite cache or a number of caches. See
“IEP-23: Best Effort Affinity for thin clients”.
@@ -82,6 +80,62 @@
is generated,
:return: API result data object.
"""
+ return __cache_get_node_partitions(conn, caches, query_id)
+
+
+async def cache_get_node_partitions_async(conn: 'AioConnection', caches: Union[int, Iterable[int]],
+ query_id: int = None) -> APIResult:
+ """
+ Async version of cache_get_node_partitions.
+ """
+ return await __cache_get_node_partitions(conn, caches, query_id)
+
+
+def __post_process_partitions(result):
+ if result.status == 0:
+ # tidying up the result
+ value = {
+ 'version': (
+ result.value['version_major'],
+ result.value['version_minor']
+ ),
+ 'partition_mapping': {},
+ }
+ for partition_map in result.value['partition_mapping']:
+ is_applicable = partition_map['is_applicable']
+
+ node_mapping = None
+ if is_applicable:
+ node_mapping = {
+ p['node_uuid']: set(x['partition_id'] for x in p['node_partitions'])
+ for p in partition_map['node_mapping']
+ }
+
+ for cache_info in partition_map['cache_mapping']:
+ cache_id = cache_info['cache_id']
+
+ cache_partition_mapping = {
+ 'is_applicable': is_applicable,
+ }
+
+ parts = 0
+ if is_applicable:
+ cache_partition_mapping['cache_config'] = {
+ a['key_type_id']: a['affinity_key_field_id']
+ for a in cache_info['cache_config']
+ }
+ cache_partition_mapping['node_mapping'] = node_mapping
+
+ parts = sum(len(p) for p in cache_partition_mapping['node_mapping'].values())
+
+ cache_partition_mapping['number_of_partitions'] = parts
+
+ value['partition_mapping'][cache_id] = cache_partition_mapping
+ result.value = value
+ return result
+
+
+def __cache_get_node_partitions(conn, caches, query_id):
query_struct = Query(
OP_CACHE_PARTITIONS,
[
@@ -92,7 +146,8 @@
if not is_iterable(caches):
caches = [caches]
- result = query_struct.perform(
+ return query_perform(
+ query_struct,
conn,
query_params={
'cache_ids': [{'cache_id': cache} for cache in caches],
@@ -102,36 +157,5 @@
('version_minor', Int),
('partition_mapping', partition_mapping),
],
+ post_process_fun=__post_process_partitions
)
- if result.status == 0:
- # tidying up the result
- value = {
- 'version': (
- result.value['version_major'],
- result.value['version_minor']
- ),
- 'partition_mapping': [],
- }
- for i, partition_map in enumerate(result.value['partition_mapping']):
- cache_id = partition_map['cache_mapping'][0]['cache_id']
- value['partition_mapping'].insert(
- i,
- {
- 'cache_id': cache_id,
- 'is_applicable': partition_map['is_applicable'],
- }
- )
- if partition_map['is_applicable']:
- value['partition_mapping'][i]['cache_config'] = {
- a['key_type_id']: a['affinity_key_field_id']
- for a in partition_map['cache_mapping'][0]['cache_config']
- }
- value['partition_mapping'][i]['node_mapping'] = {
- p['node_uuid']: [
- x['partition_id'] for x in p['node_partitions']
- ]
- for p in partition_map['node_mapping']
- }
- result.value = value
-
- return result
diff --git a/pyignite/api/binary.py b/pyignite/api/binary.py
index 87a5232..345e8e8 100644
--- a/pyignite/api/binary.py
+++ b/pyignite/api/binary.py
@@ -15,17 +15,15 @@
from typing import Union
-from pyignite.constants import *
-from pyignite.datatypes.binary import (
- body_struct, enum_struct, schema_struct, binary_fields_struct,
-)
+from pyignite.connection import Connection, AioConnection
+from pyignite.constants import PROTOCOL_BYTE_ORDER
+from pyignite.datatypes.binary import enum_struct, schema_struct, binary_fields_struct
from pyignite.datatypes import String, Int, Bool
-from pyignite.queries import Query
-from pyignite.queries.op_codes import *
+from pyignite.queries import Query, query_perform
+from pyignite.queries.op_codes import OP_GET_BINARY_TYPE, OP_PUT_BINARY_TYPE
from pyignite.utils import entity_id, schema_id
from .result import APIResult
-from ..stream import BinaryStream, READ_BACKWARD
-from ..queries.response import Response
+from ..queries.response import BinaryTypeResponse
def get_binary_type(conn: 'Connection', binary_type: Union[str, int], query_id=None) -> APIResult:
@@ -39,75 +37,33 @@
is generated,
:return: API result data object.
"""
+ return __get_binary_type(conn, binary_type, query_id)
+
+async def get_binary_type_async(conn: 'AioConnection', binary_type: Union[str, int], query_id=None) -> APIResult:
+ """
+ Async version of get_binary_type.
+ """
+ return await __get_binary_type(conn, binary_type, query_id)
+
+
+def __get_binary_type(conn, binary_type, query_id):
query_struct = Query(
OP_GET_BINARY_TYPE,
[
('type_id', Int),
],
query_id=query_id,
+ response_type=BinaryTypeResponse
)
- with BinaryStream(conn) as stream:
- query_struct.from_python(stream, {
- 'type_id': entity_id(binary_type),
- })
- conn.send(stream.getbuffer())
-
- response_head_struct = Response(protocol_version=conn.get_protocol_version(),
- following=[('type_exists', Bool)])
-
- with BinaryStream(conn, conn.recv()) as stream:
- init_pos = stream.tell()
- response_head_type = response_head_struct.parse(stream)
- response_head = stream.read_ctype(response_head_type, direction=READ_BACKWARD)
-
- response_parts = []
- if response_head.type_exists:
- resp_body_type = body_struct.parse(stream)
- response_parts.append(('body', resp_body_type))
- resp_body = stream.read_ctype(resp_body_type, direction=READ_BACKWARD)
- if resp_body.is_enum:
- resp_enum = enum_struct.parse(stream)
- response_parts.append(('enums', resp_enum))
-
- resp_schema_type = schema_struct.parse(stream)
- response_parts.append(('schema', resp_schema_type))
-
- response_class = type(
- 'GetBinaryTypeResponse',
- (response_head_type,),
- {
- '_pack_': 1,
- '_fields_': response_parts,
- }
- )
- response = stream.read_ctype(response_class, position=init_pos)
-
- result = APIResult(response)
- if result.status != 0:
- return result
- result.value = {
- 'type_exists': Bool.to_python(response.type_exists)
- }
- if hasattr(response, 'body'):
- result.value.update(body_struct.to_python(response.body))
- if hasattr(response, 'enums'):
- result.value['enums'] = enum_struct.to_python(response.enums)
- if hasattr(response, 'schema'):
- result.value['schema'] = {
- x['schema_id']: [
- z['schema_field_id'] for z in x['schema_fields']
- ]
- for x in schema_struct.to_python(response.schema)
- }
- return result
+ return query_perform(query_struct, conn, query_params={
+ 'type_id': entity_id(binary_type),
+ })
-def put_binary_type(
- connection: 'Connection', type_name: str, affinity_key_field: str=None,
- is_enum=False, schema: dict=None, query_id=None,
-) -> APIResult:
+def put_binary_type(connection: 'Connection', type_name: str, affinity_key_field: str = None,
+ is_enum=False, schema: dict = None, query_id=None) -> APIResult:
"""
Registers binary type information in cluster.
@@ -125,6 +81,29 @@
is generated,
:return: API result data object.
"""
+ return __put_binary_type(connection, type_name, affinity_key_field, is_enum, schema, query_id)
+
+
+async def put_binary_type_async(connection: 'AioConnection', type_name: str, affinity_key_field: str = None,
+ is_enum=False, schema: dict = None, query_id=None) -> APIResult:
+ """
+ Async version of put_binary_type.
+ """
+ return await __put_binary_type(connection, type_name, affinity_key_field, is_enum, schema, query_id)
+
+
+def __post_process_put_binary(type_id):
+ def internal(result):
+ if result.status == 0:
+ result.value = {
+ 'type_id': type_id,
+ 'schema_id': schema_id,
+ }
+ return result
+ return internal
+
+
+def __put_binary_type(connection, type_name, affinity_key_field, is_enum, schema, query_id):
# prepare data
if schema is None:
schema = {}
@@ -195,10 +174,5 @@
],
query_id=query_id,
)
- result = query_struct.perform(connection, query_params=data)
- if result.status == 0:
- result.value = {
- 'type_id': type_id,
- 'schema_id': schema_id,
- }
- return result
+ return query_perform(query_struct, connection, query_params=data,
+ post_process_fun=__post_process_put_binary(type_id))
diff --git a/pyignite/api/cache_config.py b/pyignite/api/cache_config.py
index cfea416..0adb549 100644
--- a/pyignite/api/cache_config.py
+++ b/pyignite/api/cache_config.py
@@ -25,15 +25,19 @@
from typing import Union
+from pyignite.connection import Connection, AioConnection
from pyignite.datatypes.cache_config import cache_config_struct
from pyignite.datatypes.cache_properties import prop_map
-from pyignite.datatypes import (
- Int, Byte, prop_codes, Short, String, StringArray,
+from pyignite.datatypes import Int, Byte, prop_codes, Short, String, StringArray
+from pyignite.queries import Query, ConfigQuery, query_perform
+from pyignite.queries.op_codes import (
+ OP_CACHE_GET_CONFIGURATION, OP_CACHE_CREATE_WITH_NAME, OP_CACHE_GET_OR_CREATE_WITH_NAME, OP_CACHE_DESTROY,
+ OP_CACHE_GET_NAMES, OP_CACHE_CREATE_WITH_CONFIGURATION, OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION
)
-from pyignite.queries import Query, ConfigQuery
-from pyignite.queries.op_codes import *
from pyignite.utils import cache_id
+from .result import APIResult
+
def compact_cache_config(cache_config: dict) -> dict:
"""
@@ -48,14 +52,13 @@
for k, v in cache_config.items():
if k == 'length':
continue
- prop_code = getattr(prop_codes, 'PROP_{}'.format(k.upper()))
+ prop_code = getattr(prop_codes, f'PROP_{k.upper()}')
result[prop_code] = v
return result
-def cache_get_configuration(
- connection: 'Connection', cache: Union[str, int], flags: int=0, query_id=None,
-) -> 'APIResult':
+def cache_get_configuration(connection: 'Connection', cache: Union[str, int],
+ flags: int = 0, query_id=None) -> 'APIResult':
"""
Gets configuration for the given cache.
@@ -68,7 +71,24 @@
:return: API result data object. Result value is OrderedDict with
the cache configuration parameters.
"""
+ return __cache_get_configuration(connection, cache, flags, query_id)
+
+async def cache_get_configuration_async(connection: 'AioConnection', cache: Union[str, int],
+ flags: int = 0, query_id=None) -> 'APIResult':
+ """
+ Async version of cache_get_configuration.
+ """
+ return await __cache_get_configuration(connection, cache, flags, query_id)
+
+
+def __post_process_cache_config(result):
+ if result.status == 0:
+ result.value = compact_cache_config(result.value['cache_config'])
+ return result
+
+
+def __cache_get_configuration(connection, cache, flags, query_id):
query_struct = Query(
OP_CACHE_GET_CONFIGURATION,
[
@@ -77,24 +97,19 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
- query_params={
- 'hash_code': cache_id(cache),
- 'flags': flags,
- },
- response_config=[
- ('cache_config', cache_config_struct),
- ],
- )
- if result.status == 0:
- result.value = compact_cache_config(result.value['cache_config'])
- return result
+ return query_perform(query_struct, connection,
+ query_params={
+ 'hash_code': cache_id(cache),
+ 'flags': flags
+ },
+ response_config=[
+ ('cache_config', cache_config_struct)
+ ],
+ post_process_fun=__post_process_cache_config
+ )
-def cache_create(
- connection: 'Connection', name: str, query_id=None,
-) -> 'APIResult':
+def cache_create(connection: 'Connection', name: str, query_id=None) -> 'APIResult':
"""
Creates a cache with a given name. Returns error if a cache with specified
name already exists.
@@ -108,24 +123,18 @@
created successfully, non-zero status and an error description otherwise.
"""
- query_struct = Query(
- OP_CACHE_CREATE_WITH_NAME,
- [
- ('cache_name', String),
- ],
- query_id=query_id,
- )
- return query_struct.perform(
- connection,
- query_params={
- 'cache_name': name,
- },
- )
+ return __cache_create_with_name(OP_CACHE_CREATE_WITH_NAME, connection, name, query_id)
-def cache_get_or_create(
- connection: 'Connection', name: str, query_id=None,
-) -> 'APIResult':
+async def cache_create_async(connection: 'AioConnection', name: str, query_id=None) -> 'APIResult':
+ """
+ Async version of cache_create.
+ """
+
+ return await __cache_create_with_name(OP_CACHE_CREATE_WITH_NAME, connection, name, query_id)
+
+
+def cache_get_or_create(connection: 'Connection', name: str, query_id=None) -> 'APIResult':
"""
Creates a cache with a given name. Does nothing if the cache exists.
@@ -138,24 +147,22 @@
created successfully, non-zero status and an error description otherwise.
"""
- query_struct = Query(
- OP_CACHE_GET_OR_CREATE_WITH_NAME,
- [
- ('cache_name', String),
- ],
- query_id=query_id,
- )
- return query_struct.perform(
- connection,
- query_params={
- 'cache_name': name,
- },
- )
+ return __cache_create_with_name(OP_CACHE_GET_OR_CREATE_WITH_NAME, connection, name, query_id)
-def cache_destroy(
- connection: 'Connection', cache: Union[str, int], query_id=None,
-) -> 'APIResult':
+async def cache_get_or_create_async(connection: 'AioConnection', name: str, query_id=None) -> 'APIResult':
+ """
+ Async version of cache_get_or_create.
+ """
+ return await __cache_create_with_name(OP_CACHE_GET_OR_CREATE_WITH_NAME, connection, name, query_id)
+
+
+def __cache_create_with_name(op_code, conn, name, query_id):
+ query_struct = Query(op_code, [('cache_name', String)], query_id=query_id)
+ return query_perform(query_struct, conn, query_params={'cache_name': name})
+
+
+def cache_destroy(connection: 'Connection', cache: Union[str, int], query_id=None) -> 'APIResult':
"""
Destroys cache with a given name.
@@ -166,19 +173,20 @@
is generated,
:return: API result data object.
"""
+ return __cache_destroy(connection, cache, query_id)
- query_struct = Query(
- OP_CACHE_DESTROY,[
- ('hash_code', Int),
- ],
- query_id=query_id,
- )
- return query_struct.perform(
- connection,
- query_params={
- 'hash_code': cache_id(cache),
- },
- )
+
+async def cache_destroy_async(connection: 'AioConnection', cache: Union[str, int], query_id=None) -> 'APIResult':
+ """
+ Async version of cache_destroy.
+ """
+ return await __cache_destroy(connection, cache, query_id)
+
+
+def __cache_destroy(connection, cache, query_id):
+ query_struct = Query(OP_CACHE_DESTROY, [('hash_code', Int)], query_id=query_id)
+
+ return query_perform(query_struct, connection, query_params={'hash_code': cache_id(cache)})
def cache_get_names(connection: 'Connection', query_id=None) -> 'APIResult':
@@ -193,21 +201,30 @@
names, non-zero status and an error description otherwise.
"""
- query_struct = Query(OP_CACHE_GET_NAMES, query_id=query_id)
- result = query_struct.perform(
- connection,
- response_config=[
- ('cache_names', StringArray),
- ],
- )
+ return __cache_get_names(connection, query_id)
+
+
+async def cache_get_names_async(connection: 'AioConnection', query_id=None) -> 'APIResult':
+ """
+ Async version of cache_get_names.
+ """
+ return await __cache_get_names(connection, query_id)
+
+
+def __post_process_cache_names(result):
if result.status == 0:
result.value = result.value['cache_names']
return result
-def cache_create_with_config(
- connection: 'Connection', cache_props: dict, query_id=None,
-) -> 'APIResult':
+def __cache_get_names(connection, query_id):
+ query_struct = Query(OP_CACHE_GET_NAMES, query_id=query_id)
+ return query_perform(query_struct, connection,
+ response_config=[('cache_names', StringArray)],
+ post_process_fun=__post_process_cache_names)
+
+
+def cache_create_with_config(connection: 'Connection', cache_props: dict, query_id=None) -> 'APIResult':
"""
Creates cache with provided configuration. An error is returned
if the name is already in use.
@@ -222,29 +239,17 @@
:return: API result data object. Contains zero status if cache was created,
non-zero status and an error description otherwise.
"""
-
- prop_types = {}
- prop_values = {}
- for i, prop_item in enumerate(cache_props.items()):
- prop_code, prop_value = prop_item
- prop_name = 'property_{}'.format(i)
- prop_types[prop_name] = prop_map(prop_code)
- prop_values[prop_name] = prop_value
- prop_values['param_count'] = len(cache_props)
-
- query_struct = ConfigQuery(
- OP_CACHE_CREATE_WITH_CONFIGURATION,
- [
- ('param_count', Short),
- ] + list(prop_types.items()),
- query_id=query_id,
- )
- return query_struct.perform(connection, query_params=prop_values)
+ return __cache_create_with_config(OP_CACHE_CREATE_WITH_CONFIGURATION, connection, cache_props, query_id)
-def cache_get_or_create_with_config(
- connection: 'Connection', cache_props: dict, query_id=None,
-) -> 'APIResult':
+async def cache_create_with_config_async(connection: 'AioConnection', cache_props: dict, query_id=None) -> 'APIResult':
+ """
+ Async version of cache_create_with_config.
+ """
+ return await __cache_create_with_config(OP_CACHE_CREATE_WITH_CONFIGURATION, connection, cache_props, query_id)
+
+
+def cache_get_or_create_with_config(connection: 'Connection', cache_props: dict, query_id=None) -> 'APIResult':
"""
Creates cache with provided configuration. Does nothing if the name
is already in use.
@@ -259,9 +264,20 @@
:return: API result data object. Contains zero status if cache was created,
non-zero status and an error description otherwise.
"""
+ return __cache_create_with_config(OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION, connection, cache_props, query_id)
- prop_types = {}
- prop_values = {}
+
+async def cache_get_or_create_with_config_async(connection: 'AioConnection', cache_props: dict,
+ query_id=None) -> 'APIResult':
+ """
+ Async version of cache_get_or_create_with_config.
+ """
+ return await __cache_create_with_config(OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION, connection, cache_props,
+ query_id)
+
+
+def __cache_create_with_config(op_code, connection, cache_props, query_id):
+ prop_types, prop_values = {}, {}
for i, prop_item in enumerate(cache_props.items()):
prop_code, prop_value = prop_item
prop_name = 'property_{}'.format(i)
@@ -269,11 +285,6 @@
prop_values[prop_name] = prop_value
prop_values['param_count'] = len(cache_props)
- query_struct = ConfigQuery(
- OP_CACHE_GET_OR_CREATE_WITH_CONFIGURATION,
- [
- ('param_count', Short),
- ] + list(prop_types.items()),
- query_id=query_id,
- )
- return query_struct.perform(connection, query_params=prop_values)
+ following = [('param_count', Short)] + list(prop_types.items())
+ query_struct = ConfigQuery(op_code, following, query_id=query_id)
+ return query_perform(query_struct, connection, query_params=prop_values)
diff --git a/pyignite/api/key_value.py b/pyignite/api/key_value.py
index 25601e9..6d5663c 100644
--- a/pyignite/api/key_value.py
+++ b/pyignite/api/key_value.py
@@ -15,20 +15,26 @@
from typing import Any, Iterable, Optional, Union
-from pyignite.queries.op_codes import *
-from pyignite.datatypes import (
- Map, Bool, Byte, Int, Long, AnyDataArray, AnyDataObject,
+from pyignite.connection import AioConnection, Connection
+from pyignite.queries.op_codes import (
+ OP_CACHE_PUT, OP_CACHE_GET, OP_CACHE_GET_ALL, OP_CACHE_PUT_ALL, OP_CACHE_CONTAINS_KEY, OP_CACHE_CONTAINS_KEYS,
+ OP_CACHE_GET_AND_PUT, OP_CACHE_GET_AND_REPLACE, OP_CACHE_GET_AND_REMOVE, OP_CACHE_PUT_IF_ABSENT,
+ OP_CACHE_GET_AND_PUT_IF_ABSENT, OP_CACHE_REPLACE, OP_CACHE_REPLACE_IF_EQUALS, OP_CACHE_CLEAR, OP_CACHE_CLEAR_KEY,
+ OP_CACHE_CLEAR_KEYS, OP_CACHE_REMOVE_KEY, OP_CACHE_REMOVE_IF_EQUALS, OP_CACHE_REMOVE_KEYS, OP_CACHE_REMOVE_ALL,
+ OP_CACHE_GET_SIZE, OP_CACHE_LOCAL_PEEK
)
+from pyignite.datatypes import Map, Bool, Byte, Int, Long, AnyDataArray, AnyDataObject
+from pyignite.datatypes.base import IgniteDataType
from pyignite.datatypes.key_value import PeekModes
-from pyignite.queries import Query
+from pyignite.queries import Query, query_perform
from pyignite.utils import cache_id
+from .result import APIResult
-def cache_put(
- connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
- key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+
+def cache_put(connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache (overwriting existing value if any).
@@ -48,7 +54,19 @@
:return: API result data object. Contains zero status if a value
is written, non-zero status and an error description otherwise.
"""
+ return __cache_put(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+async def cache_put_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_put
+ """
+ return await __cache_put(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+
+def __cache_put(connection, cache, key, value, key_hint, value_hint, binary, query_id):
query_struct = Query(
OP_CACHE_PUT,
[
@@ -59,19 +77,19 @@
],
query_id=query_id,
)
- return query_struct.perform(connection, {
- 'hash_code': cache_id(cache),
- 'flag': 1 if binary else 0,
- 'key': key,
- 'value': value,
- })
+ return query_perform(
+ query_struct, connection,
+ query_params={
+ 'hash_code': cache_id(cache),
+ 'flag': 1 if binary else 0,
+ 'key': key,
+ 'value': value
+ }
+ )
-def cache_get(
- connection: 'Connection', cache: Union[str, int], key: Any,
- key_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Retrieves a value from cache by key.
@@ -88,7 +106,19 @@
:return: API result data object. Contains zero status and a value
retrieved on success, non-zero status and an error description on failure.
"""
+ return __cache_get(connection, cache, key, key_hint, binary, query_id)
+
+async def cache_get_async(connection: 'AioConnection', cache: Union[str, int], key: Any,
+ key_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_get
+ """
+ return await __cache_get(connection, cache, key, key_hint, binary, query_id)
+
+
+def __cache_get(connection, cache, key, key_hint, binary, query_id):
query_struct = Query(
OP_CACHE_GET,
[
@@ -98,27 +128,22 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
'key': key,
},
response_config=[
- ('value', AnyDataObject),
+ ('value', AnyDataObject),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status != 0:
- return result
- result.value = result.value['value']
- return result
-def cache_get_all(
- connection: 'Connection', cache: Union[str, int], keys: Iterable,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get_all(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Retrieves multiple key-value pairs from cache.
@@ -134,7 +159,18 @@
retrieved key-value pairs, non-zero status and an error description
on failure.
"""
+ return __cache_get_all(connection, cache, keys, binary, query_id)
+
+async def cache_get_all_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_get_all.
+ """
+ return await __cache_get_all(connection, cache, keys, binary, query_id)
+
+
+def __cache_get_all(connection, cache, keys, binary, query_id):
query_struct = Query(
OP_CACHE_GET_ALL,
[
@@ -144,8 +180,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -154,16 +190,12 @@
response_config=[
('data', Map),
],
+ post_process_fun=__post_process_value_by_key('data')
)
- if result.status == 0:
- result.value = dict(result.value)['data']
- return result
-def cache_put_all(
- connection: 'Connection', cache: Union[str, int], pairs: dict,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_put_all(connection: 'Connection', cache: Union[str, int], pairs: dict, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Puts multiple key-value pairs to cache (overwriting existing associations
if any).
@@ -181,7 +213,18 @@
:return: API result data object. Contains zero status if key-value pairs
are written, non-zero status and an error description otherwise.
"""
+ return __cache_put_all(connection, cache, pairs, binary, query_id)
+
+async def cache_put_all_async(connection: 'AioConnection', cache: Union[str, int], pairs: dict, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_put_all.
+ """
+ return await __cache_put_all(connection, cache, pairs, binary, query_id)
+
+
+def __cache_put_all(connection, cache, pairs, binary, query_id):
query_struct = Query(
OP_CACHE_PUT_ALL,
[
@@ -191,8 +234,8 @@
],
query_id=query_id,
)
- return query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -201,11 +244,8 @@
)
-def cache_contains_key(
- connection: 'Connection', cache: Union[str, int], key: Any,
- key_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_contains_key(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Returns a value indicating whether given key is present in cache.
@@ -223,7 +263,19 @@
retrieved on success: `True` when key is present, `False` otherwise,
non-zero status and an error description on failure.
"""
+ return __cache_contains_key(connection, cache, key, key_hint, binary, query_id)
+
+async def cache_contains_key_async(connection: 'AioConnection', cache: Union[str, int], key: Any,
+ key_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_contains_key.
+ """
+ return await __cache_contains_key(connection, cache, key, key_hint, binary, query_id)
+
+
+def __cache_contains_key(connection, cache, key, key_hint, binary, query_id):
query_struct = Query(
OP_CACHE_CONTAINS_KEY,
[
@@ -233,9 +285,9 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
- query_params={
+ return query_perform(
+ query_struct, connection,
+ query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
'key': key,
@@ -243,16 +295,12 @@
response_config=[
('value', Bool),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status == 0:
- result.value = result.value['value']
- return result
-def cache_contains_keys(
- connection: 'Connection', cache: Union[str, int], keys: Iterable,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_contains_keys(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Returns a value indicating whether all given keys are present in cache.
@@ -268,7 +316,18 @@
retrieved on success: `True` when all keys are present, `False` otherwise,
non-zero status and an error description on failure.
"""
+ return __cache_contains_keys(connection, cache, keys, binary, query_id)
+
+async def cache_contains_keys_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_contains_keys.
+ """
+ return await __cache_contains_keys(connection, cache, keys, binary, query_id)
+
+
+def __cache_contains_keys(connection, cache, keys, binary, query_id):
query_struct = Query(
OP_CACHE_CONTAINS_KEYS,
[
@@ -278,8 +337,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -288,17 +347,13 @@
response_config=[
('value', Bool),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status == 0:
- result.value = result.value['value']
- return result
-def cache_get_and_put(
- connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
- key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get_and_put(connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache, and returns the previous value
for that key, or null value if there was not such key.
@@ -320,7 +375,19 @@
or None if a value is written, non-zero status and an error description
in case of error.
"""
+ return __cache_get_and_put(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+async def cache_get_and_put_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_get_and_put.
+ """
+ return await __cache_get_and_put(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+
+def __cache_get_and_put(connection, cache, key, value, key_hint, value_hint, binary, query_id):
query_struct = Query(
OP_CACHE_GET_AND_PUT,
[
@@ -331,8 +398,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -342,17 +409,13 @@
response_config=[
('value', AnyDataObject),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status == 0:
- result.value = result.value['value']
- return result
-def cache_get_and_replace(
- connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
- key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get_and_replace(connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache, returning previous value
for that key, if and only if there is a value currently mapped
@@ -374,7 +437,19 @@
:return: API result data object. Contains zero status and an old value
or None on success, non-zero status and an error description otherwise.
"""
+ return __cache_get_and_replace(connection, cache, key, key_hint, value, value_hint, binary, query_id)
+
+async def cache_get_and_replace_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_get_and_replace.
+ """
+ return await __cache_get_and_replace(connection, cache, key, key_hint, value, value_hint, binary, query_id)
+
+
+def __cache_get_and_replace(connection, cache, key, key_hint, value, value_hint, binary, query_id):
query_struct = Query(
OP_CACHE_GET_AND_REPLACE, [
('hash_code', Int),
@@ -384,8 +459,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -395,17 +470,12 @@
response_config=[
('value', AnyDataObject),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status == 0:
- result.value = result.value['value']
- return result
-def cache_get_and_remove(
- connection: 'Connection', cache: Union[str, int], key: Any,
- key_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get_and_remove(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Removes the cache entry with specified key, returning the value.
@@ -422,7 +492,16 @@
:return: API result data object. Contains zero status and an old value
or None, non-zero status and an error description otherwise.
"""
+ return __cache_get_and_remove(connection, cache, key, key_hint, binary, query_id)
+
+async def cache_get_and_remove_async(connection: 'AioConnection', cache: Union[str, int], key: Any,
+ key_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ return await __cache_get_and_remove(connection, cache, key, key_hint, binary, query_id)
+
+
+def __cache_get_and_remove(connection, cache, key, key_hint, binary, query_id):
query_struct = Query(
OP_CACHE_GET_AND_REMOVE, [
('hash_code', Int),
@@ -431,8 +510,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -441,17 +520,13 @@
response_config=[
('value', AnyDataObject),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status == 0:
- result.value = result.value['value']
- return result
-def cache_put_if_absent(
- connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
- key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_put_if_absent(connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache only if the key
does not already exist.
@@ -472,7 +547,19 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __cache_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+async def cache_put_if_absent_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_put_if_absent.
+ """
+ return await __cache_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+
+def __cache_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id):
query_struct = Query(
OP_CACHE_PUT_IF_ABSENT,
[
@@ -483,8 +570,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -494,17 +581,13 @@
response_config=[
('success', Bool),
],
+ post_process_fun=__post_process_value_by_key('success')
)
- if result.status == 0:
- result.value = result.value['success']
- return result
-def cache_get_and_put_if_absent(
- connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
- key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get_and_put_if_absent(connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache only if the key does not
already exist.
@@ -525,7 +608,19 @@
:return: API result data object. Contains zero status and an old value
or None on success, non-zero status and an error description otherwise.
"""
+ return __cache_get_and_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+async def cache_get_and_put_if_absent_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_get_and_put_if_absent.
+ """
+ return await __cache_get_and_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+
+def __cache_get_and_put_if_absent(connection, cache, key, value, key_hint, value_hint, binary, query_id):
query_struct = Query(
OP_CACHE_GET_AND_PUT_IF_ABSENT,
[
@@ -536,8 +631,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -547,17 +642,13 @@
response_config=[
('value', AnyDataObject),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status == 0:
- result.value = result.value['value']
- return result
-def cache_replace(
- connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
- key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_replace(connection: 'Connection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache only if the key already exist.
@@ -578,7 +669,19 @@
success code, or non-zero status and an error description if something
has gone wrong.
"""
+ return __cache_replace(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+async def cache_replace_async(connection: 'AioConnection', cache: Union[str, int], key: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_replace.
+ """
+ return await __cache_replace(connection, cache, key, value, key_hint, value_hint, binary, query_id)
+
+
+def __cache_replace(connection, cache, key, value, key_hint, value_hint, binary, query_id):
query_struct = Query(
OP_CACHE_REPLACE,
[
@@ -589,8 +692,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -600,18 +703,14 @@
response_config=[
('success', Bool),
],
+ post_process_fun=__post_process_value_by_key('success')
)
- if result.status == 0:
- result.value = result.value['success']
- return result
-def cache_replace_if_equals(
- connection: 'Connection', cache: Union[str, int],
- key: Any, sample: Any, value: Any, key_hint: 'IgniteDatatType' = None,
- sample_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_replace_if_equals(connection: 'Connection', cache: Union[str, int], key: Any, sample: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None,
+ value_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Puts a value with a given key to cache only if the key already exists
and value equals provided sample.
@@ -636,7 +735,23 @@
success code, or non-zero status and an error description if something
has gone wrong.
"""
+ return __cache_replace_if_equals(connection, cache, key, sample, value, key_hint, sample_hint, value_hint, binary,
+ query_id)
+
+async def cache_replace_if_equals_async(
+ connection: 'AioConnection', cache: Union[str, int], key: Any, sample: Any, value: Any,
+ key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None, value_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_replace_if_equals.
+ """
+ return await __cache_replace_if_equals(connection, cache, key, sample, value, key_hint, sample_hint, value_hint,
+ binary, query_id)
+
+
+def __cache_replace_if_equals(connection, cache, key, sample, value, key_hint, sample_hint, value_hint, binary,
+ query_id):
query_struct = Query(
OP_CACHE_REPLACE_IF_EQUALS,
[
@@ -648,8 +763,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -660,16 +775,12 @@
response_config=[
('success', Bool),
],
+ post_process_fun=__post_process_value_by_key('success')
)
- if result.status == 0:
- result.value = result.value['success']
- return result
-def cache_clear(
- connection: 'Connection', cache: Union[str, int],
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_clear(connection: 'Connection', cache: Union[str, int], binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Clears the cache without notifying listeners or cache writers.
@@ -683,7 +794,18 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __cache_clear(connection, cache, binary, query_id)
+
+async def cache_clear_async(connection: 'AioConnection', cache: Union[str, int], binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_clear.
+ """
+ return await __cache_clear(connection, cache, binary, query_id)
+
+
+def __cache_clear(connection, cache, binary, query_id):
query_struct = Query(
OP_CACHE_CLEAR,
[
@@ -692,8 +814,8 @@
],
query_id=query_id,
)
- return query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -701,11 +823,8 @@
)
-def cache_clear_key(
- connection: 'Connection', cache: Union[str, int], key: Any,
- key_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_clear_key(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Clears the cache key without notifying listeners or cache writers.
@@ -722,7 +841,19 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __cache_clear_key(connection, cache, key, key_hint, binary, query_id)
+
+async def cache_clear_key_async(connection: 'AioConnection', cache: Union[str, int], key: Any,
+ key_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_clear_key.
+ """
+ return await __cache_clear_key(connection, cache, key, key_hint, binary, query_id)
+
+
+def __cache_clear_key(connection, cache, key, key_hint, binary, query_id):
query_struct = Query(
OP_CACHE_CLEAR_KEY,
[
@@ -732,8 +863,8 @@
],
query_id=query_id,
)
- return query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -742,10 +873,8 @@
)
-def cache_clear_keys(
- connection: 'Connection', cache: Union[str, int], keys: list,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_clear_keys(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Clears the cache keys without notifying listeners or cache writers.
@@ -760,7 +889,18 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __cache_clear_keys(connection, cache, keys, binary, query_id)
+
+async def cache_clear_keys_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_clear_keys.
+ """
+ return await __cache_clear_keys(connection, cache, keys, binary, query_id)
+
+
+def __cache_clear_keys(connection, cache, keys, binary, query_id):
query_struct = Query(
OP_CACHE_CLEAR_KEYS,
[
@@ -770,8 +910,8 @@
],
query_id=query_id,
)
- return query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -780,11 +920,8 @@
)
-def cache_remove_key(
- connection: 'Connection', cache: Union[str, int], key: Any,
- key_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_remove_key(connection: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Clears the cache key without notifying listeners or cache writers.
@@ -802,7 +939,19 @@
success code, or non-zero status and an error description if something
has gone wrong.
"""
+ return __cache_remove_key(connection, cache, key, key_hint, binary, query_id)
+
+async def cache_remove_key_async(connection: 'AioConnection', cache: Union[str, int], key: Any,
+ key_hint: 'IgniteDataType' = None, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_remove_key.
+ """
+ return await __cache_remove_key(connection, cache, key, key_hint, binary, query_id)
+
+
+def __cache_remove_key(connection, cache, key, key_hint, binary, query_id):
query_struct = Query(
OP_CACHE_REMOVE_KEY,
[
@@ -812,8 +961,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -822,17 +971,13 @@
response_config=[
('success', Bool),
],
+ post_process_fun=__post_process_value_by_key('success')
)
- if result.status == 0:
- result.value = result.value['success']
- return result
-def cache_remove_if_equals(
- connection: 'Connection', cache: Union[str, int], key: Any, sample: Any,
- key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_remove_if_equals(connection: 'Connection', cache: Union[str, int], key: Any, sample: Any,
+ key_hint: 'IgniteDataType' = None, sample_hint: 'IgniteDataType' = None,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Removes an entry with a given key if provided value is equal to
actual value, notifying listeners and cache writers.
@@ -854,7 +999,19 @@
success code, or non-zero status and an error description if something
has gone wrong.
"""
+ return __cache_remove_if_equals(connection, cache, key, sample, key_hint, sample_hint, binary, query_id)
+
+async def cache_remove_if_equals_async(
+ connection: 'AioConnection', cache: Union[str, int], key: Any, sample: Any, key_hint: 'IgniteDataType' = None,
+ sample_hint: 'IgniteDataType' = None, binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_remove_if_equals.
+ """
+ return await __cache_remove_if_equals(connection, cache, key, sample, key_hint, sample_hint, binary, query_id)
+
+
+def __cache_remove_if_equals(connection, cache, key, sample, key_hint, sample_hint, binary, query_id):
query_struct = Query(
OP_CACHE_REMOVE_IF_EQUALS,
[
@@ -865,8 +1022,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -876,16 +1033,12 @@
response_config=[
('success', Bool),
],
+ post_process_fun=__post_process_value_by_key('success')
)
- if result.status == 0:
- result.value = result.value['success']
- return result
-def cache_remove_keys(
- connection: 'Connection', cache: Union[str, int], keys: Iterable,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_remove_keys(connection: 'Connection', cache: Union[str, int], keys: Iterable, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Removes entries with given keys, notifying listeners and cache writers.
@@ -900,7 +1053,18 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __cache_remove_keys(connection, cache, keys, binary, query_id)
+
+async def cache_remove_keys_async(connection: 'AioConnection', cache: Union[str, int], keys: Iterable,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_remove_keys.
+ """
+ return await __cache_remove_keys(connection, cache, keys, binary, query_id)
+
+
+def __cache_remove_keys(connection, cache, keys, binary, query_id):
query_struct = Query(
OP_CACHE_REMOVE_KEYS,
[
@@ -910,8 +1074,8 @@
],
query_id=query_id,
)
- return query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -920,10 +1084,8 @@
)
-def cache_remove_all(
- connection: 'Connection', cache: Union[str, int],
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_remove_all(connection: 'Connection', cache: Union[str, int], binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Removes all entries from cache, notifying listeners and cache writers.
@@ -937,7 +1099,18 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __cache_remove_all(connection, cache, binary, query_id)
+
+async def cache_remove_all_async(connection: 'AioConnection', cache: Union[str, int], binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_remove_all.
+ """
+ return await __cache_remove_all(connection, cache, binary, query_id)
+
+
+def __cache_remove_all(connection, cache, binary, query_id):
query_struct = Query(
OP_CACHE_REMOVE_ALL,
[
@@ -946,8 +1119,8 @@
],
query_id=query_id,
)
- return query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -955,10 +1128,8 @@
)
-def cache_get_size(
- connection: 'Connection', cache: Union[str, int], peek_modes: int = 0,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_get_size(connection: 'Connection', cache: Union[str, int], peek_modes: Union[int, list, tuple] = 0,
+ binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
"""
Gets the number of entries in cache.
@@ -976,6 +1147,16 @@
cache entries on success, non-zero status and an error description
otherwise.
"""
+ return __cache_get_size(connection, cache, peek_modes, binary, query_id)
+
+
+async def cache_get_size_async(connection: 'AioConnection', cache: Union[str, int],
+ peek_modes: Union[int, list, tuple] = 0, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
+ return await __cache_get_size(connection, cache, peek_modes, binary, query_id)
+
+
+def __cache_get_size(connection, cache, peek_modes, binary, query_id):
if not isinstance(peek_modes, (list, tuple)):
peek_modes = [peek_modes] if peek_modes else []
@@ -988,8 +1169,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- connection,
+ return query_perform(
+ query_struct, connection,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -998,21 +1179,17 @@
response_config=[
('count', Long),
],
+ post_process_fun=__post_process_value_by_key('count')
)
- if result.status == 0:
- result.value = result.value['count']
- return result
-def cache_local_peek(
- conn: 'Connection', cache: Union[str, int],
- key: Any, key_hint: 'IgniteDataType' = None, peek_modes: int = 0,
- binary: bool = False, query_id: Optional[int] = None,
-) -> 'APIResult':
+def cache_local_peek(conn: 'Connection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ peek_modes: Union[int, list, tuple] = 0, binary: bool = False,
+ query_id: Optional[int] = None) -> 'APIResult':
"""
Peeks at in-memory cached value using default optional peek mode.
- This method will not load value from any persistent store or from a remote
+ This method will not load value from any cache store or from a remote
node.
:param conn: connection: connection to Ignite server,
@@ -1031,6 +1208,19 @@
:return: API result data object. Contains zero status and a peeked value
(null if not found).
"""
+ return __cache_local_peek(conn, cache, key, key_hint, peek_modes, binary, query_id)
+
+
+async def cache_local_peek_async(
+ conn: 'AioConnection', cache: Union[str, int], key: Any, key_hint: 'IgniteDataType' = None,
+ peek_modes: Union[int, list, tuple] = 0, binary: bool = False, query_id: Optional[int] = None) -> 'APIResult':
+ """
+ Async version of cache_local_peek.
+ """
+ return await __cache_local_peek(conn, cache, key, key_hint, peek_modes, binary, query_id)
+
+
+def __cache_local_peek(conn, cache, key, key_hint, peek_modes, binary, query_id):
if not isinstance(peek_modes, (list, tuple)):
peek_modes = [peek_modes] if peek_modes else []
@@ -1044,8 +1234,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- conn,
+ return query_perform(
+ query_struct, conn,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -1055,8 +1245,14 @@
response_config=[
('value', AnyDataObject),
],
+ post_process_fun=__post_process_value_by_key('value')
)
- if result.status != 0:
+
+
+def __post_process_value_by_key(key):
+ def internal(result):
+ if result.status == 0:
+ result.value = result.value[key]
+
return result
- result.value = result.value['value']
- return result
+ return internal
diff --git a/pyignite/api/result.py b/pyignite/api/result.py
index f60a437..f134be9 100644
--- a/pyignite/api/result.py
+++ b/pyignite/api/result.py
@@ -32,7 +32,7 @@
message = 'Success'
value = None
- def __init__(self, response: 'Response'):
+ def __init__(self, response):
self.status = getattr(response, 'status_code', OP_SUCCESS)
self.query_id = response.query_id
if hasattr(response, 'error_message'):
diff --git a/pyignite/api/sql.py b/pyignite/api/sql.py
index dc470d1..b10cc7d 100644
--- a/pyignite/api/sql.py
+++ b/pyignite/api/sql.py
@@ -15,23 +15,21 @@
from typing import Union
-from pyignite.constants import *
-from pyignite.datatypes import (
- AnyDataArray, AnyDataObject, Bool, Byte, Int, Long, Map, Null, String,
- StructArray,
-)
+from pyignite.connection import AioConnection, Connection
+from pyignite.datatypes import AnyDataArray, AnyDataObject, Bool, Byte, Int, Long, Map, Null, String, StructArray
from pyignite.datatypes.sql import StatementType
-from pyignite.queries import Query
-from pyignite.queries.op_codes import *
+from pyignite.queries import Query, query_perform
+from pyignite.queries.op_codes import (
+ OP_QUERY_SCAN, OP_QUERY_SCAN_CURSOR_GET_PAGE, OP_QUERY_SQL, OP_QUERY_SQL_CURSOR_GET_PAGE, OP_QUERY_SQL_FIELDS,
+ OP_QUERY_SQL_FIELDS_CURSOR_GET_PAGE, OP_RESOURCE_CLOSE
+)
from pyignite.utils import cache_id, deprecated
from .result import APIResult
+from ..queries.response import SQLResponse
-def scan(
- conn: 'Connection', cache: Union[str, int], page_size: int,
- partitions: int = -1, local: bool = False, binary: bool = False,
- query_id: int = None,
-) -> APIResult:
+def scan(conn: 'Connection', cache: Union[str, int], page_size: int, partitions: int = -1, local: bool = False,
+ binary: bool = False, query_id: int = None) -> APIResult:
"""
Performs scan query.
@@ -58,7 +56,24 @@
* `more`: bool, True if more data is available for subsequent
‘scan_cursor_get_page’ calls.
"""
+ return __scan(conn, cache, page_size, partitions, local, binary, query_id)
+
+async def scan_async(conn: 'AioConnection', cache: Union[str, int], page_size: int, partitions: int = -1,
+ local: bool = False, binary: bool = False, query_id: int = None) -> APIResult:
+ """
+ Async version of scan.
+ """
+ return await __scan(conn, cache, page_size, partitions, local, binary, query_id)
+
+
+def __query_result_post_process(result):
+ if result.status == 0:
+ result.value = dict(result.value)
+ return result
+
+
+def __scan(conn, cache, page_size, partitions, local, binary, query_id):
query_struct = Query(
OP_QUERY_SCAN,
[
@@ -71,8 +86,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- conn,
+ return query_perform(
+ query_struct, conn,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -86,15 +101,11 @@
('data', Map),
('more', Bool),
],
+ post_process_fun=__query_result_post_process
)
- if result.status == 0:
- result.value = dict(result.value)
- return result
-def scan_cursor_get_page(
- conn: 'Connection', cursor: int, query_id: int = None,
-) -> APIResult:
+def scan_cursor_get_page(conn: 'Connection', cursor: int, query_id: int = None) -> APIResult:
"""
Fetches the next scan query cursor page by cursor ID that is obtained
from `scan` function.
@@ -114,7 +125,14 @@
* `more`: bool, True if more data is available for subsequent
‘scan_cursor_get_page’ calls.
"""
+ return __scan_cursor_get_page(conn, cursor, query_id)
+
+async def scan_cursor_get_page_async(conn: 'AioConnection', cursor: int, query_id: int = None) -> APIResult:
+ return await __scan_cursor_get_page(conn, cursor, query_id)
+
+
+def __scan_cursor_get_page(conn, cursor, query_id):
query_struct = Query(
OP_QUERY_SCAN_CURSOR_GET_PAGE,
[
@@ -122,8 +140,8 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- conn,
+ return query_perform(
+ query_struct, conn,
query_params={
'cursor': cursor,
},
@@ -131,10 +149,8 @@
('data', Map),
('more', Bool),
],
+ post_process_fun=__query_result_post_process
)
- if result.status == 0:
- result.value = dict(result.value)
- return result
@deprecated(version='1.2.0', reason="This API is deprecated and will be removed in the following major release. "
@@ -322,6 +338,31 @@
* `more`: bool, True if more data is available for subsequent
‘sql_fields_cursor_get_page’ calls.
"""
+ return __sql_fields(conn, cache, query_str, page_size, query_args, schema, statement_type, distributed_joins,
+ local, replicated_only, enforce_join_order, collocated, lazy, include_field_names, max_rows,
+ timeout, binary, query_id)
+
+
+async def sql_fields_async(
+ conn: 'AioConnection', cache: Union[str, int],
+ query_str: str, page_size: int, query_args=None, schema: str = None,
+ statement_type: int = StatementType.ANY, distributed_joins: bool = False,
+ local: bool = False, replicated_only: bool = False,
+ enforce_join_order: bool = False, collocated: bool = False,
+ lazy: bool = False, include_field_names: bool = False, max_rows: int = -1,
+ timeout: int = 0, binary: bool = False, query_id: int = None
+) -> APIResult:
+ """
+ Async version of sql_fields.
+ """
+ return await __sql_fields(conn, cache, query_str, page_size, query_args, schema, statement_type, distributed_joins,
+ local, replicated_only, enforce_join_order, collocated, lazy, include_field_names,
+ max_rows, timeout, binary, query_id)
+
+
+def __sql_fields(conn, cache, query_str, page_size, query_args, schema, statement_type, distributed_joins, local,
+ replicated_only, enforce_join_order, collocated, lazy, include_field_names, max_rows, timeout,
+ binary, query_id):
if query_args is None:
query_args = []
@@ -346,10 +387,11 @@
('include_field_names', Bool),
],
query_id=query_id,
+ response_type=SQLResponse
)
- return query_struct.perform(
- conn,
+ return query_perform(
+ query_struct, conn,
query_params={
'hash_code': cache_id(cache),
'flag': 1 if binary else 0,
@@ -368,15 +410,12 @@
'timeout': timeout,
'include_field_names': include_field_names,
},
- sql=True,
include_field_names=include_field_names,
has_cursor=True,
)
-def sql_fields_cursor_get_page(
- conn: 'Connection', cursor: int, field_count: int, query_id: int = None,
-) -> APIResult:
+def sql_fields_cursor_get_page(conn: 'Connection', cursor: int, field_count: int, query_id: int = None) -> APIResult:
"""
Retrieves the next query result page by cursor ID from `sql_fields`.
@@ -396,7 +435,18 @@
* `more`: bool, True if more data is available for subsequent
‘sql_fields_cursor_get_page’ calls.
"""
+ return __sql_fields_cursor_get_page(conn, cursor, field_count, query_id)
+
+async def sql_fields_cursor_get_page_async(conn: 'AioConnection', cursor: int, field_count: int,
+ query_id: int = None) -> APIResult:
+ """
+ Async version sql_fields_cursor_get_page.
+ """
+ return await __sql_fields_cursor_get_page(conn, cursor, field_count, query_id)
+
+
+def __sql_fields_cursor_get_page(conn, cursor, field_count, query_id):
query_struct = Query(
OP_QUERY_SQL_FIELDS_CURSOR_GET_PAGE,
[
@@ -404,16 +454,20 @@
],
query_id=query_id,
)
- result = query_struct.perform(
- conn,
+ return query_perform(
+ query_struct, conn,
query_params={
'cursor': cursor,
},
response_config=[
('data', StructArray([(f'field_{i}', AnyDataObject) for i in range(field_count)])),
('more', Bool),
- ]
+ ],
+ post_process_fun=__post_process_sql_fields_cursor
)
+
+
+def __post_process_sql_fields_cursor(result):
if result.status != 0:
return result
@@ -427,9 +481,7 @@
return result
-def resource_close(
- conn: 'Connection', cursor: int, query_id: int = None
-) -> APIResult:
+def resource_close(conn: 'Connection', cursor: int, query_id: int = None) -> APIResult:
"""
Closes a resource, such as query cursor.
@@ -441,7 +493,14 @@
:return: API result data object. Contains zero status on success,
non-zero status and an error description otherwise.
"""
+ return __resource_close(conn, cursor, query_id)
+
+async def resource_close_async(conn: 'AioConnection', cursor: int, query_id: int = None) -> APIResult:
+ return await __resource_close(conn, cursor, query_id)
+
+
+def __resource_close(conn, cursor, query_id):
query_struct = Query(
OP_RESOURCE_CLOSE,
[
@@ -449,9 +508,9 @@
],
query_id=query_id,
)
- return query_struct.perform(
- conn,
+ return query_perform(
+ query_struct, conn,
query_params={
'cursor': cursor,
- },
+ }
)
diff --git a/pyignite/binary.py b/pyignite/binary.py
index da62bb5..4e34267 100644
--- a/pyignite/binary.py
+++ b/pyignite/binary.py
@@ -27,15 +27,22 @@
from collections import OrderedDict
import ctypes
+from io import SEEK_CUR
from typing import Any
import attr
-from pyignite.constants import *
-from .datatypes import *
+from .constants import PROTOCOL_BYTE_ORDER
+from .datatypes import (
+ Null, ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject,
+ DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject,
+ IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject,
+ UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, String, StringArrayObject,
+ DecimalObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject, BinaryObject, WrappedDataObject
+)
from .datatypes.base import IgniteDataTypeProps
from .exceptions import ParseError
-from .utils import entity_id, hashcode, schema_id
+from .utils import entity_id, schema_id
ALLOWED_FIELD_TYPES = [
@@ -69,12 +76,14 @@
def __new__(cls, *args, **kwargs) -> Any:
# allow all items in Binary Object schema to be populated as optional
# arguments to `__init__()` with sensible defaults.
- attributes = {}
- for k, v in cls.schema.items():
- attributes[k] = attr.ib(type=getattr(v, 'pythonic', type(None)), default=getattr(v, 'default', None))
+ if not attr.has(cls):
+ attributes = {
+ k: attr.ib(type=getattr(v, 'pythonic', type(None)), default=getattr(v, 'default', None))
+ for k, v in cls.schema.items()
+ }
- attributes.update({'version': attr.ib(type=int, default=1)})
- cls = attr.s(cls, these=attributes)
+ attributes.update({'version': attr.ib(type=int, default=1)})
+ cls = attr.s(cls, these=attributes)
# skip parameters
return super().__new__(cls)
@@ -99,7 +108,7 @@
""" Sort out class creation arguments. """
result = super().__new__(
- mcs, name, (GenericObjectProps, )+base_classes, namespace
+ mcs, name, (GenericObjectProps, ) + base_classes, namespace
)
def _from_python(self, stream, save_to_buf=False):
@@ -111,10 +120,37 @@
:param stream: BinaryStream
:param save_to_buf: Optional. If True, save serialized data to buffer.
"""
+ initial_pos = stream.tell()
+ header, header_class = write_header(self, stream)
- compact_footer = stream.compact_footer
+ offsets = [ctypes.sizeof(header_class)]
+ schema_items = list(self.schema.items())
+ for field_name, field_type in schema_items:
+ val = getattr(self, field_name, getattr(field_type, 'default', None))
+ field_start_pos = stream.tell()
+ field_type.from_python(stream, val)
+ offsets.append(max(offsets) + stream.tell() - field_start_pos)
- # prepare header
+ write_footer(self, stream, header, header_class, schema_items, offsets, initial_pos, save_to_buf)
+
+ async def _from_python_async(self, stream, save_to_buf=False):
+ """
+ Async version of _from_python
+ """
+ initial_pos = stream.tell()
+ header, header_class = write_header(self, stream)
+
+ offsets = [ctypes.sizeof(header_class)]
+ schema_items = list(self.schema.items())
+ for field_name, field_type in schema_items:
+ val = getattr(self, field_name, getattr(field_type, 'default', None))
+ field_start_pos = stream.tell()
+ await field_type.from_python_async(stream, val)
+ offsets.append(max(offsets) + stream.tell() - field_start_pos)
+
+ write_footer(self, stream, header, header_class, schema_items, offsets, initial_pos, save_to_buf)
+
+ def write_header(obj, stream):
header_class = BinaryObject.build_header()
header = header_class()
header.type_code = int.from_bytes(
@@ -122,36 +158,30 @@
byteorder=PROTOCOL_BYTE_ORDER
)
header.flags = BinaryObject.USER_TYPE | BinaryObject.HAS_SCHEMA
- if compact_footer:
+ if stream.compact_footer:
header.flags |= BinaryObject.COMPACT_FOOTER
- header.version = self.version
- header.type_id = self.type_id
- header.schema_id = self.schema_id
+ header.version = obj.version
+ header.type_id = obj.type_id
+ header.schema_id = obj.schema_id
- header_len = ctypes.sizeof(header_class)
- initial_pos = stream.tell()
+ stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
- # create fields and calculate offsets
- offsets = [ctypes.sizeof(header_class)]
- schema_items = list(self.schema.items())
+ return header, header_class
- stream.seek(initial_pos + header_len)
- for field_name, field_type in schema_items:
- val = getattr(self, field_name, getattr(field_type, 'default', None))
- field_start_pos = stream.tell()
- field_type.from_python(stream, val)
- offsets.append(max(offsets) + stream.tell() - field_start_pos)
-
+ def write_footer(obj, stream, header, header_class, schema_items, offsets, initial_pos, save_to_buf):
offsets = offsets[:-1]
+ header_len = ctypes.sizeof(header_class)
# create footer
if max(offsets, default=0) < 255:
header.flags |= BinaryObject.OFFSET_ONE_BYTE
elif max(offsets) < 65535:
header.flags |= BinaryObject.OFFSET_TWO_BYTES
+
schema_class = BinaryObject.schema_type(header.flags) * len(offsets)
schema = schema_class()
- if compact_footer:
+
+ if stream.compact_footer:
for i, offset in enumerate(offsets):
schema[i] = offset
else:
@@ -171,8 +201,8 @@
stream.write(schema)
if save_to_buf:
- self._buffer = bytes(stream.mem_view(initial_pos, stream.tell() - initial_pos))
- self._hashcode = header.hash_code
+ obj._buffer = bytes(stream.mem_view(initial_pos, stream.tell() - initial_pos))
+ obj._hashcode = header.hash_code
def _setattr(self, attr_name: str, attr_value: Any):
# reset binary representation, if any field is changed
@@ -184,6 +214,7 @@
super(result, self).__setattr__(attr_name, attr_value)
setattr(result, _from_python.__name__, _from_python)
+ setattr(result, _from_python_async.__name__, _from_python_async)
setattr(result, '__setattr__', _setattr)
setattr(result, '_buffer', None)
setattr(result, '_hashcode', None)
diff --git a/pyignite/cache.py b/pyignite/cache.py
index a91a3cf..5fba6fb 100644
--- a/pyignite/cache.py
+++ b/pyignite/cache.py
@@ -16,54 +16,145 @@
import time
from typing import Any, Dict, Iterable, Optional, Tuple, Union
-from .constants import *
-from .binary import GenericObjectMeta, unwrap_binary
+from .constants import AFFINITY_RETRIES, AFFINITY_DELAY
+from .binary import GenericObjectMeta
from .datatypes import prop_codes
from .datatypes.internal import AnyDataObject
-from .exceptions import (
- CacheCreationError, CacheError, ParameterError, SQLError,
- connection_errors,
-)
-from .utils import (
- cache_id, get_field_by_id, is_wrapped,
- status_to_exception, unsigned
-)
+from .exceptions import CacheCreationError, CacheError, ParameterError, SQLError, connection_errors
+from .utils import cache_id, get_field_by_id, status_to_exception, unsigned
from .api.cache_config import (
- cache_create, cache_create_with_config,
- cache_get_or_create, cache_get_or_create_with_config,
- cache_destroy, cache_get_configuration,
+ cache_create, cache_create_with_config, cache_get_or_create, cache_get_or_create_with_config, cache_destroy,
+ cache_get_configuration
)
from .api.key_value import (
- cache_get, cache_put, cache_get_all, cache_put_all, cache_replace,
- cache_clear, cache_clear_key, cache_clear_keys,
- cache_contains_key, cache_contains_keys,
- cache_get_and_put, cache_get_and_put_if_absent, cache_put_if_absent,
- cache_get_and_remove, cache_get_and_replace,
- cache_remove_key, cache_remove_keys, cache_remove_all,
- cache_remove_if_equals, cache_replace_if_equals, cache_get_size,
+ cache_get, cache_put, cache_get_all, cache_put_all, cache_replace, cache_clear, cache_clear_key, cache_clear_keys,
+ cache_contains_key, cache_contains_keys, cache_get_and_put, cache_get_and_put_if_absent, cache_put_if_absent,
+ cache_get_and_remove, cache_get_and_replace, cache_remove_key, cache_remove_keys, cache_remove_all,
+ cache_remove_if_equals, cache_replace_if_equals, cache_get_size
)
-from .api.sql import scan, scan_cursor_get_page, sql, sql_cursor_get_page
+from .cursors import ScanCursor, SqlCursor
from .api.affinity import cache_get_node_partitions
-
PROP_CODES = set([
getattr(prop_codes, x)
for x in dir(prop_codes)
if x.startswith('PROP_')
])
-CACHE_CREATE_FUNCS = {
- True: {
- True: cache_get_or_create_with_config,
- False: cache_create_with_config,
- },
- False: {
- True: cache_get_or_create,
- False: cache_create,
- },
-}
-class Cache:
+def get_cache(client: 'Client', settings: Union[str, dict]) -> 'Cache':
+ name, settings = __parse_settings(settings)
+ if settings:
+ raise ParameterError('Only cache name allowed as a parameter')
+
+ return Cache(client, name)
+
+
+def create_cache(client: 'Client', settings: Union[str, dict]) -> 'Cache':
+ name, settings = __parse_settings(settings)
+
+ conn = client.random_node
+ if settings:
+ result = cache_create_with_config(conn, settings)
+ else:
+ result = cache_create(conn, name)
+
+ if result.status != 0:
+ raise CacheCreationError(result.message)
+
+ return Cache(client, name)
+
+
+def get_or_create_cache(client: 'Client', settings: Union[str, dict]) -> 'Cache':
+ name, settings = __parse_settings(settings)
+
+ conn = client.random_node
+ if settings:
+ result = cache_get_or_create_with_config(conn, settings)
+ else:
+ result = cache_get_or_create(conn, name)
+
+ if result.status != 0:
+ raise CacheCreationError(result.message)
+
+ return Cache(client, name)
+
+
+def __parse_settings(settings: Union[str, dict]) -> Tuple[Optional[str], Optional[dict]]:
+ if isinstance(settings, str):
+ return settings, None
+ elif isinstance(settings, dict) and prop_codes.PROP_NAME in settings:
+ name = settings[prop_codes.PROP_NAME]
+ if len(settings) == 1:
+ return name, None
+
+ if not set(settings).issubset(PROP_CODES):
+ raise ParameterError('One or more settings was not recognized')
+
+ return name, settings
+ else:
+ raise ParameterError('You should supply at least cache name')
+
+
+class BaseCacheMixin:
+ def _get_affinity_key(self, key, key_hint=None):
+ if key_hint is None:
+ key_hint = AnyDataObject.map_python_type(key)
+
+ if self.affinity.get('is_applicable'):
+ config = self.affinity.get('cache_config')
+ if config:
+ affinity_key_id = config.get(key_hint.type_id)
+
+ if affinity_key_id and isinstance(key, GenericObjectMeta):
+ return get_field_by_id(key, affinity_key_id)
+
+ return key, key_hint
+
+ def _update_affinity(self, full_affinity):
+ self.affinity['version'] = full_affinity['version']
+
+ full_mapping = full_affinity.get('partition_mapping')
+ if full_mapping and self.cache_id in full_mapping:
+ self.affinity.update(full_mapping[self.cache_id])
+
+ def _get_node_by_hashcode(self, hashcode, parts):
+ """
+ Get node by key hashcode. Calculate partition and return node on that it is primary.
+ (algorithm is taken from `RendezvousAffinityFunction.java`)
+ """
+
+ # calculate partition for key or affinity key
+ # (algorithm is taken from `RendezvousAffinityFunction.java`)
+ mask = parts - 1
+
+ if parts & mask == 0:
+ part = (hashcode ^ (unsigned(hashcode) >> 16)) & mask
+ else:
+ part = abs(hashcode // parts)
+
+ assert 0 <= part < parts, 'Partition calculation has failed'
+
+ node_mapping = self.affinity.get('node_mapping')
+ if not node_mapping:
+ return None
+
+ node_uuid, best_conn = None, None
+ for u, p in node_mapping.items():
+ if part in p:
+ node_uuid = u
+ break
+
+ if node_uuid:
+ for n in self.client._nodes:
+ if n.uuid == node_uuid:
+ best_conn = n
+ break
+ if best_conn and best_conn.alive:
+ return best_conn
+
+
+class Cache(BaseCacheMixin):
"""
Ignite cache abstraction. Users should never use this class directly,
but construct its instances with
@@ -73,77 +164,18 @@
:ref:`this example <create_cache>` on how to do it.
"""
- affinity = None
- _cache_id = None
- _name = None
- _client = None
- _settings = None
-
- @staticmethod
- def _validate_settings(
- settings: Union[str, dict] = None, get_only: bool = False,
- ):
- if any([
- not settings,
- type(settings) not in (str, dict),
- type(settings) is dict and prop_codes.PROP_NAME not in settings,
- ]):
- raise ParameterError('You should supply at least cache name')
-
- if all([
- type(settings) is dict,
- not set(settings).issubset(PROP_CODES),
- ]):
- raise ParameterError('One or more settings was not recognized')
-
- if get_only and type(settings) is dict and len(settings) != 1:
- raise ParameterError('Only cache name allowed as a parameter')
-
- def __init__(
- self, client: 'Client', settings: Union[str, dict] = None,
- with_get: bool = False, get_only: bool = False,
- ):
+ def __init__(self, client: 'Client', name: str):
"""
- Initialize cache object.
+ Initialize cache object. For internal use.
:param client: Ignite client,
- :param settings: cache settings. Can be a string (cache name) or a dict
- of cache properties and their values. In this case PROP_NAME is
- mandatory,
- :param with_get: (optional) do not raise exception, if the cache
- is already exists. Defaults to False,
- :param get_only: (optional) do not communicate with Ignite server
- at all, only create Cache instance. Defaults to False.
+ :param name: Cache name.
"""
self._client = client
- self._validate_settings(settings)
- if type(settings) == str:
- self._name = settings
- else:
- self._name = settings[prop_codes.PROP_NAME]
-
- if not get_only:
- func = CACHE_CREATE_FUNCS[type(settings) is dict][with_get]
- result = func(client.random_node, settings)
- if result.status != 0:
- raise CacheCreationError(result.message)
-
+ self._name = name
+ self._settings = None
self._cache_id = cache_id(self._name)
- self.affinity = {
- 'version': (0, 0),
- }
-
- def get_protocol_version(self) -> Optional[Tuple]:
- """
- Returns the tuple of major, minor, and revision numbers of the used
- thin protocol version, or None, if no connection to the Ignite cluster
- was not yet established.
-
- This method is not a part of the public API. Unless you wish to
- extend the `pyignite` capabilities (with additional testing, logging,
- examining connections, et c.) you probably should not use it.
- """
- return self.client.protocol_version
+ self.affinity = {'version': (0, 0)}
@property
def settings(self) -> Optional[dict]:
@@ -197,18 +229,6 @@
"""
return self._cache_id
- def _process_binary(self, value: Any) -> Any:
- """
- Detects and recursively unwraps Binary Object.
-
- :param value: anything that could be a Binary Object,
- :return: the result of the Binary Object unwrapping with all other data
- left intact.
- """
- if is_wrapped(value):
- return unwrap_binary(self._client, value)
- return value
-
@status_to_exception(CacheError)
def destroy(self):
"""
@@ -234,9 +254,7 @@
return result
- def get_best_node(
- self, key: Any = None, key_hint: 'IgniteDataType' = None,
- ) -> 'Connection':
+ def get_best_node(self, key: Any = None, key_hint: 'IgniteDataType' = None) -> 'Connection':
"""
Returns the node from the list of the nodes, opened by client, that
most probably contains the needed key-value pair. See IEP-23.
@@ -253,14 +271,11 @@
conn = self._client.random_node
if self.client.partition_aware and key is not None:
- if key_hint is None:
- key_hint = AnyDataObject.map_python_type(key)
-
if self.affinity['version'] < self._client.affinity_version:
# update partition mapping
while True:
try:
- self.affinity = self._get_affinity(conn)
+ full_affinity = self._get_affinity(conn)
break
except connection_errors:
# retry if connection failed
@@ -270,68 +285,23 @@
# server did not create mapping in time
return conn
- # flatten it a bit
- try:
- self.affinity.update(self.affinity['partition_mapping'][0])
- except IndexError:
- return conn
- del self.affinity['partition_mapping']
-
- # calculate the number of partitions
- parts = 0
- if 'node_mapping' in self.affinity:
- for p in self.affinity['node_mapping'].values():
- parts += len(p)
-
- self.affinity['number_of_partitions'] = parts
+ self._update_affinity(full_affinity)
for conn in self.client._nodes:
if not conn.alive:
conn.reconnect()
- else:
- # get number of partitions
- parts = self.affinity.get('number_of_partitions')
+
+ parts = self.affinity.get('number_of_partitions')
if not parts:
return conn
- if self.affinity['is_applicable']:
- affinity_key_id = self.affinity['cache_config'].get(
- key_hint.type_id,
- None
- )
- if affinity_key_id and isinstance(key, GenericObjectMeta):
- key, key_hint = get_field_by_id(key, affinity_key_id)
+ key, key_hint = self._get_affinity_key(key, key_hint)
+ hashcode = key_hint.hashcode(key, self._client)
- # calculate partition for key or affinity key
- # (algorithm is taken from `RendezvousAffinityFunction.java`)
- base_value = key_hint.hashcode(key, self._client)
- mask = parts - 1
-
- if parts & mask == 0:
- part = (base_value ^ (unsigned(base_value) >> 16)) & mask
- else:
- part = abs(base_value // parts)
-
- assert 0 <= part < parts, 'Partition calculation has failed'
-
- # search for connection
- try:
- node_uuid, best_conn = None, None
- for u, p in self.affinity['node_mapping'].items():
- if part in p:
- node_uuid = u
- break
-
- if node_uuid:
- for n in conn.client._nodes:
- if n.uuid == node_uuid:
- best_conn = n
- break
- if best_conn and best_conn.alive:
- conn = best_conn
- except KeyError:
- pass
+ best_node = self._get_node_by_hashcode(hashcode, parts)
+ if best_node:
+ return best_node
return conn
@@ -354,12 +324,12 @@
key,
key_hint=key_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
def put(
- self, key, value, key_hint: object = None, value_hint: object = None
+ self, key, value, key_hint: object = None, value_hint: object = None
):
"""
Puts a value with a given key to cache (overwriting existing value
@@ -392,7 +362,7 @@
result = cache_get_all(self.get_best_node(), self._cache_id, keys)
if result.value:
for key, value in result.value.items():
- result.value[key] = self._process_binary(value)
+ result.value[key] = self.client.unwrap_binary(value)
return result
@status_to_exception(CacheError)
@@ -409,7 +379,7 @@
@status_to_exception(CacheError)
def replace(
- self, key, value, key_hint: object = None, value_hint: object = None
+ self, key, value, key_hint: object = None, value_hint: object = None
):
"""
Puts a value with a given key to cache only if the key already exist.
@@ -429,7 +399,7 @@
self._cache_id, key, value,
key_hint=key_hint, value_hint=value_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
@@ -466,6 +436,16 @@
)
@status_to_exception(CacheError)
+ def clear_keys(self, keys: Iterable):
+ """
+ Clears the cache key without notifying listeners or cache writers.
+
+ :param keys: a list of keys or (key, type hint) tuples
+ """
+
+ return cache_clear_keys(self.get_best_node(), self._cache_id, keys)
+
+ @status_to_exception(CacheError)
def contains_key(self, key, key_hint=None) -> bool:
"""
Returns a value indicating whether given key is present in cache.
@@ -493,7 +473,7 @@
:param keys: a list of keys or (key, type hint) tuples,
:return: boolean `True` when all keys are present, `False` otherwise.
"""
- return cache_contains_keys(self._client, self._cache_id, keys)
+ return cache_contains_keys(self.get_best_node(), self._cache_id, keys)
@status_to_exception(CacheError)
def get_and_put(self, key, value, key_hint=None, value_hint=None) -> Any:
@@ -518,12 +498,12 @@
key, value,
key_hint, value_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
def get_and_put_if_absent(
- self, key, value, key_hint=None, value_hint=None
+ self, key, value, key_hint=None, value_hint=None
):
"""
Puts a value with a given key to cache only if the key does not
@@ -546,7 +526,7 @@
key, value,
key_hint, value_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
@@ -591,12 +571,12 @@
key,
key_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
def get_and_replace(
- self, key, value, key_hint=None, value_hint=None
+ self, key, value, key_hint=None, value_hint=None
) -> Any:
"""
Puts a value with a given key to cache, returning previous value
@@ -620,7 +600,7 @@
key, value,
key_hint, value_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
@@ -683,8 +663,8 @@
@status_to_exception(CacheError)
def replace_if_equals(
- self, key, sample, value,
- key_hint=None, sample_hint=None, value_hint=None
+ self, key, sample, value,
+ key_hint=None, sample_hint=None, value_hint=None
) -> Any:
"""
Puts a value with a given key to cache only if the key already exists
@@ -710,7 +690,7 @@
key, sample, value,
key_hint, sample_hint, value_hint
)
- result.value = self._process_binary(result.value)
+ result.value = self.client.unwrap_binary(result.value)
return result
@status_to_exception(CacheError)
@@ -727,9 +707,7 @@
self.get_best_node(), self._cache_id, peek_modes
)
- def scan(
- self, page_size: int = 1, partitions: int = -1, local: bool = False
- ):
+ def scan(self, page_size: int = 1, partitions: int = -1, local: bool = False):
"""
Returns all key-value pairs from the cache, similar to `get_all`, but
with internal pagination, which is slower, but safer.
@@ -740,40 +718,14 @@
(negative to query entire cache),
:param local: (optional) pass True if this query should be executed
on local node only. Defaults to False,
- :return: generator with key-value pairs.
+ :return: Scan query cursor.
"""
- node = self.get_best_node()
-
- result = scan(
- node,
- self._cache_id,
- page_size,
- partitions,
- local
- )
- if result.status != 0:
- raise CacheError(result.message)
-
- cursor = result.value['cursor']
- for k, v in result.value['data'].items():
- k = self._process_binary(k)
- v = self._process_binary(v)
- yield k, v
-
- while result.value['more']:
- result = scan_cursor_get_page(node, cursor)
- if result.status != 0:
- raise CacheError(result.message)
-
- for k, v in result.value['data'].items():
- k = self._process_binary(k)
- v = self._process_binary(v)
- yield k, v
+ return ScanCursor(self.client, self._cache_id, page_size, partitions, local)
def select_row(
- self, query_str: str, page_size: int = 1,
- query_args: Optional[list] = None, distributed_joins: bool = False,
- replicated_only: bool = False, local: bool = False, timeout: int = 0
+ self, query_str: str, page_size: int = 1,
+ query_args: Optional[list] = None, distributed_joins: bool = False,
+ replicated_only: bool = False, local: bool = False, timeout: int = 0
):
"""
Executes a simplified SQL SELECT query over data stored in the cache.
@@ -791,46 +743,13 @@
on local node only. Defaults to False,
:param timeout: (optional) non-negative timeout value in ms. Zero
disables timeout (default),
- :return: generator with key-value pairs.
+ :return: Sql cursor.
"""
- node = self.get_best_node()
-
- def generate_result(value):
- cursor = value['cursor']
- more = value['more']
- for k, v in value['data'].items():
- k = self._process_binary(k)
- v = self._process_binary(v)
- yield k, v
-
- while more:
- inner_result = sql_cursor_get_page(node, cursor)
- if result.status != 0:
- raise SQLError(result.message)
- more = inner_result.value['more']
- for k, v in inner_result.value['data'].items():
- k = self._process_binary(k)
- v = self._process_binary(v)
- yield k, v
-
type_name = self.settings[
prop_codes.PROP_QUERY_ENTITIES
][0]['value_type_name']
if not type_name:
raise SQLError('Value type is unknown')
- result = sql(
- node,
- self._cache_id,
- type_name,
- query_str,
- page_size,
- query_args,
- distributed_joins,
- replicated_only,
- local,
- timeout
- )
- if result.status != 0:
- raise SQLError(result.message)
- return generate_result(result.value)
+ return SqlCursor(self.client, self._cache_id, type_name, query_str, page_size, query_args,
+ distributed_joins, replicated_only, local, timeout)
diff --git a/pyignite/client.py b/pyignite/client.py
index 9416474..e4eef6a 100644
--- a/pyignite/client.py
+++ b/pyignite/client.py
@@ -44,22 +44,20 @@
import random
import re
from itertools import chain
-from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
+from typing import Iterable, Type, Union, Any
from .api.binary import get_binary_type, put_binary_type
from .api.cache_config import cache_get_names
-from .api.sql import sql_fields, sql_fields_cursor_get_page
-from .cache import Cache
+from .cursors import SqlFieldsCursor
+from .cache import Cache, create_cache, get_cache, get_or_create_cache
from .connection import Connection
-from .constants import *
+from .constants import IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT, PROTOCOL_BYTE_ORDER
from .datatypes import BinaryObject
from .datatypes.internal import tc_map
-from .exceptions import (
- BinaryTypeError, CacheError, ReconnectError, SQLError, connection_errors,
-)
+from .exceptions import BinaryTypeError, CacheError, ReconnectError, connection_errors
+from .stream import BinaryStream, READ_BACKWARD
from .utils import (
- cache_id, capitalize, entity_id, schema_id, process_delimiter,
- status_to_exception, is_iterable,
+ cache_id, capitalize, entity_id, schema_id, process_delimiter, status_to_exception, is_iterable, is_wrapped
)
from .binary import GenericObjectMeta
@@ -67,7 +65,185 @@
__all__ = ['Client']
-class Client:
+class BaseClient:
+ # used for Complex object data class names sanitizing
+ _identifier = re.compile(r'[^0-9a-zA-Z_.+$]', re.UNICODE)
+ _ident_start = re.compile(r'^[^a-zA-Z_]+', re.UNICODE)
+
+ def __init__(self, compact_footer: bool = None, partition_aware: bool = False, **kwargs):
+ self._compact_footer = compact_footer
+ self._partition_aware = partition_aware
+ self._connection_args = kwargs
+ self._registry = defaultdict(dict)
+ self._nodes = []
+ self._current_node = 0
+ self._partition_aware = partition_aware
+ self.affinity_version = (0, 0)
+ self._protocol_version = None
+
+ @property
+ def protocol_version(self):
+ """
+ Returns the tuple of major, minor, and revision numbers of the used
+ thin protocol version, or None, if no connection to the Ignite cluster
+ was not yet established.
+
+ This method is not a part of the public API. Unless you wish to
+ extend the `pyignite` capabilities (with additional testing, logging,
+ examining connections, et c.) you probably should not use it.
+ """
+ return self._protocol_version
+
+ @protocol_version.setter
+ def protocol_version(self, value):
+ self._protocol_version = value
+
+ @property
+ def partition_aware(self):
+ return self._partition_aware and self.partition_awareness_supported_by_protocol
+
+ @property
+ def partition_awareness_supported_by_protocol(self):
+ return self.protocol_version is not None and self.protocol_version >= (1, 4, 0)
+
+ @property
+ def compact_footer(self) -> bool:
+ """
+ This property remembers Complex object schema encoding approach when
+ decoding any Complex object, to use the same approach on Complex
+ object encoding.
+
+ :return: True if compact schema was used by server or no Complex
+ object decoding has yet taken place, False if full schema was used.
+ """
+ # this is an ordinary object property, but its backing storage
+ # is a class attribute
+
+ # use compact schema by default, but leave initial (falsy) backing
+ # value unchanged
+ return self._compact_footer or self._compact_footer is None
+
+ @compact_footer.setter
+ def compact_footer(self, value: bool):
+ # normally schema approach should not change
+ if self._compact_footer not in (value, None):
+ raise Warning('Can not change client schema approach.')
+ else:
+ self._compact_footer = value
+
+ @staticmethod
+ def _process_connect_args(*args):
+ if len(args) == 0:
+ # no parameters − use default Ignite host and port
+ return [(IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT)]
+ if len(args) == 1 and is_iterable(args[0]):
+ # iterable of host-port pairs is given
+ return args[0]
+ if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], int):
+ # host and port are given
+ return [args]
+
+ raise ConnectionError('Connection parameters are not valid.')
+
+ def _process_get_binary_type_result(self, result):
+ if result.status != 0 or not result.value['type_exists']:
+ return result
+
+ binary_fields = result.value.pop('binary_fields')
+ old_format_schemas = result.value.pop('schema')
+ result.value['schemas'] = []
+ for s_id, field_ids in old_format_schemas.items():
+ result.value['schemas'].append(self._convert_schema(field_ids, binary_fields))
+ return result
+
+ @staticmethod
+ def _convert_type(tc_type: int):
+ try:
+ return tc_map(tc_type.to_bytes(1, PROTOCOL_BYTE_ORDER))
+ except (KeyError, OverflowError):
+ # if conversion to char or type lookup failed,
+ # we probably have a binary object type ID
+ return BinaryObject
+
+ def _convert_schema(self, field_ids: list, binary_fields: list) -> OrderedDict:
+ converted_schema = OrderedDict()
+ for field_id in field_ids:
+ binary_field = next(x for x in binary_fields if x['field_id'] == field_id)
+ converted_schema[binary_field['field_name']] = self._convert_type(binary_field['type_id'])
+ return converted_schema
+
+ @staticmethod
+ def _create_dataclass(type_name: str, schema: OrderedDict = None) -> Type:
+ """
+ Creates default (generic) class for Ignite Complex object.
+
+ :param type_name: Complex object type name,
+ :param schema: Complex object schema,
+ :return: the resulting class.
+ """
+ schema = schema or {}
+ return GenericObjectMeta(type_name, (), {}, schema=schema)
+
+ @classmethod
+ def _create_type_name(cls, type_name: str) -> str:
+ """
+ Creates Python data class name from Ignite binary type name.
+
+ Handles all the special cases found in
+ `java.org.apache.ignite.binary.BinaryBasicNameMapper.simpleName()`.
+ Tries to adhere to PEP8 along the way.
+ """
+
+ # general sanitizing
+ type_name = cls._identifier.sub('', type_name)
+
+ # - name ending with '$' (Scala)
+ # - name + '$' + some digits (anonymous class)
+ # - '$$Lambda$' in the middle
+ type_name = process_delimiter(type_name, '$')
+
+ # .NET outer/inner class delimiter
+ type_name = process_delimiter(type_name, '+')
+
+ # Java fully qualified class name
+ type_name = process_delimiter(type_name, '.')
+
+ # start chars sanitizing
+ type_name = capitalize(cls._ident_start.sub('', type_name))
+
+ return type_name
+
+ def _sync_binary_registry(self, type_id: int, type_info: dict):
+ """
+ Sync binary registry
+ :param type_id: Complex object type ID.
+ :param type_info: Complex object type info.
+ """
+ if type_info['type_exists']:
+ for schema in type_info['schemas']:
+ if not self._registry[type_id].get(schema_id(schema), None):
+ data_class = self._create_dataclass(
+ self._create_type_name(type_info['type_name']),
+ schema,
+ )
+ self._registry[type_id][schema_id(schema)] = data_class
+
+ def _get_from_registry(self, type_id, schema):
+ """
+ Get binary type info from registry.
+
+ :param type_id: Complex object type ID.
+ :param schema: Complex object schema.
+ """
+ if schema:
+ try:
+ return self._registry[type_id][schema_id(schema)]
+ except KeyError:
+ return None
+ return self._registry[type_id]
+
+
+class Client(BaseClient):
"""
This is a main `pyignite` class, that is build upon the
:class:`~pyignite.connection.Connection`. In addition to the attributes,
@@ -79,23 +255,7 @@
* binary types registration endpoint.
"""
- _registry = defaultdict(dict)
- _compact_footer: bool = None
- _connection_args: Dict = None
- _current_node: int = None
- _nodes: List[Connection] = None
-
- # used for Complex object data class names sanitizing
- _identifier = re.compile(r'[^0-9a-zA-Z_.+$]', re.UNICODE)
- _ident_start = re.compile(r'^[^a-zA-Z_]+', re.UNICODE)
-
- affinity_version: Optional[Tuple] = None
- protocol_version: Optional[Tuple] = None
-
- def __init__(
- self, compact_footer: bool = None, partition_aware: bool = False,
- **kwargs
- ):
+ def __init__(self, compact_footer: bool = None, partition_aware: bool = False, **kwargs):
"""
Initialize client.
@@ -111,35 +271,7 @@
The feature is in experimental status, so the parameter is `False`
by default. This will be changed later.
"""
- self._compact_footer = compact_footer
- self._connection_args = kwargs
- self._nodes = []
- self._current_node = 0
- self._partition_aware = partition_aware
- self.affinity_version = (0, 0)
-
- def get_protocol_version(self) -> Optional[Tuple]:
- """
- Returns the tuple of major, minor, and revision numbers of the used
- thin protocol version, or None, if no connection to the Ignite cluster
- was not yet established.
-
- This method is not a part of the public API. Unless you wish to
- extend the `pyignite` capabilities (with additional testing, logging,
- examining connections, et c.) you probably should not use it.
- """
- return self.protocol_version
-
- @property
- def partition_aware(self):
- return self._partition_aware and self.partition_awareness_supported_by_protocol
-
- @property
- def partition_awareness_supported_by_protocol(self):
- # TODO: Need to re-factor this. I believe, we need separate class or
- # set of functions to work with protocol versions without manually
- # comparing versions with just some random tuples
- return self.protocol_version is not None and self.protocol_version >= (1, 4, 0)
+ super().__init__(compact_footer, partition_aware, **kwargs)
def connect(self, *args):
"""
@@ -147,21 +279,7 @@
:param args: (optional) host(s) and port(s) to connect to.
"""
- if len(args) == 0:
- # no parameters − use default Ignite host and port
- nodes = [(IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT)]
- elif len(args) == 1 and is_iterable(args[0]):
- # iterable of host-port pairs is given
- nodes = args[0]
- elif (
- len(args) == 2
- and isinstance(args[0], str)
- and isinstance(args[1], int)
- ):
- # host and port are given
- nodes = [args]
- else:
- raise ConnectionError('Connection parameters are not valid.')
+ nodes = self._process_connect_args(*args)
# the following code is quite twisted, because the protocol version
# is initially unknown
@@ -169,14 +287,12 @@
# TODO: open first node in foreground, others − in background
for i, node in enumerate(nodes):
host, port = node
- conn = Connection(self, **self._connection_args)
- conn.host = host
- conn.port = port
+ conn = Connection(self, host, port, **self._connection_args)
try:
if self.protocol_version is None or self.partition_aware:
# open connection before adding to the pool
- conn.connect(host, port)
+ conn.connect()
# now we have the protocol version
if not self.partition_aware:
@@ -210,13 +326,7 @@
"""
if self.partition_aware:
# if partition awareness is used just pick a random connected node
- try:
- return random.choice(
- list(n for n in self._nodes if n.alive)
- )
- except IndexError:
- # cannot choose from an empty sequence
- raise ReconnectError('Can not reconnect: out of nodes.') from None
+ return self._get_random_node()
else:
# if partition awareness is not used then just return the current
# node if it's alive or the next usable node if connection with the
@@ -238,7 +348,7 @@
for i in chain(range(self._current_node, num_nodes), range(self._current_node)):
node = self._nodes[i]
try:
- node.connect(node.host, node.port)
+ node.connect()
except connection_errors:
pass
else:
@@ -247,6 +357,19 @@
# no nodes left
raise ReconnectError('Can not reconnect: out of nodes.')
+ def _get_random_node(self, reconnect=True):
+ alive_nodes = [n for n in self._nodes if n.alive]
+ if alive_nodes:
+ return random.choice(alive_nodes)
+ elif reconnect:
+ for n in self._nodes:
+ n.reconnect()
+
+ return self._get_random_node(reconnect=False)
+ else:
+ # cannot choose from an empty sequence
+ raise ReconnectError('Can not reconnect: out of nodes.') from None
+
@status_to_exception(BinaryTypeError)
def get_binary_type(self, binary_type: Union[str, int]) -> dict:
"""
@@ -267,71 +390,8 @@
- `schemas`: a list, containing the Complex object schemas in format:
OrderedDict[field name: field type hint]. A schema can be empty.
"""
- def convert_type(tc_type: int):
- try:
- return tc_map(tc_type.to_bytes(1, PROTOCOL_BYTE_ORDER))
- except (KeyError, OverflowError):
- # if conversion to char or type lookup failed,
- # we probably have a binary object type ID
- return BinaryObject
-
- def convert_schema(
- field_ids: list, binary_fields: list
- ) -> OrderedDict:
- converted_schema = OrderedDict()
- for field_id in field_ids:
- binary_field = [
- x
- for x in binary_fields
- if x['field_id'] == field_id
- ][0]
- converted_schema[binary_field['field_name']] = convert_type(
- binary_field['type_id']
- )
- return converted_schema
-
- conn = self.random_node
-
- result = get_binary_type(conn, binary_type)
- if result.status != 0 or not result.value['type_exists']:
- return result
-
- binary_fields = result.value.pop('binary_fields')
- old_format_schemas = result.value.pop('schema')
- result.value['schemas'] = []
- for s_id, field_ids in old_format_schemas.items():
- result.value['schemas'].append(
- convert_schema(field_ids, binary_fields)
- )
- return result
-
- @property
- def compact_footer(self) -> bool:
- """
- This property remembers Complex object schema encoding approach when
- decoding any Complex object, to use the same approach on Complex
- object encoding.
-
- :return: True if compact schema was used by server or no Complex
- object decoding has yet taken place, False if full schema was used.
- """
- # this is an ordinary object property, but its backing storage
- # is a class attribute
-
- # use compact schema by default, but leave initial (falsy) backing
- # value unchanged
- return (
- self.__class__._compact_footer
- or self.__class__._compact_footer is None
- )
-
- @compact_footer.setter
- def compact_footer(self, value: bool):
- # normally schema approach should not change
- if self.__class__._compact_footer not in (value, None):
- raise Warning('Can not change client schema approach.')
- else:
- self.__class__._compact_footer = value
+ result = get_binary_type(self.random_node, binary_type)
+ return self._process_get_binary_type_result(result)
@status_to_exception(BinaryTypeError)
def put_binary_type(
@@ -353,71 +413,9 @@
When register binary type, pass a dict of field names: field types.
Binary type with no fields is OK.
"""
- return put_binary_type(
- self.random_node, type_name, affinity_key_field, is_enum, schema
- )
+ return put_binary_type(self.random_node, type_name, affinity_key_field, is_enum, schema)
- @staticmethod
- def _create_dataclass(type_name: str, schema: OrderedDict = None) -> Type:
- """
- Creates default (generic) class for Ignite Complex object.
-
- :param type_name: Complex object type name,
- :param schema: Complex object schema,
- :return: the resulting class.
- """
- schema = schema or {}
- return GenericObjectMeta(type_name, (), {}, schema=schema)
-
- def _sync_binary_registry(self, type_id: int):
- """
- Reads Complex object description from Ignite server. Creates default
- Complex object classes and puts in registry, if not already there.
-
- :param type_id: Complex object type ID.
- """
- type_info = self.get_binary_type(type_id)
- if type_info['type_exists']:
- for schema in type_info['schemas']:
- if not self._registry[type_id].get(schema_id(schema), None):
- data_class = self._create_dataclass(
- self._create_type_name(type_info['type_name']),
- schema,
- )
- self._registry[type_id][schema_id(schema)] = data_class
-
- @classmethod
- def _create_type_name(cls, type_name: str) -> str:
- """
- Creates Python data class name from Ignite binary type name.
-
- Handles all the special cases found in
- `java.org.apache.ignite.binary.BinaryBasicNameMapper.simpleName()`.
- Tries to adhere to PEP8 along the way.
- """
-
- # general sanitizing
- type_name = cls._identifier.sub('', type_name)
-
- # - name ending with '$' (Scala)
- # - name + '$' + some digits (anonymous class)
- # - '$$Lambda$' in the middle
- type_name = process_delimiter(type_name, '$')
-
- # .NET outer/inner class delimiter
- type_name = process_delimiter(type_name, '+')
-
- # Java fully qualified class name
- type_name = process_delimiter(type_name, '.')
-
- # start chars sanitizing
- type_name = capitalize(cls._ident_start.sub('', type_name))
-
- return type_name
-
- def register_binary_type(
- self, data_class: Type, affinity_key_field: str = None,
- ):
+ def register_binary_type(self, data_class: Type, affinity_key_field: str = None):
"""
Register the given class as a representation of a certain Complex
object type. Discards autogenerated or previously registered class.
@@ -425,47 +423,44 @@
:param data_class: Complex object class,
:param affinity_key_field: (optional) affinity parameter.
"""
- if not self.query_binary_type(
- data_class.type_id, data_class.schema_id
- ):
- self.put_binary_type(
- data_class.type_name,
- affinity_key_field,
- schema=data_class.schema,
- )
+ if not self.query_binary_type(data_class.type_id, data_class.schema_id):
+ self.put_binary_type(data_class.type_name, affinity_key_field, schema=data_class.schema)
self._registry[data_class.type_id][data_class.schema_id] = data_class
- def query_binary_type(
- self, binary_type: Union[int, str], schema: Union[int, dict] = None,
- sync: bool = True
- ):
+ def query_binary_type(self, binary_type: Union[int, str], schema: Union[int, dict] = None):
"""
Queries the registry of Complex object classes.
:param binary_type: Complex object type name or ID,
- :param schema: (optional) Complex object schema or schema ID,
- :param sync: (optional) look up the Ignite server for registered
- Complex objects and create data classes for them if needed,
+ :param schema: (optional) Complex object schema or schema ID
:return: found dataclass or None, if `schema` parameter is provided,
a dict of {schema ID: dataclass} format otherwise.
"""
type_id = entity_id(binary_type)
- s_id = schema_id(schema)
- if schema:
- try:
- result = self._registry[type_id][s_id]
- except KeyError:
- result = None
- else:
- result = self._registry[type_id]
-
- if sync and not result:
- self._sync_binary_registry(type_id)
- return self.query_binary_type(type_id, s_id, sync=False)
+ result = self._get_from_registry(type_id, schema)
+ if not result:
+ type_info = self.get_binary_type(type_id)
+ self._sync_binary_registry(type_id, type_info)
+ return self._get_from_registry(type_id, schema)
return result
+ def unwrap_binary(self, value: Any) -> Any:
+ """
+ Detects and recursively unwraps Binary Object.
+
+ :param value: anything that could be a Binary Object,
+ :return: the result of the Binary Object unwrapping with all other data
+ left intact.
+ """
+ if is_wrapped(value):
+ blob, offset = value
+ with BinaryStream(self, blob) as stream:
+ data_class = BinaryObject.parse(stream)
+ return BinaryObject.to_python(stream.read_ctype(data_class, direction=READ_BACKWARD), self)
+ return value
+
def create_cache(self, settings: Union[str, dict]) -> 'Cache':
"""
Creates Ignite cache by name. Raises `CacheError` if such a cache is
@@ -477,7 +472,7 @@
:ref:`cache creation example <sql_cache_create>`,
:return: :class:`~pyignite.cache.Cache` object.
"""
- return Cache(self, settings)
+ return create_cache(self, settings)
def get_or_create_cache(self, settings: Union[str, dict]) -> 'Cache':
"""
@@ -489,7 +484,7 @@
:ref:`cache creation example <sql_cache_create>`,
:return: :class:`~pyignite.cache.Cache` object.
"""
- return Cache(self, settings, with_get=True)
+ return get_or_create_cache(self, settings)
def get_cache(self, settings: Union[str, dict]) -> 'Cache':
"""
@@ -501,7 +496,7 @@
property is allowed),
:return: :class:`~pyignite.cache.Cache` object.
"""
- return Cache(self, settings, get_only=True)
+ return get_cache(self, settings)
@status_to_exception(CacheError)
def get_cache_names(self) -> list:
@@ -559,42 +554,12 @@
:return: generator with result rows as a lists. If
`include_field_names` was set, the first row will hold field names.
"""
- def generate_result(value):
- cursor = value['cursor']
- more = value['more']
-
- if include_field_names:
- yield value['fields']
- field_count = len(value['fields'])
- else:
- field_count = value['field_count']
- for line in value['data']:
- yield line
-
- while more:
- inner_result = sql_fields_cursor_get_page(
- conn, cursor, field_count
- )
- if inner_result.status != 0:
- raise SQLError(result.message)
- more = inner_result.value['more']
- for line in inner_result.value['data']:
- yield line
-
- conn = self.random_node
c_id = cache.cache_id if isinstance(cache, Cache) else cache_id(cache)
if c_id != 0:
schema = None
- result = sql_fields(
- conn, c_id, query_str, page_size, query_args, schema,
- statement_type, distributed_joins, local, replicated_only,
- enforce_join_order, collocated, lazy, include_field_names,
- max_rows, timeout,
- )
- if result.status != 0:
- raise SQLError(result.message)
-
- return generate_result(result.value)
+ return SqlFieldsCursor(self, c_id, query_str, page_size, query_args, schema, statement_type, distributed_joins,
+ local, replicated_only, enforce_join_order, collocated, lazy, include_field_names,
+ max_rows, timeout)
diff --git a/pyignite/connection/__init__.py b/pyignite/connection/__init__.py
index 1114594..14e820a 100644
--- a/pyignite/connection/__init__.py
+++ b/pyignite/connection/__init__.py
@@ -34,5 +34,6 @@
"""
from .connection import Connection
+from .aio_connection import AioConnection
-__all__ = ['Connection']
+__all__ = ['Connection', 'AioConnection']
diff --git a/pyignite/connection/aio_connection.py b/pyignite/connection/aio_connection.py
new file mode 100644
index 0000000..e5c11da
--- /dev/null
+++ b/pyignite/connection/aio_connection.py
@@ -0,0 +1,242 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+from asyncio import Lock
+from collections import OrderedDict
+from io import BytesIO
+from typing import Union
+
+from pyignite.constants import PROTOCOLS, PROTOCOL_BYTE_ORDER
+from pyignite.exceptions import HandshakeError, SocketError, connection_errors
+from .connection import BaseConnection
+
+from .handshake import HandshakeRequest, HandshakeResponse
+from .ssl import create_ssl_context
+from ..stream import AioBinaryStream
+
+
+class AioConnection(BaseConnection):
+ """
+ Asyncio connection to Ignite node. It serves multiple purposes:
+
+ * wrapper of asyncio streams. See also https://docs.python.org/3/library/asyncio-stream.html
+ * encapsulates handshake and reconnection.
+ """
+
+ def __init__(self, client: 'AioClient', host: str, port: int, username: str = None, password: str = None,
+ **ssl_params):
+ """
+ Initialize connection.
+
+ For the use of the SSL-related parameters see
+ https://docs.python.org/3/library/ssl.html#ssl-certificates.
+
+ :param client: Ignite client object,
+ :param host: Ignite server node's host name or IP,
+ :param port: Ignite server node's port number,
+ :param use_ssl: (optional) set to True if Ignite server uses SSL
+ on its binary connector. Defaults to use SSL when username
+ and password has been supplied, not to use SSL otherwise,
+ :param ssl_version: (optional) SSL version constant from standard
+ `ssl` module. Defaults to TLS v1.1, as in Ignite 2.5,
+ :param ssl_ciphers: (optional) ciphers to use. If not provided,
+ `ssl` default ciphers are used,
+ :param ssl_cert_reqs: (optional) determines how the remote side
+ certificate is treated:
+
+ * `ssl.CERT_NONE` − remote certificate is ignored (default),
+ * `ssl.CERT_OPTIONAL` − remote certificate will be validated,
+ if provided,
+ * `ssl.CERT_REQUIRED` − valid remote certificate is required,
+
+ :param ssl_keyfile: (optional) a path to SSL key file to identify
+ local (client) party,
+ :param ssl_keyfile_password: (optional) password for SSL key file,
+ can be provided when key file is encrypted to prevent OpenSSL
+ password prompt,
+ :param ssl_certfile: (optional) a path to ssl certificate file
+ to identify local (client) party,
+ :param ssl_ca_certfile: (optional) a path to a trusted certificate
+ or a certificate chain. Required to check the validity of the remote
+ (server-side) certificate,
+ :param username: (optional) user name to authenticate to Ignite
+ cluster,
+ :param password: (optional) password to authenticate to Ignite cluster.
+ """
+ super().__init__(client, host, port, username, password, **ssl_params)
+ self._mux = Lock()
+ self._reader = None
+ self._writer = None
+
+ @property
+ def closed(self) -> bool:
+ """ Tells if socket is closed. """
+ return self._writer is None
+
+ async def connect(self) -> Union[dict, OrderedDict]:
+ """
+ Connect to the given server node with protocol version fallback.
+ """
+ async with self._mux:
+ return await self._connect()
+
+ async def _connect(self) -> Union[dict, OrderedDict]:
+ detecting_protocol = False
+
+ # choose highest version first
+ if self.client.protocol_version is None:
+ detecting_protocol = True
+ self.client.protocol_version = max(PROTOCOLS)
+
+ try:
+ result = await self._connect_version()
+ except HandshakeError as e:
+ if e.expected_version in PROTOCOLS:
+ self.client.protocol_version = e.expected_version
+ result = await self._connect_version()
+ else:
+ raise e
+ except connection_errors:
+ # restore undefined protocol version
+ if detecting_protocol:
+ self.client.protocol_version = None
+ raise
+
+ # connection is ready for end user
+ self.uuid = result.get('node_uuid', None) # version-specific (1.4+)
+
+ self.failed = False
+ return result
+
+ async def _connect_version(self) -> Union[dict, OrderedDict]:
+ """
+ Connect to the given server node using protocol version
+ defined on client.
+ """
+
+ ssl_context = create_ssl_context(self.ssl_params)
+ self._reader, self._writer = await asyncio.open_connection(self.host, self.port, ssl=ssl_context)
+
+ protocol_version = self.client.protocol_version
+
+ hs_request = HandshakeRequest(
+ protocol_version,
+ self.username,
+ self.password
+ )
+
+ with AioBinaryStream(self.client) as stream:
+ await hs_request.from_python_async(stream)
+ await self._send(stream.getbuffer(), reconnect=False)
+
+ with AioBinaryStream(self.client, await self._recv(reconnect=False)) as stream:
+ hs_response = await HandshakeResponse.parse_async(stream, self.protocol_version)
+
+ if hs_response.op_code == 0:
+ self._close()
+ self._process_handshake_error(hs_response)
+
+ return hs_response
+
+ async def reconnect(self):
+ async with self._mux:
+ await self._reconnect()
+
+ async def _reconnect(self):
+ if self.alive:
+ return
+
+ self._close()
+
+ # connect and silence the connection errors
+ try:
+ await self._connect()
+ except connection_errors:
+ pass
+
+ async def request(self, data: Union[bytes, bytearray, memoryview]) -> bytearray:
+ """
+ Perform request.
+
+ :param data: bytes to send.
+ """
+ async with self._mux:
+ await self._send(data)
+ return await self._recv()
+
+ async def _send(self, data: Union[bytes, bytearray, memoryview], reconnect=True):
+ if self.closed:
+ raise SocketError('Attempt to use closed connection.')
+
+ try:
+ self._writer.write(data)
+ await self._writer.drain()
+ except connection_errors:
+ self.failed = True
+ if reconnect:
+ await self._reconnect()
+ raise
+
+ async def _recv(self, reconnect=True) -> bytearray:
+ if self.closed:
+ raise SocketError('Attempt to use closed connection.')
+
+ with BytesIO() as stream:
+ try:
+ buf = await self._reader.readexactly(4)
+ response_len = int.from_bytes(buf, PROTOCOL_BYTE_ORDER)
+
+ stream.write(buf)
+
+ stream.write(await self._reader.readexactly(response_len))
+ except connection_errors:
+ self.failed = True
+ if reconnect:
+ await self._reconnect()
+ raise
+
+ return bytearray(stream.getbuffer())
+
+ async def close(self):
+ async with self._mux:
+ self._close()
+
+ def _close(self):
+ """
+ Close connection.
+ """
+ if self._writer:
+ try:
+ self._writer.close()
+ except connection_errors:
+ pass
+
+ self._writer, self._reader = None, None
diff --git a/pyignite/connection/connection.py b/pyignite/connection/connection.py
index 8db304e..901cb56 100644
--- a/pyignite/connection/connection.py
+++ b/pyignite/connection/connection.py
@@ -32,64 +32,94 @@
import socket
from typing import Union
-from pyignite.constants import *
-from pyignite.exceptions import (
- HandshakeError, ParameterError, SocketError, connection_errors, AuthenticationError,
-)
-from pyignite.datatypes import Byte, Int, Short, String, UUIDObject
-from pyignite.datatypes.internal import Struct
+from pyignite.constants import PROTOCOLS, IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT, PROTOCOL_BYTE_ORDER
+from pyignite.exceptions import HandshakeError, SocketError, connection_errors, AuthenticationError
-from .handshake import HandshakeRequest
-from .ssl import wrap
-from ..stream import BinaryStream, READ_BACKWARD
+from .handshake import HandshakeRequest, HandshakeResponse
+from .ssl import wrap, check_ssl_params
+from ..stream import BinaryStream
CLIENT_STATUS_AUTH_FAILURE = 2000
-class Connection:
+class BaseConnection:
+ def __init__(self, client, host: str = None, port: int = None, username: str = None, password: str = None,
+ **ssl_params):
+ self.client = client
+ self.host = host if host else IGNITE_DEFAULT_HOST
+ self.port = port if port else IGNITE_DEFAULT_PORT
+ self.username = username
+ self.password = password
+ self.uuid = None
+
+ check_ssl_params(ssl_params)
+
+ if self.username and self.password and 'use_ssl' not in ssl_params:
+ ssl_params['use_ssl'] = True
+
+ self.ssl_params = ssl_params
+ self._failed = False
+
+ @property
+ def closed(self) -> bool:
+ """ Tells if socket is closed. """
+ raise NotImplementedError
+
+ @property
+ def failed(self) -> bool:
+ """ Tells if connection is failed. """
+ return self._failed
+
+ @failed.setter
+ def failed(self, value):
+ self._failed = value
+
+ @property
+ def alive(self) -> bool:
+ """ Tells if connection is up and no failure detected. """
+ return not self.failed and not self.closed
+
+ def __repr__(self) -> str:
+ return '{}:{}'.format(self.host or '?', self.port or '?')
+
+ @property
+ def protocol_version(self):
+ """
+ Returns the tuple of major, minor, and revision numbers of the used
+ thin protocol version, or None, if no connection to the Ignite cluster
+ was yet established.
+ """
+ return self.client.protocol_version
+
+ def _process_handshake_error(self, response):
+ error_text = f'Handshake error: {response.message}'
+ # if handshake fails for any reason other than protocol mismatch
+ # (i.e. authentication error), server version is 0.0.0
+ protocol_version = self.client.protocol_version
+ server_version = (response.version_major, response.version_minor, response.version_patch)
+
+ if any(server_version):
+ error_text += f' Server expects binary protocol version ' \
+ f'{server_version[0]}.{server_version[1]}.{server_version[2]}. ' \
+ f'Client provides ' \
+ f'{protocol_version[0]}.{protocol_version[1]}.{protocol_version[2]}.'
+ elif response.client_status == CLIENT_STATUS_AUTH_FAILURE:
+ raise AuthenticationError(error_text)
+ raise HandshakeError(server_version, error_text)
+
+
+class Connection(BaseConnection):
"""
This is a `pyignite` class, that represents a connection to Ignite
node. It serves multiple purposes:
* socket wrapper. Detects fragmentation and network errors. See also
https://docs.python.org/3/howto/sockets.html,
- * binary protocol connector. Incapsulates handshake and failover reconnection.
+ * binary protocol connector. Encapsulates handshake and failover reconnection.
"""
- _socket = None
- _failed = None
-
- client = None
- host = None
- port = None
- timeout = None
- username = None
- password = None
- ssl_params = {}
- uuid = None
-
- @staticmethod
- def _check_ssl_params(params):
- expected_args = [
- 'use_ssl',
- 'ssl_version',
- 'ssl_ciphers',
- 'ssl_cert_reqs',
- 'ssl_keyfile',
- 'ssl_keyfile_password',
- 'ssl_certfile',
- 'ssl_ca_certfile',
- ]
- for param in params:
- if param not in expected_args:
- raise ParameterError((
- 'Unexpected parameter for connection initialization: `{}`'
- ).format(param))
-
- def __init__(
- self, client: 'Client', timeout: float = 2.0,
- username: str = None, password: str = None, **ssl_params
- ):
+ def __init__(self, client: 'Client', host: str, port: int, timeout: float = 2.0,
+ username: str = None, password: str = None, **ssl_params):
"""
Initialize connection.
@@ -97,6 +127,8 @@
https://docs.python.org/3/library/ssl.html#ssl-certificates.
:param client: Ignite client object,
+ :param host: Ignite server node's host name or IP,
+ :param port: Ignite server node's port number,
:param timeout: (optional) sets timeout (in seconds) for each socket
operation including `connect`. 0 means non-blocking mode, which is
virtually guaranteed to fail. Can accept integer or float value.
@@ -130,84 +162,15 @@
cluster,
:param password: (optional) password to authenticate to Ignite cluster.
"""
- self.client = client
+ super().__init__(client, host, port, username, password, **ssl_params)
self.timeout = timeout
- self.username = username
- self.password = password
- self._check_ssl_params(ssl_params)
- if self.username and self.password and 'use_ssl' not in ssl_params:
- ssl_params['use_ssl'] = True
- self.ssl_params = ssl_params
- self._failed = False
+ self._socket = None
@property
def closed(self) -> bool:
- """ Tells if socket is closed. """
return self._socket is None
- @property
- def failed(self) -> bool:
- """ Tells if connection is failed. """
- return self._failed
-
- @failed.setter
- def failed(self, value):
- self._failed = value
-
- @property
- def alive(self) -> bool:
- """ Tells if connection is up and no failure detected. """
- return not self.failed and not self.closed
-
- def __repr__(self) -> str:
- return '{}:{}'.format(self.host or '?', self.port or '?')
-
- _wrap = wrap
-
- def get_protocol_version(self):
- """
- Returns the tuple of major, minor, and revision numbers of the used
- thin protocol version, or None, if no connection to the Ignite cluster
- was yet established.
- """
- return self.client.protocol_version
-
- def read_response(self) -> Union[dict, OrderedDict]:
- """
- Processes server's response to the handshake request.
-
- :return: handshake data.
- """
- response_start = Struct([
- ('length', Int),
- ('op_code', Byte),
- ])
- with BinaryStream(self, self.recv(reconnect=False)) as stream:
- start_class = response_start.parse(stream)
- start = stream.read_ctype(start_class, direction=READ_BACKWARD)
- data = response_start.to_python(start)
- response_end = None
- if data['op_code'] == 0:
- response_end = Struct([
- ('version_major', Short),
- ('version_minor', Short),
- ('version_patch', Short),
- ('message', String),
- ('client_status', Int)
- ])
- elif self.get_protocol_version() >= (1, 4, 0):
- response_end = Struct([
- ('node_uuid', UUIDObject),
- ])
- if response_end:
- end_class = response_end.parse(stream)
- end = stream.read_ctype(end_class, direction=READ_BACKWARD)
- data.update(response_end.to_python(end))
- return data
-
- def connect(
- self, host: str = None, port: int = None
- ) -> Union[dict, OrderedDict]:
+ def connect(self) -> Union[dict, OrderedDict]:
"""
Connect to the given server node with protocol version fallback.
@@ -222,11 +185,11 @@
self.client.protocol_version = max(PROTOCOLS)
try:
- result = self._connect_version(host, port)
+ result = self._connect_version()
except HandshakeError as e:
if e.expected_version in PROTOCOLS:
self.client.protocol_version = e.expected_version
- result = self._connect_version(host, port)
+ result = self._connect_version()
else:
raise e
except connection_errors:
@@ -237,28 +200,19 @@
# connection is ready for end user
self.uuid = result.get('node_uuid', None) # version-specific (1.4+)
-
self.failed = False
return result
- def _connect_version(
- self, host: str = None, port: int = None,
- ) -> Union[dict, OrderedDict]:
+ def _connect_version(self) -> Union[dict, OrderedDict]:
"""
Connect to the given server node using protocol version
defined on client.
-
- :param host: Ignite server node's host name or IP,
- :param port: Ignite server node's port number.
"""
- host = host or IGNITE_DEFAULT_HOST
- port = port or IGNITE_DEFAULT_PORT
-
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(self.timeout)
- self._socket = self._wrap(self._socket)
- self._socket.connect((host, port))
+ self._socket = wrap(self._socket, self.ssl_params)
+ self._socket.connect((self.host, self.port))
protocol_version = self.client.protocol_version
@@ -268,56 +222,41 @@
self.password
)
- with BinaryStream(self) as stream:
+ with BinaryStream(self.client) as stream:
hs_request.from_python(stream)
self.send(stream.getbuffer(), reconnect=False)
- hs_response = self.read_response()
- if hs_response['op_code'] == 0:
- self.close()
+ with BinaryStream(self.client, self.recv(reconnect=False)) as stream:
+ hs_response = HandshakeResponse.parse(stream, self.protocol_version)
- error_text = 'Handshake error: {}'.format(hs_response['message'])
- # if handshake fails for any reason other than protocol mismatch
- # (i.e. authentication error), server version is 0.0.0
- if any([
- hs_response['version_major'],
- hs_response['version_minor'],
- hs_response['version_patch'],
- ]):
- error_text += (
- ' Server expects binary protocol version '
- '{version_major}.{version_minor}.{version_patch}. Client '
- 'provides {client_major}.{client_minor}.{client_patch}.'
- ).format(
- client_major=protocol_version[0],
- client_minor=protocol_version[1],
- client_patch=protocol_version[2],
- **hs_response
- )
- elif hs_response['client_status'] == CLIENT_STATUS_AUTH_FAILURE:
- raise AuthenticationError(error_text)
- raise HandshakeError((
- hs_response['version_major'],
- hs_response['version_minor'],
- hs_response['version_patch'],
- ), error_text)
- self.host, self.port = host, port
- return hs_response
+ if hs_response.op_code == 0:
+ self.close()
+ self._process_handshake_error(hs_response)
+
+ return hs_response
def reconnect(self):
- # do not reconnect if connection is already working
- # or was closed on purpose
- if not self.failed:
+ if self.alive:
return
self.close()
# connect and silence the connection errors
try:
- self.connect(self.host, self.port)
+ self.connect()
except connection_errors:
pass
+ def request(self, data: Union[bytes, bytearray, memoryview], flags=None) -> bytearray:
+ """
+ Perform request.
+
+ :param data: bytes to send,
+ :param flags: (optional) OS-specific flags.
+ """
+ self.send(data, flags=flags)
+ return self.recv()
+
def send(self, data: Union[bytes, bytearray, memoryview], flags=None, reconnect=True):
"""
Send data down the socket.
@@ -337,7 +276,8 @@
self._socket.sendall(data, **kwargs)
except connection_errors:
self.failed = True
- self.reconnect()
+ if reconnect:
+ self.reconnect()
raise
def recv(self, flags=None, reconnect=True) -> bytearray:
diff --git a/pyignite/connection/handshake.py b/pyignite/connection/handshake.py
index 3315c4e..0b0fe50 100644
--- a/pyignite/connection/handshake.py
+++ b/pyignite/connection/handshake.py
@@ -15,8 +15,9 @@
from typing import Optional, Tuple
-from pyignite.datatypes import Byte, Int, Short, String
+from pyignite.datatypes import Byte, Int, Short, String, UUIDObject
from pyignite.datatypes.internal import Struct
+from pyignite.stream import READ_BACKWARD
OP_HANDSHAKE = 1
@@ -51,6 +52,12 @@
self.handshake_struct = Struct(fields)
def from_python(self, stream):
+ self.handshake_struct.from_python(stream, self.__create_handshake_data())
+
+ async def from_python_async(self, stream):
+ await self.handshake_struct.from_python_async(stream, self.__create_handshake_data())
+
+ def __create_handshake_data(self):
handshake_data = {
'length': 8,
'op_code': OP_HANDSHAKE,
@@ -69,5 +76,66 @@
len(self.username),
len(self.password),
])
+ return handshake_data
- self.handshake_struct.from_python(stream, handshake_data)
+
+class HandshakeResponse(dict):
+ """
+ Handshake response.
+ """
+ __response_start = Struct([
+ ('length', Int),
+ ('op_code', Byte),
+ ])
+
+ def __init__(self, data):
+ super().__init__()
+ self.update(data)
+
+ def __getattr__(self, item):
+ return self.get(item)
+
+ @classmethod
+ def parse(cls, stream, protocol_version):
+ start_class = cls.__response_start.parse(stream)
+ start = stream.read_ctype(start_class, direction=READ_BACKWARD)
+ data = cls.__response_start.to_python(start)
+
+ response_end = cls.__create_response_end(data, protocol_version)
+ if response_end:
+ end_class = response_end.parse(stream)
+ end = stream.read_ctype(end_class, direction=READ_BACKWARD)
+ data.update(response_end.to_python(end))
+
+ return cls(data)
+
+ @classmethod
+ async def parse_async(cls, stream, protocol_version):
+ start_class = cls.__response_start.parse(stream)
+ start = stream.read_ctype(start_class, direction=READ_BACKWARD)
+ data = await cls.__response_start.to_python_async(start)
+
+ response_end = cls.__create_response_end(data, protocol_version)
+ if response_end:
+ end_class = await response_end.parse_async(stream)
+ end = stream.read_ctype(end_class, direction=READ_BACKWARD)
+ data.update(await response_end.to_python_async(end))
+
+ return cls(data)
+
+ @classmethod
+ def __create_response_end(cls, start_data, protocol_version):
+ response_end = None
+ if start_data['op_code'] == 0:
+ response_end = Struct([
+ ('version_major', Short),
+ ('version_minor', Short),
+ ('version_patch', Short),
+ ('message', String),
+ ('client_status', Int)
+ ])
+ elif protocol_version >= (1, 4, 0):
+ response_end = Struct([
+ ('node_uuid', UUIDObject),
+ ])
+ return response_end
diff --git a/pyignite/connection/ssl.py b/pyignite/connection/ssl.py
index 9773860..385b414 100644
--- a/pyignite/connection/ssl.py
+++ b/pyignite/connection/ssl.py
@@ -16,34 +16,62 @@
import ssl
from ssl import SSLContext
-from pyignite.constants import *
+from pyignite.constants import SSL_DEFAULT_CIPHERS, SSL_DEFAULT_VERSION
+from pyignite.exceptions import ParameterError
-def wrap(conn: 'Connection', _socket):
+def wrap(socket, ssl_params):
""" Wrap socket in SSL wrapper. """
- if conn.ssl_params.get('use_ssl', None):
- keyfile = conn.ssl_params.get('ssl_keyfile', None)
- certfile = conn.ssl_params.get('ssl_certfile', None)
+ if not ssl_params.get('use_ssl'):
+ return socket
- if keyfile and not certfile:
- raise ValueError("certfile must be specified")
+ context = create_ssl_context(ssl_params)
- password = conn.ssl_params.get('ssl_keyfile_password', None)
- ssl_version = conn.ssl_params.get('ssl_version', SSL_DEFAULT_VERSION)
- ciphers = conn.ssl_params.get('ssl_ciphers', SSL_DEFAULT_CIPHERS)
- cert_reqs = conn.ssl_params.get('ssl_cert_reqs', ssl.CERT_NONE)
- ca_certs = conn.ssl_params.get('ssl_ca_certfile', None)
+ return context.wrap_socket(sock=socket)
- context = SSLContext(ssl_version)
- context.verify_mode = cert_reqs
- if ca_certs:
- context.load_verify_locations(ca_certs)
- if certfile:
- context.load_cert_chain(certfile, keyfile, password)
- if ciphers:
- context.set_ciphers(ciphers)
+def check_ssl_params(params):
+ expected_args = [
+ 'use_ssl',
+ 'ssl_version',
+ 'ssl_ciphers',
+ 'ssl_cert_reqs',
+ 'ssl_keyfile',
+ 'ssl_keyfile_password',
+ 'ssl_certfile',
+ 'ssl_ca_certfile',
+ ]
+ for param in params:
+ if param not in expected_args:
+ raise ParameterError((
+ 'Unexpected parameter for connection initialization: `{}`'
+ ).format(param))
- _socket = context.wrap_socket(sock=_socket)
- return _socket
+def create_ssl_context(ssl_params):
+ if not ssl_params.get('use_ssl'):
+ return None
+
+ keyfile = ssl_params.get('ssl_keyfile', None)
+ certfile = ssl_params.get('ssl_certfile', None)
+
+ if keyfile and not certfile:
+ raise ValueError("certfile must be specified")
+
+ password = ssl_params.get('ssl_keyfile_password', None)
+ ssl_version = ssl_params.get('ssl_version', SSL_DEFAULT_VERSION)
+ ciphers = ssl_params.get('ssl_ciphers', SSL_DEFAULT_CIPHERS)
+ cert_reqs = ssl_params.get('ssl_cert_reqs', ssl.CERT_NONE)
+ ca_certs = ssl_params.get('ssl_ca_certfile', None)
+
+ context = SSLContext(ssl_version)
+ context.verify_mode = cert_reqs
+
+ if ca_certs:
+ context.load_verify_locations(ca_certs)
+ if certfile:
+ context.load_cert_chain(certfile, keyfile, password)
+ if ciphers:
+ context.set_ciphers(ciphers)
+
+ return context
diff --git a/pyignite/cursors.py b/pyignite/cursors.py
new file mode 100644
index 0000000..c699556
--- /dev/null
+++ b/pyignite/cursors.py
@@ -0,0 +1,319 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This module contains sync and async cursors for different types of queries.
+"""
+
+import asyncio
+
+from pyignite.api import (
+ scan, scan_cursor_get_page, resource_close, scan_async, scan_cursor_get_page_async, resource_close_async, sql,
+ sql_cursor_get_page, sql_fields, sql_fields_cursor_get_page, sql_fields_cursor_get_page_async, sql_fields_async
+)
+from pyignite.exceptions import CacheError, SQLError
+
+
+__all__ = ['ScanCursor', 'SqlCursor', 'SqlFieldsCursor', 'AioScanCursor', 'AioSqlFieldsCursor']
+
+
+class BaseCursorMixin:
+ @property
+ def connection(self):
+ return getattr(self, '_conn', None)
+
+ @connection.setter
+ def connection(self, value):
+ setattr(self, '_conn', value)
+
+ @property
+ def cursor_id(self):
+ return getattr(self, '_cursor_id', None)
+
+ @cursor_id.setter
+ def cursor_id(self, value):
+ setattr(self, '_cursor_id', value)
+
+ @property
+ def more(self):
+ return getattr(self, '_more', None)
+
+ @more.setter
+ def more(self, value):
+ setattr(self, '_more', value)
+
+ @property
+ def cache_id(self):
+ return getattr(self, '_cache_id', None)
+
+ @cache_id.setter
+ def cache_id(self, value):
+ setattr(self, '_cache_id', value)
+
+ @property
+ def client(self):
+ return getattr(self, '_client', None)
+
+ @client.setter
+ def client(self, value):
+ setattr(self, '_client', value)
+
+ @property
+ def data(self):
+ return getattr(self, '_data', None)
+
+ @data.setter
+ def data(self, value):
+ setattr(self, '_data', value)
+
+
+class CursorMixin(BaseCursorMixin):
+ def __enter__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self):
+ if self.connection and self.cursor_id and self.more:
+ resource_close(self.connection, self.cursor_id)
+
+
+class AioCursorMixin(BaseCursorMixin):
+ def __await__(self):
+ return (yield from self.__aenter__().__await__())
+
+ def __aiter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ async def close(self):
+ if self.connection and self.cursor_id and self.more:
+ await resource_close_async(self.connection, self.cursor_id)
+
+
+class AbstractScanCursor:
+ def __init__(self, client, cache_id, page_size, partitions, local):
+ self.client = client
+ self.cache_id = cache_id
+ self._page_size = page_size
+ self._partitions = partitions
+ self._local = local
+
+ def _finalize_init(self, result):
+ if result.status != 0:
+ raise CacheError(result.message)
+
+ self.cursor_id, self.more = result.value['cursor'], result.value['more']
+ self.data = iter(result.value['data'].items())
+
+ def _process_page_response(self, result):
+ if result.status != 0:
+ raise CacheError(result.message)
+
+ self.data, self.more = iter(result.value['data'].items()), result.value['more']
+
+
+class ScanCursor(AbstractScanCursor, CursorMixin):
+ def __init__(self, client, cache_id, page_size, partitions, local):
+ super().__init__(client, cache_id, page_size, partitions, local)
+
+ self.connection = self.client.random_node
+ result = scan(self.connection, self.cache_id, self._page_size, self._partitions, self._local)
+ self._finalize_init(result)
+
+ def __next__(self):
+ if not self.data:
+ raise StopIteration
+
+ try:
+ k, v = next(self.data)
+ except StopIteration:
+ if self.more:
+ self._process_page_response(scan_cursor_get_page(self.connection, self.cursor_id))
+ k, v = next(self.data)
+ else:
+ raise StopIteration
+
+ return self.client.unwrap_binary(k), self.client.unwrap_binary(v)
+
+
+class AioScanCursor(AbstractScanCursor, AioCursorMixin):
+ def __init__(self, client, cache_id, page_size, partitions, local):
+ super().__init__(client, cache_id, page_size, partitions, local)
+
+ async def __aenter__(self):
+ if not self.connection:
+ self.connection = await self.client.random_node()
+ result = await scan_async(self.connection, self.cache_id, self._page_size, self._partitions, self._local)
+ self._finalize_init(result)
+ return self
+
+ async def __anext__(self):
+ if not self.connection:
+ raise CacheError("Using uninitialized cursor, initialize it using async with expression.")
+
+ if not self.data:
+ raise StopAsyncIteration
+
+ try:
+ k, v = next(self.data)
+ except StopIteration:
+ if self.more:
+ self._process_page_response(await scan_cursor_get_page_async(self.connection, self.cursor_id))
+ try:
+ k, v = next(self.data)
+ except StopIteration:
+ raise StopAsyncIteration
+ else:
+ raise StopAsyncIteration
+
+ return await asyncio.gather(
+ *[self.client.unwrap_binary(k), self.client.unwrap_binary(v)]
+ )
+
+
+class SqlCursor(CursorMixin):
+ def __init__(self, client, cache_id, *args, **kwargs):
+ self.client = client
+ self.cache_id = cache_id
+ self.connection = self.client.random_node
+ result = sql(self.connection, self.cache_id, *args, **kwargs)
+ if result.status != 0:
+ raise SQLError(result.message)
+
+ self.cursor_id, self.more = result.value['cursor'], result.value['more']
+ self.data = iter(result.value['data'].items())
+
+ def __next__(self):
+ if not self.data:
+ raise StopIteration
+
+ try:
+ k, v = next(self.data)
+ except StopIteration:
+ if self.more:
+ result = sql_cursor_get_page(self.connection, self.cursor_id)
+ if result.status != 0:
+ raise SQLError(result.message)
+ self.data, self.more = iter(result.value['data'].items()), result.value['more']
+
+ k, v = next(self.data)
+ else:
+ raise StopIteration
+
+ return self.client.unwrap_binary(k), self.client.unwrap_binary(v)
+
+
+class AbstractSqlFieldsCursor:
+ def __init__(self, client, cache_id):
+ self.client = client
+ self.cache_id = cache_id
+
+ def _finalize_init(self, result):
+ if result.status != 0:
+ raise SQLError(result.message)
+
+ self.cursor_id, self.more = result.value['cursor'], result.value['more']
+ self.data = iter(result.value['data'])
+ self._field_names = result.value.get('fields', None)
+ if self._field_names:
+ self._field_count = len(self._field_names)
+ else:
+ self._field_count = result.value['field_count']
+
+
+class SqlFieldsCursor(AbstractSqlFieldsCursor, CursorMixin):
+ def __init__(self, client, cache_id, *args, **kwargs):
+ super().__init__(client, cache_id)
+ self.connection = self.client.random_node
+ self._finalize_init(sql_fields(self.connection, self.cache_id, *args, **kwargs))
+
+ def __next__(self):
+ if not self.data:
+ raise StopIteration
+
+ if self._field_names:
+ result = self._field_names
+ self._field_names = None
+ return result
+
+ try:
+ row = next(self.data)
+ except StopIteration:
+ if self.more:
+ result = sql_fields_cursor_get_page(self.connection, self.cursor_id, self._field_count)
+ if result.status != 0:
+ raise SQLError(result.message)
+
+ self.data, self.more = iter(result.value['data']), result.value['more']
+
+ row = next(self.data)
+ else:
+ raise StopIteration
+
+ return [self.client.unwrap_binary(v) for v in row]
+
+
+class AioSqlFieldsCursor(AbstractSqlFieldsCursor, AioCursorMixin):
+ def __init__(self, client, cache_id, *args, **kwargs):
+ super().__init__(client, cache_id)
+ self._params = (args, kwargs)
+
+ async def __aenter__(self):
+ await self._initialize(*self._params[0], *self._params[1])
+ return self
+
+ async def __anext__(self):
+ if not self.connection:
+ raise SQLError("Attempting to use uninitialized aio cursor, please await on it or use with expression.")
+
+ if not self.data:
+ raise StopAsyncIteration
+
+ if self._field_names:
+ result = self._field_names
+ self._field_names = None
+ return result
+
+ try:
+ row = next(self.data)
+ except StopIteration:
+ if self.more:
+ result = await sql_fields_cursor_get_page_async(self.connection, self.cursor_id, self._field_count)
+ if result.status != 0:
+ raise SQLError(result.message)
+
+ self.data, self.more = iter(result.value['data']), result.value['more']
+ try:
+ row = next(self.data)
+ except StopIteration:
+ raise StopAsyncIteration
+ else:
+ raise StopAsyncIteration
+
+ return await asyncio.gather(*[self.client.unwrap_binary(v) for v in row])
+
+ async def _initialize(self, *args, **kwargs):
+ if self.connection and self.cursor_id:
+ return
+
+ self.connection = await self.client.random_node()
+ self._finalize_init(await sql_fields_async(self.connection, self.cache_id, *args, **kwargs))
diff --git a/pyignite/datatypes/__init__.py b/pyignite/datatypes/__init__.py
index 49860bd..5024f79 100644
--- a/pyignite/datatypes/__init__.py
+++ b/pyignite/datatypes/__init__.py
@@ -25,22 +25,3 @@
from .primitive_arrays import *
from .primitive_objects import *
from .standard import *
-from ..stream import BinaryStream, READ_BACKWARD
-
-
-def unwrap_binary(client: 'Client', wrapped: tuple) -> object:
- """
- Unwrap wrapped BinaryObject and convert it to Python data.
-
- :param client: connection to Ignite cluster,
- :param wrapped: `WrappedDataObject` value,
- :return: dict representing wrapped BinaryObject.
- """
- from pyignite.datatypes.complex import BinaryObject
-
- blob, offset = wrapped
- with BinaryStream(client.random_node, blob) as stream:
- data_class = BinaryObject.parse(stream)
- result = BinaryObject.to_python(stream.read_ctype(data_class, direction=READ_BACKWARD), client)
-
- return result
diff --git a/pyignite/datatypes/base.py b/pyignite/datatypes/base.py
index 25b5b1e..fbd798b 100644
--- a/pyignite/datatypes/base.py
+++ b/pyignite/datatypes/base.py
@@ -47,4 +47,34 @@
This is a base class for all Ignite data types, a.k.a. parser/constructor
classes, both object and payload varieties.
"""
- pass
+ @classmethod
+ async def hashcode_async(cls, value, *args, **kwargs):
+ return cls.hashcode(value, *args, **kwargs)
+
+ @classmethod
+ def hashcode(cls, value, *args, **kwargs):
+ return 0
+
+ @classmethod
+ def parse(cls, stream):
+ raise NotImplementedError
+
+ @classmethod
+ async def parse_async(cls, stream):
+ return cls.parse(stream)
+
+ @classmethod
+ def from_python(cls, stream, value, **kwargs):
+ raise NotImplementedError
+
+ @classmethod
+ async def from_python_async(cls, stream, value, **kwargs):
+ cls.from_python(stream, value, **kwargs)
+
+ @classmethod
+ def to_python(cls, ctype_object, *args, **kwargs):
+ raise NotImplementedError
+
+ @classmethod
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ return cls.to_python(ctype_object, *args, **kwargs)
diff --git a/pyignite/datatypes/cache_properties.py b/pyignite/datatypes/cache_properties.py
index eadaef9..127b6f3 100644
--- a/pyignite/datatypes/cache_properties.py
+++ b/pyignite/datatypes/cache_properties.py
@@ -23,7 +23,6 @@
from .primitive import *
from .standard import *
-
__all__ = [
'PropName', 'PropCacheMode', 'PropCacheAtomicityMode', 'PropBackupsNumber',
'PropWriteSynchronizationMode', 'PropCopyOnRead', 'PropReadFromBackup',
@@ -81,7 +80,7 @@
@classmethod
def build_header(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -112,10 +111,16 @@
return prop_class
@classmethod
+ async def parse_async(cls, stream):
+ return cls.parse(stream)
+
+ @classmethod
def to_python(cls, ctype_object, *args, **kwargs):
- return cls.prop_data_class.to_python(
- ctype_object.data, *args, **kwargs
- )
+ return cls.prop_data_class.to_python(ctype_object.data, *args, **kwargs)
+
+ @classmethod
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ return cls.to_python(ctype_object, *args, **kwargs)
@classmethod
def from_python(cls, stream, value):
@@ -125,6 +130,10 @@
stream.write(bytes(header))
cls.prop_data_class.from_python(stream, value)
+ @classmethod
+ async def from_python_async(cls, stream, value):
+ return cls.from_python(stream, value)
+
class PropName(PropBase):
prop_code = PROP_NAME
diff --git a/pyignite/datatypes/complex.py b/pyignite/datatypes/complex.py
index b8d9c02..5cb6160 100644
--- a/pyignite/datatypes/complex.py
+++ b/pyignite/datatypes/complex.py
@@ -12,30 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import asyncio
from collections import OrderedDict
import ctypes
from io import SEEK_CUR
-from typing import Iterable, Dict
+from typing import Optional
from pyignite.constants import *
from pyignite.exceptions import ParseError
-from .base import IgniteDataType
-from .internal import AnyDataObject, infer_from_python
+from .internal import AnyDataObject, Struct, infer_from_python, infer_from_python_async
from .type_codes import *
from .type_ids import *
from .type_names import *
from .null_object import Null, Nullable
+from ..stream import AioBinaryStream, BinaryStream
-__all__ = [
- 'Map', 'ObjectArrayObject', 'CollectionObject', 'MapObject',
- 'WrappedDataObject', 'BinaryObject',
-]
-
-from ..stream import BinaryStream
+__all__ = ['Map', 'ObjectArrayObject', 'CollectionObject', 'MapObject', 'WrappedDataObject', 'BinaryObject']
-class ObjectArrayObject(IgniteDataType, Nullable):
+class ObjectArrayObject(Nullable):
"""
Array of Ignite objects of any consistent type. Its Python representation
is tuple(type_id, iterable of any type). The only type ID that makes sense
@@ -48,15 +43,10 @@
_type_id = TYPE_OBJ_ARR
type_code = TC_OBJECT_ARRAY
- @staticmethod
- def hashcode(value: Iterable) -> int:
- # Arrays are not supported as keys at the moment.
- return 0
-
@classmethod
def build_header(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -70,16 +60,36 @@
@classmethod
def parse_not_null(cls, stream):
- header_class = cls.build_header()
- header = stream.read_ctype(header_class)
- stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ header, header_class = cls.__parse_header(stream)
fields = []
for i in range(header.length):
c_type = AnyDataObject.parse(stream)
fields.append(('element_{}'.format(i), c_type))
- final_class = type(
+ return cls.__build_final_class(header_class, fields)
+
+ @classmethod
+ async def parse_not_null_async(cls, stream):
+ header, header_class = cls.__parse_header(stream)
+
+ fields = []
+ for i in range(header.length):
+ c_type = await AnyDataObject.parse_async(stream)
+ fields.append(('element_{}'.format(i), c_type))
+
+ return cls.__build_final_class(header_class, fields)
+
+ @classmethod
+ def __parse_header(cls, stream):
+ header_class = cls.build_header()
+ header = stream.read_ctype(header_class)
+ stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ return header, header_class
+
+ @classmethod
+ def __build_final_class(cls, header_class, fields):
+ return type(
cls.__name__,
(header_class,),
{
@@ -88,8 +98,6 @@
}
)
- return final_class
-
@classmethod
def to_python_not_null(cls, ctype_object, *args, **kwargs):
result = []
@@ -103,28 +111,55 @@
return ctype_object.type_id, result
@classmethod
- def from_python_not_null(cls, stream, value):
+ async def to_python_not_null_async(cls, ctype_object, *args, **kwargs):
+ result = [
+ await AnyDataObject.to_python_async(
+ getattr(ctype_object, 'element_{}'.format(i)), *args, **kwargs
+ )
+ for i in range(ctype_object.length)]
+ return ctype_object.type_id, result
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, *args, **kwargs):
type_or_id, value = value
+ try:
+ length = len(value)
+ except TypeError:
+ value = [value]
+ length = 1
+
+ cls.__write_header(stream, type_or_id, length)
+ for x in value:
+ infer_from_python(stream, x)
+
+ @classmethod
+ async def from_python_not_null_async(cls, stream, value, *args, **kwargs):
+ type_or_id, value = value
+ try:
+ length = len(value)
+ except TypeError:
+ value = [value]
+ length = 1
+
+ cls.__write_header(stream, type_or_id, length)
+ for x in value:
+ await infer_from_python_async(stream, x)
+
+ @classmethod
+ def __write_header(cls, stream, type_or_id, length):
header_class = cls.build_header()
header = header_class()
header.type_code = int.from_bytes(
cls.type_code,
byteorder=PROTOCOL_BYTE_ORDER
)
- try:
- length = len(value)
- except TypeError:
- value = [value]
- length = 1
header.length = length
header.type_id = type_or_id
stream.write(header)
- for x in value:
- infer_from_python(stream, x)
-class WrappedDataObject(IgniteDataType, Nullable):
+class WrappedDataObject(Nullable):
"""
One or more binary objects can be wrapped in an array. This allows reading,
storing, passing and writing objects efficiently without understanding
@@ -138,7 +173,7 @@
@classmethod
def build_header(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -160,7 +195,7 @@
{
'_pack_': 1,
'_fields_': [
- ('payload', ctypes.c_byte*header.length),
+ ('payload', ctypes.c_byte * header.length),
('offset', ctypes.c_int),
],
}
@@ -170,15 +205,15 @@
return final_class
@classmethod
- def to_python(cls, ctype_object, *args, **kwargs):
+ def to_python_not_null(cls, ctype_object, *args, **kwargs):
return bytes(ctype_object.payload), ctype_object.offset
@classmethod
- def from_python(cls, stream, value):
+ def from_python(cls, stream, value, *args, **kwargs):
raise ParseError('Send unwrapped data.')
-class CollectionObject(IgniteDataType, Nullable):
+class CollectionObject(Nullable):
"""
Similar to object array, but contains platform-agnostic deserialization
type hint instead of type ID.
@@ -220,15 +255,10 @@
pythonic = list
default = []
- @staticmethod
- def hashcode(value: Iterable) -> int:
- # Collections are not supported as keys at the moment.
- return 0
-
@classmethod
def build_header(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -242,16 +272,36 @@
@classmethod
def parse_not_null(cls, stream):
- header_class = cls.build_header()
- header = stream.read_ctype(header_class)
- stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ header, header_class = cls.__parse_header(stream)
fields = []
for i in range(header.length):
c_type = AnyDataObject.parse(stream)
fields.append(('element_{}'.format(i), c_type))
- final_class = type(
+ return cls.__build_final_class(header_class, fields)
+
+ @classmethod
+ async def parse_not_null_async(cls, stream):
+ header, header_class = cls.__parse_header(stream)
+
+ fields = []
+ for i in range(header.length):
+ c_type = await AnyDataObject.parse_async(stream)
+ fields.append(('element_{}'.format(i), c_type))
+
+ return cls.__build_final_class(header_class, fields)
+
+ @classmethod
+ def __parse_header(cls, stream):
+ header_class = cls.build_header()
+ header = stream.read_ctype(header_class)
+ stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ return header, header_class
+
+ @classmethod
+ def __build_final_class(cls, header_class, fields):
+ return type(
cls.__name__,
(header_class,),
{
@@ -259,46 +309,78 @@
'_fields_': fields,
}
)
- return final_class
@classmethod
def to_python(cls, ctype_object, *args, **kwargs):
- result = []
- length = getattr(ctype_object, "length", None)
+ length = cls.__get_length(ctype_object)
if length is None:
return None
- for i in range(length):
- result.append(
- AnyDataObject.to_python(
- getattr(ctype_object, 'element_{}'.format(i)),
- *args, **kwargs
- )
- )
+
+ result = [
+ AnyDataObject.to_python(getattr(ctype_object, f'element_{i}'), *args, **kwargs)
+ for i in range(length)
+ ]
return ctype_object.type, result
@classmethod
- def from_python_not_null(cls, stream, value):
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ length = cls.__get_length(ctype_object)
+ if length is None:
+ return None
+
+ result_coro = [
+ AnyDataObject.to_python_async(getattr(ctype_object, f'element_{i}'), *args, **kwargs)
+ for i in range(length)
+ ]
+
+ return ctype_object.type, await asyncio.gather(*result_coro)
+
+ @classmethod
+ def __get_length(cls, ctype_object):
+ return getattr(ctype_object, "length", None)
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, *args, **kwargs):
type_or_id, value = value
+ try:
+ length = len(value)
+ except TypeError:
+ value = [value]
+ length = 1
+
+ cls.__write_header(stream, type_or_id, length)
+ for x in value:
+ infer_from_python(stream, x)
+
+ @classmethod
+ async def from_python_not_null_async(cls, stream, value, *args, **kwargs):
+ type_or_id, value = value
+ try:
+ length = len(value)
+ except TypeError:
+ value = [value]
+ length = 1
+
+ cls.__write_header(stream, type_or_id, length)
+ for x in value:
+ await infer_from_python_async(stream, x)
+
+ @classmethod
+ def __write_header(cls, stream, type_or_id, length):
header_class = cls.build_header()
header = header_class()
header.type_code = int.from_bytes(
cls.type_code,
byteorder=PROTOCOL_BYTE_ORDER
)
- try:
- length = len(value)
- except TypeError:
- value = [value]
- length = 1
+
header.length = length
header.type = type_or_id
stream.write(header)
- for x in value:
- infer_from_python(stream, x)
-class Map(IgniteDataType, Nullable):
+class Map(Nullable):
"""
Dictionary type, payload-only.
@@ -310,15 +392,10 @@
HASH_MAP = 1
LINKED_HASH_MAP = 2
- @staticmethod
- def hashcode(value: Dict) -> int:
- # Maps are not supported as keys at the moment.
- return 0
-
@classmethod
def build_header(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -330,16 +407,36 @@
@classmethod
def parse_not_null(cls, stream):
- header_class = cls.build_header()
- header = stream.read_ctype(header_class)
- stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ header, header_class = cls.__parse_header(stream)
fields = []
for i in range(header.length << 1):
c_type = AnyDataObject.parse(stream)
fields.append(('element_{}'.format(i), c_type))
- final_class = type(
+ return cls.__build_final_class(header_class, fields)
+
+ @classmethod
+ async def parse_not_null_async(cls, stream):
+ header, header_class = cls.__parse_header(stream)
+
+ fields = []
+ for i in range(header.length << 1):
+ c_type = await AnyDataObject.parse_async(stream)
+ fields.append(('element_{}'.format(i), c_type))
+
+ return cls.__build_final_class(header_class, fields)
+
+ @classmethod
+ def __parse_header(cls, stream):
+ header_class = cls.build_header()
+ header = stream.read_ctype(header_class)
+ stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ return header, header_class
+
+ @classmethod
+ def __build_final_class(cls, header_class, fields):
+ return type(
cls.__name__,
(header_class,),
{
@@ -347,43 +444,75 @@
'_fields_': fields,
}
)
- return final_class
@classmethod
def to_python(cls, ctype_object, *args, **kwargs):
- map_type = getattr(ctype_object, 'type', cls.HASH_MAP)
- result = OrderedDict() if map_type == cls.LINKED_HASH_MAP else {}
+ map_cls = cls.__get_map_class(ctype_object)
+ result = map_cls()
for i in range(0, ctype_object.length << 1, 2):
k = AnyDataObject.to_python(
- getattr(ctype_object, 'element_{}'.format(i)),
- *args, **kwargs
- )
+ getattr(ctype_object, 'element_{}'.format(i)),
+ *args, **kwargs
+ )
v = AnyDataObject.to_python(
- getattr(ctype_object, 'element_{}'.format(i + 1)),
- *args, **kwargs
- )
+ getattr(ctype_object, 'element_{}'.format(i + 1)),
+ *args, **kwargs
+ )
result[k] = v
return result
@classmethod
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ map_cls = cls.__get_map_class(ctype_object)
+
+ kv_pairs_coro = [
+ asyncio.gather(
+ AnyDataObject.to_python_async(
+ getattr(ctype_object, 'element_{}'.format(i)),
+ *args, **kwargs
+ ),
+ AnyDataObject.to_python_async(
+ getattr(ctype_object, 'element_{}'.format(i + 1)),
+ *args, **kwargs
+ )
+ ) for i in range(0, ctype_object.length << 1, 2)
+ ]
+
+ return map_cls(await asyncio.gather(*kv_pairs_coro))
+
+ @classmethod
+ def __get_map_class(cls, ctype_object):
+ map_type = getattr(ctype_object, 'type', cls.HASH_MAP)
+ return OrderedDict if map_type == cls.LINKED_HASH_MAP else dict
+
+ @classmethod
def from_python(cls, stream, value, type_id=None):
+ cls.__write_header(stream, type_id, len(value))
+ for k, v in value.items():
+ infer_from_python(stream, k)
+ infer_from_python(stream, v)
+
+ @classmethod
+ async def from_python_async(cls, stream, value, type_id=None):
+ cls.__write_header(stream, type_id, len(value))
+ for k, v in value.items():
+ await infer_from_python_async(stream, k)
+ await infer_from_python_async(stream, v)
+
+ @classmethod
+ def __write_header(cls, stream, type_id, length):
header_class = cls.build_header()
header = header_class()
- length = len(value)
header.length = length
+
if hasattr(header, 'type_code'):
- header.type_code = int.from_bytes(
- cls.type_code,
- byteorder=PROTOCOL_BYTE_ORDER
- )
+ header.type_code = int.from_bytes(cls.type_code, byteorder=PROTOCOL_BYTE_ORDER)
+
if hasattr(header, 'type'):
header.type = type_id
stream.write(header)
- for k, v in value.items():
- infer_from_python(stream, k)
- infer_from_python(stream, v)
class MapObject(Map):
@@ -404,7 +533,7 @@
@classmethod
def build_header(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -419,23 +548,43 @@
@classmethod
def to_python(cls, ctype_object, *args, **kwargs):
obj_type = getattr(ctype_object, "type", None)
- if obj_type is None:
- return None
- return obj_type, super().to_python(
- ctype_object, *args, **kwargs
- )
+ if obj_type:
+ return obj_type, super().to_python(ctype_object, *args, **kwargs)
+ return None
@classmethod
- def from_python(cls, stream, value):
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ obj_type = getattr(ctype_object, "type", None)
+ if obj_type:
+ return obj_type, await super().to_python_async(ctype_object, *args, **kwargs)
+ return None
+
+ @classmethod
+ def __get_obj_type(cls, ctype_object):
+ return getattr(ctype_object, "type", None)
+
+ @classmethod
+ def from_python(cls, stream, value, **kwargs):
+ type_id, value = cls.__unpack_value(stream, value)
+ if value:
+ super().from_python(stream, value, type_id)
+
+ @classmethod
+ async def from_python_async(cls, stream, value, **kwargs):
+ type_id, value = cls.__unpack_value(stream, value)
+ if value:
+ await super().from_python_async(stream, value, type_id)
+
+ @classmethod
+ def __unpack_value(cls, stream, value):
if value is None:
Null.from_python(stream)
- return
+ return None, None
- type_id, value = value
- super().from_python(stream, value, type_id)
+ return value
-class BinaryObject(IgniteDataType, Nullable):
+class BinaryObject(Nullable):
_type_id = TYPE_BINARY_OBJ
type_code = TC_COMPLEX_OBJECT
@@ -446,19 +595,26 @@
OFFSET_TWO_BYTES = 0x0010
COMPACT_FOOTER = 0x0020
- @staticmethod
- def hashcode(value: object, client: None) -> int:
+ @classmethod
+ def hashcode(cls, value: object, client: Optional['Client']) -> int:
# binary objects's hashcode implementation is special in the sense
# that you need to fully serialize the object to calculate
# its hashcode
- if not value._hashcode and client :
-
- with BinaryStream(client.random_node) as stream:
+ if not value._hashcode and client:
+ with BinaryStream(client) as stream:
value._from_python(stream, save_to_buf=True)
return value._hashcode
@classmethod
+ async def hashcode_async(cls, value: object, client: Optional['AioClient']) -> int:
+ if not value._hashcode and client:
+ with AioBinaryStream(client) as stream:
+ await value._from_python_async(stream, save_to_buf=True)
+
+ return value._hashcode
+
+ @classmethod
def build_header(cls):
return type(
cls.__name__,
@@ -504,22 +660,47 @@
@classmethod
def parse_not_null(cls, stream):
- from pyignite.datatypes import Struct
-
- header_class = cls.build_header()
- header = stream.read_ctype(header_class)
- stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ header, header_class = cls.__parse_header(stream)
# ignore full schema, always retrieve fields' types and order
# from complex types registry
data_class = stream.get_dataclass(header)
- fields = data_class.schema.items()
- object_fields_struct = Struct(fields)
+ object_fields_struct = cls.__build_object_fields_struct(data_class)
object_fields = object_fields_struct.parse(stream)
- final_class_fields = [('object_fields', object_fields)]
+ return cls.__build_final_class(stream, header, header_class, object_fields,
+ len(object_fields_struct.fields))
+
+ @classmethod
+ async def parse_not_null_async(cls, stream):
+ header, header_class = cls.__parse_header(stream)
+
+ # ignore full schema, always retrieve fields' types and order
+ # from complex types registry
+ data_class = await stream.get_dataclass(header)
+ object_fields_struct = cls.__build_object_fields_struct(data_class)
+ object_fields = await object_fields_struct.parse_async(stream)
+
+ return cls.__build_final_class(stream, header, header_class, object_fields,
+ len(object_fields_struct.fields))
+
+ @classmethod
+ def __parse_header(cls, stream):
+ header_class = cls.build_header()
+ header = stream.read_ctype(header_class)
+ stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ return header, header_class
+
+ @staticmethod
+ def __build_object_fields_struct(data_class):
+ fields = data_class.schema.items()
+ return Struct(fields)
+
+ @classmethod
+ def __build_final_class(cls, stream, header, header_class, object_fields, fields_len):
+ final_class_fields = [('object_fields', object_fields)]
if header.flags & cls.HAS_SCHEMA:
- schema = cls.schema_type(header.flags) * len(fields)
+ schema = cls.schema_type(header.flags) * fields_len
stream.seek(ctypes.sizeof(schema), SEEK_CUR)
final_class_fields.append(('schema', schema))
@@ -537,35 +718,71 @@
@classmethod
def to_python(cls, ctype_object, client: 'Client' = None, *args, **kwargs):
- type_id = getattr(ctype_object, "type_id", None)
- if type_id is None:
- return None
+ type_id = cls.__get_type_id(ctype_object, client)
+ if type_id:
+ data_class = client.query_binary_type(type_id, ctype_object.schema_id)
- if not client:
- raise ParseError(
- 'Can not query binary type {}'.format(type_id)
- )
-
- data_class = client.query_binary_type(
- type_id,
- ctype_object.schema_id
- )
- result = data_class()
-
- result.version = ctype_object.version
- for field_name, field_type in data_class.schema.items():
- setattr(
- result, field_name, field_type.to_python(
- getattr(ctype_object.object_fields, field_name),
- client, *args, **kwargs
+ result = data_class()
+ result.version = ctype_object.version
+ for field_name, field_type in data_class.schema.items():
+ setattr(
+ result, field_name, field_type.to_python(
+ getattr(ctype_object.object_fields, field_name),
+ client, *args, **kwargs
+ )
)
- )
- return result
+ return result
+
+ return None
@classmethod
- def from_python_not_null(cls, stream, value):
- if getattr(value, '_buffer', None):
- stream.write(value._buffer)
- else:
+ async def to_python_async(cls, ctype_object, client: 'AioClient' = None, *args, **kwargs):
+ type_id = cls.__get_type_id(ctype_object, client)
+ if type_id:
+ data_class = await client.query_binary_type(type_id, ctype_object.schema_id)
+
+ result = data_class()
+ result.version = ctype_object.version
+
+ field_values = await asyncio.gather(
+ *[
+ field_type.to_python_async(
+ getattr(ctype_object.object_fields, field_name), client, *args, **kwargs
+ )
+ for field_name, field_type in data_class.schema.items()
+ ]
+ )
+
+ for i, field_name in enumerate(data_class.schema.keys()):
+ setattr(result, field_name, field_values[i])
+
+ return result
+ return None
+
+ @classmethod
+ def __get_type_id(cls, ctype_object, client):
+ type_id = getattr(ctype_object, "type_id", None)
+ if type_id:
+ if not client:
+ raise ParseError(f'Can not query binary type {type_id}')
+ return type_id
+ return None
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, **kwargs):
+ if cls.__write_fast_path(stream, value):
stream.register_binary_type(value.__class__)
value._from_python(stream)
+
+ @classmethod
+ async def from_python_not_null_async(cls, stream, value, **kwargs):
+ if cls.__write_fast_path(stream, value):
+ await stream.register_binary_type(value.__class__)
+ await value._from_python_async(stream)
+
+ @classmethod
+ def __write_fast_path(cls, stream, value):
+ if getattr(value, '_buffer', None):
+ stream.write(value._buffer)
+ return False
+ return True
diff --git a/pyignite/datatypes/internal.py b/pyignite/datatypes/internal.py
index a6da9fe..0de50e2 100644
--- a/pyignite/datatypes/internal.py
+++ b/pyignite/datatypes/internal.py
@@ -12,26 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import asyncio
from collections import OrderedDict
import ctypes
import decimal
from datetime import date, datetime, timedelta
from io import SEEK_CUR
-from typing import Any, Tuple, Union, Callable, List
+from typing import Any, Union, Callable, List
import uuid
import attr
-from pyignite.constants import *
+from pyignite.constants import PROTOCOL_BYTE_ORDER
from pyignite.exceptions import ParseError
from pyignite.utils import is_binary, is_hinted, is_iterable
from .type_codes import *
__all__ = [
- 'AnyDataArray', 'AnyDataObject', 'Struct', 'StructArray', 'tc_map',
- 'infer_from_python',
+ 'AnyDataArray', 'AnyDataObject', 'Struct', 'StructArray', 'tc_map', 'infer_from_python', 'infer_from_python_async'
]
from ..stream import READ_BACKWARD
@@ -124,11 +123,25 @@
self.var2 = var2
def parse(self, stream, context):
- return self.var1.parse(stream) if self.predicate1(context) else self.var2.parse(stream)
+ if self.predicate1(context):
+ return self.var1.parse(stream)
+ return self.var2.parse(stream)
+
+ async def parse_async(self, stream, context):
+ if self.predicate1(context):
+ return await self.var1.parse_async(stream)
+ return await self.var2.parse_async(stream)
def to_python(self, ctype_object, context, *args, **kwargs):
- return self.var1.to_python(ctype_object, *args, **kwargs) if self.predicate2(context)\
- else self.var2.to_python(ctype_object, *args, **kwargs)
+ if self.predicate2(context):
+ return self.var1.to_python(ctype_object, *args, **kwargs)
+ return self.var2.to_python(ctype_object, *args, **kwargs)
+
+ async def to_python_async(self, ctype_object, context, *args, **kwargs):
+ if self.predicate2(context):
+ return await self.var1.to_python_async(ctype_object, *args, **kwargs)
+ return await self.var2.to_python_async(ctype_object, *args, **kwargs)
+
@attr.s
class StructArray:
@@ -139,7 +152,7 @@
def build_header_class(self):
return type(
- self.__class__.__name__+'Header',
+ self.__class__.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -150,19 +163,34 @@
)
def parse(self, stream):
+ fields, length = [], self.__parse_length(stream)
+
+ for i in range(length):
+ c_type = Struct(self.following).parse(stream)
+ fields.append(('element_{}'.format(i), c_type))
+
+ return self.__build_final_class(fields)
+
+ async def parse_async(self, stream):
+ fields, length = [], self.__parse_length(stream)
+
+ for i in range(length):
+ c_type = await Struct(self.following).parse_async(stream)
+ fields.append(('element_{}'.format(i), c_type))
+
+ return self.__build_final_class(fields)
+
+ def __parse_length(self, stream):
counter_type_len = ctypes.sizeof(self.counter_type)
length = int.from_bytes(
stream.mem_view(offset=counter_type_len),
byteorder=PROTOCOL_BYTE_ORDER
)
stream.seek(counter_type_len, SEEK_CUR)
+ return length
- fields = []
- for i in range(length):
- c_type = Struct(self.following).parse(stream)
- fields.append(('element_{}'.format(i), c_type))
-
- data_class = type(
+ def __build_final_class(self, fields):
+ return type(
'StructArray',
(self.build_header_class(),),
{
@@ -171,36 +199,47 @@
},
)
- return data_class
-
def to_python(self, ctype_object, *args, **kwargs):
- result = []
length = getattr(ctype_object, 'length', 0)
- for i in range(length):
- result.append(
- Struct(
- self.following, dict_type=dict
- ).to_python(
- getattr(ctype_object, 'element_{}'.format(i)),
- *args, **kwargs
- )
- )
- return result
+ return [
+ Struct(self.following, dict_type=dict).to_python(getattr(ctype_object, 'element_{}'.format(i)),
+ *args, **kwargs)
+ for i in range(length)
+ ]
+
+ async def to_python_async(self, ctype_object, *args, **kwargs):
+ length = getattr(ctype_object, 'length', 0)
+ result_coro = [
+ Struct(self.following, dict_type=dict).to_python_async(getattr(ctype_object, 'element_{}'.format(i)),
+ *args, **kwargs)
+ for i in range(length)
+ ]
+ return await asyncio.gather(*result_coro)
def from_python(self, stream, value):
- length = len(value)
- header_class = self.build_header_class()
- header = header_class()
- header.length = length
+ self.__write_header(stream, len(value))
-
- stream.write(header)
- for i, v in enumerate(value):
+ for v in value:
for default_key, default_value in self.defaults.items():
v.setdefault(default_key, default_value)
for name, el_class in self.following:
el_class.from_python(stream, v[name])
+ async def from_python_async(self, stream, value):
+ self.__write_header(stream, len(value))
+
+ for v in value:
+ for default_key, default_value in self.defaults.items():
+ v.setdefault(default_key, default_value)
+ for name, el_class in self.following:
+ await el_class.from_python_async(stream, v[name])
+
+ def __write_header(self, stream, length):
+ header_class = self.build_header_class()
+ header = header_class()
+ header.length = length
+ stream.write(header)
+
@attr.s
class Struct:
@@ -210,12 +249,7 @@
defaults = attr.ib(type=dict, default={})
def parse(self, stream):
- fields, ctx = [], {}
-
- for _, c_type in self.fields:
- if isinstance(c_type, Conditional):
- for name in c_type.fields:
- ctx[name] = None
+ fields, ctx = [], self.__prepare_conditional_ctx()
for name, c_type in self.fields:
is_cond = isinstance(c_type, Conditional)
@@ -224,7 +258,31 @@
if name in ctx:
ctx[name] = stream.read_ctype(c_type, direction=READ_BACKWARD)
- data_class = type(
+ return self.__build_final_class(fields)
+
+ async def parse_async(self, stream):
+ fields, ctx = [], self.__prepare_conditional_ctx()
+
+ for name, c_type in self.fields:
+ is_cond = isinstance(c_type, Conditional)
+ c_type = await c_type.parse_async(stream, ctx) if is_cond else await c_type.parse_async(stream)
+ fields.append((name, c_type))
+ if name in ctx:
+ ctx[name] = stream.read_ctype(c_type, direction=READ_BACKWARD)
+
+ return self.__build_final_class(fields)
+
+ def __prepare_conditional_ctx(self):
+ ctx = {}
+ for _, c_type in self.fields:
+ if isinstance(c_type, Conditional):
+ for name in c_type.fields:
+ ctx[name] = None
+ return ctx
+
+ @staticmethod
+ def __build_final_class(fields):
+ return type(
'Struct',
(ctypes.LittleEndianStructure,),
{
@@ -233,11 +291,7 @@
},
)
- return data_class
-
- def to_python(
- self, ctype_object, *args, **kwargs
- ) -> Union[dict, OrderedDict]:
+ def to_python(self, ctype_object, *args, **kwargs) -> Union[dict, OrderedDict]:
result = self.dict_type()
for name, c_type in self.fields:
is_cond = isinstance(c_type, Conditional)
@@ -251,13 +305,41 @@
)
return result
+ async def to_python_async(self, ctype_object, *args, **kwargs) -> Union[dict, OrderedDict]:
+ result = self.dict_type()
+ for name, c_type in self.fields:
+ is_cond = isinstance(c_type, Conditional)
+
+ if is_cond:
+ value = await c_type.to_python_async(
+ getattr(ctype_object, name),
+ result,
+ *args, **kwargs
+ )
+ else:
+ value = await c_type.to_python_async(
+ getattr(ctype_object, name),
+ *args, **kwargs
+ )
+ result[name] = value
+ return result
+
def from_python(self, stream, value):
- for default_key, default_value in self.defaults.items():
- value.setdefault(default_key, default_value)
+ self.__set_defaults(value)
for name, el_class in self.fields:
el_class.from_python(stream, value[name])
+ async def from_python_async(self, stream, value):
+ self.__set_defaults(value)
+
+ for name, el_class in self.fields:
+ await el_class.from_python_async(stream, value[name])
+
+ def __set_defaults(self, value):
+ for default_key, default_value in self.defaults.items():
+ value.setdefault(default_key, default_value)
+
class AnyDataObject:
"""
@@ -294,29 +376,44 @@
# if an iterable contains items of more than one non-nullable type,
# return None
- if all([
- isinstance(x, type_first)
- or ((x is None) and allow_none) for x in iterator
- ]):
+ if all(isinstance(x, type_first) or ((x is None) and allow_none) for x in iterator):
return type_first
@classmethod
def parse(cls, stream):
- type_code = bytes(stream.mem_view(offset=ctypes.sizeof(ctypes.c_byte)))
- try:
- data_class = tc_map(type_code)
- except KeyError:
- raise ParseError('Unknown type code: `{}`'.format(type_code))
+ data_class = cls.__data_class_parse(stream)
return data_class.parse(stream)
@classmethod
+ async def parse_async(cls, stream):
+ data_class = cls.__data_class_parse(stream)
+ return await data_class.parse_async(stream)
+
+ @classmethod
+ def __data_class_parse(cls, stream):
+ type_code = bytes(stream.mem_view(offset=ctypes.sizeof(ctypes.c_byte)))
+ try:
+ return tc_map(type_code)
+ except KeyError:
+ raise ParseError('Unknown type code: `{}`'.format(type_code))
+
+ @classmethod
def to_python(cls, ctype_object, *args, **kwargs):
+ data_class = cls.__data_class_from_ctype(ctype_object)
+ return data_class.to_python(ctype_object)
+
+ @classmethod
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ data_class = cls.__data_class_from_ctype(ctype_object)
+ return await data_class.to_python_async(ctype_object)
+
+ @classmethod
+ def __data_class_from_ctype(cls, ctype_object):
type_code = ctype_object.type_code.to_bytes(
ctypes.sizeof(ctypes.c_byte),
byteorder=PROTOCOL_BYTE_ORDER
)
- data_class = tc_map(type_code)
- return data_class.to_python(ctype_object)
+ return tc_map(type_code)
@classmethod
def _init_python_map(cls):
@@ -423,6 +520,11 @@
p_type = cls.map_python_type(value)
p_type.from_python(stream, value)
+ @classmethod
+ async def from_python_async(cls, stream, value):
+ p_type = cls.map_python_type(value)
+ await p_type.from_python_async(stream, value)
+
def infer_from_python(stream, value: Any):
"""
@@ -431,14 +533,26 @@
:param value: pythonic value or (value, type_hint) tuple,
:return: bytes.
"""
- if is_hinted(value):
- value, data_type = value
- else:
- data_type = AnyDataObject
+ value, data_type = __unpack_hinted(value)
data_type.from_python(stream, value)
+async def infer_from_python_async(stream, value: Any):
+ """
+ Async version of infer_from_python
+ """
+ value, data_type = __unpack_hinted(value)
+
+ await data_type.from_python_async(stream, value)
+
+
+def __unpack_hinted(value):
+ if is_hinted(value):
+ return value
+ return value, AnyDataObject
+
+
@attr.s
class AnyDataArray(AnyDataObject):
"""
@@ -448,7 +562,7 @@
def build_header(self):
return type(
- self.__class__.__name__+'Header',
+ self.__class__.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -459,16 +573,33 @@
)
def parse(self, stream):
- header_class = self.build_header()
- header = stream.read_ctype(header_class)
- stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ header, header_class = self.__parse_header(stream)
fields = []
for i in range(header.length):
c_type = super().parse(stream)
fields.append(('element_{}'.format(i), c_type))
- final_class = type(
+ return self.__build_final_class(header_class, fields)
+
+ async def parse_async(self, stream):
+ header, header_class = self.__parse_header(stream)
+
+ fields = []
+ for i in range(header.length):
+ c_type = await super().parse_async(stream)
+ fields.append(('element_{}'.format(i), c_type))
+
+ return self.__build_final_class(header_class, fields)
+
+ def __parse_header(self, stream):
+ header_class = self.build_header()
+ header = stream.read_ctype(header_class)
+ stream.seek(ctypes.sizeof(header_class), SEEK_CUR)
+ return header, header_class
+
+ def __build_final_class(self, header_class, fields):
+ return type(
self.__class__.__name__,
(header_class,),
{
@@ -476,34 +607,58 @@
'_fields_': fields,
}
)
- return final_class
@classmethod
def to_python(cls, ctype_object, *args, **kwargs):
- result = []
- length = getattr(ctype_object, "length", None)
- if length is None:
- return None
- for i in range(length):
- result.append(
+ length = cls.__get_length(ctype_object)
+
+ return [
+ super().to_python(getattr(ctype_object, 'element_{}'.format(i)), *args, **kwargs)
+ for i in range(length)
+ ]
+
+ @classmethod
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ length = cls.__get_length(ctype_object)
+
+ values = asyncio.gather(
+ *[
super().to_python(
getattr(ctype_object, 'element_{}'.format(i)),
*args, **kwargs
- )
- )
- return result
+ ) for i in range(length)
+ ]
+ )
+ return await values
+
+ @staticmethod
+ def __get_length(ctype_object):
+ return getattr(ctype_object, "length", None)
def from_python(self, stream, value):
- header_class = self.build_header()
- header = header_class()
-
try:
length = len(value)
except TypeError:
value = [value]
length = 1
- header.length = length
+ self.__write_header(stream, length)
- stream.write(header)
for x in value:
infer_from_python(stream, x)
+
+ async def from_python_async(self, stream, value):
+ try:
+ length = len(value)
+ except TypeError:
+ value = [value]
+ length = 1
+ self.__write_header(stream, length)
+
+ for x in value:
+ await infer_from_python_async(stream, x)
+
+ def __write_header(self, stream, length):
+ header_class = self.build_header()
+ header = header_class()
+ header.length = length
+ stream.write(header)
diff --git a/pyignite/datatypes/null_object.py b/pyignite/datatypes/null_object.py
index 912ded8..f16034f 100644
--- a/pyignite/datatypes/null_object.py
+++ b/pyignite/datatypes/null_object.py
@@ -21,13 +21,12 @@
import ctypes
from io import SEEK_CUR
-from typing import Any
from .base import IgniteDataType
from .type_codes import TC_NULL
-__all__ = ['Null']
+__all__ = ['Null', 'Nullable']
from ..constants import PROTOCOL_BYTE_ORDER
@@ -37,11 +36,6 @@
pythonic = type(None)
_object_c_type = None
- @staticmethod
- def hashcode(value: Any) -> int:
- # Null object can not be a cache key.
- return 0
-
@classmethod
def build_c_type(cls):
if cls._object_c_type is None:
@@ -59,55 +53,99 @@
@classmethod
def parse(cls, stream):
- init_pos, offset = stream.tell(), ctypes.sizeof(ctypes.c_byte)
- stream.seek(offset, SEEK_CUR)
+ stream.seek(ctypes.sizeof(ctypes.c_byte), SEEK_CUR)
return cls.build_c_type()
- @staticmethod
- def to_python(*args, **kwargs):
+ @classmethod
+ def to_python(cls, *args, **kwargs):
return None
- @staticmethod
- def from_python(stream, *args):
+ @classmethod
+ def from_python(cls, stream, *args):
stream.write(TC_NULL)
-class Nullable:
+class Nullable(IgniteDataType):
@classmethod
def parse_not_null(cls, stream):
raise NotImplementedError
@classmethod
- def parse(cls, stream):
- type_len = ctypes.sizeof(ctypes.c_byte)
+ async def parse_not_null_async(cls, stream):
+ return cls.parse_not_null(stream)
- if stream.mem_view(offset=type_len) == TC_NULL:
- stream.seek(type_len, SEEK_CUR)
- return Null.build_c_type()
+ @classmethod
+ def parse(cls, stream):
+ is_null, null_type = cls.__check_null_input(stream)
+
+ if is_null:
+ return null_type
return cls.parse_not_null(stream)
@classmethod
+ async def parse_async(cls, stream):
+ is_null, null_type = cls.__check_null_input(stream)
+
+ if is_null:
+ return null_type
+
+ return await cls.parse_not_null_async(stream)
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, **kwargs):
+ raise NotImplementedError
+
+ @classmethod
+ async def from_python_not_null_async(cls, stream, value, **kwargs):
+ return cls.from_python_not_null(stream, value, **kwargs)
+
+ @classmethod
+ def from_python(cls, stream, value, **kwargs):
+ if value is None:
+ Null.from_python(stream)
+ else:
+ cls.from_python_not_null(stream, value)
+
+ @classmethod
+ async def from_python_async(cls, stream, value, **kwargs):
+ if value is None:
+ Null.from_python(stream)
+ else:
+ await cls.from_python_not_null_async(stream, value, **kwargs)
+
+ @classmethod
def to_python_not_null(cls, ctypes_object, *args, **kwargs):
raise NotImplementedError
@classmethod
+ async def to_python_not_null_async(cls, ctypes_object, *args, **kwargs):
+ return cls.to_python_not_null(ctypes_object, *args, **kwargs)
+
+ @classmethod
def to_python(cls, ctypes_object, *args, **kwargs):
- if ctypes_object.type_code == int.from_bytes(
- TC_NULL,
- byteorder=PROTOCOL_BYTE_ORDER
- ):
+ if cls.__is_null(ctypes_object):
return None
return cls.to_python_not_null(ctypes_object, *args, **kwargs)
@classmethod
- def from_python_not_null(cls, stream, value):
- raise NotImplementedError
+ async def to_python_async(cls, ctypes_object, *args, **kwargs):
+ if cls.__is_null(ctypes_object):
+ return None
+
+ return await cls.to_python_not_null_async(ctypes_object, *args, **kwargs)
@classmethod
- def from_python(cls, stream, value):
- if value is None:
- Null.from_python(stream)
- else:
- cls.from_python_not_null(stream, value)
+ def __check_null_input(cls, stream):
+ type_len = ctypes.sizeof(ctypes.c_byte)
+
+ if stream.mem_view(offset=type_len) == TC_NULL:
+ stream.seek(type_len, SEEK_CUR)
+ return True, Null.build_c_type()
+
+ return False, None
+
+ @classmethod
+ def __is_null(cls, ctypes_object):
+ return ctypes_object.type_code == int.from_bytes(TC_NULL, byteorder=PROTOCOL_BYTE_ORDER)
diff --git a/pyignite/datatypes/primitive.py b/pyignite/datatypes/primitive.py
index ffa2e32..3bbb196 100644
--- a/pyignite/datatypes/primitive.py
+++ b/pyignite/datatypes/primitive.py
@@ -48,8 +48,7 @@
@classmethod
def parse(cls, stream):
- init_pos, offset = stream.tell(), ctypes.sizeof(cls.c_type)
- stream.seek(offset, SEEK_CUR)
+ stream.seek(ctypes.sizeof(cls.c_type), SEEK_CUR)
return cls.c_type
@classmethod
diff --git a/pyignite/datatypes/primitive_arrays.py b/pyignite/datatypes/primitive_arrays.py
index 7cb5b20..a21de77 100644
--- a/pyignite/datatypes/primitive_arrays.py
+++ b/pyignite/datatypes/primitive_arrays.py
@@ -15,11 +15,8 @@
import ctypes
from io import SEEK_CUR
-from typing import Any
from pyignite.constants import *
-from . import Null
-from .base import IgniteDataType
from .null_object import Nullable
from .primitive import *
from .type_codes import *
@@ -35,7 +32,7 @@
]
-class PrimitiveArray(IgniteDataType, Nullable):
+class PrimitiveArray(Nullable):
"""
Base class for array of primitives. Payload-only.
"""
@@ -44,15 +41,10 @@
primitive_type = None
type_code = None
- @staticmethod
- def hashcode(value: Any) -> int:
- # Arrays are not supported as keys at the moment.
- return 0
-
@classmethod
def build_header_class(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -88,7 +80,11 @@
return [ctype_object.data[i] for i in range(ctype_object.length)]
@classmethod
- def from_python_not_null(cls, stream, value):
+ async def to_python_async(cls, ctypes_object, *args, **kwargs):
+ return cls.to_python(ctypes_object, *args, **kwargs)
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, **kwargs):
header_class = cls.build_header_class()
header = header_class()
if hasattr(header, 'type_code'):
@@ -188,7 +184,7 @@
@classmethod
def build_header_class(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -312,7 +308,5 @@
length = getattr(ctype_object, "length", None)
if length is None:
return None
- result = [False] * length
- for i in range(length):
- result[i] = ctype_object.data[i] != 0
- return result
+
+ return [ctype_object.data[i] != 0 for i in range(length)]
diff --git a/pyignite/datatypes/primitive_objects.py b/pyignite/datatypes/primitive_objects.py
index e942dd7..5849935 100644
--- a/pyignite/datatypes/primitive_objects.py
+++ b/pyignite/datatypes/primitive_objects.py
@@ -18,11 +18,10 @@
from pyignite.constants import *
from pyignite.utils import unsigned
-from .base import IgniteDataType
from .type_codes import *
from .type_ids import *
from .type_names import *
-from .null_object import Null, Nullable
+from .null_object import Nullable
__all__ = [
'DataObject', 'ByteObject', 'ShortObject', 'IntObject', 'LongObject',
@@ -30,7 +29,7 @@
]
-class DataObject(IgniteDataType, Nullable):
+class DataObject(Nullable):
"""
Base class for primitive data objects.
@@ -65,12 +64,16 @@
stream.seek(ctypes.sizeof(data_type), SEEK_CUR)
return data_type
- @staticmethod
- def to_python(ctype_object, *args, **kwargs):
+ @classmethod
+ def to_python(cls, ctype_object, *args, **kwargs):
return getattr(ctype_object, "value", None)
@classmethod
- def from_python_not_null(cls, stream, value):
+ async def to_python_async(cls, ctype_object, *args, **kwargs):
+ return cls.to_python(ctype_object, *args, **kwargs)
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, **kwargs):
data_type = cls.build_c_type()
data_object = data_type()
data_object.type_code = int.from_bytes(
@@ -89,8 +92,8 @@
pythonic = int
default = 0
- @staticmethod
- def hashcode(value: int, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: int, *args, **kwargs) -> int:
return value
@@ -102,8 +105,8 @@
pythonic = int
default = 0
- @staticmethod
- def hashcode(value: int, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: int, *args, **kwargs) -> int:
return value
@@ -115,8 +118,8 @@
pythonic = int
default = 0
- @staticmethod
- def hashcode(value: int, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: int, *args, **kwargs) -> int:
return value
@@ -128,8 +131,8 @@
pythonic = int
default = 0
- @staticmethod
- def hashcode(value: int, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: int, *args, **kwargs) -> int:
return value ^ (unsigned(value, ctypes.c_ulonglong) >> 32)
@@ -141,8 +144,8 @@
pythonic = float
default = 0.0
- @staticmethod
- def hashcode(value: float, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: float, *args, **kwargs) -> int:
return ctypes.cast(
ctypes.pointer(ctypes.c_float(value)),
ctypes.POINTER(ctypes.c_int)
@@ -157,8 +160,8 @@
pythonic = float
default = 0.0
- @staticmethod
- def hashcode(value: float, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: float, *args, **kwargs) -> int:
bits = ctypes.cast(
ctypes.pointer(ctypes.c_double(value)),
ctypes.POINTER(ctypes.c_longlong)
@@ -180,8 +183,8 @@
pythonic = str
default = ' '
- @staticmethod
- def hashcode(value: str, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: str, *args, **kwargs) -> int:
return ord(value)
@classmethod
@@ -195,7 +198,7 @@
).decode(PROTOCOL_CHAR_ENCODING)
@classmethod
- def from_python_not_null(cls, stream, value):
+ def from_python_not_null(cls, stream, value, **kwargs):
if type(value) is str:
value = value.encode(PROTOCOL_CHAR_ENCODING)
# assuming either a bytes or an integer
@@ -216,8 +219,8 @@
pythonic = bool
default = False
- @staticmethod
- def hashcode(value: bool, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: bool, *args, **kwargs) -> int:
return 1231 if value else 1237
@classmethod
@@ -226,4 +229,3 @@
if value is None:
return None
return value != 0
-
diff --git a/pyignite/datatypes/standard.py b/pyignite/datatypes/standard.py
index af50a8e..2b61235 100644
--- a/pyignite/datatypes/standard.py
+++ b/pyignite/datatypes/standard.py
@@ -18,16 +18,15 @@
import decimal
from io import SEEK_CUR
from math import ceil
-from typing import Any, Tuple
+from typing import Tuple
import uuid
from pyignite.constants import *
from pyignite.utils import datetime_hashcode, decimal_hashcode, hashcode
-from .base import IgniteDataType
from .type_codes import *
from .type_ids import *
from .type_names import *
-from .null_object import Null, Nullable
+from .null_object import Nullable
__all__ = [
'String', 'DecimalObject', 'UUIDObject', 'TimestampObject', 'DateObject',
@@ -44,7 +43,7 @@
]
-class StandardObject(IgniteDataType, Nullable):
+class StandardObject(Nullable):
_type_name = None
_type_id = None
type_code = None
@@ -60,7 +59,7 @@
return data_type
-class String(IgniteDataType, Nullable):
+class String(Nullable):
"""
Pascal-style string: `c_int` counter, followed by count*bytes.
UTF-8-encoded, so that one character may take 1 to 4 bytes.
@@ -70,8 +69,8 @@
type_code = TC_STRING
pythonic = str
- @staticmethod
- def hashcode(value: str, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: str, *args, **kwargs) -> int:
return hashcode(value)
@classmethod
@@ -124,15 +123,15 @@
stream.write(data_object)
-class DecimalObject(IgniteDataType, Nullable):
+class DecimalObject(Nullable):
_type_name = NAME_DECIMAL
_type_id = TYPE_DECIMAL
type_code = TC_DECIMAL
pythonic = decimal.Decimal
default = decimal.Decimal('0.00')
- @staticmethod
- def hashcode(value: decimal.Decimal, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: decimal.Decimal, *args, **kwargs) -> int:
return decimal_hashcode(value)
@classmethod
@@ -180,11 +179,7 @@
range(len(data))
])
# apply scale
- result = (
- result
- / decimal.Decimal('10')
- ** decimal.Decimal(ctype_object.scale)
- )
+ result = result / decimal.Decimal('10') ** decimal.Decimal(ctype_object.scale)
if sign:
# apply sign
result = -result
@@ -195,7 +190,7 @@
sign, digits, scale = value.normalize().as_tuple()
integer = int(''.join([str(d) for d in digits]))
# calculate number of bytes (at least one, and not forget the sign bit)
- length = ceil((integer.bit_length() + 1)/8)
+ length = ceil((integer.bit_length() + 1) / 8)
# write byte string
data = []
for i in range(length):
@@ -247,8 +242,8 @@
UUID_BYTE_ORDER = (7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8)
- @staticmethod
- def hashcode(value: 'UUID', *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: 'UUID', *args, **kwargs) -> int:
msb = value.int >> 64
lsb = value.int & 0xffffffffffffffff
hilo = msb ^ lsb
@@ -309,8 +304,8 @@
pythonic = tuple
default = (datetime(1970, 1, 1), 0)
- @staticmethod
- def hashcode(value: Tuple[datetime, int], *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: Tuple[datetime, int], *args, **kwargs) -> int:
return datetime_hashcode(int(value[0].timestamp() * 1000))
@classmethod
@@ -331,7 +326,7 @@
return cls._object_c_type
@classmethod
- def from_python_not_null(cls, stream, value: tuple):
+ def from_python_not_null(cls, stream, value: tuple, **kwargs):
data_type = cls.build_c_type()
data_object = data_type()
data_object.type_code = int.from_bytes(
@@ -346,7 +341,7 @@
@classmethod
def to_python_not_null(cls, ctypes_object, *args, **kwargs):
return (
- datetime.fromtimestamp(ctypes_object.epoch/1000),
+ datetime.fromtimestamp(ctypes_object.epoch / 1000),
ctypes_object.fraction
)
@@ -365,8 +360,8 @@
pythonic = datetime
default = datetime(1970, 1, 1)
- @staticmethod
- def hashcode(value: datetime, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: datetime, *args, **kwargs) -> int:
return datetime_hashcode(int(value.timestamp() * 1000))
@classmethod
@@ -401,7 +396,7 @@
@classmethod
def to_python_not_null(cls, ctypes_object, *args, **kwargs):
- return datetime.fromtimestamp(ctypes_object.epoch/1000)
+ return datetime.fromtimestamp(ctypes_object.epoch / 1000)
class TimeObject(StandardObject):
@@ -417,8 +412,8 @@
pythonic = timedelta
default = timedelta()
- @staticmethod
- def hashcode(value: timedelta, *args, **kwargs) -> int:
+ @classmethod
+ def hashcode(cls, value: timedelta, *args, **kwargs) -> int:
return datetime_hashcode(int(value.total_seconds() * 1000))
@classmethod
@@ -510,7 +505,7 @@
type_code = TC_BINARY_ENUM
-class StandardArray(IgniteDataType, Nullable):
+class StandardArray(Nullable):
"""
Base class for array of primitives. Payload-only.
"""
@@ -519,15 +514,10 @@
standard_type = None
type_code = None
- @staticmethod
- def hashcode(value: Any) -> int:
- # Arrays are not supported as keys at the moment.
- return 0
-
@classmethod
def build_header_class(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -575,7 +565,11 @@
return result
@classmethod
- def from_python_not_null(cls, stream, value):
+ async def to_python_async(cls, ctypes_object, *args, **kwargs):
+ return cls.to_python(ctypes_object, *args, **kwargs)
+
+ @classmethod
+ def from_python_not_null(cls, stream, value, **kwargs):
header_class = cls.build_header_class()
header = header_class()
if hasattr(header, 'type_code'):
@@ -648,7 +642,7 @@
@classmethod
def build_header_class(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -723,7 +717,7 @@
@classmethod
def build_header_class(cls):
return type(
- cls.__name__+'Header',
+ cls.__name__ + 'Header',
(ctypes.LittleEndianStructure,),
{
'_pack_': 1,
@@ -736,7 +730,7 @@
)
@classmethod
- def from_python_not_null(cls, stream, value):
+ def from_python_not_null(cls, stream, value, **kwargs):
type_id, value = value
header_class = cls.build_header_class()
header = header_class()
@@ -754,7 +748,7 @@
cls.standard_type.from_python(stream, x)
@classmethod
- def to_python(cls, ctype_object, *args, **kwargs):
+ def to_python_not_null(cls, ctype_object, *args, **kwargs):
type_id = getattr(ctype_object, "type_id", None)
if type_id is None:
return None
diff --git a/pyignite/exceptions.py b/pyignite/exceptions.py
index 5933228..579aa29 100644
--- a/pyignite/exceptions.py
+++ b/pyignite/exceptions.py
@@ -93,4 +93,4 @@
pass
-connection_errors = (IOError, OSError)
+connection_errors = (IOError, OSError, EOFError)
diff --git a/pyignite/queries/__init__.py b/pyignite/queries/__init__.py
index d558125..56c6347 100644
--- a/pyignite/queries/__init__.py
+++ b/pyignite/queries/__init__.py
@@ -21,4 +21,4 @@
:mod:`pyignite.datatypes` binary parser/generator classes.
"""
-from .query import Query, ConfigQuery
+from .query import Query, ConfigQuery, query_perform
diff --git a/pyignite/queries/query.py b/pyignite/queries/query.py
index b5be753..beea5d9 100644
--- a/pyignite/queries/query.py
+++ b/pyignite/queries/query.py
@@ -13,15 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import attr
import ctypes
+from io import SEEK_CUR
from random import randint
+import attr
+
from pyignite.api.result import APIResult
-from pyignite.connection import Connection
+from pyignite.connection import Connection, AioConnection
from pyignite.constants import MIN_LONG, MAX_LONG, RHF_TOPOLOGY_CHANGED
-from pyignite.queries.response import Response, SQLResponse
-from pyignite.stream import BinaryStream, READ_BACKWARD
+from pyignite.queries.response import Response
+from pyignite.stream import AioBinaryStream, BinaryStream, READ_BACKWARD
+
+
+def query_perform(query_struct, conn, post_process_fun=None, **kwargs):
+ async def _async_internal():
+ result = await query_struct.perform_async(conn, **kwargs)
+ if post_process_fun:
+ return post_process_fun(result)
+ return result
+
+ def _internal():
+ result = query_struct.perform(conn, **kwargs)
+ if post_process_fun:
+ return post_process_fun(result)
+ return result
+
+ if isinstance(conn, AioConnection):
+ return _async_internal()
+ return _internal()
@attr.s
@@ -29,6 +49,7 @@
op_code = attr.ib(type=int)
following = attr.ib(type=list, factory=list)
query_id = attr.ib(type=int, default=None)
+ response_type = attr.ib(type=type(Response), default=Response)
_query_c_type = None
@classmethod
@@ -48,32 +69,45 @@
)
return cls._query_c_type
- def _build_header(self, stream, values: dict):
+ def from_python(self, stream, values: dict = None):
+ init_pos, header = stream.tell(), self._build_header(stream)
+ values = values if values else None
+
+ for name, c_type in self.following:
+ c_type.from_python(stream, values[name])
+
+ self.__write_header(stream, header, init_pos)
+
+ async def from_python_async(self, stream, values: dict = None):
+ init_pos, header = stream.tell(), self._build_header(stream)
+ values = values if values else None
+
+ for name, c_type in self.following:
+ await c_type.from_python_async(stream, values[name])
+
+ self.__write_header(stream, header, init_pos)
+
+ def _build_header(self, stream):
header_class = self.build_c_type()
header_len = ctypes.sizeof(header_class)
- init_pos = stream.tell()
- stream.seek(init_pos + header_len)
+ stream.seek(header_len, SEEK_CUR)
header = header_class()
header.op_code = self.op_code
if self.query_id is None:
header.query_id = randint(MIN_LONG, MAX_LONG)
- for name, c_type in self.following:
- c_type.from_python(stream, values[name])
-
- header.length = stream.tell() - init_pos - ctypes.sizeof(ctypes.c_int)
- stream.seek(init_pos)
-
return header
- def from_python(self, stream, values: dict = None):
- header = self._build_header(stream, values if values else {})
+ @staticmethod
+ def __write_header(stream, header, init_pos):
+ header.length = stream.tell() - init_pos - ctypes.sizeof(ctypes.c_int)
+ stream.seek(init_pos)
stream.write(header)
def perform(
self, conn: Connection, query_params: dict = None,
- response_config: list = None, sql: bool = False, **kwargs,
+ response_config: list = None, **kwargs,
) -> APIResult:
"""
Perform query and process result.
@@ -83,26 +117,60 @@
Defaults to no parameters,
:param response_config: (optional) response configuration − list of
(name, type_hint) tuples. Defaults to empty return value,
- :param sql: (optional) use normal (default) or SQL response class,
:return: instance of :class:`~pyignite.api.result.APIResult` with raw
value (may undergo further processing in API functions).
"""
- with BinaryStream(conn) as stream:
+ with BinaryStream(conn.client) as stream:
self.from_python(stream, query_params)
- conn.send(stream.getbuffer())
+ response_data = conn.request(stream.getbuffer())
- if sql:
- response_struct = SQLResponse(protocol_version=conn.get_protocol_version(),
- following=response_config, **kwargs)
- else:
- response_struct = Response(protocol_version=conn.get_protocol_version(),
- following=response_config)
+ response_struct = self.response_type(protocol_version=conn.protocol_version,
+ following=response_config, **kwargs)
- with BinaryStream(conn, conn.recv()) as stream:
+ with BinaryStream(conn.client, response_data) as stream:
response_ctype = response_struct.parse(stream)
response = stream.read_ctype(response_ctype, direction=READ_BACKWARD)
- # this test depends on protocol version
+ result = self.__post_process_response(conn, response_struct, response)
+
+ if result.status == 0:
+ result.value = response_struct.to_python(response)
+ return result
+
+ async def perform_async(
+ self, conn: AioConnection, query_params: dict = None,
+ response_config: list = None, **kwargs,
+ ) -> APIResult:
+ """
+ Perform query and process result.
+
+ :param conn: connection to Ignite server,
+ :param query_params: (optional) dict of named query parameters.
+ Defaults to no parameters,
+ :param response_config: (optional) response configuration − list of
+ (name, type_hint) tuples. Defaults to empty return value,
+ :return: instance of :class:`~pyignite.api.result.APIResult` with raw
+ value (may undergo further processing in API functions).
+ """
+ with AioBinaryStream(conn.client) as stream:
+ await self.from_python_async(stream, query_params)
+ data = await conn.request(stream.getbuffer())
+
+ response_struct = self.response_type(protocol_version=conn.protocol_version,
+ following=response_config, **kwargs)
+
+ with AioBinaryStream(conn.client, data) as stream:
+ response_ctype = await response_struct.parse_async(stream)
+ response = stream.read_ctype(response_ctype, direction=READ_BACKWARD)
+
+ result = self.__post_process_response(conn, response_struct, response)
+
+ if result.status == 0:
+ result.value = await response_struct.to_python_async(response)
+ return result
+
+ @staticmethod
+ def __post_process_response(conn, response_struct, response):
if getattr(response, 'flags', False) & RHF_TOPOLOGY_CHANGED:
# update latest affinity version
new_affinity = (response.affinity_version, response.affinity_minor)
@@ -112,10 +180,7 @@
conn.client.affinity_version = new_affinity
# build result
- result = APIResult(response)
- if result.status == 0:
- result.value = response_struct.to_python(response)
- return result
+ return APIResult(response)
class ConfigQuery(Query):
@@ -142,7 +207,7 @@
)
return cls._query_c_type
- def _build_header(self, stream, values: dict):
- header = super()._build_header(stream, values)
+ def _build_header(self, stream):
+ header = super()._build_header(stream)
header.config_length = header.length - ctypes.sizeof(type(header))
return header
diff --git a/pyignite/queries/response.py b/pyignite/queries/response.py
index ca2ae14..83a6e6a 100644
--- a/pyignite/queries/response.py
+++ b/pyignite/queries/response.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
from io import SEEK_CUR
import attr
@@ -20,6 +21,7 @@
from pyignite.constants import RHF_TOPOLOGY_CHANGED, RHF_ERROR
from pyignite.datatypes import AnyDataObject, Bool, Int, Long, String, StringArray, Struct
+from pyignite.datatypes.binary import body_struct, enum_struct, schema_struct
from pyignite.queries.op_codes import OP_SUCCESS
from pyignite.stream import READ_BACKWARD
@@ -35,7 +37,7 @@
# replace None with empty list
self.following = self.following or []
- def build_header(self):
+ def __build_header(self):
if self._response_header is None:
fields = [
('length', ctypes.c_int),
@@ -57,9 +59,9 @@
)
return self._response_header
- def parse(self, stream):
+ def __parse_header(self, stream):
init_pos = stream.tell()
- header_class = self.build_header()
+ header_class = self.__build_header()
header_len = ctypes.sizeof(header_class)
header = stream.read_ctype(header_class)
stream.seek(header_len, SEEK_CUR)
@@ -85,9 +87,10 @@
if has_error:
msg_type = String.parse(stream)
fields.append(('error_message', msg_type))
- else:
- self._parse_success(stream, fields)
+ return not has_error, init_pos, header_class, fields
+
+ def __build_response_class(self, stream, init_pos, header_class, fields):
response_class = type(
self._response_class_name,
(header_class,),
@@ -100,21 +103,52 @@
stream.seek(init_pos + ctypes.sizeof(response_class))
return response_class
+ def parse(self, stream):
+ success, init_pos, header_class, fields = self.__parse_header(stream)
+ if success:
+ self._parse_success(stream, fields)
+
+ return self.__build_response_class(stream, init_pos, header_class, fields)
+
+ async def parse_async(self, stream):
+ success, init_pos, header_class, fields = self.__parse_header(stream)
+ if success:
+ await self._parse_success_async(stream, fields)
+
+ return self.__build_response_class(stream, init_pos, header_class, fields)
+
def _parse_success(self, stream, fields: list):
for name, ignite_type in self.following:
c_type = ignite_type.parse(stream)
fields.append((name, c_type))
- def to_python(self, ctype_object, *args, **kwargs):
- result = OrderedDict()
+ async def _parse_success_async(self, stream, fields: list):
+ for name, ignite_type in self.following:
+ c_type = await ignite_type.parse_async(stream)
+ fields.append((name, c_type))
+ def to_python(self, ctype_object, *args, **kwargs):
+ if not self.following:
+ return None
+
+ result = OrderedDict()
for name, c_type in self.following:
result[name] = c_type.to_python(
getattr(ctype_object, name),
*args, **kwargs
)
- return result if result else None
+ return result
+
+ async def to_python_async(self, ctype_object, *args, **kwargs):
+ if not self.following:
+ return None
+
+ values = await asyncio.gather(
+ *[c_type.to_python_async(getattr(ctype_object, name), *args, **kwargs) for name, c_type in self.following]
+ )
+
+ return OrderedDict([(name, values[i]) for i, (name, _) in enumerate(self.following)])
@attr.s
@@ -135,38 +169,62 @@
return 'field_count', Int
def _parse_success(self, stream, fields: list):
- following = [
- self.fields_or_field_count(),
- ('row_count', Int),
- ]
- if self.has_cursor:
- following.insert(0, ('cursor', Long))
- body_struct = Struct(following)
+ body_struct = self.__create_body_struct()
body_class = body_struct.parse(stream)
body = stream.read_ctype(body_class, direction=READ_BACKWARD)
- if self.include_field_names:
- field_count = body.fields.length
- else:
- field_count = body.field_count
-
- data_fields = []
+ data_fields, field_count = [], self.__get_fields_count(body)
for i in range(body.row_count):
row_fields = []
for j in range(field_count):
field_class = AnyDataObject.parse(stream)
row_fields.append(('column_{}'.format(j), field_class))
- row_class = type(
- 'SQLResponseRow',
- (ctypes.LittleEndianStructure,),
- {
- '_pack_': 1,
- '_fields_': row_fields,
- }
- )
- data_fields.append(('row_{}'.format(i), row_class))
+ self.__row_post_process(i, row_fields, data_fields)
+ self.__body_class_post_process(body_class, fields, data_fields)
+
+ async def _parse_success_async(self, stream, fields: list):
+ body_struct = self.__create_body_struct()
+ body_class = await body_struct.parse_async(stream)
+ body = stream.read_ctype(body_class, direction=READ_BACKWARD)
+
+ data_fields, field_count = [], self.__get_fields_count(body)
+ for i in range(body.row_count):
+ row_fields = []
+ for j in range(field_count):
+ field_class = await AnyDataObject.parse_async(stream)
+ row_fields.append(('column_{}'.format(j), field_class))
+
+ self.__row_post_process(i, row_fields, data_fields)
+
+ self.__body_class_post_process(body_class, fields, data_fields)
+
+ def __create_body_struct(self):
+ following = [self.fields_or_field_count(), ('row_count', Int)]
+ if self.has_cursor:
+ following.insert(0, ('cursor', Long))
+ return Struct(following)
+
+ def __get_fields_count(self, body):
+ if self.include_field_names:
+ return body.fields.length
+ return body.field_count
+
+ @staticmethod
+ def __row_post_process(idx, row_fields, data_fields):
+ row_class = type(
+ 'SQLResponseRow',
+ (ctypes.LittleEndianStructure,),
+ {
+ '_pack_': 1,
+ '_fields_': row_fields,
+ }
+ )
+ data_fields.append((f'row_{idx}', row_class))
+
+ @staticmethod
+ def __body_class_post_process(body_class, fields, data_fields):
data_class = type(
'SQLResponseData',
(ctypes.LittleEndianStructure,),
@@ -182,24 +240,8 @@
def to_python(self, ctype_object, *args, **kwargs):
if getattr(ctype_object, 'status_code', 0) == 0:
- result = {
- 'more': Bool.to_python(
- ctype_object.more, *args, **kwargs
- ),
- 'data': [],
- }
- if hasattr(ctype_object, 'fields'):
- result['fields'] = StringArray.to_python(
- ctype_object.fields, *args, **kwargs
- )
- else:
- result['field_count'] = Int.to_python(
- ctype_object.field_count, *args, **kwargs
- )
- if hasattr(ctype_object, 'cursor'):
- result['cursor'] = Long.to_python(
- ctype_object.cursor, *args, **kwargs
- )
+ result = self.__to_python_result_header(ctype_object, *args, **kwargs)
+
for row_item in ctype_object.data._fields_:
row_name = row_item[0]
row_object = getattr(ctype_object.data, row_name)
@@ -207,8 +249,104 @@
for col_item in row_object._fields_:
col_name = col_item[0]
col_object = getattr(row_object, col_name)
- row.append(
- AnyDataObject.to_python(col_object, *args, **kwargs)
- )
+ row.append(AnyDataObject.to_python(col_object, *args, **kwargs))
result['data'].append(row)
return result
+
+ async def to_python_async(self, ctype_object, *args, **kwargs):
+ if getattr(ctype_object, 'status_code', 0) == 0:
+ result = self.__to_python_result_header(ctype_object, *args, **kwargs)
+
+ data_coro = []
+ for row_item in ctype_object.data._fields_:
+ row_name = row_item[0]
+ row_object = getattr(ctype_object.data, row_name)
+ row_coro = []
+ for col_item in row_object._fields_:
+ col_name = col_item[0]
+ col_object = getattr(row_object, col_name)
+ row_coro.append(AnyDataObject.to_python_async(col_object, *args, **kwargs))
+
+ data_coro.append(asyncio.gather(*row_coro))
+
+ result['data'] = await asyncio.gather(*data_coro)
+ return result
+
+ @staticmethod
+ def __to_python_result_header(ctype_object, *args, **kwargs):
+ result = {
+ 'more': Bool.to_python(ctype_object.more, *args, **kwargs),
+ 'data': [],
+ }
+ if hasattr(ctype_object, 'fields'):
+ result['fields'] = StringArray.to_python(ctype_object.fields, *args, **kwargs)
+ else:
+ result['field_count'] = Int.to_python(ctype_object.field_count, *args, **kwargs)
+
+ if hasattr(ctype_object, 'cursor'):
+ result['cursor'] = Long.to_python(ctype_object.cursor, *args, **kwargs)
+ return result
+
+
+class BinaryTypeResponse(Response):
+ _response_class_name = 'GetBinaryTypeResponse'
+
+ def _parse_success(self, stream, fields: list):
+ type_exists = self.__process_type_exists(stream, fields)
+
+ if type_exists.value:
+ resp_body_type = body_struct.parse(stream)
+ fields.append(('body', resp_body_type))
+ resp_body = stream.read_ctype(resp_body_type, direction=READ_BACKWARD)
+ if resp_body.is_enum:
+ resp_enum = enum_struct.parse(stream)
+ fields.append(('enums', resp_enum))
+
+ resp_schema_type = schema_struct.parse(stream)
+ fields.append(('schema', resp_schema_type))
+
+ async def _parse_success_async(self, stream, fields: list):
+ type_exists = self.__process_type_exists(stream, fields)
+
+ if type_exists.value:
+ resp_body_type = await body_struct.parse_async(stream)
+ fields.append(('body', resp_body_type))
+ resp_body = stream.read_ctype(resp_body_type, direction=READ_BACKWARD)
+ if resp_body.is_enum:
+ resp_enum = await enum_struct.parse_async(stream)
+ fields.append(('enums', resp_enum))
+
+ resp_schema_type = await schema_struct.parse_async(stream)
+ fields.append(('schema', resp_schema_type))
+
+ @staticmethod
+ def __process_type_exists(stream, fields):
+ fields.append(('type_exists', ctypes.c_byte))
+ type_exists = stream.read_ctype(ctypes.c_byte)
+ stream.seek(ctypes.sizeof(ctypes.c_byte), SEEK_CUR)
+
+ return type_exists
+
+ def to_python(self, ctype_object, *args, **kwargs):
+ if getattr(ctype_object, 'status_code', 0) == 0:
+ result = {
+ 'type_exists': Bool.to_python(ctype_object.type_exists)
+ }
+
+ if hasattr(ctype_object, 'body'):
+ result.update(body_struct.to_python(ctype_object.body))
+
+ if hasattr(ctype_object, 'enums'):
+ result['enums'] = enum_struct.to_python(ctype_object.enums)
+
+ if hasattr(ctype_object, 'schema'):
+ result['schema'] = {
+ x['schema_id']: [
+ z['schema_field_id'] for z in x['schema_fields']
+ ]
+ for x in schema_struct.to_python(ctype_object.schema)
+ }
+ return result
+
+ async def to_python_async(self, ctype_object, *args, **kwargs):
+ return self.to_python(ctype_object, *args, **kwargs)
diff --git a/pyignite/stream/__init__.py b/pyignite/stream/__init__.py
index 94153b4..76d171d 100644
--- a/pyignite/stream/__init__.py
+++ b/pyignite/stream/__init__.py
@@ -13,4 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .binary_stream import BinaryStream, READ_FORWARD, READ_BACKWARD
\ No newline at end of file
+from .binary_stream import BinaryStream, AioBinaryStream, READ_FORWARD, READ_BACKWARD
+
+__all__ = ['BinaryStream', 'AioBinaryStream', 'READ_BACKWARD', 'READ_FORWARD']
diff --git a/pyignite/stream/binary_stream.py b/pyignite/stream/binary_stream.py
index 46ac683..57b4b83 100644
--- a/pyignite/stream/binary_stream.py
+++ b/pyignite/stream/binary_stream.py
@@ -14,39 +14,23 @@
# limitations under the License.
import ctypes
from io import BytesIO
+from typing import Union, Optional
+import pyignite
import pyignite.utils as ignite_utils
READ_FORWARD = 0
READ_BACKWARD = 1
-class BinaryStream:
- def __init__(self, conn, buf=None):
- """
- Initialize binary stream around buffers.
-
- :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO.
- :param conn: Connection instance, required.
- """
- from pyignite.connection import Connection
-
- if not isinstance(conn, Connection):
- raise TypeError(f"invalid parameter: expected instance of {Connection}")
-
- if buf and not isinstance(buf, (bytearray, bytes, memoryview)):
- raise TypeError(f"invalid parameter: expected bytes-like object")
-
- self.conn = conn
- self.stream = BytesIO(buf) if buf else BytesIO()
-
+class BinaryStreamBaseMixin:
@property
def compact_footer(self) -> bool:
- return self.conn.client.compact_footer
+ return self.client.compact_footer
@compact_footer.setter
def compact_footer(self, value: bool):
- self.conn.client.compact_footer = value
+ self.client.compact_footer = value
def read(self, size):
buf = bytearray(size)
@@ -86,10 +70,10 @@
def mem_view(self, start=-1, offset=0):
start = start if start >= 0 else self.tell()
- return self.stream.getbuffer()[start:start+offset]
+ return self.stream.getbuffer()[start:start + offset]
def hashcode(self, start, bytes_len):
- return ignite_utils.hashcode(self.stream.getbuffer()[start:start+bytes_len])
+ return ignite_utils.hashcode(self.stream.getbuffer()[start:start + bytes_len])
def __enter__(self):
return self
@@ -100,15 +84,48 @@
except BufferError:
pass
+
+class BinaryStream(BinaryStreamBaseMixin):
+ """
+ Synchronous binary stream.
+ """
+ def __init__(self, client: 'pyignite.Client', buf: Optional[Union[bytes, bytearray, memoryview]] = None):
+ """
+ :param client: Client instance, required.
+ :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO.
+ """
+ self.client = client
+ self.stream = BytesIO(buf) if buf else BytesIO()
+
def get_dataclass(self, header):
- # get field names from outer space
- result = self.conn.client.query_binary_type(
- header.type_id,
- header.schema_id
- )
+ result = self.client.query_binary_type(header.type_id, header.schema_id)
if not result:
raise RuntimeError('Binary type is not registered')
return result
def register_binary_type(self, *args, **kwargs):
- return self.conn.client.register_binary_type(*args, **kwargs)
+ self.client.register_binary_type(*args, **kwargs)
+
+
+class AioBinaryStream(BinaryStreamBaseMixin):
+ """
+ Asyncio binary stream.
+ """
+ def __init__(self, client: 'pyignite.AioClient', buf: Optional[Union[bytes, bytearray, memoryview]] = None):
+ """
+ Initialize binary stream around buffers.
+
+ :param client: AioClient instance, required.
+ :param buf: Buffer, optional parameter. If not passed, creates empty BytesIO.
+ """
+ self.client = client
+ self.stream = BytesIO(buf) if buf else BytesIO()
+
+ async def get_dataclass(self, header):
+ result = await self.client.query_binary_type(header.type_id, header.schema_id)
+ if not result:
+ raise RuntimeError('Binary type is not registered')
+ return result
+
+ async def register_binary_type(self, *args, **kwargs):
+ await self.client.register_binary_type(*args, **kwargs)
diff --git a/pyignite/utils.py b/pyignite/utils.py
index f1a7f90..975f414 100644
--- a/pyignite/utils.py
+++ b/pyignite/utils.py
@@ -15,6 +15,7 @@
import ctypes
import decimal
+import inspect
import warnings
from functools import wraps
@@ -65,23 +66,14 @@
"""
Check if a value is a tuple of data item and its type hint.
"""
- return (
- isinstance(value, tuple)
- and len(value) == 2
- and issubclass(value[1], IgniteDataType)
- )
+ return isinstance(value, tuple) and len(value) == 2 and issubclass(value[1], IgniteDataType)
def is_wrapped(value: Any) -> bool:
"""
Check if a value is of WrappedDataObject type.
"""
- return (
- type(value) is tuple
- and len(value) == 2
- and type(value[0]) is bytes
- and type(value[1]) is int
- )
+ return type(value) is tuple and len(value) == 2 and type(value[0]) is bytes and type(value[1]) is int
def int_overflow(value: int) -> int:
@@ -107,7 +99,7 @@
def __hashcode_fallback(data: Union[str, bytes, bytearray, memoryview]) -> int:
if data is None:
return 0
-
+
if isinstance(data, str):
"""
For strings we iterate over code point which are of the int type
@@ -206,8 +198,7 @@
# this is the case when Java BigDecimal digits are stored
# compactly, in the internal 64-bit integer field
int_hash = (
- (unsigned(value, ctypes.c_ulonglong) >> 32) * 31
- + (value & LONG_MASK)
+ (unsigned(value, ctypes.c_ulonglong) >> 32) * 31 + (value & LONG_MASK)
) & LONG_MASK
else:
# digits are not fit in the 64-bit long, so they get split internally
@@ -243,25 +234,31 @@
def status_to_exception(exc: Type[Exception]):
"""
Converts erroneous status code with error message to an exception
- of the given class.
+ of the given class. Supports coroutines.
:param exc: the class of exception to raise,
- :return: decorator.
+ :return: decorated function.
"""
+ def process_result(result):
+ if result.status != 0:
+ raise exc(result.message)
+ return result.value
+
def ste_decorator(fn):
- @wraps(fn)
- def ste_wrapper(*args, **kwargs):
- result = fn(*args, **kwargs)
- if result.status != 0:
- raise exc(result.message)
- return result.value
- return ste_wrapper
+ if inspect.iscoroutinefunction(fn):
+ @wraps(fn)
+ async def ste_wrapper_async(*args, **kwargs):
+ return process_result(await fn(*args, **kwargs))
+ return ste_wrapper_async
+ else:
+ @wraps(fn)
+ def ste_wrapper(*args, **kwargs):
+ return process_result(fn(*args, **kwargs))
+ return ste_wrapper
return ste_decorator
-def get_field_by_id(
- obj: 'GenericObjectMeta', field_id: int
-) -> Tuple[Any, IgniteDataType]:
+def get_field_by_id(obj: 'GenericObjectMeta', field_id: int) -> Tuple[Any, IgniteDataType]:
"""
Returns a complex object's field value, given the field's entity ID.
diff --git a/requirements/tests.txt b/requirements/tests.txt
index 5d5ae84..38a8e9e 100644
--- a/requirements/tests.txt
+++ b/requirements/tests.txt
@@ -1,7 +1,10 @@
# these packages are used for testing
+async_generator==1.10; python_version < '3.7'
pytest==6.2.2
pytest-cov==2.11.1
+pytest-asyncio==0.14.0
teamcity-messages==1.28
psutil==5.8.0
jinja2==2.11.3
+flake8==3.8.4
diff --git a/setup.py b/setup.py
index 4d90e4e..5db3aed 100644
--- a/setup.py
+++ b/setup.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import re
from collections import defaultdict
from distutils.command.build_ext import build_ext
from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError
@@ -86,6 +86,14 @@
with open('README.md', 'r', encoding='utf-8') as readme_file:
long_description = readme_file.read()
+version = ''
+with open('pyignite/__init__.py', 'r') as fd:
+ version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
+ fd.read(), re.MULTILINE).group(1)
+
+if not version:
+ raise RuntimeError('Cannot find version information')
+
def run_setup(with_binary=True):
if with_binary:
@@ -98,7 +106,7 @@
setuptools.setup(
name='pyignite',
- version='0.4.0',
+ version=version,
python_requires='>=3.6',
author='The Apache Software Foundation',
author_email='dev@ignite.apache.org',
diff --git a/tests/affinity/conftest.py b/tests/affinity/conftest.py
index 7595f25..2ec2b1b 100644
--- a/tests/affinity/conftest.py
+++ b/tests/affinity/conftest.py
@@ -15,8 +15,7 @@
import pytest
-from pyignite import Client
-from pyignite.api import cache_create, cache_destroy
+from pyignite import Client, AioClient
from tests.util import start_ignite_gen
# Sometimes on slow testing servers and unstable topology
@@ -42,29 +41,21 @@
@pytest.fixture
def client():
client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT)
-
- client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
-
- yield client
-
- client.close()
+ try:
+ client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
+ yield client
+ finally:
+ client.close()
@pytest.fixture
-def client_not_connected():
- client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT)
- yield client
- client.close()
-
-
-@pytest.fixture
-def cache(connected_client):
- cache_name = 'my_bucket'
- conn = connected_client.random_node
-
- cache_create(conn, cache_name)
- yield cache_name
- cache_destroy(conn, cache_name)
+async def async_client():
+ client = AioClient(partition_aware=True)
+ try:
+ await client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
+ yield client
+ finally:
+ await client.close()
@pytest.fixture(scope='module', autouse=True)
diff --git a/tests/affinity/test_affinity.py b/tests/affinity/test_affinity.py
index ee8f6c0..b1bcec7 100644
--- a/tests/affinity/test_affinity.py
+++ b/tests/affinity/test_affinity.py
@@ -13,139 +13,265 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from datetime import datetime, timedelta
+import asyncio
import decimal
+from datetime import datetime, timedelta
from uuid import UUID, uuid4
import pytest
-from pyignite import GenericObjectMeta
-from pyignite.api import *
-from pyignite.constants import *
-from pyignite.datatypes import *
-from pyignite.datatypes.cache_config import CacheMode
-from pyignite.datatypes.prop_codes import *
-
-
-def test_get_node_partitions(client):
- conn = client.random_node
-
- cache_1 = client.get_or_create_cache('test_cache_1')
- cache_2 = client.get_or_create_cache({
- PROP_NAME: 'test_cache_2',
- PROP_CACHE_KEY_CONFIGURATION: [
- {
- 'type_name': ByteArray.type_name,
- 'affinity_key_field_name': 'byte_affinity',
- }
- ],
- })
- client.get_or_create_cache('test_cache_3')
- client.get_or_create_cache('test_cache_4')
- client.get_or_create_cache('test_cache_5')
-
- result = cache_get_node_partitions(
- conn,
- [cache_1.cache_id, cache_2.cache_id]
- )
- assert result.status == 0, result.message
-
-
-@pytest.mark.parametrize(
- 'key, key_hint', [
- # integers
- (42, None),
- (43, ByteObject),
- (-44, ByteObject),
- (45, IntObject),
- (-46, IntObject),
- (47, ShortObject),
- (-48, ShortObject),
- (49, LongObject),
- (MAX_INT-50, LongObject),
- (MAX_INT+51, LongObject),
-
- # floating point
- (5.2, None),
- (5.354, FloatObject),
- (-5.556, FloatObject),
- (-57.58, DoubleObject),
-
- # boolean
- (True, None),
- (True, BoolObject),
- (False, BoolObject),
-
- # char
- ('A', CharObject),
- ('Z', CharObject),
- ('â…“', CharObject),
- ('á', CharObject),
- ('Ñ‹', CharObject),
- ('ã‚«', CharObject),
- ('Ø', CharObject),
- ('ß', CharObject),
-
- # string
- ('This is a test string', None),
- ('Кириллица', None),
- ('Little Mary had a lamb', String),
-
- # UUID
- (UUID('12345678123456789876543298765432'), None),
- (UUID('74274274274274274274274274274274'), UUIDObject),
- (uuid4(), None),
-
- # decimal (long internal representation in Java)
- (decimal.Decimal('-234.567'), None),
- (decimal.Decimal('200.0'), None),
- (decimal.Decimal('123.456'), DecimalObject),
- (decimal.Decimal('1.0'), None),
- (decimal.Decimal('0.02'), None),
-
- # decimal (BigInteger internal representation in Java)
- (decimal.Decimal('12345671234567123.45671234567'), None),
- (decimal.Decimal('-845678456.7845678456784567845'), None),
-
- # date and time
- (datetime(1980, 1, 1), None),
- ((datetime(1980, 1, 1), 999), TimestampObject),
- (timedelta(days=99), TimeObject),
-
- ],
+from pyignite import GenericObjectMeta, AioClient
+from pyignite.api import (
+ cache_get_node_partitions, cache_get_node_partitions_async, cache_local_peek, cache_local_peek_async
)
-def test_affinity(client, key, key_hint):
- cache_1 = client.get_or_create_cache({
- PROP_NAME: 'test_cache_1',
- PROP_CACHE_MODE: CacheMode.PARTITIONED,
- })
- value = 42
- cache_1.put(key, value, key_hint=key_hint)
+from pyignite.constants import MAX_INT
+from pyignite.datatypes import (
+ BinaryObject, ByteArray, ByteObject, IntObject, ShortObject, LongObject, FloatObject, DoubleObject, BoolObject,
+ CharObject, String, UUIDObject, DecimalObject, TimestampObject, TimeObject
+)
+from pyignite.datatypes.cache_config import CacheMode
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_MODE, PROP_CACHE_KEY_CONFIGURATION
+from tests.util import wait_for_condition, wait_for_condition_async
- best_node = cache_1.get_best_node(key, key_hint=key_hint)
- for node in filter(lambda n: n.alive, client._nodes):
- result = cache_local_peek(
- node, cache_1.cache_id, key, key_hint=key_hint,
- )
- if node is best_node:
- assert result.value == value, (
- 'Affinity calculation error for {}'.format(key)
- )
+def test_get_node_partitions(client, caches):
+ cache_ids = [cache.cache_id for cache in caches]
+ __wait_for_ready_affinity(client, cache_ids)
+ mappings = __get_mappings(client, cache_ids)
+ __check_mappings(mappings, cache_ids)
+
+
+@pytest.mark.asyncio
+async def test_get_node_partitions_async(async_client, async_caches):
+ cache_ids = [cache.cache_id for cache in async_caches]
+ await __wait_for_ready_affinity(async_client, cache_ids)
+ mappings = await __get_mappings(async_client, cache_ids)
+ __check_mappings(mappings, cache_ids)
+
+
+def __wait_for_ready_affinity(client, cache_ids):
+ def inner():
+ def condition():
+ result = __get_mappings(client, cache_ids)
+ return len(result.value['partition_mapping']) == len(cache_ids)
+
+ wait_for_condition(condition)
+
+ async def inner_async():
+ async def condition():
+ result = await __get_mappings(client, cache_ids)
+ return len(result.value['partition_mapping']) == len(cache_ids)
+
+ await wait_for_condition_async(condition)
+
+ return inner_async() if isinstance(client, AioClient) else inner()
+
+
+def __get_mappings(client, cache_ids):
+ def inner():
+ conn = client.random_node
+ result = cache_get_node_partitions(conn, cache_ids)
+ assert result.status == 0, result.message
+ return result
+
+ async def inner_async():
+ conn = await client.random_node()
+ result = await cache_get_node_partitions_async(conn, cache_ids)
+ assert result.status == 0, result.message
+ return result
+
+ return inner_async() if isinstance(client, AioClient) else inner()
+
+
+def __check_mappings(result, cache_ids):
+ partition_mapping = result.value['partition_mapping']
+
+ for i, cache_id in enumerate(cache_ids):
+ cache_mapping = partition_mapping[cache_id]
+ assert 'is_applicable' in cache_mapping
+
+ # Check replicated cache
+ if i == 3:
+ assert not cache_mapping['is_applicable']
+ assert 'node_mapping' not in cache_mapping
+ assert cache_mapping['number_of_partitions'] == 0
else:
- assert result.value is None, (
- 'Affinity calculation error for {}'.format(key)
- )
+ # Check cache config
+ if i == 2:
+ assert cache_mapping['cache_config']
- cache_1.destroy()
+ assert cache_mapping['is_applicable']
+ assert cache_mapping['node_mapping']
+ assert cache_mapping['number_of_partitions'] == 1024
-def test_affinity_for_generic_object(client):
- cache_1 = client.get_or_create_cache({
+@pytest.fixture
+def caches(client):
+ yield from __create_caches_fixture(client)
+
+
+@pytest.fixture
+async def async_caches(async_client):
+ async for caches in __create_caches_fixture(async_client):
+ yield caches
+
+
+def __create_caches_fixture(client):
+ caches_to_create = []
+ for i in range(0, 5):
+ cache_name = f'test_cache_{i}'
+ if i == 2:
+ caches_to_create.append((
+ cache_name,
+ {
+ PROP_NAME: cache_name,
+ PROP_CACHE_KEY_CONFIGURATION: [
+ {
+ 'type_name': ByteArray.type_name,
+ 'affinity_key_field_name': 'byte_affinity',
+ }
+ ]
+ }))
+ elif i == 3:
+ caches_to_create.append((
+ cache_name,
+ {
+ PROP_NAME: cache_name,
+ PROP_CACHE_MODE: CacheMode.REPLICATED
+ }
+ ))
+ else:
+ caches_to_create.append((cache_name, None))
+
+ def generate_caches():
+ caches = []
+ for name, config in caches_to_create:
+ if config:
+ cache = client.get_or_create_cache(config)
+ else:
+ cache = client.get_or_create_cache(name)
+ caches.append(cache)
+ return asyncio.gather(*caches) if isinstance(client, AioClient) else caches
+
+ def inner():
+ caches = []
+ try:
+ caches = generate_caches()
+ yield caches
+ finally:
+ for cache in caches:
+ cache.destroy()
+
+ async def inner_async():
+ caches = []
+ try:
+ caches = await generate_caches()
+ yield caches
+ finally:
+ await asyncio.gather(*[cache.destroy() for cache in caches])
+
+ return inner_async() if isinstance(client, AioClient) else inner()
+
+
+@pytest.fixture
+def cache(client):
+ cache = client.get_or_create_cache({
PROP_NAME: 'test_cache_1',
PROP_CACHE_MODE: CacheMode.PARTITIONED,
})
+ try:
+ yield cache
+ finally:
+ cache.destroy()
+
+@pytest.fixture
+async def async_cache(async_client):
+ cache = await async_client.get_or_create_cache({
+ PROP_NAME: 'test_cache_1',
+ PROP_CACHE_MODE: CacheMode.PARTITIONED,
+ })
+ try:
+ yield cache
+ finally:
+ await cache.destroy()
+
+
+affinity_primitives_params = [
+ # integers
+ (42, None),
+ (43, ByteObject),
+ (-44, ByteObject),
+ (45, IntObject),
+ (-46, IntObject),
+ (47, ShortObject),
+ (-48, ShortObject),
+ (49, LongObject),
+ (MAX_INT - 50, LongObject),
+ (MAX_INT + 51, LongObject),
+
+ # floating point
+ (5.2, None),
+ (5.354, FloatObject),
+ (-5.556, FloatObject),
+ (-57.58, DoubleObject),
+
+ # boolean
+ (True, None),
+ (True, BoolObject),
+ (False, BoolObject),
+
+ # char
+ ('A', CharObject),
+ ('Z', CharObject),
+ ('â…“', CharObject),
+ ('á', CharObject),
+ ('Ñ‹', CharObject),
+ ('ã‚«', CharObject),
+ ('Ø', CharObject),
+ ('ß', CharObject),
+
+ # string
+ ('This is a test string', None),
+ ('Кириллица', None),
+ ('Little Mary had a lamb', String),
+
+ # UUID
+ (UUID('12345678123456789876543298765432'), None),
+ (UUID('74274274274274274274274274274274'), UUIDObject),
+ (uuid4(), None),
+
+ # decimal (long internal representation in Java)
+ (decimal.Decimal('-234.567'), None),
+ (decimal.Decimal('200.0'), None),
+ (decimal.Decimal('123.456'), DecimalObject),
+ (decimal.Decimal('1.0'), None),
+ (decimal.Decimal('0.02'), None),
+
+ # decimal (BigInteger internal representation in Java)
+ (decimal.Decimal('12345671234567123.45671234567'), None),
+ (decimal.Decimal('-845678456.7845678456784567845'), None),
+
+ # date and time
+ (datetime(1980, 1, 1), None),
+ ((datetime(1980, 1, 1), 999), TimestampObject),
+ (timedelta(days=99), TimeObject)
+]
+
+
+@pytest.mark.parametrize('key, key_hint', affinity_primitives_params)
+def test_affinity(client, cache, key, key_hint):
+ __check_best_node_calculation(client, cache, key, 42, key_hint=key_hint)
+
+
+@pytest.mark.parametrize('key, key_hint', affinity_primitives_params)
+@pytest.mark.asyncio
+async def test_affinity_async(async_client, async_cache, key, key_hint):
+ await __check_best_node_calculation(async_client, async_cache, key, 42, key_hint=key_hint)
+
+
+@pytest.fixture
+def key_generic_object():
class KeyClass(
metaclass=GenericObjectMeta,
schema={
@@ -158,61 +284,45 @@
key = KeyClass()
key.NO = 1
key.NAME = 'test_string'
+ yield key
- cache_1.put(key, 42, key_hint=BinaryObject)
- best_node = cache_1.get_best_node(key, key_hint=BinaryObject)
+@pytest.mark.parametrize('with_type_hint', [True, False])
+def test_affinity_for_generic_object(client, cache, key_generic_object, with_type_hint):
+ key_hint = BinaryObject if with_type_hint else None
+ __check_best_node_calculation(client, cache, key_generic_object, 42, key_hint=key_hint)
- for node in filter(lambda n: n.alive, client._nodes):
- result = cache_local_peek(
- node, cache_1.cache_id, key, key_hint=BinaryObject,
- )
+
+@pytest.mark.parametrize('with_type_hint', [True, False])
+@pytest.mark.asyncio
+async def test_affinity_for_generic_object_async(async_client, async_cache, key_generic_object, with_type_hint):
+ key_hint = BinaryObject if with_type_hint else None
+ await __check_best_node_calculation(async_client, async_cache, key_generic_object, 42, key_hint=key_hint)
+
+
+def __check_best_node_calculation(client, cache, key, value, key_hint=None):
+ def check_peek_value(node, best_node, result):
if node is best_node:
- assert result.value == 42, (
- 'Affinity calculation error for {}'.format(key)
- )
+ assert result.value == value, f'Affinity calculation error for {key}'
else:
- assert result.value is None, (
- 'Affinity calculation error for {}'.format(key)
- )
+ assert result.value is None, f'Affinity calculation error for {key}'
- cache_1.destroy()
+ def inner():
+ cache.put(key, value, key_hint=key_hint)
+ best_node = cache.get_best_node(key, key_hint=key_hint)
+ for node in filter(lambda n: n.alive, client._nodes):
+ result = cache_local_peek(node, cache.cache_id, key, key_hint=key_hint)
-def test_affinity_for_generic_object_without_type_hints(client):
- cache_1 = client.get_or_create_cache({
- PROP_NAME: 'test_cache_1',
- PROP_CACHE_MODE: CacheMode.PARTITIONED,
- })
+ check_peek_value(node, best_node, result)
- class KeyClass(
- metaclass=GenericObjectMeta,
- schema={
- 'NO': IntObject,
- 'NAME': String,
- },
- ):
- pass
+ async def inner_async():
+ await cache.put(key, value, key_hint=key_hint)
+ best_node = await cache.get_best_node(key, key_hint=key_hint)
- key = KeyClass()
- key.NO = 2
- key.NAME = 'another_test_string'
+ for node in filter(lambda n: n.alive, client._nodes):
+ result = await cache_local_peek_async(node, cache.cache_id, key, key_hint=key_hint)
- cache_1.put(key, 42)
+ check_peek_value(node, best_node, result)
- best_node = cache_1.get_best_node(key)
-
- for node in filter(lambda n: n.alive, client._nodes):
- result = cache_local_peek(
- node, cache_1.cache_id, key
- )
- if node is best_node:
- assert result.value == 42, (
- 'Affinity calculation error for {}'.format(key)
- )
- else:
- assert result.value is None, (
- 'Affinity calculation error for {}'.format(key)
- )
-
- cache_1.destroy()
+ return inner_async() if isinstance(client, AioClient) else inner()
diff --git a/tests/affinity/test_affinity_bad_servers.py b/tests/affinity/test_affinity_bad_servers.py
index 6fd08d5..b169168 100644
--- a/tests/affinity/test_affinity_bad_servers.py
+++ b/tests/affinity/test_affinity_bad_servers.py
@@ -15,9 +15,9 @@
import pytest
-from pyignite.exceptions import ReconnectError
+from pyignite.exceptions import ReconnectError, connection_errors
from tests.affinity.conftest import CLIENT_SOCKET_TIMEOUT
-from tests.util import start_ignite, kill_process_tree, get_client
+from tests.util import start_ignite, kill_process_tree, get_client, get_client_async
@pytest.fixture(params=['with-partition-awareness', 'without-partition-awareness'])
@@ -26,10 +26,16 @@
def test_client_with_multiple_bad_servers(with_partition_awareness):
- with pytest.raises(ReconnectError) as e_info:
+ with pytest.raises(ReconnectError, match="Can not connect."):
with get_client(partition_aware=with_partition_awareness) as client:
client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
- assert str(e_info.value) == "Can not connect."
+
+
+@pytest.mark.asyncio
+async def test_client_with_multiple_bad_servers_async(with_partition_awareness):
+ with pytest.raises(ReconnectError, match="Can not connect."):
+ async with get_client_async(partition_aware=with_partition_awareness) as client:
+ await client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
def test_client_with_failed_server(request, with_partition_awareness):
@@ -52,6 +58,27 @@
kill_process_tree(srv.pid)
+@pytest.mark.asyncio
+async def test_client_with_failed_server_async(request, with_partition_awareness):
+ srv = start_ignite(idx=4)
+ try:
+ async with get_client_async(partition_aware=with_partition_awareness) as client:
+ await client.connect([("127.0.0.1", 10804)])
+ cache = await client.get_or_create_cache(request.node.name)
+ await cache.put(1, 1)
+ kill_process_tree(srv.pid)
+
+ if with_partition_awareness:
+ ex_class = (ReconnectError, ConnectionResetError)
+ else:
+ ex_class = ConnectionResetError
+
+ with pytest.raises(ex_class):
+ await cache.get(1)
+ finally:
+ kill_process_tree(srv.pid)
+
+
def test_client_with_recovered_server(request, with_partition_awareness):
srv = start_ignite(idx=4)
try:
@@ -67,7 +94,7 @@
# First request may fail.
try:
cache.put(1, 2)
- except:
+ except connection_errors:
pass
# Retry succeeds
@@ -75,3 +102,29 @@
assert cache.get(1) == 2
finally:
kill_process_tree(srv.pid)
+
+
+@pytest.mark.asyncio
+async def test_client_with_recovered_server_async(request, with_partition_awareness):
+ srv = start_ignite(idx=4)
+ try:
+ async with get_client_async(partition_aware=with_partition_awareness) as client:
+ await client.connect([("127.0.0.1", 10804)])
+ cache = await client.get_or_create_cache(request.node.name)
+ await cache.put(1, 1)
+
+ # Kill and restart server
+ kill_process_tree(srv.pid)
+ srv = start_ignite(idx=4)
+
+ # First request may fail.
+ try:
+ await cache.put(1, 2)
+ except connection_errors:
+ pass
+
+ # Retry succeeds
+ await cache.put(1, 2)
+ assert await cache.get(1) == 2
+ finally:
+ kill_process_tree(srv.pid)
diff --git a/tests/affinity/test_affinity_request_routing.py b/tests/affinity/test_affinity_request_routing.py
index 101db39..64197ff 100644
--- a/tests/affinity/test_affinity_request_routing.py
+++ b/tests/affinity/test_affinity_request_routing.py
@@ -13,20 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
from collections import OrderedDict, deque
+import random
+
import pytest
-from pyignite import *
-from pyignite.connection import Connection
+from pyignite import GenericObjectMeta, AioClient, Client
+from pyignite.aio_cache import AioCache
+from pyignite.connection import Connection, AioConnection
from pyignite.constants import PROTOCOL_BYTE_ORDER
-from pyignite.datatypes import *
+from pyignite.datatypes import String, LongObject
from pyignite.datatypes.cache_config import CacheMode
-from pyignite.datatypes.prop_codes import *
-from tests.util import *
-
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_BACKUPS_NUMBER, PROP_CACHE_KEY_CONFIGURATION, PROP_CACHE_MODE
+from tests.util import wait_for_condition, wait_for_condition_async, start_ignite, kill_process_tree
requests = deque()
old_send = Connection.send
+old_send_async = AioConnection._send
def patched_send(self, *args, **kwargs):
@@ -40,13 +44,26 @@
return old_send(self, *args, **kwargs)
+async def patched_send_async(self, *args, **kwargs):
+ """Patched send function that push to queue idx of server to which request is routed."""
+ buf = args[0]
+ if buf and len(buf) >= 6:
+ op_code = int.from_bytes(buf[4:6], byteorder=PROTOCOL_BYTE_ORDER)
+ # Filter only caches operation.
+ if 1000 <= op_code < 1100:
+ requests.append(self.port % 100)
+ return await old_send_async(self, *args, **kwargs)
+
+
def setup_function():
requests.clear()
Connection.send = patched_send
+ AioConnection._send = patched_send_async
def teardown_function():
Connection.send = old_send
+ AioConnection.send = old_send_async
def wait_for_affinity_distribution(cache, key, node_idx, timeout=30):
@@ -68,6 +85,25 @@
f"got {real_node_idx} instead")
+async def wait_for_affinity_distribution_async(cache, key, node_idx, timeout=30):
+ real_node_idx = 0
+
+ async def check_grid_idx():
+ nonlocal real_node_idx
+ try:
+ await cache.get(key)
+ real_node_idx = requests.pop()
+ except (OSError, IOError):
+ return False
+ return real_node_idx == node_idx
+
+ res = await wait_for_condition_async(check_grid_idx, timeout=timeout)
+
+ if not res:
+ raise TimeoutError(f"failed to wait for affinity distribution, expected node_idx {node_idx},"
+ f"got {real_node_idx} instead")
+
+
@pytest.mark.parametrize("key,grid_idx", [(1, 1), (2, 2), (3, 3), (4, 1), (5, 1), (6, 2), (11, 1), (13, 1), (19, 1)])
@pytest.mark.parametrize("backups", [0, 1, 2, 3])
def test_cache_operation_on_primitive_key_routes_request_to_primary_node(request, key, grid_idx, backups, client):
@@ -75,52 +111,56 @@
PROP_NAME: request.node.name + str(backups),
PROP_BACKUPS_NUMBER: backups,
})
+ try:
+ __perform_operations_on_primitive_key(client, cache, key, grid_idx)
+ finally:
+ cache.destroy()
- cache.put(key, key)
- wait_for_affinity_distribution(cache, key, grid_idx)
- # Test
- cache.get(key)
- assert requests.pop() == grid_idx
+@pytest.mark.parametrize("key,grid_idx", [(1, 1), (2, 2), (3, 3), (4, 1), (5, 1), (6, 2), (11, 1), (13, 1), (19, 1)])
+@pytest.mark.parametrize("backups", [0, 1, 2, 3])
+@pytest.mark.asyncio
+async def test_cache_operation_on_primitive_key_routes_request_to_primary_node_async(
+ request, key, grid_idx, backups, async_client):
+ cache = await async_client.get_or_create_cache({
+ PROP_NAME: request.node.name + str(backups),
+ PROP_BACKUPS_NUMBER: backups,
+ })
+ try:
+ await __perform_operations_on_primitive_key(async_client, cache, key, grid_idx)
+ finally:
+ await cache.destroy()
- cache.put(key, key)
- assert requests.pop() == grid_idx
- cache.replace(key, key + 1)
- assert requests.pop() == grid_idx
+def __perform_operations_on_primitive_key(client, cache, key, grid_idx):
+ operations = [
+ ('get', 1), ('put', 2), ('replace', 2), ('clear_key', 1), ('contains_key', 1), ('get_and_put', 2),
+ ('get_and_put_if_absent', 2), ('put_if_absent', 2), ('get_and_remove', 1), ('get_and_replace', 2),
+ ('remove_key', 1), ('remove_if_equals', 2), ('replace', 2), ('replace_if_equals', 3)
+ ]
- cache.clear_key(key)
- assert requests.pop() == grid_idx
+ def inner():
+ cache.put(key, key)
+ wait_for_affinity_distribution(cache, key, grid_idx)
- cache.contains_key(key)
- assert requests.pop() == grid_idx
+ for op_name, param_nums in operations:
+ op = getattr(cache, op_name)
+ args = [random.randint(-100, 100) for _ in range(0, param_nums - 1)]
+ op(key, *args)
+ assert requests.pop() == grid_idx
- cache.get_and_put(key, 3)
- assert requests.pop() == grid_idx
+ async def inner_async():
+ await cache.put(key, key)
+ await wait_for_affinity_distribution_async(cache, key, grid_idx)
- cache.get_and_put_if_absent(key, 4)
- assert requests.pop() == grid_idx
+ for op_name, param_nums in operations:
+ op = getattr(cache, op_name)
+ args = [random.randint(-100, 100) for _ in range(0, param_nums - 1)]
+ await op(key, *args)
- cache.put_if_absent(key, 5)
- assert requests.pop() == grid_idx
+ assert requests.pop() == grid_idx
- cache.get_and_remove(key)
- assert requests.pop() == grid_idx
-
- cache.get_and_replace(key, 6)
- assert requests.pop() == grid_idx
-
- cache.remove_key(key)
- assert requests.pop() == grid_idx
-
- cache.remove_if_equals(key, -1)
- assert requests.pop() == grid_idx
-
- cache.replace(key, -1)
- assert requests.pop() == grid_idx
-
- cache.replace_if_equals(key, 10, -10)
- assert requests.pop() == grid_idx
+ return inner_async() if isinstance(client, AioClient) else inner()
@pytest.mark.skip(reason="Custom key objects are not supported yet")
@@ -164,50 +204,144 @@
assert requests.pop() == grid_idx
-def test_cache_operation_routed_to_new_cluster_node(request, client_not_connected):
- client_not_connected.connect(
- [("127.0.0.1", 10801), ("127.0.0.1", 10802), ("127.0.0.1", 10803), ("127.0.0.1", 10804)]
- )
- cache = client_not_connected.get_or_create_cache(request.node.name)
- key = 12
- wait_for_affinity_distribution(cache, key, 3)
- cache.put(key, key)
- cache.put(key, key)
- assert requests.pop() == 3
+client_routed_connection_string = [('127.0.0.1', 10800 + idx) for idx in range(1, 5)]
- srv = start_ignite(idx=4)
+
+@pytest.fixture
+def client_routed_cache(request):
+ client = Client(partition_aware=True)
try:
- # Wait for rebalance and partition map exchange
- wait_for_affinity_distribution(cache, key, 4)
-
- # Response is correct and comes from the new node
- res = cache.get_and_remove(key)
- assert res == key
- assert requests.pop() == 4
+ client.connect(client_routed_connection_string)
+ yield client.get_or_create_cache(request.node.name)
finally:
- kill_process_tree(srv.pid)
+ client.close()
-def test_replicated_cache_operation_routed_to_random_node(request, client):
+@pytest.fixture
+async def async_client_routed_cache(request):
+ client = AioClient(partition_aware=True)
+ try:
+ await client.connect(client_routed_connection_string)
+ yield await client.get_or_create_cache(request.node.name)
+ finally:
+ await client.close()
+
+
+def test_cache_operation_routed_to_new_cluster_node(client_routed_cache):
+ __perform_cache_operation_routed_to_new_node(client_routed_cache)
+
+
+@pytest.mark.asyncio
+async def test_cache_operation_routed_to_new_cluster_node_async(async_client_routed_cache):
+ await __perform_cache_operation_routed_to_new_node(async_client_routed_cache)
+
+
+def __perform_cache_operation_routed_to_new_node(cache):
+ key = 12
+
+ def inner():
+ wait_for_affinity_distribution(cache, key, 3)
+ cache.put(key, key)
+ cache.put(key, key)
+ assert requests.pop() == 3
+
+ srv = start_ignite(idx=4)
+ try:
+ # Wait for rebalance and partition map exchange
+ wait_for_affinity_distribution(cache, key, 4)
+
+ # Response is correct and comes from the new node
+ res = cache.get_and_remove(key)
+ assert res == key
+ assert requests.pop() == 4
+ finally:
+ kill_process_tree(srv.pid)
+
+ async def inner_async():
+ await wait_for_affinity_distribution_async(cache, key, 3)
+ await cache.put(key, key)
+ await cache.put(key, key)
+ assert requests.pop() == 3
+
+ srv = start_ignite(idx=4)
+ try:
+ # Wait for rebalance and partition map exchange
+ await wait_for_affinity_distribution_async(cache, key, 4)
+
+ # Response is correct and comes from the new node
+ res = await cache.get_and_remove(key)
+ assert res == key
+ assert requests.pop() == 4
+ finally:
+ kill_process_tree(srv.pid)
+
+ return inner_async() if isinstance(cache, AioCache) else inner()
+
+
+@pytest.fixture
+def replicated_cache(request, client):
cache = client.get_or_create_cache({
PROP_NAME: request.node.name,
PROP_CACHE_MODE: CacheMode.REPLICATED,
})
+ try:
+ yield cache
+ finally:
+ cache.destroy()
- verify_random_node(cache)
+
+@pytest.fixture
+async def async_replicated_cache(request, async_client):
+ cache = await async_client.get_or_create_cache({
+ PROP_NAME: request.node.name,
+ PROP_CACHE_MODE: CacheMode.REPLICATED,
+ })
+ try:
+ yield cache
+ finally:
+ await cache.destroy()
+
+
+def test_replicated_cache_operation_routed_to_random_node(replicated_cache):
+ verify_random_node(replicated_cache)
+
+
+@pytest.mark.asyncio
+async def test_replicated_cache_operation_routed_to_random_node_async(async_replicated_cache):
+ await verify_random_node(async_replicated_cache)
def verify_random_node(cache):
key = 1
- cache.put(key, key)
- idx1 = requests.pop()
- idx2 = idx1
-
- # Try 10 times - random node may end up being the same
- for _ in range(1, 10):
+ def inner():
cache.put(key, key)
- idx2 = requests.pop()
- if idx2 != idx1:
- break
- assert idx1 != idx2
+
+ idx1 = requests.pop()
+ idx2 = idx1
+
+ # Try 10 times - random node may end up being the same
+ for _ in range(1, 10):
+ cache.put(key, key)
+ idx2 = requests.pop()
+ if idx2 != idx1:
+ break
+ assert idx1 != idx2
+
+ async def inner_async():
+ await cache.put(key, key)
+
+ idx1 = requests.pop()
+
+ idx2 = idx1
+
+ # Try 10 times - random node may end up being the same
+ for _ in range(1, 10):
+ await cache.put(key, key)
+ idx2 = requests.pop()
+
+ if idx2 != idx1:
+ break
+ assert idx1 != idx2
+
+ return inner_async() if isinstance(cache, AioCache) else inner()
diff --git a/tests/affinity/test_affinity_single_connection.py b/tests/affinity/test_affinity_single_connection.py
index 0768011..c3d2473 100644
--- a/tests/affinity/test_affinity_single_connection.py
+++ b/tests/affinity/test_affinity_single_connection.py
@@ -15,15 +15,27 @@
import pytest
-from pyignite import Client
+from pyignite import Client, AioClient
-@pytest.fixture(scope='module')
+@pytest.fixture
def client():
client = Client(partition_aware=True)
- client.connect('127.0.0.1', 10801)
- yield client
- client.close()
+ try:
+ client.connect('127.0.0.1', 10801)
+ yield client
+ finally:
+ client.close()
+
+
+@pytest.fixture
+async def async_client():
+ client = AioClient(partition_aware=True)
+ try:
+ await client.connect('127.0.0.1', 10801)
+ yield client
+ finally:
+ await client.close()
def test_all_cache_operations_with_partition_aware_client_on_single_server(request, client):
@@ -108,3 +120,88 @@
assert not res
assert res2
assert cache.get(key) == key2
+
+
+@pytest.mark.asyncio
+async def test_all_cache_operations_with_partition_aware_client_on_single_server_async(request, async_client):
+ cache = await async_client.get_or_create_cache(request.node.name)
+ key = 1
+ key2 = 2
+
+ # Put/Get
+ await cache.put(key, key)
+ assert await cache.get(key) == key
+
+ # Replace
+ res = await cache.replace(key, key2)
+ assert res
+ assert await cache.get(key) == key2
+
+ # Clear
+ await cache.put(key2, key2)
+ await cache.clear_key(key2)
+ assert await cache.get(key2) is None
+
+ # ContainsKey
+ assert await cache.contains_key(key)
+ assert not await cache.contains_key(key2)
+
+ # GetAndPut
+ await cache.put(key, key)
+ res = await cache.get_and_put(key, key2)
+ assert res == key
+ assert await cache.get(key) == key2
+
+ # GetAndPutIfAbsent
+ await cache.clear_key(key)
+ res = await cache.get_and_put_if_absent(key, key)
+ res2 = await cache.get_and_put_if_absent(key, key2)
+ assert res is None
+ assert res2 == key
+ assert await cache.get(key) == key
+
+ # PutIfAbsent
+ await cache.clear_key(key)
+ res = await cache.put_if_absent(key, key)
+ res2 = await cache.put_if_absent(key, key2)
+ assert res
+ assert not res2
+ assert await cache.get(key) == key
+
+ # GetAndRemove
+ await cache.put(key, key)
+ res = await cache.get_and_remove(key)
+ assert res == key
+ assert await cache.get(key) is None
+
+ # GetAndReplace
+ await cache.put(key, key)
+ res = await cache.get_and_replace(key, key2)
+ assert res == key
+ assert await cache.get(key) == key2
+
+ # RemoveKey
+ await cache.put(key, key)
+ await cache.remove_key(key)
+ assert await cache.get(key) is None
+
+ # RemoveIfEquals
+ await cache.put(key, key)
+ res = await cache.remove_if_equals(key, key2)
+ res2 = await cache.remove_if_equals(key, key)
+ assert not res
+ assert res2
+ assert await cache.get(key) is None
+
+ # Replace
+ await cache.put(key, key)
+ await cache.replace(key, key2)
+ assert await cache.get(key) == key2
+
+ # ReplaceIfEquals
+ await cache.put(key, key)
+ res = await cache.replace_if_equals(key, key2, key2)
+ res2 = await cache.replace_if_equals(key, key, key2)
+ assert not res
+ assert res2
+ assert await cache.get(key) == key2
diff --git a/tests/common/conftest.py b/tests/common/conftest.py
index 402aede..243d822 100644
--- a/tests/common/conftest.py
+++ b/tests/common/conftest.py
@@ -15,8 +15,7 @@
import pytest
-from pyignite import Client
-from pyignite.api import cache_create, cache_destroy
+from pyignite import Client, AioClient
from tests.util import start_ignite_gen
@@ -38,19 +37,36 @@
@pytest.fixture(scope='module')
def client():
client = Client()
+ try:
+ client.connect('127.0.0.1', 10801)
+ yield client
+ finally:
+ client.close()
- client.connect('127.0.0.1', 10801)
- yield client
+@pytest.fixture(scope='module')
+async def async_client(event_loop):
+ client = AioClient()
+ try:
+ await client.connect('127.0.0.1', 10801)
+ yield client
+ finally:
+ await client.close()
- client.close()
+
+@pytest.fixture
+async def async_cache(async_client: 'AioClient'):
+ cache = await async_client.create_cache('my_bucket')
+ try:
+ yield cache
+ finally:
+ await cache.destroy()
@pytest.fixture
def cache(client):
- cache_name = 'my_bucket'
- conn = client.random_node
-
- cache_create(conn, cache_name)
- yield cache_name
- cache_destroy(conn, cache_name)
+ cache = client.create_cache('my_bucket')
+ try:
+ yield cache
+ finally:
+ cache.destroy()
diff --git a/tests/common/test_binary.py b/tests/common/test_binary.py
index 5fa2ec4..1d7192f 100644
--- a/tests/common/test_binary.py
+++ b/tests/common/test_binary.py
@@ -16,15 +16,17 @@
from collections import OrderedDict
from decimal import Decimal
+import pytest
+
from pyignite import GenericObjectMeta
+from pyignite.aio_cache import AioCache
from pyignite.datatypes import (
BinaryObject, BoolObject, IntObject, DecimalObject, LongObject, String, ByteObject, ShortObject, FloatObject,
DoubleObject, CharObject, UUIDObject, DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject,
ByteArrayObject, ShortArrayObject, IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject,
CharArrayObject, BoolArrayObject, UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject,
EnumArrayObject, StringArrayObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject)
-from pyignite.datatypes.prop_codes import *
-
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_SQL_SCHEMA, PROP_QUERY_ENTITIES
insert_data = [
[1, True, 'asdf', 42, Decimal('2.4')],
@@ -54,7 +56,7 @@
insert_query = '''
INSERT INTO {} (
- test_pk, test_bool, test_str, test_int, test_decimal,
+ test_pk, test_bool, test_str, test_int, test_decimal,
) VALUES (?, ?, ?, ?, ?)'''.format(table_sql_name)
select_query = '''SELECT * FROM {}'''.format(table_sql_name)
@@ -62,51 +64,69 @@
drop_query = 'DROP TABLE {} IF EXISTS'.format(table_sql_name)
-def test_sql_read_as_binary(client):
+@pytest.fixture
+def table_cache_read(client):
client.sql(drop_query)
-
- # create table
client.sql(create_query)
- # insert some rows
for line in insert_data:
client.sql(insert_query, query_args=line)
- table_cache = client.get_cache(table_cache_name)
- result = table_cache.scan()
-
- # convert Binary object fields' values to a tuple
- # to compare it with the initial data
- for key, value in result:
- assert key in {x[0] for x in insert_data}
- assert (
- value.TEST_BOOL,
- value.TEST_STR,
- value.TEST_INT,
- value.TEST_DECIMAL
- ) in {tuple(x[1:]) for x in insert_data}
-
- client.sql(drop_query)
+ cache = client.get_cache(table_cache_name)
+ yield cache
+ cache.destroy()
-def test_sql_write_as_binary(client):
- # configure cache as an SQL table
- type_name = table_cache_name
+@pytest.fixture
+async def table_cache_read_async(async_client):
+ await async_client.sql(drop_query)
+ await async_client.sql(create_query)
- # register binary type
- class AllDataType(
- metaclass=GenericObjectMeta,
- type_name=type_name,
- schema=OrderedDict([
- ('TEST_BOOL', BoolObject),
- ('TEST_STR', String),
- ('TEST_INT', IntObject),
- ('TEST_DECIMAL', DecimalObject),
- ]),
- ):
- pass
+ for line in insert_data:
+ await async_client.sql(insert_query, query_args=line)
- table_cache = client.get_or_create_cache({
+ cache = await async_client.get_cache(table_cache_name)
+ yield cache
+ await cache.destroy()
+
+
+def test_sql_read_as_binary(table_cache_read):
+ with table_cache_read.scan() as cursor:
+ # convert Binary object fields' values to a tuple
+ # to compare it with the initial data
+ for key, value in cursor:
+ assert key in {x[0] for x in insert_data}
+ assert (value.TEST_BOOL, value.TEST_STR, value.TEST_INT, value.TEST_DECIMAL) \
+ in {tuple(x[1:]) for x in insert_data}
+
+
+@pytest.mark.asyncio
+async def test_sql_read_as_binary_async(table_cache_read_async):
+ async with table_cache_read_async.scan() as cursor:
+ # convert Binary object fields' values to a tuple
+ # to compare it with the initial data
+ async for key, value in cursor:
+ assert key in {x[0] for x in insert_data}
+ assert (value.TEST_BOOL, value.TEST_STR, value.TEST_INT, value.TEST_DECIMAL) \
+ in {tuple(x[1:]) for x in insert_data}
+
+
+class AllDataType(
+ metaclass=GenericObjectMeta,
+ type_name=table_cache_name,
+ schema=OrderedDict([
+ ('TEST_BOOL', BoolObject),
+ ('TEST_STR', String),
+ ('TEST_INT', IntObject),
+ ('TEST_DECIMAL', DecimalObject),
+ ]),
+):
+ pass
+
+
+@pytest.fixture
+def table_cache_write_settings():
+ return {
PROP_NAME: table_cache_name,
PROP_SQL_SCHEMA: scheme_name,
PROP_QUERY_ENTITIES: [
@@ -142,15 +162,18 @@
},
],
'query_indexes': [],
- 'value_type_name': type_name,
+ 'value_type_name': table_cache_name,
'value_field_name': None,
},
],
- })
- table_settings = table_cache.settings
- assert table_settings, 'SQL table cache settings are empty'
+ }
- # insert rows as k-v
+
+@pytest.fixture
+def table_cache_write(client, table_cache_write_settings):
+ cache = client.get_or_create_cache(table_cache_write_settings)
+ assert cache.settings, 'SQL table cache settings are empty'
+
for row in insert_data:
value = AllDataType()
(
@@ -159,13 +182,39 @@
value.TEST_INT,
value.TEST_DECIMAL,
) = row[1:]
- table_cache.put(row[0], value, key_hint=IntObject)
+ cache.put(row[0], value, key_hint=IntObject)
- data = table_cache.scan()
- assert len(list(data)) == len(insert_data), (
- 'Not all data was read as key-value'
- )
+ data = cache.scan()
+ assert len(list(data)) == len(insert_data), 'Not all data was read as key-value'
+ yield cache
+ cache.destroy()
+
+
+@pytest.fixture
+async def async_table_cache_write(async_client, table_cache_write_settings):
+ cache = await async_client.get_or_create_cache(table_cache_write_settings)
+ assert await cache.settings(), 'SQL table cache settings are empty'
+
+ for row in insert_data:
+ value = AllDataType()
+ (
+ value.TEST_BOOL,
+ value.TEST_STR,
+ value.TEST_INT,
+ value.TEST_DECIMAL,
+ ) = row[1:]
+ await cache.put(row[0], value, key_hint=IntObject)
+
+ async with cache.scan() as cursor:
+ data = [a async for a in cursor]
+ assert len(data) == len(insert_data), 'Not all data was read as key-value'
+
+ yield cache
+ await cache.destroy()
+
+
+def test_sql_write_as_binary(client, table_cache_write):
# read rows as SQL
data = client.sql(select_query, include_field_names=True)
@@ -176,14 +225,29 @@
data = list(data)
assert len(data) == len(insert_data), 'Not all data was read as SQL rows'
- # cleanup
- table_cache.destroy()
+
+@pytest.mark.asyncio
+async def test_sql_write_as_binary_async(async_client, async_table_cache_write):
+ # read rows as SQL
+ async with async_client.sql(select_query, include_field_names=True) as cursor:
+ header_row = await cursor.__anext__()
+ for field_name in AllDataType.schema.keys():
+ assert field_name in header_row, 'Not all field names in header row'
+
+ data = [v async for v in cursor]
+ assert len(data) == len(insert_data), 'Not all data was read as SQL rows'
-def test_nested_binary_objects(client):
+def test_nested_binary_objects(cache):
+ __check_nested_binary_objects(cache)
- nested_cache = client.get_or_create_cache('nested_binary')
+@pytest.mark.asyncio
+async def test_nested_binary_objects_async(async_cache):
+ await __check_nested_binary_objects(async_cache)
+
+
+def __check_nested_binary_objects(cache):
class InnerType(
metaclass=GenericObjectMeta,
schema=OrderedDict([
@@ -203,29 +267,42 @@
):
pass
- inner = InnerType(inner_int=42, inner_str='This is a test string')
+ def prepare_obj():
+ inner = InnerType(inner_int=42, inner_str='This is a test string')
- outer = OuterType(
- outer_int=43,
- nested_binary=inner,
- outer_str='This is another test string'
- )
+ return OuterType(
+ outer_int=43,
+ nested_binary=inner,
+ outer_str='This is another test string'
+ )
- nested_cache.put(1, outer)
+ def check_obj(result):
+ assert result.outer_int == 43
+ assert result.outer_str == 'This is another test string'
+ assert result.nested_binary.inner_int == 42
+ assert result.nested_binary.inner_str == 'This is a test string'
- result = nested_cache.get(1)
- assert result.outer_int == 43
- assert result.outer_str == 'This is another test string'
- assert result.nested_binary.inner_int == 42
- assert result.nested_binary.inner_str == 'This is a test string'
+ async def inner_async():
+ await cache.put(1, prepare_obj())
+ check_obj(await cache.get(1))
- nested_cache.destroy()
+ def inner():
+ cache.put(1, prepare_obj())
+ check_obj(cache.get(1))
+
+ return inner_async() if isinstance(cache, AioCache) else inner()
-def test_add_schema_to_binary_object(client):
+def test_add_schema_to_binary_object(cache):
+ __check_add_schema_to_binary_object(cache)
- migrate_cache = client.get_or_create_cache('migrate_binary')
+@pytest.mark.asyncio
+async def test_add_schema_to_binary_object_async(async_cache):
+ await __check_add_schema_to_binary_object(async_cache)
+
+
+def __check_add_schema_to_binary_object(cache):
class MyBinaryType(
metaclass=GenericObjectMeta,
schema=OrderedDict([
@@ -236,54 +313,66 @@
):
pass
- binary_object = MyBinaryType(
- test_str='Test string',
- test_int=42,
- test_bool=True,
- )
- migrate_cache.put(1, binary_object)
+ def prepare_bo_v1():
+ return MyBinaryType(test_str='Test string', test_int=42, test_bool=True)
- result = migrate_cache.get(1)
- assert result.test_str == 'Test string'
- assert result.test_int == 42
- assert result.test_bool is True
+ def check_bo_v1(result):
+ assert result.test_str == 'Test string'
+ assert result.test_int == 42
+ assert result.test_bool is True
- modified_schema = MyBinaryType.schema.copy()
- modified_schema['test_decimal'] = DecimalObject
- del modified_schema['test_bool']
+ def prepare_bo_v2():
+ modified_schema = MyBinaryType.schema.copy()
+ modified_schema['test_decimal'] = DecimalObject
+ del modified_schema['test_bool']
- class MyBinaryTypeV2(
- metaclass=GenericObjectMeta,
- type_name='MyBinaryType',
- schema=modified_schema,
- ):
- pass
+ class MyBinaryTypeV2(
+ metaclass=GenericObjectMeta,
+ type_name='MyBinaryType',
+ schema=modified_schema,
+ ):
+ pass
- assert MyBinaryType.type_id == MyBinaryTypeV2.type_id
- assert MyBinaryType.schema_id != MyBinaryTypeV2.schema_id
+ assert MyBinaryType.type_id == MyBinaryTypeV2.type_id
+ assert MyBinaryType.schema_id != MyBinaryTypeV2.schema_id
- binary_object_v2 = MyBinaryTypeV2(
- test_str='Another test',
- test_int=43,
- test_decimal=Decimal('2.34')
- )
+ return MyBinaryTypeV2(test_str='Another test', test_int=43, test_decimal=Decimal('2.34'))
- migrate_cache.put(2, binary_object_v2)
+ def check_bo_v2(result):
+ assert result.test_str == 'Another test'
+ assert result.test_int == 43
+ assert result.test_decimal == Decimal('2.34')
+ assert not hasattr(result, 'test_bool')
- result = migrate_cache.get(2)
- assert result.test_str == 'Another test'
- assert result.test_int == 43
- assert result.test_decimal == Decimal('2.34')
- assert not hasattr(result, 'test_bool')
+ async def inner_async():
+ await cache.put(1, prepare_bo_v1())
+ check_bo_v1(await cache.get(1))
+ await cache.put(2, prepare_bo_v2())
+ check_bo_v2(await cache.get(2))
- migrate_cache.destroy()
+ def inner():
+ cache.put(1, prepare_bo_v1())
+ check_bo_v1(cache.get(1))
+ cache.put(2, prepare_bo_v2())
+ check_bo_v2(cache.get(2))
+
+ return inner_async() if isinstance(cache, AioCache) else inner()
-def test_complex_object_names(client):
+def test_complex_object_names(cache):
"""
Test the ability to work with Complex types, which names contains symbols
not suitable for use in Python identifiers.
"""
+ __check_complex_object_names(cache)
+
+
+@pytest.mark.asyncio
+async def test_complex_object_names_async(async_cache):
+ await __check_complex_object_names(async_cache)
+
+
+def __check_complex_object_names(cache):
type_name = 'Non.Pythonic#type-name$'
key = 'key'
data = 'test'
@@ -297,41 +386,47 @@
):
pass
- cache = client.get_or_create_cache('test_name_cache')
- cache.put(key, NonPythonicallyNamedType(field=data))
+ def check(obj):
+ assert obj.type_name == type_name, 'Complex type name mismatch'
+ assert obj.field == data, 'Complex object data failure'
- obj = cache.get(key)
- assert obj.type_name == type_name, 'Complex type name mismatch'
- assert obj.field == data, 'Complex object data failure'
+ async def inner_async():
+ await cache.put(key, NonPythonicallyNamedType(field=data))
+ check(await cache.get(key))
+
+ def inner():
+ cache.put(key, NonPythonicallyNamedType(field=data))
+ check(cache.get(key))
+
+ return inner_async() if isinstance(cache, AioCache) else inner()
-def test_complex_object_hash(client):
- """
- Test that Python client correctly calculates hash of the binary object that
- contains negative bytes.
- """
- class Internal(
- metaclass=GenericObjectMeta,
- type_name='Internal',
- schema=OrderedDict([
- ('id', IntObject),
- ('str', String),
- ])
- ):
- pass
+class Internal(
+ metaclass=GenericObjectMeta, type_name='Internal',
+ schema=OrderedDict([
+ ('id', IntObject),
+ ('str', String)
+ ])
+):
+ pass
- class TestObject(
- metaclass=GenericObjectMeta,
- type_name='TestObject',
- schema=OrderedDict([
- ('id', IntObject),
- ('str', String),
- ('internal', BinaryObject),
- ])
- ):
- pass
- obj_ascii = TestObject()
+class NestedObject(
+ metaclass=GenericObjectMeta, type_name='NestedObject',
+ schema=OrderedDict([
+ ('id', IntObject),
+ ('str', String),
+ ('internal', BinaryObject)
+ ])
+):
+ pass
+
+
+@pytest.fixture
+def complex_objects():
+ fixtures = []
+
+ obj_ascii = NestedObject()
obj_ascii.id = 1
obj_ascii.str = 'test_string'
@@ -339,11 +434,9 @@
obj_ascii.internal.id = 2
obj_ascii.internal.str = 'lorem ipsum'
- hash_ascii = BinaryObject.hashcode(obj_ascii, client=client)
+ fixtures.append((obj_ascii, -1314567146))
- assert hash_ascii == -1314567146, 'Invalid hashcode value for object with ASCII strings'
-
- obj_utf8 = TestObject()
+ obj_utf8 = NestedObject()
obj_utf8.id = 1
obj_utf8.str = 'юникод'
@@ -351,39 +444,63 @@
obj_utf8.internal.id = 2
obj_utf8.internal.str = 'ユニコード'
- hash_utf8 = BinaryObject.hashcode(obj_utf8, client=client)
+ fixtures.append((obj_utf8, -1945378474))
- assert hash_utf8 == -1945378474, 'Invalid hashcode value for object with UTF-8 strings'
+ yield fixtures
-def test_complex_object_null_fields(client):
+def test_complex_object_hash(client, complex_objects):
+ for obj, hash in complex_objects:
+ assert hash == BinaryObject.hashcode(obj, client)
+
+
+@pytest.mark.asyncio
+async def test_complex_object_hash_async(async_client, complex_objects):
+ for obj, hash in complex_objects:
+ assert hash == await BinaryObject.hashcode_async(obj, async_client)
+
+
+def camel_to_snake(name):
+ return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
+
+
+fields = {camel_to_snake(type_.__name__): type_ for type_ in [
+ ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject,
+ DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject,
+ IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject,
+ UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, String,
+ StringArrayObject, DecimalObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject,
+ BinaryObject]}
+
+
+class AllTypesObject(metaclass=GenericObjectMeta, type_name='AllTypesObject', schema=fields):
+ pass
+
+
+@pytest.fixture
+def null_fields_object():
+ res = AllTypesObject()
+
+ for field in fields.keys():
+ setattr(res, field, None)
+
+ yield res
+
+
+def test_complex_object_null_fields(cache, null_fields_object):
"""
Test that Python client can correctly write and read binary object that
contains null fields.
"""
- def camel_to_snake(name):
- return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
+ cache.put(1, null_fields_object)
+ assert cache.get(1) == null_fields_object, 'Objects mismatch'
- fields = {camel_to_snake(type_.__name__): type_ for type_ in [
- ByteObject, ShortObject, IntObject, LongObject, FloatObject, DoubleObject, CharObject, BoolObject, UUIDObject,
- DateObject, TimestampObject, TimeObject, EnumObject, BinaryEnumObject, ByteArrayObject, ShortArrayObject,
- IntArrayObject, LongArrayObject, FloatArrayObject, DoubleArrayObject, CharArrayObject, BoolArrayObject,
- UUIDArrayObject, DateArrayObject, TimestampArrayObject, TimeArrayObject, EnumArrayObject, String,
- StringArrayObject, DecimalObject, DecimalArrayObject, ObjectArrayObject, CollectionObject, MapObject,
- BinaryObject]}
- class AllTypesObject(metaclass=GenericObjectMeta, type_name='AllTypesObject', schema=fields):
- pass
-
- key = 42
- null_fields_value = AllTypesObject()
-
- for field in fields.keys():
- setattr(null_fields_value, field, None)
-
- cache = client.get_or_create_cache('all_types_test_cache')
- cache.put(key, null_fields_value)
-
- got_obj = cache.get(key)
-
- assert got_obj == null_fields_value, 'Objects mismatch'
+@pytest.mark.asyncio
+async def test_complex_object_null_fields_async(async_cache, null_fields_object):
+ """
+ Test that Python client can correctly write and read binary object that
+ contains null fields.
+ """
+ await async_cache.put(1, null_fields_object)
+ assert await async_cache.get(1) == null_fields_object, 'Objects mismatch'
diff --git a/tests/common/test_cache_class.py b/tests/common/test_cache_class.py
index 940160a..02dfa82 100644
--- a/tests/common/test_cache_class.py
+++ b/tests/common/test_cache_class.py
@@ -19,66 +19,56 @@
import pytest
from pyignite import GenericObjectMeta
-from pyignite.datatypes import (
- BoolObject, DecimalObject, FloatObject, IntObject, String,
-)
-from pyignite.datatypes.prop_codes import *
+from pyignite.datatypes import BoolObject, DecimalObject, FloatObject, IntObject, String
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_KEY_CONFIGURATION
from pyignite.exceptions import CacheError, ParameterError
def test_cache_create(client):
cache = client.get_or_create_cache('my_oop_cache')
- assert cache.name == cache.settings[PROP_NAME] == 'my_oop_cache'
- cache.destroy()
+ try:
+ assert cache.name == cache.settings[PROP_NAME] == 'my_oop_cache'
+ finally:
+ cache.destroy()
-def test_cache_remove(client):
- cache = client.get_or_create_cache('my_cache')
- cache.clear()
- assert cache.get_size() == 0
-
- cache.put_all({
- 'key_1': 1,
- 'key_2': 2,
- 'key_3': 3,
- 'key_4': 4,
- 'key_5': 5,
- })
- assert cache.get_size() == 5
-
- result = cache.remove_if_equals('key_1', 42)
- assert result is False
- assert cache.get_size() == 5
-
- result = cache.remove_if_equals('key_1', 1)
- assert result is True
- assert cache.get_size() == 4
-
- cache.remove_keys(['key_1', 'key_3', 'key_5', 'key_7'])
- assert cache.get_size() == 2
-
- cache.remove_all()
- assert cache.get_size() == 0
+@pytest.mark.asyncio
+async def test_cache_create_async(async_client):
+ cache = await async_client.get_or_create_cache('my_oop_cache')
+ try:
+ assert (await cache.name()) == (await cache.settings())[PROP_NAME] == 'my_oop_cache'
+ finally:
+ await cache.destroy()
-def test_cache_get(client):
+def test_get_cache(client):
my_cache = client.get_or_create_cache('my_cache')
- assert my_cache.settings[PROP_NAME] == 'my_cache'
- my_cache.destroy()
-
- error = None
+ try:
+ assert my_cache.settings[PROP_NAME] == 'my_cache'
+ finally:
+ my_cache.destroy()
my_cache = client.get_cache('my_cache')
- try:
+ with pytest.raises(CacheError):
_ = my_cache.settings[PROP_NAME]
- except CacheError as e:
- error = e
-
- assert type(error) is CacheError
-def test_cache_config(client):
- cache_config = {
+@pytest.mark.asyncio
+async def test_get_cache_async(async_client):
+ my_cache = await async_client.get_or_create_cache('my_cache')
+ try:
+ assert (await my_cache.settings())[PROP_NAME] == 'my_cache'
+ finally:
+ await my_cache.destroy()
+
+ my_cache = await async_client.get_cache('my_cache')
+ with pytest.raises(CacheError):
+ _ = (await my_cache.settings())[PROP_NAME]
+
+
+@pytest.fixture
+def cache_config():
+ yield {
PROP_NAME: 'my_oop_cache',
PROP_CACHE_KEY_CONFIGURATION: [
{
@@ -87,28 +77,31 @@
},
],
}
+
+
+def test_cache_config(client, cache_config):
client.create_cache(cache_config)
-
cache = client.get_or_create_cache('my_oop_cache')
- assert cache.name == cache_config[PROP_NAME]
- assert (
- cache.settings[PROP_CACHE_KEY_CONFIGURATION]
- == cache_config[PROP_CACHE_KEY_CONFIGURATION]
- )
-
- cache.destroy()
+ try:
+ assert cache.name == cache_config[PROP_NAME]
+ assert cache.settings[PROP_CACHE_KEY_CONFIGURATION] == cache_config[PROP_CACHE_KEY_CONFIGURATION]
+ finally:
+ cache.destroy()
-def test_cache_get_put(client):
- cache = client.get_or_create_cache('my_oop_cache')
- cache.put('my_key', 42)
- result = cache.get('my_key')
- assert result, 42
- cache.destroy()
+@pytest.mark.asyncio
+async def test_cache_config_async(async_client, cache_config):
+ await async_client.create_cache(cache_config)
+ cache = await async_client.get_or_create_cache('my_oop_cache')
+ try:
+ assert await cache.name() == cache_config[PROP_NAME]
+ assert (await cache.settings())[PROP_CACHE_KEY_CONFIGURATION] == cache_config[PROP_CACHE_KEY_CONFIGURATION]
+ finally:
+ await cache.destroy()
-def test_cache_binary_get_put(client):
-
+@pytest.fixture
+def binary_type_fixture():
class TestBinaryType(
metaclass=GenericObjectMeta,
schema=OrderedDict([
@@ -120,52 +113,63 @@
):
pass
- cache = client.create_cache('my_oop_cache')
-
- my_value = TestBinaryType(
+ return TestBinaryType(
test_bool=True,
test_str='This is a test',
test_int=42,
test_decimal=Decimal('34.56'),
)
- cache.put('my_key', my_value)
+
+def test_cache_binary_get_put(cache, binary_type_fixture):
+ cache.put('my_key', binary_type_fixture)
value = cache.get('my_key')
- assert value.test_bool is True
- assert value.test_str == 'This is a test'
- assert value.test_int == 42
- assert value.test_decimal == Decimal('34.56')
-
- cache.destroy()
+ assert value.test_bool == binary_type_fixture.test_bool
+ assert value.test_str == binary_type_fixture.test_str
+ assert value.test_int == binary_type_fixture.test_int
+ assert value.test_decimal == binary_type_fixture.test_decimal
-def test_get_binary_type(client):
- client.put_binary_type(
- 'TestBinaryType',
- schema=OrderedDict([
+@pytest.mark.asyncio
+async def test_cache_binary_get_put_async(async_cache, binary_type_fixture):
+ await async_cache.put('my_key', binary_type_fixture)
+
+ value = await async_cache.get('my_key')
+ assert value.test_bool == binary_type_fixture.test_bool
+ assert value.test_str == binary_type_fixture.test_str
+ assert value.test_int == binary_type_fixture.test_int
+ assert value.test_decimal == binary_type_fixture.test_decimal
+
+
+@pytest.fixture
+def binary_type_schemas_fixture():
+ schemas = [
+ OrderedDict([
('TEST_BOOL', BoolObject),
('TEST_STR', String),
('TEST_INT', IntObject),
- ])
- )
- client.put_binary_type(
- 'TestBinaryType',
- schema=OrderedDict([
+ ]),
+ OrderedDict([
('TEST_BOOL', BoolObject),
('TEST_STR', String),
('TEST_INT', IntObject),
('TEST_FLOAT', FloatObject),
- ])
- )
- client.put_binary_type(
- 'TestBinaryType',
- schema=OrderedDict([
+ ]),
+ OrderedDict([
('TEST_BOOL', BoolObject),
('TEST_STR', String),
('TEST_INT', IntObject),
('TEST_DECIMAL', DecimalObject),
])
- )
+ ]
+ yield 'TestBinaryType', schemas
+
+
+def test_get_binary_type(client, binary_type_schemas_fixture):
+ type_name, schemas = binary_type_schemas_fixture
+
+ for schema in schemas:
+ client.put_binary_type(type_name, schema=schema)
binary_type_info = client.get_binary_type('TestBinaryType')
assert len(binary_type_info['schemas']) == 3
@@ -175,60 +179,37 @@
assert len(binary_type_info) == 1
-@pytest.mark.parametrize('page_size', range(1, 17, 5))
-def test_cache_scan(request, client, page_size):
- test_data = {
- 1: 'This is a test',
- 2: 'One more test',
- 3: 'Foo',
- 4: 'Buzz',
- 5: 'Bar',
- 6: 'Lorem ipsum',
- 7: 'dolor sit amet',
- 8: 'consectetur adipiscing elit',
- 9: 'Nullam aliquet',
- 10: 'nisl at ante',
- 11: 'suscipit',
- 12: 'ut cursus',
- 13: 'metus interdum',
- 14: 'Nulla tincidunt',
- 15: 'sollicitudin iaculis',
- }
+@pytest.mark.asyncio
+async def test_get_binary_type_async(async_client, binary_type_schemas_fixture):
+ type_name, schemas = binary_type_schemas_fixture
- cache = client.get_or_create_cache(request.node.name)
- cache.put_all(test_data)
+ for schema in schemas:
+ await async_client.put_binary_type(type_name, schema=schema)
- gen = cache.scan(page_size=page_size)
- received_data = []
- for k, v in gen:
- assert k in test_data.keys()
- assert v in test_data.values()
- received_data.append((k, v))
- assert len(received_data) == len(test_data)
+ binary_type_info = await async_client.get_binary_type('TestBinaryType')
+ assert len(binary_type_info['schemas']) == 3
- cache.destroy()
+ binary_type_info = await async_client.get_binary_type('NonExistentType')
+ assert binary_type_info['type_exists'] is False
+ assert len(binary_type_info) == 1
-def test_get_and_put_if_absent(client):
- cache = client.get_or_create_cache('my_oop_cache')
-
- value = cache.get_and_put_if_absent('my_key', 42)
- assert value is None
- cache.put('my_key', 43)
- value = cache.get_and_put_if_absent('my_key', 42)
- assert value is 43
-
-
-def test_cache_get_when_cache_does_not_exist(client):
+def test_get_cache_errors(client):
cache = client.get_cache('missing-cache')
- with pytest.raises(CacheError) as e_info:
+
+ with pytest.raises(CacheError, match=r'Cache does not exist \[cacheId='):
cache.put(1, 1)
- assert str(e_info.value) == "Cache does not exist [cacheId= 1665146971]"
-
-def test_cache_create_with_none_name(client):
- with pytest.raises(ParameterError) as e_info:
+ with pytest.raises(ParameterError, match="You should supply at least cache name"):
client.create_cache(None)
- assert str(e_info.value) == "You should supply at least cache name"
+@pytest.mark.asyncio
+async def test_get_cache_errors_async(async_client):
+ cache = await async_client.get_cache('missing-cache')
+
+ with pytest.raises(CacheError, match=r'Cache does not exist \[cacheId='):
+ await cache.put(1, 1)
+
+ with pytest.raises(ParameterError, match="You should supply at least cache name"):
+ await async_client.create_cache(None)
diff --git a/tests/common/test_cache_class_sql.py b/tests/common/test_cache_class_sql.py
deleted file mode 100644
index 5f72b39..0000000
--- a/tests/common/test_cache_class_sql.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import pytest
-
-
-initial_data = [
- ('John', 'Doe', 5),
- ('Jane', 'Roe', 4),
- ('Joe', 'Bloggs', 4),
- ('Richard', 'Public', 3),
- ('Negidius', 'Numerius', 3),
- ]
-
-create_query = '''CREATE TABLE Student (
- id INT(11) PRIMARY KEY,
- first_name CHAR(24),
- last_name CHAR(32),
- grade INT(11))'''
-
-insert_query = '''INSERT INTO Student(id, first_name, last_name, grade)
-VALUES (?, ?, ?, ?)'''
-
-select_query = 'SELECT id, first_name, last_name, grade FROM Student'
-
-drop_query = 'DROP TABLE Student IF EXISTS'
-
-
-@pytest.mark.parametrize('page_size', range(1, 6, 2))
-def test_sql_fields(client, page_size):
-
- client.sql(drop_query, page_size)
-
- result = client.sql(create_query, page_size)
- assert next(result)[0] == 0
-
- for i, data_line in enumerate(initial_data, start=1):
- fname, lname, grade = data_line
- result = client.sql(
- insert_query,
- page_size,
- query_args=[i, fname, lname, grade]
- )
- assert next(result)[0] == 1
-
- result = client.sql(
- select_query,
- page_size,
- include_field_names=True,
- )
- field_names = next(result)
- assert set(field_names) == {'ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE'}
-
- data = list(result)
- assert len(data) == 5
- for row in data:
- assert len(row) == 4
-
- client.sql(drop_query, page_size)
-
-
-@pytest.mark.parametrize('page_size', range(1, 6, 2))
-def test_sql(client, page_size):
-
- client.sql(drop_query, page_size)
-
- result = client.sql(create_query, page_size)
- assert next(result)[0] == 0
-
- for i, data_line in enumerate(initial_data, start=1):
- fname, lname, grade = data_line
- result = client.sql(
- insert_query,
- page_size,
- query_args=[i, fname, lname, grade]
- )
- assert next(result)[0] == 1
-
- student = client.get_or_create_cache('SQL_PUBLIC_STUDENT')
- result = student.select_row('TRUE', page_size)
- for k, v in result:
- assert k in range(1, 6)
- assert v.FIRST_NAME in [
- 'John',
- 'Jane',
- 'Joe',
- 'Richard',
- 'Negidius',
- ]
-
- client.sql(drop_query, page_size)
diff --git a/tests/common/test_cache_composite_key_class_sql.py b/tests/common/test_cache_composite_key_class_sql.py
deleted file mode 100644
index 989a229..0000000
--- a/tests/common/test_cache_composite_key_class_sql.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections import OrderedDict
-
-from pyignite import GenericObjectMeta
-from pyignite.datatypes import (
- IntObject, String
-)
-
-
-class StudentKey(
- metaclass=GenericObjectMeta,
- type_name='test.model.StudentKey',
- schema=OrderedDict([
- ('ID', IntObject),
- ('DEPT', String)
- ])
- ):
- pass
-
-
-class Student(
- metaclass=GenericObjectMeta,
- type_name='test.model.Student',
- schema=OrderedDict([
- ('NAME', String),
- ])
- ):
- pass
-
-
-create_query = '''CREATE TABLE StudentTable (
- id INT(11),
- dept VARCHAR,
- name CHAR(24),
- PRIMARY KEY (id, dept))
- WITH "CACHE_NAME=StudentCache, KEY_TYPE=test.model.StudentKey, VALUE_TYPE=test.model.Student"'''
-
-insert_query = '''INSERT INTO StudentTable (id, dept, name) VALUES (?, ?, ?)'''
-
-select_query = 'SELECT _KEY, id, dept, name FROM StudentTable'
-
-drop_query = 'DROP TABLE StudentTable IF EXISTS'
-
-
-def test_cache_get_with_composite_key_finds_sql_value(client):
- """
- Should query a record with composite key and calculate
- internal hashcode correctly.
- """
-
- client.sql(drop_query)
-
- # Create table.
- result = client.sql(create_query)
- assert next(result)[0] == 0
-
- student_key = StudentKey(1, 'Acct')
- student_val = Student('John')
-
- # Put new Strudent with StudentKey.
- result = client.sql(insert_query, query_args=[student_key.ID, student_key.DEPT, student_val.NAME])
- assert next(result)[0] == 1
-
- # Cache get finds the same value.
- studentCache = client.get_cache('StudentCache')
- val = studentCache.get(student_key)
- assert val is not None
- assert val.NAME == student_val.NAME
-
- query_result = list(client.sql(select_query, include_field_names=True))
-
- validate_query_result(student_key, student_val, query_result)
-
-
-def test_python_sql_finds_inserted_value_with_composite_key(client):
- """
- Insert a record with a composite key and query it with SELECT SQL.
- """
-
- client.sql(drop_query)
-
- # Create table.
- result = client.sql(create_query)
- assert next(result)[0] == 0
-
- student_key = StudentKey(2, 'Business')
- student_val = Student('Abe')
-
- # Put new value using cache.
- studentCache = client.get_cache('StudentCache')
- studentCache.put(student_key, student_val)
-
- # Find the value using SQL.
- query_result = list(client.sql(select_query, include_field_names=True))
-
- validate_query_result(student_key, student_val, query_result)
-
-
-def validate_query_result(student_key, student_val, query_result):
- """
- Compare query result with expected key and value.
- """
- assert len(query_result) == 2
- sql_row = dict(zip(query_result[0], query_result[1]))
-
- assert sql_row['ID'] == student_key.ID
- assert sql_row['DEPT'] == student_key.DEPT
- assert sql_row['NAME'] == student_val.NAME
diff --git a/tests/common/test_cache_config.py b/tests/common/test_cache_config.py
index b708b0c..f4c8067 100644
--- a/tests/common/test_cache_config.py
+++ b/tests/common/test_cache_config.py
@@ -12,29 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
-from pyignite.api import *
-from pyignite.datatypes.prop_codes import *
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_KEY_CONFIGURATION
+from pyignite.exceptions import CacheError
+
+cache_name = 'config_cache'
-def test_get_configuration(client):
-
- conn = client.random_node
-
- result = cache_get_or_create(conn, 'my_unique_cache')
- assert result.status == 0
-
- result = cache_get_configuration(conn, 'my_unique_cache')
- assert result.status == 0
- assert result.value[PROP_NAME] == 'my_unique_cache'
-
-
-def test_create_with_config(client):
-
- cache_name = 'my_very_unique_name'
- conn = client.random_node
-
- result = cache_create_with_config(conn, {
+@pytest.fixture
+def cache_config():
+ return {
PROP_NAME: cache_name,
PROP_CACHE_KEY_CONFIGURATION: [
{
@@ -42,38 +30,86 @@
'affinity_key_field_name': 'abc1234',
}
],
- })
- assert result.status == 0
-
- result = cache_get_names(conn)
- assert cache_name in result.value
-
- result = cache_create_with_config(conn, {
- PROP_NAME: cache_name,
- })
- assert result.status != 0
+ }
-def test_get_or_create_with_config(client):
+@pytest.fixture
+def cache(client):
+ cache = client.get_or_create_cache(cache_name)
+ yield cache
+ cache.destroy()
- cache_name = 'my_very_unique_name'
- conn = client.random_node
- result = cache_get_or_create_with_config(conn, {
- PROP_NAME: cache_name,
- PROP_CACHE_KEY_CONFIGURATION: [
- {
- 'type_name': 'blah',
- 'affinity_key_field_name': 'abc1234',
- }
- ],
- })
- assert result.status == 0
+@pytest.fixture
+async def async_cache(async_client):
+ cache = await async_client.get_or_create_cache(cache_name)
+ yield cache
+ await cache.destroy()
- result = cache_get_names(conn)
- assert cache_name in result.value
- result = cache_get_or_create_with_config(conn, {
- PROP_NAME: cache_name,
- })
- assert result.status == 0
+@pytest.fixture
+def cache_with_config(client, cache_config):
+ cache = client.get_or_create_cache(cache_config)
+ yield cache
+ cache.destroy()
+
+
+@pytest.fixture
+async def async_cache_with_config(async_client, cache_config):
+ cache = await async_client.get_or_create_cache(cache_config)
+ yield cache
+ await cache.destroy()
+
+
+def test_cache_get_configuration(client, cache):
+ assert cache_name in client.get_cache_names()
+ assert cache.settings[PROP_NAME] == cache_name
+
+
+@pytest.mark.asyncio
+async def test_cache_get_configuration_async(async_client, async_cache):
+ assert cache_name in (await async_client.get_cache_names())
+ assert (await async_cache.settings())[PROP_NAME] == cache_name
+
+
+def test_get_or_create_with_config_existing(client, cache_with_config, cache_config):
+ assert cache_name in client.get_cache_names()
+
+ with pytest.raises(CacheError):
+ client.create_cache(cache_config)
+
+ cache = client.get_or_create_cache(cache_config)
+ assert cache.settings == cache_with_config.settings
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_with_config_existing_async(async_client, async_cache_with_config, cache_config):
+ assert cache_name in (await async_client.get_cache_names())
+
+ with pytest.raises(CacheError):
+ await async_client.create_cache(cache_config)
+
+ cache = await async_client.get_or_create_cache(cache_config)
+ assert (await cache.settings()) == (await async_cache_with_config.settings())
+
+
+def test_get_or_create_with_config_new(client, cache_config):
+ assert cache_name not in client.get_cache_names()
+ cache = client.get_or_create_cache(cache_config)
+ try:
+ assert cache_name in client.get_cache_names()
+ assert cache.settings[PROP_NAME] == cache_name
+ finally:
+ cache.destroy()
+
+
+@pytest.mark.asyncio
+async def test_get_or_create_with_config_new_async(async_client, cache_config):
+ assert cache_name not in (await async_client.get_cache_names())
+
+ cache = await async_client.get_or_create_cache(cache_config)
+ try:
+ assert cache_name in (await async_client.get_cache_names())
+ assert (await cache.settings())[PROP_NAME] == cache_name
+ finally:
+ await cache.destroy()
diff --git a/tests/common/test_datatypes.py b/tests/common/test_datatypes.py
index 83e9a60..c1aa19f 100644
--- a/tests/common/test_datatypes.py
+++ b/tests/common/test_datatypes.py
@@ -20,199 +20,239 @@
import pytest
import uuid
-from pyignite.api.key_value import cache_get, cache_put
-from pyignite.datatypes import *
+from pyignite.datatypes import (
+ ByteObject, IntObject, FloatObject, CharObject, ShortObject, BoolObject, ByteArrayObject, IntArrayObject,
+ ShortArrayObject, FloatArrayObject, BoolArrayObject, CharArrayObject, TimestampObject, String, BinaryEnumObject,
+ TimestampArrayObject, BinaryEnumArrayObject, ObjectArrayObject, CollectionObject, MapObject
+)
from pyignite.utils import unsigned
+put_get_data_params = [
+ # integers
+ (42, None),
+ (42, ByteObject),
+ (42, ShortObject),
+ (42, IntObject),
+
+ # floats
+ (3.1415, None), # True for Double but not Float
+ (3.5, FloatObject),
+
+ # char is never autodetected
+ ('Ñ‹', CharObject),
+ ('ã‚«', CharObject),
+
+ # bool
+ (True, None),
+ (False, None),
+ (True, BoolObject),
+ (False, BoolObject),
+
+ # arrays of integers
+ ([1, 2, 3, 5], None),
+ (b'buzz', ByteArrayObject),
+ (bytearray([7, 8, 8, 11]), None),
+ (bytearray([7, 8, 8, 11]), ByteArrayObject),
+ ([1, 2, 3, 5], ShortArrayObject),
+ ([1, 2, 3, 5], IntArrayObject),
+
+ # arrays of floats
+ ([2.2, 4.4, 6.6], None),
+ ([2.5, 6.5], FloatArrayObject),
+
+ # array of char
+ (['Ñ‹', 'ã‚«'], CharArrayObject),
+
+ # array of bool
+ ([True, False, True], None),
+ ([True, False], BoolArrayObject),
+ ([False, True], BoolArrayObject),
+ ([True, False, True, False], BoolArrayObject),
+
+ # string
+ ('Little Mary had a lamb', None),
+ ('This is a test', String),
+
+ # decimals
+ (decimal.Decimal('2.5'), None),
+ (decimal.Decimal('-1.3'), None),
+
+ # uuid
+ (uuid.uuid4(), None),
+
+ # date
+ (datetime(year=1998, month=4, day=6, hour=18, minute=30), None),
+
+ # no autodetection for timestamp either
+ (
+ (datetime(year=1998, month=4, day=6, hour=18, minute=30), 1000),
+ TimestampObject
+ ),
+
+ # time
+ (timedelta(days=4, hours=4, minutes=24), None),
+
+ # enum is useless in Python, except for interoperability with Java.
+ # Also no autodetection
+ ((5, 6), BinaryEnumObject),
+
+ # arrays of standard types
+ (['String 1', 'String 2'], None),
+ (['Some of us are empty', None, 'But not the others'], None),
+
+ ([decimal.Decimal('2.71828'), decimal.Decimal('100')], None),
+ ([decimal.Decimal('2.1'), None, decimal.Decimal('3.1415')], None),
+
+ ([uuid.uuid4(), uuid.uuid4()], None),
+ (
+ [
+ datetime(year=2010, month=1, day=1),
+ datetime(year=2010, month=12, day=31),
+ ],
+ None,
+ ),
+ ([timedelta(minutes=30), timedelta(hours=2)], None),
+ (
+ [
+ (datetime(year=2010, month=1, day=1), 1000),
+ (datetime(year=2010, month=12, day=31), 200),
+ ],
+ TimestampArrayObject
+ ),
+ ((-1, [(6001, 1), (6002, 2), (6003, 3)]), BinaryEnumArrayObject),
+
+ # object array
+ ((ObjectArrayObject.OBJECT, [1, 2, decimal.Decimal('3')]), ObjectArrayObject),
+
+ # collection
+ ((CollectionObject.LINKED_LIST, [1, 2, 3]), None),
+
+ # map
+ ((MapObject.HASH_MAP, {'key': 4, 5: 6.0}), None),
+ ((MapObject.LINKED_HASH_MAP, OrderedDict([('key', 4), (5, 6.0)])), None),
+]
+
@pytest.mark.parametrize(
'value, value_hint',
- [
- # integers
- (42, None),
- (42, ByteObject),
- (42, ShortObject),
- (42, IntObject),
-
- # floats
- (3.1415, None), # True for Double but not Float
- (3.5, FloatObject),
-
- # char is never autodetected
- ('Ñ‹', CharObject),
- ('ã‚«', CharObject),
-
- # bool
- (True, None),
- (False, None),
- (True, BoolObject),
- (False, BoolObject),
-
- # arrays of integers
- ([1, 2, 3, 5], None),
- (b'buzz', ByteArrayObject),
- (bytearray([7, 8, 8, 11]), None),
- (bytearray([7, 8, 8, 11]), ByteArrayObject),
- ([1, 2, 3, 5], ShortArrayObject),
- ([1, 2, 3, 5], IntArrayObject),
-
- # arrays of floats
- ([2.2, 4.4, 6.6], None),
- ([2.5, 6.5], FloatArrayObject),
-
- # array of char
- (['Ñ‹', 'ã‚«'], CharArrayObject),
-
- # array of bool
- ([True, False, True], None),
- ([True, False], BoolArrayObject),
- ([False, True], BoolArrayObject),
- ([True, False, True, False], BoolArrayObject),
-
- # string
- ('Little Mary had a lamb', None),
- ('This is a test', String),
-
- # decimals
- (decimal.Decimal('2.5'), None),
- (decimal.Decimal('-1.3'), None),
-
- # uuid
- (uuid.uuid4(), None),
-
- # date
- (datetime(year=1998, month=4, day=6, hour=18, minute=30), None),
-
- # no autodetection for timestamp either
- (
- (datetime(year=1998, month=4, day=6, hour=18, minute=30), 1000),
- TimestampObject
- ),
-
- # time
- (timedelta(days=4, hours=4, minutes=24), None),
-
- # enum is useless in Python, except for interoperability with Java.
- # Also no autodetection
- ((5, 6), BinaryEnumObject),
-
- # arrays of standard types
- (['String 1', 'String 2'], None),
- (['Some of us are empty', None, 'But not the others'], None),
-
- ([decimal.Decimal('2.71828'), decimal.Decimal('100')], None),
- ([decimal.Decimal('2.1'), None, decimal.Decimal('3.1415')], None),
-
- ([uuid.uuid4(), uuid.uuid4()], None),
- (
- [
- datetime(year=2010, month=1, day=1),
- datetime(year=2010, month=12, day=31),
- ],
- None,
- ),
- ([timedelta(minutes=30), timedelta(hours=2)], None),
- (
- [
- (datetime(year=2010, month=1, day=1), 1000),
- (datetime(year=2010, month=12, day=31), 200),
- ],
- TimestampArrayObject
- ),
- ((-1, [(6001, 1), (6002, 2), (6003, 3)]), BinaryEnumArrayObject),
-
- # object array
- ((ObjectArrayObject.OBJECT, [1, 2, decimal.Decimal('3')]), ObjectArrayObject),
-
- # collection
- ((CollectionObject.LINKED_LIST, [1, 2, 3]), None),
-
- # map
- ((MapObject.HASH_MAP, {'key': 4, 5: 6.0}), None),
- ((MapObject.LINKED_HASH_MAP, OrderedDict([('key', 4), (5, 6.0)])), None),
- ]
+ put_get_data_params
)
-def test_put_get_data(client, cache, value, value_hint):
+def test_put_get_data(cache, value, value_hint):
+ cache.put('my_key', value, value_hint=value_hint)
+ assert cache.get('my_key') == value
- conn = client.random_node
- result = cache_put(conn, cache, 'my_key', value, value_hint=value_hint)
- assert result.status == 0
+@pytest.mark.parametrize(
+ 'value, value_hint',
+ put_get_data_params
+)
+@pytest.mark.asyncio
+async def test_put_get_data_async(async_cache, value, value_hint):
+ await async_cache.put('my_key', value, value_hint=value_hint)
+ assert await async_cache.get('my_key') == value
- result = cache_get(conn, cache, 'my_key')
- assert result.status == 0
- assert result.value == value
- if isinstance(result.value, list):
- for res, val in zip(result.value, value):
- assert type(res) == type(val)
+bytearray_params = [
+ [1, 2, 3, 5],
+ (7, 8, 13, 18),
+ (-128, -1, 0, 1, 127, 255),
+]
@pytest.mark.parametrize(
'value',
- [
- [1, 2, 3, 5],
- (7, 8, 13, 18),
- (-128, -1, 0, 1, 127, 255),
- ]
+ bytearray_params
)
-def test_bytearray_from_list_or_tuple(client, cache, value):
+def test_bytearray_from_list_or_tuple(cache, value):
"""
ByteArrayObject's pythonic type is `bytearray`, but it should also accept
lists or tuples as a content.
"""
- conn = client.random_node
+ cache.put('my_key', value, value_hint=ByteArrayObject)
- result = cache_put(
- conn, cache, 'my_key', value, value_hint=ByteArrayObject
- )
- assert result.status == 0
+ assert cache.get('my_key') == bytearray([unsigned(ch, ctypes.c_ubyte) for ch in value])
- result = cache_get(conn, cache, 'my_key')
- assert result.status == 0
- assert result.value == bytearray([
- unsigned(ch, ctypes.c_ubyte) for ch in value
- ])
+
+@pytest.mark.parametrize(
+ 'value',
+ bytearray_params
+)
+@pytest.mark.asyncio
+async def test_bytearray_from_list_or_tuple_async(async_cache, value):
+ """
+ ByteArrayObject's pythonic type is `bytearray`, but it should also accept
+ lists or tuples as a content.
+ """
+
+ await async_cache.put('my_key', value, value_hint=ByteArrayObject)
+
+ result = await async_cache.get('my_key')
+ assert result == bytearray([unsigned(ch, ctypes.c_ubyte) for ch in value])
+
+
+uuid_params = [
+ 'd57babad-7bc1-4c82-9f9c-e72841b92a85',
+ '5946c0c0-2b76-479d-8694-a2e64a3968da',
+ 'a521723d-ad5d-46a6-94ad-300f850ef704',
+]
+
+uuid_table_create_sql = "CREATE TABLE test_uuid_repr (id INTEGER PRIMARY KEY, uuid_field UUID)"
+uuid_table_drop_sql = "DROP TABLE test_uuid_repr IF EXISTS"
+uuid_table_insert_sql = "INSERT INTO test_uuid_repr(id, uuid_field) VALUES (?, ?)"
+uuid_table_query_sql = "SELECT * FROM test_uuid_repr WHERE uuid_field=?"
+
+
+@pytest.fixture()
+async def uuid_table(client):
+ client.sql(uuid_table_drop_sql)
+ client.sql(uuid_table_create_sql)
+ yield None
+ client.sql(uuid_table_drop_sql)
+
+
+@pytest.fixture()
+async def uuid_table_async(async_client):
+ await async_client.sql(uuid_table_drop_sql)
+ await async_client.sql(uuid_table_create_sql)
+ yield None
+ await async_client.sql(uuid_table_drop_sql)
@pytest.mark.parametrize(
'uuid_string',
- [
- 'd57babad-7bc1-4c82-9f9c-e72841b92a85',
- '5946c0c0-2b76-479d-8694-a2e64a3968da',
- 'a521723d-ad5d-46a6-94ad-300f850ef704',
- ]
+ uuid_params
)
-def test_uuid_representation(client, uuid_string):
+def test_uuid_representation(client, uuid_string, uuid_table):
""" Test if textual UUID representation is correct. """
uuid_value = uuid.UUID(uuid_string)
- # initial cleanup
- client.sql("DROP TABLE test_uuid_repr IF EXISTS")
- # create table with UUID field
- client.sql(
- "CREATE TABLE test_uuid_repr (id INTEGER PRIMARY KEY, uuid_field UUID)"
- )
# use uuid.UUID class to insert data
- client.sql(
- "INSERT INTO test_uuid_repr(id, uuid_field) VALUES (?, ?)",
- query_args=[1, uuid_value]
- )
+ client.sql(uuid_table_insert_sql, query_args=[1, uuid_value])
# use hex string to retrieve data
- result = client.sql(
- "SELECT * FROM test_uuid_repr WHERE uuid_field='{}'".format(
- uuid_string
- )
- )
+ with client.sql(uuid_table_query_sql, query_args=[str(uuid_value)]) as cursor:
+ result = list(cursor)
- # finalize query
- result = list(result)
+ # if a line was retrieved, our test was successful
+ assert len(result) == 1
+ assert result[0][1] == uuid_value
- # final cleanup
- client.sql("DROP TABLE test_uuid_repr IF EXISTS")
- # if a line was retrieved, our test was successful
- assert len(result) == 1
- # doublecheck
- assert result[0][1] == uuid_value
+@pytest.mark.parametrize(
+ 'uuid_string',
+ uuid_params
+)
+@pytest.mark.asyncio
+async def test_uuid_representation_async(async_client, uuid_string, uuid_table_async):
+ """ Test if textual UUID representation is correct. """
+ uuid_value = uuid.UUID(uuid_string)
+
+ # use uuid.UUID class to insert data
+ await async_client.sql(uuid_table_insert_sql, query_args=[1, uuid_value])
+ # use hex string to retrieve data
+ async with async_client.sql(uuid_table_query_sql, query_args=[str(uuid_value)]) as cursor:
+ result = [row async for row in cursor]
+
+ # if a line was retrieved, our test was successful
+ assert len(result) == 1
+ assert result[0][1] == uuid_value
diff --git a/tests/common/test_generic_object.py b/tests/common/test_generic_object.py
index 73dc870..d6c0ee1 100644
--- a/tests/common/test_generic_object.py
+++ b/tests/common/test_generic_object.py
@@ -14,11 +14,10 @@
# limitations under the License.
from pyignite import GenericObjectMeta
-from pyignite.datatypes import *
+from pyignite.datatypes import IntObject, String
def test_go():
-
class GenericObject(
metaclass=GenericObjectMeta,
schema={
diff --git a/tests/common/test_get_names.py b/tests/common/test_get_names.py
index 2d6c0bc..7fcb499 100644
--- a/tests/common/test_get_names.py
+++ b/tests/common/test_get_names.py
@@ -12,21 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
-from pyignite.api import cache_create, cache_get_names
+import pytest
def test_get_names(client):
-
- conn = client.random_node
-
- bucket_names = ['my_bucket', 'my_bucket_2', 'my_bucket_3']
+ bucket_names = {'my_bucket', 'my_bucket_2', 'my_bucket_3'}
for name in bucket_names:
- cache_create(conn, name)
+ client.get_or_create_cache(name)
- result = cache_get_names(conn)
- assert result.status == 0
- assert type(result.value) == list
- assert len(result.value) >= len(bucket_names)
- for i, name in enumerate(bucket_names):
- assert name in result.value
+ assert set(client.get_cache_names()) == bucket_names
+
+
+@pytest.mark.asyncio
+async def test_get_names_async(async_client):
+ bucket_names = {'my_bucket', 'my_bucket_2', 'my_bucket_3'}
+ await asyncio.gather(*[async_client.get_or_create_cache(name) for name in bucket_names])
+
+ assert set(await async_client.get_cache_names()) == bucket_names
diff --git a/tests/common/test_key_value.py b/tests/common/test_key_value.py
index a7edce1..0f492a2 100644
--- a/tests/common/test_key_value.py
+++ b/tests/common/test_key_value.py
@@ -15,426 +15,405 @@
from datetime import datetime
-from pyignite.api import *
-from pyignite.datatypes import (
- CollectionObject, IntObject, MapObject, TimestampObject,
-)
+import pytest
+
+from pyignite.datatypes import CollectionObject, IntObject, MapObject, TimestampObject
-def test_put_get(client, cache):
+def test_put_get(cache):
+ cache.put('my_key', 5)
- conn = client.random_node
-
- result = cache_put(conn, cache, 'my_key', 5)
- assert result.status == 0
-
- result = cache_get(conn, cache, 'my_key')
- assert result.status == 0
- assert result.value == 5
+ assert cache.get('my_key') == 5
-def test_get_all(client, cache):
+@pytest.mark.asyncio
+async def test_put_get_async(async_cache):
+ await async_cache.put('my_key', 5)
- conn = client.random_node
-
- result = cache_get_all(conn, cache, ['key_1', 2, (3, IntObject)])
- assert result.status == 0
- assert result.value == {}
-
- cache_put(conn, cache, 'key_1', 4)
- cache_put(conn, cache, 3, 18, key_hint=IntObject)
-
- result = cache_get_all(conn, cache, ['key_1', 2, (3, IntObject)])
- assert result.status == 0
- assert result.value == {'key_1': 4, 3: 18}
+ assert await async_cache.get('my_key') == 5
-def test_put_all(client, cache):
+def test_get_all(cache):
+ assert cache.get_all(['key_1', 2, (3, IntObject)]) == {}
- conn = client.random_node
+ cache.put('key_1', 4)
+ cache.put(3, 18, key_hint=IntObject)
+ assert cache.get_all(['key_1', 2, (3, IntObject)]) == {'key_1': 4, 3: 18}
+
+
+@pytest.mark.asyncio
+async def test_get_all_async(async_cache):
+ assert await async_cache.get_all(['key_1', 2, (3, IntObject)]) == {}
+
+ await async_cache.put('key_1', 4)
+ await async_cache.put(3, 18, key_hint=IntObject)
+
+ assert await async_cache.get_all(['key_1', 2, (3, IntObject)]) == {'key_1': 4, 3: 18}
+
+
+def test_put_all(cache):
test_dict = {
1: 2,
'key_1': 4,
(3, IntObject): 18,
}
- test_keys = ['key_1', 1, 3]
+ cache.put_all(test_dict)
- result = cache_put_all(conn, cache, test_dict)
- assert result.status == 0
+ result = cache.get_all(list(test_dict.keys()))
- result = cache_get_all(conn, cache, test_keys)
- assert result.status == 0
- assert len(test_dict) == 3
+ assert len(result) == len(test_dict)
+ for k, v in test_dict.items():
+ k = k[0] if isinstance(k, tuple) else k
+ assert result[k] == v
- for key in result.value:
- assert key in test_keys
+@pytest.mark.asyncio
+async def test_put_all_async(async_cache):
+ test_dict = {
+ 1: 2,
+ 'key_1': 4,
+ (3, IntObject): 18,
+ }
+ await async_cache.put_all(test_dict)
-def test_contains_key(client, cache):
+ result = await async_cache.get_all(list(test_dict.keys()))
- conn = client.random_node
+ assert len(result) == len(test_dict)
+ for k, v in test_dict.items():
+ k = k[0] if isinstance(k, tuple) else k
+ assert result[k] == v
- cache_put(conn, cache, 'test_key', 42)
- result = cache_contains_key(conn, cache, 'test_key')
- assert result.value is True
+def test_contains_key(cache):
+ cache.put('test_key', 42)
- result = cache_contains_key(conn, cache, 'non-existant-key')
- assert result.value is False
+ assert cache.contains_key('test_key')
+ assert not cache.contains_key('non-existent-key')
-def test_contains_keys(client, cache):
+@pytest.mark.asyncio
+async def test_contains_key_async(async_cache):
+ await async_cache.put('test_key', 42)
- conn = client.random_node
+ assert await async_cache.contains_key('test_key')
+ assert not await async_cache.contains_key('non-existent-key')
- cache_put(conn, cache, 5, 6)
- cache_put(conn, cache, 'test_key', 42)
- result = cache_contains_keys(conn, cache, [5, 'test_key'])
- assert result.value is True
+def test_contains_keys(cache):
+ cache.put(5, 6)
+ cache.put('test_key', 42)
- result = cache_contains_keys(conn, cache, [5, 'non-existent-key'])
- assert result.value is False
+ assert cache.contains_keys([5, 'test_key'])
+ assert not cache.contains_keys([5, 'non-existent-key'])
-def test_get_and_put(client, cache):
+@pytest.mark.asyncio
+async def test_contains_keys_async(async_cache):
+ await async_cache.put(5, 6)
+ await async_cache.put('test_key', 42)
- conn = client.random_node
+ assert await async_cache.contains_keys([5, 'test_key'])
+ assert not await async_cache.contains_keys([5, 'non-existent-key'])
- result = cache_get_and_put(conn, cache, 'test_key', 42)
- assert result.status == 0
- assert result.value is None
- result = cache_get(conn, cache, 'test_key')
- assert result.status == 0
- assert result.value is 42
+def test_get_and_put(cache):
+ assert cache.get_and_put('test_key', 42) is None
+ assert cache.get('test_key') == 42
+ assert cache.get_and_put('test_key', 1234) == 42
+ assert cache.get('test_key') == 1234
- result = cache_get_and_put(conn, cache, 'test_key', 1234)
- assert result.status == 0
- assert result.value == 42
+@pytest.mark.asyncio
+async def test_get_and_put_async(async_cache):
+ assert await async_cache.get_and_put('test_key', 42) is None
+ assert await async_cache.get('test_key') == 42
+ assert await async_cache.get_and_put('test_key', 1234) == 42
+ assert await async_cache.get('test_key') == 1234
-def test_get_and_replace(client, cache):
- conn = client.random_node
+def test_get_and_replace(cache):
+ assert cache.get_and_replace('test_key', 42) is None
+ assert cache.get('test_key') is None
+ cache.put('test_key', 42)
+ assert cache.get_and_replace('test_key', 1234) == 42
- result = cache_get_and_replace(conn, cache, 'test_key', 42)
- assert result.status == 0
- assert result.value is None
- result = cache_get(conn, cache, 'test_key')
- assert result.status == 0
- assert result.value is None
+@pytest.mark.asyncio
+async def test_get_and_replace_async(async_cache):
+ assert await async_cache.get_and_replace('test_key', 42) is None
+ assert await async_cache.get('test_key') is None
+ await async_cache.put('test_key', 42)
+ assert await async_cache.get_and_replace('test_key', 1234) == 42
- cache_put(conn, cache, 'test_key', 42)
- result = cache_get_and_replace(conn, cache, 'test_key', 1234)
- assert result.status == 0
- assert result.value == 42
+def test_get_and_remove(cache):
+ assert cache.get_and_remove('test_key') is None
+ cache.put('test_key', 42)
+ assert cache.get_and_remove('test_key') == 42
+ assert cache.get_and_remove('test_key') is None
-def test_get_and_remove(client, cache):
+@pytest.mark.asyncio
+async def test_get_and_remove_async(async_cache):
+ assert await async_cache.get_and_remove('test_key') is None
+ await async_cache.put('test_key', 42)
+ assert await async_cache.get_and_remove('test_key') == 42
+ assert await async_cache.get_and_remove('test_key') is None
- conn = client.random_node
- result = cache_get_and_remove(conn, cache, 'test_key')
- assert result.status == 0
- assert result.value is None
+def test_put_if_absent(cache):
+ assert cache.put_if_absent('test_key', 42)
+ assert not cache.put_if_absent('test_key', 1234)
- cache_put(conn, cache, 'test_key', 42)
- result = cache_get_and_remove(conn, cache, 'test_key')
- assert result.status == 0
- assert result.value == 42
+@pytest.mark.asyncio
+async def test_put_if_absent_async(async_cache):
+ assert await async_cache.put_if_absent('test_key', 42)
+ assert not await async_cache.put_if_absent('test_key', 1234)
-def test_put_if_absent(client, cache):
+def test_get_and_put_if_absent(cache):
+ assert cache.get_and_put_if_absent('test_key', 42) is None
+ assert cache.get_and_put_if_absent('test_key', 1234) == 42
+ assert cache.get_and_put_if_absent('test_key', 5678) == 42
+ assert cache.get('test_key') == 42
- conn = client.random_node
- result = cache_put_if_absent(conn, cache, 'test_key', 42)
- assert result.status == 0
- assert result.value is True
+@pytest.mark.asyncio
+async def test_get_and_put_if_absent_async(async_cache):
+ assert await async_cache.get_and_put_if_absent('test_key', 42) is None
+ assert await async_cache.get_and_put_if_absent('test_key', 1234) == 42
+ assert await async_cache.get_and_put_if_absent('test_key', 5678) == 42
+ assert await async_cache.get('test_key') == 42
- result = cache_put_if_absent(conn, cache, 'test_key', 1234)
- assert result.status == 0
- assert result.value is False
+def test_replace(cache):
+ assert cache.replace('test_key', 42) is False
+ cache.put('test_key', 1234)
+ assert cache.replace('test_key', 42) is True
+ assert cache.get('test_key') == 42
-def test_get_and_put_if_absent(client, cache):
- conn = client.random_node
+@pytest.mark.asyncio
+async def test_replace_async(async_cache):
+ assert await async_cache.replace('test_key', 42) is False
+ await async_cache.put('test_key', 1234)
+ assert await async_cache.replace('test_key', 42) is True
+ assert await async_cache.get('test_key') == 42
- result = cache_get_and_put_if_absent(conn, cache, 'test_key', 42)
- assert result.status == 0
- assert result.value is None
- result = cache_get_and_put_if_absent(conn, cache, 'test_key', 1234)
- assert result.status == 0
- assert result.value == 42
+def test_replace_if_equals(cache):
+ assert cache.replace_if_equals('my_test', 42, 1234) is False
+ cache.put('my_test', 42)
+ assert cache.replace_if_equals('my_test', 42, 1234) is True
+ assert cache.get('my_test') == 1234
- result = cache_get_and_put_if_absent(conn, cache, 'test_key', 5678)
- assert result.status == 0
- assert result.value == 42
+@pytest.mark.asyncio
+async def test_replace_if_equals_async(async_cache):
+ assert await async_cache.replace_if_equals('my_test', 42, 1234) is False
+ await async_cache.put('my_test', 42)
+ assert await async_cache.replace_if_equals('my_test', 42, 1234) is True
+ assert await async_cache.get('my_test') == 1234
-def test_replace(client, cache):
- conn = client.random_node
+def test_clear(cache):
+ cache.put('my_test', 42)
+ cache.clear()
+ assert cache.get('my_test') is None
- result = cache_replace(conn, cache, 'test_key', 42)
- assert result.status == 0
- assert result.value is False
- cache_put(conn, cache, 'test_key', 1234)
+@pytest.mark.asyncio
+async def test_clear_async(async_cache):
+ await async_cache.put('my_test', 42)
+ await async_cache.clear()
+ assert await async_cache.get('my_test') is None
- result = cache_replace(conn, cache, 'test_key', 42)
- assert result.status == 0
- assert result.value is True
- result = cache_get(conn, cache, 'test_key')
- assert result.status == 0
- assert result.value == 42
+def test_clear_key(cache):
+ cache.put('my_test', 42)
+ cache.put('another_test', 24)
+ cache.clear_key('my_test')
-def test_replace_if_equals(client, cache):
-
- conn = client.random_node
-
- result = cache_replace_if_equals(conn, cache, 'my_test', 42, 1234)
- assert result.status == 0
- assert result.value is False
-
- cache_put(conn, cache, 'my_test', 42)
-
- result = cache_replace_if_equals(conn, cache, 'my_test', 42, 1234)
- assert result.status == 0
- assert result.value is True
-
- result = cache_get(conn, cache, 'my_test')
- assert result.status == 0
- assert result.value == 1234
-
-
-def test_clear(client, cache):
-
- conn = client.random_node
-
- result = cache_put(conn, cache, 'my_test', 42)
- assert result.status == 0
-
- result = cache_clear(conn, cache)
- assert result.status == 0
-
- result = cache_get(conn, cache, 'my_test')
- assert result.status == 0
- assert result.value is None
-
-
-def test_clear_key(client, cache):
-
- conn = client.random_node
-
- result = cache_put(conn, cache, 'my_test', 42)
- assert result.status == 0
-
- result = cache_put(conn, cache, 'another_test', 24)
- assert result.status == 0
-
- result = cache_clear_key(conn, cache, 'my_test')
- assert result.status == 0
-
- result = cache_get(conn, cache, 'my_test')
- assert result.status == 0
- assert result.value is None
-
- result = cache_get(conn, cache, 'another_test')
- assert result.status == 0
- assert result.value == 24
-
-
-def test_clear_keys(client, cache):
-
- conn = client.random_node
-
- result = cache_put(conn, cache, 'my_test_key', 42)
- assert result.status == 0
-
- result = cache_put(conn, cache, 'another_test', 24)
- assert result.status == 0
-
- result = cache_clear_keys(conn, cache, [
- 'my_test_key',
- 'nonexistent_key',
- ])
- assert result.status == 0
+ assert cache.get('my_test') is None
+ assert cache.get('another_test') == 24
- result = cache_get(conn, cache, 'my_test_key')
- assert result.status == 0
- assert result.value is None
- result = cache_get(conn, cache, 'another_test')
- assert result.status == 0
- assert result.value == 24
+@pytest.mark.asyncio
+async def test_clear_key_async(async_cache):
+ await async_cache.put('my_test', 42)
+ await async_cache.put('another_test', 24)
+ await async_cache.clear_key('my_test')
-def test_remove_key(client, cache):
+ assert await async_cache.get('my_test') is None
+ assert await async_cache.get('another_test') == 24
- conn = client.random_node
- result = cache_put(conn, cache, 'my_test_key', 42)
- assert result.status == 0
+def test_clear_keys(cache):
+ cache.put('my_test_key', 42)
+ cache.put('another_test', 24)
- result = cache_remove_key(conn, cache, 'my_test_key')
- assert result.status == 0
- assert result.value is True
+ cache.clear_keys(['my_test_key', 'nonexistent_key'])
- result = cache_remove_key(conn, cache, 'non_existent_key')
- assert result.status == 0
- assert result.value is False
+ assert cache.get('my_test_key') is None
+ assert cache.get('another_test') == 24
-def test_remove_if_equals(client, cache):
+@pytest.mark.asyncio
+async def test_clear_keys_async(async_cache):
+ await async_cache.put('my_test_key', 42)
+ await async_cache.put('another_test', 24)
- conn = client.random_node
+ await async_cache.clear_keys(['my_test_key', 'nonexistent_key'])
- result = cache_put(conn, cache, 'my_test', 42)
- assert result.status == 0
+ assert await async_cache.get('my_test_key') is None
+ assert await async_cache.get('another_test') == 24
- result = cache_remove_if_equals(conn, cache, 'my_test', 1234)
- assert result.status == 0
- assert result.value is False
- result = cache_remove_if_equals(conn, cache, 'my_test', 42)
- assert result.status == 0
- assert result.value is True
+def test_remove_key(cache):
+ cache.put('my_test_key', 42)
+ assert cache.remove_key('my_test_key') is True
+ assert cache.remove_key('non_existent_key') is False
- result = cache_get(conn, cache, 'my_test')
- assert result.status == 0
- assert result.value is None
+@pytest.mark.asyncio
+async def test_remove_key_async(async_cache):
+ await async_cache.put('my_test_key', 42)
+ assert await async_cache.remove_key('my_test_key') is True
+ assert await async_cache.remove_key('non_existent_key') is False
-def test_remove_keys(client, cache):
- conn = client.random_node
+def test_remove_if_equals(cache):
+ cache.put('my_test', 42)
+ assert cache.remove_if_equals('my_test', 1234) is False
+ assert cache.remove_if_equals('my_test', 42) is True
+ assert cache.get('my_test') is None
- result = cache_put(conn, cache, 'my_test', 42)
- assert result.status == 0
- result = cache_put(conn, cache, 'another_test', 24)
- assert result.status == 0
+@pytest.mark.asyncio
+async def test_remove_if_equals_async(async_cache):
+ await async_cache.put('my_test', 42)
+ assert await async_cache.remove_if_equals('my_test', 1234) is False
+ assert await async_cache.remove_if_equals('my_test', 42) is True
+ assert await async_cache.get('my_test') is None
- result = cache_remove_keys(conn, cache, ['my_test', 'non_existent'])
- assert result.status == 0
- result = cache_get(conn, cache, 'my_test')
- assert result.status == 0
- assert result.value is None
+def test_remove_keys(cache):
+ cache.put('my_test', 42)
- result = cache_get(conn, cache, 'another_test')
- assert result.status == 0
- assert result.value == 24
+ cache.put('another_test', 24)
+ cache.remove_keys(['my_test', 'non_existent'])
+ assert cache.get('my_test') is None
+ assert cache.get('another_test') == 24
-def test_remove_all(client, cache):
- conn = client.random_node
+@pytest.mark.asyncio
+async def test_remove_keys_async(async_cache):
+ await async_cache.put('my_test', 42)
- result = cache_put(conn, cache, 'my_test', 42)
- assert result.status == 0
+ await async_cache.put('another_test', 24)
+ await async_cache.remove_keys(['my_test', 'non_existent'])
- result = cache_put(conn, cache, 'another_test', 24)
- assert result.status == 0
+ assert await async_cache.get('my_test') is None
+ assert await async_cache.get('another_test') == 24
- result = cache_remove_all(conn, cache)
- assert result.status == 0
- result = cache_get(conn, cache, 'my_test')
- assert result.status == 0
- assert result.value is None
+def test_remove_all(cache):
+ cache.put('my_test', 42)
+ cache.put('another_test', 24)
+ cache.remove_all()
- result = cache_get(conn, cache, 'another_test')
- assert result.status == 0
- assert result.value is None
+ assert cache.get('my_test') is None
+ assert cache.get('another_test') is None
-def test_cache_get_size(client, cache):
+@pytest.mark.asyncio
+async def test_remove_all_async(async_cache):
+ await async_cache.put('my_test', 42)
+ await async_cache.put('another_test', 24)
+ await async_cache.remove_all()
- conn = client.random_node
+ assert await async_cache.get('my_test') is None
+ assert await async_cache.get('another_test') is None
- result = cache_put(conn, cache, 'my_test', 42)
- assert result.status == 0
- result = cache_get_size(conn, cache)
- assert result.status == 0
- assert result.value == 1
+def test_cache_get_size(cache):
+ cache.put('my_test', 42)
+ assert cache.get_size() == 1
-def test_put_get_collection(client):
+@pytest.mark.asyncio
+async def test_cache_get_size_async(async_cache):
+ await async_cache.put('my_test', 42)
+ assert await async_cache.get_size() == 1
- test_datetime = datetime(year=1996, month=3, day=1)
- cache = client.get_or_create_cache('test_coll_cache')
- cache.put(
+collection_params = [
+ [
'simple',
- (
- 1,
- [
- (123, IntObject),
- 678,
- None,
- 55.2,
- ((test_datetime, 0), TimestampObject),
- ]
- ),
- value_hint=CollectionObject
- )
- value = cache.get('simple')
- assert value == (1, [123, 678, None, 55.2, (test_datetime, 0)])
-
- cache.put(
+ (1, [(123, IntObject), 678, None, 55.2, ((datetime(year=1996, month=3, day=1), 0), TimestampObject)]),
+ (1, [123, 678, None, 55.2, (datetime(year=1996, month=3, day=1), 0)])
+ ],
+ [
'nested',
- (
- 1,
- [
- 123,
- ((1, [456, 'inner_test_string', 789]), CollectionObject),
- 'outer_test_string',
- ]
- ),
- value_hint=CollectionObject
- )
- value = cache.get('nested')
- assert value == (
- 1,
- [
- 123,
- (1, [456, 'inner_test_string', 789]),
- 'outer_test_string'
- ]
- )
-
-
-def test_put_get_map(client):
-
- cache = client.get_or_create_cache('test_map_cache')
-
- cache.put(
- 'test_map',
+ (1, [123, ((1, [456, 'inner_test_string', 789]), CollectionObject), 'outer_test_string']),
+ (1, [123, (1, [456, 'inner_test_string', 789]), 'outer_test_string'])
+ ],
+ [
+ 'hash_map',
(
MapObject.HASH_MAP,
{
(123, IntObject): 'test_data',
456: ((1, [456, 'inner_test_string', 789]), CollectionObject),
'test_key': 32.4,
+ 'simple_strings': ['string_1', 'string_2']
}
),
- value_hint=MapObject
- )
- value = cache.get('test_map')
- assert value == (MapObject.HASH_MAP, {
- 123: 'test_data',
- 456: (1, [456, 'inner_test_string', 789]),
- 'test_key': 32.4,
- })
+ (
+ MapObject.HASH_MAP,
+ {
+ 123: 'test_data',
+ 456: (1, [456, 'inner_test_string', 789]),
+ 'test_key': 32.4,
+ 'simple_strings': ['string_1', 'string_2']
+ }
+ )
+ ],
+ [
+ 'linked_hash_map',
+ (
+ MapObject.LINKED_HASH_MAP,
+ {
+ 'test_data': 12345,
+ 456: ['string_1', 'string_2'],
+ 'test_key': 32.4
+ }
+ ),
+ (
+ MapObject.LINKED_HASH_MAP,
+ {
+ 'test_data': 12345,
+ 456: ['string_1', 'string_2'],
+ 'test_key': 32.4
+ }
+ )
+ ],
+]
+
+
+@pytest.mark.parametrize(['key', 'hinted_value', 'value'], collection_params)
+def test_put_get_collection(cache, key, hinted_value, value):
+ cache.put(key, hinted_value)
+ assert cache.get(key) == value
+
+
+@pytest.mark.parametrize(['key', 'hinted_value', 'value'], collection_params)
+@pytest.mark.asyncio
+async def test_put_get_collection_async(async_cache, key, hinted_value, value):
+ await async_cache.put(key, hinted_value)
+ assert await async_cache.get(key) == value
diff --git a/tests/common/test_scan.py b/tests/common/test_scan.py
index 2f0e056..d55fd3e 100644
--- a/tests/common/test_scan.py
+++ b/tests/common/test_scan.py
@@ -12,57 +12,153 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import OrderedDict
-from pyignite.api import (
- scan, scan_cursor_get_page, resource_close, cache_put_all,
-)
+import pytest
+
+from pyignite import GenericObjectMeta
+from pyignite.api import resource_close, resource_close_async
+from pyignite.connection import AioConnection
+from pyignite.datatypes import IntObject, String
+from pyignite.exceptions import CacheError
-def test_scan(client, cache):
-
- conn = client.random_node
- page_size = 10
-
- result = cache_put_all(conn, cache, {
- 'key_{}'.format(v): v for v in range(page_size * 2)
- })
- assert result.status == 0
-
- result = scan(conn, cache, page_size)
- assert result.status == 0
- assert len(result.value['data']) == page_size
- assert result.value['more'] is True
-
- cursor = result.value['cursor']
-
- result = scan_cursor_get_page(conn, cursor)
- assert result.status == 0
- assert len(result.value['data']) == page_size
- assert result.value['more'] is False
-
- result = scan_cursor_get_page(conn, cursor)
- assert result.status != 0
+class SimpleObject(
+ metaclass=GenericObjectMeta,
+ type_name='SimpleObject',
+ schema=OrderedDict([
+ ('id', IntObject),
+ ('str', String),
+ ])
+):
+ pass
-def test_close_resource(client, cache):
+page_size = 10
- conn = client.random_node
- page_size = 10
- result = cache_put_all(conn, cache, {
- 'key_{}'.format(v): v for v in range(page_size * 2)
- })
- assert result.status == 0
+@pytest.fixture
+def test_objects_data():
+ yield {i: SimpleObject(id=i, str=f'str_{i}') for i in range(page_size * 2)}
- result = scan(conn, cache, page_size)
- assert result.status == 0
- assert len(result.value['data']) == page_size
- assert result.value['more'] is True
- cursor = result.value['cursor']
+@pytest.mark.asyncio
+def test_scan_objects(cache, test_objects_data):
+ cache.put_all(test_objects_data)
- result = resource_close(conn, cursor)
- assert result.status == 0
+ for p_sz in [page_size, page_size * 2, page_size * 3, page_size + 5]:
+ with cache.scan(p_sz) as cursor:
+ result = {k: v for k, v in cursor}
+ assert result == test_objects_data
- result = scan_cursor_get_page(conn, cursor)
- assert result.status != 0
+ __check_cursor_closed(cursor)
+
+ with pytest.raises(Exception):
+ with cache.scan(p_sz) as cursor:
+ for _ in cursor:
+ raise Exception
+
+ __check_cursor_closed(cursor)
+
+ cursor = cache.scan(page_size)
+ assert {k: v for k, v in cursor} == test_objects_data
+ __check_cursor_closed(cursor)
+
+
+@pytest.mark.asyncio
+async def test_scan_objects_async(async_cache, test_objects_data):
+ await async_cache.put_all(test_objects_data)
+
+ for p_sz in [page_size, page_size * 2, page_size * 3, page_size + 5]:
+ async with async_cache.scan(p_sz) as cursor:
+ result = {k: v async for k, v in cursor}
+ assert result == test_objects_data
+
+ await __check_cursor_closed(cursor)
+
+ with pytest.raises(Exception):
+ async with async_cache.scan(p_sz) as cursor:
+ async for _ in cursor:
+ raise Exception
+
+ await __check_cursor_closed(cursor)
+
+ cursor = await async_cache.scan(page_size)
+ assert {k: v async for k, v in cursor} == test_objects_data
+
+ await __check_cursor_closed(cursor)
+
+
+@pytest.fixture
+def cache_scan_data():
+ yield {
+ 1: 'This is a test',
+ 2: 'One more test',
+ 3: 'Foo',
+ 4: 'Buzz',
+ 5: 'Bar',
+ 6: 'Lorem ipsum',
+ 7: 'dolor sit amet',
+ 8: 'consectetur adipiscing elit',
+ 9: 'Nullam aliquet',
+ 10: 'nisl at ante',
+ 11: 'suscipit',
+ 12: 'ut cursus',
+ 13: 'metus interdum',
+ 14: 'Nulla tincidunt',
+ 15: 'sollicitudin iaculis',
+ }
+
+
+@pytest.mark.parametrize('page_size', range(1, 17, 5))
+def test_cache_scan(cache, cache_scan_data, page_size):
+ cache.put_all(cache_scan_data)
+
+ with cache.scan(page_size=page_size) as cursor:
+ assert {k: v for k, v in cursor} == cache_scan_data
+
+
+@pytest.mark.parametrize('page_size', range(1, 17, 5))
+@pytest.mark.asyncio
+async def test_cache_scan_async(async_cache, cache_scan_data, page_size):
+ await async_cache.put_all(cache_scan_data)
+
+ async with async_cache.scan(page_size=page_size) as cursor:
+ assert {k: v async for k, v in cursor} == cache_scan_data
+
+
+def test_uninitialized_cursor(cache, test_objects_data):
+ cache.put_all(test_objects_data)
+
+ cursor = cache.scan(page_size)
+ for _ in cursor:
+ break
+
+ cursor.close()
+ __check_cursor_closed(cursor)
+
+
+@pytest.mark.asyncio
+async def test_uninitialized_cursor_async(async_cache, test_objects_data):
+ await async_cache.put_all(test_objects_data)
+
+ # iterating of non-awaited cursor.
+ with pytest.raises(CacheError):
+ cursor = async_cache.scan(page_size)
+ assert {k: v async for k, v in cursor} == test_objects_data
+
+ cursor = await async_cache.scan(page_size)
+ assert {k: v async for k, v in cursor} == test_objects_data
+ await __check_cursor_closed(cursor)
+
+
+def __check_cursor_closed(cursor):
+ async def check_async():
+ result = await resource_close_async(cursor.connection, cursor.cursor_id)
+ assert result.status != 0
+
+ def check():
+ result = resource_close(cursor.connection, cursor.cursor_id)
+ assert result.status != 0
+
+ return check_async() if isinstance(cursor.connection, AioConnection) else check()
diff --git a/tests/common/test_sql.py b/tests/common/test_sql.py
index cc68a02..0841b7f 100644
--- a/tests/common/test_sql.py
+++ b/tests/common/test_sql.py
@@ -15,160 +15,173 @@
import pytest
-from pyignite.api import (
- sql_fields, sql_fields_cursor_get_page,
- sql, sql_cursor_get_page,
- cache_get_configuration,
-)
+from pyignite import AioClient
+from pyignite.aio_cache import AioCache
from pyignite.datatypes.cache_config import CacheMode
-from pyignite.datatypes.prop_codes import *
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_SQL_SCHEMA, PROP_QUERY_ENTITIES, PROP_CACHE_MODE
from pyignite.exceptions import SQLError
from pyignite.utils import entity_id
-from pyignite.binary import unwrap_binary
-initial_data = [
- ('John', 'Doe', 5),
- ('Jane', 'Roe', 4),
- ('Joe', 'Bloggs', 4),
- ('Richard', 'Public', 3),
- ('Negidius', 'Numerius', 3),
- ]
+student_table_data = [
+ ('John', 'Doe', 5),
+ ('Jane', 'Roe', 4),
+ ('Joe', 'Bloggs', 4),
+ ('Richard', 'Public', 3),
+ ('Negidius', 'Numerius', 3),
+]
-create_query = '''CREATE TABLE Student (
- id INT(11) PRIMARY KEY,
- first_name CHAR(24),
- last_name CHAR(32),
- grade INT(11))'''
-
-insert_query = '''INSERT INTO Student(id, first_name, last_name, grade)
-VALUES (?, ?, ?, ?)'''
-
-select_query = 'SELECT id, first_name, last_name, grade FROM Student'
-
-drop_query = 'DROP TABLE Student IF EXISTS'
-
-page_size = 4
+student_table_select_query = 'SELECT id, first_name, last_name, grade FROM Student ORDER BY ID ASC'
-def test_sql(client):
-
- conn = client.random_node
-
- # cleanup
- client.sql(drop_query)
-
- result = sql_fields(
- conn,
- 0,
- create_query,
- page_size,
- schema='PUBLIC',
- include_field_names=True
- )
- assert result.status == 0, result.message
-
- for i, data_line in enumerate(initial_data, start=1):
- fname, lname, grade = data_line
- result = sql_fields(
- conn,
- 0,
- insert_query,
- page_size,
- schema='PUBLIC',
- query_args=[i, fname, lname, grade],
- include_field_names=True
- )
- assert result.status == 0, result.message
-
- result = cache_get_configuration(conn, 'SQL_PUBLIC_STUDENT')
- assert result.status == 0, result.message
-
- binary_type_name = result.value[PROP_QUERY_ENTITIES][0]['value_type_name']
- result = sql(
- conn,
- 'SQL_PUBLIC_STUDENT',
- binary_type_name,
- 'TRUE',
- page_size
- )
- assert result.status == 0, result.message
- assert len(result.value['data']) == page_size
- assert result.value['more'] is True
-
- for wrapped_object in result.value['data'].values():
- data = unwrap_binary(client, wrapped_object)
- assert data.type_id == entity_id(binary_type_name)
-
- cursor = result.value['cursor']
-
- while result.value['more']:
- result = sql_cursor_get_page(conn, cursor)
- assert result.status == 0, result.message
-
- for wrapped_object in result.value['data'].values():
- data = unwrap_binary(client, wrapped_object)
- assert data.type_id == entity_id(binary_type_name)
-
- # repeat cleanup
- result = sql_fields(conn, 0, drop_query, page_size, schema='PUBLIC')
- assert result.status == 0
+@pytest.fixture
+def student_table_fixture(client):
+ yield from __create_student_table_fixture(client)
-def test_sql_fields(client):
-
- conn = client.random_node
-
- # cleanup
- client.sql(drop_query)
-
- result = sql_fields(
- conn,
- 0,
- create_query,
- page_size,
- schema='PUBLIC',
- include_field_names=True
- )
- assert result.status == 0, result.message
-
- for i, data_line in enumerate(initial_data, start=1):
- fname, lname, grade = data_line
- result = sql_fields(
- conn,
- 0,
- insert_query,
- page_size,
- schema='PUBLIC',
- query_args=[i, fname, lname, grade],
- include_field_names=True
- )
- assert result.status == 0, result.message
-
- result = sql_fields(
- conn,
- 0,
- select_query,
- page_size,
- schema='PUBLIC',
- include_field_names=True
- )
- assert result.status == 0
- assert len(result.value['data']) == page_size
- assert result.value['more'] is True
-
- cursor = result.value['cursor']
-
- result = sql_fields_cursor_get_page(conn, cursor, field_count=4)
- assert result.status == 0
- assert len(result.value['data']) == len(initial_data) - page_size
- assert result.value['more'] is False
-
- # repeat cleanup
- result = sql_fields(conn, 0, drop_query, page_size, schema='PUBLIC')
- assert result.status == 0
+@pytest.fixture
+async def async_student_table_fixture(async_client):
+ async for _ in __create_student_table_fixture(async_client):
+ yield
-def test_long_multipage_query(client):
+def __create_student_table_fixture(client):
+ create_query = '''CREATE TABLE Student (
+ id INT(11) PRIMARY KEY,
+ first_name CHAR(24),
+ last_name CHAR(32),
+ grade INT(11))'''
+
+ insert_query = '''INSERT INTO Student(id, first_name, last_name, grade)
+ VALUES (?, ?, ?, ?)'''
+
+ drop_query = 'DROP TABLE Student IF EXISTS'
+
+ def inner():
+ client.sql(drop_query)
+ client.sql(create_query)
+
+ for i, data_line in enumerate(student_table_data):
+ fname, lname, grade = data_line
+ client.sql(insert_query, query_args=[i, fname, lname, grade])
+
+ yield None
+ client.sql(drop_query)
+
+ async def inner_async():
+ await client.sql(drop_query)
+ await client.sql(create_query)
+
+ for i, data_line in enumerate(student_table_data):
+ fname, lname, grade = data_line
+ await client.sql(insert_query, query_args=[i, fname, lname, grade])
+
+ yield None
+ await client.sql(drop_query)
+
+ return inner_async() if isinstance(client, AioClient) else inner()
+
+
+@pytest.mark.parametrize('page_size', range(1, 6, 2))
+def test_sql(client, student_table_fixture, page_size):
+ cache = client.get_cache('SQL_PUBLIC_STUDENT')
+ cache_config = cache.settings
+
+ binary_type_name = cache_config[PROP_QUERY_ENTITIES][0]['value_type_name']
+
+ with cache.select_row('ORDER BY ID ASC', page_size=4) as cursor:
+ for i, row in enumerate(cursor):
+ k, v = row
+ assert k == i
+
+ assert (v.FIRST_NAME, v.LAST_NAME, v.GRADE) == student_table_data[i]
+ assert v.type_id == entity_id(binary_type_name)
+
+
+@pytest.mark.parametrize('page_size', range(1, 6, 2))
+def test_sql_fields(client, student_table_fixture, page_size):
+ with client.sql(student_table_select_query, page_size=page_size, include_field_names=True) as cursor:
+ for i, row in enumerate(cursor):
+ if i > 0:
+ assert tuple(row) == (i - 1,) + student_table_data[i - 1]
+ else:
+ assert row == ['ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE']
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('page_size', range(1, 6, 2))
+async def test_sql_fields_async(async_client, async_student_table_fixture, page_size):
+ async with async_client.sql(student_table_select_query, page_size=page_size, include_field_names=True) as cursor:
+ i = 0
+ async for row in cursor:
+ if i > 0:
+ assert tuple(row) == (i - 1,) + student_table_data[i - 1]
+ else:
+ assert row == ['ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE']
+ i += 1
+
+ cursor = await async_client.sql(student_table_select_query, page_size=page_size, include_field_names=True)
+ try:
+ i = 0
+ async for row in cursor:
+ if i > 0:
+ assert tuple(row) == (i - 1,) + student_table_data[i - 1]
+ else:
+ assert row == ['ID', 'FIRST_NAME', 'LAST_NAME', 'GRADE']
+ i += 1
+ finally:
+ await cursor.close()
+
+
+multipage_fields = ["id", "abc", "ghi", "def", "jkl", "prs", "mno", "tuw", "zyz", "abc1", "def1", "jkl1", "prs1"]
+
+
+@pytest.fixture
+def long_multipage_table_fixture(client):
+ yield from __long_multipage_table_fixture(client)
+
+
+@pytest.fixture
+async def async_long_multipage_table_fixture(async_client):
+ async for _ in __long_multipage_table_fixture(async_client):
+ yield
+
+
+def __long_multipage_table_fixture(client):
+ drop_query = 'DROP TABLE LongMultipageQuery IF EXISTS'
+
+ create_query = "CREATE TABLE LongMultiPageQuery (%s, %s)" % (
+ multipage_fields[0] + " INT(11) PRIMARY KEY", ",".join(map(lambda f: f + " INT(11)", multipage_fields[1:])))
+
+ insert_query = "INSERT INTO LongMultipageQuery (%s) VALUES (%s)" % (
+ ",".join(multipage_fields), ",".join("?" * len(multipage_fields)))
+
+ def query_args(_id):
+ return [_id] + list(i * _id for i in range(1, len(multipage_fields)))
+
+ def inner():
+ client.sql(drop_query)
+ client.sql(create_query)
+
+ for i in range(1, 21):
+ client.sql(insert_query, query_args=query_args(i))
+ yield None
+
+ client.sql(drop_query)
+
+ async def inner_async():
+ await client.sql(drop_query)
+ await client.sql(create_query)
+
+ for i in range(1, 21):
+ await client.sql(insert_query, query_args=query_args(i))
+ yield None
+
+ await client.sql(drop_query)
+
+ return inner_async() if isinstance(client, AioClient) else inner()
+
+
+def test_long_multipage_query(client, long_multipage_table_fixture):
"""
The test creates a table with 13 columns (id and 12 enumerated columns)
and 20 records with id in range from 1 to 20. Values of enumerated columns
@@ -177,25 +190,20 @@
The goal is to ensure that all the values are selected in a right order.
"""
- fields = ["id", "abc", "ghi", "def", "jkl", "prs", "mno", "tuw", "zyz", "abc1", "def1", "jkl1", "prs1"]
+ with client.sql('SELECT * FROM LongMultipageQuery', page_size=1) as cursor:
+ for page in cursor:
+ assert len(page) == len(multipage_fields)
+ for field_number, value in enumerate(page[1:], start=1):
+ assert value == field_number * page[0]
- client.sql('DROP TABLE LongMultipageQuery IF EXISTS')
- client.sql("CREATE TABLE LongMultiPageQuery (%s, %s)" %
- (fields[0] + " INT(11) PRIMARY KEY", ",".join(map(lambda f: f + " INT(11)", fields[1:]))))
-
- for id in range(1, 21):
- client.sql(
- "INSERT INTO LongMultipageQuery (%s) VALUES (%s)" % (",".join(fields), ",".join("?" * len(fields))),
- query_args=[id] + list(i * id for i in range(1, len(fields))))
-
- result = client.sql('SELECT * FROM LongMultipageQuery', page_size=1)
- for page in result:
- assert len(page) == len(fields)
- for field_number, value in enumerate(page[1:], start=1):
- assert value == field_number * page[0]
-
- client.sql(drop_query)
+@pytest.mark.asyncio
+async def test_long_multipage_query_async(async_client, async_long_multipage_table_fixture):
+ async with async_client.sql('SELECT * FROM LongMultipageQuery', page_size=1) as cursor:
+ async for page in cursor:
+ assert len(page) == len(multipage_fields)
+ for field_number, value in enumerate(page[1:], start=1):
+ assert value == field_number * page[0]
def test_sql_not_create_cache_with_schema(client):
@@ -203,20 +211,30 @@
client.sql(schema=None, cache='NOT_EXISTING', query_str='select * from NotExisting')
+@pytest.mark.asyncio
+async def test_sql_not_create_cache_with_schema_async(async_client):
+ with pytest.raises(SQLError, match=r".*Cache does not exist.*"):
+ await async_client.sql(schema=None, cache='NOT_EXISTING_ASYNC', query_str='select * from NotExistingAsync')
+
+
def test_sql_not_create_cache_with_cache(client):
with pytest.raises(SQLError, match=r".*Failed to set schema.*"):
client.sql(schema='NOT_EXISTING', query_str='select * from NotExisting')
-def test_query_with_cache(client):
- test_key = 42
- test_value = 'Lorem ipsum'
+@pytest.mark.asyncio
+async def test_sql_not_create_cache_with_cache_async(async_client):
+ with pytest.raises(SQLError, match=r".*Failed to set schema.*"):
+ await async_client.sql(schema='NOT_EXISTING_ASYNC', query_str='select * from NotExistingAsync')
- cache_name = test_query_with_cache.__name__.upper()
+
+@pytest.fixture
+def indexed_cache_settings():
+ cache_name = 'indexed_cache'
schema_name = f'{cache_name}_schema'.upper()
table_name = f'{cache_name}_table'.upper()
- cache = client.create_cache({
+ yield {
PROP_NAME: cache_name,
PROP_SQL_SCHEMA: schema_name,
PROP_CACHE_MODE: CacheMode.PARTITIONED,
@@ -243,18 +261,67 @@
],
},
],
- })
+ }
- cache.put(test_key, test_value)
+
+@pytest.fixture
+def indexed_cache_fixture(client, indexed_cache_settings):
+ cache_name = indexed_cache_settings[PROP_NAME]
+ schema_name = indexed_cache_settings[PROP_SQL_SCHEMA]
+ table_name = indexed_cache_settings[PROP_QUERY_ENTITIES][0]['table_name']
+
+ cache = client.create_cache(indexed_cache_settings)
+
+ yield cache, cache_name, schema_name, table_name
+ cache.destroy()
+
+
+@pytest.fixture
+async def async_indexed_cache_fixture(async_client, indexed_cache_settings):
+ cache_name = indexed_cache_settings[PROP_NAME]
+ schema_name = indexed_cache_settings[PROP_SQL_SCHEMA]
+ table_name = indexed_cache_settings[PROP_QUERY_ENTITIES][0]['table_name']
+
+ cache = await async_client.create_cache(indexed_cache_settings)
+
+ yield cache, cache_name, schema_name, table_name
+ await cache.destroy()
+
+
+def test_query_with_cache(client, indexed_cache_fixture):
+ return __check_query_with_cache(client, indexed_cache_fixture)
+
+
+@pytest.mark.asyncio
+async def test_query_with_cache_async(async_client, async_indexed_cache_fixture):
+ return await __check_query_with_cache(async_client, async_indexed_cache_fixture)
+
+
+def __check_query_with_cache(client, cache_fixture):
+ test_key, test_value = 42, 'Lorem ipsum'
+ cache, cache_name, schema_name, table_name = cache_fixture
+ query = f'select value from {table_name}'
args_to_check = [
('schema', schema_name),
('cache', cache),
- ('cache', cache.name),
+ ('cache', cache_name),
('cache', cache.cache_id)
]
- for param, value in args_to_check:
- page = client.sql(f'select value from {table_name}', **{param: value})
- received = next(page)[0]
- assert test_value == received
+ def inner():
+ cache.put(test_key, test_value)
+ for param, value in args_to_check:
+ with client.sql(query, **{param: value}) as cursor:
+ received = next(cursor)[0]
+ assert test_value == received
+
+ async def async_inner():
+ await cache.put(test_key, test_value)
+ for param, value in args_to_check:
+ async with client.sql(query, **{param: value}) as cursor:
+ row = await cursor.__anext__()
+ received = row[0]
+ assert test_value == received
+
+ return async_inner() if isinstance(cache, AioCache) else inner()
diff --git a/tests/common/test_sql_composite_key.py b/tests/common/test_sql_composite_key.py
new file mode 100644
index 0000000..76de77e
--- /dev/null
+++ b/tests/common/test_sql_composite_key.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+from enum import Enum
+
+import pytest
+
+from pyignite import GenericObjectMeta, AioClient
+from pyignite.datatypes import IntObject, String
+
+
+class StudentKey(
+ metaclass=GenericObjectMeta,
+ type_name='test.model.StudentKey',
+ schema=OrderedDict([
+ ('ID', IntObject),
+ ('DEPT', String)
+ ])
+):
+ pass
+
+
+class Student(
+ metaclass=GenericObjectMeta,
+ type_name='test.model.Student',
+ schema=OrderedDict([
+ ('NAME', String),
+ ])
+):
+ pass
+
+
+create_query = '''CREATE TABLE StudentTable (
+ id INT(11),
+ dept VARCHAR,
+ name CHAR(24),
+ PRIMARY KEY (id, dept))
+ WITH "CACHE_NAME=StudentCache, KEY_TYPE=test.model.StudentKey, VALUE_TYPE=test.model.Student"'''
+
+insert_query = '''INSERT INTO StudentTable (id, dept, name) VALUES (?, ?, ?)'''
+
+select_query = 'SELECT id, dept, name FROM StudentTable'
+
+select_kv_query = 'SELECT _key, _val FROM StudentTable'
+
+drop_query = 'DROP TABLE StudentTable IF EXISTS'
+
+
+@pytest.fixture
+def student_table_fixture(client):
+ yield from __create_student_table_fixture(client)
+
+
+@pytest.fixture
+async def async_student_table_fixture(async_client):
+ async for _ in __create_student_table_fixture(async_client):
+ yield
+
+
+def __create_student_table_fixture(client):
+ def inner():
+ client.sql(drop_query)
+ client.sql(create_query)
+ yield None
+ client.sql(drop_query)
+
+ async def inner_async():
+ await client.sql(drop_query)
+ await client.sql(create_query)
+ yield None
+ await client.sql(drop_query)
+
+ return inner_async() if isinstance(client, AioClient) else inner()
+
+
+class InsertMode(Enum):
+ SQL = 1
+ CACHE = 2
+
+
+@pytest.mark.parametrize('insert_mode', [InsertMode.SQL, InsertMode.CACHE])
+def test_sql_composite_key(client, insert_mode, student_table_fixture):
+ __perform_test(client, insert_mode)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('insert_mode', [InsertMode.SQL, InsertMode.CACHE])
+async def test_sql_composite_key_async(async_client, insert_mode, async_student_table_fixture):
+ await __perform_test(async_client, insert_mode)
+
+
+def __perform_test(client, insert=InsertMode.SQL):
+ student_key = StudentKey(2, 'Business')
+ student_val = Student('Abe')
+
+ def validate_query_result(key, val, query_result):
+ """
+ Compare query result with expected key and value.
+ """
+ assert len(query_result) == 2
+ sql_row = dict(zip(query_result[0], query_result[1]))
+
+ assert sql_row['ID'] == key.ID
+ assert sql_row['DEPT'] == key.DEPT
+ assert sql_row['NAME'] == val.NAME
+
+ def validate_kv_query_result(key, val, query_result):
+ """
+ Compare query result with expected key and value.
+ """
+ assert len(query_result) == 2
+ sql_row = dict(zip(query_result[0], query_result[1]))
+
+ sql_key, sql_val = sql_row['_KEY'], sql_row['_VAL']
+ assert sql_key.ID == key.ID
+ assert sql_key.DEPT == key.DEPT
+ assert sql_val.NAME == val.NAME
+
+ def inner():
+ if insert == InsertMode.SQL:
+ result = client.sql(insert_query, query_args=[student_key.ID, student_key.DEPT, student_val.NAME])
+ assert next(result)[0] == 1
+ else:
+ studentCache = client.get_cache('StudentCache')
+ studentCache.put(student_key, student_val)
+ val = studentCache.get(student_key)
+ assert val is not None
+ assert val.NAME == student_val.NAME
+
+ query_result = list(client.sql(select_query, include_field_names=True))
+ validate_query_result(student_key, student_val, query_result)
+
+ query_result = list(client.sql(select_kv_query, include_field_names=True))
+ validate_kv_query_result(student_key, student_val, query_result)
+
+ async def inner_async():
+ if insert == InsertMode.SQL:
+ result = await client.sql(insert_query, query_args=[student_key.ID, student_key.DEPT, student_val.NAME])
+ assert (await result.__anext__())[0] == 1
+ else:
+ studentCache = await client.get_cache('StudentCache')
+ await studentCache.put(student_key, student_val)
+ val = await studentCache.get(student_key)
+ assert val is not None
+ assert val.NAME == student_val.NAME
+
+ async with client.sql(select_query, include_field_names=True) as cursor:
+ query_result = [r async for r in cursor]
+ validate_query_result(student_key, student_val, query_result)
+
+ async with client.sql(select_kv_query, include_field_names=True) as cursor:
+ query_result = [r async for r in cursor]
+ validate_kv_query_result(student_key, student_val, query_result)
+
+ return inner_async() if isinstance(client, AioClient) else inner()
diff --git a/tests/conftest.py b/tests/conftest.py
index 59b7d3a..65134fd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
+
import pytest
@@ -27,7 +29,7 @@
def skip_if_no_cext(request):
skip = False
try:
- from pyignite import _cutils
+ from pyignite import _cutils # noqa: F401
except ImportError:
if request.config.getoption('--force-cext'):
pytest.fail("C extension failed to build, fail test because of --force-cext is set.")
@@ -38,6 +40,14 @@
pytest.skip('skipped c extensions test, c extension is not available.')
+@pytest.fixture(scope='session')
+def event_loop():
+ """Create an instance of the default event loop for each test case."""
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+
+
def pytest_addoption(parser):
parser.addoption(
'--examples',
diff --git a/tests/security/test_auth.py b/tests/security/test_auth.py
index 2dd19a0..4a1c52d 100644
--- a/tests/security/test_auth.py
+++ b/tests/security/test_auth.py
@@ -15,7 +15,7 @@
import pytest
from pyignite.exceptions import AuthenticationError
-from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client
+from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client, get_client_async
DEFAULT_IGNITE_USERNAME = 'ignite'
DEFAULT_IGNITE_PASSWORD = 'ignite'
@@ -47,13 +47,27 @@
assert all(node.alive for node in client._nodes)
+@pytest.mark.asyncio
+async def test_auth_success_async(with_ssl, ssl_params):
+ ssl_params['use_ssl'] = with_ssl
+
+ async with get_client_async(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD,
+ **ssl_params) as client:
+ await client.connect("127.0.0.1", 10801)
+
+ assert all(node.alive for node in client._nodes)
+
+
+auth_failed_params = [
+ [DEFAULT_IGNITE_USERNAME, None],
+ ['invalid_user', 'invalid_password'],
+ [None, None]
+]
+
+
@pytest.mark.parametrize(
'username, password',
- [
- [DEFAULT_IGNITE_USERNAME, None],
- ['invalid_user', 'invalid_password'],
- [None, None]
- ]
+ auth_failed_params
)
def test_auth_failed(username, password, with_ssl, ssl_params):
ssl_params['use_ssl'] = with_ssl
@@ -61,3 +75,16 @@
with pytest.raises(AuthenticationError):
with get_client(username=username, password=password, **ssl_params) as client:
client.connect("127.0.0.1", 10801)
+
+
+@pytest.mark.parametrize(
+ 'username, password',
+ auth_failed_params
+)
+@pytest.mark.asyncio
+async def test_auth_failed_async(username, password, with_ssl, ssl_params):
+ ssl_params['use_ssl'] = with_ssl
+
+ with pytest.raises(AuthenticationError):
+ async with get_client_async(username=username, password=password, **ssl_params) as client:
+ await client.connect("127.0.0.1", 10801)
diff --git a/tests/security/test_ssl.py b/tests/security/test_ssl.py
index 6463a03..32db98f 100644
--- a/tests/security/test_ssl.py
+++ b/tests/security/test_ssl.py
@@ -15,7 +15,7 @@
import pytest
from pyignite.exceptions import ReconnectError
-from tests.util import start_ignite_gen, get_client, get_or_create_cache
+from tests.util import start_ignite_gen, get_client, get_or_create_cache, get_client_async, get_or_create_cache_async
@pytest.fixture(scope='module', autouse=True)
@@ -30,27 +30,58 @@
def test_connect_ssl(ssl_params):
__test_connect_ssl(**ssl_params)
-def __test_connect_ssl(**kwargs):
+
+@pytest.mark.asyncio
+async def test_connect_ssl_keystore_with_password_async(ssl_params_with_password):
+ await __test_connect_ssl(is_async=True, **ssl_params_with_password)
+
+
+@pytest.mark.asyncio
+async def test_connect_ssl_async(ssl_params):
+ await __test_connect_ssl(is_async=True, **ssl_params)
+
+
+def __test_connect_ssl(is_async=False, **kwargs):
kwargs['use_ssl'] = True
- with get_client(**kwargs) as client:
- client.connect("127.0.0.1", 10801)
+ def inner():
+ with get_client(**kwargs) as client:
+ client.connect("127.0.0.1", 10801)
- with get_or_create_cache(client, 'test-cache') as cache:
- cache.put(1, 1)
+ with get_or_create_cache(client, 'test-cache') as cache:
+ cache.put(1, 1)
- assert cache.get(1) == 1
+ assert cache.get(1) == 1
+
+ async def inner_async():
+ async with get_client_async(**kwargs) as client:
+ await client.connect("127.0.0.1", 10801)
+
+ async with get_or_create_cache_async(client, 'test-cache') as cache:
+ await cache.put(1, 1)
+
+ assert (await cache.get(1)) == 1
+
+ return inner_async() if is_async else inner()
-@pytest.mark.parametrize(
- 'invalid_ssl_params',
- [
- {'use_ssl': False},
- {'use_ssl': True},
- {'use_ssl': True, 'ssl_keyfile': 'invalid.pem', 'ssl_certfile': 'invalid.pem'}
- ]
-)
+invalid_params = [
+ {'use_ssl': False},
+ {'use_ssl': True},
+ {'use_ssl': True, 'ssl_keyfile': 'invalid.pem', 'ssl_certfile': 'invalid.pem'}
+]
+
+
+@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
def test_connection_error_with_incorrect_config(invalid_ssl_params):
with pytest.raises(ReconnectError):
with get_client(**invalid_ssl_params) as client:
client.connect([("127.0.0.1", 10801)])
+
+
+@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
+@pytest.mark.asyncio
+async def test_connection_error_with_incorrect_config_async(invalid_ssl_params):
+ with pytest.raises(ReconnectError):
+ async with get_client_async(**invalid_ssl_params) as client:
+ await client.connect([("127.0.0.1", 10801)])
diff --git a/tests/test_cutils.py b/tests/test_cutils.py
index e7c095e..d66425f 100644
--- a/tests/test_cutils.py
+++ b/tests/test_cutils.py
@@ -27,8 +27,8 @@
_cutils_hashcode = _cutils.hashcode
_cutils_schema_id = _cutils.schema_id
except ImportError:
- _cutils_hashcode = lambda x: None
- _cutils_schema_id = lambda x: None
+ _cutils_hashcode = lambda x: None # noqa: E731
+ _cutils_schema_id = lambda x: None # noqa: E731
pass
diff --git a/tests/util.py b/tests/util.py
index af4c324..f1243fc 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
import contextlib
import glob
+import inspect
import os
import shutil
@@ -24,7 +26,12 @@
import subprocess
import time
-from pyignite import Client
+from pyignite import Client, AioClient
+
+try:
+ from contextlib import asynccontextmanager
+except ImportError:
+ from async_generator import asynccontextmanager
@contextlib.contextmanager
@@ -36,6 +43,15 @@
client.close()
+@asynccontextmanager
+async def get_client_async(**kwargs):
+ client = AioClient(**kwargs)
+ try:
+ yield client
+ finally:
+ await client.close()
+
+
@contextlib.contextmanager
def get_or_create_cache(client, cache_name):
cache = client.get_or_create_cache(cache_name)
@@ -45,6 +61,15 @@
cache.destroy()
+@asynccontextmanager
+async def get_or_create_cache_async(client, cache_name):
+ cache = await client.get_or_create_cache(cache_name)
+ try:
+ yield cache
+ finally:
+ await cache.destroy()
+
+
def wait_for_condition(condition, interval=0.1, timeout=10, error=None):
start = time.time()
res = condition()
@@ -62,6 +87,23 @@
return False
+async def wait_for_condition_async(condition, interval=0.1, timeout=10, error=None):
+ start = time.time()
+ res = await condition() if inspect.iscoroutinefunction(condition) else condition()
+
+ while not res and time.time() - start < timeout:
+ await asyncio.sleep(interval)
+ res = await condition() if inspect.iscoroutinefunction(condition) else condition()
+
+ if res:
+ return True
+
+ if error is not None:
+ raise Exception(error)
+
+ return False
+
+
def is_windows():
return os.name == "nt"
diff --git a/tox.ini b/tox.ini
index 3ab8dea..90153da 100644
--- a/tox.ini
+++ b/tox.ini
@@ -15,7 +15,15 @@
[tox]
skipsdist = True
-envlist = py{36,37,38,39}
+envlist = codestyle,py{36,37,38,39}
+
+[flake8]
+max-line-length=120
+ignore = F401,F403,F405,F821
+
+[testenv:codestyle]
+basepython = python3.8
+commands = flake8
[testenv]
passenv = TEAMCITY_VERSION IGNITE_HOME