DL: Add support for custom metrics
JIRA: MADLIB-1433
This commit adds support for custom metrics to DL functions.
Most of the additions depend on the commit that added the custom
loss functions.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 9dcefed..23e16f6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -139,6 +139,14 @@
sql = "DROP TABLE {0}".format(object_table)
plpy.execute(sql, 0)
+def update_builtin_metrics(builtin_metrics):
+ builtin_metrics.append('accuracy')
+ builtin_metrics.append('acc')
+ builtin_metrics.append('crossentropy')
+ builtin_metrics.append('ce')
+ return builtin_metrics
+
+
class KerasCustomFunctionDocumentation:
@staticmethod
def _returnHelpMsg(schema_madlib, message, summary, usage, method):
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index a22a9a5..01523f3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -159,18 +159,22 @@
import psycopg2 as p2
conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')
cur = conn.cursor()
-\# import Dill and define 2 functions
+\# import Dill and define functions
import dill
-def test_sum_fn(a, b):
- return a+b
-pb_sum=dill.dumps(test_sum_fn)
-def test_mult_fn(a, b):
- return a*b
-pb_mult=dill.dumps(test_mult_fn)
+\# custom loss
+def squared_error(y_true, y_pred):
+ import keras.backend as K
+ return K.square(y_pred - y_true)
+pb_squared_error=dill.dumps(squared_error)
+\# custom metric
+def rmse(y_true, y_pred):
+ import keras.backend as K
+ return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
+pb_rmse=dill.dumps(rmse)
\# call load function
-cur.execute("DROP TABLE IF EXISTS test_custom_function_table")
-cur.execute("SELECT madlib.load_custom_function('test_custom_function_table', %s,'sum_fn', 'returns sum')", [p2.Binary(pb_sum)])
-cur.execute("SELECT madlib.load_custom_function('test_custom_function_table', %s,'mult_fn', 'returns mult')", [p2.Binary(pb_mult)])
+cur.execute("DROP TABLE IF EXISTS custom_function_table")
+cur.execute("SELECT madlib.load_custom_function('custom_function_table', %s,'squared_error', 'squared error')", [p2.Binary(pb_squared_error)])
+cur.execute("SELECT madlib.load_custom_function('custom_function_table', %s,'rmse', 'root mean square error')", [p2.Binary(pb_rmse)])
conn.commit()
</pre>
List table to see objects:
@@ -178,59 +182,59 @@
SELECT id, name, description FROM test_custom_function_table ORDER BY id;
</pre>
<pre class="result">
- id | name | description
-----+---------+--------------
- 1 | sum_fn | returns sum
- 2 | mult_fn | returns mult
+ id | name | description
+----+---------------+------------------------
+ 1 | squared_error | squared error
+ 2 | rmse | root mean square error
</pre>
-# Load object using a PL/Python function. First define the objects:
<pre class="example">
-CREATE OR REPLACE FUNCTION custom_function_object_sum()
+CREATE OR REPLACE FUNCTION custom_function_squared_error()
RETURNS BYTEA AS
$$
import dill
-def test_sum_fn(a, b):
- return a+b
-pb_sum=dill.dumps(test_sum_fn)
-return pb_sum
+def squared_error(y_true, y_pred):
+ import keras.backend as K
+ return K.square(y_pred - y_true)
+pb_squared_error=dill.dumps(squared_error)
+return pb_squared_error
$$ language plpythonu;
-CREATE OR REPLACE FUNCTION custom_function_object_mult()
+CREATE OR REPLACE FUNCTION custom_function_rmse()
RETURNS BYTEA AS
$$
import dill
-def test_mult_fn(a, b):
- return a*b
-pb_mult=dill.dumps(test_mult_fn)
-return pb_mult
+def rmse(y_true, y_pred):
+ import keras.backend as K
+ return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
+pb_rmse=dill.dumps(rmse)
+return pb_rmse
$$ language plpythonu;
</pre>
Now call loader:
<pre class="result">
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT madlib.load_custom_function('test_custom_function_table',
- custom_function_object_sum(),
- 'sum_fn',
- 'returns sum'
- );
-SELECT madlib.load_custom_function('test_custom_function_table',
- custom_function_object_mult(),
- 'mult_fn',
- 'returns mult'
- );
+DROP TABLE IF EXISTS custom_function_table;
+SELECT madlib.load_custom_function('custom_function_table',
+ custom_function_squared_error(),
+ 'squared_error',
+ 'squared error');
+SELECT madlib.load_custom_function('custom_function_table',
+ custom_function_rmse(),
+ 'rmse',
+ 'root mean square error');
</pre>
-# Delete an object by id:
<pre class="example">
-SELECT madlib.delete_custom_function( 'test_custom_function_table', 1);
-SELECT id, name, description FROM test_custom_function_table ORDER BY id;
+SELECT madlib.delete_custom_function( 'custom_function_table', 1);
+SELECT id, name, description FROM custom_function_table ORDER BY id;
</pre>
<pre class="result">
- id | name | description
-----+---------+--------------
- 2 | mult_fn | returns mult
+ id | name | description
+----+------+------------------------
+ 2 | rmse | root mean square error
</pre>
Delete an object by name:
<pre class="example">
-SELECT madlib.delete_custom_function( 'test_custom_function_table', 'mult_fn');
+SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
</pre>
If all objects are deleted from the table using this function, the table itself will be dropped.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 424250d..0a9b9ae 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -294,6 +294,8 @@
def populate_object_map(self):
builtin_losses = dir(losses)
+ builtin_metrics = update_builtin_metrics(dir(metrics))
+
# Track distinct custom functions in compile_params
custom_fn_names = []
# Track their corresponding mst_keys to pass along the custom function
@@ -308,9 +310,16 @@
# Also, it is validated later when we compile the model
# on the segments
compile_dict = convert_string_of_args_to_dict(compile_params)
- if (compile_dict['loss'] not in builtin_losses):
- custom_fn_names.append(compile_dict['loss'])
+
+ local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
+ local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
+ if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
+ custom_fn_names.append(local_loss)
custom_fn_mst_idx.append(mst_idx)
+ if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
+ custom_fn_names.append(local_metric)
+ custom_fn_mst_idx.append(mst_idx)
+
if len(custom_fn_names) > 0:
# Pass only unique custom_fn_names to query from object table
custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
@@ -400,6 +409,7 @@
model_size DOUBLE PRECISION,
metrics_elapsed_time DOUBLE PRECISION[],
metrics_type TEXT[],
+ loss_type TEXT,
training_metrics_final DOUBLE PRECISION,
training_loss_final DOUBLE PRECISION,
training_metrics DOUBLE PRECISION[],
@@ -446,15 +456,18 @@
metrics_type = 'ARRAY{0}'.format(
metrics_list) if is_metrics_specified else 'NULL'
+ loss_type = get_loss_from_compile_param(mst[self.compile_params_col])
+ loss_type = loss_type if loss_type else 'NULL'
+
info_table_insert_query = """
INSERT INTO {self.model_info_table}({self.mst_key_col},
{self.model_id_col}, {self.compile_params_col},
{self.fit_params_col}, model_type, model_size,
- metrics_type)
+ metrics_type, loss_type)
VALUES ({mst_key_val}, {model_id},
$madlib${compile_params}$madlib$,
$madlib${fit_params}$madlib$, '{model_type}',
- {model_size}, {metrics_type})
+ {model_size}, {metrics_type}, '{loss_type}')
""".format(self=self,
mst_key_val=mst[self.mst_key_col],
model_id=mst[self.model_id_col],
@@ -462,7 +475,8 @@
fit_params=mst[self.fit_params_col],
model_type='madlib_keras',
model_size=model_size,
- metrics_type=metrics_type)
+ metrics_type=metrics_type,
+ loss_type=loss_type)
plpy.execute(info_table_insert_query)
if not mst['mst_key'] in self.warm_start_msts:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index a4463e7..20d1574 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -30,10 +30,12 @@
import keras.optimizers as opt
import keras.losses as losses
+import keras.metrics as metrics
import madlib_keras_serializer
import madlib_keras_gpu_info
from madlib_keras_custom_function import CustomFunctionSchema
+from madlib_keras_custom_function import update_builtin_metrics
from utilities.utilities import _assert
from utilities.utilities import is_platform_pg
@@ -327,8 +329,19 @@
optimizers = get_optimizers()
(opt_name,final_args,compile_dict) = parse_and_validate_compile_params(compile_params)
if custom_function_map is not None:
- map=dill.loads(custom_function_map)
- compile_dict['loss']=map[compile_dict['loss']]
+ local_map=dill.loads(custom_function_map)
+
+ compile_dict['loss']=local_map[compile_dict['loss']] \
+ if compile_dict['loss'] in local_map else compile_dict['loss']
+
+ new_metrics = []
+ for i in compile_dict['metrics']:
+ if i in local_map:
+ new_metrics.append(local_map[i])
+ else:
+ new_metrics.append(i)
+ compile_dict['metrics'] = new_metrics
+
compile_dict['optimizer'] = optimizers[opt_name](**final_args) if final_args else opt_name
model.compile(**compile_dict)
@@ -371,6 +384,10 @@
"""
if len(custom_fn_names) < 1:
return None
+
+ fn_set = set(custom_fn_names)
+ unique_fn_list = (list(fn_set))
+
custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
# Dictionary map of name:object
# {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
@@ -378,13 +395,13 @@
# Query the custom function if not yet loaded from table
res = plpy.execute("""
SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table}
- WHERE {custom_fn_col_name} = ANY(ARRAY{custom_fn_names})
+ WHERE {custom_fn_col_name} = ANY(ARRAY{unique_fn_list})
""".format(custom_obj_col_name=custom_obj_col_name,
object_table=object_table,
custom_fn_col_name=CustomFunctionSchema.FN_NAME,
- custom_fn_names=custom_fn_names))
- if res.nrows() < len(custom_fn_names):
- plpy.error("Custom function {0} not defined in object table '{1}'".format(custom_fn_names, object_table))
+ unique_fn_list=unique_fn_list))
+ if res.nrows() < len(unique_fn_list):
+ plpy.error("Custom function {0} not defined in object table '{1}'".format(unique_fn_list, object_table))
for r in res:
custom_fn_map[r[CustomFunctionSchema.FN_NAME]] = dill.loads(r[custom_obj_col_name])
custom_fn_map_obj = dill.dumps(custom_fn_map)
@@ -409,7 +426,14 @@
"""
compile_dict = convert_string_of_args_to_dict(compile_params)
builtin_losses = dir(losses)
+ builtin_metrics = update_builtin_metrics(dir(metrics))
+
custom_fn_list = []
- if compile_dict['loss'] not in builtin_losses:
- custom_fn_list.append(compile_dict['loss'])
+ local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
+ local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
+ if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
+ custom_fn_list.append(local_loss)
+ if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
+ custom_fn_list.append(local_metric)
+
return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
index 671cf07..1389326 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
@@ -39,3 +39,28 @@
res=obj(arg1, arg2)
return res
$$ language plpythonu;
+
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+ c = a*b*0
+ return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+CREATE OR REPLACE FUNCTION custom_function_one_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn1(a, b):
+ c = a*b*0+1
+ return c
+
+pb=dill.dumps(test_custom_fn1)
+return pb
+$$ language plpythonu;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index 82d5e97..ddfcc8d 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -66,8 +66,8 @@
/* Test adding an existing function name should error out */
SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
-$TRAP$) = 1, 'Should error out for duplicate function name');
+ SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+ $TRAP$) = 1, 'Should error out for duplicate function name');
/* Test deletion by id where valid table exists */
/* Assert id exists before deleting */
@@ -88,10 +88,12 @@
/* Test deleting an already deleted entry should error out */
SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-SELECT delete_custom_function('test_custom_function_table', 2);
-$TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
+ SELECT delete_custom_function('test_custom_function_table', 2);
+ $TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
/* Test delete drops the table after deleting last entry*/
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
SELECT delete_custom_function('test_custom_function_table', 1);
SELECT assert(COUNT(relname) = 0, 'Table test_custom_function_table should have been deleted.')
FROM pg_class where relname='test_custom_function_table';
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index 7eb4c15..fecd19f 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -26,6 +26,11 @@
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)
+\i m4_regexp(MODULE_PATHNAME,
+ `\(.*\)libmadlib\.so',
+ `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
+
DROP TABLE if exists pg_temp.iris_model, pg_temp.iris_model_summary;
SELECT madlib_keras_fit(
'iris_data_packed',
@@ -139,22 +144,11 @@
FROM evaluate_out;
-- TEST custom loss function
--- Custom loss function returns 0 as the loss
-CREATE OR REPLACE FUNCTION custom_function_zero_object()
-RETURNS BYTEA AS
-$$
-import dill
-def test_custom_fn(a, b):
- c = a*b*0
- return c
-
-pb=dill.dumps(test_custom_fn)
-return pb
-$$ language plpythonu;
DROP TABLE IF EXISTS test_custom_function_table;
SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
SELECT madlib_keras_fit(
@@ -168,6 +162,30 @@
FALSE, NULL, 1, NULL, NULL, NULL,
'test_custom_function_table'
);
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+ 'iris_data_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['test_custom_fn1']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3,
+ FALSE, NULL, 1, NULL, NULL, NULL,
+ 'test_custom_function_table'
+);
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+ 'iris_data_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3,
+ FALSE, NULL, 1, NULL, NULL, NULL,
+ 'test_custom_function_table'
+);
SELECT assert(
model_arch_table = 'iris_model_arch' AND
@@ -185,13 +203,13 @@
object_table = 'test_custom_function_table' AND
model_size > 0 AND
madlib_version is NOT NULL AND
- compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text AND
+ compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text AND
fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
num_iterations = 3 AND
metrics_compute_frequency = 1 AND
num_classes = 3 AND
class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
- metrics_type = '{mae}' AND
+ metrics_type = '{test_custom_fn1}' AND
array_upper(training_metrics, 1) = 3 AND
training_loss = '{0,0,0}' AND
array_upper(metrics_elapsed_time, 1) = 3 ,
@@ -212,7 +230,7 @@
SELECT assert(loss >= 0 AND
metric >= 0 AND
- metrics_type = '{mae}' AND
+ metrics_type = '{test_custom_fn1}' AND
loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
@@ -223,7 +241,21 @@
'iris_model',
'iris_model_arch',
1,
- $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn1', metrics=['mae']$$::text,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='fail_test_custom_fn', metrics=['mae']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3,
+ FALSE, NULL, 1, NULL, NULL, NULL,
+ 'test_custom_function_table'
+);$TRAP$) = 1,
+'custom function in compile_params not defined in Object table.');
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+ 'iris_data_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='accuracy', metrics=['fail_test_custom_fn']$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
3,
FALSE, NULL, 1, NULL, NULL, NULL,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index 0860b7d..c4c0315 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -26,6 +26,11 @@
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)
+\i m4_regexp(MODULE_PATHNAME,
+ `\(.*\)libmadlib\.so',
+ `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
+
m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
-- Multiple models End-to-End test
@@ -162,22 +167,10 @@
FROM evaluate_out;
-- TEST custom loss function
--- Custom loss function returns 0 as the loss
-CREATE OR REPLACE FUNCTION custom_function_zero_object()
-RETURNS BYTEA AS
-$$
-import dill
-def test_custom_fn(a, b):
- c = a*b*0
- return c
-
-pb=dill.dumps(test_custom_fn)
-return pb
-$$ language plpythonu;
-
DROP TABLE IF EXISTS test_custom_function_table;
SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
-- Prepare model selection table with four rows
DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
@@ -187,7 +180,7 @@
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
- $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
+ $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['test_custom_fn1']$$
],
ARRAY[
$$batch_size=16, epochs=1$$
@@ -229,7 +222,7 @@
model_type = 'madlib_keras' AND
model_size > 0 AND
fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
- metrics_type = '{accuracy}' AND
+ metrics_type = '{test_custom_fn1}' AND
training_metrics_final >= 0 AND
training_loss_final = 0 AND
training_loss = '{0,0,0}' AND
@@ -266,7 +259,7 @@
SELECT assert(loss = 0 AND
metric >= 0 AND
- metrics_type = '{accuracy}' AND
+ metrics_type = '{test_custom_fn1}' AND
loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;