blob: de5c63d0bed0e01519a1b2d185d9eb2b03e09af5 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from input_data_preprocessor import DistributionRulesOptions
from keras_model_arch_table import ModelArchSchema
from model_arch_info import get_num_classes
from madlib_keras_custom_function import CustomFunctionSchema
from madlib_keras_helper import CLASS_VALUES_COLNAME
from madlib_keras_helper import COMPILE_PARAMS_COLNAME
from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
from madlib_keras_helper import MODEL_ID_COLNAME
from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
from madlib_keras_helper import MODEL_WEIGHTS_COLNAME
from madlib_keras_helper import NORMALIZING_CONST_COLNAME
from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
from madlib_keras_helper import METRIC_TYPE_COLNAME
from madlib_keras_helper import INTERNAL_GPU_CONFIG
from madlib_keras_helper import query_model_configs
from utilities.minibatch_validation import validate_bytea_var_for_minibatch
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import is_platform_pg
from utilities.utilities import is_var_valid
from utilities.utilities import is_superuser
from utilities.utilities import get_table_owner
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from madlib_keras_wrapper import parse_and_validate_fit_params
from madlib_keras_wrapper import parse_and_validate_compile_params
from madlib_keras_custom_function import update_builtin_metrics
from madlib_keras_custom_function import update_builtin_losses
import tensorflow.keras.losses as losses
import tensorflow.keras.metrics as metrics
class InputValidator:
@staticmethod
def validate_predict_evaluate_tables(
module_name, model_table, model_summary_table, test_table, output_table):
InputValidator._validate_model_weights_tbl(module_name, model_table)
InputValidator._validate_model_summary_tbl(
module_name, model_summary_table)
independent_varname = plpy.execute("SELECT {0} FROM {1}".format(
"independent_varname", model_summary_table))[0]["independent_varname"]
InputValidator._validate_test_tbl(
module_name, test_table, independent_varname)
output_tbl_valid(output_table, module_name)
@staticmethod
def validate_id_in_test_tbl(module_name, test_table, id_col):
_assert(is_var_valid(test_table, id_col),
"{module_name} error: invalid id column "
"('{id_col}') for test table ({table}).".format(
module_name=module_name,
id_col=id_col,
table=test_table))
@staticmethod
def validate_predict_byom_tables(module_name, model_arch_table, model_id,
test_table, id_col, output_table,
independent_varname):
InputValidator.validate_model_arch_table(
module_name, model_arch_table, model_id)
InputValidator._validate_test_tbl(
module_name, test_table, independent_varname)
InputValidator.validate_id_in_test_tbl(module_name, test_table, id_col)
output_tbl_valid(output_table, module_name)
@staticmethod
def validate_pred_type(module_name, pred_type, class_values):
error = False
if type(pred_type) == str:
if not pred_type in ['prob', 'response']:
error = True
elif type(pred_type) == int:
if pred_type <= 0:
error = True
else:
if pred_type < 0.0 or pred_type >= 1.0:
error = True
if error:
plpy.error("{0}: Invalid value for pred_type param ({1}). "\
"Must be integer>0, double precision in the range [0.0,1.0), "\
"'response' or 'prob'.".format(module_name, pred_type))
@staticmethod
def validate_input_shape(table, independent_varname, input_shape, offset,
is_minibatched=False):
"""
Validate if the input shape specified in model architecture is the same
as the shape of the image specified in the indepedent var of the input
table.
offset: This offset is the index of the start of the image array. We also
need to consider that sql array indexes start from 1
For ex if the image is of shape [32,32,3] and is minibatched, the image will
look like [10, 32, 32, 3]. The offset in this case is 1 (start the index at 1) +
1 (ignore the buffer size 10) = 2.
If the image is not batched then it will look like [32, 32 ,3] and the offset in
this case is 1 (start the index at 1).
"""
shapes = []
if is_minibatched:
for i in independent_varname:
ind_shape_col = add_postfix(i, "_shape")
query = """
SELECT {ind_shape_col} AS shape
FROM {table}
LIMIT 1
""".format(**locals())
# This query will fail if an image in independent var does not have the
# same number of dimensions as the input_shape.
result = plpy.execute(query)
result = result[0]['shape']
shapes.append(result[1:])
else:
for counter, ind in enumerate(independent_varname):
array_upper_query = ", ".join("array_upper({0}, {1})".format(
ind, i+offset, i) for i in range(len(input_shape[counter])))
query = """
SELECT ARRAY[{0}] AS shape
FROM {1}
LIMIT 1
""".format(array_upper_query, table)
# This query will fail if an image in independent var does not have the
# same number of dimensions as the input_shape.
result = plpy.execute(query)
result = result[0]['shape']
shapes.append(result)
_assert(len(shapes) == len(input_shape),
"model_keras error: The number of dimensions ({0}) of each image"
" in model architecture and {1} in {2} ({3}) do not match.".format(
len(input_shape), independent_varname, table, len(result)))
for i in range(len(input_shape)):
local_input_shape = input_shape[i]
local_arch_shape = shapes[i]
for j in range(len(local_input_shape)):
if local_arch_shape[j] != local_input_shape[j]:
# Construct the shape in independent varname to display
# meaningful error msg.
plpy.error("model_keras error: Input shape {0} in the model"
" architecture does not match the input shape {1} of column"
" {2} in table {3}.".format(
local_input_shape[j], local_arch_shape[j],
independent_varname, table))
@staticmethod
def validate_model_arch_table(module_name, model_arch_table, model_id):
input_tbl_valid(model_arch_table, module_name)
_assert(model_id is not None,
"{0}: Invalid model architecture ID.".format(module_name))
@staticmethod
def validate_normalizing_const(module_name, normalizing_const):
_assert(normalizing_const > 0,
"{0} error: Normalizing constant has to be greater than 0.".
format(module_name))
@staticmethod
def validate_class_values(module_name, class_values, pred_type, model_arch):
if not class_values:
return
num_classes = [len(i) for i in class_values]
_assert(num_classes == get_num_classes(model_arch, len(class_values)),
"{0}: The number of class values do not match the " \
"provided architecture.".format(module_name))
for i in num_classes:
if pred_type == 'prob' and i+1 >= 1600:
plpy.error({"{0}: The output will have {1} columns, exceeding the "\
" max number of columns that can be created (1600)".format(
module_name, i+1)})
@staticmethod
def validate_model_weights(module_name, model_arch, model_weights):
_assert(model_weights and model_arch,
"{0}: Model weights and architecture must be valid.".format(
module_name))
@staticmethod
def _validate_model_weights_tbl(module_name, model_table):
_assert(is_var_valid(model_table, MODEL_WEIGHTS_COLNAME),
"{module_name} error: column '{model_weights}' "
"does not exist in model table '{table}'.".format(
module_name=module_name,
model_weights=MODEL_WEIGHTS_COLNAME,
table=model_table))
_assert(is_var_valid(model_table, ModelArchSchema.MODEL_ARCH),
"{module_name} error: column '{model_arch}' "
"does not exist in model table '{table}'.".format(
module_name=module_name,
model_arch=ModelArchSchema.MODEL_ARCH,
table=model_table))
@staticmethod
def _validate_test_tbl(module_name, test_table, independent_varname):
input_tbl_valid(test_table, module_name)
for i in independent_varname:
_assert(is_var_valid(test_table, i),
"{module_name} error: invalid independent_varname "
"('{i}') for test table "
"({table}).".format(
module_name=module_name,
i=i,
table=test_table))
@staticmethod
def _validate_model_summary_tbl(module_name, model_summary_table):
input_tbl_valid(model_summary_table, module_name)
cols_to_check_for = [DEPENDENT_VARNAME_COLNAME,
DEPENDENT_VARTYPE_COLNAME,
MODEL_ID_COLNAME,
MODEL_ARCH_TABLE_COLNAME,
NORMALIZING_CONST_COLNAME,
COMPILE_PARAMS_COLNAME,
METRIC_TYPE_COLNAME]
_assert(columns_exist_in_table(
model_summary_table, cols_to_check_for),
"{0} error: One or more expected columns missing in model "
"summary table ('{1}'). The expected columns are {2}.".format(
module_name, model_summary_table, cols_to_check_for))
@staticmethod
def _validate_gpu_config(module_name, source_table, accessible_gpus_for_seg):
summary_table = add_postfix(source_table, "_summary")
gpu_config = plpy.execute(
"SELECT {0} FROM {1}".format(INTERNAL_GPU_CONFIG, summary_table)
)[0][INTERNAL_GPU_CONFIG]
if gpu_config == DistributionRulesOptions.ALL_SEGMENTS:
_assert(0 not in accessible_gpus_for_seg,
"{0} error: Host(s) are missing gpus.".format(module_name))
else:
for i in gpu_config:
_assert(accessible_gpus_for_seg[i] != 0,
"{0} error: Segment {1} does not have gpu".format(module_name, i))
class FitCommonValidator(object):
def __init__(self, source_table, validation_table, output_model_table,
num_iterations, metrics_compute_frequency, warm_start,
use_gpus, accessible_gpus_for_seg, module_name, object_table):
self.source_table = source_table
self.validation_table = validation_table
self.output_model_table = output_model_table
self.metrics_compute_frequency = metrics_compute_frequency
self.warm_start = warm_start
self.num_iterations = num_iterations
self.object_table = object_table
self.source_summary_table = None
if self.source_table:
self.source_summary_table = add_postfix(
self.source_table, "_summary")
if self.validation_table:
self.validation_summary_table = add_postfix(
self.validation_table, "_summary")
if self.output_model_table:
self.output_summary_model_table = add_postfix(
self.output_model_table, "_summary")
self.accessible_gpus_for_seg = accessible_gpus_for_seg
self.module_name = module_name
self._validate_tables()
self.src_summary_dict = self.get_source_summary_table_dict(self.source_summary_table)
self.dependent_varname = self.src_summary_dict['dependent_varname']
self.independent_varname = self.src_summary_dict['independent_varname']
if not isinstance(self.dependent_varname, list) or \
not isinstance(self.independent_varname, list):
plpy.error("Input table '{0}' was preprocessed with "\
"an older version of the input preprocessor. "
"Please re-run the current version of input preprocessor "\
"on the dataset.".format(self.source_table))
self.dependent_shape_varname = [add_postfix(i, "_shape") for i in self.dependent_varname]
self.independent_shape_varname = [add_postfix(i, "_shape") for i in self.independent_varname]
self.val_dependent_varname = None
self.val_independent_varname = None
self.val_dependent_shape_varname = None
self.val_independent_shape_varname = None
if self.validation_table:
val_summary_dict = self.get_source_summary_table_dict(self.validation_summary_table)
self.val_dependent_varname = val_summary_dict['dependent_varname']
self.val_independent_varname = val_summary_dict['independent_varname']
self.val_dependent_shape_varname = [add_postfix(i, "_shape") for i in self.val_dependent_varname]
self.val_independent_shape_varname = [add_postfix(i, "_shape") for i in self.val_independent_varname]
self._validate_tables_schema()
if use_gpus:
InputValidator._validate_gpu_config(self.module_name,
self.source_table, self.accessible_gpus_for_seg)
def _validate_tables(self):
input_tbl_valid(self.source_table, self.module_name)
input_tbl_valid(self.source_summary_table, self.module_name)
if self.validation_table:
input_tbl_valid(self.validation_table, self.module_name)
input_tbl_valid(self.validation_summary_table, self.module_name)
if self.object_table is not None:
input_tbl_valid(self.object_table, self.module_name)
_assert(is_superuser(get_table_owner(self.object_table)),
"DL: Cannot use a table of a non-superuser as object table.")
cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name)
if self.warm_start:
input_tbl_valid(self.output_model_table, self.module_name)
input_tbl_valid(self.output_summary_model_table, self.module_name)
else:
output_tbl_valid(self.output_model_table, self.module_name)
output_tbl_valid(self.output_summary_model_table, self.module_name)
def _validate_tables_schema(self):
# Source table and validation tables must have the same schema
additional_cols = []
if not is_platform_pg():
additional_cols.append(DISTRIBUTION_KEY_COLNAME)
self._validate_columns_in_preprocessed_table(self.source_table,
self.independent_varname +
self.dependent_varname +
self.independent_shape_varname +
self.dependent_shape_varname +
additional_cols)
for i in self.dependent_varname:
validate_bytea_var_for_minibatch(self.source_table, i)
if self.validation_table and self.validation_table.strip() != '':
self._validate_columns_in_preprocessed_table(self.validation_table,
self.val_independent_varname +
self.val_dependent_varname +
self.val_independent_shape_varname +
self.val_dependent_shape_varname +
additional_cols)
for i in self.val_dependent_varname:
validate_bytea_var_for_minibatch(self.validation_table, i)
cols_in_tbl_valid(self.source_summary_table,
[NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
'dependent_varname', 'independent_varname'], self.module_name)
def _validate_misc_args(self):
_assert(self.num_iterations > 0,
"{0}: Number of iterations cannot be < 1.".format(self.module_name))
_assert(self._is_valid_metrics_compute_frequency(),
"{0}: metrics_compute_frequency must be in the range (1 - {1}).".format(
self.module_name, self.num_iterations))
def get_source_summary_table_dict(self, source_summary_table):
source_summary = plpy.execute("""
SELECT *
FROM {0}
""".format(source_summary_table))[0]
return source_summary
def _validate_columns_in_preprocessed_table(self, table_name, col_names):
for col in col_names:
_assert(is_var_valid(table_name, col),
"{module_name}: invalid column name "
"('{col}') for table ({table_name}). "
"Please ensure that the input table ({table_name}) "
"has been preprocessed.".format(
module_name=self.module_name,
**locals()))
def _is_valid_metrics_compute_frequency(self):
return self.metrics_compute_frequency is None or \
(self.metrics_compute_frequency >= 1 and \
self.metrics_compute_frequency <= self.num_iterations)
def _validate_validation_table(self):
if self.validation_table and self.validation_table.strip() != '':
input_tbl_valid(self.validation_table, self.module_name)
self._validate_input_table(self.validation_table, True)
validation_summary_table = add_postfix(self.validation_table, "_summary")
input_tbl_valid(validation_summary_table, self.module_name)
for i in self.val_dependent_varname:
dependent_vartype = get_expr_type(i,
self.validation_table)
_assert(dependent_vartype == 'bytea',
"Dependent variable column {0} in validation table {1} should be "
"a bytea and also one hot encoded.".format(
i, self.validation_table))
def validate_input_shapes(self, input_shape):
InputValidator.validate_input_shape(self.source_table, self.independent_varname,
input_shape, 2, True)
if self.validation_table:
InputValidator.validate_input_shape(
self.validation_table, self.val_independent_varname,
input_shape, 2, True)
class FitInputValidator(FitCommonValidator):
def __init__(self, source_table, validation_table, output_model_table,
model_arch_table, model_id, num_iterations,
metrics_compute_frequency, warm_start,
use_gpus, accessible_gpus_for_seg, object_table):
self.module_name = 'madlib_keras_fit'
super(FitInputValidator, self).__init__(source_table,
validation_table,
output_model_table,
num_iterations,
metrics_compute_frequency,
warm_start,
use_gpus,
accessible_gpus_for_seg,
self.module_name,
object_table
)
InputValidator.validate_model_arch_table(self.module_name, model_arch_table,
model_id)
class FitMultipleInputValidator(FitCommonValidator):
def __init__(self, source_table, validation_table, output_model_table,
model_selection_table, num_iterations, mst_key_col,
model_arch_table_col, metrics_compute_frequency, warm_start,
use_gpus, accessible_gpus_for_seg):
self.module_name = 'madlib_keras_fit_multiple'
input_tbl_valid(model_selection_table, self.module_name)
self.model_selection_summary_table = add_postfix(model_selection_table,
'_summary')
input_tbl_valid(self.model_selection_summary_table, self.module_name,
error_suffix_str="Please ensure that the model selection table ({0}) "
"has been created by "
"load_model_selection_table().".format(
model_selection_table))
self.msts, self.model_arch_table, self.object_table = query_model_configs(
model_selection_table, self.model_selection_summary_table,
mst_key_col, model_arch_table_col)
input_tbl_valid(self.model_arch_table, self.module_name)
super(FitMultipleInputValidator, self).__init__(source_table,
validation_table,
output_model_table,
num_iterations,
metrics_compute_frequency,
warm_start,
use_gpus,
accessible_gpus_for_seg,
self.module_name,
self.object_table)
_assert(len(self.dependent_varname) == 1
or len(self.independent_varname) == 1,
"Multiple dependent and independent variables not supported "
"for madlib_keras_fit_multiple_model!")
self.output_model_info_table = add_postfix(output_model_table, '_info')
if warm_start:
input_tbl_valid(self.output_model_info_table, self.module_name)
else:
output_tbl_valid(self.output_model_info_table, self.module_name)
class MstLoaderInputValidator():
def __init__(self,
schema_madlib,
model_arch_table,
model_selection_table,
model_selection_summary_table,
model_id_list,
compile_params_list,
fit_params_list,
object_table,
module_name='load_model_selection_table'
):
self.schema_madlib = schema_madlib
self.model_arch_table = model_arch_table
self.model_selection_table = model_selection_table
self.model_selection_summary_table = model_selection_summary_table
self.model_id_list = model_id_list
self.compile_params_list = compile_params_list
self.fit_params_list = fit_params_list
self.object_table = object_table
self.module_name = module_name #'load_model_selection_table'
self._validate_input_args()
def _validate_input_args(self):
self._validate_input_output_tables()
self._validate_model_ids()
if self.module_name == 'load_model_selection_table':
self._validate_compile_and_fit_params()
def _validate_model_ids(self):
model_id_str = '({0})'\
.format(','.join([str(x) for x in self.model_id_list]))
query = """
SELECT count(model_id)
FROM {self.model_arch_table}
WHERE model_id IN {model_id_str}
""".format(**locals())
res = int(plpy.execute(query)[0]['count'])
_assert(
res == len(self.model_id_list),
"{0}: One or more model_id of {1} not found in table {2}".format(
self.module_name,
model_id_str,
self.model_arch_table
)
)
def _validate_compile_and_fit_params(self):
if not self.fit_params_list:
plpy.error("fit_params_list cannot be NULL")
for fit_params in self.fit_params_list:
try:
res = parse_and_validate_fit_params(fit_params)
except Exception as e:
plpy.error(
"""Fit param check failed for: {0} \n
{1}
""".format(fit_params, str(e)))
if not self.compile_params_list:
plpy.error( "compile_params_list cannot be NULL")
custom_fn_names = []
# Initialize builtin loss/metrics functions
builtin_losses = update_builtin_losses(dir(losses))
builtin_metrics = update_builtin_metrics(dir(metrics))
if self.object_table is not None:
res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME,
self.object_table))
for r in res:
custom_fn_names.append(r[CustomFunctionSchema.FN_NAME])
for compile_params in self.compile_params_list:
try:
_, _, res = parse_and_validate_compile_params(compile_params)
# Validating if loss/metrics function called in compile_params
# is either defined in object table or is a built_in keras
# loss/metrics function
error_suffix = "but input object table missing!"
if self.object_table is not None:
error_suffix = "is not defined in object table '{0}'!".format(self.object_table)
_assert(res['loss'] in custom_fn_names or res['loss'] in builtin_losses,
"custom function '{0}' used in compile params "\
"{1}".format(res['loss'], error_suffix))
if 'metrics' in res:
_assert((len(set(res['metrics']).intersection(custom_fn_names)) > 0
or len(set(res['metrics']).intersection(builtin_metrics)) > 0),
"custom function '{0}' used in compile params " \
"{1}".format(res['metrics'], error_suffix))
except Exception as e:
plpy.error(
"""Compile param check failed for: {0} \n
{1}
""".format(compile_params, str(e)))
def _validate_input_output_tables(self):
input_tbl_valid(self.model_arch_table, self.module_name)
if self.object_table is not None:
input_tbl_valid(self.object_table, self.module_name)
_assert(is_superuser(get_table_owner(self.object_table)),
"DL: Cannot use a table of a non-superuser as object table.")
if self.module_name == 'load_model_selection_table' or self.module_name == 'madlib_keras_automl':
output_tbl_valid(self.model_selection_table, self.module_name)
output_tbl_valid(self.model_selection_summary_table, self.module_name)