| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from keras_model_arch_table import ModelArchSchema |
| from utilities.utilities import add_postfix |
| from utilities.validate_args import input_tbl_valid |
| |
| from madlib_keras_helper import CLASS_VALUES_COLNAME |
| from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME |
| from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME |
| from madlib_keras_helper import MODEL_WEIGHTS_COLNAME |
| from madlib_keras_helper import NORMALIZING_CONST_COLNAME |
| from madlib_keras_helper import create_summary_view |
| |
| class PredictParamsProcessor: |
| def __init__(self, model_table, module_name, mst_key): |
| self.module_name = module_name |
| self.model_table = model_table |
| self.mst_key = mst_key |
| self.model_summary_table = add_postfix(self.model_table, '_summary') |
| self.mult_where_clause = "" |
| if self.mst_key: |
| self.model_summary_table = create_summary_view(self.module_name, |
| self.model_table, |
| self.mst_key) |
| self.mult_where_clause = "WHERE mst_key = {0}".format(self.mst_key) |
| input_tbl_valid(self.model_summary_table, self.module_name) |
| self.model_summary_dict = self._get_dict_for_table(self.model_summary_table) |
| self.model_arch_dict = self._get_dict_for_table(self.model_table) |
| |
| def _get_dict_for_table(self, table_name): |
| return plpy.execute("SELECT * FROM {0} {1}".format(table_name, self.mult_where_clause), 1)[0] |
| |
| def get_class_values(self, dep): |
| col_name = add_postfix(dep, "_class_values") |
| return self.model_summary_dict[col_name] |
| |
| def get_dependent_varname(self): |
| return self.model_summary_dict[DEPENDENT_VARNAME_COLNAME] |
| |
| def get_dependent_vartype(self): |
| return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME] |
| |
| def get_model_arch(self): |
| return self.model_arch_dict[ModelArchSchema.MODEL_ARCH] |
| |
| def get_model_weights(self): |
| return self.model_arch_dict[MODEL_WEIGHTS_COLNAME] |
| |
| def get_normalizing_const(self): |
| return self.model_summary_dict[NORMALIZING_CONST_COLNAME] |