| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| m4_changequote(`<!', `!>') |
| |
| import sys |
| import json |
| import plpy |
| from keras_model_arch_table import ModelArchSchema |
| |
| def _get_layers(model_arch): |
| d = json.loads(model_arch) |
| config = d['config'] |
| if type(config) == list: |
| return config # In keras 2.1.x, all models are sequential |
| elif type(config) == dict and 'layers' in config: |
| layers = config['layers'] |
| if type(layers) == list: |
| return config['layers'] # In keras 2.x, only sequential models are supported |
| plpy.error("Unable to read model architecture JSON.") |
| |
| def get_input_shape(model_arch): |
| arch_layers = _get_layers(model_arch) |
| if 'batch_input_shape' in arch_layers[0]['config']: |
| return arch_layers[0]['config']['batch_input_shape'][1:] |
| plpy.error('Unable to get input shape from model architecture.') |
| |
| def get_num_classes(model_arch): |
| """ |
| We assume that the last dense layer in the model architecture contains the num_classes (units) |
| An example can be: |
| ``` |
| ... |
| model.add(Flatten()) |
| model.add(Dense(512)) |
| model.add(Activation('relu')) |
| model.add(Dropout(0.5)) |
| model.add(Dense(num_classes)) |
| model.add(Activation('softmax')) |
| ``` |
| where activation can be after the dense layer. |
| :param model_arch: |
| :return: |
| """ |
| arch_layers = _get_layers(model_arch) |
| i = len(arch_layers) - 1 |
| while i >= 0: |
| if 'units' in arch_layers[i]['config']: |
| return arch_layers[i]['config']['units'] |
| i -= 1 |
| plpy.error('Unable to get number of classes from model architecture.') |
| |
| def get_model_arch_layers_str(model_arch): |
| arch_layers = _get_layers(model_arch) |
| layers = "Model arch layers:\n" |
| first = True |
| for layer in arch_layers: |
| if first: |
| first = False |
| else: |
| layers += " |\n" |
| layers += " V\n" |
| class_name = layer['class_name'] |
| config = layer['config'] |
| if class_name == 'Dense': |
| layers += "{1}[{2}]\n".format(class_name, config['units']) |
| else: |
| layers += "{1}\n".format(class_name) |
| return layers |
| |
| def get_model_arch_weights(model_arch_table, model_id): |
| |
| #assume validation is already called |
| model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format( |
| ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS, |
| model_arch_table, ModelArchSchema.MODEL_ID, |
| model_id) |
| model_arch_result = plpy.execute(model_arch_query) |
| if not model_arch_result: |
| plpy.error("no model arch found in table {0} with id {1}".format( |
| model_arch_table, model_id)) |
| |
| model_arch_result = model_arch_result[0] |
| |
| model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH] |
| model_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS] |
| |
| return model_arch, model_weights |