blob: a690d94f4dbf5de516c031ecaa549de95beb7790 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains sync and async cursors for different types of queries.
"""
import asyncio
from pyignite.api import (
scan, scan_cursor_get_page, resource_close, scan_async, scan_cursor_get_page_async, resource_close_async, sql,
sql_cursor_get_page, sql_fields, sql_fields_cursor_get_page, sql_fields_cursor_get_page_async, sql_fields_async
)
from pyignite.exceptions import CacheError, SQLError
__all__ = ['ScanCursor', 'SqlCursor', 'SqlFieldsCursor', 'AioScanCursor', 'AioSqlFieldsCursor']
class BaseCursorMixin:
@property
def connection(self):
"""
Ignite cluster connection.
"""
return getattr(self, '_conn', None)
@connection.setter
def connection(self, value):
setattr(self, '_conn', value)
@property
def cursor_id(self):
"""
Cursor id.
"""
return getattr(self, '_cursor_id', None)
@cursor_id.setter
def cursor_id(self, value):
setattr(self, '_cursor_id', value)
@property
def more(self):
"""
Whether cursor has more values.
"""
return getattr(self, '_more', None)
@more.setter
def more(self, value):
setattr(self, '_more', value)
@property
def cache_info(self):
"""
Cache id.
"""
return getattr(self, '_cache_info', None)
@cache_info.setter
def cache_info(self, value):
setattr(self, '_cache_info', value)
@property
def client(self):
"""
Apache Ignite client.
"""
return getattr(self, '_client', None)
@client.setter
def client(self, value):
setattr(self, '_client', value)
@property
def data(self):
"""
Current fetched data.
"""
return getattr(self, '_data', None)
@data.setter
def data(self, value):
setattr(self, '_data', value)
class CursorMixin(BaseCursorMixin):
def __enter__(self):
return self
def __iter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
"""
Close cursor.
"""
if self.connection and self.cursor_id and self.more:
resource_close(self.connection, self.cursor_id)
class AioCursorMixin(BaseCursorMixin):
def __await__(self):
return (yield from self.__aenter__().__await__())
def __aiter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def close(self):
"""
Close cursor.
"""
if self.connection and self.cursor_id and self.more:
await resource_close_async(self.connection, self.cursor_id)
class AbstractScanCursor:
def __init__(self, client, cache_info, page_size, partitions, local):
self.client = client
self.cache_info = cache_info
self._page_size = page_size
self._partitions = partitions
self._local = local
def _finalize_init(self, result):
if result.status != 0:
raise CacheError(result.message)
self.cursor_id, self.more = result.value['cursor'], result.value['more']
self.data = iter(result.value['data'].items())
def _process_page_response(self, result):
if result.status != 0:
raise CacheError(result.message)
self.data, self.more = iter(result.value['data'].items()), result.value['more']
class ScanCursor(AbstractScanCursor, CursorMixin):
"""
Synchronous scan cursor.
"""
def __init__(self, client, cache_info, page_size, partitions, local):
"""
:param client: Synchronous Apache Ignite client.
:param cache_info: Cache meta info.
:param page_size: page size.
:param partitions: number of partitions to query (negative to query entire cache).
:param local: pass True if this query should be executed on local node only.
"""
super().__init__(client, cache_info, page_size, partitions, local)
self.connection = self.client.random_node
result = scan(self.connection, self.cache_info, self._page_size, self._partitions, self._local)
self._finalize_init(result)
def __next__(self):
if not self.data:
raise StopIteration
try:
k, v = next(self.data)
except StopIteration:
if self.more:
self._process_page_response(scan_cursor_get_page(self.connection, self.cursor_id))
k, v = next(self.data)
else:
raise StopIteration
return self.client.unwrap_binary(k), self.client.unwrap_binary(v)
class AioScanCursor(AbstractScanCursor, AioCursorMixin):
"""
Asynchronous scan query cursor.
"""
def __init__(self, client, cache_info, page_size, partitions, local):
"""
:param client: Asynchronous Apache Ignite client.
:param cache_info: Cache meta info.
:param page_size: page size.
:param partitions: number of partitions to query (negative to query entire cache).
:param local: pass True if this query should be executed on local node only.
"""
super().__init__(client, cache_info, page_size, partitions, local)
async def __aenter__(self):
if not self.connection:
self.connection = await self.client.random_node()
result = await scan_async(self.connection, self.cache_info, self._page_size, self._partitions, self._local)
self._finalize_init(result)
return self
async def __anext__(self):
if not self.connection:
raise CacheError("Using uninitialized cursor, initialize it using async with expression.")
if not self.data:
raise StopAsyncIteration
try:
k, v = next(self.data)
except StopIteration:
if self.more:
self._process_page_response(await scan_cursor_get_page_async(self.connection, self.cursor_id))
try:
k, v = next(self.data)
except StopIteration:
raise StopAsyncIteration
else:
raise StopAsyncIteration
return await asyncio.gather(
*[self.client.unwrap_binary(k), self.client.unwrap_binary(v)]
)
class SqlCursor(CursorMixin):
"""
Synchronous SQL query cursor.
"""
def __init__(self, client, cache_info, *args, **kwargs):
"""
:param client: Synchronous Apache Ignite client.
:param cache_info: Cache meta info.
"""
self.client = client
self.cache_info = cache_info
self.connection = self.client.random_node
result = sql(self.connection, self.cache_info, *args, **kwargs)
if result.status != 0:
raise SQLError(result.message)
self.cursor_id, self.more = result.value['cursor'], result.value['more']
self.data = iter(result.value['data'].items())
def __next__(self):
if not self.data:
raise StopIteration
try:
k, v = next(self.data)
except StopIteration:
if self.more:
result = sql_cursor_get_page(self.connection, self.cursor_id)
if result.status != 0:
raise SQLError(result.message)
self.data, self.more = iter(result.value['data'].items()), result.value['more']
k, v = next(self.data)
else:
raise StopIteration
return self.client.unwrap_binary(k), self.client.unwrap_binary(v)
class AbstractSqlFieldsCursor:
def __init__(self, client, cache_info):
self.client = client
self.cache_info = cache_info
def _finalize_init(self, result):
if result.status != 0:
raise SQLError(result.message)
self.cursor_id, self.more = result.value['cursor'], result.value['more']
self.data = iter(result.value['data'])
self._field_names = result.value.get('fields', None)
if self._field_names:
self._field_count = len(self._field_names)
else:
self._field_count = result.value['field_count']
class SqlFieldsCursor(AbstractSqlFieldsCursor, CursorMixin):
"""
Synchronous SQL fields query cursor.
"""
def __init__(self, client, cache_info, *args, **kwargs):
"""
:param client: Synchronous Apache Ignite client.
:param cache_info: Cache meta info.
"""
super().__init__(client, cache_info)
self.connection = self.client.random_node
self._finalize_init(sql_fields(self.connection, self.cache_info, *args, **kwargs))
def __next__(self):
if not self.data:
raise StopIteration
if self._field_names:
result = self._field_names
self._field_names = None
return result
try:
row = next(self.data)
except StopIteration:
if self.more:
result = sql_fields_cursor_get_page(self.connection, self.cursor_id, self._field_count)
if result.status != 0:
raise SQLError(result.message)
self.data, self.more = iter(result.value['data']), result.value['more']
row = next(self.data)
else:
raise StopIteration
return [self.client.unwrap_binary(v) for v in row]
class AioSqlFieldsCursor(AbstractSqlFieldsCursor, AioCursorMixin):
"""
Asynchronous SQL fields query cursor.
"""
def __init__(self, client, cache_info, *args, **kwargs):
"""
:param client: Synchronous Apache Ignite client.
:param cache_info: Cache meta info.
"""
super().__init__(client, cache_info)
self._params = (args, kwargs)
async def __aenter__(self):
await self._initialize(*self._params[0], *self._params[1])
return self
async def __anext__(self):
if not self.connection:
raise SQLError("Attempting to use uninitialized aio cursor, please await on it or use with expression.")
if not self.data:
raise StopAsyncIteration
if self._field_names:
result = self._field_names
self._field_names = None
return result
try:
row = next(self.data)
except StopIteration:
if self.more:
result = await sql_fields_cursor_get_page_async(self.connection, self.cursor_id, self._field_count)
if result.status != 0:
raise SQLError(result.message)
self.data, self.more = iter(result.value['data']), result.value['more']
try:
row = next(self.data)
except StopIteration:
raise StopAsyncIteration
else:
raise StopAsyncIteration
return await asyncio.gather(*[self.client.unwrap_binary(v) for v in row])
async def _initialize(self, *args, **kwargs):
if self.connection and self.cursor_id:
return
self.connection = await self.client.random_node()
self._finalize_init(await sql_fields_async(self.connection, self.cache_info, *args, **kwargs))