blob: 46267f075c1df28cefbfe573d66acc425b8623e2 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
import plpy
from collections import OrderedDict
from madlib_keras_validator import MstLoaderInputValidator
from utilities.control import MinWarning
from utilities.utilities import add_postfix
from madlib_keras_wrapper import convert_string_of_args_to_dict
from keras_model_arch_table import ModelArchSchema
class ModelSelectionSchema:
MST_KEY = 'mst_key'
MODEL_ID = ModelArchSchema.MODEL_ID
MODEL_ARCH_TABLE = 'model_arch_table'
OBJECT_TABLE = 'object_table'
COMPILE_PARAMS = 'compile_params'
FIT_PARAMS = 'fit_params'
col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR')
@MinWarning("warning")
class MstLoader():
"""The utility class for loading a model selection table with model parameters.
Currently just takes all combinations of input parameters passed. This
utility validates the inputs.
Attributes:
compile_params_list (list): The input list of compile params choices.
fit_params_list (list): The input list of fit params choices.
model_id_list (list): The input list of model id choices.
model_arch_table (str): The name of model architechure table.
model_selection_table (str): The name of the output mst table.
msts (list): The list of generated msts.
"""
def __init__(self,
model_arch_table,
model_selection_table,
model_id_list,
compile_params_list,
fit_params_list,
object_table=None,
**kwargs):
self.model_arch_table = model_arch_table
self.model_selection_table = model_selection_table
self.model_selection_summary_table = add_postfix(
model_selection_table, "_summary")
self.model_id_list = sorted(list(set(model_id_list)))
self.object_table = object_table
MstLoaderInputValidator(
model_arch_table=self.model_arch_table,
model_selection_table=self.model_selection_table,
model_selection_summary_table=self.model_selection_summary_table,
model_id_list=self.model_id_list,
compile_params_list=compile_params_list,
fit_params_list=fit_params_list,
object_table=object_table
)
self.compile_params_list = self.params_preprocessed(
compile_params_list)
self.fit_params_list = self.params_preprocessed(fit_params_list)
self.msts = []
self.find_combinations()
def load(self):
"""The entry point for loading the model selection table.
"""
# All of the side effects happen in this function.
self.create_mst_table()
self.create_mst_summary_table()
self.insert_into_mst_table()
def params_preprocessed(self, list_strs):
"""Preprocess the input lists. Eliminate white spaces and sort them.
Args:
list_strs (list): A list of strings.
Returns:
list: The preprocessed list of strings.
"""
dict_dedup = {}
for string in list_strs:
d = convert_string_of_args_to_dict(string)
hash_tuple = tuple( '{0} = {1}'\
.format(x, d[x]) for x in sorted(d.keys()))
dict_dedup[hash_tuple] = string
return dict_dedup.values()
def find_combinations(self):
"""Backtracking helper for generating the combinations.
"""
param_grid = OrderedDict([
(ModelSelectionSchema.MODEL_ID, self.model_id_list),
(ModelSelectionSchema.COMPILE_PARAMS, self.compile_params_list),
(ModelSelectionSchema.FIT_PARAMS, self.fit_params_list)
])
def find_combinations_helper(msts, p, i):
param_names = param_grid.keys()
if i < len(param_names):
for x in param_grid[param_names[i]]:
p[param_names[i]] = x
find_combinations_helper(msts, p, i + 1)
else:
msts.append(p.copy())
find_combinations_helper(self.msts, {}, 0)
def create_mst_table(self):
"""Initialize the output mst table.
"""
create_query = """
CREATE TABLE {self.model_selection_table} (
{mst_key} SERIAL,
{model_id} INTEGER,
{compile_params} VARCHAR,
{fit_params} VARCHAR,
unique ({model_id}, {compile_params}, {fit_params})
);
""".format(self=self,
mst_key=ModelSelectionSchema.MST_KEY,
model_id=ModelSelectionSchema.MODEL_ID,
compile_params=ModelSelectionSchema.COMPILE_PARAMS,
fit_params=ModelSelectionSchema.FIT_PARAMS)
with MinWarning('warning'):
plpy.execute(create_query)
def create_mst_summary_table(self):
"""Initialize the output mst table.
"""
create_query = """
CREATE TABLE {self.model_selection_summary_table} (
{model_arch_table} VARCHAR,
{object_table} VARCHAR
);
""".format(self=self,
model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE,
object_table=ModelSelectionSchema.OBJECT_TABLE)
with MinWarning('warning'):
plpy.execute(create_query)
def insert_into_mst_table(self):
"""Insert every thing in self.msts into the mst table.
"""
for mst in self.msts:
model_id = mst[ModelSelectionSchema.MODEL_ID]
compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS]
fit_params = mst[ModelSelectionSchema.FIT_PARAMS]
insert_query = """
INSERT INTO
{self.model_selection_table}(
{model_id_col},
{compile_params_col},
{fit_params_col}
)
VALUES (
{model_id},
$${compile_params}$$,
$${fit_params}$$
)
""".format(model_id_col=ModelSelectionSchema.MODEL_ID,
compile_params_col=ModelSelectionSchema.COMPILE_PARAMS,
fit_params_col=ModelSelectionSchema.FIT_PARAMS,
**locals())
plpy.execute(insert_query)
if self.object_table is None:
object_table = 'NULL::VARCHAR'
else:
object_table = '$${0}$$'.format(self.object_table)
insert_summary_query = """
INSERT INTO
{self.model_selection_summary_table}(
{model_arch_table_name},
{object_table_name}
)
VALUES (
$${self.model_arch_table}$$,
{object_table}
)
""".format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
object_table_name=ModelSelectionSchema.OBJECT_TABLE,
**locals())
plpy.execute(insert_summary_query)