DL: Add custom loss function support for DL module
This commit adds support for custom loss function to DL fit, evaluate,
fit_multiple. Following are the changes:
1. fit: A new optional param for passing in the object_table name
object_table (optional) VARCHAR: Name of the table containing Python
objects in the case that custom loss functions or custom metrics are
specified in the parameter `compile_params`
```
madlib_keras_fit(
source_table,
model,
model_arch_table,
model_id,
compile_params,
fit_params,
num_iterations,
use_gpus,
validation_table,
metrics_compute_frequency,
warm_start,
name,
description,
object_table -- new parameter
)
```
This new param is also outputed in the output summary table.
2. Adding helper functions to parse custom loss functions, query their
definitions from the object_table and create an object of a dictionary
of {'fn_name': 'fn_object'} to be passed to the fit functions, where it
is read and passed as a python object to keras.
3. Evaluate: No change to the madlib_keras_evaluate() function. Reads object_table
information from the fit/fit_multiple output model table.
Output table adds a new column:
loss_type: Type of loss used that was used in the training step
If a custom loss or metric is used, we should give the name of it.
Otherwise list the built-in one used
4. fit_multiple: No change to the madlib_keras_fit_multiple_model() function. Reads
object_table information from the model_selection table. The mst keys
having are populated by None object_map by default. If the object_table
exists, the helper function to parse custom loss functions from the
compile_params is called to get all the custom_function names. Once we
get all the custom function names, we query their definitions from the
object_table and create a single dictionary of {{'fn_name1':
'fn_object1'}, {'fn_name2': 'fn_object2'}...} and pass it to the fit
multiple functions, where it is read and the corresponding function
definition is passed as a python object to keras.
A summary table named <model>_summary is also created, which has the
following new columns:
model_selection_table: Name of the table containing model selection
parameters to be tried.
object_table: Name of the object table containing the serialized
Python objects for custom loss functions and custom metrics (read from
the mst_summary table).
5. Adding corresponding unit tests and dev-check tests
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 091fce2..8bb1531 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -61,7 +61,7 @@
del SD[SD_STORE.SESS]
def get_init_model_and_sess(SD, device_name, gpu_count, segments_per_host,
- model_architecture, compile_params):
+ model_architecture, compile_params, custom_function_map):
# If a live session is present, re-use it. Otherwise, recreate it.
if SD_STORE.SESS in SD :
if SD_STORE.SEGMENT_MODEL not in SD:
@@ -73,7 +73,7 @@
else:
sess = get_keras_session(device_name, gpu_count, segments_per_host)
K.set_session(sess)
- segment_model = init_model(model_architecture, compile_params)
+ segment_model = init_model(model_architecture, compile_params, custom_function_map)
SD_STORE.init_SD(SD, sess, segment_model)
return segment_model, sess
@@ -82,7 +82,7 @@
model_id, compile_params, fit_params, num_iterations,
use_gpus, validation_table=None,
metrics_compute_frequency=None, warm_start=False, name="",
- description="", **kwargs):
+ description="", object_table=None, **kwargs):
module_name = 'madlib_keras_fit'
fit_params = "" if not fit_params else fit_params
@@ -107,7 +107,7 @@
source_table, validation_table, model, model_arch_table,
model_id, mb_dep_var_col, mb_indep_var_col,
num_iterations, metrics_compute_frequency, warm_start,
- use_gpus, accessible_gpus_for_seg)
+ use_gpus, accessible_gpus_for_seg, object_table)
if metrics_compute_frequency is None:
metrics_compute_frequency = num_iterations
@@ -148,6 +148,19 @@
# Prepare the SQL for running distributed training via UDA
compile_params_to_pass = "$madlib$" + compile_params + "$madlib$"
fit_params_to_pass = "$madlib$" + fit_params + "$madlib$"
+ custom_function_map = None
+
+ # If the object_table exists, we read the list of custom
+ # function used in the compile_params and map it to their
+ # object definition from the object table
+ custom_fn_list = get_custom_functions_list(compile_params)
+ if object_table is not None:
+ custom_function_map = query_custom_functions_map(object_table, custom_fn_list)
+ elif len(custom_fn_list) >= 1:
+ # Error out if custom_function is called without specifying the object table
+ # with the function definition
+ plpy.error("Object table not specified for function {0} in compile_params".format(custom_fn_list))
+
run_training_iteration = plpy.prepare("""
SELECT {schema_madlib}.fit_step(
{mb_dep_var_col},
@@ -165,10 +178,11 @@
{use_gpus}::BOOLEAN,
ARRAY{accessible_gpus_for_seg},
$1,
- $2
+ $2,
+ $3
) AS iteration_result
FROM {source_table}
- """.format(**locals()), ["bytea", "boolean"])
+ """.format(**locals()), ["bytea", "boolean", "bytea"])
# Define the state for the model and loss/metric storage lists
training_loss, training_metrics, metrics_elapsed_time = [], [], []
@@ -182,7 +196,7 @@
start_iteration = time.time()
is_final_iteration = (i == num_iterations)
serialized_weights = plpy.execute(run_training_iteration,
- [serialized_weights, is_final_iteration]
+ [serialized_weights, is_final_iteration, custom_function_map]
)[0]['iteration_result']
end_iteration = time.time()
info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
@@ -194,7 +208,8 @@
compute_out = compute_loss_and_metrics(
schema_madlib, source_table, compile_params_to_pass, model_arch,
serialized_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
- images_per_seg_train, training_metrics, training_loss, i, is_final_iteration)
+ images_per_seg_train, training_metrics, training_loss, i, is_final_iteration,
+ custom_function_map)
metrics_iters.append(i)
compute_time, compute_metrics, compute_loss = compute_out
@@ -211,7 +226,7 @@
schema_madlib, validation_table, compile_params_to_pass,
model_arch, serialized_weights, use_gpus, accessible_gpus_for_seg,
seg_ids_val, images_per_seg_val, validation_metrics,
- validation_loss, i, is_final_iteration)
+ validation_loss, i, is_final_iteration, custom_function_map)
val_compute_time, val_compute_metrics, val_compute_loss = val_compute_out
info_str += "\n\tTime for evaluating validation dataset in "\
@@ -238,7 +253,6 @@
independent_varname = src_summary_dict['independent_varname_in_source_table']
# Define some constants to be inserted into the summary table.
model_type = "madlib_keras"
- compile_params_dict = convert_string_of_args_to_dict(compile_params)
metrics_list = get_metrics_from_compile_param(compile_params)
is_metrics_specified = True if metrics_list else False
metrics_type = 'ARRAY{0}'.format(metrics_list) if is_metrics_specified else 'NULL'
@@ -264,6 +278,7 @@
validation_metrics_final = validation_loss_final = 'NULL'
validation_table = 'NULL'
+ object_table = "$MAD${0}$MAD$".format(object_table) if object_table is not None else 'NULL'
if warm_start:
plpy.execute("DROP TABLE {0}, {1}".format
(model, fit_validator.output_summary_model_table))
@@ -280,6 +295,7 @@
$2 AS fit_params,
{num_iterations}::INTEGER AS num_iterations,
{validation_table}::TEXT AS validation_table,
+ {object_table}::TEXT AS object_table,
{metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
$3 AS name,
$4 AS description,
@@ -398,7 +414,8 @@
def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
serialized_weights, use_gpus, accessible_gpus_for_seg,
dist_key_mapping, images_per_seg_val, metrics_list, loss_list,
- curr_iter, is_final_iteration, model_table=None, mst_key=None):
+ curr_iter, is_final_iteration, custom_fn_name,
+ model_table=None, mst_key=None):
"""
Compute the loss and metric using a given model (serialized_weights) on the
given dataset (table.)
@@ -414,6 +431,7 @@
dist_key_mapping,
images_per_seg_val,
is_final_iteration,
+ custom_fn_name,
model_table,
mst_key)
end_val = time.time()
@@ -444,12 +462,12 @@
return (curr_iter)%metrics_compute_frequency == 0 or \
curr_iter == num_iterations
-def init_model(model_architecture, compile_params):
+def init_model(model_architecture, compile_params, custom_function_map):
"""
Should only be called at the first row of first iteration.
"""
segment_model = model_from_json(model_architecture)
- compile_model(segment_model, compile_params)
+ compile_model(segment_model, compile_params, custom_function_map)
return segment_model
def update_model(segment_model, prev_serialized_weights):
@@ -466,7 +484,7 @@
compile_params, fit_params, dist_key, dist_key_mapping,
current_seg_id, segments_per_host, images_per_seg, use_gpus,
accessible_gpus_for_seg, prev_serialized_weights, is_final_iteration=True,
- is_multiple_model=False, **kwargs):
+ is_multiple_model=False, custom_function_map=None, **kwargs):
"""
This transition function is common for madlib_keras_fit() and
madlib_keras_fit_multiple_model(). The important difference between
@@ -491,7 +509,8 @@
segment_model, sess = get_init_model_and_sess(SD, device_name,
accessible_gpus_for_seg[current_seg_id],
segments_per_host,
- model_architecture, compile_params)
+ model_architecture, compile_params,
+ custom_function_map)
if not state:
agg_image_count = 0
set_model_weights(segment_model, prev_serialized_weights)
@@ -623,10 +642,16 @@
InputValidator.validate_input_shape(
test_table, MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, input_shape, 2, True)
- compile_params_query = "SELECT compile_params, metrics_type FROM {0}".format(model_summary_table)
+ compile_params_query = "SELECT compile_params, metrics_type, object_table FROM {0}".format(model_summary_table)
res = plpy.execute(compile_params_query)[0]
metrics_type = res['metrics_type']
compile_params = "$madlib$" + res['compile_params'] + "$madlib$"
+ object_table = res['object_table']
+ loss_type = get_loss_from_compile_param(res['compile_params'])
+ custom_function_map = None
+ if object_table is not None:
+ custom_fn_list = get_custom_functions_list(res['compile_params'])
+ custom_function_map = query_custom_functions_map(object_table, custom_fn_list)
dist_key_mapping, images_per_seg = get_image_count_per_seg_for_minibatched_data_from_db(test_table)
@@ -634,7 +659,7 @@
get_loss_metric_from_keras_eval(
schema_madlib, test_table, compile_params, model_arch,
model_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
- images_per_seg)
+ images_per_seg, custom_function_map=custom_function_map)
if not metrics_type:
metrics_type = None
@@ -643,8 +668,8 @@
with MinWarning("error"):
create_output_table = plpy.prepare("""
CREATE TABLE {0} AS
- SELECT $1 as loss, $2 as metric, $3 as metrics_type""".format(output_table), ["FLOAT", "FLOAT", "TEXT[]"])
- plpy.execute(create_output_table, [loss, metric, metrics_type])
+ SELECT $1 as loss, $2 as metric, $3 as metrics_type, $4 as loss_type""".format(output_table), ["FLOAT", "FLOAT", "TEXT[]", "TEXT"])
+ plpy.execute(create_output_table, [loss, metric, metrics_type, loss_type])
if is_mult_model:
plpy.execute("DROP VIEW IF EXISTS {0}".format(model_summary_table))
@@ -674,7 +699,8 @@
def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
model_arch, serialized_weights, use_gpus,
accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
- is_final_iteration=True, model_table=None, mst_key=None):
+ is_final_iteration=True, custom_function_map=None,
+ model_table=None, mst_key=None):
dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
gp_segment_id_col = '0' if is_platform_pg() else '__table__.{0}'.format(GP_SEGMENT_ID_COLNAME)
@@ -709,7 +735,8 @@
ARRAY{images_per_seg},
{use_gpus}::BOOLEAN,
ARRAY{accessible_gpus_for_seg},
- {is_final_iteration}
+ {is_final_iteration},
+ {custom_map_var}
)) as loss_metric
from {table} AS __table__ {mult_sql}
"""
@@ -718,12 +745,15 @@
weights = '__mt__.{0}'.format(MODEL_WEIGHTS_COLNAME)
mst_key_col = ModelSelectionSchema.MST_KEY
mult_sql = ', {model_table} AS __mt__ WHERE {mst_key_col} = {mst_key}'.format(**locals())
- res = plpy.execute(eval_sql.format(**locals()))
+ custom_map_var = '$1'
+ evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea"])
+ res = plpy.execute(evaluate_query, [custom_function_map])
else:
weights = '$1'
mult_sql = ''
- evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea"])
- res = plpy.execute(evaluate_query, [serialized_weights])
+ custom_map_var = '$2'
+ evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea", "bytea"])
+ res = plpy.execute(evaluate_query, [serialized_weights, custom_function_map])
loss_metric = res[0]['loss_metric']
return loss_metric
@@ -735,7 +765,7 @@
dist_key, dist_key_mapping, current_seg_id,
segments_per_host, images_per_seg,
use_gpus, accessible_gpus_for_seg,
- is_final_iteration, **kwargs):
+ is_final_iteration, custom_function_map=None, **kwargs):
SD = kwargs['SD']
device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
agg_loss, agg_metric, agg_image_count = state
@@ -755,7 +785,7 @@
accessible_gpus_for_seg[current_seg_id],
segments_per_host,
model_architecture,
- compile_params)
+ compile_params, custom_function_map)
if not agg_image_count:
# These should already be 0, but just in case make sure
agg_metric = 0
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 90e7a98..f4f03cf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1653,7 +1653,8 @@
metrics_compute_frequency INTEGER,
warm_start BOOLEAN,
name VARCHAR,
- description VARCHAR
+ description VARCHAR,
+ object_table VARCHAR
) RETURNS VOID AS $$
PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
from utilities.control import SetGUC
@@ -1675,6 +1676,25 @@
validation_table VARCHAR,
metrics_compute_frequency INTEGER,
warm_start BOOLEAN,
+ name VARCHAR,
+ description VARCHAR
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
+ source_table VARCHAR,
+ model VARCHAR,
+ model_arch_table VARCHAR,
+ model_id INTEGER,
+ compile_params VARCHAR,
+ fit_params VARCHAR,
+ num_iterations INTEGER,
+ use_gpus BOOLEAN,
+ validation_table VARCHAR,
+ metrics_compute_frequency INTEGER,
+ warm_start BOOLEAN,
name VARCHAR
) RETURNS VOID AS $$
SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL);
@@ -1775,7 +1795,8 @@
use_gpus BOOLEAN,
accessible_gpus_for_seg INTEGER[],
prev_serialized_weights BYTEA,
- is_final_iteration BOOLEAN
+ is_final_iteration BOOLEAN,
+ custom_function_map BYTEA
) RETURNS BYTEA AS $$
PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
return madlib_keras.fit_transition(**globals())
@@ -1815,7 +1836,8 @@
BOOLEAN,
INTEGER[],
BYTEA,
- BOOLEAN);
+ BOOLEAN,
+ BYTEA);
CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
/* dep_var */ BYTEA,
/* ind_var */ BYTEA,
@@ -1832,7 +1854,8 @@
/* use_gpus */ BOOLEAN,
/* segments_per_host */ INTEGER[],
/* serialized_weights */ BYTEA,
- /* is_final_iteration */ BOOLEAN
+ /* is_final_iteration */ BOOLEAN,
+ /* custom_loss_cfunction */ BYTEA
)(
STYPE=BYTEA,
SFUNC=MADLIB_SCHEMA.fit_transition,
@@ -2048,7 +2071,8 @@
images_per_seg INTEGER[],
use_gpus BOOLEAN,
accessible_gpus_for_seg INTEGER[],
- is_final_iteration BOOLEAN
+ is_final_iteration BOOLEAN,
+ custom_function_map BYTEA
) RETURNS REAL[3] AS $$
PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
return madlib_keras.internal_keras_eval_transition(**globals())
@@ -2087,7 +2111,8 @@
INTEGER[],
BOOLEAN,
INTEGER[],
- BOOLEAN);
+ BOOLEAN,
+ BYTEA);
CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
/* dependent_var */ BYTEA,
@@ -2104,7 +2129,8 @@
/* images_per_seg*/ INTEGER[],
/* use_gpus */ BOOLEAN,
/* accessible_gpus_for_seg */ INTEGER[],
- /* is_final_iteration */ BOOLEAN
+ /* is_final_iteration */ BOOLEAN,
+ /* custom_function_map */ BYTEA
)(
STYPE=REAL[3],
INITCOND='{0,0,0}',
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index c122def..424250d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -153,7 +153,11 @@
self.msts = self.fit_validator_train.msts
self.model_arch_table = self.fit_validator_train.model_arch_table
+ self.object_table = self.fit_validator_train.object_table
self.metrics_iters = []
+ self.object_map_col = 'object_map'
+ if self.object_table is not None:
+ self.populate_object_map()
original_cuda_env = None
if CUDA_VISIBLE_DEVICES_KEY in os.environ:
@@ -272,6 +276,7 @@
seg_ids,
images_per_seg,
[], [], epoch, True,
+ mst[self.object_map_col],
self.model_output_table,
mst[self.mst_key_col])
mst_metric_eval_time[mst[self.mst_key_col]] \
@@ -287,6 +292,31 @@
grand_schedule[dist_key] = rotate(msts, index)
return grand_schedule
+ def populate_object_map(self):
+ builtin_losses = dir(losses)
+ # Track distinct custom functions in compile_params
+ custom_fn_names = []
+ # Track their corresponding mst_keys to pass along the custom function
+ # definition read from the object table.
+ # For compile_params calling builtin functions the object_map is set to
+ # None.
+ custom_fn_mst_idx = []
+ for mst, mst_idx in zip(self.msts, range(len(self.msts))):
+ compile_params = mst[self.compile_params_col]
+ # We assume that the compile_param is validated as part
+ # of the loading mst_table and thus not validating here
+ # Also, it is validated later when we compile the model
+ # on the segments
+ compile_dict = convert_string_of_args_to_dict(compile_params)
+ if (compile_dict['loss'] not in builtin_losses):
+ custom_fn_names.append(compile_dict['loss'])
+ custom_fn_mst_idx.append(mst_idx)
+ if len(custom_fn_names) > 0:
+ # Pass only unique custom_fn_names to query from object table
+ custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
+ for mst_idx in custom_fn_mst_idx:
+ self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
+
def create_mst_schedule_table(self, mst_row):
mst_temp_query = """
CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
@@ -294,7 +324,8 @@
{self.compile_params_col} VARCHAR,
{self.fit_params_col} VARCHAR,
{dist_key_col} INTEGER,
- {self.mst_key_col} INTEGER)
+ {self.mst_key_col} INTEGER,
+ {self.object_map_col} BYTEA)
""".format(dist_key_col=dist_key_col, **locals())
plpy.execute(mst_temp_query)
for mst, dist_key in zip(mst_row, self.dist_keys):
@@ -303,21 +334,24 @@
compile_params = mst[self.compile_params_col]
fit_params = mst[self.fit_params_col]
mst_key = mst[self.mst_key_col]
+ object_map = mst[self.object_map_col]
else:
model_id = "NULL"
compile_params = "NULL"
fit_params = "NULL"
mst_key = "NULL"
- mst_insert_query = """
+ object_map = None
+ mst_insert_query = plpy.prepare(
+ """
INSERT INTO {self.mst_current_schedule_tbl}
VALUES ({model_id},
$madlib${compile_params}$madlib$,
$madlib${fit_params}$madlib$,
{dist_key},
- {mst_key})
- """.format(**locals())
- plpy.execute(mst_insert_query)
-
+ {mst_key},
+ $1)
+ """.format(**locals()), ["BYTEA"])
+ plpy.execute(mst_insert_query, [object_map])
def create_model_output_table(self):
output_table_create_query = """
@@ -464,6 +498,8 @@
num_classes = len(class_values)
name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
+ object_table = 'NULL' if self.object_table is None \
+ else '$MAD${0}$MAD$'.format(self.object_table)
metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
class_values_colname = CLASS_VALUES_COLNAME
dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
@@ -479,6 +515,8 @@
$MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
$MAD${independent_varname}$MAD$::TEXT AS independent_varname,
$MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
+ $MAD${self.model_selection_table}$MAD$::TEXT AS model_selection_table,
+ {object_table}::TEXT AS object_table,
{self.num_iterations}::INTEGER AS num_iterations,
{self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
{self.warm_start} AS warm_start,
@@ -590,7 +628,8 @@
{use_gpus}::BOOLEAN,
ARRAY{self.accessible_gpus_for_seg},
{self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
- {is_final_iteration}::BOOLEAN
+ {is_final_iteration}::BOOLEAN,
+ {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
)::BYTEA AS {self.model_weights_col},
{self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
,src.{dist_key_col} AS {dist_key_col}
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 8d68385..392a3be 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1508,7 +1508,8 @@
use_gpus BOOLEAN,
accessible_gpus_for_seg INTEGER[],
prev_serialized_weights BYTEA,
- is_final_iteration BOOLEAN
+ is_final_iteration BOOLEAN,
+ custom_function_map BYTEA
) RETURNS BYTEA AS $$
PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
return madlib_keras.fit_transition(is_multiple_model = True, **globals())
@@ -1531,7 +1532,8 @@
BOOLEAN,
INTEGER[],
BYTEA,
- BOOLEAN);
+ BOOLEAN,
+ BYTEA);
CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
/* dependent_var */ BYTEA,
/* independent_var */ BYTEA,
@@ -1546,9 +1548,10 @@
/* segments_per_host */ INTEGER,
/* images_per_seg */ INTEGER[],
/* use_gpus */ BOOLEAN,
- /* accessible_gpus_for_seg */ INTEGER[],
+ /* accessible_gpus_for_seg */ INTEGER[],
/* prev_serialized_weights */ BYTEA,
- /* is_final_iteration */ BOOLEAN
+ /* is_final_iteration */ BOOLEAN,
+ /* custom_function_obj_map */ BYTEA
)(
STYPE=BYTEA,
SFUNC=MADLIB_SCHEMA.fit_transition_multiple_model
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index b2b7397..49e3a12 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -213,16 +213,20 @@
def query_model_configs(model_selection_table, model_selection_summary_table,
mst_key_col, model_arch_table_col):
msts_query = """
- SELECT * FROM {model_selection_table}
+ SELECT *, NULL as object_map FROM {model_selection_table}
ORDER BY {mst_key_col}
""".format(**locals())
- model_arch_table_query = """
- SELECT {model_arch_table_col}
+ from madlib_keras_model_selection import ModelSelectionSchema
+ object_table_col = ModelSelectionSchema.OBJECT_TABLE
+ summary_table_query = """
+ SELECT {model_arch_table_col}, {object_table_col}
FROM {model_selection_summary_table}
""".format(**locals())
msts = list(plpy.execute(msts_query))
- model_arch_table = plpy.execute(model_arch_table_query)[0][model_arch_table_col]
- return msts, model_arch_table
+ summary_res = plpy.execute(summary_table_query)
+ model_arch_table = summary_res[0][model_arch_table_col]
+ object_table = summary_res[0][object_table_col]
+ return msts, model_arch_table, object_table
def query_dist_keys(source_table, dist_key_col):
""" Read distinct keys from the source table """
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index a364a9e..bb2e744 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -249,7 +249,7 @@
model_arch_table, model_id, dependent_varname,
independent_varname, num_iterations,
metrics_compute_frequency, warm_start,
- use_gpus, accessible_gpus_for_seg, module_name):
+ use_gpus, accessible_gpus_for_seg, module_name, object_table):
self.source_table = source_table
self.validation_table = validation_table
self.output_model_table = output_model_table
@@ -262,6 +262,7 @@
self.metrics_compute_frequency = metrics_compute_frequency
self.warm_start = warm_start
self.num_iterations = num_iterations
+ self.object_table = object_table
self.source_summary_table = None
if self.source_table:
self.source_summary_table = add_postfix(
@@ -283,6 +284,9 @@
"{0}: metrics_compute_frequency must be in the range (1 - {1}).".format(
self.module_name, self.num_iterations))
input_tbl_valid(self.source_table, self.module_name)
+ if self.object_table is not None:
+ input_tbl_valid(self.object_table, self.module_name)
+ cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name)
input_tbl_valid(self.source_summary_table, self.module_name,
error_suffix_str="Please ensure that the source table ({0}) "
"has been preprocessed by "
@@ -384,7 +388,7 @@
model_arch_table, model_id, dependent_varname,
independent_varname, num_iterations,
metrics_compute_frequency, warm_start,
- use_gpus, accessible_gpus_for_seg):
+ use_gpus, accessible_gpus_for_seg, object_table):
self.module_name = 'madlib_keras_fit'
super(FitInputValidator, self).__init__(source_table,
@@ -399,7 +403,8 @@
warm_start,
use_gpus,
accessible_gpus_for_seg,
- self.module_name)
+ self.module_name,
+ object_table)
InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
self.model_id)
@@ -418,7 +423,7 @@
"has been created by "
"load_model_selection_table().".format(
model_selection_table))
- self.msts, self.model_arch_table = query_model_configs(
+ self.msts, self.model_arch_table, self.object_table = query_model_configs(
model_selection_table, model_selection_summary_table,
mst_key_col, model_arch_table_col)
if warm_start:
@@ -437,7 +442,8 @@
warm_start,
use_gpus,
accessible_gpus_for_seg,
- self.module_name)
+ self.module_name,
+ self.object_table)
class MstLoaderInputValidator():
def __init__(self,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 541d370..575be98 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -18,8 +18,10 @@
# under the License.
import ast
+import dill
import os
import plpy
+from collections import defaultdict
from math import ceil
from keras import backend as K
@@ -31,6 +33,8 @@
import madlib_keras_serializer
import madlib_keras_gpu_info
+from madlib_keras_custom_function import CustomFunctionSchema
+
from utilities.utilities import _assert
from utilities.utilities import is_platform_pg
@@ -175,6 +179,16 @@
"please refer to the documentation").format(ckey))
return metrics
+def get_loss_from_compile_param(str_of_args):
+ compile_dict = convert_string_of_args_to_dict(str_of_args)
+ loss = None
+ if 'loss' in compile_dict:
+ loss = compile_dict['loss']
+ else:
+ plpy.error(("Invalid input value for parameter 'loss', "
+ "please refer to the documentation"))
+ return loss
+
# Parse the compile parameters and the optimizer.
def parse_and_validate_compile_params(str_of_args):
"""
@@ -307,9 +321,12 @@
return optimizers
# Run the keras.compile with the given parameters
-def compile_model(model, compile_params):
+def compile_model(model, compile_params, custom_function_map=None):
optimizers = get_optimizers()
(opt_name,final_args,compile_dict) = parse_and_validate_compile_params(compile_params)
+ if custom_function_map is not None:
+ map=dill.loads(custom_function_map)
+ compile_dict['loss']=map[compile_dict['loss']]
compile_dict['optimizer'] = optimizers[opt_name](**final_args) if final_args else opt_name
model.compile(**compile_dict)
@@ -331,3 +348,66 @@
compile_dict['sample_weight_mode'] is None or
compile_dict['sample_weight_mode'] == "temporal",
"""compile parameter sample_weight_mode can only be "temporal" or None""")
+
+# Returns an object of custom function name and it corresponding object
+def query_custom_functions_map(object_table, custom_fn_names):
+ """
+ Args:
+ @param: object_table Name of the object table
+ @param: custom_fn_names List of custom function read from compile_param
+ if custom function exisst in compile_params,
+ expected list length >= 1
+ else,
+ an empty list is passed in
+ Returns:
+ custom_fn_map_obj: A dill object of a dictionary mapping custom function
+ name to its definition object as read from the object
+ table
+ Example:
+ {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+
+ """
+ if len(custom_fn_names) < 1:
+ return None
+ custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
+ # Dictionary map of name:object
+ # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+ custom_fn_map = defaultdict(list)
+ # Query the custom function if not yet loaded from table
+ res = plpy.execute("""
+ SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table}
+ WHERE {custom_fn_col_name} = ANY(ARRAY{custom_fn_names})
+ """.format(custom_obj_col_name=custom_obj_col_name,
+ object_table=object_table,
+ custom_fn_col_name=CustomFunctionSchema.FN_NAME,
+ custom_fn_names=custom_fn_names))
+ if res.nrows() < len(custom_fn_names):
+ plpy.error("Custom function {0} not defined in object table '{1}'".format(custom_fn_names, object_table))
+ for r in res:
+ custom_fn_map[r[CustomFunctionSchema.FN_NAME]] = dill.loads(r[custom_obj_col_name])
+ custom_fn_map_obj = dill.dumps(custom_fn_map)
+ return custom_fn_map_obj
+
+def get_custom_functions_list(compile_params):
+ """
+ Args:
+ @param: compile_params compile params passed to keras.compile
+ Returns:
+ custom_fn_list: List of custom function read from compile_param
+ if custom function exist in compile_params,
+ returns list length >= 1
+ else,
+ returns an empty list
+ Example:
+ if custom function exist in compile_params,
+ returns [custom_fn1, custom_fn2, ....]
+ else,
+ []
+
+ """
+ compile_dict = convert_string_of_args_to_dict(compile_params)
+ builtin_losses = dir(losses)
+ custom_fn_list = []
+ if compile_dict['loss'] not in builtin_losses:
+ custom_fn_list.append(compile_dict['loss'])
+ return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index a35eb6b..bd77532 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -56,6 +56,7 @@
pg_typeof(normalizing_const) = 'real'::regtype AND
name is NULL AND
description is NULL AND
+ object_table is NULL AND
model_size > 0 AND
madlib_version is NOT NULL AND
compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text AND
@@ -417,3 +418,15 @@
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
3);
+
+-- Test invalid loss function in compile_param
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+ 'cifar_10_sample_test_shape_batched',
+ 'keras_saved_out',
+ 'model_arch',
+ 3,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='custom_fn', metrics=['accuracy']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3);$TRAP$) = 1,
+ 'Object table not specified for custom function in compile_params.');
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index 20e2332..7eb4c15 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -137,3 +137,96 @@
metric >= 0 AND
metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
+
+-- TEST custom loss function
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+ c = a*b*0
+ return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+ 'iris_data_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3,
+ FALSE, NULL, 1, NULL, NULL, NULL,
+ 'test_custom_function_table'
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ model_id = 1 AND
+ model_type = 'madlib_keras' AND
+ source_table = 'iris_data_packed' AND
+ model = 'iris_model' AND
+ dependent_varname = 'class_text' AND
+ independent_varname = 'attributes' AND
+ dependent_vartype LIKE '%char%' AND
+ normalizing_const = 1 AND
+ pg_typeof(normalizing_const) = 'real'::regtype AND
+ name is NULL AND
+ description is NULL AND
+ object_table = 'test_custom_function_table' AND
+ model_size > 0 AND
+ madlib_version is NOT NULL AND
+ compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text AND
+ fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+ num_iterations = 3 AND
+ metrics_compute_frequency = 1 AND
+ num_classes = 3 AND
+ class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+ metrics_type = '{mae}' AND
+ array_upper(training_metrics, 1) = 3 AND
+ training_loss = '{0,0,0}' AND
+ array_upper(metrics_elapsed_time, 1) = 3 ,
+ 'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_model_summary) summary;
+
+SELECT assert(
+ model_weights IS NOT NULL AND
+ model_arch IS NOT NULL, 'Keras model output validation failed. Actual:' || __to_char(k))
+FROM (SELECT * FROM iris_model) k;
+
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+ 'iris_model',
+ 'iris_data_val',
+ 'evaluate_out',
+ FALSE);
+
+SELECT assert(loss >= 0 AND
+ metric >= 0 AND
+ metrics_type = '{mae}' AND
+ loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+ 'iris_data_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn1', metrics=['mae']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3,
+ FALSE, NULL, 1, NULL, NULL, NULL,
+ 'test_custom_function_table'
+);$TRAP$) = 1,
+'custom function in compile_params not defined in Object table.');
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index fa90c86..90442e9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -201,6 +201,8 @@
model_info = 'iris_multiple_model_info' AND
source_table = 'iris_data_one_hot_encoded_packed' AND
model = 'iris_multiple_model' AND
+ model_selection_table = 'mst_table_4row' AND
+ object_table IS NULL AND
dependent_varname = 'class_one_hot_encoded' AND
independent_varname = 'attributes' AND
madlib_version is NOT NULL AND
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index d738f48..0860b7d 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -161,5 +161,115 @@
metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
+-- TEST custom loss function
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+ c = a*b*0
+ return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+-- Prepare model selection table with four rows
+DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_model_selection_table(
+ 'iris_model_arch',
+ 'mst_object_table',
+ ARRAY[1],
+ ARRAY[
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+ $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=16, epochs=1$$
+ ],
+ 'test_custom_function_table'
+);
+
+DROP TABLE if exists iris_multiple_model_custom_fn, iris_multiple_model_custom_fn_summary, iris_multiple_model_custom_fn_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model_custom_fn',
+ 'mst_object_table',
+ 3,
+ FALSE,
+ 'iris_data_one_hot_encoded_packed',
+ 1
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table = 'iris_data_one_hot_encoded_packed' AND
+ model_info = 'iris_multiple_model_custom_fn_info' AND
+ source_table = 'iris_data_packed' AND
+ model = 'iris_multiple_model_custom_fn' AND
+ dependent_varname = 'class_text' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ num_classes = 3 AND
+ class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+ dependent_vartype LIKE '%char%' AND
+ normalizing_const = 1,
+ 'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_custom_fn_summary) summary;
+
+SELECT assert(
+ model_type = 'madlib_keras' AND
+ model_size > 0 AND
+ fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
+ metrics_type = '{accuracy}' AND
+ training_metrics_final >= 0 AND
+ training_loss_final = 0 AND
+ training_loss = '{0,0,0}' AND
+ array_upper(training_metrics, 1) = 3 AND
+ array_upper(training_loss, 1) = 3 AND
+ validation_metrics_final >= 0 AND
+ validation_loss_final = 0 AND
+ array_upper(validation_metrics, 1) = 3 AND
+ array_upper(validation_loss, 1) = 3 AND
+ array_upper(metrics_elapsed_time, 1) = 3,
+ 'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_custom_fn_info where compile_params like '%test_custom_fn%') info;
+
+-- Run Predict
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict(
+ 'iris_multiple_model_custom_fn',
+ 'iris_data',
+ 'id',
+ 'attributes',
+ 'pg_temp.iris_predict',
+ 'prob',
+ NULL,
+ 1);
+
+-- Run Evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+ 'iris_multiple_model_custom_fn',
+ 'iris_data_val',
+ 'evaluate_out',
+ NULL,
+ 2);
+
+SELECT assert(loss = 0 AND
+ metric >= 0 AND
+ metrics_type = '{accuracy}' AND
+ loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+
+
DROP SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__ CASCADE;
!>)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 43ad38b..bffd5a9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -270,6 +270,86 @@
5 IN (SELECT mst_key FROM iris_multiple_model_info),
'mst_key 5 should be in the info table since it has been added to mst_table');
+-- warm start with custom function
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+ c = a*b*0
+ return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+ 'iris_model_arch',
+ 'mst_table',
+ ARRAY[1,2],
+ ARRAY[
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$,
+ $$loss='test_custom_fn', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=5,epochs=1$$
+ ],
+ 'test_custom_function_table'
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ FALSE, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+ training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+-- warm start for fit multiple model
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ FALSE,
+ NULL, 1,
+ TRUE -- warm_start
+);
+
+-- Test that when warm_start is TRUE, all the output tables are persistent(not unlogged)
+SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
+SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
+SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged');
+
+
+SELECT assert(
+ array_upper(training_loss, 1) = 3 AND
+ array_upper(training_metrics, 1) = 3,
+ 'metrics compute frequency must be 1.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+ abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+ abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+ abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+ abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+ 'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+
-- Transfer learning tests
-- Load the same arch again so that we can compare transfer learning results
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index ed3e0da..d3b2cd7 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1020,7 +1020,7 @@
obj = self.subject.FitCommonValidator(
'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
'dep_varname', 'independent_varname', 5, None, False, False, [0],
- 'module_name')
+ 'module_name', None)
self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
def test_is_valid_metrics_compute_frequency_True_num(self):
@@ -1028,7 +1028,7 @@
obj = self.subject.FitCommonValidator(
'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
'dep_varname', 'independent_varname', 5, 3, False, False, [0],
- 'module_name')
+ 'module_name', None)
self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
def test_is_valid_metrics_compute_frequency_False_zero(self):
@@ -1036,7 +1036,7 @@
obj = self.subject.FitCommonValidator(
'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
'dep_varname', 'independent_varname', 5, 0, False, False, [0],
- 'module_name')
+ 'module_name', None)
self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
def test_is_valid_metrics_compute_frequency_False_greater(self):
@@ -1044,7 +1044,7 @@
obj = self.subject.FitCommonValidator(
'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
'dep_varname', 'independent_varname', 5, 6, False, False, [0],
- 'module_name')
+ 'module_name', None)
self.assertEqual(False, obj._is_valid_metrics_compute_frequency())