| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| @file keras_model_arch_table.py_in |
| |
| @brief keras model arch table management helper functions |
| |
| @namespace keras_model_arch_table |
| """ |
| |
| from internal.db_utils import quote_literal |
| import plpy |
| from utilities.control import MinWarning |
| from utilities.utilities import get_col_name_type_sql_string |
| from utilities.utilities import unique_string |
| from utilities.validate_args import columns_missing_from_table |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import quote_ident |
| from utilities.validate_args import table_exists |
| |
| class ModelArchSchema: |
| """Expected format of keras_model_arch_table. |
| Example uses: |
| |
| from utilities.validate_args import columns_missing_from_table |
| from keras_model_arch_table import Format |
| |
| # Validate names in cols list against actual table |
| missing_cols = columns_missing_from_table('my_arch_table', Format.col_names) |
| |
| # Get model arch from keras model arch table, without hard coding column names |
| sql = "SELECT {arch} FROM {table} WHERE {id} = {my_id}" |
| .format(arch=Format.model_arch, |
| table='my_arch_table', |
| id=Format.model_id, |
| my_id=1) |
| arch = plpy.execute(sql)[0] |
| |
| """ |
| col_names = ('model_id', 'model_arch', 'model_weights', 'name', 'description', |
| '__internal_madlib_id__') |
| col_types = ('SERIAL PRIMARY KEY', 'JSON', 'bytea', 'TEXT', 'TEXT', 'TEXT') |
| (MODEL_ID, MODEL_ARCH, MODEL_WEIGHTS, NAME, DESCRIPTION, |
| __INTERNAL_MADLIB_ID__) = col_names |
| |
| @MinWarning("error") |
| def load_keras_model(keras_model_arch_table, model_arch, model_weights, |
| name, description, **kwargs): |
| model_arch_table = quote_ident(keras_model_arch_table) |
| if not table_exists(model_arch_table): |
| col_defs = get_col_name_type_sql_string(ModelArchSchema.col_names, |
| ModelArchSchema.col_types) |
| |
| sql = "CREATE TABLE {model_arch_table} ({col_defs})" \ |
| .format(**locals()) |
| |
| plpy.execute(sql, 0) |
| plpy.info("Keras Model Arch: Created new keras model architecture table {0}." \ |
| .format(model_arch_table)) |
| else: |
| missing_cols = columns_missing_from_table(model_arch_table, |
| ModelArchSchema.col_names) |
| if len(missing_cols) > 0: |
| plpy.error("Keras Model Arch: Invalid keras model architecture table {0}," |
| " missing columns: {1}".format(model_arch_table, |
| missing_cols)) |
| |
| unique_str = unique_string(prefix_has_temp=False) |
| insert_query = plpy.prepare("INSERT INTO {model_arch_table} " |
| "VALUES(DEFAULT, $1, $2, $3, $4, $5);".format(**locals()), |
| ModelArchSchema.col_types[1:]) |
| insert_res = plpy.execute(insert_query,[model_arch, model_weights, name, description, |
| unique_str], 0) |
| |
| select_query = """SELECT {model_id_col}, {model_arch_col} FROM {model_arch_table} |
| WHERE {internal_id_col} = '{unique_str}'""".format( |
| model_id_col=ModelArchSchema.MODEL_ID, |
| model_arch_col=ModelArchSchema.MODEL_ARCH, |
| model_arch_table=model_arch_table, |
| internal_id_col=ModelArchSchema.__INTERNAL_MADLIB_ID__, |
| unique_str=unique_str) |
| select_res = plpy.execute(select_query,1) |
| |
| plpy.info("Keras Model Arch: Added model id {0} to {1} table". |
| format(select_res[0][ModelArchSchema.MODEL_ID], model_arch_table)) |
| |
| @MinWarning("error") |
| def delete_keras_model(keras_model_arch_table, model_id, **kwargs): |
| model_arch_table = quote_ident(keras_model_arch_table) |
| input_tbl_valid(model_arch_table, "Keras Model Arch") |
| |
| missing_cols = columns_missing_from_table(model_arch_table, ModelArchSchema.col_names) |
| if len(missing_cols) > 0: |
| plpy.error("Keras Model Arch: Invalid keras model architecture table {0}," |
| " missing columns: {1}".format(model_arch_table, missing_cols)) |
| |
| sql = """ |
| DELETE FROM {model_arch_table} WHERE {model_id_col}={model_id} |
| """.format(model_arch_table=model_arch_table, model_id_col=ModelArchSchema.MODEL_ID, |
| model_id=model_id) |
| res = plpy.execute(sql, 0) |
| |
| if res.nrows() > 0: |
| plpy.info("Keras Model Arch: Model id {0} has been deleted from {1}.". |
| format(model_id, model_arch_table)) |
| else: |
| plpy.error("Keras Model Arch: Model id {0} not found".format(model_id)) |
| |
| sql = "SELECT {0} FROM {1}".format(ModelArchSchema.MODEL_ID, model_arch_table) |
| res = plpy.execute(sql, 0) |
| if not res: |
| plpy.info("Keras Model Arch: Dropping empty keras model architecture "\ |
| "table {model_arch_table}".format(model_arch_table=model_arch_table)) |
| sql = "DROP TABLE {0}".format(model_arch_table) |
| plpy.execute(sql, 0) |
| |
| class KerasModelArchDocumentation: |
| @staticmethod |
| def _returnHelpMsg(schema_madlib, message, summary, usage, method): |
| if not message: |
| return summary |
| elif message.lower() in ('usage', 'help', '?'): |
| return usage |
| return """ |
| No such option. Use "SELECT {schema_madlib}.{method}()" |
| for help. |
| """.format(**locals()) |
| |
| @staticmethod |
| def load_keras_model_help(schema_madlib, message): |
| method = "load_keras_model" |
| summary = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| The architecture of the model to be used in madlib_keras_train() |
| function must be stored in a table, the details of which must be |
| provided as parameters to the madlib_keras_train module. This is |
| a helper function to help users insert JSON blobs of Keras model |
| architectures into a table. |
| If the output table already exists, the model_arch specified will |
| be added as a new row into the table. The output table could thus |
| act as a repository of Keras model architectures. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.{method}('usage') |
| """.format(**locals()) |
| |
| usage = """ |
| --------------------------------------------------------------------------- |
| USAGE |
| --------------------------------------------------------------------------- |
| SELECT {schema_madlib}.{method}( |
| keras_model_arch_table, -- Output table to load keras model arch. |
| model_arch, -- JSON of the model architecture to insert. |
| model_weights, -- Model weights to load as a PostgreSQL |
| binary data type. |
| name, -- Free text string to identify a name |
| description -- Free text string to provide a description |
| ); |
| |
| |
| --------------------------------------------------------------------------- |
| OUTPUT |
| --------------------------------------------------------------------------- |
| The output table produced by load_keras_model contains the following columns: |
| |
| 'model_id' -- SERIAL PRIMARY KEY. Model ID. |
| 'model_arch' -- JSON. JSON blob of the model architecture. |
| 'model_weights' -- bytea. weights of the model for warm start. |
| '__internal_madlib_id__' -- TEXT. Unique id for model arch. |
| |
| """.format(**locals()) |
| |
| return KerasModelArchDocumentation._returnHelpMsg( |
| schema_madlib, message, summary, usage, method) |
| # --------------------------------------------------------------------- |
| |
| @staticmethod |
| def delete_keras_model_help(schema_madlib, message): |
| method = "delete_keras_model" |
| summary = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Delete the model architecture corresponding to the provided model_id |
| from the model architecture repository table (keras_model_arch_table). |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.{method}('usage') |
| """.format(**locals()) |
| |
| usage = """ |
| --------------------------------------------------------------------------- |
| USAGE |
| --------------------------------------------------------------------------- |
| SELECT {schema_madlib}.{method}( |
| keras_model_arch_table VARCHAR, -- Table containing Keras model architectures. |
| model_id INTEGER -- The id of the model arch to be deleted. |
| ); |
| |
| |
| --------------------------------------------------------------------------- |
| OUTPUT |
| --------------------------------------------------------------------------- |
| This method deletes the row corresponding to the given model_id in |
| keras_model_arch_table. This also tries to drop the table if the table is |
| empty after dropping the model_id. If there are any views depending on the |
| table, a warning message is displayed and the table is not dropped. |
| |
| --------------------------------------------------------------------------- |
| """.format(**locals()) |
| |
| return KerasModelArchDocumentation._returnHelpMsg( |
| schema_madlib, message, summary, usage, method) |