Deep Learning: Use compile_and_set_weights() in predict
Commit SHA 137ba49 changed the way we process compile_params, and
although it was used in fit, it wasn't used in the predict function.
This commit makes necessary changes to process compile_params using this
new function. We also now use KerasWeightsSerializer.get_model_shapes to
get the model shape in predict function.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 5e4e62b..c484c09 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -32,6 +32,7 @@
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
+from madlib_keras_wrapper import compile_and_set_weights
from madlib_keras_wrapper import convert_string_of_args_to_dict
from madlib_keras_helper import get_class_values_and_type
from madlib_keras_helper import KerasWeightsSerializer
@@ -76,19 +77,12 @@
def internal_keras_predict(x_test, model_arch, model_data, input_shape,
compile_params, class_values):
model = model_from_json(model_arch)
- compile_params = convert_string_of_args_to_dict(compile_params)
device_name = '/cpu:0'
os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
+ model_shapes = KerasWeightsSerializer.get_model_shapes(model)
+ compile_and_set_weights(model, compile_params, device_name,
+ model_data, model_shapes)
- with K.tf.device(device_name):
- model.compile(**compile_params)
-
- model_shapes = []
- for weight_arr in model.get_weights():
- model_shapes.append(weight_arr.shape)
- _,_,_, model_weights = KerasWeightsSerializer.deserialize_weights(
- model_data, model_shapes)
- model.set_weights(model_weights)
x_test = np.array(x_test).reshape(1, *input_shape)
x_test /= 255
proba_argmax = model.predict_classes(x_test)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index f393db3..3e8a18c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -171,5 +171,5 @@
'model_arch',
1,
'x',
- '''optimizer''=SGD(lr=0.01, decay=1e-6, nesterov=True), ''loss''=''categorical_crossentropy'', ''metrics''=[''accuracy'']'::text,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
'cifar10_predict');