blob: ee8de07d4685a5def2d860c38dff03f6b46fca7a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextvars
import attr
from pyignite.datatypes import Byte, String, Long, Int, Bool
from pyignite.exceptions import CacheError
from pyignite.queries import Query, query_perform
from pyignite.queries.op_codes import OP_TX_START, OP_TX_END
__CURRENT_TX = contextvars.ContextVar('current_tx', default=None)
def get_tx_id():
ctx = __CURRENT_TX.get() if __CURRENT_TX else None
return ctx.tx_id if ctx else None
def get_tx_connection():
ctx = __CURRENT_TX.get() if __CURRENT_TX else None
return ctx.conn if ctx else None
@attr.s
class TransactionContext:
tx_id = attr.ib(type=int, default=None)
conn = attr.ib(default=None)
def tx_start(conn, concurrency, isolation, timeout: int = 0, label: str = None):
result = __tx_start(conn, concurrency, isolation, timeout, label)
return __tx_start_post_process(result, conn)
async def tx_start_async(conn, concurrency, isolation, timeout: int = 0, label: str = None):
result = await __tx_start(conn, concurrency, isolation, timeout, label)
return __tx_start_post_process(result, conn)
def __tx_start(conn, concurrency, isolation, timeout, label):
query_struct = Query(
OP_TX_START,
[
('concurrency', Byte),
('isolation', Byte),
('timeout', Long),
('label', String)
]
)
return query_perform(
query_struct, conn,
query_params={
'concurrency': concurrency,
'isolation': isolation,
'timeout': timeout,
'label': label
},
response_config=[
('tx_id', Int)
]
)
def tx_end(tx_id, committed):
ctx = __CURRENT_TX.get()
if not ctx or ctx.tx_id != tx_id:
raise CacheError("Cannot commit transaction from different thread or coroutine")
try:
return __tx_end(ctx.conn, tx_id, committed)
finally:
__CURRENT_TX.set(None)
async def tx_end_async(tx_id, committed):
ctx = __CURRENT_TX.get()
if not ctx or ctx.tx_id != tx_id:
raise CacheError("Cannot commit transaction from different thread or coroutine")
try:
return await __tx_end(ctx.conn, tx_id, committed)
finally:
__CURRENT_TX.set(None)
def __tx_end(conn, tx_id, committed):
query_struct = Query(
OP_TX_END,
[
('tx_id', Int),
('committed', Bool)
],
)
return query_perform(
query_struct, conn,
query_params={
'tx_id': tx_id,
'committed': committed
}
)
def __tx_start_post_process(result, conn):
if result.status == 0:
tx_id = result.value['tx_id']
__CURRENT_TX.set(TransactionContext(tx_id, conn))
result.value = tx_id
return result