blob: 0d6fc7cf2d2613cc35af07c43ff52b94a1c07b67 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
@file keras_model_arch_table.py_in
@brief keras model arch table management helper functions
@namespace keras_model_arch_table
"""
from internal.db_utils import quote_literal
import plpy
from utilities.control import MinWarning
from utilities.utilities import get_col_name_type_sql_string
from utilities.utilities import unique_string
from utilities.validate_args import columns_missing_from_table
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import quote_ident
from utilities.validate_args import table_exists
class ModelArchSchema:
"""Expected format of keras_model_arch_table.
Example uses:
from utilities.validate_args import columns_missing_from_table
from keras_model_arch_table import Format
# Validate names in cols list against actual table
missing_cols = columns_missing_from_table('my_arch_table', Format.col_names)
# Get model arch from keras model arch table, without hard coding column names
sql = "SELECT {arch} FROM {table} WHERE {id} = {my_id}"
.format(arch=Format.model_arch,
table='my_arch_table',
id=Format.model_id,
my_id=1)
arch = plpy.execute(sql)[0]
"""
col_names = ('model_id', 'model_arch', 'model_weights', 'name', 'description',
'__internal_madlib_id__')
col_types = ('SERIAL PRIMARY KEY', 'JSON', 'bytea', 'TEXT', 'TEXT', 'TEXT')
(MODEL_ID, MODEL_ARCH, MODEL_WEIGHTS, NAME, DESCRIPTION,
__INTERNAL_MADLIB_ID__) = col_names
@MinWarning("error")
def load_keras_model(keras_model_arch_table, model_arch, model_weights,
name, description, **kwargs):
model_arch_table = quote_ident(keras_model_arch_table)
if not table_exists(model_arch_table):
col_defs = get_col_name_type_sql_string(ModelArchSchema.col_names,
ModelArchSchema.col_types)
sql = "CREATE TABLE {model_arch_table} ({col_defs})" \
.format(**locals())
plpy.execute(sql, 0)
plpy.info("Keras Model Arch: Created new keras model architecture table {0}." \
.format(model_arch_table))
else:
missing_cols = columns_missing_from_table(model_arch_table,
ModelArchSchema.col_names)
if len(missing_cols) > 0:
plpy.error("Keras Model Arch: Invalid keras model architecture table {0},"
" missing columns: {1}".format(model_arch_table,
missing_cols))
unique_str = unique_string(prefix_has_temp=False)
insert_query = plpy.prepare("INSERT INTO {model_arch_table} "
"VALUES(DEFAULT, $1, $2, $3, $4, $5);".format(**locals()),
ModelArchSchema.col_types[1:])
insert_res = plpy.execute(insert_query,[model_arch, model_weights, name, description,
unique_str], 0)
select_query = """SELECT {model_id_col}, {model_arch_col} FROM {model_arch_table}
WHERE {internal_id_col} = '{unique_str}'""".format(
model_id_col=ModelArchSchema.MODEL_ID,
model_arch_col=ModelArchSchema.MODEL_ARCH,
model_arch_table=model_arch_table,
internal_id_col=ModelArchSchema.__INTERNAL_MADLIB_ID__,
unique_str=unique_str)
select_res = plpy.execute(select_query,1)
plpy.info("Keras Model Arch: Added model id {0} to {1} table".
format(select_res[0][ModelArchSchema.MODEL_ID], model_arch_table))
@MinWarning("error")
def delete_keras_model(keras_model_arch_table, model_id, **kwargs):
model_arch_table = quote_ident(keras_model_arch_table)
input_tbl_valid(model_arch_table, "Keras Model Arch")
missing_cols = columns_missing_from_table(model_arch_table, ModelArchSchema.col_names)
if len(missing_cols) > 0:
plpy.error("Keras Model Arch: Invalid keras model architecture table {0},"
" missing columns: {1}".format(model_arch_table, missing_cols))
sql = """
DELETE FROM {model_arch_table} WHERE {model_id_col}={model_id}
""".format(model_arch_table=model_arch_table, model_id_col=ModelArchSchema.MODEL_ID,
model_id=model_id)
res = plpy.execute(sql, 0)
if res.nrows() > 0:
plpy.info("Keras Model Arch: Model id {0} has been deleted from {1}.".
format(model_id, model_arch_table))
else:
plpy.error("Keras Model Arch: Model id {0} not found".format(model_id))
sql = "SELECT {0} FROM {1}".format(ModelArchSchema.MODEL_ID, model_arch_table)
res = plpy.execute(sql, 0)
if not res:
plpy.info("Keras Model Arch: Dropping empty keras model architecture "\
"table {model_arch_table}".format(model_arch_table=model_arch_table))
sql = "DROP TABLE {0}".format(model_arch_table)
plpy.execute(sql, 0)
class KerasModelArchDocumentation:
@staticmethod
def _returnHelpMsg(schema_madlib, message, summary, usage, method):
if not message:
return summary
elif message.lower() in ('usage', 'help', '?'):
return usage
return """
No such option. Use "SELECT {schema_madlib}.{method}()"
for help.
""".format(**locals())
@staticmethod
def load_keras_model_help(schema_madlib, message):
method = "load_keras_model"
summary = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
The architecture of the model to be used in madlib_keras_train()
function must be stored in a table, the details of which must be
provided as parameters to the madlib_keras_train module. This is
a helper function to help users insert JSON blobs of Keras model
architectures into a table.
If the output table already exists, the model_arch specified will
be added as a new row into the table. The output table could thus
act as a repository of Keras model architectures.
For more details on function usage:
SELECT {schema_madlib}.{method}('usage')
""".format(**locals())
usage = """
---------------------------------------------------------------------------
USAGE
---------------------------------------------------------------------------
SELECT {schema_madlib}.{method}(
keras_model_arch_table, -- Output table to load keras model arch.
model_arch, -- JSON of the model architecture to insert.
model_weights, -- Model weights to load as a PostgreSQL
binary data type.
name, -- Free text string to identify a name
description -- Free text string to provide a description
);
---------------------------------------------------------------------------
OUTPUT
---------------------------------------------------------------------------
The output table produced by load_keras_model contains the following columns:
'model_id' -- SERIAL PRIMARY KEY. Model ID.
'model_arch' -- JSON. JSON blob of the model architecture.
'model_weights' -- bytea. weights of the model for warm start.
'__internal_madlib_id__' -- TEXT. Unique id for model arch.
""".format(**locals())
return KerasModelArchDocumentation._returnHelpMsg(
schema_madlib, message, summary, usage, method)
# ---------------------------------------------------------------------
@staticmethod
def delete_keras_model_help(schema_madlib, message):
method = "delete_keras_model"
summary = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Delete the model architecture corresponding to the provided model_id
from the model architecture repository table (keras_model_arch_table).
For more details on function usage:
SELECT {schema_madlib}.{method}('usage')
""".format(**locals())
usage = """
---------------------------------------------------------------------------
USAGE
---------------------------------------------------------------------------
SELECT {schema_madlib}.{method}(
keras_model_arch_table VARCHAR, -- Table containing Keras model architectures.
model_id INTEGER -- The id of the model arch to be deleted.
);
---------------------------------------------------------------------------
OUTPUT
---------------------------------------------------------------------------
This method deletes the row corresponding to the given model_id in
keras_model_arch_table. This also tries to drop the table if the table is
empty after dropping the model_id. If there are any views depending on the
table, a warning message is displayed and the table is not dropped.
---------------------------------------------------------------------------
""".format(**locals())
return KerasModelArchDocumentation._returnHelpMsg(
schema_madlib, message, summary, usage, method)