DL: improvements to "serializing weights with image count"
JIRA: MADLIB-1416
We need to add the image count to the model weights to get the final
state for fit_transition.
Previously we would use np.concatenate to join the image count and model
weights together. A cleaner and faster way is to append image
count to the weights list and get rid of the np.concatenate call.
Ran unit test `test_serialize_image_nd_weights_valid_output` with large
model_weights `[np.array([1]*100000000), np.array([1]*100000000)]` to
confirm the speed improvements.
test_serialize_image_nd_weights_valid_output with old code and large
weights: 17.8 s
test_serialize_image_nd_weights_valid_output with new code and large
weights: 16.4 s
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
index d70a2f8..5ab1f36 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
@@ -86,11 +86,9 @@
if model_weights is None:
return None
flattened_weights = [w.flatten() for w in model_weights]
- model_weights_serialized = np.concatenate(flattened_weights)
- new_model_string = np.array([image_count])
- new_model_string = np.concatenate((new_model_string, model_weights_serialized))
- new_model_string = np.float32(new_model_string)
- return new_model_string.tostring()
+ state = [np.array([image_count])] + flattened_weights
+ state = np.concatenate(state)
+ return np.float32(state).tostring()
def serialize_state_with_1d_weights(image_count, model_weights):
@@ -141,8 +139,8 @@
if model_weights is None:
return None
flattened_weights = [w.flatten() for w in model_weights]
- model_weights_serialized = np.concatenate(flattened_weights)
- return np.float32(model_weights_serialized).tostring()
+ flattened_weights = np.concatenate(flattened_weights)
+ return np.float32(flattened_weights).tostring()
def deserialize_as_nd_weights(model_weights_serialized, model_shapes):