blob: 60b898cf0a670e5c41bc6c6c534f29d77dd9f0a9 [file] [log] [blame]
use std::io;
use byteorder::LittleEndian;
use byteorder::ReadBytesExt;
use byteorder::WriteBytesExt;
#[derive(Debug)]
pub struct BitWriter<W> {
inner: W,
buf: u32,
end: u8,
}
impl<W> BitWriter<W>
where
W: io::Write,
{
pub fn new(inner: W) -> Self {
BitWriter {
inner: inner,
buf: 0,
end: 0,
}
}
#[inline(always)]
pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
self.write_bits(1, bit as u16)
}
#[inline(always)]
pub fn write_bits(&mut self, bitwidth: u8, bits: u16) -> io::Result<()> {
debug_assert!(bitwidth < 16);
debug_assert!(self.end + bitwidth <= 32);
self.buf |= (bits as u32) << self.end;
self.end += bitwidth;
self.flush_if_needed()
}
pub fn flush(&mut self) -> io::Result<()> {
while self.end > 0 {
self.inner.write_u8(self.buf as u8)?;
self.buf >>= 8;
self.end = self.end.saturating_sub(8);
}
self.inner.flush()?;
Ok(())
}
#[inline(always)]
fn flush_if_needed(&mut self) -> io::Result<()> {
if self.end >= 16 {
self.inner.write_u16::<LittleEndian>(self.buf as u16)?;
self.end -= 16;
self.buf >>= 16;
}
Ok(())
}
}
impl<W> BitWriter<W> {
pub fn as_inner_ref(&self) -> &W {
&self.inner
}
pub fn as_inner_mut(&mut self) -> &mut W {
&mut self.inner
}
pub fn into_inner(self) -> W {
self.inner
}
}
#[derive(Debug)]
pub struct BitReader<R> {
inner: R,
last_read: u32,
offset: u8,
last_error: Option<io::Error>,
}
impl<R> BitReader<R>
where
R: io::Read,
{
pub fn new(inner: R) -> Self {
BitReader {
inner,
last_read: 0,
offset: 32,
last_error: None,
}
}
#[inline(always)]
pub fn set_last_error(&mut self, e: io::Error) {
self.last_error = Some(e);
}
#[inline(always)]
pub fn check_last_error(&mut self) -> io::Result<()> {
if let Some(e) = self.last_error.take() {
Err(e)
} else {
Ok(())
}
}
#[inline(always)]
pub fn read_bit(&mut self) -> io::Result<bool> {
self.read_bits(1).map(|b| b != 0)
}
#[inline(always)]
pub fn read_bits(&mut self, bitwidth: u8) -> io::Result<u16> {
let v = self.read_bits_unchecked(bitwidth);
self.check_last_error().map(|_| v)
}
#[inline(always)]
pub fn read_bits_unchecked(&mut self, bitwidth: u8) -> u16 {
let bits = self.peek_bits_unchecked(bitwidth);
self.skip_bits(bitwidth);
bits
}
#[inline(always)]
pub fn peek_bits_unchecked(&mut self, bitwidth: u8) -> u16 {
debug_assert!(bitwidth <= 16);
while (32 - self.offset) < bitwidth {
if self.last_error.is_some() {
return 0;
}
if let Err(e) = self.fill_next_u8() {
self.last_error = Some(e);
return 0;
}
}
debug_assert!(self.offset < 32 || bitwidth == 0);
let bits = self.last_read.wrapping_shr(self.offset as u32) as u16;
bits & ((1 << bitwidth) - 1)
}
#[inline(always)]
pub fn skip_bits(&mut self, bitwidth: u8) {
debug_assert!(32 - self.offset >= bitwidth);
self.offset += bitwidth;
}
#[inline(always)]
fn fill_next_u8(&mut self) -> io::Result<()> {
self.offset -= 8;
self.last_read >>= 8;
let next = self.inner.read_u8()? as u32;
self.last_read |= next << (32 - 8);
Ok(())
}
#[inline]
pub(crate) fn state(&self) -> BitReaderState {
BitReaderState {
last_read: self.last_read,
offset: self.offset,
}
}
#[inline]
pub(crate) fn restore_state(&mut self, state: BitReaderState) {
self.last_read = state.last_read;
self.offset = state.offset;
}
}
impl<R> BitReader<R> {
pub fn reset(&mut self) {
self.offset = 32;
}
pub fn as_inner_ref(&self) -> &R {
&self.inner
}
pub fn as_inner_mut(&mut self) -> &mut R {
&mut self.inner
}
pub fn into_inner(self) -> R {
self.inner
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct BitReaderState {
last_read: u32,
offset: u8,
}
#[cfg(test)]
mod test {
use std::io;
use super::*;
#[test]
fn writer_works() {
let mut writer = BitWriter::new(Vec::new());
writer.write_bit(true).unwrap();
writer.write_bits(3, 0b010).unwrap();
writer.write_bits(11, 0b10101011010).unwrap();
writer.flush().unwrap();
writer.write_bit(true).unwrap();
writer.flush().unwrap();
let buf = writer.into_inner();
assert_eq!(buf, [0b10100101, 0b01010101, 0b00000001]);
}
#[test]
fn reader_works() {
let buf = [0b10100101, 0b11010101];
let mut reader = BitReader::new(&buf[..]);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bits(8).unwrap(), 0b01101001);
assert_eq!(reader.peek_bits_unchecked(3), 0b101);
assert_eq!(reader.peek_bits_unchecked(3), 0b101);
reader.skip_bits(1);
assert_eq!(reader.peek_bits_unchecked(3), 0b010);
assert_eq!(
reader.read_bits(8).map_err(|e| e.kind()),
Err(io::ErrorKind::UnexpectedEof)
);
}
}