blob: 94902f71a5f6a2152eb19736c54e6eda1309a571 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file test_perf.h
* \brief operator unit test utility functions
* \author Chris Olivier
*/
#ifndef TEST_PERF_H_
#define TEST_PERF_H_
#ifndef _WIN32
#include <sys/time.h>
#else
#include <Windows.h>
#endif
#include <dmlc/logging.h>
#include <iomanip>
#include <iostream>
#include <atomic>
#include <unordered_set>
#include <unordered_map>
#include <mutex>
#include <string>
#include <map>
namespace mxnet {
namespace test {
namespace perf {
/*! \brief current timestamp: millionths of a second */
inline uint64_t getMicroTickCount() {
#ifndef _WIN32
struct timeval tv;
gettimeofday(&tv, nullptr);
return uint64_t(tv.tv_sec) * 1000000 + tv.tv_usec;
#else
LARGE_INTEGER CurrentTime;
LARGE_INTEGER Frequency;
QueryPerformanceFrequency(&Frequency);
QueryPerformanceCounter(&CurrentTime);
CurrentTime.QuadPart *= 1000000;
CurrentTime.QuadPart /= Frequency.QuadPart;
return CurrentTime.QuadPart;
#endif
}
/*! \brief current timestamp: millionths of a second */
inline uint64_t getNannoTickCount() {
#ifndef _WIN32
struct timeval tv;
gettimeofday(&tv, NULL);
return (uint64_t(tv.tv_sec) * 1000000 + tv.tv_usec) * 1000;
#else
LARGE_INTEGER CurrentTime;
LARGE_INTEGER Frequency;
QueryPerformanceFrequency(&Frequency);
QueryPerformanceCounter(&CurrentTime);
CurrentTime.QuadPart *= 1000000000;
CurrentTime.QuadPart /= Frequency.QuadPart;
return CurrentTime.QuadPart;
#endif
}
#define MICRO2MS(__micro$) (((__micro$) + 500) / 1000)
#define MICRO2MSF(__micro$) (static_cast<float>(__micro$) / 1000)
#define MICRO2MSF(__micro$) (static_cast<float>(__micro$) / 1000)
#define MS2MICRO(__ms$) ((__ms$)*1000)
#define NANO2MSF(__nano$) (static_cast<float>(__nano$) / 1000000)
#define MICRO2S(__micro$) (((__micro$) + 500000) / 1000000)
#define MICRO2SF(__micro$) (MICRO2MSF(__micro$) / 1000)
/*! \brief Calculate time between construction and destruction */
class TimedScope {
std::string label_;
uint64_t startTime_;
uint64_t stopTime_;
const size_t count_;
public:
explicit inline TimedScope(const char* msg = nullptr, size_t count = 1, const bool start = true)
: startTime_(start ? getMicroTickCount() : 0), stopTime_(0), count_(count) {
CHECK_NE(count, 0U);
if (msg && *msg) {
label_ = msg;
}
}
explicit inline TimedScope(const std::string& msg, size_t count = 1, const bool start = true)
: startTime_(start ? getMicroTickCount() : 0), count_(count) {
CHECK_NE(count, 0U);
if (!msg.empty()) {
label_ = msg;
}
}
inline ~TimedScope() {
print();
}
inline void start() {
startTime_ = getMicroTickCount();
}
inline void stop() {
stopTime_ = getMicroTickCount();
}
inline float elapsedMilliseconds() const {
const uint64_t elapsed = stopTime_ ? stopTime_ - startTime_ : getMicroTickCount() - startTime_;
return MICRO2MSF(elapsed);
}
inline uint64_t elapsedMicroseconds() const {
return stopTime_ ? stopTime_ - startTime_ : getMicroTickCount() - startTime_;
}
void print() const {
const uint64_t diff = elapsedMicroseconds();
std::stringstream ss;
if (!label_.empty()) {
ss << label_ << " ";
}
ss << "elapsed time: " << std::setprecision(4) << std::fixed << MICRO2MSF(diff) << " ms";
if (count_ != 0 && count_ != 1) {
const float microSecondsEach = static_cast<float>(diff) / count_;
ss << " ( " << MICRO2MSF(microSecondsEach) << " ms each )";
}
std::cout << ss.str() << std::endl;
}
};
/*! \brief Accumulate separate timing values mapped by label/id -> total time spent */
class TimingInstrument {
public:
explicit TimingInstrument(const char* name = "") : name_(name) {}
void startTiming(int id, const char* s) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
auto i = data_.find(id);
if (i == data_.end()) {
i = data_.emplace(std::make_pair(id, Info(s))).first;
}
if (!i->second.nestingCount_++) {
i->second.baseTime_ = getMicroTickCount();
}
}
void stopTiming(int id, const size_t subIterationCount = 1) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
auto i = data_.find(id);
CHECK_NE(i == data_.end(), true) << "Can't stop timing on an object that we don't know about";
if (i != data_.end()) {
CHECK_NE(i->second.nestingCount_, 0U) << "While stopping timing, invalid nesting count of 0";
if (!--i->second.nestingCount_) {
CHECK_NE(i->second.baseTime_, 0U) << "Invalid base time";
i->second.duration_.fetch_add(getMicroTickCount() - i->second.baseTime_);
i->second.baseTime_ = 0;
i->second.cycleCount_.fetch_add(subIterationCount);
}
}
}
uint64_t getDuration(int id) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
auto i = data_.find(id);
if (i != data_.end()) {
const Info& info = i->second;
const uint64_t duration =
info.nestingCount_.load() ?
info.duration_.load() + (getMicroTickCount() - info.baseTime_.load()) :
info.duration_.load();
return duration;
}
return 0;
}
bool isTiming(int id) {
std::unordered_map<int, Info>::const_iterator i = data_.find(id);
if (i != data_.end()) {
return i->second.nestingCount_.load() != 0;
}
return false;
}
template <typename StreamType>
void print(StreamType* os, const std::string& label_, bool doReset = false) {
std::unique_lock<std::recursive_mutex> lk(mutex_);
// Sorted output
std::map<int, Info> data(data_.begin(), data_.end());
for (std::map<int, Info>::const_iterator i = data.begin(), e = data.end(); i != e; ++i) {
const Info& info = i->second;
const uint64_t duration = getDuration(i->first);
*os << label_ << ": " << name_ << " Timing [" << info.name_ << "] "
<< (info.nestingCount_.load() ? "*" : "") << MICRO2MSF(duration) << " ms";
if (info.cycleCount_.load()) {
*os << ", avg: " << (MICRO2MSF(duration) / info.cycleCount_) << " ms X " << info.cycleCount_
<< " passes";
}
*os << std::endl;
}
*os << std::flush;
if (doReset) {
reset();
}
}
void reset() {
std::unique_lock<std::recursive_mutex> lk(mutex_);
for (auto i = data_.begin(), e = data_.end(); i != e; ++i) {
const int id = i->first;
const bool wasTiming = isTiming(id);
if (wasTiming) {
stopTiming(id);
}
// need zero count here
CHECK_EQ(i->second.nestingCount_.load(), 0U);
i->second.duration_ = 0;
if (wasTiming) {
startTiming(id, i->second.name_.c_str());
}
}
}
TimingInstrument& operator+=(const TimingInstrument& o) {
for (auto i = o.data_.begin(), e = o.data_.end(); i != e; ++i) {
auto j = data_.find(i->first);
if (j != data_.end()) {
const Info& oInfo = i->second;
CHECK_EQ(oInfo.nestingCount_, 0U);
j->second.duration_ += oInfo.duration_;
j->second.cycleCount_ += oInfo.cycleCount_;
} else {
data_.insert(std::make_pair(i->first, i->second));
}
}
return *this;
}
struct Info {
explicit inline Info(const char* s)
: name_(s ? s : ""), baseTime_(0), nestingCount_(0), cycleCount_(0), duration_(0) {}
inline Info(const Info& o)
: name_(o.name_),
baseTime_(o.baseTime_.load()),
nestingCount_(o.nestingCount_.load()),
cycleCount_(o.cycleCount_.load()),
duration_(o.duration_.load()) {
CHECK_EQ(o.nestingCount_, 0U);
}
inline Info& operator=(const Info& o) {
name_ = o.name_;
baseTime_.store(baseTime_.load());
nestingCount_.store(nestingCount_.load());
cycleCount_.store(cycleCount_.load());
duration_.store(duration_.load());
return *this;
}
/*!
* \brief Return time for each operation in milliseconds
* \return Time for each operation in milliseconds
*/
inline double TimeEach() const {
return static_cast<double>(duration_) / cycleCount_.load() / 1000.0f;
}
std::string name_;
std::atomic<uint64_t> baseTime_;
std::atomic<uint64_t> nestingCount_;
std::atomic<uint64_t> cycleCount_; // Note that nesting may skew averages
std::atomic<uint64_t> duration_;
};
typedef std::unordered_map<int, TimingInstrument::Info> timing_map_t;
const timing_map_t& data() const {
return data_;
}
private:
std::string name_;
mutable std::recursive_mutex mutex_;
std::unordered_map<int, Info> data_;
};
using timing_map_t = TimingInstrument::timing_map_t;
/*! \brief Accumulated scoped timing, indexed by ID */
class TimingItem {
public:
inline TimingItem(TimingInstrument* ti,
int id,
const char* name,
const size_t subIterationCount = 1)
: ti_(ti), id_(id), subIterationCount_(subIterationCount) {
if (ti_) {
ti_->startTiming(id, name);
}
}
inline ~TimingItem() {
if (ti_) {
ti_->stopTiming(id_, subIterationCount_);
}
}
private:
TimingInstrument* ti_;
const int id_;
const size_t subIterationCount_;
};
} // namespace perf
} // namespace test
} // namespace mxnet
#endif // TEST_PERF_H_