blob: d37652dd2c051976abc96c91574dc35fafe93fca [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file monitor.hpp
* \brief monitor implementation
* \author Xin Li
*/
#ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_
#define CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_
#include <cmath>
#include <sstream>
#include <algorithm>
#include <vector>
#include <string>
#include "mxnet-cpp/monitor.h"
namespace mxnet {
namespace cpp {
inline NDArray _default_monitor_func(const NDArray &x) {
return Operator("norm").PushInput(x).Invoke()[0] / std::sqrt(x.Size());
}
inline Monitor::Monitor(int interval, std::regex pattern, StatFunc stat_func)
: interval(interval), pattern(pattern), stat_func(stat_func), step(0) {
}
inline void Monitor::install(Executor *exe) {
MXExecutorSetMonitorCallback(exe->handle_,
static_cast<ExecutorMonitorCallback>(&Monitor::executor_callback),
this);
exes.push_back(exe);
}
inline void Monitor::tic() {
if (step % interval == 0) {
activated = true;
stats.clear();
}
}
inline std::vector<Monitor::Stat> Monitor::toc() {
std::vector<Monitor::Stat> results;
if (activated) {
activated = false;
for (auto* exe : exes) {
for (auto& arg : exe->arg_arrays) {
arg.WaitToRead();
}
for (auto& aux : exe->aux_arrays) {
aux.WaitToRead();
}
for (auto &pair : exe->arg_dict()) {
if (std::regex_match(pair.first, pattern)) {
stats.emplace_back(step, pair.first, stat_func(pair.second));
}
}
for (auto &pair : exe->aux_dict()) {
if (std::regex_match(pair.first, pattern)) {
stats.emplace_back(step, pair.first, stat_func(pair.second));
}
}
}
results.swap(stats);
}
++step;
return results;
}
inline void Monitor::toc_print() {
auto results = toc();
std::vector<float> data(1);
for (auto& stat : results) {
NDArray ndarray = std::get<2>(stat);
std::string str;
if (ndarray.Size() == 1) {
if (ndarray.GetContext().GetDeviceType() != DeviceType::kGPU) {
data[0] = ndarray.GetData()[0];
} else {
ndarray.SyncCopyToCPU(&data);
}
str = std::to_string(data[0]);
} else {
std::ostringstream out;
out << ndarray;
str = out.str();
}
LG << "Batch: " << std::get<0>(stat) << ' ' << std::get<1>(stat) << ' ' << str;
}
}
inline void Monitor::executor_callback(const char *name, NDArrayHandle handle,
void *monitor_ptr) {
Monitor *monitor = static_cast<Monitor*>(monitor_ptr);
if (monitor->activated && std::regex_match(name, monitor->pattern)) {
monitor->stats.emplace_back(monitor->step, name, monitor->stat_func(NDArray(handle)));
}
}
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_