blob: 4434e75193aa6a482784c39ab3ae01382f632964 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
from ast import literal_eval
from collections import OrderedDict
from itertools import product as itertools_product
from keras_model_arch_table import ModelArchSchema
import numpy as np
import plpy
from copy import deepcopy
from madlib_keras_custom_function import CustomFunctionSchema
from madlib_keras_helper import generate_row_string
from madlib_keras_validator import MstLoaderInputValidator
from madlib_keras_wrapper import convert_string_of_args_to_dict
from madlib_keras_wrapper import parse_and_validate_fit_params
from madlib_keras_wrapper import parse_and_validate_compile_params
from utilities.control import MinWarning
from utilities.utilities import add_postfix, _assert, _assert_equal, extract_keyvalue_params
from utilities.utilities import quote_ident, get_schema, is_platform_pg
from utilities.validate_args import table_exists, drop_tables
from tensorflow.keras import losses as losses
from tensorflow.keras import metrics as metrics
class ModelSelectionSchema:
MST_KEY = 'mst_key'
MODEL_ID = ModelArchSchema.MODEL_ID
MODEL_ARCH_TABLE = 'model_arch_table'
OBJECT_TABLE = 'object_table'
COMPILE_PARAMS = 'compile_params'
FIT_PARAMS = 'fit_params'
col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR')
GRID_SEARCH='grid'
RANDOM_SEARCH='random'
OPTIMIZER_PARAMS_LIST = 'optimizer_params_list'
@MinWarning("warning")
class MstLoader():
"""The utility class for loading a model selection table with model parameters.
Currently just takes all combinations of input parameters passed. This
utility validates the inputs.
Attributes:
compile_params_list (list): The input list of compile params choices.
fit_params_list (list): The input list of fit params choices.
model_id_list (list): The input list of model id choices.
model_arch_table (str): The name of model architecture table.
model_selection_table (str): The name of the output mst table.
msts (list): The list of generated msts.
"""
def __init__(self,
schema_madlib,
model_arch_table,
model_selection_table,
model_id_list,
compile_params_list,
fit_params_list,
object_table=None,
**kwargs):
self.schema_madlib = schema_madlib
self.model_arch_table = model_arch_table
self.model_selection_table = model_selection_table
self.model_selection_summary_table = add_postfix(
model_selection_table, "_summary")
self.model_id_list = sorted(list(set(model_id_list)))
if object_table is not None:
object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
self.object_table = object_table
MstLoaderInputValidator(
schema_madlib=self.schema_madlib,
model_arch_table=self.model_arch_table,
model_selection_table=self.model_selection_table,
model_selection_summary_table=self.model_selection_summary_table,
model_id_list=self.model_id_list,
compile_params_list=compile_params_list,
fit_params_list=fit_params_list,
object_table=object_table
)
self.compile_params_list = self.params_preprocessed(
compile_params_list)
self.fit_params_list = self.params_preprocessed(fit_params_list)
self.msts = []
self.find_combinations()
def load(self):
"""The entry point for loading the model selection table.
"""
# All of the side effects happen in this function.
self.create_mst_table()
self.create_mst_summary_table()
self.insert_into_mst_table()
def params_preprocessed(self, list_strs):
"""Preprocess the input lists. Eliminate white spaces and sort them.
Args:
list_strs (list): A list of strings.
Returns:
list: The preprocessed list of strings.
"""
dict_dedup = {}
for string in list_strs:
d = convert_string_of_args_to_dict(string)
hash_tuple = tuple( '{0} = {1}' \
.format(x, d[x]) for x in sorted(d.keys()))
dict_dedup[hash_tuple] = string
return dict_dedup.values()
def find_combinations(self):
"""Backtracking helper for generating the combinations.
"""
param_grid = OrderedDict([
(ModelSelectionSchema.MODEL_ID, self.model_id_list),
(ModelSelectionSchema.COMPILE_PARAMS, self.compile_params_list),
(ModelSelectionSchema.FIT_PARAMS, self.fit_params_list)
])
def find_combinations_helper(msts, p, i):
param_names = param_grid.keys()
if i < len(param_names):
for x in param_grid[param_names[i]]:
p[param_names[i]] = x
find_combinations_helper(msts, p, i + 1)
else:
msts.append(p.copy())
find_combinations_helper(self.msts, {}, 0)
def create_mst_table(self):
"""Initialize the output mst table.
"""
with_query = "" if is_platform_pg() else """
with(appendonly=false)"""
create_query = """
CREATE TABLE {self.model_selection_table} (
{mst_key} SERIAL,
{model_id} INTEGER,
{compile_params} VARCHAR,
{fit_params} VARCHAR,
unique ({model_id}, {compile_params}, {fit_params})
) {with_query};
""".format(self=self,
mst_key=ModelSelectionSchema.MST_KEY,
model_id=ModelSelectionSchema.MODEL_ID,
compile_params=ModelSelectionSchema.COMPILE_PARAMS,
fit_params=ModelSelectionSchema.FIT_PARAMS,
with_query=with_query)
with MinWarning('warning'):
plpy.execute(create_query)
def create_mst_summary_table(self):
"""Initialize the output mst table.
"""
create_query = """
CREATE TABLE {self.model_selection_summary_table} (
{model_arch_table} VARCHAR,
{object_table} VARCHAR
);
""".format(self=self,
model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE,
object_table=ModelSelectionSchema.OBJECT_TABLE)
with MinWarning('warning'):
plpy.execute(create_query)
def insert_into_mst_table(self):
"""Insert every thing in self.msts into the mst table.
"""
for mst in self.msts:
model_id = mst[ModelSelectionSchema.MODEL_ID]
compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS]
fit_params = mst[ModelSelectionSchema.FIT_PARAMS]
insert_query = """
INSERT INTO
{self.model_selection_table}(
{model_id_col},
{compile_params_col},
{fit_params_col}
)
VALUES (
{model_id},
$${compile_params}$$,
$${fit_params}$$
)
""".format(model_id_col=ModelSelectionSchema.MODEL_ID,
compile_params_col=ModelSelectionSchema.COMPILE_PARAMS,
fit_params_col=ModelSelectionSchema.FIT_PARAMS,
**locals())
plpy.execute(insert_query)
if self.object_table is None:
object_table = 'NULL::VARCHAR'
else:
object_table = '$${0}$$'.format(self.object_table)
insert_summary_query = """
INSERT INTO
{self.model_selection_summary_table}(
{model_arch_table_name},
{object_table_name}
)
VALUES (
$${self.model_arch_table}$$,
{object_table}
)
""".format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
object_table_name=ModelSelectionSchema.OBJECT_TABLE,
**locals())
plpy.execute(insert_summary_query)
@MinWarning("warning")
class MstSearch():
"""
The utility class for generating model selection configs and loading into a MST table with model parameters.
Currently takes string representations of python dictionaries for compile and fit params.
Generates configs with a chosen search algorithm
Attributes:
model_arch_table (str): The name of model architecture table.
model_selection_table (str): The name of the output mst table.
model_id_list (list): The input list of model id choices.
compile_params_grid (string repr of python dict): The input of compile params choices.
fit_params_grid (string repr of python dict): The input of fit params choices.
search_type (str, default 'grid'): Hyperparameter search strategy, 'grid' or 'random'.
Only for 'random' search type (defaults None):
num_configs (int): Number of configs to generate.
random_state (int): Seed for result reproducibility.
object_table (str, default None): The name of the object table, for custom (metric) functions.
"""
def __init__(self,
schema_madlib,
model_arch_table,
model_selection_table,
model_id_list,
compile_params_grid,
fit_params_grid,
search_type='grid',
num_configs=None,
random_state=None,
object_table=None,
**kwargs):
self.schema_madlib = schema_madlib
self.model_arch_table = model_arch_table
self.model_selection_table = model_selection_table
self.model_selection_summary_table = add_postfix(
model_selection_table, "_summary")
self.model_id_list = sorted(list(set(model_id_list)))
if object_table is not None:
schema_name = get_schema(object_table)
if schema_name is None:
object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
elif schema_name != schema_madlib:
plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib))
MstLoaderInputValidator(
schema_madlib=self.schema_madlib,
model_arch_table=self.model_arch_table,
model_selection_table=self.model_selection_table,
model_selection_summary_table=self.model_selection_summary_table,
model_id_list=self.model_id_list,
compile_params_list=compile_params_grid,
fit_params_list=fit_params_grid,
object_table=object_table,
module_name='generate_model_configs'
)
self.search_type = search_type
self.num_configs = num_configs
self.random_state = random_state
self.object_table = object_table
compile_params_grid = compile_params_grid.replace('\n', '').replace(' ', '')
fit_params_grid = fit_params_grid.replace('\n', '').replace(' ', '')
self.accepted_distributions = ['linear', 'log', 'log_near_one']
# extracting python dict
try:
self.compile_params_dict = literal_eval(compile_params_grid)
except:
plpy.error("Invalid syntax in 'compile_params_dict'")
try:
self.fit_params_dict = literal_eval(fit_params_grid)
except:
plpy.error("Invalid syntax in 'fit_params_dict'")
self.validate_inputs(compile_params_grid, fit_params_grid)
self.msts = []
if ModelSelectionSchema.GRID_SEARCH.startswith(self.search_type.lower()):
self.find_grid_combinations()
elif ModelSelectionSchema.RANDOM_SEARCH.startswith(self.search_type.lower()):
# else should also suffice as random search is established.
self.find_random_combinations()
# param checks and validation
compile_params_lst, fit_params_lst = [], []
for i in self.msts:
compile_params_lst.append(i[ModelSelectionSchema.COMPILE_PARAMS])
fit_params_lst.append(i[ModelSelectionSchema.FIT_PARAMS])
self._validate_params_and_object_table(compile_params_lst, fit_params_lst)
def load(self):
"""The entry point for loading the model selection table.
"""
# All of the side effects happen in this function.
if table_exists(self.model_selection_table):
if table_exists(self.model_selection_summary_table):
res = plpy.execute("SELECT model_arch_table from {0}".format(self.model_selection_summary_table))
# exactly one value
for r in res:
_assert_equal(r['model_arch_table'], self.model_arch_table,
"DL: Inconsistent model arch table. Use '{0}' if appending rows to '{1}'".format(
r['model_arch_table'], self.model_selection_table
))
else:
self.create_mst_summary_table()
else:
self.create_mst_table()
self.create_mst_summary_table()
self.insert_into_mst_table()
def validate_inputs(self, compile_params_grid, fit_params_grid):
"""
Ensures validity of inputs related to grid and random search.
:param compile_params_grid: The input string repr of compile params choices.
:param fit_params_grid: The input string repr of fit params choices.
"""
for c in self.compile_params_dict:
_assert_equal(type(self.compile_params_dict[c]), list, "DL: compile param values must be specified in a list")
for f in self.fit_params_dict:
_assert_equal(type(self.fit_params_dict[f]), list, "DL: fit param values must be specified in a list")
if ModelSelectionSchema.GRID_SEARCH.startswith(self.search_type.lower()):
_assert(self.num_configs is None and self.random_state is None,
"DL: 'num_configs' and 'random_state' must be NULL for grid search")
for distribution_type in self.accepted_distributions:
# If the distribution is used, it will be in the following format:
# [123, 456, '<dist name>']
# Matching single quotes and the closing bracket minimizes false
# positives from the log_dir parameter.
tmp_dist = "'{0}']".format(distribution_type)
_assert(tmp_dist not in compile_params_grid and
tmp_dist not in fit_params_grid,
"DL: Cannot search from a distribution with grid search")
elif ModelSelectionSchema.RANDOM_SEARCH.startswith(self.search_type.lower()):
_assert(self.num_configs is not None and self.num_configs > 0, "DL: 'num_configs' cannot be NULL and "
"needs to be a natural number "
"for random search")
else:
plpy.error("DL: 'search_type' must be either 'grid' or 'random'")
if ModelSelectionSchema.OPTIMIZER_PARAMS_LIST in self.compile_params_dict:
optimizer_params_list = self.compile_params_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST]
optimizer_param_keys = set([j for i in optimizer_params_list for j in i])
_assert(set(self.compile_params_dict).isdisjoint(optimizer_param_keys),
"DL: 'optimizer_params_list' key should only contain 'optimizer' and/or optimizer related params \
and no such params should reside out of the key")
for k in optimizer_params_list:
_assert(len(k) != 0, "DL: empty dictionaries cannot be specified in the value list of \
'optimizer_params_list'")
def _validate_params_and_object_table(self, compile_params_lst, fit_params_lst):
if not fit_params_lst:
plpy.error("fit_params_list cannot be NULL")
for fit_params in fit_params_lst:
try:
res = parse_and_validate_fit_params(fit_params)
except Exception as e:
plpy.error(
"""Fit param check failed for: {0} \n
{1}
""".format(fit_params, str(e)))
if not compile_params_lst:
plpy.error( "compile_params_list cannot be NULL")
custom_fn_name = []
## Initialize builtin loss/metrics functions
builtin_losses = dir(losses)
builtin_metrics = dir(metrics)
# Default metrics, since it is not part of the builtin metrics list
builtin_metrics.append('accuracy')
if self.object_table is not None:
res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME,
self.object_table))
for r in res:
custom_fn_name.append(r[CustomFunctionSchema.FN_NAME])
for compile_params in compile_params_lst:
try:
_, _, res = parse_and_validate_compile_params(compile_params,
[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST])
# Validating if loss/metrics function called in compile_params
# is either defined in object table or is a built_in keras
# loss/metrics function
error_suffix = "but input object table missing!"
if self.object_table is not None:
error_suffix = "is not defined in object table '{0}'!".format(self.object_table)
_assert(res['loss'] in custom_fn_name or res['loss'] in builtin_losses,
"custom function '{0}' used in compile params " \
"{1}".format(res['loss'], error_suffix))
if 'metrics' in res:
_assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0
or len(set(res['metrics']).intersection(builtin_metrics)) > 0),
"custom function '{0}' used in compile params " \
"{1}".format(res['metrics'], error_suffix))
except Exception as e:
plpy.error(
"""Compile param check failed for: {0} \n
{1}
""".format(compile_params, str(e)))
def find_grid_combinations(self):
"""
Finds combinations using grid search.
"""
# assuming optimizer_params_list is present
if ModelSelectionSchema.OPTIMIZER_PARAMS_LIST in self.compile_params_dict:
for opt_params_dict in self.compile_params_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST]:
keys, values = zip(*opt_params_dict.items())
opt_configs_params = [dict(zip(keys, v)) for v in itertools_product(*values)]
copied_compile_dict = deepcopy(self.compile_params_dict)
copied_compile_dict[ModelSelectionSchema.OPTIMIZER_PARAMS_LIST] = opt_configs_params
self.grid_combinations_helper(copied_compile_dict, self.fit_params_dict)
else:
self.grid_combinations_helper(self.compile_params_dict, self.fit_params_dict)
def grid_combinations_helper(self, compile_dict, fit_dict):
combined_dict = dict(compile_dict, **fit_dict)
combined_dict[ModelSelectionSchema.MODEL_ID] = self.model_id_list
keys, values = zip(*combined_dict.items())
all_configs_params = [dict(zip(keys, v)) for v in itertools_product(*values)]
# to separate the compile and fit configs
for config in all_configs_params:
combination = {}
compile_configs, fit_configs = {}, {}
for k in config:
if k == ModelSelectionSchema.MODEL_ID:
combination[ModelSelectionSchema.MODEL_ID] = config[k]
elif k in compile_dict:
compile_configs[k] = config[k]
elif k in fit_dict:
fit_configs[k] = config[k]
else:
plpy.error("DL: {0} is an unidentified key".format(k))
combination[ModelSelectionSchema.COMPILE_PARAMS] = generate_row_string(compile_configs)
combination[ModelSelectionSchema.FIT_PARAMS] = generate_row_string(fit_configs)
self.msts.append(combination)
def find_random_combinations(self):
"""
Finds combinations using random search.
"""
seed_changes = 0
for _ in range(self.num_configs):
combination = {}
if self.random_state:
np.random.seed(self.random_state+seed_changes)
seed_changes += 1
combination[ModelSelectionSchema.MODEL_ID] = np.random.choice(self.model_id_list)
compile_dict, seed_changes = self.generate_param_config(self.compile_params_dict, seed_changes)
combination[ModelSelectionSchema.COMPILE_PARAMS] = generate_row_string(compile_dict)
fit_dict, seed_changes = self.generate_param_config(self.fit_params_dict, seed_changes)
combination[ModelSelectionSchema.FIT_PARAMS] = generate_row_string(fit_dict)
self.msts.append(combination)
def generate_param_config(self, params_dict, seed_changes):
"""
Generating a parameter configuration for random search.
:param params_dict: Dictionary of params choices.
:param config_dict: Dictionary to store param config.
:param seed_changes: Changes in seed for random sampling + reproducibility.
:return: config_dict, seed_changes.
"""
config_dict = {}
for cp in params_dict:
if self.random_state:
np.random.seed(self.random_state+seed_changes)
seed_changes += 1
param_values = params_dict[cp]
if cp == ModelSelectionSchema.OPTIMIZER_PARAMS_LIST:
opt_dict = np.random.choice(param_values)
opt_combination = {}
for i in opt_dict:
opt_values = opt_dict[i]
if self.random_state:
np.random.seed(self.random_state+seed_changes)
seed_changes += 1
opt_combination[i] = self.sample_val(i, opt_values)
config_dict[cp] = opt_combination
else:
config_dict[cp] = self.sample_val(cp, param_values)
return config_dict, seed_changes
def sample_val(self, cp, param_value_list):
"""
Samples a value from a given list of values, either randomly from a list of discrete elements,
or from a specified distribution.
:param cp: compile param
:param param_value_list: list of values (or specified distribution) for a param
:return: sampled value
"""
# check if need to sample from a distribution
if type(param_value_list[-1]) == str and all([type(i) != str and not callable(i) for i in param_value_list[:-1]]) \
and len(param_value_list) > 1:
_assert_equal(len(param_value_list), 3,
"DL: '{0}' should have exactly 3 elements if picking from a distribution".format(cp))
_assert(param_value_list[1] > param_value_list[0],
"DL: '{0}' should be of the format [lower_bound, upper_bound, distribution_type]".format(cp))
if param_value_list[-1] == 'linear':
return np.random.uniform(param_value_list[0], param_value_list[1])
elif param_value_list[-1] == 'log':
return np.power(10, np.random.uniform(np.log10(param_value_list[0]),
np.log10(param_value_list[1])))
elif param_value_list[-1] == 'log_near_one':
return 1.0 - np.power(10, np.random.uniform(np.log10(1.0-param_value_list[1]),
np.log10(1.0-param_value_list[0])))
else:
plpy.error("DL: Please choose a valid distribution type for '{0}': {1}".format(
cp, self.accepted_distributions))
else:
# random sampling
return np.random.choice(param_value_list)
def create_mst_table(self):
"""Initialize the output mst table, if it doesn't exist (for incremental loading).
"""
with_query = "" if is_platform_pg() else """
with(appendonly=false)"""
create_query = """
CREATE TABLE {self.model_selection_table} (
{mst_key} SERIAL,
{model_id} INTEGER,
{compile_params} VARCHAR,
{fit_params} VARCHAR,
unique ({model_id}, {compile_params}, {fit_params})
) {with_query};
""".format(self=self,
mst_key=ModelSelectionSchema.MST_KEY,
model_id=ModelSelectionSchema.MODEL_ID,
compile_params=ModelSelectionSchema.COMPILE_PARAMS,
fit_params=ModelSelectionSchema.FIT_PARAMS,
with_query=with_query)
with MinWarning('warning'):
plpy.execute(create_query)
def create_mst_summary_table(self):
"""Initialize the output mst table.
"""
create_query = """
CREATE TABLE {self.model_selection_summary_table} (
{model_arch_table} VARCHAR,
{object_table} VARCHAR
);
""".format(self=self,
model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE,
object_table=ModelSelectionSchema.OBJECT_TABLE)
with MinWarning('warning'):
plpy.execute(create_query)
def insert_into_mst_table(self):
"""Insert every thing in self.msts into the mst table.
"""
for mst in self.msts:
model_id = mst[ModelSelectionSchema.MODEL_ID]
compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS]
fit_params = mst[ModelSelectionSchema.FIT_PARAMS]
insert_query = """
INSERT INTO
{self.model_selection_table}(
{model_id_col},
{compile_params_col},
{fit_params_col}
)
VALUES (
{model_id},
$${compile_params}$$,
$${fit_params}$$
)
""".format(model_id_col=ModelSelectionSchema.MODEL_ID,
compile_params_col=ModelSelectionSchema.COMPILE_PARAMS,
fit_params_col=ModelSelectionSchema.FIT_PARAMS,
**locals())
plpy.execute(insert_query)
if self.object_table is None:
object_table = 'NULL::VARCHAR'
else:
object_table = '$${0}$$'.format(self.object_table)
insert_summary_query = """
INSERT INTO
{self.model_selection_summary_table}(
{model_arch_table_name},
{object_table_name}
)
VALUES (
$${self.model_arch_table}$$,
{object_table}
)
""".format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
object_table_name=ModelSelectionSchema.OBJECT_TABLE,
**locals())
plpy.execute(insert_summary_query)