package rardecode

import (
	"errors"
	"io"
)

const (
	maxCodeLength = 15 // maximum code length in bits
	maxQuickBits  = 10
	maxQuickSize  = 1 << maxQuickBits
)

var (
	errHuffDecodeFailed   = errors.New("rardecode: huffman decode failed")
	errInvalidLengthTable = errors.New("rardecode: invalid huffman code length table")
)

type huffmanDecoder struct {
	limit     [maxCodeLength + 1]int
	pos       [maxCodeLength + 1]int
	symbol    []int
	min       uint
	quickbits uint
	quicklen  [maxQuickSize]uint
	quicksym  [maxQuickSize]int
}

func (h *huffmanDecoder) init(codeLengths []byte) {
	var count [maxCodeLength + 1]int

	for _, n := range codeLengths {
		if n == 0 {
			continue
		}
		count[n]++
	}

	h.pos[0] = 0
	h.limit[0] = 0
	h.min = 0
	for i := uint(1); i <= maxCodeLength; i++ {
		h.limit[i] = h.limit[i-1] + count[i]<<(maxCodeLength-i)
		h.pos[i] = h.pos[i-1] + count[i-1]
		if h.min == 0 && h.limit[i] > 0 {
			h.min = i
		}
	}

	if cap(h.symbol) >= len(codeLengths) {
		h.symbol = h.symbol[:len(codeLengths)]
		for i := range h.symbol {
			h.symbol[i] = 0
		}
	} else {
		h.symbol = make([]int, len(codeLengths))
	}

	copy(count[:], h.pos[:])
	for i, n := range codeLengths {
		if n != 0 {
			h.symbol[count[n]] = i
			count[n]++
		}
	}

	if len(codeLengths) >= 298 {
		h.quickbits = maxQuickBits
	} else {
		h.quickbits = maxQuickBits - 3
	}

	bits := uint(1)
	for i := 0; i < 1<<h.quickbits; i++ {
		v := i << (maxCodeLength - h.quickbits)

		for v >= h.limit[bits] && bits < maxCodeLength {
			bits++
		}
		h.quicklen[i] = bits

		dist := v - h.limit[bits-1]
		dist >>= (maxCodeLength - bits)

		pos := h.pos[bits] + dist
		if pos < len(h.symbol) {
			h.quicksym[i] = h.symbol[pos]
		} else {
			h.quicksym[i] = 0
		}
	}
}

func (h *huffmanDecoder) readSym(r bitReader) (int, error) {
	bits := uint(maxCodeLength)
	v, err := r.readBits(maxCodeLength)
	if err != nil {
		if err != io.EOF {
			return 0, err
		}
		// fall back to 1 bit at a time if we read past EOF
		for i := uint(1); i <= maxCodeLength; i++ {
			b, err := r.readBits(1)
			if err != nil {
				return 0, err // not enough bits return error
			}
			v |= b << (maxCodeLength - i)
			if v < h.limit[i] {
				bits = i
				break
			}
		}
	} else {
		if v < h.limit[h.quickbits] {
			i := v >> (maxCodeLength - h.quickbits)
			r.unreadBits(maxCodeLength - h.quicklen[i])
			return h.quicksym[i], nil
		}

		for i, n := range h.limit[h.min:] {
			if v < n {
				bits = h.min + uint(i)
				r.unreadBits(maxCodeLength - bits)
				break
			}
		}
	}

	dist := v - h.limit[bits-1]
	dist >>= maxCodeLength - bits

	pos := h.pos[bits] + dist
	if pos >= len(h.symbol) {
		return 0, errHuffDecodeFailed
	}

	return h.symbol[pos], nil
}

// readCodeLengthTable reads a new code length table into codeLength from br.
// If addOld is set the old table is added to the new one.
func readCodeLengthTable(br bitReader, codeLength []byte, addOld bool) error {
	var bitlength [20]byte
	for i := 0; i < len(bitlength); i++ {
		n, err := br.readBits(4)
		if err != nil {
			return err
		}
		if n == 0xf {
			cnt, err := br.readBits(4)
			if err != nil {
				return err
			}
			if cnt > 0 {
				// array already zero'd dont need to explicitly set
				i += cnt + 1
				continue
			}
		}
		bitlength[i] = byte(n)
	}

	var bl huffmanDecoder
	bl.init(bitlength[:])

	for i := 0; i < len(codeLength); i++ {
		l, err := bl.readSym(br)
		if err != nil {
			return err
		}

		if l < 16 {
			if addOld {
				codeLength[i] = (codeLength[i] + byte(l)) & 0xf
			} else {
				codeLength[i] = byte(l)
			}
			continue
		}

		var count int
		var value byte

		switch l {
		case 16, 18:
			count, err = br.readBits(3)
			count += 3
		default:
			count, err = br.readBits(7)
			count += 11
		}
		if err != nil {
			return err
		}
		if l < 18 {
			if i == 0 {
				return errInvalidLengthTable
			}
			value = codeLength[i-1]
		}
		for ; count > 0 && i < len(codeLength); i++ {
			codeLength[i] = value
			count--
		}
		i--
	}
	return nil
}