| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| # this is the main function of model selection. |
| |
| import calendar |
| import os |
| import time |
| from src.common.constant import Config |
| from src.dataset_utils.structure_data_loader import libsvm_dataloader |
| from exps.shared_args import parse_arguments |
| |
| |
| def generate_data_loader(): |
| if args.dataset in [Config.c10, Config.c100, Config.imgNet]: |
| train_loader, val_loader, class_num = dataset.get_dataloader( |
| train_batch_size=args.batch_size, |
| test_batch_size=args.batch_size, |
| dataset=args.dataset, |
| num_workers=1, |
| datadir=os.path.join(args.base_dir, "data")) |
| test_loader = val_loader |
| else: |
| train_loader, val_loader, test_loader = libsvm_dataloader( |
| args=args, |
| data_dir=os.path.join(args.base_dir, "data", "structure_data", args.dataset), |
| nfield=args.nfield, |
| batch_size=args.batch_size) |
| class_num = args.num_labels |
| |
| return train_loader, val_loader, test_loader, class_num |
| |
| |
| def run_with_time_budget(time_budget: float): |
| """ |
| :param time_budget: the given time budget, in second |
| :return: |
| """ |
| |
| # define dataLoader, and sample a mini-batch |
| train_loader, val_loader, test_loader, class_num = generate_data_loader() |
| args.num_labels = class_num |
| data_loader = [train_loader, val_loader, test_loader] |
| |
| rms = RunModelSelection(args.search_space, args, is_simulate=False) |
| best_arch, _, _, _, _, _, _, _ = rms.select_model_online(time_budget, data_loader) |
| |
| return best_arch |
| |
| |
| if __name__ == "__main__": |
| args = parse_arguments() |
| |
| # set the log name |
| gmt = time.gmtime() |
| ts = calendar.timegm(gmt) |
| os.environ.setdefault("log_file_name", args.log_name + "_" + str(ts) + ".log") |
| os.environ.setdefault("base_dir", args.base_dir) |
| |
| from src.eva_engine.run_ms import RunModelSelection |
| from src.dataset_utils import dataset |
| |
| run_with_time_budget(args.budget) |