blob: 4f896846d92f71afbb7614b25e3b96084d15b789 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, too-many-arguments
"""Lightweight API for mxnet prediction.
This is for prediction only, use mxnet python package instead for most tasks.
"""
from __future__ import absolute_import
import os
import sys
from array import array
import ctypes
import logging
import numpy as np
# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
None: -1,
np.float32: 0,
np.float64: 1,
np.float16: 2,
np.uint8: 3,
np.int32: 4,
np.int8: 5,
np.int64: 6,
}
_DTYPE_MX_TO_NP = {
-1: None,
0: np.float32,
1: np.float64,
2: np.float16,
3: np.uint8,
4: np.int32,
5: np.int8,
6: np.int64,
}
__all__ = ["Predictor", "load_ndarray_file"]
py_str = lambda x: x.decode('utf-8')
def c_str_array(strings):
"""Create ctypes const char ** from a list of Python strings.
Parameters
----------
strings : list of string
Python strings.
Returns
-------
(ctypes.c_char_p * len(strings))
A const char ** pointer that can be passed to C API.
"""
arr = (ctypes.c_char_p * len(strings))()
arr[:] = [s.encode('utf-8') for s in strings]
return arr
def c_str(string):
""""Convert a python string to C string."""
if not isinstance(string, str):
string = string.decode('ascii')
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)
def c_array_buf(ctype, buf):
"""Create ctypes array from a Python buffer."""
return (ctype * len(buf)).from_buffer(buf)
def _find_lib_path():
"""Find mxnet library."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
amalgamation_lib_path = os.path.join(curr_path, '../../lib/libmxnet_predict.so')
if os.path.exists(amalgamation_lib_path) and os.path.isfile(amalgamation_lib_path):
lib_path = [amalgamation_lib_path]
return lib_path
else:
logging.info('Cannot find libmxnet_predict.so. Will search for MXNet library using libinfo.py then.')
try:
from mxnet.libinfo import find_lib_path
lib_path = find_lib_path()
return lib_path
except ImportError:
libinfo_path = os.path.join(curr_path, '../../python/mxnet/libinfo.py')
if os.path.exists(libinfo_path) and os.path.isfile(libinfo_path):
libinfo = {'__file__': libinfo_path}
exec(compile(open(libinfo_path, "rb").read(), libinfo_path, 'exec'), libinfo, libinfo)
lib_path = libinfo['find_lib_path']()
return lib_path
else:
raise RuntimeError('Cannot find libinfo.py at %s.' % libinfo_path)
def _load_lib():
"""Load libary by searching possible path."""
lib_path = _find_lib_path()
lib = ctypes.cdll.LoadLibrary(lib_path[0])
# DMatrix functions
lib.MXGetLastError.restype = ctypes.c_char_p
return lib
def _check_call(ret):
"""Check the return value of API."""
if ret != 0:
raise RuntimeError(py_str(_LIB.MXGetLastError()))
def _monitor_callback_wrapper(callback):
"""A wrapper for the user-defined handle."""
def callback_handle(name, array, _):
""" ctypes function """
callback(name, array)
return callback_handle
_LIB = _load_lib()
# type definitions
mx_uint = ctypes.c_uint
mx_int = ctypes.c_int
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
PredictorHandle = ctypes.c_void_p
NDListHandle = ctypes.c_void_p
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3}
class Predictor(object):
"""A predictor class that runs prediction.
Parameters
----------
symbol_json_str : str
Path to the symbol file.
param_raw_bytes : str, bytes
The raw parameter bytes.
input_shapes : dict of str to tuple
The shape of input data
dev_type : str, optional
The device type of the predictor.
dev_id : int, optional
The device id of the predictor.
type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
"""
def __init__(self, symbol_file,
param_raw_bytes, input_shapes,
dev_type="cpu", dev_id=0, type_dict=None):
dev_type = devstr2type[dev_type]
indptr = [0]
sdata = []
keys = []
for k, v in input_shapes.items():
if not isinstance(v, tuple):
raise ValueError("Expect input_shapes to be dict str->tuple")
keys.append(c_str(k))
sdata.extend(v)
indptr.append(len(sdata))
handle = PredictorHandle()
param_raw_bytes = bytearray(param_raw_bytes)
ptr = (ctypes.c_char * len(param_raw_bytes)).from_buffer(param_raw_bytes)
# data types
num_provided_arg_types = 0
# provided type argument names
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)()
# provided types
provided_arg_type_data = ctypes.POINTER(mx_uint)()
if type_dict is not None:
provided_arg_type_names = []
provided_arg_type_data = []
for k, v in type_dict.items():
v = np.dtype(v).type
if v in _DTYPE_NP_TO_MX:
provided_arg_type_names.append(k)
provided_arg_type_data.append(_DTYPE_NP_TO_MX[v])
num_provided_arg_types = mx_uint(len(provided_arg_type_names))
provided_arg_type_names = c_str_array(provided_arg_type_names)
provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data))
_check_call(_LIB.MXPredCreateEx(
c_str(symbol_file),
ptr, len(param_raw_bytes),
ctypes.c_int(dev_type), ctypes.c_int(dev_id),
mx_uint(len(indptr) - 1),
c_array(ctypes.c_char_p, keys),
c_array(mx_uint, indptr),
c_array(mx_uint, sdata),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
ctypes.byref(handle)))
self.type_dict = type_dict
self.handle = handle
def __del__(self):
_check_call(_LIB.MXPredFree(self.handle))
def forward(self, **kwargs):
"""Perform forward to get the output.
Parameters
----------
**kwargs
Keyword arguments of input variable name to data.
Examples
--------
>>> predictor.forward(data=mydata)
>>> out = predictor.get_output(0)
"""
if self.type_dict and len(self.type_dict) != len(kwargs.items()):
raise ValueError("number of kwargs should be same as len of type_dict" \
"Please check your forward pass inputs" \
"or type_dict passed to Predictor instantiation")
for k, v in kwargs.items():
if not isinstance(v, np.ndarray):
raise ValueError("Expect numpy ndarray as input")
if self.type_dict and k in self.type_dict:
v = np.asarray(v, dtype=self.type_dict[k], order='C')
else:
v = np.asarray(v, dtype=np.float32, order='C')
_check_call(_LIB.MXPredSetInput(
self.handle, c_str(k),
v.ctypes.data_as(mx_float_p),
mx_uint(v.size)))
_check_call(_LIB.MXPredForward(self.handle))
def reshape(self, input_shapes):
"""Change the input shape of the predictor.
Parameters
----------
input_shapes : dict of str to tuple
The new shape of input data.
Examples
--------
>>> predictor.reshape({'data':data_shape_tuple})
"""
indptr = [0]
sdata = []
keys = []
for k, v in input_shapes.items():
if not isinstance(v, tuple):
raise ValueError("Expect input_shapes to be dict str->tuple")
keys.append(c_str(k))
sdata.extend(v)
indptr.append(len(sdata))
new_handle = PredictorHandle()
_check_call(_LIB.MXPredReshape(
mx_uint(len(indptr) - 1),
c_array(ctypes.c_char_p, keys),
c_array(mx_uint, indptr),
c_array(mx_uint, sdata),
self.handle,
ctypes.byref(new_handle)))
_check_call(_LIB.MXPredFree(self.handle))
self.handle = new_handle
def get_output(self, index):
"""Get the index-th output.
Parameters
----------
index : int
The index of output.
Returns
-------
out : numpy array.
The output array.
"""
pdata = ctypes.POINTER(mx_uint)()
ndim = mx_uint()
out_type = mx_int()
_check_call(_LIB.MXPredGetOutputShape(
self.handle, index,
ctypes.byref(pdata),
ctypes.byref(ndim)))
_check_call(_LIB.MXPredGetOutputType(
self.handle, index,
ctypes.byref(out_type)))
shape = tuple(pdata[:ndim.value])
data = np.empty(shape, dtype=_DTYPE_MX_TO_NP[out_type.value])
_check_call(_LIB.MXPredGetOutput(
self.handle, mx_uint(index),
data.ctypes.data_as(mx_float_p),
mx_uint(data.size)))
return data
def set_monitor_callback(self, callback, monitor_all=False):
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p)
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
_check_call(_LIB.MXPredSetMonitorCallback(self.handle,
self._monitor_callback,
None,
ctypes.c_int(monitor_all)))
def load_ndarray_file(nd_bytes):
"""Load ndarray file and return as list of numpy array.
Parameters
----------
nd_bytes : str or bytes
The internal ndarray bytes
Returns
-------
out : dict of str to numpy array or list of numpy array
The output list or dict, depending on whether the saved type is list or dict.
"""
handle = NDListHandle()
olen = mx_uint()
nd_bytes = bytearray(nd_bytes)
ptr = (ctypes.c_char * len(nd_bytes)).from_buffer(nd_bytes)
_check_call(_LIB.MXNDListCreate(
ptr, len(nd_bytes),
ctypes.byref(handle), ctypes.byref(olen)))
keys = []
arrs = []
for i in range(olen.value):
key = ctypes.c_char_p()
cptr = mx_float_p()
pdata = ctypes.POINTER(mx_uint)()
ndim = mx_uint()
_check_call(_LIB.MXNDListGet(
handle, mx_uint(i), ctypes.byref(key),
ctypes.byref(cptr), ctypes.byref(pdata), ctypes.byref(ndim)))
shape = tuple(pdata[:ndim.value])
dbuffer = (mx_float * np.prod(shape)).from_address(ctypes.addressof(cptr.contents))
ret = np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
ret = np.array(ret, dtype=np.float32)
keys.append(py_str(key.value))
arrs.append(ret)
_check_call(_LIB.MXNDListFree(handle))
if len(keys) == 0 or len(keys[0]) == 0:
return arrs
else:
return {keys[i] : arrs[i] for i in range(len(keys))
}