blob: b032997749327d44346b0c7214beaf487d6f3f77 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import sqlite3
import traceback
from src.common.constant import Config
base_folder_dir = os.environ.get("base_dir")
if base_folder_dir is None: base_folder_dir = os.getcwd()
base_dir = os.path.join(base_folder_dir, "img_data", "ground_truth")
print("local api running at {}".format(base_dir))
# sum score is better
tf_smt_file_NB101C10 = os.path.join(base_dir, "TFMEM_101_c10_100run_8k_models_score_sum")
tf_smt_file_NB201C10 = os.path.join(base_dir, "TFMEM_201_c10_100run_score_sum")
tf_smt_file_NB201C100 = os.path.join(base_dir, "TFMEM_201_c100_100run_score_sum")
tf_smt_file_NB201Img = os.path.join(base_dir, "TFMEM_201_imgNet_100run_score_sum")
# rank is not as good as sum
# tf_smt_file_NB201C10 = os.path.join(base_dir, "TFMEM_201_c10_100run_rank_bugs")
# tf_smt_file_NB201C100 = os.path.join(base_dir, "TFMEM_201_c100_200run_rank")
# tf_smt_file_NB201Img = os.path.join(base_dir, "TFMEM_201_imgNet_200run_rank")
con = None
cur = None
# fetch result from simulated result
def fetch_from_db(space_name, dataset, run_id_m, N_m):
"""
:param run_id_m: run_id 100 max
:param B1_m: number of models evaluted
:return:
"""
global con
global cur
if con is None:
if space_name == Config.NB201:
if dataset == Config.c10:
tf_smt_used = tf_smt_file_NB201C10
elif dataset == Config.c100:
tf_smt_used = tf_smt_file_NB201C100
elif dataset == Config.imgNet:
tf_smt_used = tf_smt_file_NB201Img
else:
print(f"{dataset} is Not implemented")
raise
elif space_name == Config.NB101:
if dataset == Config.c10:
tf_smt_used = tf_smt_file_NB101C10
else:
print(f"{dataset}Not implemented")
raise
else:
print(f"{space_name} is Not implemented")
raise
print(tf_smt_used)
con = sqlite3.connect(tf_smt_used)
cur = con.cursor()
res = cur.execute(
"SELECT * FROM simulateExp WHERE run_num = {} and model_explored = {}".format(run_id_m, N_m))
fetch_res = res.fetchone()
try:
arch_id = fetch_res[2]
candidates = json.loads(fetch_res[3])
current_time = float(fetch_res[4])
except:
print(traceback.format_exc())
raise f"res is None when using run_id ={run_id_m} and bm = {N_m}"
return arch_id, candidates, current_time
if __name__ == '__main__':
print(fetch_from_db(Config.NB201, Config.c10, 3, 10))