/* ---------------------------------------------------------------------*//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *//* ---------------------------------------------------------------------*/

m4_include(`SQLCommon.m4')
m4_changequote(<<<,>>>)
m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
,<<<
m4_changequote(<!,!>)

-- =================== Setup & Initialization for FitMultiple tests ========================
--
--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
--  They take a long time to run.  Including similar tests here would probably not be worth
--  the extra time added to dev-check.
--
--  Instead, we just want to unit test different python functions in the FitMultiple class.
--  However, most of the important behavior we need to test requires access to an actual
--  Greenplum database... mostly, we want to make sure that the models hop around to the
--  right segments in the right order.  Therefore, the unit tests are here, as a part of
--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
--  get through to gpdb. We also want to mock the number of segments, so we can test what
--  the model hopping behavior will be for a large cluster, even though dev-check should be
--  able to run on a single dev host.

\i m4_regexp(MODULE_PATHNAME,
             <!\(.*\)libmadlib\.so!>,
            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
)

-- Mock version() function to convince the InputValidator this is the real madlib schema
CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
$$
    SELECT MADLIB_SCHEMA.version();
$$ LANGUAGE sql IMMUTABLE;

-- Call this first to initialize the FitMultiple object, before anything else happens.
-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
--  are filled in.  They can be overriden later, before test functions are called, if necessary.
CREATE OR REPLACE FUNCTION init_fit_mult(
    source_table            VARCHAR,
    model_selection_table   VARCHAR
) RETURNS VOID AS
$$
    import sys
    from mock import Mock, patch

    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
    schema_madlib = 'madlib_installcheck_deep_learning'

    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
        schema_madlib,
        source_table,
        'orig_model_out',
        model_selection_table,
        1
    )
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);

CREATE OR REPLACE FUNCTION test_init_schedule(
    schedule_table VARCHAR
) RETURNS BOOLEAN AS
$$
    fit_mult = GD['fit_mult']
    fit_mult.schedule_tbl = schedule_table

    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
    if fit_mult.init_schedule_tbl():
        err_msg = None
    else:
        err_msg = 'FitMultiple.init_schedule_tbl() returned False'

    return err_msg
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);

CREATE OR REPLACE FUNCTION test_rotate_schedule(
    schedule_table          VARCHAR
) RETURNS VOID AS
$$
    fit_mult = GD['fit_mult']

    if fit_mult.schedule_tbl != schedule_table:
        fit_mult.init_schedule_tbl()

    fit_mult.rotate_schedule_tbl()

$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);

-- Mock fit_transition function, for testing
--  madlib_keras_fit_multiple_model() python code
CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
    dependent_var               BYTEA[],
    independent_var             BYTEA[],
    dependent_var_shape         INTEGER[],
    independent_var_shape       INTEGER[],
    model_architecture          TEXT,
    compile_params              TEXT,
    fit_params                  TEXT,
    dist_key                    INTEGER,
    dist_key_mapping            INTEGER[],
    current_seg_id              INTEGER,
    segments_per_host           INTEGER[],
    images_per_seg              INTEGER[],
    accessible_gpus_for_seg     INTEGER[],
    serialized_weights          BYTEA,
    is_final_training_call      BOOLEAN,
    use_caching                 BOOLEAN,
    custom_function_map         BYTEA
) RETURNS BYTEA AS
$$
    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
                   'independent_var_shape', 'use_caching' ]

    num_calls = 1
    if 'transition_function_params' in GD:
        if dist_key in GD['transition_function_params']:
            if not 'reset' in GD['transition_function_params'][dist_key]:
                num_calls = GD['transition_function_params'][dist_key]['num_calls']
                num_calls += 1

    g = globals()
    params = dict()

    for k in param_keys:
        params[k] = g[k]

    params['dependent_var'] = len(dependent_var[0]) if dependent_var[0] else 0
    params['independent_var'] = len(independent_var[0]) if independent_var[0] else 0
    params['num_calls'] = num_calls

    if not 'transition_function_params' in GD:
        GD['transition_function_params'] = dict()
    GD['transition_function_params'][dist_key] = params

    # compute simulated seg_id ( current_seg_id is the actual seg_id )
    seg_id = dist_key_mapping.index( dist_key )

    if dependent_var_shape[0] and dependent_var_shape[0][0] * num_calls < images_per_seg [ seg_id ]:
        return None
    else:
        GD['transition_function_params'][dist_key]['reset'] = True
        return serialized_weights
$$ LANGUAGE plpythonu VOLATILE;

CREATE OR REPLACE FUNCTION validate_transition_function_params(
    current_seg_id                       INTEGER,
    segments_per_host                    INTEGER[],
    images_per_seg                       INTEGER[],
    expected_num_calls                   INTEGER,
    expected_dist_key                    INTEGER,
    expected_is_final_training_call      BOOLEAN,
    expected_dist_key_mapping            INTEGER[],
    dependent_var_len                    INTEGER,
    independent_var_len                  INTEGER,
    use_caching                          BOOLEAN
) RETURNS TEXT AS
$$
    err_msg = "transition function was not called on segment {}".format(current_seg_id)

    if 'transition_function_params' not in GD:
        return err_msg
    elif expected_dist_key not in GD['transition_function_params']:
        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
    actual = GD['transition_function_params'][expected_dist_key]

    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
       Actual={}, Expected={}"""

    validation_map = {
        'current_seg_id'         : current_seg_id,
        'segments_per_host'      : segments_per_host,
        'num_calls'              : expected_num_calls,
        'is_final_training_call' : expected_is_final_training_call,
        'dist_key'               : expected_dist_key,
        'dependent_var'          : dependent_var_len,
        'independent_var'        : independent_var_len,
        'use_caching'            : use_caching
    }

    for param, expected in validation_map.items():
        if actual[param] != expected:
            return err_msg.format(
                param,
                actual[param],
                expected
            )

    return 'PASS'  # actual params match expected params
$$ LANGUAGE plpythonu VOLATILE;

-- Helper to rotate an array of int's
CREATE OR REPLACE FUNCTION rotate_keys(
    keys    INTEGER[]
) RETURNS INTEGER[]
AS $$
   return keys[-1:] + keys[:-1]
$$ LANGUAGE plpythonu IMMUTABLE;

CREATE OR REPLACE FUNCTION reverse_rotate_keys(
    keys    INTEGER[]
) RETURNS INTEGER[]
AS $$
   return keys[1:] + keys[:1]
$$ LANGUAGE plpythonu IMMUTABLE;

CREATE OR REPLACE FUNCTION setup_model_tables(
    input_table TEXT,
    output_table TEXT,
    cached_source_table TEXT
) RETURNS TEXT AS
$$
    fit_mult = GD['fit_mult']

    fit_mult.model_input_tbl = input_table
    fit_mult.model_output_tbl = output_table
    fit_mult.cached_source_table = cached_source_table

    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
    fit_mult.init_model_output_tbl()
    q = """
        UPDATE {model_out} -- Reduce size of model for faster tests
            SET ( model_weights, model_arch, compile_params, fit_params )
                  = ( mst_key::TEXT::BYTEA,
                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
                      'c' || mst_key::TEXT,
                      'f' || mst_key::TEXT
                    )
        WHERE mst_key IS NOT NULL;
    """.format(model_out=fit_mult.model_output_tbl)
    plpy.execute(q)
$$ LANGUAGE plpythonu VOLATILE;

-- Updates dist keys in src table and internal fit_mult class variables
--    num_data_segs can be larger than actual number of segments, since this
--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
--    which can be used for validating dist key mappings and images per seg later.
CREATE OR REPLACE FUNCTION update_dist_keys(
    src_table TEXT,
    num_data_segs INTEGER,
    num_models INTEGER,
    expected_distkey_mappings_tbl TEXT
) RETURNS VOID AS
$$
    redist_cmd = """
        UPDATE {src_table}
            SET __dist_key__ = (buffer_id % {num_data_segs})
    """.format(**globals())
    plpy.execute(redist_cmd)

    fit_mult = GD['fit_mult']

    q = """
        SELECT SUM(attributes_shape[1]) AS image_count,
            __dist_key__
        FROM {src_table}
        GROUP BY __dist_key__
        ORDER BY __dist_key__
    """.format(**globals())
    res = plpy.execute(q)

    images_per_seg = [ int(r['image_count']) for r in res ]
    dist_keys = [ int(r['__dist_key__']) for r in res ]
    num_dist_keys = len(dist_keys)

    fit_mult.source_table = src_table
    fit_mult.max_dist_key = sorted(dist_keys)[-1]
    fit_mult.images_per_seg_train = images_per_seg
    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
    data_distribution_per_seg = [num_data_segs] * num_dist_keys
    fit_mult.segments_per_host = data_distribution_per_seg

    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
    if num_models < num_dist_keys:
        fit_mult.msts_for_schedule += [None] * \
                                 (num_dist_keys - num_models)
    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
                              for mst in fit_mult.msts_for_schedule ]
    fit_mult.num_msts = num_models

    fit_mult.extra_dist_keys = []
    for i in range(num_models - num_dist_keys):
        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys

    create_distkey_map_tbl_cmd = """
        DROP TABLE IF EXISTS {exp_table};
        CREATE TABLE {exp_table} AS
        SELECT
            ARRAY(  -- map of dist_keys to seg_ids from source table
                SELECT __dist_key__
                FROM {fm.source_table}
                GROUP BY __dist_key__
                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
            ) AS expected_dist_key_mapping,
            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
            ARRAY{data_distribution_per_seg} AS segments_per_host,
            __dist_key__
        FROM {fm.source_table}
        GROUP BY __dist_key__
        DISTRIBUTED BY (__dist_key__);
    """.format(
            fm=fit_mult,
            data_distribution_per_seg=data_distribution_per_seg,
            exp_table=expected_distkey_mappings_tbl
        )
    plpy.execute(create_distkey_map_tbl_cmd)
$$ LANGUAGE plpythonu VOLATILE;

CREATE OR REPLACE FUNCTION test_run_training(
    source_table TEXT,
    hop INTEGER,
    is_very_first_hop BOOLEAN,
    is_final_training_call BOOLEAN,
    use_caching BOOLEAN
) RETURNS VOID AS
$$
    fit_mult = GD['fit_mult']

    # Each time we start a new test, clear out stats
    #   like num_calls from GD so we don't end up validating
    #   against old results
    if 'transition_function_params' in GD:
        del GD['transition_function_params']

    fit_mult.source_tbl = source_table
    fit_mult.is_very_first_hop = is_very_first_hop
    fit_mult.is_final_training_call = is_final_training_call
    if use_caching != fit_mult.use_caching:
        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
                                  # query when use_caching changes!
    fit_mult.use_caching = use_caching

    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);

CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
RETURNS VOID AS
$$
DECLARE
    actual INTEGER[];
    expected INTEGER[];
BEGIN
    EXECUTE 'SELECT ARRAY(' ||
        'SELECT mst_key FROM ' || output_tbl || ' ORDER BY __dist_key__)'
    INTO actual;

    EXECUTE 'SELECT mst_keys FROM ' || expected_tbl
    INTO expected;

    PERFORM assert(
        actual = expected,
        'mst keys found in wrong order / wrong segments!' ||
        E'\nActual: ' || actual::text ||
        E'\nExpected: ' || expected::text
    );
END
$$ LANGUAGE PLpgSQL VOLATILE;

-- Create mst table
DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'iris_mst_table',
    ARRAY[1],
    ARRAY[
        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=5,epochs=1$$,
        $$batch_size=10,epochs=1$$
    ]
);

-- Create FitMultiple object for running test functions
SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');

CREATE TABLE src_3segs AS
    SELECT * FROM iris_data_15buf_packed
    DISTRIBUTED BY (__dist_key__);

-- Simulate 6 models on 3 segments --
SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');

--=== Test init_schedule_tbl() ===--
-- ====================================================================
-- ===========  Enough setup, now for the actual tests! ===============
-- ====================================================================

SELECT test_init_schedule('current_schedule');
SELECT assert(
    s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
    'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);

-- Save order of mst keys in schedule for tracking
DROP TABLE IF EXISTS expected_order;
CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;

--=== Test rotate_schedule() ===--
SELECT test_rotate_schedule('current_schedule');
UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
SELECT validate_mst_key_order('current_schedule', 'expected_order');
UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys);  -- Undo for later

-- Initialize model_output table, and set model_input & cached_src table names
SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
SELECT validate_mst_key_order('model_output', 'expected_order');

-- Order of params in run_training test function (for reference below):
--
--  test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)

--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
SELECT test_run_training('src_3segs', 0, False, False, False);

    -- mst_keys should not have moved
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            ARRAY[3, 3, 3],
            s.expected_images_per_seg,
            5,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            12,                -- dependent_var length
            32,                -- independent_var length
            False              -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

--=== Test an ordinary hop - no caching (# msts > # segs) ===--
SELECT test_run_training('src_3segs', 1, False, False, False);
SELECT test_rotate_schedule('current_schedule');

    -- check that mst keys rotated onto correct segments
    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            5,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            12,                -- dependent_var length
            32,                -- independent_var length
            False              -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

--=== Test final training hop - no caching (# msts > # segs) ===--
SELECT test_run_training('src_3segs', 8, False, True, False);

    -- check that mst keys rotated onto correct segments
    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            5,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            True,              -- expected is_final_training_call
            s.expected_dist_key_mapping,
            12,                -- dependent_var length
            32,                -- independent_var length
            False              -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

--=== Test very first hop - caching enabled   ( # msts > # segs ) ===--
SELECT test_run_training('src_3segs', 0, True, False, True);

    -- mst_keys should not have moved
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            5,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            12,                 -- dependent_var length
            32,                 -- independent_var length
            True                -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

    -- validate that cached source table was created with proper dist keys
    SELECT assert(
        c.__dist_key__ IS NOT NULL AND s.__dist_key__ IS NOT NULL,
        'cached src table was not created or dist keys do not match original src table')
    FROM cached_src c FULL JOIN (SELECT __dist_key__ FROM src_3segs GROUP BY __dist_key__) s USING(__dist_key__);

-- Test ordinary hop - caching enabled   ( # msts > # segs )
SELECT test_run_training('src_3segs', 7, False, False, True);

    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            1,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            0,                 -- dependent_var length
            0,                 -- independent_var length
            True               -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

-- Test final training hop - caching enabled   ( # msts > # segs )
SELECT test_run_training('src_3segs', 2, False, True, True);

    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- independent_var & dependent_var should have both been passed as NULL
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            1,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            True,              -- expected is_final_training_call
            s.expected_dist_key_mapping,
            0,                 -- dependent_var length
            0,                 -- independent_var length
            True               -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

--=== Simulate 3 models on 3 segments ===--
SELECT update_dist_keys('src_3segs', 3, 3, 'expected_dist_key_mappings');
DELETE FROM iris_mst_table WHERE ARRAY[mst_key] <@ (SELECT mst_keys FROM expected_order);
SELECT test_init_schedule('current_schedule');
SELECT assert(
    COUNT(*) = 3,
    'Wrong number of mst_keys in schedule table created by test_init_schedule()\n' ||
    'Expected: 3\nActual: ' || COUNT(*)::TEXT
) FROM current_schedule;
-- Make sure none of the entries in the schedule table are NULL
--     ( this should only happen for # msts < # segs case )
SELECT assert(
    COUNT(*) = 0,
    'NULL mst_key found in schedule table created by test_init_schedule, even though # msts = # segs'
) FROM current_schedule WHERE mst_key IS NULL;

-- Save new order of mst keys in schedule for tracking
DROP TABLE IF EXISTS expected_order;
CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;

SELECT setup_model_tables('model_input', 'model_output', 'cached_src');

SELECT validate_mst_key_order('model_output', 'expected_order');
SELECT test_rotate_schedule('current_schedule');

-- Test ordinary hop - no caching    ( # msts = # segs )
SELECT test_run_training('src_3segs', 2, False, False, False);

    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    SELECT validate_mst_key_order('model_output', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            5,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            12,                 -- dependent_var length
            32,                 -- independent_var length
            False               -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

--=== Simulate 3 models on 5 segments ( # msts < # segs ) ===--
--      ( by updating dist keys in source table )
CREATE TABLE src_5segs AS
    SELECT * FROM iris_data_15buf_packed
    DISTRIBUTED BY (__dist_key__);

SELECT update_dist_keys('src_5segs', 5, 3, 'expected_dist_key_mappings');
SELECT test_init_schedule('current_schedule');
SELECT assert(
    COUNT(*) = 2,
    'Wrong number NULL entries in schedule table created by test_init_schedule()\n' ||
    'Expected: 2\nActual: ' || COUNT(*)::TEXT
) FROM current_schedule WHERE mst_key IS NULL;

SELECT assert(
    COUNT(*) = 3,
    'Wrong number of non-NULL entries in schedule table created by test_init_schedule()\n' ||
    'Expected: 3\nActual: ' || COUNT(*)::TEXT
) FROM current_schedule WHERE mst_key IS NOT NULL;

-- Save expected mst_key order
DROP TABLE IF EXISTS expected_order;
CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;

-- Initialize model_output table, and set model_input & cached_src table names
SELECT setup_model_tables('model_input', 'model_output', 'cached_src');

DROP TABLE IF EXISTS model_output_ext;
CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
-- Make sure model_output was created with correct mst_key order
SELECT validate_mst_key_order('model_output_ext', 'expected_order');

--=== Test very first hop - caching enabled   ( # msts < # segs ) ===--
SELECT test_run_training('src_5segs', 0, True, False, True);
SELECT test_rotate_schedule('current_schedule');

    -- Verify mst keys did not move
    DROP TABLE IF EXISTS model_output_ext;
    CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
    SELECT validate_mst_key_order('model_output_ext', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    --   This should generate an Assertion failure if the transition function was not called for
    --   any __dist_key__, even if there is no model on that segment.
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            3,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            12,                 -- dependent_var length
            32,                 -- independent_var length
            True                -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

-- Test ordinary hop - caching enabled   ( # msts < # segs )
  -- This should generate an Assertion failure if the transition function is
  --  called for any __dist_key__ where mst_key is NULL
SELECT test_run_training('src_5segs', 1, False, False, True);

    -- verify mst_keys moved to correct segments
    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    DROP TABLE IF EXISTS model_output_ext;
    CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
    SELECT validate_mst_key_order('model_output_ext', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            1,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            False,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            0,                 -- dependent_var length
            0,                 -- independent_var length
            True               -- use_caching
        ) AS res,
        s.__dist_key__  -- WHERE clause restricts this check to segments with models
    FROM expected_dist_key_mappings s WHERE ARRAY[__dist_key__] <@ ARRAY(SELECT __dist_key__ FROM current_schedule WHERE mst_key IS NOT NULL)
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

    SELECT test_rotate_schedule('current_schedule');

-- Test final training hop - caching enabled   ( # msts < # segs )
SELECT test_run_training('src_5segs', 5, False, True, True);

    -- verify mst_keys moved to correct segments
    UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
    DROP TABLE IF EXISTS model_output_ext;
    CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__);
    SELECT validate_mst_key_order('model_output_ext', 'expected_order');

    -- verify transition func was called correct # of times with correct params
    --    This should generate an Assertion failure if the transition function was not
    --    called for any __dist_key__, even if there is no model on that segment.
    DROP TABLE IF EXISTS validate_params_results;
    CREATE TABLE validate_params_results AS
        SELECT validate_transition_function_params(
            s.gp_segment_id,
            s.segments_per_host,
            s.expected_images_per_seg,
            1,                 -- expected num_calls (per dist_key)
            s.__dist_key__,
            True,             -- expected is_final_training_call
            s.expected_dist_key_mapping,
            0,                 -- dependent_var length
            0,                 -- independent_var length
            True               -- use_caching
        ) AS res,
        s.__dist_key__
    FROM expected_dist_key_mappings s
    DISTRIBUTED BY (__dist_key__);
    SELECT assert(res = 'PASS', res) FROM validate_params_results;

-- We don't want to hide the madlib versions of these for any other
--   test files that run afterwards
DROP FUNCTION madlib_installcheck_deep_learning.version();
DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
    dependent_var               BYTEA[],
    independent_var             BYTEA[],
    dependent_var_shape         INTEGER[],
    independent_var_shape       INTEGER[],
    model_architecture          TEXT,
    compile_params              TEXT,
    fit_params                  TEXT,
    dist_key                    INTEGER,
    dist_key_mapping            INTEGER[],
    current_seg_id              INTEGER,
    segments_per_host           INTEGER[],
    images_per_seg              INTEGER[],
    accessible_gpus_for_seg     INTEGER[],
    serialized_weights          BYTEA,
    is_final_training_call      BOOLEAN,
    use_caching                 BOOLEAN,
    custom_function_map         BYTEA
)

>>>)  -- m4_endif postgres
