blob: e50097095f0616491bb5867d3c0457649e4096e8 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
import dill
import plpy
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import get_col_name_type_sql_string
from utilities.validate_args import columns_missing_from_table
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import quote_ident
from utilities.validate_args import table_exists
module_name = 'Keras Custom Function'
class CustomFunctionSchema:
"""Expected format of custom function table.
Example uses:
from utilities.validate_args import columns_missing_from_table
from madlib_keras_custom_function import CustomFunctionSchema
# Validate names in cols list against actual table
missing_cols = columns_missing_from_table('my_custom_fn_table', CustomFunctionSchema.col_names)
# Get function object from table, without hard coding column names
sql = "SELECT {object} FROM {table} WHERE {id} = {my_id}"
.format(object=CustomFunctionSchema.FN_OBJ,
table='my_custom_fn_table',
id=CustomFunctionSchema.FN_ID,
my_id=1)
object = plpy.execute(sql)[0]
"""
FN_ID = 'id'
FN_NAME = 'name'
FN_OBJ = 'object'
FN_DESC = 'description'
col_names = (FN_ID, FN_NAME, FN_DESC, FN_OBJ)
col_types = ('SERIAL', 'TEXT', 'TEXT', 'BYTEA')
def _validate_object(object, **kwargs):
_assert(object is not None, "{0}: function object cannot be NULL!".format(module_name))
try:
obj=dill.loads(object)
except Exception as e:
plpy.error("{0}: Invalid function object".format(module_name, e))
def load_custom_function(object_table, object, name, description=None, **kwargs):
object_table = quote_ident(object_table)
_validate_object(object)
_assert(name is not None,
"{0}: function name cannot be NULL!".format(module_name))
if not table_exists(object_table):
col_defs = get_col_name_type_sql_string(CustomFunctionSchema.col_names,
CustomFunctionSchema.col_types)
sql = "CREATE TABLE {0} ({1}, PRIMARY KEY({2}))" \
.format(object_table, col_defs, CustomFunctionSchema.FN_NAME)
plpy.execute(sql, 0)
# Using plpy.notice here as this function can be called:
# 1. Directly by the user, we do want to display to the user
# if we create a new table or later the function name that
# is added to the table
# 2. From load_top_k_accuracy_function, since plpy.info
# displays the query context when called from the function
# there is a very verbose output and cannot be suppressed with
# MinWarning decorator as INFO is always displayed irrespective
# of what the decorator sets the client_min_messages to.
# Therefore, instead we print this information as a NOTICE
# when called directly by the user and suppress it by setting
# MinWarning decorator to 'error' level in the calling function.
plpy.notice("{0}: Created new custom function table {1}." \
.format(module_name, object_table))
else:
missing_cols = columns_missing_from_table(object_table,
CustomFunctionSchema.col_names)
if len(missing_cols) > 0:
plpy.error("{0}: Invalid custom function table {1},"
" missing columns: {2}".format(module_name,
object_table,
missing_cols))
insert_query = plpy.prepare("INSERT INTO {0} "
"VALUES(DEFAULT, $1, $2, $3);".format(object_table),
CustomFunctionSchema.col_types[1:])
try:
plpy.execute(insert_query,[name, description, object], 0)
# spiexceptions.UniqueViolation is only supported for PG>=9.2. For
# GP5(based of PG8.4) it cannot be used. Therefore, checking exception
# message for duplicate key error.
except Exception as e:
if 'duplicate key' in e.message:
plpy.error("Function '{0}' already exists in {1}".format(name, object_table))
plpy.error(e)
plpy.notice("{0}: Added function {1} to {2} table".
format(module_name, name, object_table))
def delete_custom_function(object_table, id=None, name=None, **kwargs):
object_table = quote_ident(object_table)
input_tbl_valid(object_table, "Keras Custom Funtion")
_assert(id is not None or name is not None,
"{0}: function id/name cannot be NULL! " \
"Use \"SELECT delete_custom_function('usage')\" for help.".format(module_name))
missing_cols = columns_missing_from_table(object_table, CustomFunctionSchema.col_names)
if len(missing_cols) > 0:
plpy.error("{0}: Invalid custom function table {1},"
" missing columns: {2}".format(module_name, object_table,
missing_cols))
if id is not None:
sql = """
DELETE FROM {0} WHERE {1}={2}
""".format(object_table, CustomFunctionSchema.FN_ID, id)
else:
sql = """
DELETE FROM {0} WHERE {1}=$${2}$$
""".format(object_table, CustomFunctionSchema.FN_NAME, name)
res = plpy.execute(sql, 0)
if res.nrows() > 0:
plpy.notice("{0}: Object id {1} has been deleted from {2}.".
format(module_name, id, object_table))
else:
plpy.error("{0}: Object id {1} not found".format(module_name, id))
sql = "SELECT {0} FROM {1}".format(CustomFunctionSchema.FN_ID, object_table)
res = plpy.execute(sql, 0)
if not res:
plpy.notice("{0}: Dropping empty custom keras function table " \
"table {1}".format(module_name, object_table))
sql = "DROP TABLE {0}".format(object_table)
plpy.execute(sql, 0)
def update_builtin_metrics(builtin_metrics):
builtin_metrics.append('accuracy')
builtin_metrics.append('acc')
builtin_metrics.append('crossentropy')
builtin_metrics.append('ce')
return builtin_metrics
@MinWarning("error")
def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs):
object_table = quote_ident(object_table)
_assert(k > 0,
"{0}: For top k accuracy functions k has to be a positive integer.".format(module_name))
fn_name = "top_{k}_accuracy".format(**locals())
sql = """
SELECT {schema_madlib}.load_custom_function(\'{object_table}\',
{schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'),
\'{fn_name}\',
\'returns {fn_name}\');
""".format(**locals())
plpy.execute(sql)
# As this function allocates the name for the top_k_accuracy function,
# printing it out here so the user doesn't need to lookup for the
# newly added custom function name in the object_table
plpy.info("{0}: Added function \'{1}\' to \'{2}\' table".
format(module_name, fn_name, object_table))
return
class KerasCustomFunctionDocumentation:
@staticmethod
def _returnHelpMsg(schema_madlib, message, summary, usage, method):
if not message:
return summary
elif message.lower() in ('usage', 'help', '?'):
return usage
return """
No such option. Use "SELECT {schema_madlib}.{method}()"
for help.
""".format(**locals())
@staticmethod
def load_custom_function_help(schema_madlib, message):
method = "load_custom_function"
summary = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
The user can specify custom functions as part of the parameters
passed to madlib_keras_fit()/madlib_keras_fit_multiple(). These
custom function(s) definition must be stored in a table to pass.
This is a helper function to help users insert object(BYTEA) of
the function definitions into a table.
If the output table already exists, the custom function specified
will be added as a new row into the table. The output table could
thus act as a repository of Keras custom functions.
For more details on function usage:
SELECT {schema_madlib}.{method}('usage')
""".format(**locals())
usage = """
---------------------------------------------------------------------------
USAGE
---------------------------------------------------------------------------
SELECT {schema_madlib}.{method}(
object_table, -- VARCHAR. Output table to load custom function.
object, -- BYTEA. dill pickled object of the function definition.
name, -- TEXT. Free text string to identify a name
description -- TEXT. Free text string to provide a description
);
---------------------------------------------------------------------------
OUTPUT
---------------------------------------------------------------------------
The output table produced by load_custom_function contains the following columns:
'id' -- SERIAL. Function ID.
'name' -- TEXT PRIMARY KEY. unique function name.
'description' -- TEXT. function description.
'object' -- BYTEA. dill pickled function object.
""".format(**locals())
return KerasCustomFunctionDocumentation._returnHelpMsg(
schema_madlib, message, summary, usage, method)
# ---------------------------------------------------------------------
@staticmethod
def delete_custom_function_help(schema_madlib, message):
method = "delete_custom_function"
summary = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
Delete the custom function corresponding to the provided id
from the custom function repository table (object_table).
For more details on function usage:
SELECT {schema_madlib}.{method}('usage')
""".format(**locals())
usage = """
---------------------------------------------------------------------------
USAGE
---------------------------------------------------------------------------
SELECT {schema_madlib}.{method}(
object_table VARCHAR, -- Table containing keras custom function objects.
id INTEGER -- The id of the keras custom function object
to be deleted.
);
SELECT {schema_madlib}.{method}(
object_table VARCHAR, -- Table containing keras custom function objects.
name TEXT -- Function name of the keras custom function
object to be deleted.
);
---------------------------------------------------------------------------
OUTPUT
---------------------------------------------------------------------------
This method deletes the row corresponding to the given id in the
object_table. This also tries to drop the table if the table is
empty after dropping the id. If there are any views depending on the
table, a warning message is displayed and the table is not dropped.
---------------------------------------------------------------------------
""".format(**locals())
return KerasCustomFunctionDocumentation._returnHelpMsg(
schema_madlib, message, summary, usage, method)
@staticmethod
def load_top_k_accuracy_function_help(schema_madlib, message):
method = "load_top_k_accuracy_function"
summary = """
----------------------------------------------------------------
SUMMARY
----------------------------------------------------------------
The user can specify a custom n value for top_n_accuracy metric.
If the output table already exists, the custom function specified
will be added as a new row into the table. The output table could
thus act as a repository of Keras custom functions.
For more details on function usage:
SELECT {schema_madlib}.{method}('usage')
""".format(**locals())
usage = """
---------------------------------------------------------------------------
USAGE
---------------------------------------------------------------------------
SELECT {schema_madlib}.{method}(
object_table, -- VARCHAR. Output table to load custom function.
k -- INTEGER. The number of samples for top n accuracy
);
---------------------------------------------------------------------------
OUTPUT
---------------------------------------------------------------------------
The output table produced by load_top_k_accuracy_function contains the following columns:
'id' -- SERIAL. Function ID.
'name' -- TEXT PRIMARY KEY. unique function name.
'description' -- TEXT. function description.
'object' -- BYTEA. dill pickled function object.
""".format(**locals())
return KerasCustomFunctionDocumentation._returnHelpMsg(
schema_madlib, message, summary, usage, method)
# ---------------------------------------------------------------------