| m4_include(`SQLCommon.m4') |
| m4_changequote(<<<,>>>) |
| m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres |
| ,<<< |
| m4_changequote(<!,!>) |
| |
| -- =================== Setup & Initialization for FitMultiple tests ======================== |
| -- |
| -- For fit multiple, we test end-to-end functionality along with performance elsewhere. |
| -- They take a long time to run. Including similar tests here would probably not be worth |
| -- the extra time added to dev-check. |
| -- |
| -- Instead, we just want to unit test different python functions in the FitMultiple class. |
| -- However, most of the important behavior we need to test requires access to an actual |
| -- Greenplum database... mostly, we want to make sure that the models hop around to the |
| -- right segments in the right order. Therefore, the unit tests are here, as a part of |
| -- dev-check. we mock fit_transition() and some validation functions in FitMultiple, but |
| -- do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to |
| -- get through to gpdb. We also want to mock the number of segments, so we can test what |
| -- the model hopping behavior will be for a large cluster, even though dev-check should be |
| -- able to run on a single dev host. |
| |
| \i m4_regexp(MODULE_PATHNAME, |
| <!\(.*\)libmadlib\.so!>, |
| <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!> |
| ) |
| |
| -- Mock version() function to convince the InputValidator this is the real madlib schema |
| CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS |
| $$ |
| SELECT MADLIB_SCHEMA.version(); |
| $$ LANGUAGE sql IMMUTABLE; |
| |
| -- Call this first to initialize the FitMultiple object, before anything else happens. |
| -- Pass a real mst table and source table, rest of FitMultipleModel() constructor params |
| -- are filled in. They can be overriden later, before test functions are called, if necessary. |
| CREATE OR REPLACE FUNCTION init_fit_mult( |
| source_table VARCHAR, |
| model_selection_table VARCHAR |
| ) RETURNS VOID AS |
| $$ |
| import sys |
| from mock import Mock, patch |
| |
| PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model) |
| schema_madlib = 'madlib_installcheck_deep_learning' |
| |
| GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel( |
| schema_madlib, |
| source_table, |
| 'orig_model_out', |
| model_selection_table, |
| 1 |
| ) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA); |
| |
| CREATE OR REPLACE FUNCTION test_init_schedule( |
| schedule_table VARCHAR |
| ) RETURNS BOOLEAN AS |
| $$ |
| fit_mult = GD['fit_mult'] |
| fit_mult.schedule_tbl = schedule_table |
| |
| plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table)) |
| if fit_mult.init_schedule_tbl(): |
| err_msg = None |
| else: |
| err_msg = 'FitMultiple.init_schedule_tbl() returned False' |
| |
| return err_msg |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); |
| |
| CREATE OR REPLACE FUNCTION test_rotate_schedule( |
| schedule_table VARCHAR |
| ) RETURNS VOID AS |
| $$ |
| fit_mult = GD['fit_mult'] |
| |
| if fit_mult.schedule_tbl != schedule_table: |
| fit_mult.init_schedule_tbl() |
| |
| fit_mult.rotate_schedule_tbl() |
| |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); |
| |
| -- Mock fit_transition function, for testing |
| -- madlib_keras_fit_multiple_model() python code |
| CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model( |
| dependent_var BYTEA[], |
| independent_var BYTEA[], |
| dependent_var_shape INTEGER[], |
| independent_var_shape INTEGER[], |
| model_architecture TEXT, |
| compile_params TEXT, |
| fit_params TEXT, |
| dist_key INTEGER, |
| dist_key_mapping INTEGER[], |
| current_seg_id INTEGER, |
| segments_per_host INTEGER[], |
| images_per_seg INTEGER[], |
| accessible_gpus_for_seg INTEGER[], |
| serialized_weights BYTEA, |
| is_final_training_call BOOLEAN, |
| use_caching BOOLEAN, |
| custom_function_map BYTEA |
| ) RETURNS BYTEA AS |
| $$ |
| param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping', |
| 'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call', |
| 'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params', |
| 'independent_var_shape', 'use_caching' ] |
| |
| num_calls = 1 |
| if 'transition_function_params' in GD: |
| if dist_key in GD['transition_function_params']: |
| if not 'reset' in GD['transition_function_params'][dist_key]: |
| num_calls = GD['transition_function_params'][dist_key]['num_calls'] |
| num_calls += 1 |
| |
| g = globals() |
| params = dict() |
| |
| for k in param_keys: |
| params[k] = g[k] |
| |
| params['dependent_var'] = len(dependent_var[0]) if dependent_var[0] else 0 |
| params['independent_var'] = len(independent_var[0]) if independent_var[0] else 0 |
| params['num_calls'] = num_calls |
| |
| if not 'transition_function_params' in GD: |
| GD['transition_function_params'] = dict() |
| GD['transition_function_params'][dist_key] = params |
| |
| # compute simulated seg_id ( current_seg_id is the actual seg_id ) |
| seg_id = dist_key_mapping.index( dist_key ) |
| |
| if dependent_var_shape[0] and dependent_var_shape[0][0] * num_calls < images_per_seg [ seg_id ]: |
| return None |
| else: |
| GD['transition_function_params'][dist_key]['reset'] = True |
| return serialized_weights |
| $$ LANGUAGE plpythonu VOLATILE; |
| |
| CREATE OR REPLACE FUNCTION validate_transition_function_params( |
| current_seg_id INTEGER, |
| segments_per_host INTEGER[], |
| images_per_seg INTEGER[], |
| expected_num_calls INTEGER, |
| expected_dist_key INTEGER, |
| expected_is_final_training_call BOOLEAN, |
| expected_dist_key_mapping INTEGER[], |
| dependent_var_len INTEGER, |
| independent_var_len INTEGER, |
| use_caching BOOLEAN |
| ) RETURNS TEXT AS |
| $$ |
| err_msg = "transition function was not called on segment {}".format(current_seg_id) |
| |
| if 'transition_function_params' not in GD: |
| return err_msg |
| elif expected_dist_key not in GD['transition_function_params']: |
| return err_msg + " for __dist_key__ = {}".format(expected_dist_key) |
| actual = GD['transition_function_params'][expected_dist_key] |
| |
| err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model: |
| Actual={}, Expected={}""" |
| |
| validation_map = { |
| 'current_seg_id' : current_seg_id, |
| 'segments_per_host' : segments_per_host, |
| 'num_calls' : expected_num_calls, |
| 'is_final_training_call' : expected_is_final_training_call, |
| 'dist_key' : expected_dist_key, |
| 'dependent_var' : dependent_var_len, |
| 'independent_var' : independent_var_len, |
| 'use_caching' : use_caching |
| } |
| |
| for param, expected in validation_map.items(): |
| if actual[param] != expected: |
| return err_msg.format( |
| param, |
| actual[param], |
| expected |
| ) |
| |
| return 'PASS' # actual params match expected params |
| $$ LANGUAGE plpythonu VOLATILE; |
| |
| -- Helper to rotate an array of int's |
| CREATE OR REPLACE FUNCTION rotate_keys( |
| keys INTEGER[] |
| ) RETURNS INTEGER[] |
| AS $$ |
| return keys[-1:] + keys[:-1] |
| $$ LANGUAGE plpythonu IMMUTABLE; |
| |
| CREATE OR REPLACE FUNCTION reverse_rotate_keys( |
| keys INTEGER[] |
| ) RETURNS INTEGER[] |
| AS $$ |
| return keys[1:] + keys[:1] |
| $$ LANGUAGE plpythonu IMMUTABLE; |
| |
| CREATE OR REPLACE FUNCTION setup_model_tables( |
| input_table TEXT, |
| output_table TEXT, |
| cached_source_table TEXT |
| ) RETURNS TEXT AS |
| $$ |
| fit_mult = GD['fit_mult'] |
| |
| fit_mult.model_input_tbl = input_table |
| fit_mult.model_output_tbl = output_table |
| fit_mult.cached_source_table = cached_source_table |
| |
| plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table)) |
| plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table)) |
| fit_mult.init_model_output_tbl() |
| q = """ |
| UPDATE {model_out} -- Reduce size of model for faster tests |
| SET ( model_weights, model_arch, compile_params, fit_params ) |
| = ( mst_key::TEXT::BYTEA, |
| ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON, |
| 'c' || mst_key::TEXT, |
| 'f' || mst_key::TEXT |
| ) |
| WHERE mst_key IS NOT NULL; |
| """.format(model_out=fit_mult.model_output_tbl) |
| plpy.execute(q) |
| $$ LANGUAGE plpythonu VOLATILE; |
| |
| -- Updates dist keys in src table and internal fit_mult class variables |
| -- num_data_segs can be larger than actual number of segments, since this |
| -- is just for simulated testing. This will also write to expected_distkey_mappings_tbl |
| -- which can be used for validating dist key mappings and images per seg later. |
| CREATE OR REPLACE FUNCTION update_dist_keys( |
| src_table TEXT, |
| num_data_segs INTEGER, |
| num_models INTEGER, |
| expected_distkey_mappings_tbl TEXT |
| ) RETURNS VOID AS |
| $$ |
| redist_cmd = """ |
| UPDATE {src_table} |
| SET __dist_key__ = (buffer_id % {num_data_segs}) |
| """.format(**globals()) |
| plpy.execute(redist_cmd) |
| |
| fit_mult = GD['fit_mult'] |
| |
| q = """ |
| SELECT SUM(attributes_shape[1]) AS image_count, |
| __dist_key__ |
| FROM {src_table} |
| GROUP BY __dist_key__ |
| ORDER BY __dist_key__ |
| """.format(**globals()) |
| res = plpy.execute(q) |
| |
| images_per_seg = [ int(r['image_count']) for r in res ] |
| dist_keys = [ int(r['__dist_key__']) for r in res ] |
| num_dist_keys = len(dist_keys) |
| |
| fit_mult.source_table = src_table |
| fit_mult.max_dist_key = sorted(dist_keys)[-1] |
| fit_mult.images_per_seg_train = images_per_seg |
| fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys |
| fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys |
| data_distribution_per_seg = [num_data_segs] * num_dist_keys |
| fit_mult.segments_per_host = data_distribution_per_seg |
| |
| fit_mult.msts_for_schedule = fit_mult.msts[:num_models] |
| if num_models < num_dist_keys: |
| fit_mult.msts_for_schedule += [None] * \ |
| (num_dist_keys - num_models) |
| fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\ |
| for mst in fit_mult.msts_for_schedule ] |
| fit_mult.num_msts = num_models |
| |
| fit_mult.extra_dist_keys = [] |
| for i in range(num_models - num_dist_keys): |
| fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i) |
| fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys |
| |
| create_distkey_map_tbl_cmd = """ |
| DROP TABLE IF EXISTS {exp_table}; |
| CREATE TABLE {exp_table} AS |
| SELECT |
| ARRAY( -- map of dist_keys to seg_ids from source table |
| SELECT __dist_key__ |
| FROM {fm.source_table} |
| GROUP BY __dist_key__ |
| ORDER BY __dist_key__ -- This would be gp_segment_id if it weren't a simulation |
| ) AS expected_dist_key_mapping, |
| ARRAY{fm.images_per_seg_train} AS expected_images_per_seg, |
| ARRAY{data_distribution_per_seg} AS segments_per_host, |
| __dist_key__ |
| FROM {fm.source_table} |
| GROUP BY __dist_key__ |
| DISTRIBUTED BY (__dist_key__); |
| """.format( |
| fm=fit_mult, |
| data_distribution_per_seg=data_distribution_per_seg, |
| exp_table=expected_distkey_mappings_tbl |
| ) |
| plpy.execute(create_distkey_map_tbl_cmd) |
| $$ LANGUAGE plpythonu VOLATILE; |
| |
| CREATE OR REPLACE FUNCTION test_run_training( |
| source_table TEXT, |
| hop INTEGER, |
| is_very_first_hop BOOLEAN, |
| is_final_training_call BOOLEAN, |
| use_caching BOOLEAN |
| ) RETURNS VOID AS |
| $$ |
| fit_mult = GD['fit_mult'] |
| |
| # Each time we start a new test, clear out stats |
| # like num_calls from GD so we don't end up validating |
| # against old results |
| if 'transition_function_params' in GD: |
| del GD['transition_function_params'] |
| |
| fit_mult.source_tbl = source_table |
| fit_mult.is_very_first_hop = is_very_first_hop |
| fit_mult.is_final_training_call = is_final_training_call |
| if use_caching != fit_mult.use_caching: |
| fit_mult.udf_plan = None # Otherwise it will execute the wrong |
| # query when use_caching changes! |
| fit_mult.use_caching = use_caching |
| |
| fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop) |
| $$ LANGUAGE plpythonu VOLATILE |
| m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); |
| |
| CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT) |
| RETURNS VOID AS |
| $$ |
| DECLARE |
| actual INTEGER[]; |
| expected INTEGER[]; |
| BEGIN |
| EXECUTE 'SELECT ARRAY(' || |
| 'SELECT mst_key FROM ' || output_tbl || ' ORDER BY __dist_key__)' |
| INTO actual; |
| |
| EXECUTE 'SELECT mst_keys FROM ' || expected_tbl |
| INTO expected; |
| |
| PERFORM assert( |
| actual = expected, |
| 'mst keys found in wrong order / wrong segments!' || |
| E'\nActual: ' || actual::text || |
| E'\nExpected: ' || expected::text |
| ); |
| END |
| $$ LANGUAGE PLpgSQL VOLATILE; |
| |
| -- Create mst table |
| DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary; |
| SELECT load_model_selection_table( |
| 'iris_model_arch', |
| 'iris_mst_table', |
| ARRAY[1], |
| ARRAY[ |
| $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$, |
| $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$, |
| $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$ |
| ], |
| ARRAY[ |
| $$batch_size=5,epochs=1$$, |
| $$batch_size=10,epochs=1$$ |
| ] |
| ); |
| |
| -- Create FitMultiple object for running test functions |
| SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table'); |
| |
| CREATE TABLE src_3segs AS |
| SELECT * FROM iris_data_15buf_packed |
| DISTRIBUTED BY (__dist_key__); |
| |
| -- Simulate 6 models on 3 segments -- |
| SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings'); |
| |
| --=== Test init_schedule_tbl() ===-- |
| -- ==================================================================== |
| -- =========== Enough setup, now for the actual tests! =============== |
| -- ==================================================================== |
| |
| SELECT test_init_schedule('current_schedule'); |
| SELECT assert( |
| s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL, |
| 'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table' |
| ) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key); |
| |
| -- Save order of mst keys in schedule for tracking |
| DROP TABLE IF EXISTS expected_order; |
| CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys; |
| |
| --=== Test rotate_schedule() ===-- |
| SELECT test_rotate_schedule('current_schedule'); |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| SELECT validate_mst_key_order('current_schedule', 'expected_order'); |
| UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys); -- Undo for later |
| |
| -- Initialize model_output table, and set model_input & cached_src table names |
| SELECT setup_model_tables('model_input', 'model_output', 'cached_src'); |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- Order of params in run_training test function (for reference below): |
| -- |
| -- test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching) |
| |
| --=== Test first hop of an iteration - no caching (# msts > # segs) ===-- |
| SELECT test_run_training('src_3segs', 0, False, False, False); |
| |
| -- mst_keys should not have moved |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| ARRAY[3, 3, 3], |
| s.expected_images_per_seg, |
| 5, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 12, -- dependent_var length |
| 32, -- independent_var length |
| False -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| --=== Test an ordinary hop - no caching (# msts > # segs) ===-- |
| SELECT test_run_training('src_3segs', 1, False, False, False); |
| SELECT test_rotate_schedule('current_schedule'); |
| |
| -- check that mst keys rotated onto correct segments |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 5, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 12, -- dependent_var length |
| 32, -- independent_var length |
| False -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| --=== Test final training hop - no caching (# msts > # segs) ===-- |
| SELECT test_run_training('src_3segs', 8, False, True, False); |
| |
| -- check that mst keys rotated onto correct segments |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 5, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| True, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 12, -- dependent_var length |
| 32, -- independent_var length |
| False -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| --=== Test very first hop - caching enabled ( # msts > # segs ) ===-- |
| SELECT test_run_training('src_3segs', 0, True, False, True); |
| |
| -- mst_keys should not have moved |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 5, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 12, -- dependent_var length |
| 32, -- independent_var length |
| True -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| -- validate that cached source table was created with proper dist keys |
| SELECT assert( |
| c.__dist_key__ IS NOT NULL AND s.__dist_key__ IS NOT NULL, |
| 'cached src table was not created or dist keys do not match original src table') |
| FROM cached_src c FULL JOIN (SELECT __dist_key__ FROM src_3segs GROUP BY __dist_key__) s USING(__dist_key__); |
| |
| -- Test ordinary hop - caching enabled ( # msts > # segs ) |
| SELECT test_run_training('src_3segs', 7, False, False, True); |
| |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 1, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 0, -- dependent_var length |
| 0, -- independent_var length |
| True -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| -- Test final training hop - caching enabled ( # msts > # segs ) |
| SELECT test_run_training('src_3segs', 2, False, True, True); |
| |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- independent_var & dependent_var should have both been passed as NULL |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 1, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| True, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 0, -- dependent_var length |
| 0, -- independent_var length |
| True -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| --=== Simulate 3 models on 3 segments ===-- |
| SELECT update_dist_keys('src_3segs', 3, 3, 'expected_dist_key_mappings'); |
| DELETE FROM iris_mst_table WHERE ARRAY[mst_key] <@ (SELECT mst_keys FROM expected_order); |
| SELECT test_init_schedule('current_schedule'); |
| SELECT assert( |
| COUNT(*) = 3, |
| 'Wrong number of mst_keys in schedule table created by test_init_schedule()\n' || |
| 'Expected: 3\nActual: ' || COUNT(*)::TEXT |
| ) FROM current_schedule; |
| -- Make sure none of the entries in the schedule table are NULL |
| -- ( this should only happen for # msts < # segs case ) |
| SELECT assert( |
| COUNT(*) = 0, |
| 'NULL mst_key found in schedule table created by test_init_schedule, even though # msts = # segs' |
| ) FROM current_schedule WHERE mst_key IS NULL; |
| |
| -- Save new order of mst keys in schedule for tracking |
| DROP TABLE IF EXISTS expected_order; |
| CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys; |
| |
| SELECT setup_model_tables('model_input', 'model_output', 'cached_src'); |
| |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| SELECT test_rotate_schedule('current_schedule'); |
| |
| -- Test ordinary hop - no caching ( # msts = # segs ) |
| SELECT test_run_training('src_3segs', 2, False, False, False); |
| |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| SELECT validate_mst_key_order('model_output', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 5, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 12, -- dependent_var length |
| 32, -- independent_var length |
| False -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| --=== Simulate 3 models on 5 segments ( # msts < # segs ) ===-- |
| -- ( by updating dist keys in source table ) |
| CREATE TABLE src_5segs AS |
| SELECT * FROM iris_data_15buf_packed |
| DISTRIBUTED BY (__dist_key__); |
| |
| SELECT update_dist_keys('src_5segs', 5, 3, 'expected_dist_key_mappings'); |
| SELECT test_init_schedule('current_schedule'); |
| SELECT assert( |
| COUNT(*) = 2, |
| 'Wrong number NULL entries in schedule table created by test_init_schedule()\n' || |
| 'Expected: 2\nActual: ' || COUNT(*)::TEXT |
| ) FROM current_schedule WHERE mst_key IS NULL; |
| |
| SELECT assert( |
| COUNT(*) = 3, |
| 'Wrong number of non-NULL entries in schedule table created by test_init_schedule()\n' || |
| 'Expected: 3\nActual: ' || COUNT(*)::TEXT |
| ) FROM current_schedule WHERE mst_key IS NOT NULL; |
| |
| -- Save expected mst_key order |
| DROP TABLE IF EXISTS expected_order; |
| CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys; |
| |
| -- Initialize model_output table, and set model_input & cached_src table names |
| SELECT setup_model_tables('model_input', 'model_output', 'cached_src'); |
| |
| DROP TABLE IF EXISTS model_output_ext; |
| CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); |
| -- Make sure model_output was created with correct mst_key order |
| SELECT validate_mst_key_order('model_output_ext', 'expected_order'); |
| |
| --=== Test very first hop - caching enabled ( # msts < # segs ) ===-- |
| SELECT test_run_training('src_5segs', 0, True, False, True); |
| SELECT test_rotate_schedule('current_schedule'); |
| |
| -- Verify mst keys did not move |
| DROP TABLE IF EXISTS model_output_ext; |
| CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); |
| SELECT validate_mst_key_order('model_output_ext', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| -- This should generate an Assertion failure if the transition function was not called for |
| -- any __dist_key__, even if there is no model on that segment. |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 3, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 12, -- dependent_var length |
| 32, -- independent_var length |
| True -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| -- Test ordinary hop - caching enabled ( # msts < # segs ) |
| -- This should generate an Assertion failure if the transition function is |
| -- called for any __dist_key__ where mst_key is NULL |
| SELECT test_run_training('src_5segs', 1, False, False, True); |
| |
| -- verify mst_keys moved to correct segments |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| DROP TABLE IF EXISTS model_output_ext; |
| CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); |
| SELECT validate_mst_key_order('model_output_ext', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 1, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| False, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 0, -- dependent_var length |
| 0, -- independent_var length |
| True -- use_caching |
| ) AS res, |
| s.__dist_key__ -- WHERE clause restricts this check to segments with models |
| FROM expected_dist_key_mappings s WHERE ARRAY[__dist_key__] <@ ARRAY(SELECT __dist_key__ FROM current_schedule WHERE mst_key IS NOT NULL) |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| SELECT test_rotate_schedule('current_schedule'); |
| |
| -- Test final training hop - caching enabled ( # msts < # segs ) |
| SELECT test_run_training('src_5segs', 5, False, True, True); |
| |
| -- verify mst_keys moved to correct segments |
| UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); |
| DROP TABLE IF EXISTS model_output_ext; |
| CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); |
| SELECT validate_mst_key_order('model_output_ext', 'expected_order'); |
| |
| -- verify transition func was called correct # of times with correct params |
| -- This should generate an Assertion failure if the transition function was not |
| -- called for any __dist_key__, even if there is no model on that segment. |
| DROP TABLE IF EXISTS validate_params_results; |
| CREATE TABLE validate_params_results AS |
| SELECT validate_transition_function_params( |
| s.gp_segment_id, |
| s.segments_per_host, |
| s.expected_images_per_seg, |
| 1, -- expected num_calls (per dist_key) |
| s.__dist_key__, |
| True, -- expected is_final_training_call |
| s.expected_dist_key_mapping, |
| 0, -- dependent_var length |
| 0, -- independent_var length |
| True -- use_caching |
| ) AS res, |
| s.__dist_key__ |
| FROM expected_dist_key_mappings s |
| DISTRIBUTED BY (__dist_key__); |
| SELECT assert(res = 'PASS', res) FROM validate_params_results; |
| |
| -- We don't want to hide the madlib versions of these for any other |
| -- test files that run afterwards |
| DROP FUNCTION madlib_installcheck_deep_learning.version(); |
| DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model( |
| dependent_var BYTEA[], |
| independent_var BYTEA[], |
| dependent_var_shape INTEGER[], |
| independent_var_shape INTEGER[], |
| model_architecture TEXT, |
| compile_params TEXT, |
| fit_params TEXT, |
| dist_key INTEGER, |
| dist_key_mapping INTEGER[], |
| current_seg_id INTEGER, |
| segments_per_host INTEGER[], |
| images_per_seg INTEGER[], |
| accessible_gpus_for_seg INTEGER[], |
| serialized_weights BYTEA, |
| is_final_training_call BOOLEAN, |
| use_caching BOOLEAN, |
| custom_function_map BYTEA |
| ) |
| |
| >>>) -- m4_endif postgres |