| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| |
| from ast import literal_eval |
| from collections import OrderedDict |
| from itertools import product as itertools_product |
| from keras_model_arch_table import ModelArchSchema |
| import numpy as np |
| import plpy |
| from copy import deepcopy |
| |
| from madlib_keras_custom_function import CustomFunctionSchema |
| from madlib_keras_helper import generate_row_string |
| from madlib_keras_validator import MstLoaderInputValidator |
| from madlib_keras_wrapper import convert_string_of_args_to_dict |
| from madlib_keras_wrapper import parse_and_validate_fit_params |
| from madlib_keras_wrapper import parse_and_validate_compile_params |
| from utilities.control import MinWarning |
| from utilities.utilities import add_postfix, _assert, _assert_equal, extract_keyvalue_params |
| from utilities.utilities import quote_ident, get_schema, is_platform_pg |
| from utilities.validate_args import table_exists, drop_tables |
| |
| from tensorflow.keras import losses as losses |
| from tensorflow.keras import metrics as metrics |
| |
| class ModelSelectionSchema: |
| MST_KEY = 'mst_key' |
| MODEL_ID = ModelArchSchema.MODEL_ID |
| MODEL_ARCH_TABLE = 'model_arch_table' |
| OBJECT_TABLE = 'object_table' |
| COMPILE_PARAMS = 'compile_params' |
| FIT_PARAMS = 'fit_params' |
| col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR') |
| GRID_SEARCH='grid' |
| RANDOM_SEARCH='random' |
| OPTIMIZER_PARAMS_LIST = 'optimizer_params_list' |
| |
| @MinWarning("warning") |
| class MstLoader(): |
| """The utility class for loading a model selection table with model parameters. |
| |
| Currently just takes all combinations of input parameters passed. This |
| utility validates the inputs. |
| |
| Attributes: |
| compile_params_list (list): The input list of compile params choices. |
| fit_params_list (list): The input list of fit params choices. |
| model_id_list (list): The input list of model id choices. |
| model_arch_table (str): The name of model architecture table. |
| model_selection_table (str): The name of the output mst table. |
| msts (list): The list of generated msts. |
| |
| """ |
| |
| def __init__(self, |
| schema_madlib, |
| model_arch_table, |
| model_selection_table, |
| model_id_list, |
| compile_params_list, |
| fit_params_list, |
| object_table=None, |
| **kwargs): |
| |
| self.schema_madlib = schema_madlib |
| self.model_arch_table = model_arch_table |
| self.model_selection_table = model_selection_table |
| self.model_selection_summary_table = add_postfix( |
| model_selection_table, "_summary") |
| self.model_id_list = sorted(list(set(model_id_list))) |
| if object_table is not None: |
| object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table)) |
| self.object_table = object_table |
| |
| MstLoaderInputValidator( |
| schema_madlib=self.schema_madlib, |
| model_arch_table=self.model_arch_table, |
| model_selection_table=self.model_selection_table, |
| model_selection_summary_table=self.model_selection_summary_table, |
| model_id_list=self.model_id_list, |
| compile_params_list=compile_params_list, |
| fit_params_list=fit_params_list, |
| object_table=object_table |
| ) |
| self.compile_params_list = self.params_preprocessed( |
| compile_params_list) |
| self.fit_params_list = self.params_preprocessed(fit_params_list) |
| |
| self.msts = [] |
| |
| self.find_combinations() |
| |
| def load(self): |
| """The entry point for loading the model selection table. |
| """ |
| # All of the side effects happen in this function. |
| self.create_mst_table() |
| self.create_mst_summary_table() |
| self.insert_into_mst_table() |
| |
| def params_preprocessed(self, list_strs): |
| """Preprocess the input lists. Eliminate white spaces and sort them. |
| |
| Args: |
| list_strs (list): A list of strings. |
| |
| Returns: |
| list: The preprocessed list of strings. |
| """ |
| |
| dict_dedup = {} |
| for string in list_strs: |
| d = convert_string_of_args_to_dict(string) |
| hash_tuple = tuple( '{0} = {1}' \ |
| .format(x, d[x]) for x in sorted(d.keys())) |
| dict_dedup[hash_tuple] = string |
| |
| return dict_dedup.values() |
| |
| def find_combinations(self): |
| """Backtracking helper for generating the combinations. |
| """ |
| param_grid = OrderedDict([ |
| (ModelSelectionSchema.MODEL_ID, self.model_id_list), |
| (ModelSelectionSchema.COMPILE_PARAMS, self.compile_params_list), |
| (ModelSelectionSchema.FIT_PARAMS, self.fit_params_list) |
| ]) |
| |
| def find_combinations_helper(msts, p, i): |
| param_names = param_grid.keys() |
| if i < len(param_names): |
| for x in param_grid[param_names[i]]: |
| p[param_names[i]] = x |
| find_combinations_helper(msts, p, i + 1) |
| else: |
| msts.append(p.copy()) |
| find_combinations_helper(self.msts, {}, 0) |
| |
| def create_mst_table(self): |
| """Initialize the output mst table. |
| """ |
| with_query = "" if is_platform_pg() else """ |
| with(appendonly=false)""" |
| create_query = """ |
| CREATE TABLE {self.model_selection_table} ( |
| {mst_key} SERIAL, |
| {model_id} INTEGER, |
| {compile_params} VARCHAR, |
| {fit_params} VARCHAR, |
| unique ({model_id}, {compile_params}, {fit_params}) |
| ) {with_query}; |
| """.format(self=self, |
| mst_key=ModelSelectionSchema.MST_KEY, |
| model_id=ModelSelectionSchema.MODEL_ID, |
| compile_params=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params=ModelSelectionSchema.FIT_PARAMS, |
| with_query=with_query) |
| with MinWarning('warning'): |
| plpy.execute(create_query) |
| |
| def create_mst_summary_table(self): |
| """Initialize the output mst table. |
| """ |
| create_query = """ |
| CREATE TABLE {self.model_selection_summary_table} ( |
| {model_arch_table} VARCHAR, |
| {object_table} VARCHAR |
| ); |
| """.format(self=self, |
| model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE, |
| object_table=ModelSelectionSchema.OBJECT_TABLE) |
| with MinWarning('warning'): |
| plpy.execute(create_query) |
| |
| def insert_into_mst_table(self): |
| """Insert every thing in self.msts into the mst table. |
| """ |
| for mst in self.msts: |
| model_id = mst[ModelSelectionSchema.MODEL_ID] |
| compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS] |
| fit_params = mst[ModelSelectionSchema.FIT_PARAMS] |
| insert_query = """ |
| INSERT INTO |
| {self.model_selection_table}( |
| {model_id_col}, |
| {compile_params_col}, |
| {fit_params_col} |
| ) |
| VALUES ( |
| {model_id}, |
| $${compile_params}$$, |
| $${fit_params}$$ |
| ) |
| """.format(model_id_col=ModelSelectionSchema.MODEL_ID, |
| compile_params_col=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params_col=ModelSelectionSchema.FIT_PARAMS, |
| **locals()) |
| plpy.execute(insert_query) |
| if self.object_table is None: |
| object_table = 'NULL::VARCHAR' |
| else: |
| object_table = '$${0}$$'.format(self.object_table) |
| insert_summary_query = """ |
| INSERT INTO |
| {self.model_selection_summary_table}( |
| {model_arch_table_name}, |
| {object_table_name} |
| ) |
| VALUES ( |
| $${self.model_arch_table}$$, |
| {object_table} |
| ) |
| """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE, |
| object_table_name=ModelSelectionSchema.OBJECT_TABLE, |
| **locals()) |
| plpy.execute(insert_summary_query) |
| |
| @MinWarning("warning") |
| class MstSearch(): |
| """ |
| The utility class for generating model selection configs and loading into a MST table with model parameters. |
| |
| Currently takes string representations of python dictionaries for compile and fit params. |
| Generates configs with a chosen search algorithm |
| |
| Attributes: |
| model_arch_table (str): The name of model architecture table. |
| model_selection_table (str): The name of the output mst table. |
| model_id_list (list): The input list of model id choices. |
| compile_params_grid (string repr of python dict): The input of compile params choices. |
| fit_params_grid (string repr of python dict): The input of fit params choices. |
| search_type (str, default 'grid'): Hyperparameter search strategy, 'grid' or 'random'. |
| |
| Only for 'random' search type (defaults None): |
| num_configs (int): Number of configs to generate. |
| random_state (int): Seed for result reproducibility. |
| |
| object_table (str, default None): The name of the object table, for custom (metric) functions. |
| |
| """ |
| |
| def __init__(self, |
| schema_madlib, |
| model_arch_table, |
| model_selection_table, |
| model_id_list, |
| compile_params_grid, |
| fit_params_grid, |
| search_type='grid', |
| num_configs=None, |
| random_state=None, |
| object_table=None, |
| **kwargs): |
| |
| self.schema_madlib = schema_madlib |
| self.model_arch_table = model_arch_table |
| self.model_selection_table = model_selection_table |
| self.model_selection_summary_table = add_postfix( |
| model_selection_table, "_summary") |
| self.model_id_list = sorted(list(set(model_id_list))) |
| |
| if object_table is not None: |
| schema_name = get_schema(object_table) |
| if schema_name is None: |
| object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table)) |
| elif schema_name != schema_madlib: |
| plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib)) |
| MstLoaderInputValidator( |
| schema_madlib=self.schema_madlib, |
| model_arch_table=self.model_arch_table, |
| model_selection_table=self.model_selection_table, |
| model_selection_summary_table=self.model_selection_summary_table, |
| model_id_list=self.model_id_list, |
| compile_params_list=compile_params_grid, |
| fit_params_list=fit_params_grid, |
| object_table=object_table, |
| module_name='generate_model_configs' |
| ) |
| |
| self.search_type = search_type |
| self.num_configs = num_configs |
| self.random_state = random_state |
| self.object_table = object_table |
| |
| compile_params_grid = compile_params_grid.replace('\n', '').replace(' ', '') |
| fit_params_grid = fit_params_grid.replace('\n', '').replace(' ', '') |
| self.accepted_distributions = ['linear', 'log', 'log_near_one'] |
| |
| # extracting python dict |
| try: |
| self.compile_params_dict = literal_eval(compile_params_grid) |
| except: |
| plpy.error("Invalid syntax in 'compile_params_dict'") |
| try: |
| self.fit_params_dict = literal_eval(fit_params_grid) |
| except: |
| plpy.error("Invalid syntax in 'fit_params_dict'") |
| self.validate_inputs(compile_params_grid, fit_params_grid) |
| |
| self.msts = [] |
| |
| if ModelSelectionSchema.GRID_SEARCH.startswith(self.search_type.lower()): |
| self.find_grid_combinations() |
| elif ModelSelectionSchema.RANDOM_SEARCH.startswith(self.search_type.lower()): |
| # else should also suffice as random search is established. |
| self.find_random_combinations() |
| |
| # param checks and validation |
| compile_params_lst, fit_params_lst = [], [] |
| for i in self.msts: |
| compile_params_lst.append(i[ModelSelectionSchema.COMPILE_PARAMS]) |
| fit_params_lst.append(i[ModelSelectionSchema.FIT_PARAMS]) |
| self._validate_params_and_object_table(compile_params_lst, fit_params_lst) |
| |
| def load(self): |
| """The entry point for loading the model selection table. |
| """ |
| # All of the side effects happen in this function. |
| if table_exists(self.model_selection_table): |
| if table_exists(self.model_selection_summary_table): |
| res = plpy.execute("SELECT model_arch_table from {0}".format(self.model_selection_summary_table)) |
| # exactly one value |
| for r in res: |
| _assert_equal(r['model_arch_table'], self.model_arch_table, |
| "DL: Inconsistent model arch table. Use '{0}' if appending rows to '{1}'".format( |
| r['model_arch_table'], self.model_selection_table |
| )) |
| else: |
| self.create_mst_summary_table() |
| else: |
| self.create_mst_table() |
| self.create_mst_summary_table() |
| self.insert_into_mst_table() |
| |
| def validate_inputs(self, compile_params_grid, fit_params_grid): |
| """ |
| Ensures validity of inputs related to grid and random search. |
| |
| :param compile_params_grid: The input string repr of compile params choices. |
| :param fit_params_grid: The input string repr of fit params choices. |
| """ |
| |
| for c in self.compile_params_dict: |
| _assert_equal(type(self.compile_params_dict[c]), list, "DL: compile param values must be specified in a list") |
| for f in self.fit_params_dict: |
| _assert_equal(type(self.fit_params_dict[f]), list, "DL: fit param values must be specified in a list") |
| |
| if ModelSelectionSchema.GRID_SEARCH.startswith(self.search_type.lower()): |
| _assert(self.num_configs is None and self.random_state is None, |
| "DL: 'num_configs' and 'random_state' must be NULL for grid search") |
| for distribution_type in self.accepted_distributions: |
| # If the distribution is used, it will be in the following format: |
| # [123, 456, '<dist name>'] |
| # Matching single quotes and the closing bracket minimizes false |
| # positives from the log_dir parameter. |
| tmp_dist = "'{0}']".format(distribution_type) |
| _assert(tmp_dist not in compile_params_grid and |
| tmp_dist not in fit_params_grid, |
| "DL: Cannot search from a distribution with grid search") |
| elif ModelSelectionSchema.RANDOM_SEARCH.startswith(self.search_type.lower()): |
| _assert(self.num_configs is not None and self.num_configs > 0, "DL: 'num_configs' cannot be NULL and " |
| "needs to be a natural number " |
| "for random search") |
| else: |
| plpy.error("DL: 'search_type' must be either 'grid' or 'random'") |
| |
| if ModelSelectionSchema.OPTIMIZER_PARAMS_LIST in self.compile_params_dict: |
| optimizer_params_list = self.compile_params_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST] |
| optimizer_param_keys = set([j for i in optimizer_params_list for j in i]) |
| _assert(set(self.compile_params_dict).isdisjoint(optimizer_param_keys), |
| "DL: 'optimizer_params_list' key should only contain 'optimizer' and/or optimizer related params \ |
| and no such params should reside out of the key") |
| for k in optimizer_params_list: |
| _assert(len(k) != 0, "DL: empty dictionaries cannot be specified in the value list of \ |
| 'optimizer_params_list'") |
| |
| def _validate_params_and_object_table(self, compile_params_lst, fit_params_lst): |
| if not fit_params_lst: |
| plpy.error("fit_params_list cannot be NULL") |
| for fit_params in fit_params_lst: |
| try: |
| res = parse_and_validate_fit_params(fit_params) |
| except Exception as e: |
| plpy.error( |
| """Fit param check failed for: {0} \n |
| {1} |
| """.format(fit_params, str(e))) |
| if not compile_params_lst: |
| plpy.error( "compile_params_list cannot be NULL") |
| custom_fn_name = [] |
| ## Initialize builtin loss/metrics functions |
| builtin_losses = dir(losses) |
| builtin_metrics = dir(metrics) |
| # Default metrics, since it is not part of the builtin metrics list |
| builtin_metrics.append('accuracy') |
| if self.object_table is not None: |
| res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME, |
| self.object_table)) |
| for r in res: |
| custom_fn_name.append(r[CustomFunctionSchema.FN_NAME]) |
| for compile_params in compile_params_lst: |
| try: |
| _, _, res = parse_and_validate_compile_params(compile_params, |
| [ModelSelectionSchema.OPTIMIZER_PARAMS_LIST]) |
| # Validating if loss/metrics function called in compile_params |
| # is either defined in object table or is a built_in keras |
| # loss/metrics function |
| error_suffix = "but input object table missing!" |
| if self.object_table is not None: |
| error_suffix = "is not defined in object table '{0}'!".format(self.object_table) |
| |
| _assert(res['loss'] in custom_fn_name or res['loss'] in builtin_losses, |
| "custom function '{0}' used in compile params " \ |
| "{1}".format(res['loss'], error_suffix)) |
| if 'metrics' in res: |
| _assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0 |
| or len(set(res['metrics']).intersection(builtin_metrics)) > 0), |
| "custom function '{0}' used in compile params " \ |
| "{1}".format(res['metrics'], error_suffix)) |
| |
| except Exception as e: |
| plpy.error( |
| """Compile param check failed for: {0} \n |
| {1} |
| """.format(compile_params, str(e))) |
| |
| def find_grid_combinations(self): |
| """ |
| Finds combinations using grid search. |
| """ |
| # assuming optimizer_params_list is present |
| if ModelSelectionSchema.OPTIMIZER_PARAMS_LIST in self.compile_params_dict: |
| for opt_params_dict in self.compile_params_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST]: |
| keys, values = zip(*opt_params_dict.items()) |
| opt_configs_params = [dict(zip(keys, v)) for v in itertools_product(*values)] |
| copied_compile_dict = deepcopy(self.compile_params_dict) |
| copied_compile_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST] = opt_configs_params |
| self.grid_combinations_helper(copied_compile_dict, self.fit_params_dict) |
| else: |
| self.grid_combinations_helper(self.compile_params_dict, self.fit_params_dict) |
| |
| def grid_combinations_helper(self, compile_dict, fit_dict): |
| combined_dict = dict(compile_dict, **fit_dict) |
| combined_dict[ModelSelectionSchema.MODEL_ID] = self.model_id_list |
| keys, values = zip(*combined_dict.items()) |
| all_configs_params = [dict(zip(keys, v)) for v in itertools_product(*values)] |
| |
| # to separate the compile and fit configs |
| for config in all_configs_params: |
| combination = {} |
| compile_configs, fit_configs = {}, {} |
| for k in config: |
| if k == ModelSelectionSchema.MODEL_ID: |
| combination[ModelSelectionSchema.MODEL_ID] = config[k] |
| elif k in compile_dict: |
| compile_configs[k] = config[k] |
| elif k in fit_dict: |
| fit_configs[k] = config[k] |
| else: |
| plpy.error("DL: {0} is an unidentified key".format(k)) |
| combination[ModelSelectionSchema.COMPILE_PARAMS] = generate_row_string(compile_configs) |
| combination[ModelSelectionSchema.FIT_PARAMS] = generate_row_string(fit_configs) |
| self.msts.append(combination) |
| |
| def find_random_combinations(self): |
| """ |
| Finds combinations using random search. |
| """ |
| seed_changes = 0 |
| |
| for _ in range(self.num_configs): |
| combination = {} |
| if self.random_state: |
| np.random.seed(self.random_state+seed_changes) |
| seed_changes += 1 |
| combination[ModelSelectionSchema.MODEL_ID] = np.random.choice(self.model_id_list) |
| compile_dict, seed_changes = self.generate_param_config(self.compile_params_dict, seed_changes) |
| combination[ModelSelectionSchema.COMPILE_PARAMS] = generate_row_string(compile_dict) |
| fit_dict, seed_changes = self.generate_param_config(self.fit_params_dict, seed_changes) |
| combination[ModelSelectionSchema.FIT_PARAMS] = generate_row_string(fit_dict) |
| self.msts.append(combination) |
| |
| def generate_param_config(self, params_dict, seed_changes): |
| """ |
| Generating a parameter configuration for random search. |
| :param params_dict: Dictionary of params choices. |
| :param config_dict: Dictionary to store param config. |
| :param seed_changes: Changes in seed for random sampling + reproducibility. |
| :return: config_dict, seed_changes. |
| """ |
| config_dict = {} |
| for cp in params_dict: |
| if self.random_state: |
| np.random.seed(self.random_state+seed_changes) |
| seed_changes += 1 |
| param_values = params_dict[cp] |
| if cp == ModelSelectionSchema.OPTIMIZER_PARAMS_LIST: |
| opt_dict = np.random.choice(param_values) |
| opt_combination = {} |
| for i in opt_dict: |
| opt_values = opt_dict[i] |
| if self.random_state: |
| np.random.seed(self.random_state+seed_changes) |
| seed_changes += 1 |
| opt_combination[i] = self.sample_val(i, opt_values) |
| config_dict[cp] = opt_combination |
| else: |
| config_dict[cp] = self.sample_val(cp, param_values) |
| return config_dict, seed_changes |
| |
| def sample_val(self, cp, param_value_list): |
| """ |
| Samples a value from a given list of values, either randomly from a list of discrete elements, |
| or from a specified distribution. |
| :param cp: compile param |
| :param param_value_list: list of values (or specified distribution) for a param |
| :return: sampled value |
| """ |
| # check if need to sample from a distribution |
| if type(param_value_list[-1]) == str and all([type(i) != str and not callable(i) for i in param_value_list[:-1]]) \ |
| and len(param_value_list) > 1: |
| _assert_equal(len(param_value_list), 3, |
| "DL: '{0}' should have exactly 3 elements if picking from a distribution".format(cp)) |
| _assert(param_value_list[1] > param_value_list[0], |
| "DL: '{0}' should be of the format [lower_bound, upper_bound, distribution_type]".format(cp)) |
| if param_value_list[-1] == 'linear': |
| return np.random.uniform(param_value_list[0], param_value_list[1]) |
| elif param_value_list[-1] == 'log': |
| return np.power(10, np.random.uniform(np.log10(param_value_list[0]), |
| np.log10(param_value_list[1]))) |
| elif param_value_list[-1] == 'log_near_one': |
| return 1.0 - np.power(10, np.random.uniform(np.log10(1.0-param_value_list[1]), |
| np.log10(1.0-param_value_list[0]))) |
| else: |
| plpy.error("DL: Please choose a valid distribution type for '{0}': {1}".format( |
| cp, self.accepted_distributions)) |
| else: |
| # random sampling |
| return np.random.choice(param_value_list) |
| |
| def create_mst_table(self): |
| """Initialize the output mst table, if it doesn't exist (for incremental loading). |
| """ |
| with_query = "" if is_platform_pg() else """ |
| with(appendonly=false)""" |
| create_query = """ |
| CREATE TABLE {self.model_selection_table} ( |
| {mst_key} SERIAL, |
| {model_id} INTEGER, |
| {compile_params} VARCHAR, |
| {fit_params} VARCHAR, |
| unique ({model_id}, {compile_params}, {fit_params}) |
| ) {with_query}; |
| """.format(self=self, |
| mst_key=ModelSelectionSchema.MST_KEY, |
| model_id=ModelSelectionSchema.MODEL_ID, |
| compile_params=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params=ModelSelectionSchema.FIT_PARAMS, |
| with_query=with_query) |
| |
| with MinWarning('warning'): |
| plpy.execute(create_query) |
| |
| def create_mst_summary_table(self): |
| """Initialize the output mst table. |
| """ |
| create_query = """ |
| CREATE TABLE {self.model_selection_summary_table} ( |
| {model_arch_table} VARCHAR, |
| {object_table} VARCHAR |
| ); |
| """.format(self=self, |
| model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE, |
| object_table=ModelSelectionSchema.OBJECT_TABLE) |
| with MinWarning('warning'): |
| plpy.execute(create_query) |
| |
| def insert_into_mst_table(self): |
| """Insert every thing in self.msts into the mst table. |
| """ |
| for mst in self.msts: |
| model_id = mst[ModelSelectionSchema.MODEL_ID] |
| compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS] |
| fit_params = mst[ModelSelectionSchema.FIT_PARAMS] |
| insert_query = """ |
| INSERT INTO |
| {self.model_selection_table}( |
| {model_id_col}, |
| {compile_params_col}, |
| {fit_params_col} |
| ) |
| VALUES ( |
| {model_id}, |
| $${compile_params}$$, |
| $${fit_params}$$ |
| ) |
| """.format(model_id_col=ModelSelectionSchema.MODEL_ID, |
| compile_params_col=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params_col=ModelSelectionSchema.FIT_PARAMS, |
| **locals()) |
| plpy.execute(insert_query) |
| if self.object_table is None: |
| object_table = 'NULL::VARCHAR' |
| else: |
| object_table = '$${0}$$'.format(self.object_table) |
| insert_summary_query = """ |
| INSERT INTO |
| {self.model_selection_summary_table}( |
| {model_arch_table_name}, |
| {object_table_name} |
| ) |
| VALUES ( |
| $${self.model_arch_table}$$, |
| {object_table} |
| ) |
| """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE, |
| object_table_name=ModelSelectionSchema.OBJECT_TABLE, |
| **locals()) |
| plpy.execute(insert_summary_query) |