blob: 8c3bd9adc1d31c5ec5e204c8bd3cc35a5d3c84d3 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# custom service file
# model_handler.py
"""
ModelHandler defines a base model handler.
"""
import logging
import data_transformer
import keras
import sys
import numpy as np
import defs
import mxnet as mx
class ModelHandler(object):
"""
A base Model handler implementation.
"""
def __init__(self):
self.error = None
self._context = None
self._batch_size = 0
self.initialized = False
self.mod = None
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
self._batch_size = context.system_properties["batch_size"]
self.initialized = True
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix='./prog', epoch=0)
self.mod = mx.mod.Module(symbol=sym,
data_names=['/dropout_1_input1'],
context=mx.cpu(),
label_names=None)
self.mod.bind(for_training=False,
data_shapes=[('/dropout_1_input1', (1, 2048, 70), 'float32', 'NTC')],
label_shapes=self.mod._label_shapes)
self.mod.set_params(arg_params, aux_params)
def preprocess(self, batch):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
#with open('tmp_file','wb') as f:
# f.write(batch[0].get('body'))
#return mx.nd.array(data_transformer.file_to_vec('tmp_file', file_vector_size=defs.file_chars_trunc_limit))
return mx.nd.array(data_transformer.file_to_vec(batch[0].get('body'), file_vector_size=defs.file_chars_trunc_limit))
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
return self.mod.predict(model_input)
def postprocess(self, inference_output):
"""
Return predict result in batch.
:param inference_output: list of inference output
:return: list of predict results
"""
y = inference_output
results = []
for i in range(0, len(defs.langs)):
results.append("{} - {}: {}%".format(' ' if (y[0][i] < 0.5) else '*', defs.langs[i],
(100 * y[0][i])).strip('<NDArray 1 @cpu(0)>%'))
return [results]
def handle(self, data, context):
"""
Custom service entry point function.
:param data: list of objects, raw input from request
:param context: model server context
:return: list of outputs to be send back to client
"""
try:
data = self.preprocess(data)
data = self.inference(data)
data = self.postprocess(data)
print("after", data)
return data
except Exception as e:
logging.error(e, exc_info=True)
request_processor = context.request_processor
request_processor.report_status(500, "Unknown inference error")
return [str(e)] * self._batch_size