/* ----------------------------------------------------------------------- *//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *
 * @file madlib_keras_custom_function.sql_in
 *
 * @brief Utility function to load serialized Python objects into a table
 * @date May 2020
 *
 *
 *//* ----------------------------------------------------------------------- */

m4_include(`SQLCommon.m4')
/**
@addtogroup grp_custom_function

@brief Utility function to load serialized Python objects into a table.

\warning <em> This MADlib method is still in early stage development.
Interface and implementation are subject to change. </em>

<div class="toc"><b>Contents</b><ul>
<li class="level1"><a href="#load_function">Load Function</a></li>
<li class="level1"><a href="#delete_function">Delete Function</a></li>
<li class="level1"><a href="#top_k_function">Top k Accuracy Function</a></li>
<li class="level1"><a href="#example">Examples</a></li>
<li class="level1"><a href="#literature">Literature</a></li>
<li class="level1"><a href="#related">Related Topics</a></li>
</ul></div>

This utility function loads custom Python functions
into a table for use by deep learning algorithms.
Custom functions can be useful if, for example, you need loss functions
or metrics that are not built into the standard libraries.
The functions to be loaded must be in the form of serialized Python objects
created using Dill, which extends Python's pickle module to the majority
of the built-in Python types [1].

Custom functions are also used to return top k categorical accuracy rate
in the case that you want a different k value than the default from Keras.
This module includes a helper function to create the custom function
automatically for a specified k. 

There is also a utility function to delete a function
from the table.

@anchor load_function
@par Load Function

<pre class="syntax">
load_custom_function(
    object table,
    object,
    name,
    description
    )
</pre>
\b Arguments
<dl class="arglist">
  <dt>object table</dt>
  <dd>VARCHAR. Table to load serialized Python objects.  If this table
  does not exist, it will be created.  If this table already
  exists, a new row is inserted into the existing table.
  </dd>

  <dt>object</dt>
  <dd>BYTEA. PostgreSQL binary data type of the Python object.
  Object must be created with the Dill package for serializing
  Python objects.
  </dd>

  <dt>name</dt>
  <dd>TEXT. Name of the object. Must be unique identifier
  in the table, since this name is used when passing the
  object to Keras.
  </dd>

  <dt>description (optional)</dt>
  <dd>TEXT, default: NULL. Free text string to provide
  a description, if desired.
  </dd>

</dl>

<b>Output table</b>
<br>
    The output table contains the following columns:
    <table class="output">
      <tr>
        <th>id</th>
        <td>SERIAL. Object ID.
        </td>
      </tr>
      <tr>
        <th>name</th>
        <td>TEXT PRIMARY KEY. Name of the object.
        </td>
      </tr>
      <tr>
        <th>description</th>
        <td>TEXT. Description of the object (free text).
        </td>
      </tr>
      <tr>
        <th>object</th>
        <td>BYTEA. Serialized Python object stored as a PostgreSQL binary data type.
        </td>
      </tr>
    </table>
</br>

@anchor delete_function
@par Delete Function

Delete by id:
<pre class="syntax">
delete_custom_function(
    object_table,
    id
)
</pre>
Or alternatively by name:
<pre class="syntax">
delete_custom_function(
    object_table,
    name
)
</pre>
\b Arguments
<dl class="arglist">
  <dt>object_table</dt>
    <dd>VARCHAR. Table containing Python object to be deleted.
  </dd>
  <dt>id</dt>
    <dd>INTEGER. The id of the object to be deleted.
  </dd>
  <dt>name</dt>
    <dd>TEXT. Name of the object to be deleted.
  </dd>
</dl>

@anchor top_k_function
@par Top k Accuracy Function

Create and load a custom function for a specific k into the custom functions table.
The Keras accuracy parameter 'top_k_categorical_accuracy' returns top 5 accuracy by default [2].
If you want a different top k value, use this helper function to create a custom
Python function to compute the top k accuracy that you specify.

<pre class="syntax">
load_top_k_accuracy_function(
    object table,
    k
    )
</pre>
\b Arguments
<dl class="arglist">
  <dt>object table</dt>
  <dd>VARCHAR. Table to load serialized Python objects.  If this table
  does not exist, it will be created.  If this table already
  exists, a new row is inserted into the existing table.
  </dd>

  <dt>k</dt>
  <dd>INTEGER. k value for the top k accuracy that you want.
  </dd>

</dl>

<b>Output table</b>
<br>
    The output table contains the following columns:
    <table class="output">
      <tr>
        <th>id</th>
        <td>SERIAL. Object ID.
        </td>
      </tr>
      <tr>
        <th>name</th>
        <td>TEXT PRIMARY KEY. Name of the object.
        Generated with the following pattern: top_(k)_accuracy.
        </td>
      </tr>
      <tr>
        <th>description</th>
        <td>TEXT. Description of the object.
        </td>
      </tr>
      <tr>
        <th>object</th>
        <td>BYTEA. Serialized Python object stored as a PostgreSQL binary data type.
        </td>
      </tr>
    </table>
</br>

@anchor example
@par Examples
-# Load object using psycopg2. Psycopg is a PostgreSQL database
adapter for the Python programming language.  Note need to use the
psycopg2.Binary() method to pass as bytes.
<pre class="example">
\# import database connector psycopg2 and create connection cursor
import psycopg2 as p2
conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')
cur = conn.cursor()
\# import Dill and define functions
import dill
\# custom loss
def squared_error(y_true, y_pred):
    import keras.backend as K
    return K.square(y_pred - y_true)
pb_squared_error=dill.dumps(squared_error)
\# custom metric
def rmse(y_true, y_pred):
    import keras.backend as K
    return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
pb_rmse=dill.dumps(rmse)
\# call load function
cur.execute("DROP TABLE IF EXISTS custom_function_table")
cur.execute("SELECT madlib.load_custom_function('custom_function_table',  %s,'squared_error', 'squared error')", [p2.Binary(pb_squared_error)])
cur.execute("SELECT madlib.load_custom_function('custom_function_table',  %s,'rmse', 'root mean square error')", [p2.Binary(pb_rmse)])
conn.commit()
</pre>
List table to see objects:
<pre class="example">
SELECT id, name, description FROM custom_function_table ORDER BY id;
</pre>
<pre class="result">
 id |     name      |      description
----+---------------+------------------------
  1 | squared_error | squared error
  2 | rmse          | root mean square error
</pre>
-# Load object using a PL/Python function.  First define the objects:
<pre class="example">
CREATE OR REPLACE FUNCTION custom_function_squared_error()
RETURNS BYTEA AS
$$
import dill
def squared_error(y_true, y_pred):
    import keras.backend as K
    return K.square(y_pred - y_true)
pb_squared_error=dill.dumps(squared_error)
return pb_squared_error
$$ language plpythonu;
CREATE OR REPLACE FUNCTION custom_function_rmse()
RETURNS BYTEA AS
$$
import dill
def rmse(y_true, y_pred):
    import keras.backend as K
    return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
pb_rmse=dill.dumps(rmse)
return pb_rmse
$$ language plpythonu;
</pre>
Now call loader:
<pre class="result">
DROP TABLE IF EXISTS custom_function_table;
SELECT madlib.load_custom_function('custom_function_table',
                                   custom_function_squared_error(),
                                   'squared_error',
                                   'squared error');
SELECT madlib.load_custom_function('custom_function_table',
                                   custom_function_rmse(),
                                   'rmse',
                                   'root mean square error');
</pre>
-# Delete an object by id:
<pre class="example">
SELECT madlib.delete_custom_function( 'custom_function_table', 1);
SELECT id, name, description FROM custom_function_table ORDER BY id;
</pre>
<pre class="result">
 id | name |      description
----+------+------------------------
  2 | rmse | root mean square error
</pre>
Delete an object by name:
<pre class="example">
SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
</pre>
If all objects are deleted from the table using this function, the table itself will be dropped.
</pre>
-# Load top 3 accuracy function followed by a top 10 accuracy function:
<pre class="example">
DROP TABLE IF EXISTS custom_function_table;
SELECT madlib.load_top_k_accuracy_function('custom_function_table',
                                           3);
SELECT madlib.load_top_k_accuracy_function('custom_function_table',
                                           10);
SELECT id, name, description FROM custom_function_table ORDER BY id;
</pre>
<pre class="result">
 id |      name       |       description       
----+-----------------+-------------------------
  1 | top_3_accuracy  | returns top_3_accuracy
  2 | top_10_accuracy | returns top_10_accuracy
</pre>
@anchor literature
@literature

[1] Dill https://pypi.org/project/dill/

[2] https://keras.io/api/metrics/accuracy_metrics/#topkcategoricalaccuracy-class

@anchor related
@par Related Topics

See madlib_keras_custom_function.sql_in

*/

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_custom_function(
    object_table            VARCHAR,
    object                  BYTEA,
    name                    TEXT,
    description             TEXT
) RETURNS VOID AS $$
    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_custom_function')
    with AOControl(False):
        madlib_keras_custom_function.load_custom_function(**globals())
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_custom_function(
    object_table            VARCHAR,
    object                  BYTEA,
    name                    TEXT
) RETURNS VOID AS $$
    SELECT MADLIB_SCHEMA.load_custom_function($1, $2, $3, NULL)
$$ LANGUAGE sql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

-- Functions for online help
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_custom_function(
    message VARCHAR
) RETURNS VARCHAR AS $$
    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_custom_function_help(schema_madlib, message)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_custom_function()
RETURNS VARCHAR AS $$
    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_custom_function_help(schema_madlib, '')
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

-- Function to delete a keras custom function from object table
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function(
    object_table VARCHAR,
    id INTEGER
)
RETURNS VOID AS $$
    PythonFunctionBodyOnly(`deep_learning',`madlib_keras_custom_function')
    with AOControl(False):
        madlib_keras_custom_function.delete_custom_function(schema_madlib, object_table, id=id)
