blob: 9189eb7c2af2a9abb566cc7cbfcedf5b140ab491 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import result
import mysql.connector
from typing import List, Tuple
import re
class SQLExecutor:
def __init__(self, user: str, password: str, host: str, port: int, database: str) -> None:
self.connection = mysql.connector.connect(
user=user,
password=password,
host=host,
port=port,
database=database
)
self.cursor = self.connection.cursor()
self.wait_fetch_time_index = 4
def execute_query(self, query: str, parameters: Tuple | None) -> List[Tuple]:
if parameters:
self.cursor.execute(query, parameters)
else:
self.cursor.execute(query)
results = self.cursor.fetchall()
return results
def get_execute_time(self, query: str) -> float:
self.execute_query(query, None)
profile = self.execute_query("show query profile\"\"", None)
return self.get_n_ms(profile[0][self.wait_fetch_time_index])
def get_n_ms(self, t: str):
res = re.search(r"(\d+h)*(\d+min)*(\d+s)*(\d+ms)", t)
if res is None:
raise Exception(f"invalid time {t}")
n = 0
h = res.group(1)
if h is not None:
n += int(h.replace("h", "")) * 60 * 60 * 1000
min = res.group(2)
if min is not None != 0:
n += int(min.replace("min", "")) * 60 * 1000
s = res.group(3)
if s is not None != 0:
n += int(s.replace("s", "")) * 1000
ms = res.group(4)
if len(ms) != 0:
n += int(ms.replace("ms", ""))
return n
def execute_many_queries(self, queries: List[Tuple[str, Tuple]]) -> List[List[Tuple]]:
results = []
for query, parameters in queries:
result = self.execute_query(query, parameters)
results.append(result)
return results
def get_plan_with_cost(self, query: str):
result = self.execute_query(f"explain optimized plan {query}", None)
cost = float(result[0][0].replace("cost = ", ""))
plan = "".join([s[0] for s in result[1:]])
return plan, cost
def commit(self) -> None:
self.connection.commit()
def rollback(self) -> None:
self.connection.rollback()
def close(self) -> None:
self.cursor.close()
self.connection.close()