blob: b6ce52559de651e5d913c5b96171e898021fdf17 [file] [log] [blame]
m4_include(`SQLCommon.m4')
m4_changequote(<<<,>>>)
m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
,<<<
m4_changequote(<!,!>)
-- =================== Setup & Initialization for FitMultiple tests ========================
--
-- For fit multiple, we test end-to-end functionality along with performance elsewhere.
-- They take a long time to run. Including similar tests here would probably not be worth
-- the extra time added to dev-check.
--
-- Instead, we just want to unit test different python functions in the FitMultiple class.
-- However, most of the important behavior we need to test requires access to an actual
-- Greenplum database... mostly, we want to make sure that the models hop around to the
-- right segments in the right order. Therefore, the unit tests are here, as a part of
-- dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
-- do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
-- get through to gpdb. We also want to mock the number of segments, so we can test what
-- the model hopping behavior will be for a large cluster, even though dev-check should be
-- able to run on a single dev host.
\i m4_regexp(MODULE_PATHNAME,
<!\(.*\)libmadlib\.so!>,
<!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
)
-- Mock version() function to convince the InputValidator this is the real madlib schema
CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
$$
SELECT MADLIB_SCHEMA.version();
$$ LANGUAGE sql IMMUTABLE;
-- Call this first to initialize the FitMultiple object, before anything else happens.
-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
-- are filled in. They can be overriden later, before test functions are called, if necessary.
CREATE OR REPLACE FUNCTION init_fit_mult(
source_table VARCHAR,
model_selection_table VARCHAR
) RETURNS VOID AS
$$
import sys
from mock import Mock, patch
PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
schema_madlib = 'madlib_installcheck_deep_learning'
GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
schema_madlib,
source_table,
'orig_model_out',
model_selection_table,
1
)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
CREATE OR REPLACE FUNCTION test_init_schedule(
schedule_table VARCHAR
) RETURNS BOOLEAN AS
$$
fit_mult = GD['fit_mult']
fit_mult.schedule_tbl = schedule_table
plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
if fit_mult.init_schedule_tbl():
err_msg = None
else:
err_msg = 'FitMultiple.init_schedule_tbl() returned False'
return err_msg
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
CREATE OR REPLACE FUNCTION test_rotate_schedule(
schedule_table VARCHAR
) RETURNS VOID AS
$$
fit_mult = GD['fit_mult']
if fit_mult.schedule_tbl != schedule_table:
fit_mult.init_schedule_tbl()
fit_mult.rotate_schedule_tbl()
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
-- Mock fit_transition function, for testing
-- madlib_keras_fit_multiple_model() python code
CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
dependent_var BYTEA[],
independent_var BYTEA[],
dependent_var_shape INTEGER[],
independent_var_shape INTEGER[],
model_architecture TEXT,
compile_params TEXT,
fit_params TEXT,
dist_key INTEGER,
dist_key_mapping INTEGER[],
current_seg_id INTEGER,
segments_per_host INTEGER,
images_per_seg INTEGER[],
accessible_gpus_for_seg INTEGER[],
serialized_weights BYTEA,
is_final_training_call BOOLEAN,
use_caching BOOLEAN,
custom_function_map BYTEA
) RETURNS BYTEA AS
$$
param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
'independent_var_shape', 'use_caching' ]
num_calls = 1
if 'transition_function_params' in GD:
if dist_key in GD['transition_function_params']:
if not 'reset' in GD['transition_function_params'][dist_key]:
num_calls = GD['transition_function_params'][dist_key]['num_calls']
num_calls += 1
g = globals()
params = dict()
for k in param_keys:
params[k] = g[k]
params['dependent_var'] = len(dependent_var[0]) if dependent_var[0] else 0
params['independent_var'] = len(independent_var[0]) if independent_var[0] else 0
params['num_calls'] = num_calls
if not 'transition_function_params' in GD:
GD['transition_function_params'] = dict()
GD['transition_function_params'][dist_key] = params
# compute simulated seg_id ( current_seg_id is the actual seg_id )
seg_id = dist_key_mapping.index( dist_key )
if dependent_var_shape[0] and dependent_var_shape[0][0] * num_calls < images_per_seg [ seg_id ]:
return None
else:
GD['transition_function_params'][dist_key]['reset'] = True
return serialized_weights
$$ LANGUAGE plpythonu VOLATILE;
CREATE OR REPLACE FUNCTION validate_transition_function_params(
current_seg_id INTEGER,
segments_per_host INTEGER,
images_per_seg INTEGER[],
expected_num_calls INTEGER,
expected_dist_key INTEGER,
expected_is_final_training_call BOOLEAN,
expected_dist_key_mapping INTEGER[],
dependent_var_len INTEGER,
independent_var_len INTEGER,
use_caching BOOLEAN
) RETURNS TEXT AS
$$
err_msg = "transition function was not called on segment {}".format(current_seg_id)
if 'transition_function_params' not in GD:
return err_msg
elif expected_dist_key not in GD['transition_function_params']:
return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
actual = GD['transition_function_params'][expected_dist_key]
err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
Actual={}, Expected={}"""
validation_map = {
'current_seg_id' : current_seg_id,
'segments_per_host' : segments_per_host,
'num_calls' : expected_num_calls,
'is_final_training_call' : expected_is_final_training_call,
'dist_key' : expected_dist_key,
'dependent_var' : dependent_var_len,
'independent_var' : independent_var_len,
'use_caching' : use_caching
}
for param, expected in validation_map.items():
if actual[param] != expected:
return err_msg.format(
param,
actual[param],
expected
)
return 'PASS' # actual params match expected params
$$ LANGUAGE plpythonu VOLATILE;
-- Helper to rotate an array of int's
CREATE OR REPLACE FUNCTION rotate_keys(
keys INTEGER[]
) RETURNS INTEGER[]
AS $$
return keys[-1:] + keys[:-1]
$$ LANGUAGE plpythonu IMMUTABLE;
CREATE OR REPLACE FUNCTION reverse_rotate_keys(
keys INTEGER[]
) RETURNS INTEGER[]
AS $$
return keys[1:] + keys[:1]
$$ LANGUAGE plpythonu IMMUTABLE;
CREATE OR REPLACE FUNCTION setup_model_tables(
input_table TEXT,
output_table TEXT,
cached_source_table TEXT
) RETURNS TEXT AS
$$
fit_mult = GD['fit_mult']
fit_mult.model_input_tbl = input_table
fit_mult.model_output_tbl = output_table
fit_mult.cached_source_table = cached_source_table
plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
fit_mult.init_model_output_tbl()
q = """
UPDATE {model_out} -- Reduce size of model for faster tests
SET ( model_weights, model_arch, compile_params, fit_params )
= ( mst_key::TEXT::BYTEA,
( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
'c' || mst_key::TEXT,
'f' || mst_key::TEXT
)
WHERE mst_key IS NOT NULL;
""".format(model_out=fit_mult.model_output_tbl)
plpy.execute(q)
$$ LANGUAGE plpythonu VOLATILE;
-- Updates dist keys in src table and internal fit_mult class variables
-- num_data_segs can be larger than actual number of segments, since this
-- is just for simulated testing. This will also write to expected_distkey_mappings_tbl
-- which can be used for validating dist key mappings and images per seg later.
CREATE OR REPLACE FUNCTION update_dist_keys(
src_table TEXT,
num_data_segs INTEGER,
num_models INTEGER,
expected_distkey_mappings_tbl TEXT
) RETURNS VOID AS
$$
redist_cmd = """
UPDATE {src_table}
SET __dist_key__ = (buffer_id % {num_data_segs})
""".format(**globals())
plpy.execute(redist_cmd)
fit_mult = GD['fit_mult']
q = """
SELECT SUM(attributes_shape[1]) AS image_count,
__dist_key__
FROM {src_table}
GROUP BY __dist_key__
ORDER BY __dist_key__
""".format(**globals())
res = plpy.execute(q)
images_per_seg = [ int(r['image_count']) for r in res ]
dist_keys = [ int(r['__dist_key__']) for r in res ]
num_dist_keys = len(dist_keys)
fit_mult.source_table = src_table
fit_mult.max_dist_key = sorted(dist_keys)[-1]
fit_mult.images_per_seg_train = images_per_seg
fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
fit_mult.segments_per_host = num_data_segs
fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
if num_models < num_dist_keys:
fit_mult.msts_for_schedule += [None] * \
(num_dist_keys - num_models)
fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
for mst in fit_mult.msts_for_schedule ]
fit_mult.num_msts = num_models
fit_mult.extra_dist_keys = []
for i in range(num_models - num_dist_keys):
fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
create_distkey_map_tbl_cmd = """
DROP TABLE IF EXISTS {exp_table};
CREATE TABLE {exp_table} AS
SELECT
ARRAY( -- map of dist_keys to seg_ids from source table
SELECT __dist_key__
FROM {fm.source_table}
GROUP BY __dist_key__
ORDER BY __dist_key__ -- This would be gp_segment_id if it weren't a simulation
) AS expected_dist_key_mapping,
ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
{num_data_segs} AS segments_per_host,
__dist_key__
FROM {fm.source_table}
GROUP BY __dist_key__
DISTRIBUTED BY (__dist_key__);
""".format(
fm=fit_mult,
num_data_segs=num_data_segs,
exp_table=expected_distkey_mappings_tbl
)
plpy.execute(create_distkey_map_tbl_cmd)
$$ LANGUAGE plpythonu VOLATILE;
CREATE OR REPLACE FUNCTION test_run_training(
source_table TEXT,
hop INTEGER,
is_very_first_hop BOOLEAN,
is_final_training_call BOOLEAN,
use_caching BOOLEAN
) RETURNS VOID AS
$$
fit_mult = GD['fit_mult']
# Each time we start a new test, clear out stats
# like num_calls from GD so we don't end up validating
# against old results
if 'transition_function_params' in GD:
del GD['transition_function_params']
fit_mult.source_tbl = source_table
fit_mult.is_very_first_hop = is_very_first_hop
fit_mult.is_final_training_call = is_final_training_call
if use_caching != fit_mult.use_caching:
fit_mult.udf_plan = None # Otherwise it will execute the wrong
# query when use_caching changes!
fit_mult.use_caching = use_caching
fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
RETURNS VOID AS
$$
DECLARE
actual INTEGER[];
expected INTEGER[];
BEGIN
EXECUTE 'SELECT ARRAY(' ||
'SELECT mst_key FROM ' || output_tbl || ' ORDER BY __dist_key__)'
INTO actual;
EXECUTE 'SELECT mst_keys FROM ' || expected_tbl
INTO expected;
PERFORM assert(
actual = expected,
'mst keys found in wrong order / wrong segments!' ||
E'\nActual: ' || actual::text ||
E'\nExpected: ' || expected::text
);
END
$$ LANGUAGE PLpgSQL VOLATILE;
-- Create mst table
DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'iris_mst_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
],
ARRAY[
$$batch_size=5,epochs=1$$,
$$batch_size=10,epochs=1$$
]
);
-- Create FitMultiple object for running test functions
SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');
CREATE TABLE src_3segs AS
SELECT * FROM iris_data_15buf_packed
DISTRIBUTED BY (__dist_key__);
-- Simulate 6 models on 3 segments --
SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');
--=== Test init_schedule_tbl() ===--
-- ====================================================================
-- =========== Enough setup, now for the actual tests! ===============
-- ====================================================================
SELECT test_init_schedule('current_schedule');
SELECT assert(
s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);
-- Save order of mst keys in schedule for tracking
DROP TABLE IF EXISTS expected_order;
CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
--=== Test rotate_schedule() ===--
SELECT test_rotate_schedule('current_schedule');
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('current_schedule', 'expected_order');
UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys); -- Undo for later
-- Initialize model_output table, and set model_input & cached_src table names
SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
SELECT validate_mst_key_order('model_output', 'expected_order');
-- Order of params in run_training test function (for reference below):
--
-- test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)
--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
SELECT test_run_training('src_3segs', 0, False, False, False);
-- mst_keys should not have moved
SELECT validate_mst_key_order('model_output', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
3,
s.expected_images_per_seg,
5, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
12, -- dependent_var length
32, -- independent_var length
False -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
--=== Test an ordinary hop - no caching (# msts > # segs) ===--
SELECT test_run_training('src_3segs', 1, False, False, False);
SELECT test_rotate_schedule('current_schedule');
-- check that mst keys rotated onto correct segments
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('model_output', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
5, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
12, -- dependent_var length
32, -- independent_var length
False -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
--=== Test final training hop - no caching (# msts > # segs) ===--
SELECT test_run_training('src_3segs', 8, False, True, False);
-- check that mst keys rotated onto correct segments
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('model_output', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
5, -- expected num_calls (per dist_key)
s.__dist_key__,
True, -- expected is_final_training_call
s.expected_dist_key_mapping,
12, -- dependent_var length
32, -- independent_var length
False -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
--=== Test very first hop - caching enabled ( # msts > # segs ) ===--
SELECT test_run_training('src_3segs', 0, True, False, True);
-- mst_keys should not have moved
SELECT validate_mst_key_order('model_output', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
5, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
12, -- dependent_var length
32, -- independent_var length
True -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
-- validate that cached source table was created with proper dist keys
SELECT assert(
c.__dist_key__ IS NOT NULL AND s.__dist_key__ IS NOT NULL,
'cached src table was not created or dist keys do not match original src table')
FROM cached_src c FULL JOIN (SELECT __dist_key__ FROM src_3segs GROUP BY __dist_key__) s USING(__dist_key__);
-- Test ordinary hop - caching enabled ( # msts > # segs )
SELECT test_run_training('src_3segs', 7, False, False, True);
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('model_output', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
1, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
0, -- dependent_var length
0, -- independent_var length
True -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
-- Test final training hop - caching enabled ( # msts > # segs )
SELECT test_run_training('src_3segs', 2, False, True, True);
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('model_output', 'expected_order');
-- independent_var & dependent_var should have both been passed as NULL
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
1, -- expected num_calls (per dist_key)
s.__dist_key__,
True, -- expected is_final_training_call
s.expected_dist_key_mapping,
0, -- dependent_var length
0, -- independent_var length
True -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
--=== Simulate 3 models on 3 segments ===--
SELECT update_dist_keys('src_3segs', 3, 3, 'expected_dist_key_mappings');
DELETE FROM iris_mst_table WHERE ARRAY[mst_key] <@ (SELECT mst_keys FROM expected_order);
SELECT test_init_schedule('current_schedule');
SELECT assert(
COUNT(*) = 3,
'Wrong number of mst_keys in schedule table created by test_init_schedule()\n' ||
'Expected: 3\nActual: ' || COUNT(*)::TEXT
) FROM current_schedule;
-- Make sure none of the entries in the schedule table are NULL
-- ( this should only happen for # msts < # segs case )
SELECT assert(
COUNT(*) = 0,
'NULL mst_key found in schedule table created by test_init_schedule, even though # msts = # segs'
) FROM current_schedule WHERE mst_key IS NULL;
-- Save new order of mst keys in schedule for tracking
DROP TABLE IF EXISTS expected_order;
CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
SELECT validate_mst_key_order('model_output', 'expected_order');
SELECT test_rotate_schedule('current_schedule');
-- Test ordinary hop - no caching ( # msts = # segs )
SELECT test_run_training('src_3segs', 2, False, False, False);
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('model_output', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
5, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
12, -- dependent_var length
32, -- independent_var length
False -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
--=== Simulate 3 models on 5 segments ( # msts < # segs ) ===--
-- ( by updating dist keys in source table )
CREATE TABLE src_5segs AS
SELECT * FROM iris_data_15buf_packed
DISTRIBUTED BY (__dist_key__);
SELECT update_dist_keys('src_5segs', 5, 3, 'expected_dist_key_mappings');
SELECT test_init_schedule('current_schedule');
SELECT assert(
COUNT(*) = 2,
'Wrong number NULL entries in schedule table created by test_init_schedule()\n' ||
'Expected: 2\nActual: ' || COUNT(*)::TEXT
) FROM current_schedule WHERE mst_key IS NULL;
SELECT assert(
COUNT(*) = 3,
'Wrong number of non-NULL entries in schedule table created by test_init_schedule()\n' ||
'Expected: 3\nActual: ' || COUNT(*)::TEXT
) FROM current_schedule WHERE mst_key IS NOT NULL;
-- Save expected mst_key order
DROP TABLE IF EXISTS expected_order;
CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
-- Initialize model_output table, and set model_input & cached_src table names
SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
DROP TABLE IF EXISTS model_output_ext;
CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
-- Make sure model_output was created with correct mst_key order
SELECT validate_mst_key_order('model_output_ext', 'expected_order');
--=== Test very first hop - caching enabled ( # msts < # segs ) ===--
SELECT test_run_training('src_5segs', 0, True, False, True);
SELECT test_rotate_schedule('current_schedule');
-- Verify mst keys did not move
DROP TABLE IF EXISTS model_output_ext;
CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
SELECT validate_mst_key_order('model_output_ext', 'expected_order');
-- verify transition func was called correct # of times with correct params
-- This should generate an Assertion failure if the transition function was not called for
-- any __dist_key__, even if there is no model on that segment.
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
3, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
12, -- dependent_var length
32, -- independent_var length
True -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
-- Test ordinary hop - caching enabled ( # msts < # segs )
-- This should generate an Assertion failure if the transition function is
-- called for any __dist_key__ where mst_key is NULL
SELECT test_run_training('src_5segs', 1, False, False, True);
-- verify mst_keys moved to correct segments
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
DROP TABLE IF EXISTS model_output_ext;
CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
SELECT validate_mst_key_order('model_output_ext', 'expected_order');
-- verify transition func was called correct # of times with correct params
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
1, -- expected num_calls (per dist_key)
s.__dist_key__,
False, -- expected is_final_training_call
s.expected_dist_key_mapping,
0, -- dependent_var length
0, -- independent_var length
True -- use_caching
) AS res,
s.__dist_key__ -- WHERE clause restricts this check to segments with models
FROM expected_dist_key_mappings s WHERE ARRAY[__dist_key__] <@ ARRAY(SELECT __dist_key__ FROM current_schedule WHERE mst_key IS NOT NULL)
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
SELECT test_rotate_schedule('current_schedule');
-- Test final training hop - caching enabled ( # msts < # segs )
SELECT test_run_training('src_5segs', 5, False, True, True);
-- verify mst_keys moved to correct segments
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
DROP TABLE IF EXISTS model_output_ext;
CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
SELECT validate_mst_key_order('model_output_ext', 'expected_order');
-- verify transition func was called correct # of times with correct params
-- This should generate an Assertion failure if the transition function was not
-- called for any __dist_key__, even if there is no model on that segment.
DROP TABLE IF EXISTS validate_params_results;
CREATE TABLE validate_params_results AS
SELECT validate_transition_function_params(
s.gp_segment_id,
s.segments_per_host,
s.expected_images_per_seg,
1, -- expected num_calls (per dist_key)
s.__dist_key__,
True, -- expected is_final_training_call
s.expected_dist_key_mapping,
0, -- dependent_var length
0, -- independent_var length
True -- use_caching
) AS res,
s.__dist_key__
FROM expected_dist_key_mappings s
DISTRIBUTED BY (__dist_key__);
SELECT assert(res = 'PASS', res) FROM validate_params_results;
-- We don't want to hide the madlib versions of these for any other
-- test files that run afterwards
DROP FUNCTION madlib_installcheck_deep_learning.version();
DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
dependent_var BYTEA[],
independent_var BYTEA[],
dependent_var_shape INTEGER[],
independent_var_shape INTEGER[],
model_architecture TEXT,
compile_params TEXT,
fit_params TEXT,
dist_key INTEGER,
dist_key_mapping INTEGER[],
current_seg_id INTEGER,
segments_per_host INTEGER,
images_per_seg INTEGER[],
accessible_gpus_for_seg INTEGER[],
serialized_weights BYTEA,
is_final_training_call BOOLEAN,
use_caching BOOLEAN,
custom_function_map BYTEA
)
>>>) -- m4_endif postgres