blob: dcec54d1823d89100da77a1c929ad14a320a5ae3 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file ring_buffer.h
* \brief this file aims to provide a wrapper of sockets
*/
#ifndef TVM_COMMON_RING_BUFFER_H_
#define TVM_COMMON_RING_BUFFER_H_
#include <vector>
#include <cstring>
#include <algorithm>
namespace tvm {
namespace common {
/*!
* \brief Ring buffer class for data buffering in IO.
* Enables easy usage for sync and async mode.
*/
class RingBuffer {
public:
/*! \brief Initial capacity of ring buffer. */
static const int kInitCapacity = 4 << 10;
/*! \brief constructor */
RingBuffer() : ring_(kInitCapacity) {}
/*! \return number of bytes available in buffer. */
size_t bytes_available() const {
return bytes_available_;
}
/*! \return Current capacity of buffer. */
size_t capacity() const {
return ring_.size();
}
/*!
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
* \param n The size of capacity.
*/
void Reserve(size_t n) {
if (ring_.size() < n) {
size_t old_size = ring_.size();
size_t new_size = static_cast<size_t>(n * 1.2);
ring_.resize(new_size);
if (head_ptr_ + bytes_available_ > old_size) {
// copy the ring overflow part into the tail.
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
} else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) {
// shrink too large temporary buffer to avoid out of memory on some embedded devices
size_t old_bytes = bytes_available_;
std::vector<char> tmp(old_bytes);
Read(&tmp[0], old_bytes);
ring_.resize(kInitCapacity);
ring_.shrink_to_fit();
memcpy(&ring_[0], &tmp[0], old_bytes);
head_ptr_ = 0;
bytes_available_ = old_bytes;
}
}
/*!
* \brief Peform a non-blocking read from buffer
* size must be smaller than this->bytes_available()
* \param data the data pointer.
* \param size The number of bytes to read.
*/
void Read(void* data, size_t size) {
CHECK_GE(bytes_available_, size);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
memcpy(data, &ring_[0] + head_ptr_, ncopy);
if (ncopy < size) {
memcpy(reinterpret_cast<char*>(data) + ncopy,
&ring_[0], size - ncopy);
}
head_ptr_ = (head_ptr_ + size) % ring_.size();
bytes_available_ -= size;
}
/*!
* \brief Read data from buffer with and put them to non-blocking send function.
*
* \param fsend A send function handle to put the data to.
* \param max_nbytes Maximum number of bytes can to read.
* \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size);
*/
template<typename FSend>
size_t ReadWithCallback(FSend fsend, size_t max_nbytes) {
size_t size = std::min(max_nbytes, bytes_available_);
CHECK_NE(size, 0U);
size_t ncopy = std::min(size, ring_.size() - head_ptr_);
size_t nsend = fsend(&ring_[0] + head_ptr_, ncopy);
bytes_available_ -= nsend;
if (ncopy == nsend && ncopy < size) {
size_t nsend2 = fsend(&ring_[0], size - ncopy);
bytes_available_ -= nsend2;
nsend += nsend2;
}
return nsend;
}
/*!
* \brief Write data into buffer, always ensures all data is written.
* \param data The data pointer
* \param size The size of data to be written.
*/
void Write(const void* data, size_t size) {
this->Reserve(bytes_available_ + size);
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
memcpy(&ring_[0] + (tail - ring_.size()), data, size);
} else {
size_t ncopy = std::min(ring_.size() - tail, size);
memcpy(&ring_[0] + tail, data, ncopy);
if (ncopy < size) {
memcpy(&ring_[0], reinterpret_cast<const char*>(data) + ncopy, size - ncopy);
}
}
bytes_available_ += size;
}
/*!
* \brief Writen data into the buffer by give it a non-blocking callback function.
*
* \param frecv A receive function handle
* \param max_nbytes Maximum number of bytes can write.
* \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size);
*/
template<typename FRecv>
size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) {
this->Reserve(bytes_available_ + max_nbytes);
size_t nbytes = max_nbytes;
size_t tail = head_ptr_ + bytes_available_;
if (tail >= ring_.size()) {
size_t nrecv = frecv(&ring_[0] + (tail - ring_.size()), nbytes);
bytes_available_ += nrecv;
return nrecv;
} else {
size_t ncopy = std::min(ring_.size() - tail, nbytes);
size_t nrecv = frecv(&ring_[0] + tail, ncopy);
bytes_available_ += nrecv;
if (nrecv == ncopy && ncopy < nbytes) {
size_t nrecv2 = frecv(&ring_[0], nbytes - ncopy);
bytes_available_ += nrecv2;
nrecv += nrecv2;
}
return nrecv;
}
}
private:
// buffer head
size_t head_ptr_{0};
// number of bytes in the buffer.
size_t bytes_available_{0};
// The internald ata ring.
std::vector<char> ring_;
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_RING_BUFFER_H_