| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from keras_model_arch_table import ModelArchSchema |
| from utilities.utilities import add_postfix |
| from utilities.validate_args import input_tbl_valid |
| |
| from madlib_keras_helper import CLASS_VALUES_COLNAME |
| from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME |
| from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME |
| from madlib_keras_helper import MODEL_DATA_COLNAME |
| from madlib_keras_helper import NORMALIZING_CONST_COLNAME |
| |
| class PredictParamsProcessor: |
| def __init__(self, model_table, module_name): |
| self.module_name = module_name |
| self.model_table = model_table |
| self.model_summary_table = add_postfix(self.model_table, '_summary') |
| input_tbl_valid(self.model_summary_table, self.module_name) |
| self.model_summary_dict = self._get_dict_for_table(self.model_summary_table) |
| self.model_arch_dict = self._get_dict_for_table(self.model_table) |
| |
| def _get_dict_for_table(self, table_name): |
| return plpy.execute("SELECT * FROM {0}".format(table_name), 1)[0] |
| |
| def get_class_values(self): |
| return self.model_summary_dict[CLASS_VALUES_COLNAME] |
| |
| def get_dependent_varname(self): |
| return self.model_summary_dict[DEPENDENT_VARNAME_COLNAME] |
| |
| def get_dependent_vartype(self): |
| return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME] |
| |
| def get_model_arch(self): |
| return self.model_arch_dict[ModelArchSchema.MODEL_ARCH] |
| |
| def get_model_data(self): |
| return self.model_arch_dict[MODEL_DATA_COLNAME] |
| |
| def get_normalizing_const(self): |
| return self.model_summary_dict[NORMALIZING_CONST_COLNAME] |