blob: 3a225cc3b055726d7dc62a05f1473e70356c5275 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
m4_changequote(`<!', `!>')
import sys
import json
import plpy
from keras_model_arch_table import ModelArchSchema
def _get_layers(model_arch):
d = json.loads(model_arch)
config = d['config']
if type(config) == list:
return config # In keras 2.1.x, all models are sequential
elif type(config) == dict and 'layers' in config:
layers = config['layers']
if type(layers) == list:
return config['layers'] # In keras 2.x, only sequential models are supported
plpy.error("Unable to read model architecture JSON.")
def get_input_shape(model_arch):
arch_layers = _get_layers(model_arch)
if 'batch_input_shape' in arch_layers[0]['config']:
return arch_layers[0]['config']['batch_input_shape'][1:]
plpy.error('Unable to get input shape from model architecture.')
def get_num_classes(model_arch):
"""
We assume that the last dense layer in the model architecture contains the num_classes (units)
An example can be:
```
...
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))
```
where activation can be after the dense layer.
:param model_arch:
:return:
"""
arch_layers = _get_layers(model_arch)
i = len(arch_layers) - 1
while i >= 0:
if 'units' in arch_layers[i]['config']:
return arch_layers[i]['config']['units']
i -= 1
plpy.error('Unable to get number of classes from model architecture.')
def get_model_arch_layers_str(model_arch):
arch_layers = _get_layers(model_arch)
layers = "Model arch layers:\n"
first = True
for layer in arch_layers:
if first:
first = False
else:
layers += " |\n"
layers += " V\n"
class_name = layer['class_name']
config = layer['config']
if class_name == 'Dense':
layers += "{1}[{2}]\n".format(class_name, config['units'])
else:
layers += "{1}\n".format(class_name)
return layers
def get_model_arch_weights(model_arch_table, model_id):
#assume validation is already called
model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
model_arch_table, ModelArchSchema.MODEL_ID,
model_id)
model_arch_result = plpy.execute(model_arch_query)
if not model_arch_result:
plpy.error("no model arch found in table {0} with id {1}".format(
model_arch_table, model_id))
model_arch_result = model_arch_result[0]
model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH]
model_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
return model_arch, model_weights