| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| import time |
| import sys |
| |
| from madlib_keras import compute_loss_and_metrics |
| from madlib_keras import get_initial_weights |
| from madlib_keras import get_model_arch_weights |
| from madlib_keras import get_source_summary_table_dict |
| from madlib_keras import should_compute_metrics_this_iter |
| from madlib_keras_helper import * |
| from madlib_keras_model_selection import ModelSelectionSchema |
| from madlib_keras_validator import * |
| from madlib_keras_wrapper import * |
| |
| from utilities.control import MinWarning |
| from utilities.control import OptimizerControl |
| from utilities.control import SetGUC |
| from utilities.utilities import add_postfix |
| from utilities.utilities import is_platform_gp6_or_up |
| from utilities.utilities import unique_string |
| from utilities.utilities import rotate |
| from utilities.utilities import madlib_version |
| from utilities.utilities import is_platform_pg |
| from utilities.utilities import get_seg_number |
| from utilities.utilities import get_segments_per_host |
| from utilities.utilities import rename_table |
| |
| import json |
| from collections import defaultdict |
| import random |
| import datetime |
| |
| from tensorflow.keras.models import * |
| mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL |
| mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL |
| dist_key_col = DISTRIBUTION_KEY_COLNAME |
| |
| """ |
| FitMultipleModel: This class implements the Model Hopper technique for |
| training multiple models in parallel. The goal of this function is to train |
| multiple different models on the same data with different parameters. |
| The main advantage of this method over running the existing fit function in a |
| loop is avoiding inaccuracies caused by the model averaging. The basic idea of |
| model hopper is simple. Let's assume that there are n segments and c*n model |
| configurations. We begin with distributing these configs to segments. After |
| that, each segment trains their c models on the data they have locally for one |
| iteration. Once we have these models, we move them to a different segment |
| (hopping) as well as receive a different set of models. Once we have the new |
| models, we use our segments data to refine them (similar to the warm start |
| functionality). Once every model hops through every segment, we consider an |
| iteration complete. |
| |
| This method ensures that we don't have to average any model and the |
| loss&accuracy is very close to the ideal case, where all of the data is in one |
| segment. |
| |
| Note that this function is disabled for Postgres. |
| """ |
| |
| @MinWarning("warning") |
| class FitMultipleModel(): |
| def __init__(self, schema_madlib, source_table, model_output_table, |
| model_selection_table, num_iterations, |
| use_gpus=False, validation_table=None, |
| metrics_compute_frequency=None, warm_start=False, name="", |
| description="", use_caching=False, metrics_elapsed_time_offset=0, **kwargs): |
| """ |
| |
| :param schema_madlib: schema name |
| :param source_table: input table containing training dataset |
| :param model_output_table: output table |
| :param model_selection_table: input table containing model configs |
| :param num_iterations: number of iterations to train |
| :param use_gpus: determines whether GPUs are to be used for training |
| :param validation_table: input table containing the validation dataset |
| :param metrics_compute_frequency: Frequency to compute per-iteration metrics |
| for the training dataset and validation dataset |
| :param warm_start: indicates whether to initialize weights with the coefficients |
| from the last call of the fit function |
| :param name: name |
| :param description: description |
| :param metrics_elapsed_time_offset: time elapsed for the previous call to fit_multiple |
| (internal param used by automl to accumulate |
| metrics_elapsed_time) |
| """ |
| # set the random seed for visit order/scheduling |
| random.seed(1) |
| if is_platform_pg(): |
| plpy.error( |
| "DL: Multiple model training is not supported on PostgreSQL.") |
| self.source_table = source_table |
| self.validation_table = validation_table |
| self.model_selection_table = model_selection_table |
| if self.model_selection_table: |
| self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary') |
| |
| self.num_iterations = num_iterations |
| self.metrics_compute_frequency = metrics_compute_frequency |
| self.name = name |
| self.description = description |
| self.use_caching = use_caching if use_caching is not None else False |
| self.module_name = 'madlib_keras_fit_multiple_model' |
| self.schema_madlib = schema_madlib |
| self.version = madlib_version(self.schema_madlib) |
| self.mst_key_col = ModelSelectionSchema.MST_KEY |
| self.model_id_col = ModelSelectionSchema.MODEL_ID |
| self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS |
| self.fit_params_col = ModelSelectionSchema.FIT_PARAMS |
| self.model_arch_table_col = ModelSelectionSchema.MODEL_ARCH_TABLE |
| self.model_weights_col = ModelArchSchema.MODEL_WEIGHTS |
| self.model_arch_col = ModelArchSchema.MODEL_ARCH |
| self.train_mst_metric_eval_time = defaultdict(list) |
| self.train_mst_loss = defaultdict(list) |
| self.train_mst_metric = defaultdict(list) |
| self.info_str = "" |
| self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape") |
| self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape") |
| self.use_gpus = use_gpus |
| self.segments_per_host = get_segments_per_host() |
| self.cached_source_table = unique_string('cached_source_table') |
| self.metrics_elapsed_time_offset = metrics_elapsed_time_offset |
| if self.use_gpus: |
| self.accessible_gpus_for_seg = get_accessible_gpus_for_seg( |
| self.schema_madlib, self.segments_per_host, self.module_name) |
| else: |
| self.accessible_gpus_for_seg = get_seg_number()*[0] |
| |
| self.original_model_output_table = model_output_table |
| if self.original_model_output_table: |
| self.model_info_table = add_postfix(self.original_model_output_table, '_info') |
| self.model_summary_table = add_postfix( |
| self.original_model_output_table, '_summary') |
| |
| self.model_output_table = self.original_model_output_table |
| |
| """ |
| For warm start, we need to copy the model output table to a temp table |
| because we call truncate on the model output table while training. |
| If the query gets aborted, we need to make sure that the user passed |
| model output table can be recovered. |
| """ |
| self.warm_start = bool(warm_start) |
| self.warm_start_msts = [] |
| if self.warm_start: |
| self.model_output_table = unique_string('initial_model') |
| |
| self.fit_validator_train = FitMultipleInputValidator( |
| self.source_table, self.validation_table, self.original_model_output_table, |
| self.model_selection_table, self.model_selection_summary_table, |
| mb_dep_var_col, mb_indep_var_col, self.num_iterations, |
| self.model_info_table, self.mst_key_col, self.model_arch_table_col, |
| self.metrics_compute_frequency, self.warm_start, self.use_gpus, |
| self.accessible_gpus_for_seg) |
| if self.metrics_compute_frequency is None: |
| self.metrics_compute_frequency = num_iterations |
| |
| |
| self.msts = self.fit_validator_train.msts |
| self.model_arch_table = self.fit_validator_train.model_arch_table |
| self.object_table = self.fit_validator_train.object_table |
| self.metrics_iters = [] |
| self.object_map_col = 'object_map' |
| if self.object_table is not None: |
| self.populate_object_map() |
| |
| original_cuda_env = None |
| if CUDA_VISIBLE_DEVICES_KEY in os.environ: |
| original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY] |
| |
| self.dist_key_mapping, self.images_per_seg_train = \ |
| get_image_count_per_seg_for_minibatched_data_from_db( |
| self.source_table) |
| |
| if self.validation_table: |
| self.valid_mst_metric_eval_time = defaultdict(list) |
| self.valid_mst_loss = defaultdict(list) |
| self.valid_mst_metric = defaultdict(list) |
| self.dist_key_mapping_valid, self.images_per_seg_valid = \ |
| get_image_count_per_seg_for_minibatched_data_from_db( |
| self.validation_table) |
| self.mst_weights_tbl = unique_string(desp='mst_weights') |
| self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule') |
| |
| self.dist_keys = query_dist_keys(self.source_table, dist_key_col) |
| if len(self.msts) < len(self.dist_keys): |
| self.msts_for_schedule = self.msts + [None] * \ |
| (len(self.dist_keys) - len(self.msts)) |
| else: |
| self.msts_for_schedule = self.msts |
| random.shuffle(self.msts_for_schedule) |
| self.grand_schedule = self.generate_schedule(self.msts_for_schedule) |
| self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME |
| self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else '' |
| |
| if self.warm_start: |
| self.create_model_output_table_warm_start() |
| else: |
| self.create_model_output_table() |
| |
| self.weights_to_update_tbl = unique_string(desp='weights_to_update') |
| self.fit_multiple_model() |
| |
| # Update and cleanup metadata tables |
| self.insert_info_table() |
| self.create_model_summary_table() |
| if self.warm_start: |
| self.cleanup_for_warm_start() |
| reset_cuda_env(original_cuda_env) |
| |
| def fit_multiple_model(self): |
| # WARNING: set orca off to prevent unwanted redistribution |
| with OptimizerControl(False): |
| self.start_training_time = datetime.datetime.now() |
| self.metrics_elapsed_start_time = time.time() |
| self.train_multiple_model() |
| self.end_training_time = datetime.datetime.now() |
| |
| def cleanup_for_warm_start(self): |
| """ |
| 1. drop original model table |
| 2. rename temp to original |
| :return: |
| """ |
| drop_query = "DROP TABLE IF EXISTS {}".format( |
| self.original_model_output_table) |
| plpy.execute(drop_query) |
| rename_table(self.schema_madlib, self.model_output_table, |
| self.original_model_output_table) |
| |
| def train_multiple_model(self): |
| total_msts = len(self.msts_for_schedule) |
| for iter in range(1, self.num_iterations+1): |
| for mst_idx in range(total_msts): |
| mst_row = [self.grand_schedule[dist_key][mst_idx] |
| for dist_key in self.dist_keys] |
| self.create_mst_schedule_table(mst_row) |
| self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1) |
| if mst_idx == 0: |
| start_iteration = time.time() |
| self.run_training(mst_idx, mst_idx==0 and iter==1) |
| if mst_idx == (total_msts - 1): |
| end_iteration = time.time() |
| self.info_str = "\tTime for training in iteration " \ |
| "{0}: {1} sec\n".format(iter, |
| end_iteration - |
| start_iteration) |
| if should_compute_metrics_this_iter(iter, |
| self.metrics_compute_frequency, |
| self.num_iterations): |
| self.metrics_iters.append(iter) |
| self.info_str += "\tTraining set after iteration {0}:".format(iter) |
| self.evaluate_model(iter, self.source_table, True) |
| if self.validation_table: |
| self.evaluate_model(iter, self.validation_table, False) |
| plpy.info("\n"+self.info_str) |
| plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self)) |
| |
| def evaluate_model(self, epoch, table, is_train): |
| if is_train: |
| mst_metric_eval_time = self.train_mst_metric_eval_time |
| mst_loss = self.train_mst_loss |
| mst_metric = self.train_mst_metric |
| seg_ids = self.dist_key_mapping |
| images_per_seg = self.images_per_seg_train |
| else: |
| mst_metric_eval_time = self.valid_mst_metric_eval_time |
| mst_loss = self.valid_mst_loss |
| mst_metric = self.valid_mst_metric |
| seg_ids = self.dist_key_mapping_valid |
| images_per_seg = self.images_per_seg_valid |
| self.info_str += "\n\tValidation set after iteration {0}:".format(epoch) |
| for mst in self.msts: |
| model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col]) |
| _, metric, loss = compute_loss_and_metrics( |
| self.schema_madlib, table, "$madlib${0}$madlib$".format( |
| mst[self.compile_params_col]), |
| model_arch, |
| None, |
| self.use_gpus, |
| self.accessible_gpus_for_seg, |
| seg_ids, |
| images_per_seg, |
| [], [], True, |
| mst[self.object_map_col], |
| self.model_output_table, |
| mst[self.mst_key_col]) |
| mst_metric_eval_time[mst[self.mst_key_col]] \ |
| .append(self.metrics_elapsed_time_offset + (time.time() - self.metrics_elapsed_start_time)) |
| mst_loss[mst[self.mst_key_col]].append(loss) |
| mst_metric[mst[self.mst_key_col]].append(metric) |
| self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss) |
| |
| def generate_schedule(self, msts): |
| """ Generate the schedule for models hopping to segments """ |
| grand_schedule = {} |
| for index, dist_key in enumerate(self.dist_keys): |
| grand_schedule[dist_key] = rotate(msts, index) |
| return grand_schedule |
| |
| def populate_object_map(self): |
| builtin_losses = dir(losses) |
| builtin_metrics = update_builtin_metrics(dir(metrics)) |
| |
| # Track distinct custom functions in compile_params |
| custom_fn_names = [] |
| # Track their corresponding mst_keys to pass along the custom function |
| # definition read from the object table. |
| # For compile_params calling builtin functions the object_map is set to |
| # None. |
| custom_fn_mst_idx = [] |
| for mst, mst_idx in zip(self.msts, range(len(self.msts))): |
| compile_params = mst[self.compile_params_col] |
| # We assume that the compile_param is validated as part |
| # of the loading mst_table and thus not validating here |
| # Also, it is validated later when we compile the model |
| # on the segments |
| compile_dict = convert_string_of_args_to_dict(compile_params) |
| |
| local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None |
| local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None |
| if local_loss and (local_loss not in [a.lower() for a in builtin_losses]): |
| custom_fn_names.append(local_loss) |
| custom_fn_mst_idx.append(mst_idx) |
| if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]): |
| custom_fn_names.append(local_metric) |
| custom_fn_mst_idx.append(mst_idx) |
| |
| if len(custom_fn_names) > 0: |
| # Pass only unique custom_fn_names to query from object table |
| custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names))) |
| for mst_idx in custom_fn_mst_idx: |
| self.msts[mst_idx][self.object_map_col] = custom_fn_object_map |
| |
| def create_mst_schedule_table(self, mst_row): |
| mst_temp_query = """ |
| CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl} |
| ({self.model_id_col} INTEGER, |
| {self.compile_params_col} VARCHAR, |
| {self.fit_params_col} VARCHAR, |
| {dist_key_col} INTEGER, |
| {self.mst_key_col} INTEGER, |
| {self.object_map_col} BYTEA) |
| """.format(dist_key_col=dist_key_col, **locals()) |
| plpy.execute(mst_temp_query) |
| for mst, dist_key in zip(mst_row, self.dist_keys): |
| if mst: |
| model_id = mst[self.model_id_col] |
| compile_params = mst[self.compile_params_col] |
| fit_params = mst[self.fit_params_col] |
| mst_key = mst[self.mst_key_col] |
| object_map = mst[self.object_map_col] |
| else: |
| model_id = "NULL" |
| compile_params = "NULL" |
| fit_params = "NULL" |
| mst_key = "NULL" |
| object_map = None |
| mst_insert_query = plpy.prepare( |
| """ |
| INSERT INTO {self.mst_current_schedule_tbl} |
| VALUES ({model_id}, |
| $madlib${compile_params}$madlib$, |
| $madlib${fit_params}$madlib$, |
| {dist_key}, |
| {mst_key}, |
| $1) |
| """.format(**locals()), ["BYTEA"]) |
| plpy.execute(mst_insert_query, [object_map]) |
| |
| def create_model_output_table(self): |
| output_table_create_query = """ |
| CREATE TABLE {self.model_output_table} |
| ({self.mst_key_col} INTEGER PRIMARY KEY, |
| {self.model_weights_col} BYTEA, |
| {self.model_arch_col} JSON) |
| """.format(self=self) |
| plpy.execute(output_table_create_query) |
| self.initialize_model_output_and_info() |
| |
| def create_model_output_table_warm_start(self): |
| """ |
| For warm start, we need to copy the model output table to a temp table |
| because we call truncate on the model output table while training. |
| If the query gets aborted, we need to make sure that the user passed |
| model output table can be recovered. |
| """ |
| plpy.execute(""" |
| CREATE TABLE {self.model_output_table} ( |
| LIKE {self.original_model_output_table} INCLUDING indexes); |
| """.format(self=self)) |
| |
| plpy.execute("""INSERT INTO {self.model_output_table} |
| SELECT * FROM {self.original_model_output_table}; |
| """.format(self=self)) |
| |
| plpy.execute(""" DELETE FROM {self.model_output_table} |
| WHERE {self.mst_key_col} NOT IN ( |
| SELECT {self.mst_key_col} FROM {self.model_selection_table}) |
| """.format(self=self)) |
| self.warm_start_msts = plpy.execute( |
| """ SELECT array_agg({0}) AS a FROM {1} |
| """.format(self.mst_key_col, self.model_output_table))[0]['a'] |
| plpy.execute("DROP TABLE {0}".format(self.model_info_table)) |
| self.initialize_model_output_and_info() |
| |
| def initialize_model_output_and_info(self): |
| info_table_create_query = """ |
| CREATE TABLE {self.model_info_table} |
| ({self.mst_key_col} INTEGER PRIMARY KEY, |
| {self.model_id_col} INTEGER, |
| {self.compile_params_col} TEXT, |
| {self.fit_params_col} TEXT, |
| model_type TEXT, |
| model_size DOUBLE PRECISION, |
| metrics_elapsed_time DOUBLE PRECISION[], |
| metrics_type TEXT[], |
| loss_type TEXT, |
| training_metrics_final DOUBLE PRECISION, |
| training_loss_final DOUBLE PRECISION, |
| training_metrics DOUBLE PRECISION[], |
| training_loss DOUBLE PRECISION[], |
| validation_metrics_final DOUBLE PRECISION, |
| validation_loss_final DOUBLE PRECISION, |
| validation_metrics DOUBLE PRECISION[], |
| validation_loss DOUBLE PRECISION[]) |
| """.format(self=self) |
| |
| plpy.execute(info_table_create_query) |
| for mst in self.msts: |
| model_arch, model_weights = get_model_arch_weights(self.model_arch_table, |
| mst[self.model_id_col]) |
| |
| |
| # If warm start is enabled, weights from transfer learning cannot be |
| # used, even if a particular model doesn't have warm start weights. |
| if self.warm_start: |
| model_weights = None |
| mst_filter = """ |
| WHERE {mst_col}={mst_key} |
| """.format( |
| mst_col=self.mst_key_col, |
| mst_key=mst['mst_key'] |
| ) |
| |
| else: |
| mst_filter = '' |
| |
| serialized_weights = get_initial_weights(self.model_output_table, |
| model_arch, |
| model_weights, |
| mst['mst_key'] in self.warm_start_msts, |
| self.use_gpus, |
| self.accessible_gpus_for_seg, |
| mst_filter |
| ) |
| model_size = sys.getsizeof(serialized_weights) / 1024.0 |
| |
| metrics_list = get_metrics_from_compile_param( |
| mst[self.compile_params_col]) |
| is_metrics_specified = True if metrics_list else False |
| metrics_type = 'ARRAY{0}'.format( |
| metrics_list) if is_metrics_specified else 'NULL' |
| |
| loss_type = get_loss_from_compile_param(mst[self.compile_params_col]) |
| loss_type = loss_type if loss_type else 'NULL' |
| |
| info_table_insert_query = """ |
| INSERT INTO {self.model_info_table}({self.mst_key_col}, |
| {self.model_id_col}, {self.compile_params_col}, |
| {self.fit_params_col}, model_type, model_size, |
| metrics_type, loss_type) |
| VALUES ({mst_key_val}, {model_id}, |
| $madlib${compile_params}$madlib$, |
| $madlib${fit_params}$madlib$, '{model_type}', |
| {model_size}, {metrics_type}, '{loss_type}') |
| """.format(self=self, |
| mst_key_val=mst[self.mst_key_col], |
| model_id=mst[self.model_id_col], |
| compile_params=mst[self.compile_params_col], |
| fit_params=mst[self.fit_params_col], |
| model_type='madlib_keras', |
| model_size=model_size, |
| metrics_type=metrics_type, |
| loss_type=loss_type) |
| plpy.execute(info_table_insert_query) |
| |
| if not mst['mst_key'] in self.warm_start_msts: |
| output_table_insert_query = """ |
| INSERT INTO {self.model_output_table}( |
| {self.mst_key_col}, {self.model_weights_col}, |
| {self.model_arch_col}) |
| VALUES ({mst_key}, $1, $2) |
| """.format(self=self, |
| mst_key=mst[self.mst_key_col]) |
| output_table_insert_query_prepared = plpy.prepare( |
| output_table_insert_query, ["bytea", "json"]) |
| plpy.execute(output_table_insert_query_prepared, [ |
| serialized_weights, model_arch]) |
| |
| def create_model_summary_table(self): |
| if self.warm_start: |
| plpy.execute("DROP TABLE {0}".format(self.model_summary_table)) |
| src_summary_dict = get_source_summary_table_dict(self.fit_validator_train) |
| class_values = src_summary_dict['class_values'] |
| class_values_type = src_summary_dict['class_values_type'] |
| dep_vartype = src_summary_dict['dep_vartype'] |
| dependent_varname = \ |
| src_summary_dict['dependent_varname_in_source_table'] |
| independent_varname = \ |
| src_summary_dict['independent_varname_in_source_table'] |
| norm_const = src_summary_dict['norm_const'] |
| self.validation_table = 'NULL' if self.validation_table is None \ |
| else '$MAD${0}$MAD$'.format(self.validation_table) |
| if class_values is None: |
| num_classes = 'NULL' |
| else: |
| num_classes = len(class_values) |
| name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name) |
| descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description) |
| object_table = 'NULL' if self.object_table is None \ |
| else '$MAD${0}$MAD$'.format(self.object_table) |
| metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL' |
| class_values_colname = CLASS_VALUES_COLNAME |
| dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME |
| normalizing_const_colname = NORMALIZING_CONST_COLNAME |
| float32_sql_type = FLOAT32_SQL_TYPE |
| create_query = plpy.prepare(""" |
| CREATE TABLE {self.model_summary_table} AS |
| SELECT |
| $MAD${self.source_table}$MAD$::TEXT AS source_table, |
| {self.validation_table}::TEXT AS validation_table, |
| $MAD${self.model_output_table}$MAD$::TEXT AS model, |
| $MAD${self.model_info_table}$MAD$::TEXT AS model_info, |
| $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname, |
| $MAD${independent_varname}$MAD$::TEXT AS independent_varname, |
| $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table, |
| $MAD${self.model_selection_table}$MAD$::TEXT AS model_selection_table, |
| {object_table}::TEXT AS object_table, |
| {self.num_iterations}::INTEGER AS num_iterations, |
| {self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency, |
| {self.warm_start} AS warm_start, |
| {name}::TEXT AS name, |
| {descr}::TEXT AS description, |
| '{self.start_training_time}'::TIMESTAMP AS start_training_time, |
| '{self.end_training_time}'::TIMESTAMP AS end_training_time, |
| '{self.version}'::TEXT AS madlib_version, |
| {num_classes}::INTEGER AS num_classes, |
| $1 AS {class_values_colname}, |
| $MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname}, |
| {norm_const}::{float32_sql_type} AS {normalizing_const_colname}, |
| ARRAY{metrics_iters}::INTEGER[] AS metrics_iters |
| """.format(**locals()), [class_values_type]) |
| plpy.execute(create_query, [class_values]) |
| |
| def update_info_table(self, mst, is_train): |
| mst_key = mst[self.mst_key_col] |
| metrics, metrics_final, metrics_elapsed_time = \ |
| "NULL", "NULL", "NULL" |
| if is_train: |
| mst_metric = self.train_mst_metric |
| mst_metric_eval_time = self.train_mst_metric_eval_time |
| mst_loss = self.train_mst_loss |
| else: |
| mst_metric = self.valid_mst_metric |
| mst_metric_eval_time = self.valid_mst_metric_eval_time |
| mst_loss = self.valid_mst_loss |
| |
| if mst_key in mst_metric: |
| metrics_final, metrics = get_metrics_sql_string(mst_metric[mst_key]) |
| metrics_elapsed_time = mst_metric_eval_time[mst_key] |
| metrics_elapsed_time = "ARRAY{}".format(metrics_elapsed_time) |
| loss_final, loss = get_metrics_sql_string(mst_loss[mst_key]) |
| |
| if is_train: |
| update_query = """ |
| UPDATE {self.model_info_table} SET |
| training_metrics_final = {metrics_final}, |
| training_loss_final = {loss_final}, |
| metrics_elapsed_time = {metrics_elapsed_time}, |
| training_metrics = {metrics}, |
| training_loss = {loss} |
| WHERE {self.mst_key_col} = {mst_key} |
| """.format(**locals()) |
| else: |
| update_query = """ |
| UPDATE {self.model_info_table} SET |
| validation_metrics_final = {metrics_final}, |
| validation_loss_final = {loss_final}, |
| metrics_elapsed_time = {metrics_elapsed_time}, |
| validation_metrics = {metrics}, |
| validation_loss = {loss} |
| WHERE {self.mst_key_col} = {mst_key} |
| """.format(**locals()) |
| plpy.execute(update_query) |
| |
| def insert_info_table(self): |
| for mst in self.msts: |
| self.update_info_table(mst, True) |
| if self.validation_table: |
| self.update_info_table(mst, False) |
| |
| def run_training(self, mst_idx, is_very_first_hop): |
| # NOTE: In the DL module, we want to avoid CREATING TEMP tables |
| # (creates a slice which stays until the session is disconnected) |
| # or minimize writing queries that generate plans with Motions (creating |
| # multiple slices on segments). |
| # This is mainly to avoid any GPU memory allocation failures. Since GPU |
| # memory allocation is tied to the process where it is initialized, failures |
| # may occur when a newly created slice(process) tries allocating GPU memory |
| # which is already allocated by a previously created slice(process). |
| # Therefore we want to have queries that do not add motions and all the |
| # sub-queries running Keras/tensorflow operations reuse the same slice(process) |
| # that was used for initializing GPU memory. |
| use_gpus = self.use_gpus if self.use_gpus else False |
| mst_weights_query = """ |
| CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS |
| SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col}, |
| model_arch_tbl.{self.model_arch_col} |
| FROM |
| {self.mst_current_schedule_tbl} mst_tbl |
| LEFT JOIN {self.model_output_table} wgh_tbl |
| ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col} |
| LEFT JOIN {self.model_arch_table} model_arch_tbl |
| ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col} |
| DISTRIBUTED BY ({dist_key_col}) |
| """.format(dist_key_col=dist_key_col, |
| **locals()) |
| plpy.execute(mst_weights_query) |
| use_gpus = self.use_gpus if self.use_gpus else False |
| dep_shape_col = self.dep_shape_col |
| ind_shape_col = self.ind_shape_col |
| dep_var = mb_dep_var_col |
| indep_var = mb_indep_var_col |
| source_table = self.source_table |
| where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self) |
| if self.use_caching: |
| # Caching populates the independent_var and dependent_var into the cache on the very first hop |
| # For the very_first_hop, we want to run the transition function on all segments, including |
| # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check |
| # on mst_key. Once the cache is populated, with the independent_var and dependent_var values |
| # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src |
| # table to join for referencing the dist_key |
| if is_very_first_hop: |
| plpy.execute(""" |
| DROP TABLE IF EXISTS {self.cached_source_table}; |
| CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col}); |
| """.format(self=self, dist_key_col=dist_key_col)) |
| else: |
| dep_shape_col = 'ARRAY[0]' |
| ind_shape_col = 'ARRAY[0]' |
| dep_var = 'NULL' |
| indep_var = 'NULL' |
| source_table = self.cached_source_table |
| if is_very_first_hop or self.is_final_training_call: |
| where_clause = "" |
| |
| uda_query = """ |
| CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS |
| SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col}, |
| {mb_indep_var_col}, |
| {dep_shape_col}, |
| {ind_shape_col}, |
| {self.mst_weights_tbl}.{self.model_arch_col}::TEXT, |
| {self.mst_weights_tbl}.{self.compile_params_col}::TEXT, |
| {self.mst_weights_tbl}.{self.fit_params_col}::TEXT, |
| src.{dist_key_col}, |
| ARRAY{self.dist_key_mapping}, |
| src.{self.gp_segment_id_col}, |
| {self.segments_per_host}, |
| ARRAY{self.images_per_seg_train}, |
| ARRAY{self.accessible_gpus_for_seg}, |
| {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA, |
| {is_final_training_call}::BOOLEAN, |
| {use_caching}::BOOLEAN, |
| {self.mst_weights_tbl}.{self.object_map_col}::BYTEA |
| )::BYTEA AS {self.model_weights_col}, |
| {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col} |
| ,src.{dist_key_col} AS {dist_key_col} |
| FROM {source_table} src JOIN {self.mst_weights_tbl} |
| USING ({dist_key_col}) |
| {where_clause} |
| GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col} |
| DISTRIBUTED BY({dist_key_col}) |
| """.format(mb_dep_var_col=dep_var, |
| mb_indep_var_col=indep_var, |
| dep_shape_col=dep_shape_col, |
| ind_shape_col=ind_shape_col, |
| is_final_training_call=self.is_final_training_call, |
| use_caching=self.use_caching, |
| dist_key_col=dist_key_col, |
| use_gpus=use_gpus, |
| source_table=source_table, |
| where_clause=where_clause, |
| self=self |
| ) |
| plpy.execute(uda_query) |
| |
| update_query = """ |
| UPDATE {self.model_output_table} |
| SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col} |
| FROM {self.weights_to_update_tbl} |
| WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col} |
| """.format(self=self) |
| plpy.execute(update_query) |
| |
| self.truncate_and_drop_tables() |
| |
| def truncate_and_drop_tables(self): |
| """ |
| Context: UPDATE statements in postgres are not in-place replacements but |
| the row to be updated is marked for deletion(note that the disk space for |
| this row doesn't get released until vaccuum is called) and a new row in |
| inserted. |
| |
| This function will clear out the disk space used by the model_output_table |
| and also drop all the other intermediate tables. |
| If available, set the `` guc so that the truncate command can release the |
| disk space. The disk space will be released immediately and hence the |
| model_output table won't grow in size with each UPDATE statement. |
| |
| Without this guc, the disk space won't be released and each |
| call to the UPDATE statement will keep adding to the disk space. The disk |
| space will only be released when the query is completed. |
| |
| The guc can cause data loss if not used properly. Since truncate will |
| actually clear the disk space immediately, there is no way to recover to |
| the state before truncate was called on that table. So this guc should only |
| be set for intermediate tables and never for tables created outside the |
| scope of the fit_multiple udf. |
| |
| Workflow |
| 1. Create temp table from model table (including the indexes) |
| 2. truncate the model table to release disk space |
| 3. rename temp table to model table so that it can be reused for the next |
| hop |
| :return: |
| """ |
| |
| with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"): |
| temp_model_table = unique_string('updated_model') |
| unlogged_table = self.unlogged_table if not self.is_final_training_call else '' |
| plpy.execute(""" |
| CREATE {unlogged_table} TABLE {temp_model_table} ( LIKE {self.model_output_table} |
| INCLUDING indexes);""".format(temp_model_table=temp_model_table, |
| unlogged_table=unlogged_table, |
| self=self)) |
| plpy.execute(""" |
| INSERT INTO {temp_model_table} SELECT * FROM {self.model_output_table}; |
| TRUNCATE TABLE {self.model_output_table}; |
| DROP TABLE {self.model_output_table}; |
| """.format(temp_model_table=temp_model_table, self=self)) |
| rename_table(self.schema_madlib, temp_model_table, |
| self.model_output_table) |
| plpy.execute(""" |
| TRUNCATE TABLE {self.mst_weights_tbl}, {self.mst_current_schedule_tbl}, |
| {self.weights_to_update_tbl}; |
| DROP TABLE IF EXISTS {self.mst_weights_tbl}, {self.mst_current_schedule_tbl}, |
| {self.weights_to_update_tbl};""".format(self=self)) |