blob: d6b362d204f5c2c2e188832fe53adddedf17a7e4 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from model_arch_info import *
from madlib_keras_helper import *
from madlib_keras_validator import *
from predict_input_params import PredictParamsProcessor
from utilities.control import MinWarning
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import get_segments_per_host
from utilities.utilities import unique_string
from utilities.utilities import get_psql_type
from utilities.utilities import split_quoted_delimited_str
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import quote_ident
from madlib_keras_wrapper import *
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
class BasePredict():
def __init__(self, schema_madlib, table_to_validate, test_table, id_col,
independent_varname, output_table, pred_type, use_gpus, module_name):
self.schema_madlib = schema_madlib
self.table_to_validate = table_to_validate
self.test_table = test_table
self.id_col = id_col
self.independent_varname = split_quoted_delimited_str(independent_varname)
self.output_table = output_table
self.pred_type = pred_type
self.module_name = module_name
self.segments_per_host = get_segments_per_host()
self.use_gpus = use_gpus if use_gpus else False
if self.use_gpus:
accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib, self.segments_per_host, self.module_name)
_assert(len(set(accessible_gpus_for_seg)) == 1,
'{0}: Asymmetric gpu configurations are not supported'.format(self.module_name))
self.gpus_per_host = accessible_gpus_for_seg[0]
else:
self.gpus_per_host = 0
self._set_default_pred_type()
def _set_default_pred_type(self):
self.pred_type = 'prob' if self.pred_type is None else self.pred_type
self.is_response = True if self.pred_type == 'response' else False
self.pred_type = 1 if self.is_response else self.pred_type
self.get_all = True if self.pred_type == 'prob' else False
self.use_ratio = True if self.pred_type < 1 else False
def call_internal_keras(self):
pred_col_name = 'prob'
pred_col_type = 'double precision'
class_values = strip_trailing_nulls_from_class_values(self.class_values)
gp_segment_id_col, seg_ids_test, \
images_per_seg_test = get_image_count_per_seg_for_non_minibatched_data_from_db(
self.test_table)
segments_per_host = get_segments_per_host()
if self.pred_type == 1:
rank_create_sql = ""
self.pred_vartype = [i.strip('[]') for i in self.dependent_vartype]
unnest_sql = []
full_class_name_list = []
full_class_value_list = []
for i in range(self.dependent_var_count):
if self.pred_vartype[i] in ['text', 'character varying', 'varchar']:
unnest_sql.append("unnest(ARRAY{0}) AS {1}".format(
['NULL' if j is None else j for j in class_values[i]],
quote_ident(self.dependent_varname[i])))
else:
unnest_sql.append("unnest(ARRAY[{0}]) AS {1}".format(
','.join(['NULL' if j is None else str(j) for j in class_values[i]]),
quote_ident(self.dependent_varname[i])))
for j in class_values[i]:
tmp_class_name = self.dependent_varname[i] if self.dependent_varname[i] is not None else "NULL::TEXT"
full_class_name_list.append(tmp_class_name)
tmp_class_value = j if j is not None else "NULL::TEXT"
full_class_value_list.append(tmp_class_value)
unnest_sql = """unnest(ARRAY{full_class_name_list}::TEXT[]) AS class_name,
unnest(ARRAY{full_class_value_list}::TEXT[]) AS class_value
""".format(**locals())
if self.get_all:
filter_sql = ""
elif self.use_ratio:
filter_sql = "WHERE {pred_col_name} > {self.pred_type}".format(**locals())
else:
filter_sql = "WHERE rank <= {self.pred_type}".format(**locals())
select_segmentid_comma = ""
group_by_clause = ""
join_cond_on_segmentid = ""
if not is_platform_pg():
select_segmentid_comma = "{self.test_table}.gp_segment_id AS gp_segment_id,".format(self=self)
group_by_clause = "GROUP BY {self.test_table}.gp_segment_id".format(self=self)
join_cond_on_segmentid = "{self.test_table}.gp_segment_id=min_ctid.gp_segment_id AND".format(self=self)
# Calling CREATE TABLE instead of CTAS, to ensure that the plan_cache_mode
# guc codepath is called when passing in the weights
sql = """
CREATE TABLE {self.output_table}
({self.id_col} {self.id_col_type},
class_name TEXT,
class_value TEXT,
{pred_col_name} {pred_col_type},
rank INTEGER)
""".format(**locals())
plpy.execute(sql)
independent_varname_sql = ["{0}::REAL[]".format(quote_ident(i)) for i in self.independent_varname]
while len(independent_varname_sql) < 5:
independent_varname_sql.append("NULL::REAL[]")
independent_varname_sql = ', '.join(independent_varname_sql)
# Passing huge model weights to internal_keras_predict() for each row
# resulted in slowness of overall madlib_keras_predict().
# To avoid this, a CASE is added to pass the model weights only for
# the very first row(min(ctid)) that is fetched on each segment and NULL
# for the other rows.
rank_sql = """ row_number() OVER (PARTITION BY {self.id_col}, class_name
ORDER BY {pred_col_name} DESC) AS rank
""".format(**locals())
sql1 = """
INSERT INTO {self.output_table}
SELECT *
FROM (
SELECT *, {rank_sql}
FROM (
SELECT {self.id_col}::{self.id_col_type},
{unnest_sql},
unnest(
{self.schema_madlib}.internal_keras_predict
({independent_varname_sql},
$1,
CASE WHEN {self.test_table}.ctid = min_ctid.ctid THEN $2 ELSE NULL END,
{self.normalizing_const},
{gp_segment_id_col},
ARRAY{seg_ids_test},
ARRAY{images_per_seg_test},
{self.gpus_per_host},
{segments_per_host})) AS prob
FROM {self.test_table}
LEFT JOIN
(SELECT {select_segmentid_comma} MIN({self.test_table}.ctid) AS ctid
FROM {self.test_table}
{group_by_clause}) min_ctid
ON {join_cond_on_segmentid} {self.test_table}.ctid=min_ctid.ctid
) __subq1__
) __subq2__
{filter_sql}
""".format(**locals())
predict_query = plpy.prepare(sql1, ["text", "bytea"])
plpy.execute(predict_query, [self.model_arch, self.model_weights])
if self.is_response:
# Drop the rank column since it is irrelevant
plpy.execute("""
ALTER TABLE {self.output_table}
DROP COLUMN rank
""".format(**locals()))
def set_default_class_values(self, in_class_values, dependent_var_count):
self.class_values = []
num_classes = get_num_classes(self.model_arch, dependent_var_count)
for counter, i in enumerate(in_class_values):
if (i is None) or (i==[None]):
self.class_values.append(range(0, num_classes[counter]))
else:
self.class_values.append(i)
@MinWarning("warning")
class Predict(BasePredict):
def __init__(self, schema_madlib, model_table,
test_table, id_col, independent_varname,
output_table, pred_type, use_gpus,
mst_key, **kwargs):
self.module_name = 'madlib_keras_predict'
self.model_table = model_table
self.mst_key = mst_key
self.is_mult_model = mst_key is not None
if self.model_table:
self.model_summary_table = add_postfix(self.model_table, "_summary")
BasePredict.__init__(self, schema_madlib, model_table, test_table,
id_col, independent_varname,
output_table, pred_type,
use_gpus, self.module_name)
param_proc = PredictParamsProcessor(self.model_table, self.module_name, self.mst_key)
if self.is_mult_model:
self.temp_summary_view = param_proc.model_summary_table
self.model_summary_table = self.temp_summary_view
self.dependent_vartype = param_proc.get_dependent_vartype()
self.model_weights = param_proc.get_model_weights()
self.model_arch = param_proc.get_model_arch()
self.dependent_varname = param_proc.get_dependent_varname()
self.dependent_var_count = len(self.dependent_varname)
class_values = []
for dep in self.dependent_varname:
class_values.append(param_proc.get_class_values(dep))
self.set_default_class_values(class_values, self.dependent_var_count)
self.normalizing_const = param_proc.get_normalizing_const()
self.validate()
self.id_col_type = get_expr_type(self.id_col, self.test_table)
BasePredict.call_internal_keras(self)
if self.is_mult_model:
plpy.execute("DROP VIEW IF EXISTS {0}".format(self.temp_summary_view))
def validate(self):
input_tbl_valid(self.model_table, self.module_name)
if self.is_mult_model and not columns_exist_in_table(self.model_table, ['mst_key']):
plpy.error("{self.module_name}: Single model should not pass mst_key".format(**locals()))
if not self.is_mult_model and columns_exist_in_table(self.model_table, ['mst_key']):
plpy.error("{self.module_name}: Multi-model needs to pass mst_key".format(**locals()))
InputValidator.validate_predict_evaluate_tables(
self.module_name, self.model_table, self.model_summary_table,
self.test_table, self.output_table)
InputValidator.validate_id_in_test_tbl(
self.module_name, self.test_table, self.id_col)
input_shape = get_input_shape(self.model_arch)
InputValidator.validate_pred_type(
self.module_name, self.pred_type, self.class_values)
InputValidator.validate_input_shape(
self.test_table, self.independent_varname, input_shape, 1)
@MinWarning("warning")
class PredictBYOM(BasePredict):
def __init__(self, schema_madlib, model_arch_table, model_id,
test_table, id_col, independent_varname, output_table,
pred_type, use_gpus, class_values, normalizing_const,
dependent_count, **kwargs):
self.module_name='madlib_keras_predict_byom'
self.model_arch_table = model_arch_table
self.model_id = model_id
self.class_values = class_values
self.normalizing_const = normalizing_const
self.dependent_var_count = dependent_count
if self.dependent_var_count == 1:
self.dependent_varname = ['dependent_var']
else:
self.dependent_varname = ['dependent_var_{0}'.format(i) for i in range(self.dependent_var_count)]
BasePredict.__init__(self, schema_madlib, model_arch_table,
test_table, id_col, independent_varname,
output_table, pred_type, use_gpus, self.module_name)
self.dependent_vartype = []
if self.class_values:
for i in self.class_values:
self.dependent_vartype.append(get_psql_type(i[0]))
else:
self.class_values = [None]*self.dependent_var_count
if self.pred_type == 1:
self.dependent_vartype = ['text']*self.dependent_var_count
else:
self.dependent_vartype = ['double precision']*self.dependent_var_count
## Set default values for norm const and class_values
# use_gpus and pred_type are defaulted in base_predict's init
self.normalizing_const = normalizing_const
if self.normalizing_const is None:
self.normalizing_const = DEFAULT_NORMALIZING_CONST
InputValidator.validate_predict_byom_tables(
self.module_name, self.model_arch_table, self.model_id,
self.test_table, self.id_col, self.output_table,
self.independent_varname)
self.validate_and_set_defaults()
self.id_col_type = get_expr_type(self.id_col, self.test_table)
BasePredict.call_internal_keras(self)
def validate_and_set_defaults(self):
# Set some defaults first and then validate and then set some more defaults
self.model_arch, self.model_weights = get_model_arch_weights(
self.model_arch_table, self.model_id)
# Assert model_weights and model_arch are not empty.
_assert(self.model_weights and self.model_arch,
"{0}: Model weights and architecture should not be NULL.".format(
self.module_name))
self.set_default_class_values(self.class_values, self.dependent_var_count)
InputValidator.validate_pred_type(
self.module_name, self.pred_type, self.class_values)
InputValidator.validate_normalizing_const(
self.module_name, self.normalizing_const)
# TODO: Fix this validation
# The current method looks at the 'units' keyword which doesn't mean
# anything because every layer has it. It was passing because the layers
# are traversed in order. It won't work for multi-io and prone to breaking
# in the regular case.
# InputValidator.validate_class_values(
# self.module_name, self.class_values, self.pred_type, self.model_arch)
InputValidator.validate_input_shape(
self.test_table, self.independent_varname,
get_input_shape(self.model_arch), 1)
def internal_keras_predict_wide(independent_var, independent_var2,
independent_var3, independent_var4, independent_var5,
model_architecture, model_weights,
normalizing_const, current_seg_id, seg_ids,
images_per_seg, gpus_per_host, segments_per_host,
**kwargs):
return internal_keras_predict(
[independent_var, independent_var2, independent_var3, independent_var4, independent_var5],
model_architecture, model_weights, normalizing_const, current_seg_id,
seg_ids, images_per_seg, gpus_per_host, segments_per_host,
**kwargs)
def internal_keras_predict(independent_var, model_architecture, model_weights,
normalizing_const, current_seg_id, seg_ids,
images_per_seg, gpus_per_host, segments_per_host,
**kwargs):
SD = kwargs['SD']
model_key = 'segment_model_predict'
row_count_key = 'row_count'
try:
device_name = get_device_name_and_set_cuda_env(gpus_per_host, current_seg_id)
if model_key not in SD:
set_keras_session(device_name, gpus_per_host, segments_per_host)
model = model_from_json(model_architecture)
set_model_weights(model, model_weights)
SD[model_key] = model
SD[row_count_key] = 0
else:
model = SD[model_key]
SD[row_count_key] += 1
# Since the test data isn't mini-batched,
# we have to make sure that the test data np array has the same
# number of dimensions as input_shape. So we add a dimension to x.
independent_var_filtered = []
for i in independent_var:
if i is not None:
independent_var_filtered.append(expand_input_dims(i)/normalizing_const)
with tf.device(device_name):
probs = model.predict(independent_var_filtered)
# probs is a list containing a list of probability values, of all
# class levels. Since we are assuming each input is a single image,
# and not mini-batched, this list contains exactly one list in it,
# so return back the first list in probs.
result = []
if len(independent_var_filtered) > 1:
for i in probs:
for j in i[0]:
result.append(j)
else:
result = probs[0]
total_images = get_image_count_per_seg_from_array(seg_ids.index(current_seg_id),
images_per_seg)
if SD[row_count_key] == total_images:
SD.pop(model_key, None)
SD.pop(row_count_key, None)
clear_keras_session()
return result
except Exception as ex:
SD.pop(model_key, None)
SD.pop(row_count_key, None)
clear_keras_session()
plpy.error(ex)
def predict_help(schema_madlib, message, **kwargs):
"""
Help function for keras predict
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
This function allows the user to predict using a madlib_keras_fit trained
model.
For more details on function usage:
SELECT {schema_madlib}.madlib_keras_predict('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.madlib_keras_predict(
model_table, -- Name of the table containing the model
test_table, -- Name of the table containing the evaluation dataset
id_col, -- Name of the id column in the test data table
independent_varname, -- Name of the column with independent
variables in the test table
output_table, -- Name of the output table
pred_type, -- The type of the desired output
use_gpus, -- Flag for enabling GPU support
mst_key -- Identifier for the desired model out of multimodel
training output
)
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output table ('output_table' above) contains the following columns:
id: Gives the 'id' for each prediction,
corresponding to each row from the test_table.
dependent_varname: The estimated class.
prob: The probability of a given class.
rank: The rank of the estimation.
"""
else:
help_string = "No such option. Use {schema_madlib}.madlib_keras_predict()"
return help_string.format(schema_madlib=schema_madlib)
def predict_byom_help(schema_madlib, message, **kwargs):
"""
Help function for keras predict
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
This function allows the user to predict with their own pre trained model (note
that this model doesn't have to be trained using MADlib.)
For more details on function usage:
SELECT {schema_madlib}.madlib_keras_predict_byom('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.madlib_keras_predict_byom(
model_arch_table, -- Name of the table containing the model architecture
and the pre trained model weights
model_id, -- This is the id in 'model_arch_table' containing the
model architecture
test_table, -- Name of the table containing the evaluation dataset
id_col, -- Name of the id column in the test data table
independent_varname, -- Name of the column with independent
variables in the test table
output_table, -- Name of the output table
pred_type, -- The type of the desired output
use_gpus, -- Flag for enabling GPU support
class_values, -- List of class labels that were used while training the
model. If class_values is passed in as NULL, the output
table will have a column named 'prob' which is an array
of probabilities of all the classes.
Otherwise if class_values is not NULL, then the output
table will contain a column for each class/label from
the training data
normalizing_const, -- Normalizing constant used for standardizing arrays in
independent_varname
)
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output table ('output_table' above) contains the following columns:
id: Gives the 'id' for each prediction,
corresponding to each row from the test_table.
dependent_varname: The estimated class.
prob: The probability of a given class.
rank: The rank of the estimation.
"""
else:
help_string = "No such option. Use {schema_madlib}.madlib_keras_predict_byom()"
return help_string.format(schema_madlib=schema_madlib)
# ---------------------------------------------------------------------