| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| |
| import dill |
| import plpy |
| from utilities.control import MinWarning |
| from utilities.utilities import _assert |
| from utilities.utilities import get_col_name_type_sql_string |
| from utilities.utilities import current_user |
| from utilities.utilities import is_superuser |
| from utilities.utilities import get_schema |
| from utilities.validate_args import columns_missing_from_table |
| from utilities.validate_args import input_tbl_valid |
| from utilities.validate_args import quote_ident |
| from utilities.validate_args import unquote_ident |
| from utilities.validate_args import table_exists |
| |
| module_name = 'Keras Custom Function' |
| class CustomFunctionSchema: |
| """Expected format of custom function table. |
| Example uses: |
| |
| from utilities.validate_args import columns_missing_from_table |
| from madlib_keras_custom_function import CustomFunctionSchema |
| |
| # Validate names in cols list against actual table |
| missing_cols = columns_missing_from_table('my_custom_fn_table', CustomFunctionSchema.col_names) |
| |
| # Get function object from table, without hard coding column names |
| sql = "SELECT {object} FROM {table} WHERE {id} = {my_id}" |
| .format(object=CustomFunctionSchema.FN_OBJ, |
| table='my_custom_fn_table', |
| id=CustomFunctionSchema.FN_ID, |
| my_id=1) |
| object = plpy.execute(sql)[0] |
| |
| """ |
| FN_ID = 'id' |
| FN_NAME = 'name' |
| FN_OBJ = 'object' |
| FN_DESC = 'description' |
| col_names = (FN_ID, FN_NAME, FN_DESC, FN_OBJ) |
| col_types = ('SERIAL', 'TEXT', 'TEXT', 'BYTEA') |
| |
| def _validate_object(object, **kwargs): |
| _assert(object is not None, "{0}: function object cannot be NULL!".format(module_name)) |
| try: |
| obj=dill.loads(object) |
| except Exception as e: |
| plpy.error("{0}: Invalid function object".format(module_name, e)) |
| |
| @MinWarning("error") |
| def load_custom_function(schema_madlib, object_table, object, name, description=None, **kwargs): |
| |
| if object_table is not None: |
| object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table)) |
| _validate_object(object) |
| _assert(name is not None, |
| "{0}: function name cannot be NULL!".format(module_name)) |
| _assert(is_superuser(current_user()), "DL: The user has to have admin "\ |
| "privilages to load a custom function") |
| try: |
| if not table_exists(object_table): |
| col_defs = get_col_name_type_sql_string(CustomFunctionSchema.col_names, |
| CustomFunctionSchema.col_types) |
| |
| sql = """CREATE TABLE {object_table} |
| ({col_defs}, PRIMARY KEY({fn_name})) |
| """.format(fn_name=CustomFunctionSchema.FN_NAME,**locals()) |
| |
| plpy.execute(sql, 0) |
| # Using plpy.notice here as this function can be called: |
| # 1. Directly by the user, we do want to display to the user |
| # if we create a new table or later the function name that |
| # is added to the table |
| # 2. From load_top_k_accuracy_function, since plpy.info |
| # displays the query context when called from the function |
| # there is a very verbose output and cannot be suppressed with |
| # MinWarning decorator as INFO is always displayed irrespective |
| # of what the decorator sets the client_min_messages to. |
| # Therefore, instead we print this information as a NOTICE |
| # when called directly by the user and suppress it by setting |
| # MinWarning decorator to 'error' level in the calling function. |
| plpy.notice("{0}: Created new custom function table {1}." \ |
| .format(module_name, object_table)) |
| plpy.execute("GRANT SELECT ON {0} TO PUBLIC".format(object_table)) |
| else: |
| missing_cols = columns_missing_from_table(object_table, |
| CustomFunctionSchema.col_names) |
| if len(missing_cols) > 0: |
| plpy.error("{0}: Invalid custom function table {1}," |
| " missing columns: {2}".format(module_name, |
| object_table, |
| missing_cols)) |
| |
| insert_query = plpy.prepare("INSERT INTO {object_table} " |
| "VALUES(DEFAULT, $1, $2, $3);".format(**locals()), |
| CustomFunctionSchema.col_types[1:]) |
| |
| plpy.execute(insert_query,[name, description, object], 0) |
| # spiexceptions.UniqueViolation is only supported for PG>=9.2. For |
| # GP5(based of PG8.4) it cannot be used. Therefore, checking exception |
| # message for duplicate key error. |
| except Exception as e: |
| if 'duplicate key' in e.message: |
| plpy.error("Function '{0}' already exists in {1}".format(name, object_table)) |
| plpy.error(e) |
| |
| plpy.notice("{0}: Added function {1} to {2} table". |
| format(module_name, name, object_table)) |
| |
| @MinWarning("error") |
| def delete_custom_function(schema_madlib, object_table, id=None, name=None, **kwargs): |
| |
| if object_table is not None: |
| schema_name = get_schema(object_table) |
| if schema_name is None: |
| object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table)) |
| elif schema_name != schema_madlib: |
| plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib)) |
| |
| input_tbl_valid(object_table, "Keras Custom Funtion") |
| _assert(id is not None or name is not None, |
| "{0}: function id/name cannot be NULL! " \ |
| "Use \"SELECT delete_custom_function('usage')\" for help.".format(module_name)) |
| |
| missing_cols = columns_missing_from_table(object_table, CustomFunctionSchema.col_names) |
| if len(missing_cols) > 0: |
| plpy.error("{0}: Invalid custom function table {1}," |
| " missing columns: {2}".format(module_name, object_table, |
| missing_cols)) |
| |
| if id is not None: |
| sql = """ |
| DELETE FROM {object_table} WHERE {fn_id}={id} |
| """.format(fn_id=CustomFunctionSchema.FN_ID,**locals()) |
| else: |
| sql = """ |
| DELETE FROM {object_table} WHERE {fn_name}=$${name}$$ |
| """.format(fn_name=CustomFunctionSchema.FN_NAME,**locals()) |
| res = plpy.execute(sql, 0) |
| |
| if res.nrows() > 0: |
| plpy.notice("{0}: Object id {1} has been deleted from {2}.". |
| format(module_name, id, object_table)) |
| else: |
| plpy.error("{0}: Object id {1} not found".format(module_name, id)) |
| |
| sql = "SELECT {0} FROM {1}".format( |
| CustomFunctionSchema.FN_ID, object_table) |
| res = plpy.execute(sql, 0) |
| if not res: |
| plpy.notice("{0}: Dropping empty custom keras function table " \ |
| "table {1}".format(module_name, object_table)) |
| sql = "DROP TABLE {0}".format(object_table) |
| plpy.execute(sql, 0) |
| |
| def update_builtin_metrics(builtin_metrics): |
| builtin_metrics.append('accuracy') |
| builtin_metrics.append('acc') |
| builtin_metrics.append('crossentropy') |
| builtin_metrics.append('ce') |
| return builtin_metrics |
| |
| @MinWarning("error") |
| def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs): |
| |
| object_table = quote_ident(object_table) |
| _assert(k > 0, |
| "{0}: For top k accuracy functions k has to be a positive integer.".format(module_name)) |
| fn_name = "top_{k}_accuracy".format(**locals()) |
| |
| sql = """ |
| SELECT {schema_madlib}.load_custom_function(\'{object_table}\', |
| {schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'), |
| \'{fn_name}\', |
| \'returns {fn_name}\'); |
| """.format(**locals()) |
| plpy.execute(sql) |
| # As this function allocates the name for the top_k_accuracy function, |
| # printing it out here so the user doesn't need to lookup for the |
| # newly added custom function name in the object_table |
| plpy.info("{0}: Added function \'{1}\' to \'{2}\' table". |
| format(module_name, fn_name, object_table)) |
| return |
| |
| class KerasCustomFunctionDocumentation: |
| @staticmethod |
| def _returnHelpMsg(schema_madlib, message, summary, usage, method): |
| if not message: |
| return summary |
| elif message.lower() in ('usage', 'help', '?'): |
| return usage |
| return """ |
| No such option. Use "SELECT {schema_madlib}.{method}()" |
| for help. |
| """.format(**locals()) |
| |
| @staticmethod |
| def load_custom_function_help(schema_madlib, message): |
| method = "load_custom_function" |
| summary = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| The user can specify custom functions as part of the parameters |
| passed to madlib_keras_fit()/madlib_keras_fit_multiple(). These |
| custom function(s) definition must be stored in a table to pass. |
| This is a helper function to help users insert object(BYTEA) of |
| the function definitions into a table. |
| If the output table already exists, the custom function specified |
| will be added as a new row into the table. The output table could |
| thus act as a repository of Keras custom functions. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.{method}('usage') |
| """.format(**locals()) |
| |
| usage = """ |
| --------------------------------------------------------------------------- |
| USAGE |
| --------------------------------------------------------------------------- |
| SELECT {schema_madlib}.{method}( |
| object_table, -- VARCHAR. Output table to load custom function. |
| object, -- BYTEA. dill pickled object of the function definition. |
| name, -- TEXT. Free text string to identify a name |
| description -- TEXT. Free text string to provide a description |
| ); |
| |
| |
| --------------------------------------------------------------------------- |
| OUTPUT |
| --------------------------------------------------------------------------- |
| The output table produced by load_custom_function contains the following columns: |
| |
| 'id' -- SERIAL. Function ID. |
| 'name' -- TEXT PRIMARY KEY. unique function name. |
| 'description' -- TEXT. function description. |
| 'object' -- BYTEA. dill pickled function object. |
| |
| """.format(**locals()) |
| |
| return KerasCustomFunctionDocumentation._returnHelpMsg( |
| schema_madlib, message, summary, usage, method) |
| # --------------------------------------------------------------------- |
| |
| @staticmethod |
| def delete_custom_function_help(schema_madlib, message): |
| method = "delete_custom_function" |
| summary = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| Delete the custom function corresponding to the provided id |
| from the custom function repository table (object_table). |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.{method}('usage') |
| """.format(**locals()) |
| |
| usage = """ |
| --------------------------------------------------------------------------- |
| USAGE |
| --------------------------------------------------------------------------- |
| SELECT {schema_madlib}.{method}( |
| object_table VARCHAR, -- Table containing keras custom function objects. |
| id INTEGER -- The id of the keras custom function object |
| to be deleted. |
| ); |
| |
| SELECT {schema_madlib}.{method}( |
| object_table VARCHAR, -- Table containing keras custom function objects. |
| name TEXT -- Function name of the keras custom function |
| object to be deleted. |
| ); |
| |
| --------------------------------------------------------------------------- |
| OUTPUT |
| --------------------------------------------------------------------------- |
| This method deletes the row corresponding to the given id in the |
| object_table. This also tries to drop the table if the table is |
| empty after dropping the id. If there are any views depending on the |
| table, a warning message is displayed and the table is not dropped. |
| |
| --------------------------------------------------------------------------- |
| """.format(**locals()) |
| |
| return KerasCustomFunctionDocumentation._returnHelpMsg( |
| schema_madlib, message, summary, usage, method) |
| |
| @staticmethod |
| def load_top_k_accuracy_function_help(schema_madlib, message): |
| method = "load_top_k_accuracy_function" |
| summary = """ |
| ---------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------- |
| The user can specify a custom n value for top_n_accuracy metric. |
| If the output table already exists, the custom function specified |
| will be added as a new row into the table. The output table could |
| thus act as a repository of Keras custom functions. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.{method}('usage') |
| """.format(**locals()) |
| |
| usage = """ |
| --------------------------------------------------------------------------- |
| USAGE |
| --------------------------------------------------------------------------- |
| SELECT {schema_madlib}.{method}( |
| object_table, -- VARCHAR. Output table to load custom function. |
| k -- INTEGER. The number of samples for top n accuracy |
| ); |
| |
| |
| --------------------------------------------------------------------------- |
| OUTPUT |
| --------------------------------------------------------------------------- |
| The output table produced by load_top_k_accuracy_function contains the following columns: |
| |
| 'id' -- SERIAL. Function ID. |
| 'name' -- TEXT PRIMARY KEY. unique function name. |
| 'description' -- TEXT. function description. |
| 'object' -- BYTEA. dill pickled function object. |
| |
| """.format(**locals()) |
| |
| return KerasCustomFunctionDocumentation._returnHelpMsg( |
| schema_madlib, message, summary, usage, method) |
| # --------------------------------------------------------------------- |