| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import math |
| import plpy |
| import time |
| from madlib_keras_automl import KerasAutoML, AutoMLConstants |
| from utilities.utilities import get_current_timestamp, get_seg_number, get_segments_per_host, \ |
| unique_string, add_postfix, extract_keyvalue_params, _assert, _assert_equal, rename_table, \ |
| is_platform_pg |
| from utilities.control import SetGUC |
| from madlib_keras_fit_multiple_model import FitMultipleModel |
| from madlib_keras_model_selection import MstSearch, ModelSelectionSchema |
| from utilities.validate_args import table_exists, drop_tables, input_tbl_valid |
| |
| class HyperbandSchedule(): |
| """The utility class for loading a hyperband schedule table with algorithm inputs. |
| |
| Attributes: |
| schedule_table (string): Name of output table containing hyperband schedule. |
| R (int): Maximum number of resources (iterations) that can be allocated to a single configuration. |
| eta (int): Controls the proportion of configurations discarded in each round of successive halving. |
| skip_last (int): The number of last rounds to skip. |
| """ |
| def __init__(self, schedule_table, R, eta=3, skip_last=0): |
| if is_platform_pg(): |
| plpy.error( |
| "DL: Hyperband schedule is not supported on PostgreSQL.") |
| self.schedule_table = schedule_table # table name to store hyperband schedule |
| self.R = R # maximum iterations/epochs allocated to a configuration |
| self.eta = eta # defines downsampling rate |
| self.skip_last = skip_last |
| self.module_name = 'hyperband_schedule' |
| self.validate_inputs() |
| |
| # number of unique executions of Successive Halving (minus one) |
| self.s_max = int(math.floor(math.log(self.R, self.eta))) |
| self.validate_s_max() |
| |
| self.schedule_vals = [] |
| |
| self.calculate_schedule() |
| |
| def load(self): |
| """ |
| The entry point for loading the hyperband schedule table. |
| """ |
| self.create_schedule_table() |
| self.insert_into_schedule_table() |
| |
| def validate_inputs(self): |
| """ |
| Validates user input values |
| """ |
| _assert(self.eta > 1, "{0}: eta must be greater than 1".format(self.module_name)) |
| _assert(self.R >= self.eta, "{0}: R should not be less than eta".format(self.module_name)) |
| |
| def validate_s_max(self): |
| _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "{0}: skip_last must be " + |
| "non-negative and less than {1}".format(self.module_name,self.s_max)) |
| |
| def calculate_schedule(self): |
| """ |
| Calculates the hyperband schedule (number of configs and allocated resources) |
| in each round of each bracket and skips the number of last rounds specified in 'skip_last' |
| """ |
| for s in reversed(range(self.s_max+1)): |
| n = int(math.ceil(int((self.s_max+1)/(s+1))*math.pow(self.eta, s))) # initial number of configurations |
| r = self.R * math.pow(self.eta, -s) |
| |
| for i in range((s+1) - int(self.skip_last)): |
| # Computing each of the |
| n_i = n*math.pow(self.eta, -i) |
| r_i = r*math.pow(self.eta, i) |
| |
| self.schedule_vals.append({AutoMLConstants.BRACKET: s, |
| AutoMLConstants.ROUND: i, |
| AutoMLConstants.CONFIGURATIONS: int(n_i), |
| AutoMLConstants.RESOURCES: int(round(r_i))}) |
| |
| def create_schedule_table(self): |
| """Initializes the output schedule table""" |
| create_query = """ |
| CREATE TABLE {self.schedule_table} ( |
| {s} INTEGER, |
| {i} INTEGER, |
| {n_i} INTEGER, |
| {r_i} INTEGER, |
| unique ({s}, {i}) |
| ); |
| """.format(self=self, |
| s=AutoMLConstants.BRACKET, |
| i=AutoMLConstants.ROUND, |
| n_i=AutoMLConstants.CONFIGURATIONS, |
| r_i=AutoMLConstants.RESOURCES) |
| plpy.execute(create_query) |
| |
| def insert_into_schedule_table(self): |
| """Insert everything in self.schedule_vals into the output schedule table.""" |
| for sd in self.schedule_vals: |
| sd_s = sd[AutoMLConstants.BRACKET] |
| sd_i = sd[AutoMLConstants.ROUND] |
| sd_n_i = sd[AutoMLConstants.CONFIGURATIONS] |
| sd_r_i = sd[AutoMLConstants.RESOURCES] |
| insert_query = """ |
| INSERT INTO |
| {self.schedule_table}( |
| {s_col}, |
| {i_col}, |
| {n_i_col}, |
| {r_i_col} |
| ) |
| VALUES ( |
| {sd_s}, |
| {sd_i}, |
| {sd_n_i}, |
| {sd_r_i} |
| ) |
| """.format(s_col=AutoMLConstants.BRACKET, |
| i_col=AutoMLConstants.ROUND, |
| n_i_col=AutoMLConstants.CONFIGURATIONS, |
| r_i_col=AutoMLConstants.RESOURCES, |
| **locals()) |
| plpy.execute(insert_query) |
| |
| class AutoMLHyperband(KerasAutoML): |
| """ |
| This class implements Hyperband, an infinite-arm bandit based algorithm that speeds up random search |
| through adaptive resource allocation, successive halving (SHA), and early stopping. |
| |
| This class showcases a novel hyperband implementation by executing the hyperband rounds 'diagonally' |
| to evaluate multiple configurations together and leverage the compute power of MPP databases such as Greenplum. |
| |
| This automl method inherits qualities from the automl class. |
| """ |
| def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table, |
| model_id_list, compile_params_grid, fit_params_grid, automl_method, |
| automl_params, random_state=None, object_table=None, |
| use_gpus=False, validation_table=None, metrics_compute_frequency=None, |
| name=None, description=None, use_caching=False, **kwargs): |
| automl_method = automl_method if automl_method else AutoMLConstants.HYPERBAND |
| automl_params = automl_params if automl_params else 'R=6, eta=3, skip_last=0' |
| KerasAutoML.__init__(self, schema_madlib, source_table, model_output_table, model_arch_table, |
| model_selection_table, model_id_list, compile_params_grid, fit_params_grid, |
| automl_method, automl_params, random_state, object_table, use_gpus, |
| validation_table, metrics_compute_frequency, name, description, use_caching, **kwargs) |
| self.validate_and_define_inputs() |
| self.create_model_output_table() |
| self.create_model_output_info_table() |
| self.find_hyperband_config() |
| |
| def validate_and_define_inputs(self): |
| automl_params_dict = extract_keyvalue_params(self.automl_params, |
| lower_case_names=False) |
| # casting dict values to int |
| for i in automl_params_dict: |
| _assert(i in AutoMLConstants.HYPERBAND_PARAMS, |
| "{0}: Invalid param(s) passed in for hyperband. "\ |
| "Only R, eta, and skip_last may be specified".format(self.module_name)) |
| automl_params_dict[i] = int(automl_params_dict[i]) |
| _assert(len(automl_params_dict) >= 1 and len(automl_params_dict) <= 3, |
| "{0}: Only R, eta, and skip_last may be specified".format(self.module_name)) |
| for i in automl_params_dict: |
| if i == AutoMLConstants.R: |
| self.R = automl_params_dict[AutoMLConstants.R] |
| elif i == AutoMLConstants.ETA: |
| self.eta = automl_params_dict[AutoMLConstants.ETA] |
| elif i == AutoMLConstants.SKIP_LAST: |
| self.skip_last = automl_params_dict[AutoMLConstants.SKIP_LAST] |
| else: |
| plpy.error("{0}: {1} is an invalid automl param".format(self.module_name, i)) |
| _assert(self.eta > 1, "{0}: eta must be greater than 1".format(self.module_name)) |
| _assert(self.R >= self.eta, "{0}: R should not be less than eta".format(self.module_name)) |
| self.s_max = int(math.floor(math.log(self.R, self.eta))) |
| _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "{0}: skip_last must be " \ |
| "non-negative and less than {1}".format(self.module_name, self.s_max)) |
| |
| def find_hyperband_config(self): |
| """ |
| Executes the diagonal hyperband algorithm. |
| """ |
| initial_vals = {} |
| |
| # get hyper parameter configs for each s |
| for s in reversed(range(self.s_max+1)): |
| n = int(math.ceil(int((self.s_max+1)/(s+1))*math.pow(self.eta, s))) # initial number of configurations |
| r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for |
| initial_vals[s] = (n, int(round(r))) |
| self.start_training_time = get_current_timestamp(AutoMLConstants.TIME_FORMAT) |
| random_search = MstSearch(self.schema_madlib, |
| self.model_arch_table, |
| self.model_selection_table, |
| self.model_id_list, |
| self.compile_params_grid, |
| self.fit_params_grid, |
| 'random', |
| sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), |
| self.random_state, |
| self.object_table) |
| random_search.load() # for populating mst tables |
| |
| # for creating the summary table for usage in fit multiple |
| plpy.execute("CREATE TABLE {AutoMLSchema.MST_SUMMARY_TABLE} AS " \ |
| "SELECT * FROM {random_search.model_selection_summary_table}".format(AutoMLSchema=AutoMLConstants, |
| random_search=random_search)) |
| ranges_dict = self.mst_key_ranges_dict(initial_vals) |
| # to store the bracket and round numbers |
| s_dict, i_dict = {}, {} |
| for key, val in ranges_dict.items(): |
| for mst_key in range(val[0], val[1]+1): |
| s_dict[mst_key] = key |
| i_dict[mst_key] = -1 |
| |
| # outer loop on diagonal |
| metrics_elapsed_time_offset = 0 |
| for i in range((self.s_max+1) - int(self.skip_last)): |
| # inner loop on s desc |
| temp_lst = [] |
| configs_prune_lookup = {} |
| for s in range(self.s_max, self.s_max-i-1, -1): |
| n = initial_vals[s][0] |
| n_i = n * math.pow(self.eta, -i+self.s_max-s) |
| configs_prune_lookup[s] = int(round(n_i)) |
| temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i)) |
| num_iterations = int(initial_vals[self.s_max-i][1]) |
| plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format( |
| num_iterations)) |
| |
| self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup) # has keys to evaluate |
| active_keys = plpy.execute("SELECT {ModelSelectionSchema.MST_KEY} " \ |
| "FROM {AutoMLSchema.MST_TABLE}".format(AutoMLSchema=AutoMLConstants, |
| ModelSelectionSchema=ModelSelectionSchema)) |
| for k in active_keys: |
| i_dict[k[ModelSelectionSchema.MST_KEY]] += 1 |
| self.warm_start = int(i != 0) |
| mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None |
| start_time = time.time() |
| with SetGUC("plan_cache_mode", "force_generic_plan"): |
| model_training = FitMultipleModel(self.schema_madlib, |
| self.source_table, |
| AutoMLConstants.MODEL_OUTPUT_TABLE, |
| AutoMLConstants.MST_TABLE, |
| num_iterations, self.use_gpus, |
| self.validation_table, mcf, |
| self.warm_start, self.name, |
| self.description, |
| self.use_caching, |
| metrics_elapsed_time_offset) |
| model_training.fit_multiple_model() |
| metrics_elapsed_time_offset += time.time() - start_time |
| self.update_model_output_table() |
| self.update_model_output_info_table(i, initial_vals) |
| |
| self.print_best_mst_so_far() |
| |
| self.end_training_time = get_current_timestamp(AutoMLConstants.TIME_FORMAT) |
| self.add_additional_info_cols(s_dict, i_dict) |
| self.update_model_selection_table() |
| self.generate_model_output_summary_table() |
| self.remove_temp_tables() |
| |
| def mst_key_ranges_dict(self, initial_vals): |
| """ |
| Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of |
| executing a particular SHA bracket. |
| """ |
| d = {} |
| for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0 |
| if s_val == self.s_max: |
| d[s_val] = (1, initial_vals[s_val][0]) |
| else: |
| d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0]) |
| return d |
| |
| def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup): |
| """ |
| Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband. |
| :param i: outer diagonal loop iteration. |
| :param ranges_dict: model config ranges to group by bracket number. |
| :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal. |
| :return: |
| """ |
| if i == 0: |
| _assert_equal(len(configs_prune_lookup), 1, "invalid args") |
| lower_bound, upper_bound = ranges_dict[self.s_max] |
| plpy.execute("CREATE TABLE {AutoMLSchema.MST_TABLE} AS SELECT * FROM {self.model_selection_table} " |
| "WHERE {ModelSelectionSchema.MST_KEY} >= {lower_bound} " \ |
| "AND {ModelSelectionSchema.MST_KEY} <= {upper_bound}".format(self=self, |
| AutoMLSchema=AutoMLConstants, |
| lower_bound=lower_bound, |
| upper_bound=upper_bound, |
| ModelSelectionSchema=ModelSelectionSchema)) |
| return |
| # dropping and repopulating temp_mst_table |
| drop_tables([AutoMLConstants.MST_TABLE]) |
| |
| # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values |
| create_query = """ |
| CREATE TABLE {AutoMLSchema.MST_TABLE} ( |
| {mst_key} INTEGER, |
| {model_id} INTEGER, |
| {compile_params} VARCHAR, |
| {fit_params} VARCHAR, |
| unique ({model_id}, {compile_params}, {fit_params}) |
| ); |
| """.format(AutoMLSchema=AutoMLConstants, |
| mst_key=ModelSelectionSchema.MST_KEY, |
| model_id=ModelSelectionSchema.MODEL_ID, |
| compile_params=ModelSelectionSchema.COMPILE_PARAMS, |
| fit_params=ModelSelectionSchema.FIT_PARAMS) |
| plpy.execute(create_query) |
| |
| query = "" |
| new_configs = True |
| for s_val in configs_prune_lookup: |
| lower_bound, upper_bound = ranges_dict[s_val] |
| if new_configs: |
| query += "INSERT INTO {AutoMLSchema.MST_TABLE} SELECT {ModelSelectionSchema.MST_KEY}, " \ |
| "{ModelSelectionSchema.MODEL_ID}, {ModelSelectionSchema.COMPILE_PARAMS}, " \ |
| "{ModelSelectionSchema.FIT_PARAMS} FROM {self.model_selection_table} WHERE " \ |
| "{ModelSelectionSchema.MST_KEY} >= {lower_bound} AND {ModelSelectionSchema.MST_KEY} <= " \ |
| "{upper_bound};".format(self=self, AutoMLSchema=AutoMLConstants, |
| ModelSelectionSchema=ModelSelectionSchema, |
| lower_bound=lower_bound, upper_bound=upper_bound) |
| new_configs = False |
| else: |
| query += "INSERT INTO {AutoMLSchema.MST_TABLE} SELECT {ModelSelectionSchema.MST_KEY}, " \ |
| "{ModelSelectionSchema.MODEL_ID}, {ModelSelectionSchema.COMPILE_PARAMS}, " \ |
| "{ModelSelectionSchema.FIT_PARAMS} " \ |
| "FROM {self.model_info_table} WHERE {ModelSelectionSchema.MST_KEY} >= {lower_bound} " \ |
| "AND {ModelSelectionSchema.MST_KEY} <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \ |
| "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLConstants, |
| ModelSelectionSchema=ModelSelectionSchema, |
| lower_bound=lower_bound, upper_bound=upper_bound, |
| configs_prune_lookup_val=configs_prune_lookup[s_val]) |
| plpy.execute(query) |
| |
| def update_model_output_table(self): |
| """ |
| Updates gathered information of a hyperband diagonal run to the overall model output table. |
| """ |
| # updates model weights for any previously trained configs |
| plpy.execute("UPDATE {self.model_output_table} a SET model_weights=" \ |
| "t.model_weights FROM {AutoMLSchema.MODEL_OUTPUT_TABLE} t " \ |
| "WHERE a.{mst_key}=t.{mst_key}".format(self=self, |
| mst_key=ModelSelectionSchema.MST_KEY, |
| AutoMLSchema=AutoMLConstants)) |
| |
| # truncate and re-creates table to avoid memory blow-ups |
| with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"): |
| temp_model_table = unique_string('updated_model') |
| plpy.execute("CREATE TABLE {temp_model_table} AS SELECT * FROM {self.model_output_table};" \ |
| "TRUNCATE {self.model_output_table}; " \ |
| "DROP TABLE {self.model_output_table};".format(temp_model_table=temp_model_table, self=self)) |
| rename_table(self.schema_madlib, temp_model_table, self.model_output_table) |
| |
| # inserts any newly trained configs |
| plpy.execute(""" |
| INSERT INTO {self.model_output_table} SELECT * |
| FROM {AutoMLSchema.MODEL_OUTPUT_TABLE} |
| WHERE {AutoMLSchema.MODEL_OUTPUT_TABLE}.mst_key NOT IN |
| ( SELECT {ModelSelectionSchema.MST_KEY} FROM {self.model_output_table} ) |
| """.format(self=self, AutoMLSchema=AutoMLConstants, ModelSelectionSchema=ModelSelectionSchema) |
| ) |
| |
| def update_model_output_info_table(self, i, initial_vals): |
| """ |
| Updates gathered information of a hyperband diagonal run to the overall model output info table. |
| :param i: outer diagonal loop iteration. |
| :param initial_vals: Dictionary of initial configurations and resources as part of the initial hyperband |
| schedule. |
| """ |
| # normalizing factor for metrics_iters due to warm start |
| epochs_factor = sum([n[1] for n in initial_vals.values()][::-1][:i]) # i & initial_vals args needed |
| iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \ |
| "FROM {AutoMLSchema.MODEL_SUMMARY_TABLE}".format(AutoMLSchema=AutoMLConstants)) |
| metrics_iters_val = [epochs_factor+mi for mi in iters[0]['metrics_iters']] # global iteration counter |
| |
| validation_update_q = "validation_metrics_final=t.validation_metrics_final, " \ |
| "validation_loss_final=t.validation_loss_final, " \ |
| "validation_metrics=a.validation_metrics || t.validation_metrics, " \ |
| "validation_loss=a.validation_loss || t.validation_loss, " \ |
| if self.validation_table else "" |
| |
| # updates train/val info for any previously trained configs |
| plpy.execute("UPDATE {self.model_info_table} a SET " \ |
| "metrics_elapsed_time=a.metrics_elapsed_time || t.metrics_elapsed_time, " \ |
| "training_metrics_final=t.training_metrics_final, " \ |
| "training_loss_final=t.training_loss_final, " \ |
| "training_metrics=a.training_metrics || t.training_metrics, " \ |
| "training_loss=a.training_loss || t.training_loss, ".format(self=self) + validation_update_q + |
| "{AutoMLSchema.METRICS_ITERS}=a.metrics_iters || ARRAY{metrics_iters_val}::INTEGER[] " \ |
| "FROM {AutoMLSchema.MODEL_INFO_TABLE} t " \ |
| "WHERE a.{mst_key}=t.{mst_key}".format(AutoMLSchema=AutoMLConstants, |
| mst_key=ModelSelectionSchema.MST_KEY, |
| metrics_iters_val=metrics_iters_val)) |
| |
| # inserts info about metrics and validation for newly trained model configs |
| plpy.execute("INSERT INTO {self.model_info_table} SELECT t.*, ARRAY{metrics_iters_val}::INTEGER[] AS metrics_iters " \ |
| "FROM {AutoMLSchema.MODEL_INFO_TABLE} t WHERE t.mst_key NOT IN " \ |
| "(SELECT {ModelSelectionSchema.MST_KEY} FROM {self.model_info_table})".format(self=self, |
| AutoMLSchema=AutoMLConstants, |
| metrics_iters_val=metrics_iters_val, |
| ModelSelectionSchema=ModelSelectionSchema)) |
| |
| def add_additional_info_cols(self, s_dict, i_dict): |
| """Adds s and i columns to the info table""" |
| |
| plpy.execute("ALTER TABLE {self.model_info_table} ADD COLUMN s int, ADD COLUMN i int;".format(self=self)) |
| |
| l = [(k, s_dict[k], i_dict[k]) for k in s_dict] |
| query = "UPDATE {self.model_info_table} t SET s=b.s_val, i=b.i_val FROM unnest(ARRAY{l}) " \ |
| "b (key integer, s_val integer, i_val integer) WHERE t.mst_key=b.key".format(self=self, l=l) |
| plpy.execute(query) |