blob: 7569d3920a0715867deefd4418a69e90a7008800 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
import time
import sys
from keras.models import *
from madlib_keras import compute_loss_and_metrics
from madlib_keras import get_initial_weights
from madlib_keras import get_model_arch_weights
from madlib_keras import get_segments_and_gpus
from madlib_keras import get_source_summary_table_dict
from madlib_keras_helper import *
from madlib_keras_model_selection import ModelSelectionSchema
from madlib_keras_validator import *
from madlib_keras_wrapper import *
from utilities.control import MinWarning
from utilities.control import OptimizerControl
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
from utilities.utilities import rotate
from utilities.utilities import madlib_version
from utilities.utilities import is_platform_pg
import json
from collections import defaultdict
import random
import datetime
mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
dist_key_col = DISTRIBUTION_KEY_COLNAME
@MinWarning("warning")
class FitMultipleModel():
def __init__(self, schema_madlib, source_table, model_output_table,
model_selection_table, num_iterations,
gpus_per_host=0, validation_table=None, **kwargs):
# set the random seed for visit order/scheduling
random.seed(1)
if is_platform_pg():
plpy.error(
"DL: Multiple model training is not supported on PostgreSQL.")
self.source_table = source_table
self.validation_table = validation_table
self.model_selection_table = model_selection_table
if self.model_selection_table:
self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
self.model_output_table = model_output_table
if self.model_output_table:
self.model_info_table = add_postfix(model_output_table, '_info')
self.model_summary_table = add_postfix(
model_output_table, '_summary')
self.num_iterations = num_iterations
self.module_name = 'madlib_keras_fit_multiple_model'
self.schema_madlib = schema_madlib
self.version = madlib_version(self.schema_madlib)
self.mst_key_col = ModelSelectionSchema.MST_KEY
self.model_id_col = ModelSelectionSchema.MODEL_ID
self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS
self.fit_params_col = ModelSelectionSchema.FIT_PARAMS
self.model_arch_table_col = ModelSelectionSchema.MODEL_ARCH_TABLE
self.model_weights_col=ModelArchSchema.MODEL_WEIGHTS
self.model_arch_col=ModelArchSchema.MODEL_ARCH
self.train_mst_metric_eval_time = defaultdict(list)
self.train_mst_loss = defaultdict(list)
self.train_mst_metric = defaultdict(list)
self.info_str = ""
self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
self.fit_validator_train = FitMultipleInputValidator(
self.source_table, self.validation_table, self.model_output_table,
self.model_selection_table, self.model_selection_summary_table,
mb_dep_var_col, mb_indep_var_col, self.num_iterations,
self.model_info_table, self.mst_key_col, self.model_arch_table_col,
1, False)
self.msts = self.fit_validator_train.msts
self.model_arch_table = self.fit_validator_train.model_arch_table
self.seg_ids_train, self.images_per_seg_train = \
get_image_count_per_seg_for_minibatched_data_from_db(
self.source_table)
if self.validation_table:
self.valid_mst_metric_eval_time = defaultdict(list)
self.valid_mst_loss = defaultdict(list)
self.valid_mst_metric = defaultdict(list)
self.seg_ids_valid, self.images_per_seg_valid = \
get_image_count_per_seg_for_minibatched_data_from_db(
self.validation_table)
self.mst_weights_tbl = unique_string(desp='mst_weights')
self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
if len(self.msts) < len(self.dist_keys):
self.msts_for_schedule = self.msts + [None] * \
(len(self.dist_keys) - len(self.msts))
else:
self.msts_for_schedule = self.msts
random.shuffle(self.msts_for_schedule)
self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
gpus_per_host)
self.create_model_output_table()
self.weights_to_update_tbl = unique_string(desp='weights_to_update')
self.fit_multiple_model()
def fit_multiple_model(self):
# WARNING: set orca off to prevent unwanted redistribution
with OptimizerControl(False):
original_cuda_env = None
if CUDA_VISIBLE_DEVICES_KEY in os.environ:
original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
self.start_training_time = datetime.datetime.now()
self.train_multiple_model()
self.end_training_time = datetime.datetime.now()
self.insert_info_table()
self.create_model_summary_table()
reset_cuda_env(original_cuda_env)
def train_multiple_model(self):
total_msts = len(self.msts_for_schedule)
for iter in range(self.num_iterations):
for mst_idx in range(total_msts):
mst_row = [self.grand_schedule[dist_key][mst_idx]
for dist_key in self.dist_keys]
self.create_mst_schedule_table(mst_row)
if mst_idx == 0:
start_iteration = time.time()
self.run_training()
if mst_idx == (total_msts - 1):
end_iteration = time.time()
self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(iter,
end_iteration - start_iteration)
self.info_str += "\tTraining set after iteration {0}:".format(iter)
self.evaluate_model(iter, self.source_table, True)
if self.validation_table:
self.evaluate_model(iter, self.validation_table, False)
plpy.info("\n"+self.info_str)
def evaluate_model(self, epoch, table, is_train):
if is_train:
mst_metric_eval_time = self.train_mst_metric_eval_time
mst_loss = self.train_mst_loss
mst_metric = self.train_mst_metric
seg_ids = self.seg_ids_train
images_per_seg = self.images_per_seg_train
else:
mst_metric_eval_time = self.valid_mst_metric_eval_time
mst_loss = self.valid_mst_loss
mst_metric = self.valid_mst_metric
seg_ids = self.seg_ids_valid
images_per_seg = self.images_per_seg_valid
for mst in self.msts:
state = query_weights(self.model_output_table, self.model_weights_col,
self.mst_key_col, mst[self.mst_key_col])
model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
serialized_weights = \
madlib_keras_serializer.get_serialized_1d_weights_from_state(
state)
metric_eval_time, metric, loss = compute_loss_and_metrics(
self.schema_madlib, table, "$madlib${0}$madlib$".format(
mst[self.compile_params_col]),
model_arch,
serialized_weights,
self.gpus_per_host,
self.segments_per_host,
seg_ids,
images_per_seg, [], [], epoch, True)
mst_metric_eval_time[mst[self.mst_key_col]] \
.append(metric_eval_time)
mst_loss[mst[self.mst_key_col]].append(loss)
mst_metric[mst[self.mst_key_col]].append(metric)
self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss)
def generate_schedule(self, msts):
""" Generate the schedule for models hopping to segments """
grand_schedule = {}
for index, dist_key in enumerate(self.dist_keys):
grand_schedule[dist_key] = rotate(msts, index)
return grand_schedule
def create_mst_schedule_table(self, mst_row):
mst_temp_query = """
CREATE TEMP TABLE {self.mst_current_schedule_tbl}
({self.model_id_col} INTEGER,
{self.compile_params_col} VARCHAR,
{self.fit_params_col} VARCHAR,
{dist_key_col} INTEGER,
{self.mst_key_col} INTEGER)
""".format(dist_key_col=dist_key_col, **locals())
plpy.execute(mst_temp_query)
for mst, dist_key in zip(mst_row, self.dist_keys):
if mst:
model_id = mst[self.model_id_col]
compile_params = mst[self.compile_params_col]
fit_params = mst[self.fit_params_col]
mst_key = mst[self.mst_key_col]
else:
model_id = "NULL"
compile_params = "NULL"
fit_params = "NULL"
mst_key = "NULL"
mst_insert_query = """
INSERT INTO {self.mst_current_schedule_tbl}
VALUES ({model_id},
$madlib${compile_params}$madlib$,
$madlib${fit_params}$madlib$,
{dist_key},
{mst_key})
""".format(**locals())
plpy.execute(mst_insert_query)
def create_model_output_table(self):
output_table_create_query = """
CREATE TABLE {self.model_output_table}
({self.mst_key_col} INTEGER PRIMARY KEY,
{self.model_weights_col} BYTEA,
{self.model_arch_col} JSON)
""".format(self=self)
info_table_create_query = """
CREATE TABLE {self.model_info_table}
({self.mst_key_col} INTEGER PRIMARY KEY,
{self.model_id_col} INTEGER,
{self.compile_params_col} TEXT,
{self.fit_params_col} TEXT,
model_type TEXT,
model_size DOUBLE PRECISION,
metrics_elapsed_time DOUBLE PRECISION[],
metrics_type TEXT[],
training_metrics_final DOUBLE PRECISION,
training_loss_final DOUBLE PRECISION,
training_metrics DOUBLE PRECISION[],
training_loss DOUBLE PRECISION[],
validation_metrics_final DOUBLE PRECISION,
validation_loss_final DOUBLE PRECISION,
validation_metrics DOUBLE PRECISION[],
validation_loss DOUBLE PRECISION[])
""".format(self=self)
plpy.execute(output_table_create_query)
plpy.execute(info_table_create_query)
for mst in self.msts:
model_arch, model_weights = get_model_arch_weights(self.model_arch_table,
mst[self.model_id_col])
serialized_weights = get_initial_weights(self.model_output_table,
model_arch,
model_weights,
False,
self.gpus_per_host
)
model = model_from_json(model_arch)
serialized_state = \
madlib_keras_serializer.serialize_state_with_nd_weights(
0, model.get_weights())
# serialized_weights = madlib_keras_serializer.serialize_nd_weights(
# model.get_weights())
model_size = sys.getsizeof(serialized_weights) / 1024.0
metrics_list = get_metrics_from_compile_param(
mst[self.compile_params_col])
is_metrics_specified = True if metrics_list else False
metrics_type = 'ARRAY{0}'.format(
metrics_list) if is_metrics_specified else 'NULL'
output_table_insert_query = """
INSERT INTO {self.model_output_table}(
{self.mst_key_col}, {self.model_weights_col},
{self.model_arch_col})
VALUES ({mst_key}, $1, $2)
""".format(self=self,
mst_key=mst[self.mst_key_col])
output_table_insert_query_prepared = plpy.prepare(
output_table_insert_query, ["bytea", "json"])
plpy.execute(output_table_insert_query_prepared, [
serialized_state, json.dumps(model_arch)])
info_table_insert_query = """
INSERT INTO {self.model_info_table}({self.mst_key_col},
{self.model_id_col}, {self.compile_params_col},
{self.fit_params_col}, model_type, model_size,
metrics_type)
VALUES ({mst_key_val}, {model_id},
$madlib${compile_params}$madlib$,
$madlib${fit_params}$madlib$, '{model_type}',
{model_size}, {metrics_type})
""".format(self=self,
mst_key_val=mst[self.mst_key_col],
model_id=mst[self.model_id_col],
compile_params=mst[self.compile_params_col],
fit_params=mst[self.fit_params_col],
model_type='madlib_keras',
model_size=model_size,
metrics_type=metrics_type)
plpy.execute(info_table_insert_query)
def create_model_summary_table(self):
src_summary_dict = get_source_summary_table_dict(self.fit_validator_train)
class_values = src_summary_dict['class_values']
dep_vartype = src_summary_dict['dep_vartype']
dependent_varname = \
src_summary_dict['dependent_varname_in_source_table']
independent_varname = \
src_summary_dict['independent_varname_in_source_table']
norm_const = src_summary_dict['norm_const']
num_classes = len(class_values)
class_values_colname = CLASS_VALUES_COLNAME
dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
normalizing_const_colname = NORMALIZING_CONST_COLNAME
float32_sql_type = FLOAT32_SQL_TYPE
update_query = """
CREATE TABLE {self.model_summary_table} AS
SELECT
$MAD${self.source_table}$MAD$::TEXT AS source_table,
$MAD${self.validation_table}$MAD$::TEXT AS validation_table,
$MAD${self.model_output_table}$MAD$::TEXT AS model,
$MAD${self.model_info_table}$MAD$::TEXT AS model_info,
$MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
$MAD${independent_varname}$MAD$::TEXT AS independent_varname,
$MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
{self.num_iterations}::INTEGER AS num_iterations,
'{self.start_training_time}'::TIMESTAMP AS start_training_time,
'{self.end_training_time}'::TIMESTAMP AS end_training_time,
'{self.version}'::TEXT AS madlib_version,
{num_classes}::INTEGER AS num_classes,
ARRAY{class_values}::TEXT[] AS {class_values_colname},
$MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname},
{norm_const}::{float32_sql_type} AS {normalizing_const_colname}
""".format(**locals())
plpy.execute(update_query)
def update_info_table(self, mst, is_train):
mst_key = mst[self.mst_key_col]
metrics, metrics_final, metrics_elapsed_time = \
"NULL", "NULL", "NULL"
if is_train:
mst_metric = self.train_mst_metric
mst_metric_eval_time = self.train_mst_metric_eval_time
mst_loss = self.train_mst_loss
else:
mst_metric = self.valid_mst_metric
mst_metric_eval_time = self.valid_mst_metric_eval_time
mst_loss = self.valid_mst_loss
if mst_key in mst_metric:
metrics = mst_metric[mst_key]
metrics_final = metrics[-1]
metrics_elapsed_time = mst_metric_eval_time[mst_key]
metrics = "ARRAY{}".format(metrics)
metrics_elapsed_time = "ARRAY{}".format(metrics_elapsed_time)
loss = mst_loss[mst_key]
loss_final = loss[-1]
loss = "ARRAY{}".format(loss)
if is_train:
update_query = """
UPDATE {self.model_info_table} SET
training_metrics_final = {metrics_final},
training_loss_final = {loss_final},
metrics_elapsed_time = {metrics_elapsed_time},
training_metrics = {metrics},
training_loss = {loss}
WHERE {self.mst_key_col} = {mst_key}
""".format(**locals())
else:
update_query = """
UPDATE {self.model_info_table} SET
validation_metrics_final = {metrics_final},
validation_loss_final = {loss_final},
metrics_elapsed_time = {metrics_elapsed_time},
validation_metrics = {metrics},
validation_loss = {loss}
WHERE {self.mst_key_col} = {mst_key}
""".format(**locals())
plpy.execute(update_query)
def insert_info_table(self):
for mst in self.msts:
self.update_info_table(mst, True)
if self.validation_table:
self.update_info_table(mst, False)
def run_training(self):
mst_weights_query = """
CREATE TEMP TABLE {self.mst_weights_tbl} AS
SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
model_arch_tbl.{self.model_arch_col}
FROM
{self.mst_current_schedule_tbl} mst_tbl
LEFT JOIN {self.model_output_table} wgh_tbl
ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
LEFT JOIN {self.model_arch_table} model_arch_tbl
ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
DISTRIBUTED BY ({dist_key_col})
""".format(dist_key_col=dist_key_col,
**locals())
plpy.execute(mst_weights_query)
uda_query = """
CREATE TABLE {self.weights_to_update_tbl} AS
SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
{mb_indep_var_col},
{self.dep_shape_col},
{self.ind_shape_col},
{self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
{self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
{self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
src.gp_segment_id,
ARRAY{self.seg_ids_train},
ARRAY{self.images_per_seg_train},
{self.gpus_per_host},
{self.segments_per_host},
{self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
{is_final_iteration}::BOOLEAN
)::BYTEA AS {self.model_weights_col},
{self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
,src.{dist_key_col} AS {dist_key_col}
FROM {self.source_table} src JOIN {self.mst_weights_tbl}
USING ({dist_key_col})
WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
DISTRIBUTED BY({dist_key_col})
""".format(mb_dep_var_col=mb_dep_var_col,
mb_indep_var_col=mb_indep_var_col,
is_final_iteration=True,
dist_key_col=dist_key_col,
self=self
)
plpy.execute(uda_query)
update_query = """
UPDATE {self.model_output_table}
SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
FROM {self.weights_to_update_tbl}
WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
""".format(self=self)
plpy.execute(update_query)
plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format(
self.mst_weights_tbl,
self.mst_current_schedule_tbl,
self.weights_to_update_tbl))