blob: 298f63acfd1f50f887222311df830b0f82d7fcb4 [file] [log] [blame]
# coding=utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
m4_changequote(`<!', `!>')
import sys
import json
import plpy
from keras_model_arch_table import ModelArchSchema
def _get_layers(model_arch):
d = json.loads(model_arch)
config = d['config']
if type(config) == list:
return config # In keras 2.1.x, all models are sequential
elif type(config) == dict and 'layers' in config:
layers = config['layers']
if type(layers) == list:
return config['layers'] # In keras 2.x, only sequential models are supported
plpy.error("Unable to read model architecture JSON.")
def get_input_shape(model_arch):
arch_layers = _get_layers(model_arch)
if 'batch_input_shape' in arch_layers[0]['config']:
return arch_layers[0]['config']['batch_input_shape'][1:]
plpy.error('Unable to get input shape from model architecture.'\
'Make sure that the first layer defines an input_shape.')
def get_num_classes(model_arch):
We assume that the last dense layer in the model architecture contains the num_classes (units)
An example can be:
where activation can be after the dense layer.
:param model_arch:
arch_layers = _get_layers(model_arch)
i = len(arch_layers) - 1
while i >= 0:
if 'units' in arch_layers[i]['config']:
return arch_layers[i]['config']['units']
i -= 1
plpy.error('Unable to get number of classes from model architecture.')
def get_model_arch_layers_str(model_arch):
arch_layers = _get_layers(model_arch)
layers = "Model arch layers:\n"
first = True
for layer in arch_layers:
if first:
first = False
layers += " |\n"
layers += " V\n"
class_name = layer['class_name']
config = layer['config']
if class_name == 'Dense':
layers += "{1}[{2}]\n".format(class_name, config['units'])
layers += "{1}\n".format(class_name)
return layers
def get_model_arch(model_arch_table, model_id):
For fit_multiple, we don't want to keep sending weights back and
forth between the main host and the segment hosts. weights can be
up to 1GB in size, whereas the model arch in JSON is usually very
s = ModelArchSchema
model_arch_query = """
SELECT {s.MODEL_ARCH} FROM {model_arch_table}
WHERE {s.MODEL_ID} = {model_id}
model_arch_result = plpy.execute(model_arch_query)
if not model_arch_result or len(model_arch_result) != 1:
plpy.error("no model arch found in table {0} with id {1}".format(
model_arch_table, model_id))
model_arch = model_arch_result[0][ModelArchSchema.MODEL_ARCH]
return model_arch
def get_model_arch_weights(model_arch_table, model_id):
For fit, we need both the model arch & model weights
#assume validation is already called
model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
model_arch_table, ModelArchSchema.MODEL_ID,
model_arch_result = plpy.execute(model_arch_query)
if not model_arch_result:
plpy.error("no model arch found in table {0} with id {1}".format(
model_arch_table, model_id))
model_arch_result = model_arch_result[0]
model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH]
model_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
return model_arch, model_weights