DL: Add unit tests for model hopping in FitMultiple

This is a set of "unit tests" implemented as a dev-check test, since
 they need DB access to work.

 Functions unit tested in deep_learning/tests/madlib_keras_fit_multiple.sql_in:
    FitMultiple.init_schedule()
    FitMultiple.rotate_schedule()
    FitMultiple.run_training()

- Tests schedule creation and rotation
- Tests that models are hopping to right segments in right order

- Tests caching, # msts > segs, = , and # mst < segs
  Uses simulated segments instead of actual segments
  so the same test can be run on any size cluster.

- 6 models on 3 segments
- 3 models on 3 segments
- 3 models on 5 segments
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
new file mode 100644
index 0000000..a7bd3fc
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
@@ -0,0 +1,786 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment {}".format(current_seg_id)
+
+    if 'transition_function_params' not in GD:
+        return err_msg
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+BEGIN
+    EXECUTE 'SELECT ARRAY(' ||
+        'SELECT mst_key FROM ' || output_tbl || ' ORDER BY __dist_key__)'
+    INTO actual;
+
+    EXECUTE 'SELECT mst_keys FROM ' || expected_tbl
+    INTO expected;
+
+    PERFORM assert(
+        actual = expected,
+        'mst keys found in wrong order / wrong segments!' ||
+        E'\nActual: ' || actual::text ||
+        E'\nExpected: ' || expected::text
+    );
+END
+$$ LANGUAGE PLpgSQL VOLATILE;
+
+-- Create mst table
+DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'iris_mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$,
+        $$batch_size=10,epochs=1$$
+    ]
+);
+
+-- Create FitMultiple object for running test functions
+SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');
+
+CREATE TABLE src_3segs AS
+    SELECT * FROM iris_data_15buf_packed
+    DISTRIBUTED BY (__dist_key__);
+
+-- Simulate 6 models on 3 segments --
+SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');
+
+--=== Test init_schedule_tbl() ===--
+-- ====================================================================
+-- ===========  Enough setup, now for the actual tests! ===============
+-- ====================================================================
+
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+    s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
+    'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
+) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);
+
+-- Save order of mst keys in schedule for tracking 
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+--=== Test rotate_schedule() ===--
+SELECT test_rotate_schedule('current_schedule');
+UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+SELECT validate_mst_key_order('current_schedule', 'expected_order');
+UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys);  -- Undo for later
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+SELECT validate_mst_key_order('model_output', 'expected_order');
+
+-- Order of params in run_training test function (for reference below):
+--
+--  test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)
+
+--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 0, False, False, False);
+
+    -- mst_keys should not have moved
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            3,
+            s.expected_images_per_seg,
+            5,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            12,                -- dependent_var length
+            32,                -- independent_var length
+            False              -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Test an ordinary hop - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 1, False, False, False);
+SELECT test_rotate_schedule('current_schedule');
+
+    -- check that mst keys rotated onto correct segments
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            5,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            12,                -- dependent_var length
+            32,                -- independent_var length
+            False              -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Test final training hop - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 8, False, True, False);
+
+    -- check that mst keys rotated onto correct segments
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            5,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            True,              -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            12,                -- dependent_var length
+            32,                -- independent_var length
+            False              -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Test very first hop - caching enabled   ( # msts > # segs ) ===--
+SELECT test_run_training('src_3segs', 0, True, False, True);
+
+    -- mst_keys should not have moved
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            5,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            12,                 -- dependent_var length
+            32,                 -- independent_var length
+            True                -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+    -- validate that cached source table was created with proper dist keys
+    SELECT assert(
+        c.__dist_key__ IS NOT NULL AND s.__dist_key__ IS NOT NULL,
+        'cached src table was not created or dist keys do not match original src table')
+    FROM cached_src c FULL JOIN (SELECT __dist_key__ FROM src_3segs GROUP BY __dist_key__) s USING(__dist_key__);
+
+-- Test ordinary hop - caching enabled   ( # msts > # segs )
+SELECT test_run_training('src_3segs', 7, False, False, True);
+
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            1,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            0,                 -- dependent_var length
+            0,                 -- independent_var length
+            True               -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+-- Test final training hop - caching enabled   ( # msts > # segs )
+SELECT test_run_training('src_3segs', 2, False, True, True);
+
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- independent_var & dependent_var should have both been passed as NULL
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            1,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            True,              -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            0,                 -- dependent_var length
+            0,                 -- independent_var length
+            True               -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Simulate 3 models on 3 segments ===--
+SELECT update_dist_keys('src_3segs', 3, 3, 'expected_dist_key_mappings');
+DELETE FROM iris_mst_table WHERE ARRAY[mst_key] <@ (SELECT mst_keys FROM expected_order);
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+    COUNT(*) = 3,
+    'Wrong number of mst_keys in schedule table created by test_init_schedule()\n' ||
+    'Expected: 3\nActual: ' || COUNT(*)::TEXT
+) FROM current_schedule;
+-- Make sure none of the entries in the schedule table are NULL
+--     ( this should only happen for # msts < # segs case )
+SELECT assert(
+    COUNT(*) = 0,
+    'NULL mst_key found in schedule table created by test_init_schedule, even though # msts = # segs'
+) FROM current_schedule WHERE mst_key IS NULL;
+
+-- Save new order of mst keys in schedule for tracking 
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+
+SELECT validate_mst_key_order('model_output', 'expected_order');
+SELECT test_rotate_schedule('current_schedule');
+
+-- Test ordinary hop - no caching    ( # msts = # segs )
+SELECT test_run_training('src_3segs', 2, False, False, False);
+
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    SELECT validate_mst_key_order('model_output', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            5,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            12,                 -- dependent_var length
+            32,                 -- independent_var length
+            False               -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+--=== Simulate 3 models on 5 segments ( # msts < # segs ) ===--
+--      ( by updating dist keys in source table )
+CREATE TABLE src_5segs AS
+    SELECT * FROM iris_data_15buf_packed
+    DISTRIBUTED BY (__dist_key__);
+
+SELECT update_dist_keys('src_5segs', 5, 3, 'expected_dist_key_mappings');
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+    COUNT(*) = 2,
+    'Wrong number NULL entries in schedule table created by test_init_schedule()\n' ||
+    'Expected: 2\nActual: ' || COUNT(*)::TEXT
+) FROM current_schedule WHERE mst_key IS NULL;
+
+SELECT assert(
+    COUNT(*) = 3,
+    'Wrong number of non-NULL entries in schedule table created by test_init_schedule()\n' ||
+    'Expected: 3\nActual: ' || COUNT(*)::TEXT
+) FROM current_schedule WHERE mst_key IS NOT NULL;
+
+-- Save expected mst_key order
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+
+DROP TABLE IF EXISTS model_output_ext;
+CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+-- Make sure model_output was created with correct mst_key order
+SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+--=== Test very first hop - caching enabled   ( # msts < # segs ) ===--
+SELECT test_run_training('src_5segs', 0, True, False, True);
+SELECT test_rotate_schedule('current_schedule');
+
+    -- Verify mst keys did not move
+    DROP TABLE IF EXISTS model_output_ext;
+    CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+    SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    --   This should generate an Assertion failure if the transition function was not called for
+    --   any __dist_key__, even if there is no model on that segment.
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            3,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            12,                 -- dependent_var length
+            32,                 -- independent_var length
+            True                -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+-- Test ordinary hop - caching enabled   ( # msts < # segs )
+  -- This should generate an Assertion failure if the transition function is
+  --  called for any __dist_key__ where mst_key is NULL
+SELECT test_run_training('src_5segs', 1, False, False, True);
+
+    -- verify mst_keys moved to correct segments
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    DROP TABLE IF EXISTS model_output_ext;
+    CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+    SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            1,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            False,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            0,                 -- dependent_var length
+            0,                 -- independent_var length
+            True               -- use_caching
+        ) AS res,
+        s.__dist_key__  -- WHERE clause restricts this check to segments with models
+    FROM expected_dist_key_mappings s WHERE ARRAY[__dist_key__] <@ ARRAY(SELECT __dist_key__ FROM current_schedule WHERE mst_key IS NOT NULL)
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+    SELECT test_rotate_schedule('current_schedule');
+
+-- Test final training hop - caching enabled   ( # msts < # segs )
+SELECT test_run_training('src_5segs', 5, False, True, True);
+
+    -- verify mst_keys moved to correct segments
+    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+    DROP TABLE IF EXISTS model_output_ext;
+    CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
+    SELECT validate_mst_key_order('model_output_ext', 'expected_order');
+
+    -- verify transition func was called correct # of times with correct params
+    --    This should generate an Assertion failure if the transition function was not
+    --    called for any __dist_key__, even if there is no model on that segment.
+    DROP TABLE IF EXISTS validate_params_results;
+    CREATE TABLE validate_params_results AS
+        SELECT validate_transition_function_params(
+            s.gp_segment_id,
+            s.segments_per_host,
+            s.expected_images_per_seg,
+            1,                 -- expected num_calls (per dist_key)
+            s.__dist_key__,
+            True,             -- expected is_final_training_call
+            s.expected_dist_key_mapping,
+            0,                 -- dependent_var length
+            0,                 -- independent_var length
+            True               -- use_caching
+        ) AS res,
+        s.__dist_key__
+    FROM expected_dist_key_mappings s
+    DISTRIBUTED BY (__dist_key__);
+    SELECT assert(res = 'PASS', res) FROM validate_params_results;
+
+-- We don't want to hide the madlib versions of these for any other
+--   test files that run afterwards
+DROP FUNCTION madlib_installcheck_deep_learning.version();
+DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+)
+
+>>> )  -- m4_endif postgres
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
index 389efa3..88011ef 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
@@ -293,3 +293,11 @@
                         NULL,           -- Sample without replacement
                         TRUE            -- Separate output tables
                         );
+
+DROP TABLE IF EXISTS iris_data_15buf_packed, iris_data_15buf_packed_summary;
+SELECT training_preprocessor_dl('iris_test',         -- Source table
+                                'iris_data_15buf_packed',  -- Output table
+                                'class_text',        -- Dependent variable
+                                'attributes',        -- Independent variable
+                                2                    -- buffer_size  (15 buffers)
+                                );
diff --git a/src/ports/postgres/modules/internal/db_utils.py_in b/src/ports/postgres/modules/internal/db_utils.py_in
index 90c09ca..d39c845 100644
--- a/src/ports/postgres/modules/internal/db_utils.py_in
+++ b/src/ports/postgres/modules/internal/db_utils.py_in
@@ -85,6 +85,14 @@
                                             input_str=input_str)
 # ------------------------------------------------------------------------------
 
+def quote_nullable(input_str):
+    if input_str is not None:
+        return quote_literal(input_str)
+    else:
+        return 'NULL'
+
+# ------------------------------------------------------------------------------
+
 def is_col_1d_array(source_table, col_name):
     query = """
         SELECT array_upper({0}, 2) IS NULL AS n_y
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index f3d30ea..20a11c2 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -51,13 +51,12 @@
         return input_str
 # -------------------------------------------------------------------------
 
-
 def quote_ident(input_str):
     """
     Returns input_str with quotes added per Postgres identifier rules.
 
-    This function is available via plpy.quote_ident in PG > 9.1. We add this
-    function for compatibility with Greenplum.
+    This function is available via plpy.quote_ident in PG > 9.1 and GPDB >= 6.0.
+    We add this function for compatibility with older versions of Greenplum.
 
     If the input_str is a lower case string with characters in [a-z0-9_] then the
     string is returned as is, else a double quote is added in front and back of the string.