blob: bf14d1e08a2a760deb4658b36b1a65a166120fa3 [file] [log] [blame]
# coding=utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
import os
import keras
from keras import backend as K
from keras.layers import *
from keras.models import *
from keras.optimizers import *
import numpy as np
from utilities.model_arch_info import get_input_shape
from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
from madlib_keras_wrapper import compile_and_set_weights
from madlib_keras_wrapper import convert_string_of_args_to_dict
from madlib_keras_helper import CLASS_VALUES_COLNAME
from madlib_keras_helper import KerasWeightsSerializer
def predict(schema_madlib, model_table, test_table, id_col, model_arch_table,
model_arch_id, independent_varname, compile_params, output_table,
module_name = 'madlib_keras_predict'
input_tbl_valid(test_table, module_name)
input_tbl_valid(model_arch_table, module_name)
output_tbl_valid(output_table, module_name)
# _validate_input_args(test_table, model_arch_table, output_table)
model_data_query = "SELECT model_data from {0}".format(model_table)
model_data = plpy.execute(model_data_query)[0]['model_data']
model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
"WHERE id = {1}".format(model_arch_table, model_arch_id)
query_result = plpy.execute(model_arch_query)
if not query_result or len(query_result) == 0:
plpy.error("no model arch found in table {0} with id {1}".format(model_arch_table, model_arch_id))
query_result = query_result[0]
model_arch = query_result['model_arch']
input_shape = get_input_shape(model_arch)
compile_params = "$madlib$" + compile_params + "$madlib$"
model_summary_table = add_postfix(model_table, "_summary")
class_values = plpy.execute("SELECT {0} AS cv FROM {1}".format(
CLASS_VALUES_COLNAME, model_summary_table))[0]['cv']
predict_query = plpy.prepare("""
CREATE TABLE {output_table} AS
SELECT {id_col},
)[1] as prediction
from {test_table}""".format(**locals()), ["bytea"])
plpy.execute(predict_query, [model_data])
def internal_keras_predict(x_test, model_arch, model_data, input_shape,
compile_params, class_values):
model = model_from_json(model_arch)
device_name = '/cpu:0'
os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
model_shapes = KerasWeightsSerializer.get_model_shapes(model)
compile_and_set_weights(model, compile_params, device_name,
model_data, model_shapes)
x_test = np.array(x_test).reshape(1, *input_shape)
x_test /= 255
proba_argmax = model.predict_classes(x_test)
# proba_argmax is a list with exactly one element in it. That element
# refers to the index containing the largest probability value in the
# output of Keras' predict function.
return _get_class_label(class_values, proba_argmax[0])
def _get_class_label(class_values, class_index):
Returns back the class label associated with the index returned by Keras'
predict_classes function. Keras' predict_classes function returns back
the index of the 1-hot encoded output that has the highest probability
value. We should infer the exact class label corresponding to the index
by looking at the class_values list (which is obtained from the
class_values column of the model summary table). If class_values is None,
we return the index as is.
@param class_values: list of class labels.
@param class_index: integer representing the index with max
probability value.
scalar. If class_values is None, returns class_index, else returns
if not class_values:
return class_index
elif class_index != int(class_index):
plpy.error("Invalid class index {0} returned from Keras predict. "\
"Index value must be an integer".format(
elif class_index < 0 or class_index >= len(class_values):
plpy.error("Invalid class index {0} returned from Keras predict. "\
"Index value must be less than {1}".format(
class_index, len(class_values)))
return class_values[class_index]