DL: Add unit tests for model hopping in FitMultiple
This is a set of "unit tests" implemented as a dev-check test, since
they need DB access to work.
Functions unit tested in deep_learning/tests/madlib_keras_fit_multiple.sql_in:
FitMultiple.init_schedule()
FitMultiple.rotate_schedule()
FitMultiple.run_training()
- Tests schedule creation and rotation
- Tests that models are hopping to right segments in right order
- Tests caching, # msts > segs, = , and # mst < segs
Uses simulated segments instead of actual segments
so the same test can be run on any size cluster.
- 6 models on 3 segments
- 3 models on 3 segments
- 3 models on 5 segments
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
new file mode 100644
index 0000000..a7bd3fc
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
@@ -0,0 +1,786 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+-- For fit multiple, we test end-to-end functionality along with performance elsewhere.
+-- They take a long time to run. Including similar tests here would probably not be worth
+-- the extra time added to dev-check.
+--
+-- Instead, we just want to unit test different python functions in the FitMultiple class.
+-- However, most of the important behavior we need to test requires access to an actual
+-- Greenplum database... mostly, we want to make sure that the models hop around to the
+-- right segments in the right order. Therefore, the unit tests are here, as a part of
+-- dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+-- do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+-- get through to gpdb. We also want to mock the number of segments, so we can test what
+-- the model hopping behavior will be for a large cluster, even though dev-check should be
+-- able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+ <!\(.*\)libmadlib\.so!>,
+ <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+ SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+-- are filled in. They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+ source_table VARCHAR,
+ model_selection_table VARCHAR
+) RETURNS VOID AS
+$$
+ import sys
+ from mock import Mock, patch
+
+ PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+ schema_madlib = 'madlib_installcheck_deep_learning'
+
+ GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+ schema_madlib,
+ source_table,
+ 'orig_model_out',
+ model_selection_table,
+ 1
+ )
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+ schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+ fit_mult = GD['fit_mult']
+ fit_mult.schedule_tbl = schedule_table
+
+ plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+ if fit_mult.init_schedule_tbl():
+ err_msg = None
+ else:
+ err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+ return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+ schedule_table VARCHAR
+) RETURNS VOID AS
+$$
+ fit_mult = GD['fit_mult']
+
+ if fit_mult.schedule_tbl != schedule_table:
+ fit_mult.init_schedule_tbl()
+
+ fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+-- madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+ dependent_var BYTEA,
+ independent_var BYTEA,
+ dependent_var_shape INTEGER[],
+ independent_var_shape INTEGER[],
+ model_architecture TEXT,
+ compile_params TEXT,
+ fit_params TEXT,
+ dist_key INTEGER,
+ dist_key_mapping INTEGER[],
+ current_seg_id INTEGER,
+ segments_per_host INTEGER,
+ images_per_seg INTEGER[],
+ accessible_gpus_for_seg INTEGER[],
+ serialized_weights BYTEA,
+ is_final_training_call BOOLEAN,
+ use_caching BOOLEAN,
+ custom_function_map BYTEA
+) RETURNS BYTEA AS
+$$
+ param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+ 'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+ 'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+ 'independent_var_shape', 'use_caching' ]
+
+ num_calls = 1
+ if 'transition_function_params' in GD:
+ if dist_key in GD['transition_function_params']:
+ if not 'reset' in GD['transition_function_params'][dist_key]:
+ num_calls = GD['transition_function_params'][dist_key]['num_calls']
+ num_calls += 1
+
+ g = globals()
+ params = dict()
+
+ for k in param_keys:
+ params[k] = g[k]
+
+ params['dependent_var'] = len(dependent_var) if dependent_var else 0
+ params['independent_var'] = len(independent_var) if independent_var else 0
+ params['num_calls'] = num_calls
+
+ if not 'transition_function_params' in GD:
+ GD['transition_function_params'] = dict()
+ GD['transition_function_params'][dist_key] = params
+
+ # compute simulated seg_id ( current_seg_id is the actual seg_id )
+ seg_id = dist_key_mapping.index( dist_key )
+
+ if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+ return None
+ else:
+ GD['transition_function_params'][dist_key]['reset'] = True
+ return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+ current_seg_id INTEGER,
+ segments_per_host INTEGER,
+ images_per_seg INTEGER[],
+ expected_num_calls INTEGER,
+ expected_dist_key INTEGER,
+ expected_is_final_training_call BOOLEAN,
+ expected_dist_key_mapping INTEGER[],
+ dependent_var_len INTEGER,
+ independent_var_len INTEGER,
+ use_caching BOOLEAN
+) RETURNS TEXT AS
+$$
+ err_msg = "transition function was not called on segment {}".format(current_seg_id)
+
+ if 'transition_function_params' not in GD:
+ return err_msg
+ elif expected_dist_key not in GD['transition_function_params']:
+ return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+ actual = GD['transition_function_params'][expected_dist_key]
+
+ err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+ Actual={}, Expected={}"""
+
+ validation_map = {
+ 'current_seg_id' : current_seg_id,
+ 'segments_per_host' : segments_per_host,
+ 'num_calls' : expected_num_calls,
+ 'is_final_training_call' : expected_is_final_training_call,
+ 'dist_key' : expected_dist_key,
+ 'dependent_var' : dependent_var_len,
+ 'independent_var' : independent_var_len,
+ 'use_caching' : use_caching
+ }
+
+ for param, expected in validation_map.items():
+ if actual[param] != expected:
+ return err_msg.format(
+ param,
+ actual[param],
+ expected
+ )
+
+ return 'PASS' # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+ keys INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+ return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+ keys INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+ return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+ input_table TEXT,
+ output_table TEXT,
+ cached_source_table TEXT
+) RETURNS TEXT AS
+$$
+ fit_mult = GD['fit_mult']
+
+ fit_mult.model_input_tbl = input_table
+ fit_mult.model_output_tbl = output_table
+ fit_mult.cached_source_table = cached_source_table
+
+ plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+ plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+ fit_mult.init_model_output_tbl()
+ q = """
+ UPDATE {model_out} -- Reduce size of model for faster tests
+ SET ( model_weights, model_arch, compile_params, fit_params )
+ = ( mst_key::TEXT::BYTEA,
+ ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+ 'c' || mst_key::TEXT,
+ 'f' || mst_key::TEXT
+ )
+ WHERE mst_key IS NOT NULL;
+ """.format(model_out=fit_mult.model_output_tbl)
+ plpy.execute(q)
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+-- num_data_segs can be larger than actual number of segments, since this
+-- is just for simulated testing. This will also write to expected_distkey_mappings_tbl
+-- which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+ src_table TEXT,
+ num_data_segs INTEGER,
+ num_models INTEGER,
+ expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$
+ redist_cmd = """
+ UPDATE {src_table}
+ SET __dist_key__ = (buffer_id % {num_data_segs})
+ """.format(**globals())
+ plpy.execute(redist_cmd)
+
+ fit_mult = GD['fit_mult']
+
+ q = """
+ SELECT SUM(independent_var_shape[1]) AS image_count,
+ __dist_key__
+ FROM {src_table}
+ GROUP BY __dist_key__
+ ORDER BY __dist_key__
+ """.format(**globals())
+ res = plpy.execute(q)
+
+ images_per_seg = [ int(r['image_count']) for r in res ]
+ dist_keys = [ int(r['__dist_key__']) for r in res ]
+ num_dist_keys = len(dist_keys)
+
+ fit_mult.source_table = src_table
+ fit_mult.max_dist_key = sorted(dist_keys)[-1]
+ fit_mult.images_per_seg_train = images_per_seg
+ fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+ fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+ fit_mult.segments_per_host = num_data_segs
+
+ fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+ if num_models < num_dist_keys:
+ fit_mult.msts_for_schedule += [None] * \
+ (num_dist_keys - num_models)
+ fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+ for mst in fit_mult.msts_for_schedule ]
+ fit_mult.num_msts = num_models
+
+ fit_mult.extra_dist_keys = []
+ for i in range(num_models - num_dist_keys):
+ fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+ fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+ create_distkey_map_tbl_cmd = """
+ DROP TABLE IF EXISTS {exp_table};
+ CREATE TABLE {exp_table} AS
+ SELECT
+ ARRAY( -- map of dist_keys to seg_ids from source table
+ SELECT __dist_key__
+ FROM {fm.source_table}
+ GROUP BY __dist_key__
+ ORDER BY __dist_key__ -- This would be gp_segment_id if it weren't a simulation
+ ) AS expected_dist_key_mapping,
+ ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+ {num_data_segs} AS segments_per_host,
+ __dist_key__
+ FROM {fm.source_table}
+ GROUP BY __dist_key__
+ DISTRIBUTED BY (__dist_key__);
+ """.format(
+ fm=fit_mult,
+ num_data_segs=num_data_segs,
+ exp_table=expected_distkey_mappings_tbl
+ )
+ plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+ source_table TEXT,
+ hop INTEGER,
+ is_very_first_hop BOOLEAN,
+ is_final_training_call BOOLEAN,
+ use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+ fit_mult = GD['fit_mult']
+
+ # Each time we start a new test, clear out stats
+ # like num_calls from GD so we don't end up validating
+ # against old results
+ if 'transition_function_params' in GD:
+ del GD['transition_function_params']
+
+ fit_mult.source_tbl = source_table
+ fit_mult.is_very_first_hop = is_very_first_hop
+ fit_mult.is_final_training_call = is_final_training_call
+ if use_caching != fit_mult.use_caching:
+ fit_mult.udf_plan = None # Otherwise it will execute the wrong
+ # query when use_caching changes!
+ fit_mult.use_caching = use_caching
+
+ fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+ actual INTEGER[];
+ expected INTEGER[];
+BEGIN
+ EXECUTE 'SELECT ARRAY(' ||
+ 'SELECT mst_key FROM ' || output_tbl || ' ORDER BY __dist_key__)'
+ INTO actual;
+
+ EXECUTE 'SELECT mst_keys FROM ' || expected_tbl
+ INTO expected;
+
+ PERFORM assert(
+ actual = expected,
+ 'mst keys found in wrong order / wrong segments!' ||
+ E'\nActual: ' || actual::text ||
+ E'\nExpected: ' || expected::text
+ );
+END
+$$ LANGUAGE PLpgSQL VOLATILE;
+
+-- Create mst table
+DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
+SELECT load_model_selection_table(
+ 'iris_model_arch',
+ 'iris_mst_table',
+ ARRAY[1],
+ ARRAY[
+ $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
+ $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=5,epochs=1$$,
+ $$batch_size=10,epochs=1$$
+ ]
+);
+
+-- Create FitMultiple object for running test functions
+SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');
+
+CREATE TABLE src_3segs AS
+ SELECT * FROM iris_data_15buf_packed
+ DISTRIBUTED BY (__dist_key__);
+
+-- Simulate 6 models on 3 segments --
+SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');
+
+--=== Test init_schedule_tbl() ===--
+-- ====================================================================
+-- =========== Enough setup, now for the actual tests! ===============
+-- ====================================================================
+
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+ s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
+ 'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
+) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);
+
+-- Save order of mst keys in schedule for tracking
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+--=== Test rotate_schedule() ===--
+SELECT test_rotate_schedule('current_schedule');
+UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+SELECT validate_mst_key_order('current_schedule', 'expected_order');
+UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys); -- Undo for later
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+SELECT validate_mst_key_order('model_output', 'expected_order');
+
+-- Order of params in run_training test function (for reference below):
+--
+-- test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)
+
+--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 0, False, False, False);
+
+ -- mst_keys should not have moved
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ 3,
+ s.expected_images_per_seg,
+ 5, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 12, -- dependent_var length
+ 32, -- independent_var length
+ False -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Test an ordinary hop - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 1, False, False, False);
+SELECT test_rotate_schedule('current_schedule');
+
+ -- check that mst keys rotated onto correct segments
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 5, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 12, -- dependent_var length
+ 32, -- independent_var length
+ False -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Test final training hop - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 8, False, True, False);
+
+ -- check that mst keys rotated onto correct segments
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 5, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ True, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 12, -- dependent_var length
+ 32, -- independent_var length
+ False -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Test very first hop - caching enabled ( # msts > # segs ) ===--
+SELECT test_run_training('src_3segs', 0, True, False, True);
+
+ -- mst_keys should not have moved
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 5, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 12, -- dependent_var length
+ 32, -- independent_var length
+ True -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+ -- validate that cached source table was created with proper dist keys
+ SELECT assert(
+ c.__dist_key__ IS NOT NULL AND s.__dist_key__ IS NOT NULL,
+ 'cached src table was not created or dist keys do not match original src table')
+ FROM cached_src c FULL JOIN (SELECT __dist_key__ FROM src_3segs GROUP BY __dist_key__) s USING(__dist_key__);
+
+-- Test ordinary hop - caching enabled ( # msts > # segs )
+SELECT test_run_training('src_3segs', 7, False, False, True);
+
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 1, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 0, -- dependent_var length
+ 0, -- independent_var length
+ True -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+-- Test final training hop - caching enabled ( # msts > # segs )
+SELECT test_run_training('src_3segs', 2, False, True, True);
+
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- independent_var & dependent_var should have both been passed as NULL
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 1, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ True, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 0, -- dependent_var length
+ 0, -- independent_var length
+ True -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Simulate 3 models on 3 segments ===--
+SELECT update_dist_keys('src_3segs', 3, 3, 'expected_dist_key_mappings');
+DELETE FROM iris_mst_table WHERE ARRAY[mst_key] <@ (SELECT mst_keys FROM expected_order);
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+ COUNT(*) = 3,
+ 'Wrong number of mst_keys in schedule table created by test_init_schedule()\n' ||
+ 'Expected: 3\nActual: ' || COUNT(*)::TEXT
+) FROM current_schedule;
+-- Make sure none of the entries in the schedule table are NULL
+-- ( this should only happen for # msts < # segs case )
+SELECT assert(
+ COUNT(*) = 0,
+ 'NULL mst_key found in schedule table created by test_init_schedule, even though # msts = # segs'
+) FROM current_schedule WHERE mst_key IS NULL;
+
+-- Save new order of mst keys in schedule for tracking
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+
+SELECT validate_mst_key_order('model_output', 'expected_order');
+SELECT test_rotate_schedule('current_schedule');
+
+-- Test ordinary hop - no caching ( # msts = # segs )
+SELECT test_run_training('src_3segs', 2, False, False, False);
+
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ SELECT validate_mst_key_order('model_output', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 5, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 12, -- dependent_var length
+ 32, -- independent_var length
+ False -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Simulate 3 models on 5 segments ( # msts < # segs ) ===--
+-- ( by updating dist keys in source table )
+CREATE TABLE src_5segs AS
+ SELECT * FROM iris_data_15buf_packed
+ DISTRIBUTED BY (__dist_key__);
+
+SELECT update_dist_keys('src_5segs', 5, 3, 'expected_dist_key_mappings');
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+ COUNT(*) = 2,
+ 'Wrong number NULL entries in schedule table created by test_init_schedule()\n' ||
+ 'Expected: 2\nActual: ' || COUNT(*)::TEXT
+) FROM current_schedule WHERE mst_key IS NULL;
+
+SELECT assert(
+ COUNT(*) = 3,
+ 'Wrong number of non-NULL entries in schedule table created by test_init_schedule()\n' ||
+ 'Expected: 3\nActual: ' || COUNT(*)::TEXT
+) FROM current_schedule WHERE mst_key IS NOT NULL;
+
+-- Save expected mst_key order
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+
+DROP TABLE IF EXISTS model_output_ext;
+CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+-- Make sure model_output was created with correct mst_key order
+SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+--=== Test very first hop - caching enabled ( # msts < # segs ) ===--
+SELECT test_run_training('src_5segs', 0, True, False, True);
+SELECT test_rotate_schedule('current_schedule');
+
+ -- Verify mst keys did not move
+ DROP TABLE IF EXISTS model_output_ext;
+ CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+ SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ -- This should generate an Assertion failure if the transition function was not called for
+ -- any __dist_key__, even if there is no model on that segment.
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 3, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 12, -- dependent_var length
+ 32, -- independent_var length
+ True -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+-- Test ordinary hop - caching enabled ( # msts < # segs )
+ -- This should generate an Assertion failure if the transition function is
+ -- called for any __dist_key__ where mst_key is NULL
+SELECT test_run_training('src_5segs', 1, False, False, True);
+
+ -- verify mst_keys moved to correct segments
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ DROP TABLE IF EXISTS model_output_ext;
+ CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+ SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 1, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ False, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 0, -- dependent_var length
+ 0, -- independent_var length
+ True -- use_caching
+ ) AS res,
+ s.__dist_key__ -- WHERE clause restricts this check to segments with models
+ FROM expected_dist_key_mappings s WHERE ARRAY[__dist_key__] <@ ARRAY(SELECT __dist_key__ FROM current_schedule WHERE mst_key IS NOT NULL)
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+ SELECT test_rotate_schedule('current_schedule');
+
+-- Test final training hop - caching enabled ( # msts < # segs )
+SELECT test_run_training('src_5segs', 5, False, True, True);
+
+ -- verify mst_keys moved to correct segments
+ UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+ DROP TABLE IF EXISTS model_output_ext;
+ CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+ SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+ -- verify transition func was called correct # of times with correct params
+ -- This should generate an Assertion failure if the transition function was not
+ -- called for any __dist_key__, even if there is no model on that segment.
+ DROP TABLE IF EXISTS validate_params_results;
+ CREATE TABLE validate_params_results AS
+ SELECT validate_transition_function_params(
+ s.gp_segment_id,
+ s.segments_per_host,
+ s.expected_images_per_seg,
+ 1, -- expected num_calls (per dist_key)
+ s.__dist_key__,
+ True, -- expected is_final_training_call
+ s.expected_dist_key_mapping,
+ 0, -- dependent_var length
+ 0, -- independent_var length
+ True -- use_caching
+ ) AS res,
+ s.__dist_key__
+ FROM expected_dist_key_mappings s
+ DISTRIBUTED BY (__dist_key__);
+ SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+-- We don't want to hide the madlib versions of these for any other
+-- test files that run afterwards
+DROP FUNCTION madlib_installcheck_deep_learning.version();
+DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+ dependent_var BYTEA,
+ independent_var BYTEA,
+ dependent_var_shape INTEGER[],
+ independent_var_shape INTEGER[],
+ model_architecture TEXT,
+ compile_params TEXT,
+ fit_params TEXT,
+ dist_key INTEGER,
+ dist_key_mapping INTEGER[],
+ current_seg_id INTEGER,
+ segments_per_host INTEGER,
+ images_per_seg INTEGER[],
+ accessible_gpus_for_seg INTEGER[],
+ serialized_weights BYTEA,
+ is_final_training_call BOOLEAN,
+ use_caching BOOLEAN,
+ custom_function_map BYTEA
+)
+
+>>> ) -- m4_endif postgres
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
index 389efa3..88011ef 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
@@ -293,3 +293,11 @@
NULL, -- Sample without replacement
TRUE -- Separate output tables
);
+
+DROP TABLE IF EXISTS iris_data_15buf_packed, iris_data_15buf_packed_summary;
+SELECT training_preprocessor_dl('iris_test', -- Source table
+ 'iris_data_15buf_packed', -- Output table
+ 'class_text', -- Dependent variable
+ 'attributes', -- Independent variable
+ 2 -- buffer_size (15 buffers)
+ );
diff --git a/src/ports/postgres/modules/internal/db_utils.py_in b/src/ports/postgres/modules/internal/db_utils.py_in
index 90c09ca..d39c845 100644
--- a/src/ports/postgres/modules/internal/db_utils.py_in
+++ b/src/ports/postgres/modules/internal/db_utils.py_in
@@ -85,6 +85,14 @@
input_str=input_str)
# ------------------------------------------------------------------------------
+def quote_nullable(input_str):
+ if input_str is not None:
+ return quote_literal(input_str)
+ else:
+ return 'NULL'
+
+# ------------------------------------------------------------------------------
+
def is_col_1d_array(source_table, col_name):
query = """
SELECT array_upper({0}, 2) IS NULL AS n_y
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index f3d30ea..20a11c2 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -51,13 +51,12 @@
return input_str
# -------------------------------------------------------------------------
-
def quote_ident(input_str):
"""
Returns input_str with quotes added per Postgres identifier rules.
- This function is available via plpy.quote_ident in PG > 9.1. We add this
- function for compatibility with Greenplum.
+ This function is available via plpy.quote_ident in PG > 9.1 and GPDB >= 6.0.
+ We add this function for compatibility with older versions of Greenplum.
If the input_str is a lower case string with characters in [a-z0-9_] then the
string is returned as is, else a double quote is added in front and back of the string.