| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| |
| import plpy |
| from collections import OrderedDict |
| from madlib_keras_validator import MstLoaderInputValidator |
| from utilities.control import MinWarning |
| from utilities.utilities import add_postfix |
| from madlib_keras_wrapper import convert_string_of_args_to_dict |
| from keras_model_arch_table import ModelArchSchema |
| |
| class ModelSelectionSchema: |
| MST_KEY = 'mst_key' |
| MODEL_ID = ModelArchSchema.MODEL_ID |
| MODEL_ARCH_TABLE = 'model_arch_table' |
| OBJECT_TABLE = 'object_table' |
| COMPILE_PARAMS = 'compile_params' |
| FIT_PARAMS = 'fit_params' |
| col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR') |
| |
| @MinWarning("warning") |
| class MstLoader(): |
| """The utility class for loading a model selection table with model parameters. |
| |
| Currently just takes all combinations of input parameters passed. This |
| utility validates the inputs. |
| |
| Attributes: |
| compile_params_list (list): The input list of compile params choices. |
| fit_params_list (list): The input list of fit params choices. |
| model_id_list (list): The input list of model id choices. |
| model_arch_table (str): The name of model architechure table. |
| model_selection_table (str): The name of the output mst table. |
| msts (list): The list of generated msts. |
| |
| """ |
| |
| def __init__(self, |
| model_arch_table, |
| model_selection_table, |
| model_id_list, |
| compile_params_list, |
| fit_params_list, |
| object_table=None, |
| **kwargs): |
| |
| self.model_arch_table = model_arch_table |
| self.model_selection_table = model_selection_table |
| self.model_selection_summary_table = add_postfix( |
| model_selection_table, "_summary") |
| self.model_id_list = sorted(list(set(model_id_list))) |
| self.object_table = object_table |
| MstLoaderInputValidator( |
| model_arch_table=self.model_arch_table, |
| model_selection_table=self.model_selection_table, |
| model_selection_summary_table=self.model_selection_summary_table, |
| model_id_list=self.model_id_list, |
| compile_params_list=compile_params_list, |
| fit_params_list=fit_params_list, |
| object_table=object_table |
| ) |
| self.compile_params_list = self.params_preprocessed( |
| compile_params_list) |
| self.fit_params_list = self.params_preprocessed(fit_params_list) |
| |
| self.msts = [] |
| |
| self.find_combinations() |
| |
| def load(self): |
| """The entry point for loading the model selection table. |
| """ |
| # All of the side effects happen in this function. |
| self.create_mst_table() |
| self.create_mst_summary_table() |
| self.insert_into_mst_table() |
| |
| def params_preprocessed(self, list_strs): |
| """Preprocess the input lists. Eliminate white spaces and sort them. |
| |
| Args: |
| list_strs (list): A list of strings. |
| |
| Returns: |
| list: The preprocessed list of strings. |
| """ |
| |
| dict_dedup = {} |
| for string in list_strs: |
| d = convert_string_of_args_to_dict(string) |
| hash_tuple = tuple( '{0} = {1}'\ |
| .format(x, d[x]) for x in sorted(d.keys())) |
| dict_dedup[hash_tuple] = string |
| |
| return dict_dedup.values() |
| |
| def find_combinations(self): |
| """Backtracking helper for generating the combinations. |
| """ |
| param_grid = OrderedDict([ |
| (ModelSelectionSchema.MODEL_ID, self.model_id_list), |
| (ModelSelectionSchema.COMPILE_PARAMS, self.compile_params_list), |
| (ModelSelectionSchema.FIT_PARAMS, self.fit_params_list) |
| ]) |
| |
| def find_combinations_helper(msts, p, i): |
| param_names = param_grid.keys() |
| if i < len(param_names): |
| for x in param_grid[param_names[i]]: |
| p[param_names[i]] = x |
| find_combinations_helper(msts, p, i + 1) |
| else: |
| msts.append(p.copy()) |
| find_combinations_helper(self.msts, {}, 0) |
| |
| def create_mst_table(self): |
| """Initialize the output mst table. |
| """ |
| create_query = """ |
| CREATE TABLE {self.model_selection_table} ( |
| {mst_key} SERIAL, |
| {model_id} INTEGER, |
| {compile_params} VARCHAR, |
| {fit_params} VARCHAR, |
| unique ({model_id}, {compile_params}, {fit_params}) |
| ); |
| """.format(self=self, |
| mst_key=ModelSelectionSchema.MST_KEY, |
| model_id=ModelSelectionSchema.MODEL_ID, |
| compile_params=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params=ModelSelectionSchema.FIT_PARAMS) |
| with MinWarning('warning'): |
| plpy.execute(create_query) |
| |
| def create_mst_summary_table(self): |
| """Initialize the output mst table. |
| """ |
| create_query = """ |
| CREATE TABLE {self.model_selection_summary_table} ( |
| {model_arch_table} VARCHAR, |
| {object_table} VARCHAR |
| ); |
| """.format(self=self, |
| model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE, |
| object_table=ModelSelectionSchema.OBJECT_TABLE) |
| with MinWarning('warning'): |
| plpy.execute(create_query) |
| |
| def insert_into_mst_table(self): |
| """Insert every thing in self.msts into the mst table. |
| """ |
| for mst in self.msts: |
| model_id = mst[ModelSelectionSchema.MODEL_ID] |
| compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS] |
| fit_params = mst[ModelSelectionSchema.FIT_PARAMS] |
| insert_query = """ |
| INSERT INTO |
| {self.model_selection_table}( |
| {model_id_col}, |
| {compile_params_col}, |
| {fit_params_col} |
| ) |
| VALUES ( |
| {model_id}, |
| $${compile_params}$$, |
| $${fit_params}$$ |
| ) |
| """.format(model_id_col=ModelSelectionSchema.MODEL_ID, |
| compile_params_col=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params_col=ModelSelectionSchema.FIT_PARAMS, |
| **locals()) |
| plpy.execute(insert_query) |
| if self.object_table is None: |
| object_table = 'NULL::VARCHAR' |
| else: |
| object_table = '$${0}$$'.format(self.object_table) |
| insert_summary_query = """ |
| INSERT INTO |
| {self.model_selection_summary_table}( |
| {model_arch_table_name}, |
| {object_table_name} |
| ) |
| VALUES ( |
| $${self.model_arch_table}$$, |
| {object_table} |
| ) |
| """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE, |
| object_table_name=ModelSelectionSchema.OBJECT_TABLE, |
| **locals()) |
| plpy.execute(insert_summary_query) |