blob: 8ced08bcdc95bf95cbe52b17b0d285353d6896b7 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from keras_model_arch_table import ModelArchSchema
from model_arch_info import get_num_classes
from madlib_keras_custom_function import CustomFunctionSchema
from madlib_keras_helper import CLASS_VALUES_COLNAME
from madlib_keras_helper import COMPILE_PARAMS_COLNAME
from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
from madlib_keras_helper import MODEL_ID_COLNAME
from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
from madlib_keras_helper import MODEL_WEIGHTS_COLNAME
from madlib_keras_helper import NORMALIZING_CONST_COLNAME
from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
from madlib_keras_helper import METRIC_TYPE_COLNAME
from madlib_keras_helper import INTERNAL_GPU_CONFIG
from madlib_keras_helper import query_model_configs
from utilities.minibatch_validation import validate_bytea_var_for_minibatch
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import is_platform_pg
from utilities.utilities import is_var_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from madlib_keras_wrapper import parse_and_validate_fit_params
from madlib_keras_wrapper import parse_and_validate_compile_params
import keras.losses as losses
import keras.metrics as metrics
class InputValidator:
@staticmethod
def validate_predict_evaluate_tables(
module_name, model_table, model_summary_table, test_table, output_table,
independent_varname):
InputValidator._validate_model_weights_tbl(module_name, model_table)
InputValidator._validate_model_summary_tbl(
module_name, model_summary_table)
InputValidator._validate_test_tbl(
module_name, test_table, independent_varname)
output_tbl_valid(output_table, module_name)
@staticmethod
def validate_id_in_test_tbl(module_name, test_table, id_col):
_assert(is_var_valid(test_table, id_col),
"{module_name} error: invalid id column "
"('{id_col}') for test table ({table}).".format(
module_name=module_name,
id_col=id_col,
table=test_table))
@staticmethod
def validate_predict_byom_tables(module_name, model_arch_table, model_id,
test_table, id_col, output_table,
independent_varname):
InputValidator.validate_model_arch_table(
module_name, model_arch_table, model_id)
InputValidator._validate_test_tbl(
module_name, test_table, independent_varname)
InputValidator.validate_id_in_test_tbl(module_name, test_table, id_col)
output_tbl_valid(output_table, module_name)
@staticmethod
def validate_pred_type(module_name, pred_type, class_values):
if not pred_type in ['prob', 'response']:
plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
"either response or prob.".format(module_name, pred_type))
@staticmethod
def validate_input_shape(table, independent_varname, input_shape, offset,
is_minibatched=False):
"""
Validate if the input shape specified in model architecture is the same
as the shape of the image specified in the indepedent var of the input
table.
offset: This offset is the index of the start of the image array. We also
need to consider that sql array indexes start from 1
For ex if the image is of shape [32,32,3] and is minibatched, the image will
look like [10, 32, 32, 3]. The offset in this case is 1 (start the index at 1) +
1 (ignore the buffer size 10) = 2.
If the image is not batched then it will look like [32, 32 ,3] and the offset in
this case is 1 (start the index at 1).
"""
if is_minibatched:
ind_shape_col = add_postfix(independent_varname, "_shape")
query = """
SELECT {ind_shape_col} AS shape
FROM {table}
LIMIT 1
""".format(**locals())
# This query will fail if an image in independent var does not have the
# same number of dimensions as the input_shape.
result = plpy.execute(query)[0]['shape']
result = result[1:]
else:
array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
independent_varname, i+offset, i) for i in range(len(input_shape)))
query = """
SELECT {0}
FROM {1}
LIMIT 1
""".format(array_upper_query, table)
# This query will fail if an image in independent var does not have the
# same number of dimensions as the input_shape.
result = plpy.execute(query)[0]
_assert(len(result) == len(input_shape),
"model_keras error: The number of dimensions ({0}) of each image"
" in model architecture and {1} in {2} ({3}) do not match.".format(
len(input_shape), independent_varname, table, len(result)))
for i in range(len(input_shape)):
if is_minibatched:
key_name = i
input_shape_from_table = [result[j]
for j in range(len(input_shape))]
else:
key_format = "n_{0}"
key_name = key_format.format(i)
input_shape_from_table = [result[key_format.format(j)]
for j in range(len(input_shape))]
if result[key_name] != input_shape[i]:
# Construct the shape in independent varname to display
# meaningful error msg.
plpy.error("model_keras error: Input shape {0} in the model"
" architecture does not match the input shape {1} of column"
" {2} in table {3}.".format(
input_shape, input_shape_from_table,
independent_varname, table))
@staticmethod
def validate_model_arch_table(module_name, model_arch_table, model_id):
input_tbl_valid(model_arch_table, module_name)
_assert(model_id is not None,
"{0}: Invalid model architecture ID.".format(module_name))
@staticmethod
def validate_normalizing_const(module_name, normalizing_const):
_assert(normalizing_const > 0,
"{0} error: Normalizing constant has to be greater than 0.".
format(module_name))
@staticmethod
def validate_class_values(module_name, class_values, pred_type, model_arch):
if not class_values:
return
num_classes = len(class_values)
_assert(num_classes == get_num_classes(model_arch),
"{0}: The number of class values do not match the " \
"provided architecture.".format(module_name))
if pred_type == 'prob' and num_classes+1 >= 1600:
plpy.error({"{0}: The output will have {1} columns, exceeding the "\
" max number of columns that can be created (1600)".format(
module_name, num_classes+1)})
@staticmethod
def validate_model_weights(module_name, model_arch, model_weights):
_assert(model_weights and model_arch,
"{0}: Model weights and architecture must be valid.".format(
module_name))
@staticmethod
def _validate_model_weights_tbl(module_name, model_table):
_assert(is_var_valid(model_table, MODEL_WEIGHTS_COLNAME),
"{module_name} error: column '{model_weights}' "
"does not exist in model table '{table}'.".format(
module_name=module_name,
model_weights=MODEL_WEIGHTS_COLNAME,
table=model_table))
_assert(is_var_valid(model_table, ModelArchSchema.MODEL_ARCH),
"{module_name} error: column '{model_arch}' "
"does not exist in model table '{table}'.".format(
module_name=module_name,
model_arch=ModelArchSchema.MODEL_ARCH,
table=model_table))
@staticmethod
def _validate_test_tbl(module_name, test_table, independent_varname):
input_tbl_valid(test_table, module_name)
_assert(is_var_valid(test_table, independent_varname),
"{module_name} error: invalid independent_varname "
"('{independent_varname}') for test table "
"({table}).".format(
module_name=module_name,
independent_varname=independent_varname,
table=test_table))
@staticmethod
def _validate_model_summary_tbl(module_name, model_summary_table):
input_tbl_valid(model_summary_table, module_name)
cols_to_check_for = [CLASS_VALUES_COLNAME,
DEPENDENT_VARNAME_COLNAME,
DEPENDENT_VARTYPE_COLNAME,
MODEL_ID_COLNAME,
MODEL_ARCH_TABLE_COLNAME,
NORMALIZING_CONST_COLNAME,
COMPILE_PARAMS_COLNAME,
METRIC_TYPE_COLNAME]
_assert(columns_exist_in_table(
model_summary_table, cols_to_check_for),
"{0} error: One or more expected columns missing in model "
"summary table ('{1}'). The expected columns are {2}.".format(
module_name, model_summary_table, cols_to_check_for))
@staticmethod
def _validate_gpu_config(module_name, source_table, accessible_gpus_for_seg):
summary_table = add_postfix(source_table, "_summary")
gpu_config = plpy.execute(
"SELECT {0} FROM {1}".format(INTERNAL_GPU_CONFIG, summary_table)
)[0][INTERNAL_GPU_CONFIG]
if gpu_config == 'all_segments':
_assert(0 not in accessible_gpus_for_seg,
"{0} error: Host(s) are missing gpus.".format(module_name))
else:
for i in gpu_config:
_assert(accessible_gpus_for_seg[i] != 0,
"{0} error: Segment {1} does not have gpu".format(module_name, i))
class FitCommonValidator(object):
def __init__(self, source_table, validation_table, output_model_table,
model_arch_table, model_id, dependent_varname,
independent_varname, num_iterations,
metrics_compute_frequency, warm_start,
use_gpus, accessible_gpus_for_seg, module_name, object_table):
self.source_table = source_table
self.validation_table = validation_table
self.output_model_table = output_model_table
self.model_arch_table = model_arch_table
self.model_id = model_id
self.dependent_varname = dependent_varname
self.independent_varname = independent_varname
self.dep_shape_col = add_postfix(dependent_varname, "_shape")
self.ind_shape_col = add_postfix(independent_varname, "_shape")
self.metrics_compute_frequency = metrics_compute_frequency
self.warm_start = warm_start
self.num_iterations = num_iterations
self.object_table = object_table
self.source_summary_table = None
if self.source_table:
self.source_summary_table = add_postfix(
self.source_table, "_summary")
if self.output_model_table:
self.output_summary_model_table = add_postfix(
self.output_model_table, "_summary")
self.accessible_gpus_for_seg = accessible_gpus_for_seg
self.module_name = module_name
self._validate_common_args()
if use_gpus:
InputValidator._validate_gpu_config(self.module_name,
self.source_table, self.accessible_gpus_for_seg)
def _validate_common_args(self):
_assert(self.num_iterations > 0,
"{0}: Number of iterations cannot be < 1.".format(self.module_name))
_assert(self._is_valid_metrics_compute_frequency(),
"{0}: metrics_compute_frequency must be in the range (1 - {1}).".format(
self.module_name, self.num_iterations))
input_tbl_valid(self.source_table, self.module_name)
if self.object_table is not None:
input_tbl_valid(self.object_table, self.module_name)
cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name)
input_tbl_valid(self.source_summary_table, self.module_name,
error_suffix_str="Please ensure that the source table ({0}) "
"has been preprocessed by "
"the image preprocessor.".format(self.source_table))
cols_in_tbl_valid(self.source_summary_table, [CLASS_VALUES_COLNAME,
NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
'dependent_varname', 'independent_varname'], self.module_name)
if not is_platform_pg():
cols_in_tbl_valid(self.source_table, [DISTRIBUTION_KEY_COLNAME], self.module_name)
# Source table and validation tables must have the same schema
self._validate_input_table(self.source_table)
validate_bytea_var_for_minibatch(self.source_table,
self.dependent_varname)
self._validate_validation_table()
if self.warm_start:
input_tbl_valid(self.output_model_table, self.module_name)
input_tbl_valid(self.output_summary_model_table, self.module_name)
else:
output_tbl_valid(self.output_model_table, self.module_name)
output_tbl_valid(self.output_summary_model_table, self.module_name)
def _validate_input_table(self, table):
_assert(is_var_valid(table, self.independent_varname),
"{module_name}: invalid independent_varname "
"('{independent_varname}') for table ({table}). "
"Please ensure that the input table ({table}) "
"has been preprocessed by the image preprocessor.".format(
module_name=self.module_name,
independent_varname=self.independent_varname,
table=table))
_assert(is_var_valid(table, self.dependent_varname),
"{module_name}: invalid dependent_varname "
"('{dependent_varname}') for table ({table}). "
"Please ensure that the input table ({table}) "
"has been preprocessed by the image preprocessor.".format(
module_name=self.module_name,
dependent_varname=self.dependent_varname,
table=table))
_assert(is_var_valid(table, self.ind_shape_col),
"{module_name}: invalid independent_var_shape "
"('{ind_shape_col}') for table ({table}). "
"Please ensure that the input table ({table}) "
"has been preprocessed by the image preprocessor.".format(
module_name=self.module_name,
ind_shape_col=self.ind_shape_col,
table=table))
_assert(is_var_valid(table, self.dep_shape_col),
"{module_name}: invalid dependent_var_shape "
"('{dep_shape_col}') for table ({table}). "
"Please ensure that the input table ({table}) "
"has been preprocessed by the image preprocessor.".format(
module_name=self.module_name,
dep_shape_col=self.dep_shape_col,
table=table))
if not is_platform_pg():
_assert(is_var_valid(table, DISTRIBUTION_KEY_COLNAME),
"{module_name}: missing distribution key "
"('{dist_key_col}') for table ({table}). "
"Please ensure that the input table ({table}) "
"has been preprocessed by the image preprocessor.".format(
module_name=self.module_name,
dist_key_col=DISTRIBUTION_KEY_COLNAME,
table=table))
def _is_valid_metrics_compute_frequency(self):
return self.metrics_compute_frequency is None or \
(self.metrics_compute_frequency >= 1 and \
self.metrics_compute_frequency <= self.num_iterations)
def _validate_validation_table(self):
if self.validation_table and self.validation_table.strip() != '':
input_tbl_valid(self.validation_table, self.module_name)
self._validate_input_table(self.validation_table)
dependent_vartype = get_expr_type(self.dependent_varname,
self.validation_table)
_assert(dependent_vartype == 'bytea',
"Dependent variable column {0} in validation table {1} should be "
"a bytea and also one hot encoded.".format(
self.dependent_varname, self.validation_table))
def validate_input_shapes(self, input_shape):
InputValidator.validate_input_shape(self.source_table, self.independent_varname,
input_shape, 2, True)
if self.validation_table:
InputValidator.validate_input_shape(
self.validation_table, self.independent_varname,
input_shape, 2, True)
class FitInputValidator(FitCommonValidator):
def __init__(self, source_table, validation_table, output_model_table,
model_arch_table, model_id, dependent_varname,
independent_varname, num_iterations,
metrics_compute_frequency, warm_start,
use_gpus, accessible_gpus_for_seg, object_table):
self.module_name = 'madlib_keras_fit'
super(FitInputValidator, self).__init__(source_table,
validation_table,
output_model_table,
model_arch_table,
model_id,
dependent_varname,
independent_varname,
num_iterations,
metrics_compute_frequency,
warm_start,
use_gpus,
accessible_gpus_for_seg,
self.module_name,
object_table)
InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
self.model_id)
class FitMultipleInputValidator(FitCommonValidator):
def __init__(self, source_table, validation_table, output_model_table,
model_selection_table, model_selection_summary_table, dependent_varname,
independent_varname, num_iterations, model_info_table, mst_key_col,
model_arch_table_col, metrics_compute_frequency, warm_start,
use_gpus, accessible_gpus_for_seg):
self.module_name = 'madlib_keras_fit_multiple'
input_tbl_valid(model_selection_table, self.module_name)
input_tbl_valid(model_selection_summary_table, self.module_name,
error_suffix_str="Please ensure that the model selection table ({0}) "
"has been created by "
"load_model_selection_table().".format(
model_selection_table))
self.msts, self.model_arch_table, self.object_table = query_model_configs(
model_selection_table, model_selection_summary_table,
mst_key_col, model_arch_table_col)
if warm_start:
input_tbl_valid(model_info_table, self.module_name)
else:
output_tbl_valid(model_info_table, self.module_name)
super(FitMultipleInputValidator, self).__init__(source_table,
validation_table,
output_model_table,
self.model_arch_table,
None,
dependent_varname,
independent_varname,
num_iterations,
metrics_compute_frequency,
warm_start,
use_gpus,
accessible_gpus_for_seg,
self.module_name,
self.object_table)
class MstLoaderInputValidator():
def __init__(self,
model_arch_table,
model_selection_table,
model_selection_summary_table,
model_id_list,
compile_params_list,
fit_params_list,
object_table,
module_name='load_model_selection_table'
):
self.model_arch_table = model_arch_table
self.model_selection_table = model_selection_table
self.model_selection_summary_table = model_selection_summary_table
self.model_id_list = model_id_list
self.compile_params_list = compile_params_list
self.fit_params_list = fit_params_list
self.object_table = object_table
self.module_name = module_name #'load_model_selection_table'
self._validate_input_args()
def _validate_input_args(self):
self._validate_input_output_tables()
self._validate_model_ids()
if self.module_name == 'load_model_selection_table':
self._validate_compile_and_fit_params()
def _validate_model_ids(self):
model_id_str = '({0})'\
.format(','.join([str(x) for x in self.model_id_list]))
query = """
SELECT count(model_id)
FROM {self.model_arch_table}
WHERE model_id IN {model_id_str}
""".format(**locals())
res = int(plpy.execute(query)[0]['count'])
_assert(
res == len(self.model_id_list),
"{0}: One or more model_id of {1} not found in table {2}".format(
self.module_name,
model_id_str,
self.model_arch_table
)
)
def _validate_compile_and_fit_params(self):
if not self.fit_params_list:
plpy.error("fit_params_list cannot be NULL")
for fit_params in self.fit_params_list:
try:
res = parse_and_validate_fit_params(fit_params)
except Exception as e:
plpy.error(
"""Fit param check failed for: {0} \n
{1}
""".format(fit_params, str(e)))
if not self.compile_params_list:
plpy.error( "compile_params_list cannot be NULL")
custom_fn_name = []
## Initialize builtin loss/metrics functions
builtin_losses = dir(losses)
builtin_metrics = dir(metrics)
# Default metrics, since it is not part of the builtin metrics list
builtin_metrics.append('accuracy')
if self.object_table is not None:
res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME,
self.object_table))
for r in res:
custom_fn_name.append(r[CustomFunctionSchema.FN_NAME])
for compile_params in self.compile_params_list:
try:
_, _, res = parse_and_validate_compile_params(compile_params)
# Validating if loss/metrics function called in compile_params
# is either defined in object table or is a built_in keras
# loss/metrics function
error_suffix = "but input object table missing!"
if self.object_table is not None:
error_suffix = "is not defined in object table '{0}'!".format(self.object_table)
_assert(res['loss'] in custom_fn_name or res['loss'] in builtin_losses,
"custom function '{0}' used in compile params "\
"{1}".format(res['loss'], error_suffix))
if 'metrics' in res:
_assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0
or len(set(res['metrics']).intersection(builtin_metrics)) > 0),
"custom function '{0}' used in compile params " \
"{1}".format(res['metrics'], error_suffix))
except Exception as e:
plpy.error(
"""Compile param check failed for: {0} \n
{1}
""".format(compile_params, str(e)))
def _validate_input_output_tables(self):
input_tbl_valid(self.model_arch_table, self.module_name)
if self.object_table is not None:
input_tbl_valid(self.object_table, self.module_name)
if self.module_name == 'load_model_selection_table':
output_tbl_valid(self.model_selection_table, self.module_name)
output_tbl_valid(self.model_selection_summary_table, self.module_name)