blob: a2d673f9218c589238d59329492718b155c215bb [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import subprocess
import requests
import json
import time
# Change the host port username password and database name on your need
mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G"
# FE http://host:port
feHttp = "http://localhost:8030"
trace_url = feHttp + '/rest/v2/manager/query/trace_id/{}'
qerror_url = feHttp + '/rest/v2/manager/query/qerror/{}'
# File path to save test results.
# Sample:
# 8
# {
# "legacyPlanIdToPhysicalPlan": {
# "0": {
# "first": 1.0,
# "second": 1.0
# },
# .......
# "qError": 34.5
# }
# `8` represents q8 in the tpc-h test
# `first` is the estimated row count for plan which with plan id 0, `second` is the actual returned row count
qerr_saved_file_path = ""
# SQL under this directory would be tested.
original_sql_dir = "add your tpc-h/tpch-ds/ssb sql directory path here"
sql_file_prefix_for_trace = """
SET enable_nereids_planner=true;
SET session_context='trace_id:{}';
"""
q_err_list = []
def extract_number(string):
return int(''.join([c for c in string if c.isdigit()]))
def write_results(path: str, title: str, result: list):
with open(path, "a") as file:
file.write(title)
file.write("\n")
for item in result:
file.write(str(item) + " " + "\n")
file.write("\n")
def read_lines(path: str) -> list:
with open(path, "r") as f:
return f.readlines()
def write_result(title: str, result: str):
wrapped = [result]
write_results(qerr_saved_file_path, title, wrapped)
def execute_command(cmd: str):
result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return result.stdout
def execute_sql(sql_file: str):
command = mycli_cmd + " < " + sql_file
result = execute_command(command).decode("utf-8")
return result
def get_q_error(trace_id):
time.sleep(1)
# 'YWRtaW46' is the base64 encoded result for 'admin:'
headers = {'Authorization': 'BASIC YWRtaW46'}
resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers)
resp_text = resp_wrapper.text
query_id = json.loads(resp_text)["data"]
resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers)
resp_text = resp_wrapper.text
write_result(str(trace_id), resp_text)
print(trace_id)
print(resp_text)
qerr = json.loads(resp_text)["qError"]
q_err_list.append(float(qerr))
def iterates_sqls(path: str, if_write_results: bool) -> list:
cost_times = []
files = os.listdir(path)
files.sort(key=extract_number)
for filename in files:
if filename.endswith(".sql"):
filepath = os.path.join(path, filename)
traced_sql_file = filepath + ".traced"
content = read_lines(filepath)
sql_num = extract_number(filename)
print("sql num" + str(sql_num))
if if_write_results:
write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content)
execute_sql(traced_sql_file)
get_q_error(sql_num)
os.remove(traced_sql_file)
else:
execute_sql(filepath)
return cost_times
if __name__ == '__main__':
execute_command("echo 'set global enable_nereids_planner=true' | mysql -h127.0.0.1 -P9030")
execute_command("echo 'set global enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030")
print("Preparing")
iterates_sqls(original_sql_dir, False)
print("Started...")
iterates_sqls(original_sql_dir, True)
write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) / len(q_err_list)])