| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file c_runtime_api.cc |
| * \brief Device specific implementations |
| */ |
| // Acknowledgement: This file originates from incubator-tvm |
| |
| #include <mxnet/runtime/c_runtime_api.h> |
| #include <dmlc/thread_local.h> |
| #include <mxnet/runtime/packed_func.h> |
| #include <mxnet/runtime/registry.h> |
| #include <sstream> |
| #include <array> |
| #include <algorithm> |
| #include <string> |
| #include <cstdlib> |
| #include <cctype> |
| |
| #include "../c_api/c_api_common.h" |
| |
| using namespace mxnet::runtime; |
| |
| struct MXNetRuntimeEntry { |
| std::string ret_str; |
| MXNetByteArray ret_bytes; |
| std::string last_error; |
| }; |
| |
| typedef dmlc::ThreadLocalStore<MXNetRuntimeEntry> MXNetAPIRuntimeStore; |
| |
| int MXNetFuncFree(MXNetFunctionHandle func) { |
| API_BEGIN(); |
| delete static_cast<PackedFunc*>(func); |
| API_END(); |
| } |
| |
| int MXNetFuncCall(MXNetFunctionHandle func, |
| MXNetValue* args, |
| int* arg_type_codes, |
| int num_args, |
| MXNetValue* ret_val, |
| int* ret_type_code) { |
| API_BEGIN(); |
| MXNetRetValue rv; |
| (*static_cast<const PackedFunc*>(func)) |
| .CallPacked(MXNetArgs(args, arg_type_codes, num_args), &rv); |
| // handle return string. |
| if (rv.type_code() == kStr || rv.type_code() == kBytes) { |
| MXNetRuntimeEntry* e = MXNetAPIRuntimeStore::Get(); |
| e->ret_str = *rv.ptr<std::string>(); |
| if (rv.type_code() == kBytes) { |
| e->ret_bytes.data = e->ret_str.c_str(); |
| e->ret_bytes.size = e->ret_str.length(); |
| *ret_type_code = kBytes; |
| ret_val->v_handle = &(e->ret_bytes); |
| } else { |
| *ret_type_code = kStr; |
| ret_val->v_str = e->ret_str.c_str(); |
| } |
| } else { |
| rv.MoveToCHost(ret_val, ret_type_code); |
| } |
| API_END(); |
| } |
| |
| #ifndef _LIBCPP_SGX_NO_IOSTREAMS |
| //-------------------------------------------------------- |
| // Error handling mechanism |
| // ------------------------------------------------------- |
| // Standard error message format, {} means optional |
| //-------------------------------------------------------- |
| // {error_type:} {message0} |
| // {message1} |
| // {message2} |
| // {Stack trace:} // stack traces follow by this line |
| // {trace 0} // two spaces in the begining. |
| // {trace 1} |
| // {trace 2} |
| //-------------------------------------------------------- |
| /*! |
| * \brief Normalize error message |
| * |
| * Parse them header generated by by LOG(FATAL) and CHECK |
| * and reformat the message into the standard format. |
| * |
| * This function will also merge all the stack traces into |
| * one trace and trim them. |
| * |
| * \param err_msg The error message. |
| * \return normalized message. |
| */ |
| std::string NormalizeError(std::string err_msg) { |
| // ------------------------------------------------------------------------ |
| // log with header, {} indicates optional |
| //------------------------------------------------------------------------- |
| // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0} |
| // {message1} |
| // Stack trace: |
| // {stack trace 0} |
| // {stack trace 1} |
| //------------------------------------------------------------------------- |
| // Normalzied version |
| //------------------------------------------------------------------------- |
| // error_type: check_msg message0 |
| // {message1} |
| // Stack trace: |
| // File file_name, line lineno |
| // {stack trace 0} |
| // {stack trace 1} |
| //------------------------------------------------------------------------- |
| int line_number = 0; |
| std::istringstream is(err_msg); |
| std::string line, file_name, error_type, check_msg; |
| |
| // Parse log header and set the fields, |
| // Return true if it the log is in correct format, |
| // return false if something is wrong. |
| auto parse_log_header = [&]() { |
| // skip timestamp |
| if (is.peek() != '[') { |
| getline(is, line); |
| return true; |
| } |
| if (!(is >> line)) |
| return false; |
| // get filename |
| while (is.peek() == ' ') |
| is.get(); |
| if (!getline(is, file_name, ':')) { |
| return false; |
| } else { |
| if (is.peek() == '\\' || is.peek() == '/') { |
| // windows path |
| if (!getline(is, line, ':')) |
| return false; |
| file_name = file_name + ':' + line; |
| } |
| } |
| // get line number |
| if (!(is >> line_number)) |
| return false; |
| // get rest of the message. |
| while (is.peek() == ' ' || is.peek() == ':') |
| is.get(); |
| if (!getline(is, line)) |
| return false; |
| // detect check message, rewrite to remote extra : |
| if (line.compare(0, 13, "Check failed:") == 0) { |
| size_t end_pos = line.find(':', 13); |
| if (end_pos == std::string::npos) |
| return false; |
| check_msg = line.substr(0, end_pos + 1) + ' '; |
| line = line.substr(end_pos + 1); |
| } |
| return true; |
| }; |
| // if not in correct format, do not do any rewrite. |
| if (!parse_log_header()) |
| return err_msg; |
| // Parse error type. |
| { |
| size_t start_pos = 0, end_pos; |
| for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { |
| } |
| for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { |
| char ch = line[end_pos]; |
| if (ch == ':') { |
| error_type = line.substr(start_pos, end_pos - start_pos); |
| break; |
| } |
| // [A-Z0-9a-z_.] |
| if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') |
| break; |
| } |
| if (error_type.length() != 0) { |
| // if we successfully detected error_type: trim the following space. |
| for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; |
| ++start_pos) { |
| } |
| line = line.substr(start_pos); |
| } else { |
| // did not detect error_type, use default value. |
| line = line.substr(start_pos); |
| error_type = "MXNetError"; |
| } |
| } |
| // Seperate out stack trace. |
| std::ostringstream os; |
| os << error_type << ": " << check_msg << line << '\n'; |
| |
| bool trace_mode = true; |
| std::vector<std::string> stack_trace; |
| while (getline(is, line)) { |
| if (trace_mode) { |
| if (line.compare(0, 2, " ") == 0) { |
| stack_trace.push_back(line); |
| } else { |
| trace_mode = false; |
| // remove EOL trailing stacktrace. |
| if (line.length() == 0) |
| continue; |
| } |
| } |
| if (!trace_mode) { |
| if (line.compare(0, 11, "Stack trace") == 0) { |
| trace_mode = true; |
| } else { |
| os << line << '\n'; |
| } |
| } |
| } |
| if (stack_trace.size() != 0 || file_name.length() != 0) { |
| os << "Stack trace:\n"; |
| if (file_name.length() != 0) { |
| os << " File \"" << file_name << "\", line " << line_number << "\n"; |
| } |
| // Print out stack traces, optionally trim the c++ traces |
| // about the frontends (as they will be provided by the frontends). |
| bool ffi_boundary = false; |
| for (const auto& line : stack_trace) { |
| // Heuristic to detect python ffi. |
| if (line.find("libffi.so") != std::string::npos || |
| line.find("core.cpython") != std::string::npos) { |
| ffi_boundary = true; |
| } |
| // If the backtrace is not c++ backtrace with the prefix " [bt]", |
| // then we can stop trimming. |
| if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) { |
| ffi_boundary = false; |
| } |
| if (!ffi_boundary) { |
| os << line << '\n'; |
| } |
| } |
| } |
| return os.str(); |
| } |
| |
| #else |
| std::string NormalizeError(std::string err_msg) { |
| return err_msg; |
| } |
| #endif |
| |
| int MXAPIHandleException(const std::exception& e) { |
| MXAPISetLastError(NormalizeError(e.what()).c_str()); |
| return -1; |
| } |
| |
| const char* MXGetLastError() { |
| return MXNetAPIRuntimeStore::Get()->last_error.c_str(); |
| } |
| |
| void MXAPISetLastError(const char* msg) { |
| MXNetAPIRuntimeStore::Get()->last_error = msg; |
| } |