blob: 3003eb633c0667cc8e3968c285bc6ed1681dc6f2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntEnum
from typing import Union, Type
from pyignite.api.tx_api import tx_end, tx_start, tx_end_async, tx_start_async
from pyignite.datatypes import TransactionIsolation, TransactionConcurrency
from pyignite.exceptions import CacheError
from pyignite.utils import status_to_exception
def _validate_int_enum_param(value: Union[int, IntEnum], cls: Type[IntEnum]):
if value not in set(v.value for v in cls): # Use this trick to disable warning on python 3.7
raise ValueError(f'{value} not in {cls}')
return value
def _validate_timeout(value):
if not isinstance(value, int) or value < 0:
raise ValueError(f'Timeout value should be a positive integer, {value} passed instead')
return value
def _validate_label(value):
if value and not isinstance(value, str):
raise ValueError(f'Label should be str, {type(value)} passed instead')
return value
class _BaseTransaction:
def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC,
isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None):
self.client = client
self.concurrency = _validate_int_enum_param(concurrency, TransactionConcurrency)
self.isolation = _validate_int_enum_param(isolation, TransactionIsolation)
self.timeout = _validate_timeout(timeout)
self.label, self.closed = _validate_label(label), False
class Transaction(_BaseTransaction):
"""
Thin client transaction.
"""
def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC,
isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None):
super().__init__(client, concurrency, isolation, timeout, label)
self.tx_id = self.__start_tx()
def commit(self) -> None:
"""
Commit transaction.
"""
if not self.closed:
self.closed = True
return self.__end_tx(True)
def rollback(self) -> None:
"""
Rollback transaction.
"""
self.close()
def close(self) -> None:
"""
Close transaction.
"""
if not self.closed:
self.closed = True
return self.__end_tx(False)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
@status_to_exception(CacheError)
def __start_tx(self):
conn = self.client.random_node
return tx_start(conn, self.concurrency, self.isolation, self.timeout, self.label)
@status_to_exception(CacheError)
def __end_tx(self, committed):
return tx_end(self.tx_id, committed)
class AioTransaction(_BaseTransaction):
"""
Async thin client transaction.
"""
def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC,
isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None):
super().__init__(client, concurrency, isolation, timeout, label)
def __await__(self):
return (yield from self.__aenter__().__await__())
async def commit(self) -> None:
"""
Commit transaction.
"""
if not self.closed:
self.closed = True
return await self.__end_tx(True)
async def rollback(self) -> None:
"""
Rollback transaction.
"""
await self.close()
async def close(self) -> None:
"""
Close transaction.
"""
if not self.closed:
self.closed = True
return await self.__end_tx(False)
async def __aenter__(self):
self.tx_id = await self.__start_tx()
self.closed = False
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
@status_to_exception(CacheError)
async def __start_tx(self):
conn = await self.client.random_node()
return await tx_start_async(conn, self.concurrency, self.isolation, self.timeout, self.label)
@status_to_exception(CacheError)
async def __end_tx(self, committed):
return await tx_end_async(self.tx_id, committed)