blob: 607c65ac5fdb6adc3884d9f27fac96dcf75552fc [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import abstractmethod
from typing import Generator
from torch.utils.data import DataLoader
from src.search_space.core.model_params import ModelMacroCfg, ModelMicroCfg
class SpaceWrapper:
def __init__(self, cfg: ModelMacroCfg, name: str):
self.model_cfg = cfg
self.name = name
@abstractmethod
def sample_all_models(self) -> Generator[str, None, None]:
"""
Sample all models, return a list of arch ids
"""
raise NotImplementedError
"""serialize and deserialize"""
@classmethod
def serialize_model_encoding(cls, arch_micro: ModelMicroCfg) -> str:
raise NotImplementedError
@classmethod
def deserialize_model_encoding(cls, model_encoding) -> ModelMicroCfg:
raise NotImplementedError
@classmethod
def new_arch_scratch(cls, arch_macro: ModelMacroCfg, arch_micro: ModelMicroCfg, bn: bool = True):
"""
Args:
arch_macro: macro setting for one architecture
arch_micro: micro setting for one architecture
bn: true or false
Returns:
"""
raise NotImplementedError
def new_arch_scratch_with_default_setting(self, model_encoding: str, bn: bool):
"""
Use the current search space's macro setting.
Args:
model_encoding: str of the model encoding
bn: true or false
Returns:
"""
raise NotImplementedError
@abstractmethod
def load(self):
"""
Load the related API
Returns:
"""
raise NotImplementedError
@abstractmethod
def profiling(self, dataset: str,
train_loader: DataLoader = None, val_loader: DataLoader = None,
args=None, is_simulate: bool = False) -> (float, float, int):
"""
Profile the training and scoring time.
Args:
dataset:
train_loader:
val_loader
args:
is_simulate:
Returns:
"""
raise NotImplementedError
@abstractmethod
def micro_to_id(self, arch_struct: ModelMicroCfg) -> str:
raise NotImplementedError
"""init new architecture"""
@abstractmethod
def new_architecture(self, arch_id: str):
"""
Generate an architecture with arch id
:return:
"""
raise NotImplementedError
def new_architecture_with_micro_cfg(self, arch_micro: ModelMicroCfg):
"""
Generate an architecture with arch_micro
:return:
"""
raise NotImplementedError
@abstractmethod
def __len__(self):
"""
How many architectures the space has
:return:
"""
raise NotImplementedError
@abstractmethod
def get_arch_size(self, architecture):
"""
Get how many edges in each cell of the architecture.
:return:
"""
raise NotImplementedError
def update_bn_flag(self, bn: bool):
"""
Update architecture's bn,
:param bn:
:return:
"""
self.model_cfg.bn = bn
"""Below is for integrating space with various sampler"""
def random_architecture_id(self) -> (str, ModelMicroCfg):
"""
Random generate architecture id, cell structure, supporting RN, RL, R
:param max_nodes: how many nodes in this cell
:return:
"""
raise NotImplementedError
def mutate_architecture(self, parent_arch: ModelMicroCfg) -> (str, ModelMicroCfg):
"""
Mutate architecture, this is to support EA sampler
:rtype: object
:return:
"""
raise NotImplementedError
def get_reinforcement_learning_policy(self, lr_rate):
"""
This is fpr reinforcement learning policy sampler
:return:
"""
raise NotImplementedError
"""In-RDBMS Helper Functions"""
def profiling_score_time(self, dataset: str, train_loader: DataLoader = None, val_loader: DataLoader = None,
args=None, is_simulate: bool = False) -> float:
"""
Profile the scoring time.
Args:
dataset:
train_loader:
val_loader
args:
is_simulate:
Returns:
"""
raise NotImplementedError
def profiling_train_time(self, dataset: str, train_loader: DataLoader = None, val_loader: DataLoader = None,
args=None, is_simulate: bool = False) -> float:
"""
Profile the training time.
Args:
dataset:
train_loader:
val_loader
args:
is_simulate:
Returns:
"""
raise NotImplementedError