DL: Add optional parameters for multi model training
JIRA: MADLIB-1397
This commit adds the following optional params:
- metrics_compute_frequency
- warm_start
- name
- description
It also fixes a bug where the users CUDA env variable is overwritten before it can be saved.
Closes #461
Co-authored-by: Ekta Khanna <ekhanna@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 702c288..5e35b46 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -306,7 +306,8 @@
#TODO add a unit test for this in a future PR
reset_cuda_env(original_cuda_env)
-def get_initial_weights(model_table, model_arch, serialized_weights, warm_start, gpus_per_host):
+def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
+ gpus_per_host):
"""
If warm_start is True, return back initial weights from model table.
If warm_start is False, first try to get the weights from model_arch
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 4ce21dd..883ed22 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -27,6 +27,7 @@
from madlib_keras import get_model_arch_weights
from madlib_keras import get_segments_and_gpus
from madlib_keras import get_source_summary_table_dict
+from madlib_keras import should_compute_metrics_this_iter
from madlib_keras_helper import *
from madlib_keras_model_selection import ModelSelectionSchema
from madlib_keras_validator import *
@@ -73,7 +74,9 @@
class FitMultipleModel():
def __init__(self, schema_madlib, source_table, model_output_table,
model_selection_table, num_iterations,
- gpus_per_host=0, validation_table=None, **kwargs):
+ gpus_per_host=0, validation_table=None,
+ metrics_compute_frequency=None, warm_start=False, name="",
+ description="", **kwargs):
# set the random seed for visit order/scheduling
random.seed(1)
if is_platform_pg():
@@ -90,6 +93,9 @@
self.model_summary_table = add_postfix(
model_output_table, '_summary')
self.num_iterations = num_iterations
+ self.metrics_compute_frequency = metrics_compute_frequency
+ self.name = name
+ self.description = description
self.module_name = 'madlib_keras_fit_multiple_model'
self.schema_madlib = schema_madlib
self.version = madlib_version(self.schema_madlib)
@@ -111,9 +117,18 @@
self.model_selection_table, self.model_selection_summary_table,
mb_dep_var_col, mb_indep_var_col, self.num_iterations,
self.model_info_table, self.mst_key_col, self.model_arch_table_col,
- 1, False)
+ self.metrics_compute_frequency, warm_start)
+ if self.metrics_compute_frequency is None:
+ self.metrics_compute_frequency = num_iterations
+ self.warm_start = bool(warm_start)
self.msts = self.fit_validator_train.msts
self.model_arch_table = self.fit_validator_train.model_arch_table
+ self.metrics_iters = []
+
+ original_cuda_env = None
+ if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+ original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+
self.seg_ids_train, self.images_per_seg_train = \
get_image_count_per_seg_for_minibatched_data_from_db(
self.source_table)
@@ -138,26 +153,24 @@
self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
gpus_per_host)
- self.create_model_output_table()
+ if not self.warm_start:
+ self.create_model_output_table()
self.weights_to_update_tbl = unique_string(desp='weights_to_update')
self.fit_multiple_model()
+ reset_cuda_env(original_cuda_env)
def fit_multiple_model(self):
# WARNING: set orca off to prevent unwanted redistribution
with OptimizerControl(False):
- original_cuda_env = None
- if CUDA_VISIBLE_DEVICES_KEY in os.environ:
- original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
self.start_training_time = datetime.datetime.now()
self.train_multiple_model()
self.end_training_time = datetime.datetime.now()
self.insert_info_table()
self.create_model_summary_table()
- reset_cuda_env(original_cuda_env)
def train_multiple_model(self):
total_msts = len(self.msts_for_schedule)
- for iter in range(self.num_iterations):
+ for iter in range(1, self.num_iterations+1):
for mst_idx in range(total_msts):
mst_row = [self.grand_schedule[dist_key][mst_idx]
for dist_key in self.dist_keys]
@@ -167,12 +180,16 @@
self.run_training()
if mst_idx == (total_msts - 1):
end_iteration = time.time()
- self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(
- iter, end_iteration - start_iteration)
- self.info_str += "\tTraining set after iteration {0}:".format(iter)
- self.evaluate_model(iter, self.source_table, True)
- if self.validation_table:
- self.evaluate_model(iter, self.validation_table, False)
+ self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(iter,
+ end_iteration - start_iteration)
+ if should_compute_metrics_this_iter(iter,
+ self.metrics_compute_frequency,
+ self.num_iterations):
+ self.metrics_iters.append(iter)
+ self.info_str += "\tTraining set after iteration {0}:".format(iter)
+ self.evaluate_model(iter, self.source_table, True)
+ if self.validation_table:
+ self.evaluate_model(iter, self.validation_table, False)
plpy.info("\n"+self.info_str)
def evaluate_model(self, epoch, table, is_train):
@@ -246,7 +263,6 @@
plpy.execute(mst_insert_query)
def create_model_output_table(self):
-
output_table_create_query = """
CREATE TABLE {self.model_output_table}
({self.mst_key_col} INTEGER PRIMARY KEY,
@@ -282,10 +298,8 @@
model_arch,
model_weights,
False,
- self.gpus_per_host
- )
+ self.gpus_per_host)
model = model_from_json(model_arch)
-
serialized_state = model_weights if model_weights else \
madlib_keras_serializer.serialize_nd_weights(model.get_weights())
@@ -295,7 +309,6 @@
is_metrics_specified = True if metrics_list else False
metrics_type = 'ARRAY{0}'.format(
metrics_list) if is_metrics_specified else 'NULL'
-
output_table_insert_query = """
INSERT INTO {self.model_output_table}(
{self.mst_key_col}, {self.model_weights_col},
@@ -327,6 +340,8 @@
plpy.execute(info_table_insert_query)
def create_model_summary_table(self):
+ if self.warm_start:
+ plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
src_summary_dict = get_source_summary_table_dict(self.fit_validator_train)
class_values = src_summary_dict['class_values']
dep_vartype = src_summary_dict['dep_vartype']
@@ -344,6 +359,9 @@
class_values_str = 'ARRAY{0}::{1}'.format(class_values,
src_summary_dict['class_values_type'])
num_classes = len(class_values)
+ name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
+ descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
+ metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
class_values_colname = CLASS_VALUES_COLNAME
dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
normalizing_const_colname = NORMALIZING_CONST_COLNAME
@@ -359,13 +377,18 @@
$MAD${independent_varname}$MAD$::TEXT AS independent_varname,
$MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
{self.num_iterations}::INTEGER AS num_iterations,
+ {self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
+ {self.warm_start} AS warm_start,
+ {name}::TEXT AS name,
+ {descr}::TEXT AS description,
'{self.start_training_time}'::TIMESTAMP AS start_training_time,
'{self.end_training_time}'::TIMESTAMP AS end_training_time,
'{self.version}'::TEXT AS madlib_version,
{num_classes}::INTEGER AS num_classes,
{class_values_str} AS {class_values_colname},
$MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname},
- {norm_const}::{float32_sql_type} AS {normalizing_const_colname}
+ {norm_const}::{float32_sql_type} AS {normalizing_const_colname},
+ ARRAY{metrics_iters}::INTEGER[] AS metrics_iters
""".format(**locals())
plpy.execute(update_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index ea19523..433535d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -35,7 +35,11 @@
model_selection_table VARCHAR,
num_iterations INTEGER,
gpus_per_host INTEGER,
- validation_table VARCHAR
+ validation_table VARCHAR,
+ metrics_compute_frequency INTEGER,
+ warm_start BOOLEAN,
+ name VARCHAR,
+ description VARCHAR
) RETURNS VOID AS $$
PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
with AOControl(False):
@@ -48,9 +52,63 @@
model_output_table VARCHAR,
model_selection_table VARCHAR,
num_iterations INTEGER,
+ gpus_per_host INTEGER,
+ validation_table VARCHAR,
+ metrics_compute_frequency INTEGER,
+ warm_start BOOLEAN,
+ name VARCHAR
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, $7, $8, $9, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+ source_table VARCHAR,
+ model_output_table VARCHAR,
+ model_selection_table VARCHAR,
+ num_iterations INTEGER,
+ gpus_per_host INTEGER,
+ validation_table VARCHAR,
+ metrics_compute_frequency INTEGER,
+ warm_start BOOLEAN
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, $7, $8, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+ source_table VARCHAR,
+ model_output_table VARCHAR,
+ model_selection_table VARCHAR,
+ num_iterations INTEGER,
+ gpus_per_host INTEGER,
+ validation_table VARCHAR,
+ metrics_compute_frequency INTEGER
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, $7, FALSE, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+ source_table VARCHAR,
+ model_output_table VARCHAR,
+ model_selection_table VARCHAR,
+ num_iterations INTEGER,
+ gpus_per_host INTEGER,
+ validation_table VARCHAR
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, NULL, FALSE, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+ source_table VARCHAR,
+ model_output_table VARCHAR,
+ model_selection_table VARCHAR,
+ num_iterations INTEGER,
gpus_per_host INTEGER
) RETURNS VOID AS $$
- SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, NULL);
+ SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, NULL, NULL, FALSE, NULL, NULL);
$$ LANGUAGE sql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index e24b8bd..49b8934 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -394,7 +394,10 @@
self.msts, self.model_arch_table = query_model_configs(
model_selection_table, model_selection_summary_table,
mst_key_col, model_arch_table_col)
- output_tbl_valid(model_info_table, self.module_name)
+ if warm_start:
+ input_tbl_valid(model_info_table, self.module_name)
+ else:
+ output_tbl_valid(model_info_table, self.module_name)
super(FitMultipleInputValidator, self).__init__(source_table,
validation_table,
output_model_table,
@@ -407,9 +410,12 @@
warm_start,
self.module_name)
+ if warm_start:
+ mst_count = plpy.execute("SELECT count(*) FROM {0}".format(model_selection_table))[0]['count']
+ warm_count = plpy.execute("SELECT count(*) FROM {0}".format(output_model_table))[0]['count']
-
-
+ _assert(mst_count <= warm_count,
+ "{self.module_name} error: Model table and mst table do not match".format(self=self))
class MstLoaderInputValidator():
def __init__(self,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 73df519..c9511d9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -54,7 +54,8 @@
if value:
set_cuda_env(value)
else:
- del os.environ[CUDA_VISIBLE_DEVICES_KEY]
+ if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+ del os.environ[CUDA_VISIBLE_DEVICES_KEY]
def get_device_name_and_set_cuda_env(gpus_per_host, seg):
if gpus_per_host > 0:
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 2a20467..adc771e 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -205,7 +205,8 @@
dependent_vartype = 'integer[]' AND
num_classes = NULL AND
class_values = NULL AND
- normalizing_const = 1,
+ normalizing_const = 1 AND
+ metrics_iters = ARRAY[3],
'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
@@ -236,7 +237,10 @@
num_classes = 3 AND
class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
dependent_vartype LIKE '%char%' AND
- normalizing_const = 1,
+ normalizing_const = 1 AND
+ name IS NULL AND
+ description IS NULL AND
+ metrics_compute_frequency = 6,
'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
@@ -251,13 +255,13 @@
metrics_type = '{accuracy}' AND
training_metrics_final >= 0 AND
training_loss_final >= 0 AND
- array_upper(training_metrics, 1) = 6 AND
- array_upper(training_loss, 1) = 6 AND
+ array_upper(training_metrics, 1) = 1 AND
+ array_upper(training_loss, 1) = 1 AND
validation_metrics_final >= 0 AND
validation_loss_final >= 0 AND
- array_upper(validation_metrics, 1) = 6 AND
- array_upper(validation_loss, 1) = 6 AND
- array_upper(metrics_elapsed_time, 1) = 6,
+ array_upper(validation_metrics, 1) = 1 AND
+ array_upper(validation_loss, 1) = 1 AND
+ array_upper(metrics_elapsed_time, 1) = 1,
'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_info) info;
@@ -291,7 +295,12 @@
'iris_multiple_model',
'mst_table_1row',
3,
- 0
+ 0,
+ NULL,
+ 1,
+ FALSE,
+ 'multi_model_name',
+ 'multi_model_descr'
);
SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
@@ -311,6 +320,13 @@
'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_info) info;
+SELECT assert(
+ name = 'multi_model_name' AND
+ description = 'multi_model_descr' AND
+ metrics_compute_frequency = 1,
+ 'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
SELECT assert(cnt = 1,
'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
FROM (SELECT count(*) cnt FROM iris_multiple_model_info
@@ -336,9 +352,9 @@
metrics_type = '{accuracy}' AND
training_metrics_final >= 0 AND
training_loss_final >= 0 AND
- array_upper(training_metrics, 1) = 3 AND
- array_upper(training_loss, 1) = 3 AND
- array_upper(metrics_elapsed_time, 1) = 3,
+ array_upper(training_metrics, 1) = 1 AND
+ array_upper(training_loss, 1) = 1 AND
+ array_upper(metrics_elapsed_time, 1) = 1,
'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_info) info;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 220569d..0ab09b7 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -126,6 +126,129 @@
m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+ 'iris_model_arch',
+ 'mst_table',
+ ARRAY[1,2],
+ ARRAY[
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=5,epochs=1$$
+ ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+ training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+-- warm start for fit multiple model
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0,
+ NULL, 1,
+ TRUE -- warm_start
+);
+
+SELECT assert(
+ array_upper(training_loss, 1) = 3 AND
+ array_upper(training_metrics, 1) = 3,
+ 'metrics compute frequency must be 1.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+ abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+ abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+ abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+ abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+ 'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+
+-- warm start with different mst tables
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+ 'iris_model_arch',
+ 'mst_table',
+ ARRAY[1],
+ ARRAY[
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=5,epochs=1$$,
+ $$batch_size=10,epochs=1$$,
+ $$batch_size=15,epochs=1$$,
+ $$batch_size=20,epochs=1$$
+ ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+ training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+DELETE FROM mst_table WHERE mst_key = 4;
+
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0, NULL, 1,
+ TRUE);
+
+
+SELECT assert(
+ abs(first.training_loss_final-second.training_loss_final) < 1e-6,
+ 'The loss should not change for mst_key 4 since it has been removed from mst_table')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND second.mst_key = 4;
+
+INSERT INTO mst_table SELECT 4 AS mst_key, model_id, compile_params,
+ 'batch_size=8, epochs=1' FROM mst_table WHERE mst_key = 1;
+
+INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params,
+ 'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1;
+
+SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0,
+ NULL, 1,
+ TRUE -- warm_start
+);$TRAP$) = 1, 'Warm start with extra mst keys should fail.');
+
+-- Transfer learning tests
+
-- Load the same arch again so that we can compare transfer learning results
SELECT load_keras_model('iris_model_arch', -- Output table,
$$
@@ -169,15 +292,9 @@
]
);
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
- 'iris_data_packed',
- 'iris_multiple_model',
- 'mst_table',
- 3,
- 0
-);
+UPDATE iris_model_arch
+SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE mst_key=1)
+WHERE model_id = 1;
DROP TABLE IF EXISTS iris_model_first_run;
CREATE TABLE iris_model_first_run AS
@@ -185,10 +302,6 @@
training_loss_final, training_metrics_final
FROM iris_multiple_model_info;
-UPDATE iris_model_arch
-SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE mst_key=1)
-WHERE model_id = 1;
-
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT madlib_keras_fit_multiple_model(
'iris_data_packed',
@@ -199,11 +312,8 @@
);
SELECT assert(
- abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
- abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
- abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
- abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+ (first.training_loss_final-second.training_loss_final) > 1e-6,
'Transfer learning test failed because training loss and metrics don''t match the expected value.')
FROM iris_model_first_run AS first, iris_multiple_model_info AS second
-WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+WHERE first.mst_key = second.mst_key AND first.model_id = 1;
!>)