diff --git a/src/streams.h b/src/streams.h index 32a707381..d4a309be3 100644 --- a/src/streams.h +++ b/src/streams.h @@ -510,12 +510,105 @@ public: } }; +template +class BitStreamReader +{ +private: + IStream& m_istream; + /// Buffered byte read in from the input stream. A new byte is read into the + /// buffer when m_offset reaches 8. + uint8_t m_buffer{0}; + /// Number of high order bits in m_buffer already returned by previous + /// Read() calls. The next bit to be returned is at this offset from the + /// most significant bit position. + int m_offset{8}; +public: + explicit BitStreamReader(IStream& istream) : m_istream(istream) {} + /** Read the specified number of bits from the stream. The data is returned + * in the nbits least signficant bits of a 64-bit uint. + */ + uint64_t Read(int nbits) { + if (nbits < 0 || nbits > 64) { + throw std::out_of_range("nbits must be between 0 and 64"); + } + uint64_t data = 0; + while (nbits > 0) { + if (m_offset == 8) { + m_istream >> m_buffer; + m_offset = 0; + } + int bits = std::min(8 - m_offset, nbits); + data <<= bits; + data |= static_cast(m_buffer << m_offset) >> (8 - bits); + m_offset += bits; + nbits -= bits; + } + return data; + } +}; + +template +class BitStreamWriter +{ +private: + OStream& m_ostream; + + /// Buffered byte waiting to be written to the output stream. The byte is + /// written buffer when m_offset reaches 8 or Flush() is called. + uint8_t m_buffer{0}; + + /// Number of high order bits in m_buffer already written by previous + /// Write() calls and not yet flushed to the stream. The next bit to be + /// written to is at this offset from the most significant bit position. + int m_offset{0}; + +public: + explicit BitStreamWriter(OStream& ostream) : m_ostream(ostream) {} + + ~BitStreamWriter() + { + Flush(); + } + + /** Write the nbits least significant bits of a 64-bit int to the output + * stream. Data is buffered until it completes an octet. + */ + void Write(uint64_t data, int nbits) { + if (nbits < 0 || nbits > 64) { + throw std::out_of_range("nbits must be between 0 and 64"); + } + + while (nbits > 0) { + int bits = std::min(8 - m_offset, nbits); + m_buffer |= (data << (64 - nbits)) >> (64 - 8 + m_offset); + m_offset += bits; + nbits -= bits; + + if (m_offset == 8) { + Flush(); + } + } + } + + /** Flush any unwritten bits to the output stream, padding with 0's to the + * next byte boundary. + */ + void Flush() { + if (m_offset == 0) { + return; + } + + m_ostream << m_buffer; + m_buffer = 0; + m_offset = 0; + } +};