blob: b7df6a5e51fe077c48b9fb5c03576e556bbbb3b8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file c_api_error.cc
* \brief C error handling
*/
#include <nnvm/c_api.h>
#include "./c_api_common.h"
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
//--------------------------------------------------------
// Error handling mechanism
// -------------------------------------------------------
// Standard error message format, {} means optional
//--------------------------------------------------------
// {error_type:} {message0}
// {message1}
// {message2}
// {Stack trace:} // stack traces follow by this line
// {trace 0} // two spaces in the begining.
// {trace 1}
// {trace 2}
//--------------------------------------------------------
/*!
* \brief Normalize error message
*
* Parse them header generated by by LOG(FATAL) and CHECK
* and reformat the message into the standard format.
*
* This function will also merge all the stack traces into
* one trace and trim them.
*
* \param err_msg The error message.
* \return normalized message.
*/
std::string NormalizeError(std::string err_msg) {
// ------------------------------------------------------------------------
// log with header, {} indicates optional
//-------------------------------------------------------------------------
// [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0}
// {message1}
// Stack trace:
// {stack trace 0}
// {stack trace 1}
//-------------------------------------------------------------------------
// Normalzied version
//-------------------------------------------------------------------------
// error_type: check_msg message0
// {message1}
// Stack trace:
// File file_name, line lineno
// {stack trace 0}
// {stack trace 1}
//-------------------------------------------------------------------------
int line_number = 0;
std::istringstream is(err_msg);
std::string line, file_name, error_type, check_msg;
// Parse log header and set the fields,
// Return true if it the log is in correct format,
// return false if something is wrong.
auto parse_log_header = [&]() {
// skip timestamp
if (is.peek() != '[') {
getline(is, line);
return true;
}
if (!(is >> line)) return false;
// get filename
while (is.peek() == ' ') is.get();
if (!getline(is, file_name, ':')) {
return false;
} else {
if (is.peek() == '\\' || is.peek() == '/') {
// windows path
if (!getline(is, line, ':')) return false;
file_name = file_name + ':' + line;
}
}
// get line number
if (!(is >> line_number)) return false;
// get rest of the message.
while (is.peek() == ' ' || is.peek() == ':') is.get();
if (!getline(is, line)) return false;
// detect check message, rewrite to remote extra :
if (line.compare(0, 13, "Check failed:") == 0) {
size_t end_pos = line.find(':', 13);
if (end_pos == std::string::npos) return false;
check_msg = line.substr(0, end_pos + 1) + ' ';
line = line.substr(end_pos + 1);
}
return true;
};
// if not in correct format, do not do any rewrite.
if (!parse_log_header()) return err_msg;
// Parse error type.
{
size_t start_pos = 0, end_pos;
for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {}
for (end_pos = start_pos; end_pos < line.length(); ++end_pos) {
char ch = line[end_pos];
if (ch == ':') {
error_type = line.substr(start_pos, end_pos - start_pos);
break;
}
// [A-Z0-9a-z_.]
if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break;
}
if (error_type.length() != 0) {
// if we successfully detected error_type: trim the following space.
for (start_pos = end_pos + 1;
start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {}
line = line.substr(start_pos);
} else {
// did not detect error_type, use default value.
line = line.substr(start_pos);
error_type = "MXNetError";
}
}
// Seperate out stack trace.
std::ostringstream os;
os << error_type << ": " << check_msg << line << '\n';
bool trace_mode = true;
std::vector<std::string> stack_trace;
while (getline(is, line)) {
if (trace_mode) {
if (line.compare(0, 2, " ") == 0) {
stack_trace.push_back(line);
} else {
trace_mode = false;
// remove EOL trailing stacktrace.
if (line.length() == 0) continue;
}
}
if (!trace_mode) {
if (line.compare(0, 11, "Stack trace") == 0) {
trace_mode = true;
} else {
os << line << '\n';
}
}
}
if (stack_trace.size() != 0 || file_name.length() != 0) {
os << "Stack trace:\n";
if (file_name.length() != 0) {
os << " File \"" << file_name << "\", line " << line_number << "\n";
}
// Print out stack traces, optionally trim the c++ traces
// about the frontends (as they will be provided by the frontends).
bool ffi_boundary = false;
for (const auto& line : stack_trace) {
// Heuristic to detect python ffi.
if (line.find("libffi.so") != std::string::npos ||
line.find("core.cpython") != std::string::npos) {
ffi_boundary = true;
}
// If the backtrace is not c++ backtrace with the prefix " [bt]",
// then we can stop trimming.
if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) {
ffi_boundary = false;
}
if (!ffi_boundary) {
os << line << '\n';
}
}
}
return os.str();
}
#else
std::string NormalizeError(std::string err_msg) {
return err_msg;
}
#endif
int MXAPIHandleException(const std::exception &e) {
MXAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}
const char *MXGetLastError() {
return NNGetLastError();
}
void MXAPISetLastError(const char* msg) {
NNAPISetLastError(msg);
}