| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from keras_model_arch_table import ModelArchSchema |
| from model_arch_info import get_num_classes |
| from madlib_keras_custom_function import CustomFunctionSchema |
| from madlib_keras_helper import CLASS_VALUES_COLNAME |
| from madlib_keras_helper import COMPILE_PARAMS_COLNAME |
| from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME |
| from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME |
| from madlib_keras_helper import MODEL_ID_COLNAME |
| from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME |
| from madlib_keras_helper import MODEL_WEIGHTS_COLNAME |
| from madlib_keras_helper import NORMALIZING_CONST_COLNAME |
| from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME |
| from madlib_keras_helper import METRIC_TYPE_COLNAME |
| from madlib_keras_helper import INTERNAL_GPU_CONFIG |
| from madlib_keras_helper import query_model_configs |
| |
| from utilities.minibatch_validation import validate_bytea_var_for_minibatch |
| from utilities.utilities import _assert |
| from utilities.utilities import add_postfix |
| from utilities.utilities import is_platform_pg |
| from utilities.utilities import is_var_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import output_tbl_valid |
| from madlib_keras_wrapper import parse_and_validate_fit_params |
| from madlib_keras_wrapper import parse_and_validate_compile_params |
| import keras.losses as losses |
| import keras.metrics as metrics |
| |
| class InputValidator: |
| @staticmethod |
| def validate_predict_evaluate_tables( |
| module_name, model_table, model_summary_table, test_table, output_table, |
| independent_varname): |
| InputValidator._validate_model_weights_tbl(module_name, model_table) |
| InputValidator._validate_model_summary_tbl( |
| module_name, model_summary_table) |
| InputValidator._validate_test_tbl( |
| module_name, test_table, independent_varname) |
| output_tbl_valid(output_table, module_name) |
| |
| @staticmethod |
| def validate_id_in_test_tbl(module_name, test_table, id_col): |
| _assert(is_var_valid(test_table, id_col), |
| "{module_name} error: invalid id column " |
| "('{id_col}') for test table ({table}).".format( |
| module_name=module_name, |
| id_col=id_col, |
| table=test_table)) |
| |
| @staticmethod |
| def validate_predict_byom_tables(module_name, model_arch_table, model_id, |
| test_table, id_col, output_table, |
| independent_varname): |
| InputValidator.validate_model_arch_table( |
| module_name, model_arch_table, model_id) |
| InputValidator._validate_test_tbl( |
| module_name, test_table, independent_varname) |
| InputValidator.validate_id_in_test_tbl(module_name, test_table, id_col) |
| |
| output_tbl_valid(output_table, module_name) |
| |
| |
| @staticmethod |
| def validate_pred_type(module_name, pred_type, class_values): |
| |
| error = False |
| if type(pred_type) == str: |
| if not pred_type in ['prob', 'response']: |
| error = True |
| elif type(pred_type) == int: |
| if pred_type <= 0: |
| error = True |
| else: |
| if pred_type < 0.0 or pred_type >= 1.0: |
| error = True |
| if error: |
| plpy.error("{0}: Invalid value for pred_type param ({1}). "\ |
| "Must be integer>0, double precision in the range [0.0,1.0), "\ |
| "'response' or 'prob'.".format(module_name, pred_type)) |
| |
| @staticmethod |
| def validate_input_shape(table, independent_varname, input_shape, offset, |
| is_minibatched=False): |
| """ |
| Validate if the input shape specified in model architecture is the same |
| as the shape of the image specified in the indepedent var of the input |
| table. |
| offset: This offset is the index of the start of the image array. We also |
| need to consider that sql array indexes start from 1 |
| For ex if the image is of shape [32,32,3] and is minibatched, the image will |
| look like [10, 32, 32, 3]. The offset in this case is 1 (start the index at 1) + |
| 1 (ignore the buffer size 10) = 2. |
| If the image is not batched then it will look like [32, 32 ,3] and the offset in |
| this case is 1 (start the index at 1). |
| """ |
| |
| if is_minibatched: |
| ind_shape_col = add_postfix(independent_varname, "_shape") |
| query = """ |
| SELECT {ind_shape_col} AS shape |
| FROM {table} |
| LIMIT 1 |
| """.format(**locals()) |
| # This query will fail if an image in independent var does not have the |
| # same number of dimensions as the input_shape. |
| result = plpy.execute(query)[0]['shape'] |
| result = result[1:] |
| else: |
| array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format( |
| independent_varname, i+offset, i) for i in range(len(input_shape))) |
| query = """ |
| SELECT {0} |
| FROM {1} |
| LIMIT 1 |
| """.format(array_upper_query, table) |
| |
| # This query will fail if an image in independent var does not have the |
| # same number of dimensions as the input_shape. |
| result = plpy.execute(query)[0] |
| |
| _assert(len(result) == len(input_shape), |
| "model_keras error: The number of dimensions ({0}) of each image" |
| " in model architecture and {1} in {2} ({3}) do not match.".format( |
| len(input_shape), independent_varname, table, len(result))) |
| |
| for i in range(len(input_shape)): |
| if is_minibatched: |
| key_name = i |
| input_shape_from_table = [result[j] |
| for j in range(len(input_shape))] |
| else: |
| key_format = "n_{0}" |
| key_name = key_format.format(i) |
| input_shape_from_table = [result[key_format.format(j)] |
| for j in range(len(input_shape))] |
| |
| if result[key_name] != input_shape[i]: |
| # Construct the shape in independent varname to display |
| # meaningful error msg. |
| plpy.error("model_keras error: Input shape {0} in the model" |
| " architecture does not match the input shape {1} of column" |
| " {2} in table {3}.".format( |
| input_shape, input_shape_from_table, |
| independent_varname, table)) |
| |
| @staticmethod |
| def validate_model_arch_table(module_name, model_arch_table, model_id): |
| input_tbl_valid(model_arch_table, module_name) |
| _assert(model_id is not None, |
| "{0}: Invalid model architecture ID.".format(module_name)) |
| |
| |
| @staticmethod |
| def validate_normalizing_const(module_name, normalizing_const): |
| _assert(normalizing_const > 0, |
| "{0} error: Normalizing constant has to be greater than 0.". |
| format(module_name)) |
| |
| @staticmethod |
| def validate_class_values(module_name, class_values, pred_type, model_arch): |
| if not class_values: |
| return |
| num_classes = len(class_values) |
| _assert(num_classes == get_num_classes(model_arch), |
| "{0}: The number of class values do not match the " \ |
| "provided architecture.".format(module_name)) |
| if pred_type == 'prob' and num_classes+1 >= 1600: |
| plpy.error({"{0}: The output will have {1} columns, exceeding the "\ |
| " max number of columns that can be created (1600)".format( |
| module_name, num_classes+1)}) |
| |
| @staticmethod |
| def validate_model_weights(module_name, model_arch, model_weights): |
| _assert(model_weights and model_arch, |
| "{0}: Model weights and architecture must be valid.".format( |
| module_name)) |
| |
| @staticmethod |
| def _validate_model_weights_tbl(module_name, model_table): |
| _assert(is_var_valid(model_table, MODEL_WEIGHTS_COLNAME), |
| "{module_name} error: column '{model_weights}' " |
| "does not exist in model table '{table}'.".format( |
| module_name=module_name, |
| model_weights=MODEL_WEIGHTS_COLNAME, |
| table=model_table)) |
| _assert(is_var_valid(model_table, ModelArchSchema.MODEL_ARCH), |
| "{module_name} error: column '{model_arch}' " |
| "does not exist in model table '{table}'.".format( |
| module_name=module_name, |
| model_arch=ModelArchSchema.MODEL_ARCH, |
| table=model_table)) |
| |
| @staticmethod |
| def _validate_test_tbl(module_name, test_table, independent_varname): |
| input_tbl_valid(test_table, module_name) |
| _assert(is_var_valid(test_table, independent_varname), |
| "{module_name} error: invalid independent_varname " |
| "('{independent_varname}') for test table " |
| "({table}).".format( |
| module_name=module_name, |
| independent_varname=independent_varname, |
| table=test_table)) |
| |
| @staticmethod |
| def _validate_model_summary_tbl(module_name, model_summary_table): |
| input_tbl_valid(model_summary_table, module_name) |
| cols_to_check_for = [CLASS_VALUES_COLNAME, |
| DEPENDENT_VARNAME_COLNAME, |
| DEPENDENT_VARTYPE_COLNAME, |
| MODEL_ID_COLNAME, |
| MODEL_ARCH_TABLE_COLNAME, |
| NORMALIZING_CONST_COLNAME, |
| COMPILE_PARAMS_COLNAME, |
| METRIC_TYPE_COLNAME] |
| _assert(columns_exist_in_table( |
| model_summary_table, cols_to_check_for), |
| "{0} error: One or more expected columns missing in model " |
| "summary table ('{1}'). The expected columns are {2}.".format( |
| module_name, model_summary_table, cols_to_check_for)) |
| |
| @staticmethod |
| def _validate_gpu_config(module_name, source_table, accessible_gpus_for_seg): |
| |
| summary_table = add_postfix(source_table, "_summary") |
| gpu_config = plpy.execute( |
| "SELECT {0} FROM {1}".format(INTERNAL_GPU_CONFIG, summary_table) |
| )[0][INTERNAL_GPU_CONFIG] |
| if gpu_config == 'all_segments': |
| _assert(0 not in accessible_gpus_for_seg, |
| "{0} error: Host(s) are missing gpus.".format(module_name)) |
| else: |
| for i in gpu_config: |
| _assert(accessible_gpus_for_seg[i] != 0, |
| "{0} error: Segment {1} does not have gpu".format(module_name, i)) |
| |
| class FitCommonValidator(object): |
| def __init__(self, source_table, validation_table, output_model_table, |
| model_arch_table, model_id, dependent_varname, |
| independent_varname, num_iterations, |
| metrics_compute_frequency, warm_start, |
| use_gpus, accessible_gpus_for_seg, module_name, object_table): |
| self.source_table = source_table |
| self.validation_table = validation_table |
| self.output_model_table = output_model_table |
| self.model_arch_table = model_arch_table |
| self.model_id = model_id |
| self.dependent_varname = dependent_varname |
| self.independent_varname = independent_varname |
| self.dep_shape_col = add_postfix(dependent_varname, "_shape") |
| self.ind_shape_col = add_postfix(independent_varname, "_shape") |
| self.metrics_compute_frequency = metrics_compute_frequency |
| self.warm_start = warm_start |
| self.num_iterations = num_iterations |
| self.object_table = object_table |
| self.source_summary_table = None |
| if self.source_table: |
| self.source_summary_table = add_postfix( |
| self.source_table, "_summary") |
| if self.output_model_table: |
| self.output_summary_model_table = add_postfix( |
| self.output_model_table, "_summary") |
| self.accessible_gpus_for_seg = accessible_gpus_for_seg |
| self.module_name = module_name |
| self._validate_common_args() |
| if use_gpus: |
| InputValidator._validate_gpu_config(self.module_name, |
| self.source_table, self.accessible_gpus_for_seg) |
| |
| def _validate_common_args(self): |
| _assert(self.num_iterations > 0, |
| "{0}: Number of iterations cannot be < 1.".format(self.module_name)) |
| _assert(self._is_valid_metrics_compute_frequency(), |
| "{0}: metrics_compute_frequency must be in the range (1 - {1}).".format( |
| self.module_name, self.num_iterations)) |
| input_tbl_valid(self.source_table, self.module_name) |
| if self.object_table is not None: |
| input_tbl_valid(self.object_table, self.module_name) |
| cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name) |
| input_tbl_valid(self.source_summary_table, self.module_name, |
| error_suffix_str="Please ensure that the source table ({0}) " |
| "has been preprocessed by " |
| "the image preprocessor.".format(self.source_table)) |
| cols_in_tbl_valid(self.source_summary_table, [CLASS_VALUES_COLNAME, |
| NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME, |
| 'dependent_varname', 'independent_varname'], self.module_name) |
| if not is_platform_pg(): |
| cols_in_tbl_valid(self.source_table, [DISTRIBUTION_KEY_COLNAME], self.module_name) |
| |
| # Source table and validation tables must have the same schema |
| self._validate_input_table(self.source_table) |
| validate_bytea_var_for_minibatch(self.source_table, |
| self.dependent_varname) |
| |
| self._validate_validation_table() |
| if self.warm_start: |
| input_tbl_valid(self.output_model_table, self.module_name) |
| input_tbl_valid(self.output_summary_model_table, self.module_name) |
| else: |
| output_tbl_valid(self.output_model_table, self.module_name) |
| output_tbl_valid(self.output_summary_model_table, self.module_name) |
| |
| def _validate_input_table(self, table): |
| _assert(is_var_valid(table, self.independent_varname), |
| "{module_name}: invalid independent_varname " |
| "('{independent_varname}') for table ({table}). " |
| "Please ensure that the input table ({table}) " |
| "has been preprocessed by the image preprocessor.".format( |
| module_name=self.module_name, |
| independent_varname=self.independent_varname, |
| table=table)) |
| |
| _assert(is_var_valid(table, self.dependent_varname), |
| "{module_name}: invalid dependent_varname " |
| "('{dependent_varname}') for table ({table}). " |
| "Please ensure that the input table ({table}) " |
| "has been preprocessed by the image preprocessor.".format( |
| module_name=self.module_name, |
| dependent_varname=self.dependent_varname, |
| table=table)) |
| |
| _assert(is_var_valid(table, self.ind_shape_col), |
| "{module_name}: invalid independent_var_shape " |
| "('{ind_shape_col}') for table ({table}). " |
| "Please ensure that the input table ({table}) " |
| "has been preprocessed by the image preprocessor.".format( |
| module_name=self.module_name, |
| ind_shape_col=self.ind_shape_col, |
| table=table)) |
| |
| _assert(is_var_valid(table, self.dep_shape_col), |
| "{module_name}: invalid dependent_var_shape " |
| "('{dep_shape_col}') for table ({table}). " |
| "Please ensure that the input table ({table}) " |
| "has been preprocessed by the image preprocessor.".format( |
| module_name=self.module_name, |
| dep_shape_col=self.dep_shape_col, |
| table=table)) |
| |
| if not is_platform_pg(): |
| _assert(is_var_valid(table, DISTRIBUTION_KEY_COLNAME), |
| "{module_name}: missing distribution key " |
| "('{dist_key_col}') for table ({table}). " |
| "Please ensure that the input table ({table}) " |
| "has been preprocessed by the image preprocessor.".format( |
| module_name=self.module_name, |
| dist_key_col=DISTRIBUTION_KEY_COLNAME, |
| table=table)) |
| |
| def _is_valid_metrics_compute_frequency(self): |
| return self.metrics_compute_frequency is None or \ |
| (self.metrics_compute_frequency >= 1 and \ |
| self.metrics_compute_frequency <= self.num_iterations) |
| |
| def _validate_validation_table(self): |
| if self.validation_table and self.validation_table.strip() != '': |
| input_tbl_valid(self.validation_table, self.module_name) |
| self._validate_input_table(self.validation_table) |
| dependent_vartype = get_expr_type(self.dependent_varname, |
| self.validation_table) |
| _assert(dependent_vartype == 'bytea', |
| "Dependent variable column {0} in validation table {1} should be " |
| "a bytea and also one hot encoded.".format( |
| self.dependent_varname, self.validation_table)) |
| |
| |
| def validate_input_shapes(self, input_shape): |
| InputValidator.validate_input_shape(self.source_table, self.independent_varname, |
| input_shape, 2, True) |
| if self.validation_table: |
| InputValidator.validate_input_shape( |
| self.validation_table, self.independent_varname, |
| input_shape, 2, True) |
| |
| |
| class FitInputValidator(FitCommonValidator): |
| def __init__(self, source_table, validation_table, output_model_table, |
| model_arch_table, model_id, dependent_varname, |
| independent_varname, num_iterations, |
| metrics_compute_frequency, warm_start, |
| use_gpus, accessible_gpus_for_seg, object_table): |
| |
| self.module_name = 'madlib_keras_fit' |
| super(FitInputValidator, self).__init__(source_table, |
| validation_table, |
| output_model_table, |
| model_arch_table, |
| model_id, |
| dependent_varname, |
| independent_varname, |
| num_iterations, |
| metrics_compute_frequency, |
| warm_start, |
| use_gpus, |
| accessible_gpus_for_seg, |
| self.module_name, |
| object_table) |
| InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table, |
| self.model_id) |
| |
| class FitMultipleInputValidator(FitCommonValidator): |
| def __init__(self, source_table, validation_table, output_model_table, |
| model_selection_table, model_selection_summary_table, dependent_varname, |
| independent_varname, num_iterations, model_info_table, mst_key_col, |
| model_arch_table_col, metrics_compute_frequency, warm_start, |
| use_gpus, accessible_gpus_for_seg): |
| |
| self.module_name = 'madlib_keras_fit_multiple' |
| |
| input_tbl_valid(model_selection_table, self.module_name) |
| input_tbl_valid(model_selection_summary_table, self.module_name, |
| error_suffix_str="Please ensure that the model selection table ({0}) " |
| "has been created by " |
| "load_model_selection_table().".format( |
| model_selection_table)) |
| self.msts, self.model_arch_table, self.object_table = query_model_configs( |
| model_selection_table, model_selection_summary_table, |
| mst_key_col, model_arch_table_col) |
| if warm_start: |
| input_tbl_valid(model_info_table, self.module_name) |
| else: |
| output_tbl_valid(model_info_table, self.module_name) |
| super(FitMultipleInputValidator, self).__init__(source_table, |
| validation_table, |
| output_model_table, |
| self.model_arch_table, |
| None, |
| dependent_varname, |
| independent_varname, |
| num_iterations, |
| metrics_compute_frequency, |
| warm_start, |
| use_gpus, |
| accessible_gpus_for_seg, |
| self.module_name, |
| self.object_table) |
| |
| class MstLoaderInputValidator(): |
| def __init__(self, |
| model_arch_table, |
| model_selection_table, |
| model_selection_summary_table, |
| model_id_list, |
| compile_params_list, |
| fit_params_list, |
| object_table, |
| module_name='load_model_selection_table' |
| ): |
| self.model_arch_table = model_arch_table |
| self.model_selection_table = model_selection_table |
| self.model_selection_summary_table = model_selection_summary_table |
| self.model_id_list = model_id_list |
| self.compile_params_list = compile_params_list |
| self.fit_params_list = fit_params_list |
| self.object_table = object_table |
| self.module_name = module_name #'load_model_selection_table' |
| self._validate_input_args() |
| |
| def _validate_input_args(self): |
| self._validate_input_output_tables() |
| self._validate_model_ids() |
| if self.module_name == 'load_model_selection_table': |
| self._validate_compile_and_fit_params() |
| |
| def _validate_model_ids(self): |
| model_id_str = '({0})'\ |
| .format(','.join([str(x) for x in self.model_id_list])) |
| query = """ |
| SELECT count(model_id) |
| FROM {self.model_arch_table} |
| WHERE model_id IN {model_id_str} |
| """.format(**locals()) |
| res = int(plpy.execute(query)[0]['count']) |
| _assert( |
| res == len(self.model_id_list), |
| "{0}: One or more model_id of {1} not found in table {2}".format( |
| self.module_name, |
| model_id_str, |
| self.model_arch_table |
| ) |
| ) |
| def _validate_compile_and_fit_params(self): |
| if not self.fit_params_list: |
| plpy.error("fit_params_list cannot be NULL") |
| for fit_params in self.fit_params_list: |
| try: |
| res = parse_and_validate_fit_params(fit_params) |
| except Exception as e: |
| plpy.error( |
| """Fit param check failed for: {0} \n |
| {1} |
| """.format(fit_params, str(e))) |
| if not self.compile_params_list: |
| plpy.error( "compile_params_list cannot be NULL") |
| custom_fn_name = [] |
| ## Initialize builtin loss/metrics functions |
| builtin_losses = dir(losses) |
| builtin_metrics = dir(metrics) |
| # Default metrics, since it is not part of the builtin metrics list |
| builtin_metrics.append('accuracy') |
| if self.object_table is not None: |
| res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME, |
| self.object_table)) |
| for r in res: |
| custom_fn_name.append(r[CustomFunctionSchema.FN_NAME]) |
| for compile_params in self.compile_params_list: |
| try: |
| _, _, res = parse_and_validate_compile_params(compile_params) |
| # Validating if loss/metrics function called in compile_params |
| # is either defined in object table or is a built_in keras |
| # loss/metrics function |
| error_suffix = "but input object table missing!" |
| if self.object_table is not None: |
| error_suffix = "is not defined in object table '{0}'!".format(self.object_table) |
| |
| _assert(res['loss'] in custom_fn_name or res['loss'] in builtin_losses, |
| "custom function '{0}' used in compile params "\ |
| "{1}".format(res['loss'], error_suffix)) |
| if 'metrics' in res: |
| _assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0 |
| or len(set(res['metrics']).intersection(builtin_metrics)) > 0), |
| "custom function '{0}' used in compile params " \ |
| "{1}".format(res['metrics'], error_suffix)) |
| |
| except Exception as e: |
| plpy.error( |
| """Compile param check failed for: {0} \n |
| {1} |
| """.format(compile_params, str(e))) |
| |
| def _validate_input_output_tables(self): |
| input_tbl_valid(self.model_arch_table, self.module_name) |
| if self.object_table is not None: |
| input_tbl_valid(self.object_table, self.module_name) |
| if self.module_name == 'load_model_selection_table' or self.module_name == 'madlib_keras_automl': |
| output_tbl_valid(self.model_selection_table, self.module_name) |
| output_tbl_valid(self.model_selection_summary_table, self.module_name) |
| |
| |