blob: 92875247083b2de043d82d5ac458c1ec8bab0776 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
import time
import sys
from madlib_keras import compute_loss_and_metrics
from madlib_keras import get_initial_weights
from madlib_keras import get_model_arch_weights
from madlib_keras import get_source_summary_table_dict
from madlib_keras import should_compute_metrics_this_iter
from madlib_keras_helper import *
from madlib_keras_model_selection import ModelSelectionSchema
from madlib_keras_validator import *
from madlib_keras_wrapper import *
from utilities.control import MinWarning
from utilities.control import OptimizerControl
from utilities.control import SetGUC
from utilities.utilities import add_postfix
from utilities.utilities import is_platform_gp6_or_up
from utilities.utilities import unique_string
from utilities.utilities import rotate
from utilities.utilities import madlib_version
from utilities.utilities import is_platform_pg
from utilities.utilities import get_seg_number
from utilities.utilities import get_segments_per_host
from utilities.utilities import rename_table
import json
from collections import defaultdict
import random
import datetime
from tensorflow.keras.models import *
mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
dist_key_col = DISTRIBUTION_KEY_COLNAME
"""
FitMultipleModel: This class implements the Model Hopper technique for
training multiple models in parallel. The goal of this function is to train
multiple different models on the same data with different parameters.
The main advantage of this method over running the existing fit function in a
loop is avoiding inaccuracies caused by the model averaging. The basic idea of
model hopper is simple. Let's assume that there are n segments and c*n model
configurations. We begin with distributing these configs to segments. After
that, each segment trains their c models on the data they have locally for one
iteration. Once we have these models, we move them to a different segment
(hopping) as well as receive a different set of models. Once we have the new
models, we use our segments data to refine them (similar to the warm start
functionality). Once every model hops through every segment, we consider an
iteration complete.
This method ensures that we don't have to average any model and the
loss&accuracy is very close to the ideal case, where all of the data is in one
segment.
Note that this function is disabled for Postgres.
"""
@MinWarning("warning")
class FitMultipleModel():
def __init__(self, schema_madlib, source_table, model_output_table,
model_selection_table, num_iterations,
use_gpus=False, validation_table=None,
metrics_compute_frequency=None, warm_start=False, name="",
description="", use_caching=False, metrics_elapsed_time_offset=0, **kwargs):
"""
:param schema_madlib: schema name
:param source_table: input table containing training dataset
:param model_output_table: output table
:param model_selection_table: input table containing model configs
:param num_iterations: number of iterations to train
:param use_gpus: determines whether GPUs are to be used for training
:param validation_table: input table containing the validation dataset
:param metrics_compute_frequency: Frequency to compute per-iteration metrics
for the training dataset and validation dataset
:param warm_start: indicates whether to initialize weights with the coefficients
from the last call of the fit function
:param name: name
:param description: description
:param metrics_elapsed_time_offset: time elapsed for the previous call to fit_multiple
(internal param used by automl to accumulate
metrics_elapsed_time)
"""
# set the random seed for visit order/scheduling
random.seed(1)
if is_platform_pg():
plpy.error(
"DL: Multiple model training is not supported on PostgreSQL.")
self.source_table = source_table
self.validation_table = validation_table
self.model_selection_table = model_selection_table
if self.model_selection_table:
self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
self.num_iterations = num_iterations
self.metrics_compute_frequency = metrics_compute_frequency
self.name = name
self.description = description
self.use_caching = use_caching if use_caching is not None else False
self.module_name = 'madlib_keras_fit_multiple_model'
self.schema_madlib = schema_madlib
self.version = madlib_version(self.schema_madlib)
self.mst_key_col = ModelSelectionSchema.MST_KEY
self.model_id_col = ModelSelectionSchema.MODEL_ID
self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS
self.fit_params_col = ModelSelectionSchema.FIT_PARAMS
self.model_arch_table_col = ModelSelectionSchema.MODEL_ARCH_TABLE
self.model_weights_col = ModelArchSchema.MODEL_WEIGHTS
self.model_arch_col = ModelArchSchema.MODEL_ARCH
self.train_mst_metric_eval_time = defaultdict(list)
self.train_mst_loss = defaultdict(list)
self.train_mst_metric = defaultdict(list)
self.info_str = ""
self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
self.use_gpus = use_gpus
self.segments_per_host = get_segments_per_host()
self.cached_source_table = unique_string('cached_source_table')
self.metrics_elapsed_time_offset = metrics_elapsed_time_offset
if self.use_gpus:
self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
self.schema_madlib, self.segments_per_host, self.module_name)
else:
self.accessible_gpus_for_seg = get_seg_number()*[0]
self.original_model_output_table = model_output_table
if self.original_model_output_table:
self.model_info_table = add_postfix(self.original_model_output_table, '_info')
self.model_summary_table = add_postfix(
self.original_model_output_table, '_summary')
self.model_output_table = self.original_model_output_table
"""
For warm start, we need to copy the model output table to a temp table
because we call truncate on the model output table while training.
If the query gets aborted, we need to make sure that the user passed
model output table can be recovered.
"""
self.warm_start = bool(warm_start)
self.warm_start_msts = []
if self.warm_start:
self.model_output_table = unique_string('initial_model')
self.fit_validator_train = FitMultipleInputValidator(
self.source_table, self.validation_table, self.original_model_output_table,
self.model_selection_table, self.model_selection_summary_table,
mb_dep_var_col, mb_indep_var_col, self.num_iterations,
self.model_info_table, self.mst_key_col, self.model_arch_table_col,
self.metrics_compute_frequency, self.warm_start, self.use_gpus,
self.accessible_gpus_for_seg)
if self.metrics_compute_frequency is None:
self.metrics_compute_frequency = num_iterations
self.msts = self.fit_validator_train.msts
self.model_arch_table = self.fit_validator_train.model_arch_table
self.object_table = self.fit_validator_train.object_table
self.metrics_iters = []
self.object_map_col = 'object_map'
if self.object_table is not None:
self.populate_object_map()
original_cuda_env = None
if CUDA_VISIBLE_DEVICES_KEY in os.environ:
original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
self.dist_key_mapping, self.images_per_seg_train = \
get_image_count_per_seg_for_minibatched_data_from_db(
self.source_table)
if self.validation_table:
self.valid_mst_metric_eval_time = defaultdict(list)
self.valid_mst_loss = defaultdict(list)
self.valid_mst_metric = defaultdict(list)
self.dist_key_mapping_valid, self.images_per_seg_valid = \
get_image_count_per_seg_for_minibatched_data_from_db(
self.validation_table)
self.mst_weights_tbl = unique_string(desp='mst_weights')
self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
if len(self.msts) < len(self.dist_keys):
self.msts_for_schedule = self.msts + [None] * \
(len(self.dist_keys) - len(self.msts))
else:
self.msts_for_schedule = self.msts
random.shuffle(self.msts_for_schedule)
self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
if self.warm_start:
self.create_model_output_table_warm_start()
else:
self.create_model_output_table()
self.weights_to_update_tbl = unique_string(desp='weights_to_update')
self.fit_multiple_model()
# Update and cleanup metadata tables
self.insert_info_table()
self.create_model_summary_table()
if self.warm_start:
self.cleanup_for_warm_start()
reset_cuda_env(original_cuda_env)
def fit_multiple_model(self):
# WARNING: set orca off to prevent unwanted redistribution
with OptimizerControl(False):
self.start_training_time = datetime.datetime.now()
self.metrics_elapsed_start_time = time.time()
self.train_multiple_model()
self.end_training_time = datetime.datetime.now()
def cleanup_for_warm_start(self):
"""
1. drop original model table
2. rename temp to original
:return:
"""
drop_query = "DROP TABLE IF EXISTS {}".format(
self.original_model_output_table)
plpy.execute(drop_query)
rename_table(self.schema_madlib, self.model_output_table,
self.original_model_output_table)
def train_multiple_model(self):
total_msts = len(self.msts_for_schedule)
for iter in range(1, self.num_iterations+1):
for mst_idx in range(total_msts):
mst_row = [self.grand_schedule[dist_key][mst_idx]
for dist_key in self.dist_keys]
self.create_mst_schedule_table(mst_row)
self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
if mst_idx == 0:
start_iteration = time.time()
self.run_training(mst_idx, mst_idx==0 and iter==1)
if mst_idx == (total_msts - 1):
end_iteration = time.time()
self.info_str = "\tTime for training in iteration " \
"{0}: {1} sec\n".format(iter,
end_iteration -
start_iteration)
if should_compute_metrics_this_iter(iter,
self.metrics_compute_frequency,
self.num_iterations):
self.metrics_iters.append(iter)
self.info_str += "\tTraining set after iteration {0}:".format(iter)
self.evaluate_model(iter, self.source_table, True)
if self.validation_table:
self.evaluate_model(iter, self.validation_table, False)
plpy.info("\n"+self.info_str)
plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
def evaluate_model(self, epoch, table, is_train):
if is_train:
mst_metric_eval_time = self.train_mst_metric_eval_time
mst_loss = self.train_mst_loss
mst_metric = self.train_mst_metric
seg_ids = self.dist_key_mapping
images_per_seg = self.images_per_seg_train
else:
mst_metric_eval_time = self.valid_mst_metric_eval_time
mst_loss = self.valid_mst_loss
mst_metric = self.valid_mst_metric
seg_ids = self.dist_key_mapping_valid
images_per_seg = self.images_per_seg_valid
self.info_str += "\n\tValidation set after iteration {0}:".format(epoch)
for mst in self.msts:
model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
_, metric, loss = compute_loss_and_metrics(
self.schema_madlib, table, "$madlib${0}$madlib$".format(
mst[self.compile_params_col]),
model_arch,
None,
self.use_gpus,
self.accessible_gpus_for_seg,
seg_ids,
images_per_seg,
[], [], epoch, True,
mst[self.object_map_col],
self.model_output_table,
mst[self.mst_key_col])
mst_metric_eval_time[mst[self.mst_key_col]] \
.append(self.metrics_elapsed_time_offset + (time.time() - self.metrics_elapsed_start_time))
mst_loss[mst[self.mst_key_col]].append(loss)
mst_metric[mst[self.mst_key_col]].append(metric)
self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss)
def generate_schedule(self, msts):
""" Generate the schedule for models hopping to segments """
grand_schedule = {}
for index, dist_key in enumerate(self.dist_keys):
grand_schedule[dist_key] = rotate(msts, index)
return grand_schedule
def populate_object_map(self):
builtin_losses = dir(losses)
builtin_metrics = update_builtin_metrics(dir(metrics))
# Track distinct custom functions in compile_params
custom_fn_names = []
# Track their corresponding mst_keys to pass along the custom function
# definition read from the object table.
# For compile_params calling builtin functions the object_map is set to
# None.
custom_fn_mst_idx = []
for mst, mst_idx in zip(self.msts, range(len(self.msts))):
compile_params = mst[self.compile_params_col]
# We assume that the compile_param is validated as part
# of the loading mst_table and thus not validating here
# Also, it is validated later when we compile the model
# on the segments
compile_dict = convert_string_of_args_to_dict(compile_params)
local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
custom_fn_names.append(local_loss)
custom_fn_mst_idx.append(mst_idx)
if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
custom_fn_names.append(local_metric)
custom_fn_mst_idx.append(mst_idx)
if len(custom_fn_names) > 0:
# Pass only unique custom_fn_names to query from object table
custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
for mst_idx in custom_fn_mst_idx:
self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
def create_mst_schedule_table(self, mst_row):
mst_temp_query = """
CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
({self.model_id_col} INTEGER,
{self.compile_params_col} VARCHAR,
{self.fit_params_col} VARCHAR,
{dist_key_col} INTEGER,
{self.mst_key_col} INTEGER,
{self.object_map_col} BYTEA)
""".format(dist_key_col=dist_key_col, **locals())
plpy.execute(mst_temp_query)
for mst, dist_key in zip(mst_row, self.dist_keys):
if mst:
model_id = mst[self.model_id_col]
compile_params = mst[self.compile_params_col]
fit_params = mst[self.fit_params_col]
mst_key = mst[self.mst_key_col]
object_map = mst[self.object_map_col]
else:
model_id = "NULL"
compile_params = "NULL"
fit_params = "NULL"
mst_key = "NULL"
object_map = None
mst_insert_query = plpy.prepare(
"""
INSERT INTO {self.mst_current_schedule_tbl}
VALUES ({model_id},
$madlib${compile_params}$madlib$,
$madlib${fit_params}$madlib$,
{dist_key},
{mst_key},
$1)
""".format(**locals()), ["BYTEA"])
plpy.execute(mst_insert_query, [object_map])
def create_model_output_table(self):
output_table_create_query = """
CREATE TABLE {self.model_output_table}
({self.mst_key_col} INTEGER PRIMARY KEY,
{self.model_weights_col} BYTEA,
{self.model_arch_col} JSON)
""".format(self=self)
plpy.execute(output_table_create_query)
self.initialize_model_output_and_info()
def create_model_output_table_warm_start(self):
"""
For warm start, we need to copy the model output table to a temp table
because we call truncate on the model output table while training.
If the query gets aborted, we need to make sure that the user passed
model output table can be recovered.
"""
plpy.execute("""
CREATE TABLE {self.model_output_table} (
LIKE {self.original_model_output_table} INCLUDING indexes);
""".format(self=self))
plpy.execute("""INSERT INTO {self.model_output_table}
SELECT * FROM {self.original_model_output_table};
""".format(self=self))
plpy.execute(""" DELETE FROM {self.model_output_table}
WHERE {self.mst_key_col} NOT IN (
SELECT {self.mst_key_col} FROM {self.model_selection_table})
""".format(self=self))
self.warm_start_msts = plpy.execute(
""" SELECT array_agg({0}) AS a FROM {1}
""".format(self.mst_key_col, self.model_output_table))[0]['a']
plpy.execute("DROP TABLE {0}".format(self.model_info_table))
self.initialize_model_output_and_info()
def initialize_model_output_and_info(self):
info_table_create_query = """
CREATE TABLE {self.model_info_table}
({self.mst_key_col} INTEGER PRIMARY KEY,
{self.model_id_col} INTEGER,
{self.compile_params_col} TEXT,
{self.fit_params_col} TEXT,
model_type TEXT,
model_size DOUBLE PRECISION,
metrics_elapsed_time DOUBLE PRECISION[],
metrics_type TEXT[],
loss_type TEXT,
training_metrics_final DOUBLE PRECISION,
training_loss_final DOUBLE PRECISION,
training_metrics DOUBLE PRECISION[],
training_loss DOUBLE PRECISION[],
validation_metrics_final DOUBLE PRECISION,
validation_loss_final DOUBLE PRECISION,
validation_metrics DOUBLE PRECISION[],
validation_loss DOUBLE PRECISION[])
""".format(self=self)
plpy.execute(info_table_create_query)
for mst in self.msts:
model_arch, model_weights = get_model_arch_weights(self.model_arch_table,
mst[self.model_id_col])
# If warm start is enabled, weights from transfer learning cannot be
# used, even if a particular model doesn't have warm start weights.
if self.warm_start:
model_weights = None
mst_filter = """
WHERE {mst_col}={mst_key}
""".format(
mst_col=self.mst_key_col,
mst_key=mst['mst_key']
)
else:
mst_filter = ''
serialized_weights = get_initial_weights(self.model_output_table,
model_arch,
model_weights,
mst['mst_key'] in self.warm_start_msts,
self.use_gpus,
self.accessible_gpus_for_seg,
mst_filter
)
model_size = sys.getsizeof(serialized_weights) / 1024.0
metrics_list = get_metrics_from_compile_param(
mst[self.compile_params_col])
is_metrics_specified = True if metrics_list else False
metrics_type = 'ARRAY{0}'.format(
metrics_list) if is_metrics_specified else 'NULL'
loss_type = get_loss_from_compile_param(mst[self.compile_params_col])
loss_type = loss_type if loss_type else 'NULL'
info_table_insert_query = """
INSERT INTO {self.model_info_table}({self.mst_key_col},
{self.model_id_col}, {self.compile_params_col},
{self.fit_params_col}, model_type, model_size,
metrics_type, loss_type)
VALUES ({mst_key_val}, {model_id},
$madlib${compile_params}$madlib$,
$madlib${fit_params}$madlib$, '{model_type}',
{model_size}, {metrics_type}, '{loss_type}')
""".format(self=self,
mst_key_val=mst[self.mst_key_col],
model_id=mst[self.model_id_col],
compile_params=mst[self.compile_params_col],
fit_params=mst[self.fit_params_col],
model_type='madlib_keras',
model_size=model_size,
metrics_type=metrics_type,
loss_type=loss_type)
plpy.execute(info_table_insert_query)
if not mst['mst_key'] in self.warm_start_msts:
output_table_insert_query = """
INSERT INTO {self.model_output_table}(
{self.mst_key_col}, {self.model_weights_col},
{self.model_arch_col})
VALUES ({mst_key}, $1, $2)
""".format(self=self,
mst_key=mst[self.mst_key_col])
output_table_insert_query_prepared = plpy.prepare(
output_table_insert_query, ["bytea", "json"])
plpy.execute(output_table_insert_query_prepared, [
serialized_weights, model_arch])
def create_model_summary_table(self):
if self.warm_start:
plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
src_summary_dict = get_source_summary_table_dict(self.fit_validator_train)
class_values = src_summary_dict['class_values']
class_values_type = src_summary_dict['class_values_type']
dep_vartype = src_summary_dict['dep_vartype']
dependent_varname = \
src_summary_dict['dependent_varname_in_source_table']
independent_varname = \
src_summary_dict['independent_varname_in_source_table']
norm_const = src_summary_dict['norm_const']
self.validation_table = 'NULL' if self.validation_table is None \
else '$MAD${0}$MAD$'.format(self.validation_table)
if class_values is None:
num_classes = 'NULL'
else:
num_classes = len(class_values)
name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
object_table = 'NULL' if self.object_table is None \
else '$MAD${0}$MAD$'.format(self.object_table)
metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
class_values_colname = CLASS_VALUES_COLNAME
dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
normalizing_const_colname = NORMALIZING_CONST_COLNAME
float32_sql_type = FLOAT32_SQL_TYPE
create_query = plpy.prepare("""
CREATE TABLE {self.model_summary_table} AS
SELECT
$MAD${self.source_table}$MAD$::TEXT AS source_table,
{self.validation_table}::TEXT AS validation_table,
$MAD${self.model_output_table}$MAD$::TEXT AS model,
$MAD${self.model_info_table}$MAD$::TEXT AS model_info,
$MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
$MAD${independent_varname}$MAD$::TEXT AS independent_varname,
$MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
$MAD${self.model_selection_table}$MAD$::TEXT AS model_selection_table,
{object_table}::TEXT AS object_table,
{self.num_iterations}::INTEGER AS num_iterations,
{self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
{self.warm_start} AS warm_start,
{name}::TEXT AS name,
{descr}::TEXT AS description,
'{self.start_training_time}'::TIMESTAMP AS start_training_time,
'{self.end_training_time}'::TIMESTAMP AS end_training_time,
'{self.version}'::TEXT AS madlib_version,
{num_classes}::INTEGER AS num_classes,
$1 AS {class_values_colname},
$MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname},
{norm_const}::{float32_sql_type} AS {normalizing_const_colname},
ARRAY{metrics_iters}::INTEGER[] AS metrics_iters
""".format(**locals()), [class_values_type])
plpy.execute(create_query, [class_values])
def update_info_table(self, mst, is_train):
mst_key = mst[self.mst_key_col]
metrics, metrics_final, metrics_elapsed_time = \
"NULL", "NULL", "NULL"
if is_train:
mst_metric = self.train_mst_metric
mst_metric_eval_time = self.train_mst_metric_eval_time
mst_loss = self.train_mst_loss
else:
mst_metric = self.valid_mst_metric
mst_metric_eval_time = self.valid_mst_metric_eval_time
mst_loss = self.valid_mst_loss
if mst_key in mst_metric:
metrics_final, metrics = get_metrics_sql_string(mst_metric[mst_key])
metrics_elapsed_time = mst_metric_eval_time[mst_key]
metrics_elapsed_time = "ARRAY{}".format(metrics_elapsed_time)
loss_final, loss = get_metrics_sql_string(mst_loss[mst_key])
if is_train:
update_query = """
UPDATE {self.model_info_table} SET
training_metrics_final = {metrics_final},
training_loss_final = {loss_final},
metrics_elapsed_time = {metrics_elapsed_time},
training_metrics = {metrics},
training_loss = {loss}
WHERE {self.mst_key_col} = {mst_key}
""".format(**locals())
else:
update_query = """
UPDATE {self.model_info_table} SET
validation_metrics_final = {metrics_final},
validation_loss_final = {loss_final},
metrics_elapsed_time = {metrics_elapsed_time},
validation_metrics = {metrics},
validation_loss = {loss}
WHERE {self.mst_key_col} = {mst_key}
""".format(**locals())
plpy.execute(update_query)
def insert_info_table(self):
for mst in self.msts:
self.update_info_table(mst, True)
if self.validation_table:
self.update_info_table(mst, False)
def run_training(self, mst_idx, is_very_first_hop):
# NOTE: In the DL module, we want to avoid CREATING TEMP tables
# (creates a slice which stays until the session is disconnected)
# or minimize writing queries that generate plans with Motions (creating
# multiple slices on segments).
# This is mainly to avoid any GPU memory allocation failures. Since GPU
# memory allocation is tied to the process where it is initialized, failures
# may occur when a newly created slice(process) tries allocating GPU memory
# which is already allocated by a previously created slice(process).
# Therefore we want to have queries that do not add motions and all the
# sub-queries running Keras/tensorflow operations reuse the same slice(process)
# that was used for initializing GPU memory.
use_gpus = self.use_gpus if self.use_gpus else False
mst_weights_query = """
CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
model_arch_tbl.{self.model_arch_col}
FROM
{self.mst_current_schedule_tbl} mst_tbl
LEFT JOIN {self.model_output_table} wgh_tbl
ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
LEFT JOIN {self.model_arch_table} model_arch_tbl
ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
DISTRIBUTED BY ({dist_key_col})
""".format(dist_key_col=dist_key_col,
**locals())
plpy.execute(mst_weights_query)
use_gpus = self.use_gpus if self.use_gpus else False
dep_shape_col = self.dep_shape_col
ind_shape_col = self.ind_shape_col
dep_var = mb_dep_var_col
indep_var = mb_indep_var_col
source_table = self.source_table
where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
if self.use_caching:
# Caching populates the independent_var and dependent_var into the cache on the very first hop
# For the very_first_hop, we want to run the transition function on all segments, including
# the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
# on mst_key. Once the cache is populated, with the independent_var and dependent_var values
# for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
# table to join for referencing the dist_key
if is_very_first_hop:
plpy.execute("""
DROP TABLE IF EXISTS {self.cached_source_table};
CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
""".format(self=self, dist_key_col=dist_key_col))
else:
dep_shape_col = 'ARRAY[0]'
ind_shape_col = 'ARRAY[0]'
dep_var = 'NULL'
indep_var = 'NULL'
source_table = self.cached_source_table
if is_very_first_hop or self.is_final_training_call:
where_clause = ""
uda_query = """
CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
{mb_indep_var_col},
{dep_shape_col},
{ind_shape_col},
{self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
{self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
{self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
src.{dist_key_col},
ARRAY{self.dist_key_mapping},
src.{self.gp_segment_id_col},
{self.segments_per_host},
ARRAY{self.images_per_seg_train},
{use_gpus}::BOOLEAN,
ARRAY{self.accessible_gpus_for_seg},
{self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
{is_final_training_call}::BOOLEAN,
{use_caching}::BOOLEAN,
{self.mst_weights_tbl}.{self.object_map_col}::BYTEA
)::BYTEA AS {self.model_weights_col},
{self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
,src.{dist_key_col} AS {dist_key_col}
FROM {source_table} src JOIN {self.mst_weights_tbl}
USING ({dist_key_col})
{where_clause}
GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
DISTRIBUTED BY({dist_key_col})
""".format(mb_dep_var_col=dep_var,
mb_indep_var_col=indep_var,
dep_shape_col=dep_shape_col,
ind_shape_col=ind_shape_col,
is_final_training_call=self.is_final_training_call,
use_caching=self.use_caching,
dist_key_col=dist_key_col,
use_gpus=use_gpus,
source_table=source_table,
where_clause=where_clause,
self=self
)
plpy.execute(uda_query)
update_query = """
UPDATE {self.model_output_table}
SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
FROM {self.weights_to_update_tbl}
WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
""".format(self=self)
plpy.execute(update_query)
self.truncate_and_drop_tables()
def truncate_and_drop_tables(self):
"""
Context: UPDATE statements in postgres are not in-place replacements but
the row to be updated is marked for deletion(note that the disk space for
this row doesn't get released until vaccuum is called) and a new row in
inserted.
This function will clear out the disk space used by the model_output_table
and also drop all the other intermediate tables.
If available, set the `` guc so that the truncate command can release the
disk space. The disk space will be released immediately and hence the
model_output table won't grow in size with each UPDATE statement.
Without this guc, the disk space won't be released and each
call to the UPDATE statement will keep adding to the disk space. The disk
space will only be released when the query is completed.
The guc can cause data loss if not used properly. Since truncate will
actually clear the disk space immediately, there is no way to recover to
the state before truncate was called on that table. So this guc should only
be set for intermediate tables and never for tables created outside the
scope of the fit_multiple udf.
Workflow
1. Create temp table from model table (including the indexes)
2. truncate the model table to release disk space
3. rename temp table to model table so that it can be reused for the next
hop
:return:
"""
with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
temp_model_table = unique_string('updated_model')
unlogged_table = self.unlogged_table if not self.is_final_training_call else ''
plpy.execute("""
CREATE {unlogged_table} TABLE {temp_model_table} ( LIKE {self.model_output_table}
INCLUDING indexes);""".format(temp_model_table=temp_model_table,
unlogged_table=unlogged_table,
self=self))
plpy.execute("""
INSERT INTO {temp_model_table} SELECT * FROM {self.model_output_table};
TRUNCATE TABLE {self.model_output_table};
DROP TABLE {self.model_output_table};
""".format(temp_model_table=temp_model_table, self=self))
rename_table(self.schema_madlib, temp_model_table,
self.model_output_table)
plpy.execute("""
TRUNCATE TABLE {self.mst_weights_tbl}, {self.mst_current_schedule_tbl},
{self.weights_to_update_tbl};
DROP TABLE IF EXISTS {self.mst_weights_tbl}, {self.mst_current_schedule_tbl},
{self.weights_to_update_tbl};""".format(self=self))