$$ LANGUAGE plpythonu VOLATILE;

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function(
    object_table VARCHAR,
    name TEXT
)
RETURNS VOID AS $$
    PythonFunctionBodyOnly(`deep_learning',`madlib_keras_custom_function')
    with AOControl(False):
        madlib_keras_custom_function.delete_custom_function(schema_madlib, object_table, name=name)
$$ LANGUAGE plpythonu VOLATILE;

-- Functions for online help
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function(
    message VARCHAR
) RETURNS VARCHAR AS $$
    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.delete_custom_function_help(schema_madlib, message)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function()
RETURNS VARCHAR AS $$
    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.delete_custom_function_help(schema_madlib, '')
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

-- Top n accuracy function
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
    object_table            VARCHAR,
    k                       INTEGER
) RETURNS VOID AS $$
    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_custom_function')
    with AOControl(False):
        madlib_keras_custom_function.load_top_k_accuracy_function(**globals())
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
    message VARCHAR
) RETURNS VARCHAR AS $$
    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, message)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function()
RETURNS VARCHAR AS $$
    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, '')
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.top_k_categorical_acc_pickled(
n INTEGER,
fn_name VARCHAR
) RETURNS BYTEA AS $$
    import dill
    from keras.metrics import top_k_categorical_accuracy

    def fn(Y_true, Y_pred):
        return top_k_categorical_accuracy(Y_true,
                                          Y_pred,
                                          k = n)
    fn.__name__= fn_name
    pb=dill.dumps(fn)
    return pb
$$ language plpythonu
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
