blob: 3313dd3b204acda88acc27ca8106368193b189ce [file] [log] [blame]
# coding=utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import os
import plpy
# Import needed for get_data_as_np_array()
from keras import utils as keras_utils
from utilities.minibatch_validation import validate_dependent_var_for_minibatch
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import is_var_valid
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
########### Helper functions to serialize and deserialize weights #####
class KerasWeightsSerializer:
def get_model_shapes(model):
model_shapes = []
for a in model.get_weights():
return model_shapes
def deserialize_weights(model_state, model_shapes):
model_state: a stringified (serialized) state containing loss,
accuracy, buffer_count, and model_weights, passed from postgres
model_shapes: a list of tuples containing the shapes of each element
in keras.get_weights()
buffer_count: the buffer count from state
model_weights: a list of numpy arrays that can be inputted into keras.set_weights()
if not model_state or not model_shapes:
return None
state = np.fromstring(model_state, dtype=np.float32)
model_weights_serialized = state[3:]
i, j, model_weights = 0, 0, []
while j < len(model_shapes):
next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
weight_arr_portion = model_weights_serialized[i:next_pointer]
i, j = next_pointer, j + 1
return int(float(state[0])), int(float(state[1])), int(float(state[2])), model_weights
def serialize_weights(loss, accuracy, buffer_count, model_weights):
loss, accuracy, buffer_count: float values
model_weights: a list of numpy arrays, what you get from
A stringified (serialized) state containing all these values, to be
passed to postgres
if model_weights is None:
return None
flattened_weights = [w.flatten() for w in model_weights]
model_weights_serialized = np.concatenate(flattened_weights)
new_model_string = np.array([loss, accuracy, buffer_count])
new_model_string = np.concatenate((new_model_string, model_weights_serialized))
new_model_string = np.float32(new_model_string)
return new_model_string.tostring()
def deserialize_iteration_state(iteration_result):
iteration_result: the output of the step function
loss: the averaged loss from that iteration of training
accuracy: the averaged accuracy from that iteration of training
new_model_state: the stringified (serialized) state to pass in to next
iteration of step function training, represents the averaged weights
from the last iteration of training; zeros out loss, accuracy,
buffer_count in this state because the new iteration must start with
fresh values
if not iteration_result:
return None
state = np.fromstring(iteration_result, dtype=np.float32)
new_model_string = np.array(state)
new_model_string[0], new_model_string[1], new_model_string[2] = 0, 0, 0
new_model_string = np.float32(new_model_string)
return float(state[0]), float(state[1]), new_model_string.tostring()
def deserialize_weights_merge(state):
state: the stringified (serialized) state containing loss, accuracy, buffer_count, and
model_weights, passed from postgres to merge function
loss: the averaged loss from that iteration of training
accuracy: the averaged accuracy from that iteration of training
buffer_count: total buffer counts processed
model_weights: a single flattened numpy array containing all of the
weights, flattened because all we have to do is average them (so don't
have to reshape)
if not state:
return None
state = np.fromstring(state, dtype=np.float32)
return float(state[0]), float(state[1]), int(float(state[2])), state[3:]
def serialize_weights_merge(loss, accuracy, buffer_count, model_weights):
loss, accuracy, buffer_count: float values
model_weights: a single flattened numpy array containing all of the
weights, averaged in merge function over the 2 states
A stringified (serialized) state containing all these values, to be
passed to postgres
if model_weights is None:
return None
new_model_string = np.array([loss, accuracy, buffer_count])
new_model_string = np.concatenate((new_model_string, model_weights))
new_model_string = np.float32(new_model_string)
return new_model_string.tostring()
def deserialize_weights_orig(model_weights_serialized, model_shapes):
Original deserialization for warm-start, used only to parse model received
from query at the top of this file
i, j, model_weights = 0, 0, []
while j < len(model_shapes):
next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
weight_arr_portion = model_weights_serialized[i:next_pointer]
i, j = next_pointer, j + 1
return model_weights
########### General Helper functions #######
def get_data_as_np_array(table_name, y, x, input_shape, num_classes):
:param table_name: Table containing the batch of images per row
:param y: Column name for y
:param x: Column name for x
:param input_shape: input_shape of data in array format [L , W , C]
:param num_classes: num of distinct classes in y
val_data_qry = "SELECT {0}, {1} FROM {2}".format(y, x, table_name)
input_shape = map(int, input_shape)
val_data = plpy.execute(val_data_qry)
indep_len = len(val_data[0][x])
pixels_per_image = int(input_shape[0] * input_shape[1] * input_shape[2])
x_validation = np.ndarray((0,indep_len, pixels_per_image))
y_validation = np.ndarray((0,indep_len, num_classes))
for i in range(len(val_data)):
x_test = np.asarray((val_data[i][x],))
x_test = x_test.reshape(1, indep_len, pixels_per_image)
y_test = np.asarray((val_data[i][y],))
x_validation=np.concatenate((x_validation, x_test))
y_validation=np.concatenate((y_validation, y_test))
num_test_examples = x_validation.shape[0]
x_validation = x_validation.reshape(indep_len * num_test_examples, *input_shape)
x_validation = x_validation.astype('float64')
y_validation = y_validation.reshape(indep_len * num_test_examples, num_classes)
return x_validation, y_validation
CLASS_VALUES_COLNAME = "class_values"
NORMALIZING_CONST_COLNAME = "normalizing_const"
DEPENDENT_VARTYPE = "dependent_vartype"
class FitInputValidator:
def __init__(self, source_table, validation_table, output_model_table,
model_arch_table, dependent_varname, independent_varname,
self.source_table = source_table
self.validation_table = validation_table
self.output_model_table = output_model_table
self.model_arch_table = model_arch_table
self.dependent_varname = dependent_varname
self.independent_varname = independent_varname
self.num_iterations = num_iterations
self.source_summary_table = None
if self.source_table:
self.source_summary_table = add_postfix(
self.source_table, "_summary")
if self.output_model_table:
self.output_summary_model_table = add_postfix(
self.output_model_table, "_summary")
self.module_name = 'model_keras'
def _validate_input_table(self, table):
_assert(is_var_valid(table, self.independent_varname),
"model_keras error: invalid independent_varname "
"('{independent_varname}') for table "
_assert(is_var_valid(table, self.dependent_varname),
"model_keras error: invalid dependent_varname "
"('{dependent_varname}') for table "
def _validate_input_args(self):
_assert(self.num_iterations > 0,
"model_keras error: Number of iterations cannot be < 1.")
input_tbl_valid(self.source_table, self.module_name)
input_tbl_valid(self.source_summary_table, self.module_name)
self.source_summary_table, CLASS_VALUES_COLNAME),
"model_keras error: invalid class_values varname "
"('{class_values}') for source_summary_table "
# Source table and validation tables must have the same schema
if self.validation_table and self.validation_table.strip() != '':
input_tbl_valid(self.validation_table, self.module_name)
# Validate model arch table's schema.
input_tbl_valid(self.model_arch_table, self.module_name)
# Validate output tables
output_tbl_valid(self.output_model_table, self.module_name)
output_tbl_valid(self.output_summary_model_table, self.module_name)
def validate_input_shapes(self, table, input_shape):
Validate if the input shape specified in model architecture is the same
as the shape of the image specified in the indepedent var of the input
# The weird indexing with 'i+2' and 'i' below has two reasons:
# 1) The indexing for array_upper() starts from 1, but indexing in the
# input_shape list starts from 0.
# 2) Input_shape is only the image's dimension, whereas a row of
# independent varname in a table contains buffer size as the first
# dimension, followed by the image's dimension. So we must ignore
# the first dimension from independent varname.
array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
self.independent_varname, i+2, i) for i in range(len(input_shape)))
query = """
FROM {1}
""".format(array_upper_query, table)
# This query will fail if an image in independent var does not have the
# same number of dimensions as the input_shape.
result = plpy.execute(query)[0]
_assert(len(result) == len(input_shape),
"model_keras error: The number of dimensions ({0}) of each image" \
" in model architecture and {1} in {2} ({3}) do not match.".format(
len(input_shape), self.independent_varname, table, len(result)))
for i in range(len(input_shape)):
key_name = "n_{0}".format(i)
if result[key_name] != input_shape[i]:
# Construct the shape in independent varname to display
# meaningful error msg.
input_shape_from_table = [result["n_{0}".format(i)]
for i in range(len(input_shape))]
plpy.error("model_keras error: Input shape {0} in the model" \
" architecture does not match the input shape {1} of column" \
" {2} in table {3}.".format(
input_shape, input_shape_from_table,
self.independent_varname, table))