blob: c35b2bb4912269021eb2a1c412e90b750b3451a3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file c_runtime_api.cc
* \brief Device specific implementations
*/
// Acknowledgement: This file originates from incubator-tvm
#include <mxnet/runtime/c_runtime_api.h>
#include <dmlc/thread_local.h>
#include <mxnet/runtime/packed_func.h>
#include <mxnet/runtime/registry.h>
#include <sstream>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include <cctype>
#include "../c_api/c_api_common.h"
using namespace mxnet::runtime;
struct MXNetRuntimeEntry {
std::string ret_str;
MXNetByteArray ret_bytes;
std::string last_error;
};
typedef dmlc::ThreadLocalStore<MXNetRuntimeEntry> MXNetAPIRuntimeStore;
int MXNetFuncFree(MXNetFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int MXNetFuncCall(MXNetFunctionHandle func,
MXNetValue* args,
int* arg_type_codes,
int num_args,
MXNetValue* ret_val,
int* ret_type_code) {
API_BEGIN();
MXNetRetValue rv;
(*static_cast<const PackedFunc*>(func))
.CallPacked(MXNetArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr || rv.type_code() == kBytes) {
MXNetRuntimeEntry* e = MXNetAPIRuntimeStore::Get();
e->ret_str = *rv.ptr<std::string>();
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
//--------------------------------------------------------
// Error handling mechanism
// -------------------------------------------------------
// Standard error message format, {} means optional
//--------------------------------------------------------
// {error_type:} {message0}
// {message1}
// {message2}
// {Stack trace:} // stack traces follow by this line
// {trace 0} // two spaces in the begining.
// {trace 1}
// {trace 2}
//--------------------------------------------------------
/*!
* \brief Normalize error message
*
* Parse them header generated by by LOG(FATAL) and CHECK
* and reformat the message into the standard format.
*
* This function will also merge all the stack traces into
* one trace and trim them.
*
* \param err_msg The error message.
* \return normalized message.
*/
std::string NormalizeError(std::string err_msg) {
// ------------------------------------------------------------------------
// log with header, {} indicates optional
//-------------------------------------------------------------------------
// [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0}
// {message1}
// Stack trace:
// {stack trace 0}
// {stack trace 1}
//-------------------------------------------------------------------------
// Normalzied version
//-------------------------------------------------------------------------
// error_type: check_msg message0
// {message1}
// Stack trace:
// File file_name, line lineno
// {stack trace 0}
// {stack trace 1}
//-------------------------------------------------------------------------
int line_number = 0;
std::istringstream is(err_msg);
std::string line, file_name, error_type, check_msg;
// Parse log header and set the fields,
// Return true if it the log is in correct format,
// return false if something is wrong.
auto parse_log_header = [&]() {
// skip timestamp
if (is.peek() != '[') {
getline(is, line);
return true;
}
if (!(is >> line))
return false;
// get filename
while (is.peek() == ' ')
is.get();
if (!getline(is, file_name, ':')) {
return false;
} else {
if (is.peek() == '\\' || is.peek() == '/') {
// windows path
if (!getline(is, line, ':'))
return false;
file_name = file_name + ':' + line;
}
}
// get line number
if (!(is >> line_number))
return false;
// get rest of the message.
while (is.peek() == ' ' || is.peek() == ':')
is.get();
if (!getline(is, line))
return false;
// detect check message, rewrite to remote extra :
if (line.compare(0, 13, "Check failed:") == 0) {
size_t end_pos = line.find(':', 13);
if (end_pos == std::string::npos)
return false;
check_msg = line.substr(0, end_pos + 1) + ' ';
line = line.substr(end_pos + 1);
}
return true;
};
// if not in correct format, do not do any rewrite.
if (!parse_log_header())
return err_msg;
// Parse error type.
{
size_t start_pos = 0, end_pos;
for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {
}
for (end_pos = start_pos; end_pos < line.length(); ++end_pos) {
char ch = line[end_pos];
if (ch == ':') {
error_type = line.substr(start_pos, end_pos - start_pos);
break;
}
// [A-Z0-9a-z_.]
if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.')
break;
}
if (error_type.length() != 0) {
// if we successfully detected error_type: trim the following space.
for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' ';
++start_pos) {
}
line = line.substr(start_pos);
} else {
// did not detect error_type, use default value.
line = line.substr(start_pos);
error_type = "MXNetError";
}
}
// Seperate out stack trace.
std::ostringstream os;
os << error_type << ": " << check_msg << line << '\n';
bool trace_mode = true;
std::vector<std::string> stack_trace;
while (getline(is, line)) {
if (trace_mode) {
if (line.compare(0, 2, " ") == 0) {
stack_trace.push_back(line);
} else {
trace_mode = false;
// remove EOL trailing stacktrace.
if (line.length() == 0)
continue;
}
}
if (!trace_mode) {
if (line.compare(0, 11, "Stack trace") == 0) {
trace_mode = true;
} else {
os << line << '\n';
}
}
}
if (stack_trace.size() != 0 || file_name.length() != 0) {
os << "Stack trace:\n";
if (file_name.length() != 0) {
os << " File \"" << file_name << "\", line " << line_number << "\n";
}
// Print out stack traces, optionally trim the c++ traces
// about the frontends (as they will be provided by the frontends).
bool ffi_boundary = false;
for (const auto& line : stack_trace) {
// Heuristic to detect python ffi.
if (line.find("libffi.so") != std::string::npos ||
line.find("core.cpython") != std::string::npos) {
ffi_boundary = true;
}
// If the backtrace is not c++ backtrace with the prefix " [bt]",
// then we can stop trimming.
if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) {
ffi_boundary = false;
}
if (!ffi_boundary) {
os << line << '\n';
}
}
}
return os.str();
}
#else
std::string NormalizeError(std::string err_msg) {
return err_msg;
}
#endif
int MXAPIHandleException(const std::exception& e) {
MXAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}
const char* MXGetLastError() {
return MXNetAPIRuntimeStore::Get()->last_error.c_str();
}
void MXAPISetLastError(const char* msg) {
MXNetAPIRuntimeStore::Get()->last_error = msg;
}