| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import threading |
| import time |
| |
| import pytest |
| |
| import pyignite_dbapi |
| from tests.conftest import TEST_CONNECT_KWARGS |
| from tests.util import wait_for_condition |
| |
| NUM_THREADS = 50 |
| |
| |
| @pytest.fixture() |
| def module_level_threadsafety(): |
| assert pyignite_dbapi.threadsafety >= 1, "Module can not be used concurrently" |
| |
| |
| @pytest.fixture() |
| def connection_level_threadsafety(module_level_threadsafety): |
| assert pyignite_dbapi.threadsafety >= 2, "Connections can not be used concurrently" |
| |
| |
| @pytest.fixture() |
| def table(table_name, service_cursor, drop_table_cleanup): |
| service_cursor.execute(f"CREATE TABLE {table_name} (id int primary key, data varchar)") |
| yield table_name |
| |
| |
| def run_threads(fn, n=NUM_THREADS, *args): |
| barrier = threading.Barrier(n) |
| errors = [] |
| errors_lock = threading.Lock() |
| |
| def wrapper(tid): |
| try: |
| barrier.wait() |
| fn(tid, *args) |
| except Exception as e: |
| with errors_lock: |
| errors.append(e) |
| |
| threads = [threading.Thread(target=wrapper, args=(i,)) for i in range(n)] |
| for t in threads: |
| t.start() |
| for t in threads: |
| t.join() |
| |
| if errors: |
| raise errors[0] |
| |
| |
| def test_concurrent_module_import(module_level_threadsafety): |
| import importlib |
| |
| def task(_): |
| m = importlib.import_module(pyignite_dbapi.__name__) |
| assert m.threadsafety > 0, "Module can not be used concurrently" |
| |
| run_threads(task) |
| |
| |
| def test_concurrent_connect_use_close(module_level_threadsafety): |
| def task(_): |
| c = pyignite_dbapi.connect(**TEST_CONNECT_KWARGS) |
| with c.cursor() as cur: |
| cur.execute("SELECT 1") |
| assert cur.fetchone() is not None |
| c.close() |
| |
| run_threads(task) |
| |
| |
| def test_shared_connection_per_thread_cursors(connection, connection_level_threadsafety): |
| def task(_): |
| with connection.cursor() as cur: |
| cur.execute("SELECT 1") |
| row = cur.fetchone() |
| assert row is not None |
| |
| run_threads(task) |
| |
| |
| def test_concurrent_inserts_no_lost_writes(table, connection, connection_level_threadsafety): |
| rows_per_thread = 50 |
| |
| def task(thread_id): |
| with connection.cursor() as cur: |
| for i in range(rows_per_thread): |
| cur.execute(f"INSERT INTO {table} (id, data) VALUES (?, ?)", (thread_id * rows_per_thread + i, f"v{thread_id}-{i}")) |
| |
| run_threads(task) |
| |
| with connection.cursor() as cur: |
| cur.execute(f"SELECT COUNT(*) FROM {table}") |
| count = cur.fetchone()[0] |
| assert count == NUM_THREADS * rows_per_thread |
| |
| |
| def test_concurrent_commit_and_rollback(table, module_level_threadsafety): |
| """Half the threads commit, half rollback. Only committed rows appear.""" |
| committed_ids = [] |
| lock = threading.Lock() |
| |
| def task(thread_id): |
| with pyignite_dbapi.connect(**TEST_CONNECT_KWARGS) as conn: |
| conn.autocommit = False |
| with conn.cursor() as cur: |
| cur.execute(f"INSERT INTO {table} (id, data) VALUES (?, ?)", (thread_id, "x")) |
| if thread_id % 2 == 0: |
| conn.commit() |
| with lock: |
| committed_ids.append(thread_id) |
| else: |
| conn.rollback() |
| |
| run_threads(task) |
| |
| def get_ids(): |
| with pyignite_dbapi.connect(**TEST_CONNECT_KWARGS) as conn: |
| with conn.cursor() as cur: |
| cur.execute(f"SELECT id FROM {table} ORDER BY id") |
| return {row[0] for row in cur.fetchall()} |
| |
| # There is currently no mechanism to synchronize the observable timestamp across |
| # multiple connections, so changes will eventually become visible, but not necessarily immediately. |
| wait_for_condition(lambda: get_ids() == set(committed_ids), interval=0.5) |
| |
| |
| def test_concurrent_fetchall_result_integrity(table, connection, connection_level_threadsafety): |
| rows_num = 200 |
| with connection.cursor() as cur: |
| cur.executemany(f"INSERT INTO {table} (id, data) VALUES (?, ?)", [(i, f"val-{i}") for i in range(rows_num)]) |
| |
| def task(_): |
| with connection.cursor() as cur: |
| cur.execute(f"SELECT id, data FROM {table} ORDER BY id") |
| rows = cur.fetchall() |
| assert len(rows) == rows_num, f"Expected {rows_num} rows, got {len(rows)}" |
| |
| for idx, (rid, val) in enumerate(rows): |
| assert val == f"val-{rid}", f"Corrupted row: id={rid}, val={val!r}" |
| |
| run_threads(task) |
| |
| |
| def test_cursor_description_thread_safety(table, connection, connection_level_threadsafety): |
| expected_names = {"ID", "DATA"} |
| |
| def task(_): |
| with connection.cursor() as cur: |
| cur.execute(f"SELECT id, data FROM {table} LIMIT 1") |
| desc = cur.description |
| assert desc is not None |
| col_names = {col[0] for col in desc} |
| assert col_names == expected_names, f"Unexpected columns: {col_names}" |
| |
| run_threads(task) |
| |
| |
| def test_concurrent_executemany(table, connection, connection_level_threadsafety): |
| rows_per_thread = 20 |
| |
| def task(thread_id): |
| rows = [(thread_id * 1000 + i, f"{thread_id}-{i}") for i in range(rows_per_thread)] |
| with connection.cursor() as cur: |
| cur.executemany(f"INSERT INTO {table} (id, data) VALUES (?, ?)", rows) |
| |
| run_threads(task) |
| |
| with connection.cursor() as cur: |
| cur.execute(f"SELECT COUNT(*) FROM {table}") |
| count = cur.fetchone()[0] |
| |
| assert count == NUM_THREADS * rows_per_thread |