// Copyright 2015, Joe Tsai. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.

// Package prefix implements bit readers and writers that use prefix encoding.
package prefix

import (
	"fmt"
	"sort"

	"github.com/dsnet/compress/internal"
	"github.com/dsnet/compress/internal/errors"
)

func errorf(c int, f string, a ...interface{}) error {
	return errors.Error{Code: c, Pkg: "prefix", Msg: fmt.Sprintf(f, a...)}
}

func panicf(c int, f string, a ...interface{}) {
	errors.Panic(errorf(c, f, a...))
}

const (
	countBits = 5  // Number of bits to store the bit-length of the code
	valueBits = 27 // Number of bits to store the code value

	countMask = (1 << countBits) - 1
)

// PrefixCode is a representation of a prefix code, which is conceptually a
// mapping from some arbitrary symbol to some bit-string.
//
// The Sym and Cnt fields are typically provided by the user,
// while the Len and Val fields are generated by this package.
type PrefixCode struct {
	Sym uint32 // The symbol being mapped
	Cnt uint32 // The number times this symbol is used
	Len uint32 // Bit-length of the prefix code
	Val uint32 // Value of the prefix code (must be in 0..(1<<Len)-1)
}
type PrefixCodes []PrefixCode

type prefixCodesBySymbol []PrefixCode

func (c prefixCodesBySymbol) Len() int           { return len(c) }
func (c prefixCodesBySymbol) Less(i, j int) bool { return c[i].Sym < c[j].Sym }
func (c prefixCodesBySymbol) Swap(i, j int)      { c[i], c[j] = c[j], c[i] }

type prefixCodesByCount []PrefixCode

func (c prefixCodesByCount) Len() int { return len(c) }
func (c prefixCodesByCount) Less(i, j int) bool {
	return c[i].Cnt < c[j].Cnt || (c[i].Cnt == c[j].Cnt && c[i].Sym < c[j].Sym)
}
func (c prefixCodesByCount) Swap(i, j int) { c[i], c[j] = c[j], c[i] }

func (pc PrefixCodes) SortBySymbol() { sort.Sort(prefixCodesBySymbol(pc)) }
func (pc PrefixCodes) SortByCount()  { sort.Sort(prefixCodesByCount(pc)) }

// Length computes the total bit-length using the Len and Cnt fields.
func (pc PrefixCodes) Length() (nb uint) {
	for _, c := range pc {
		nb += uint(c.Len * c.Cnt)
	}
	return nb
}

// checkLengths reports whether the codes form a complete prefix tree.
func (pc PrefixCodes) checkLengths() bool {
	sum := 1 << valueBits
	for _, c := range pc {
		sum -= (1 << valueBits) >> uint(c.Len)
	}
	return sum == 0 || len(pc) == 0
}

// checkPrefixes reports whether all codes have non-overlapping prefixes.
func (pc PrefixCodes) checkPrefixes() bool {
	for i, c1 := range pc {
		for j, c2 := range pc {
			mask := uint32(1)<<c1.Len - 1
			if i != j && c1.Len <= c2.Len && c1.Val&mask == c2.Val&mask {
				return false
			}
		}
	}
	return true
}

// checkCanonical reports whether all codes are canonical.
// That is, they have the following properties:
//
//	1. All codes of a given bit-length are consecutive values.
//	2. Shorter codes lexicographically precede longer codes.
//
// The codes must have unique symbols and be sorted by the symbol
// The Len and Val fields in each code must be populated.
func (pc PrefixCodes) checkCanonical() bool {
	// Rule 1.
	var vals [valueBits + 1]PrefixCode
	for _, c := range pc {
		if c.Len > 0 {
			c.Val = internal.ReverseUint32N(c.Val, uint(c.Len))
			if vals[c.Len].Cnt > 0 && vals[c.Len].Val+1 != c.Val {
				return false
			}
			vals[c.Len].Val = c.Val
			vals[c.Len].Cnt++
		}
	}

	// Rule 2.
	var last PrefixCode
	for _, v := range vals {
		if v.Cnt > 0 {
			curVal := v.Val - v.Cnt + 1
			if last.Cnt != 0 && last.Val >= curVal {
				return false
			}
			last = v
		}
	}
	return true
}

// GenerateLengths assigns non-zero bit-lengths to all codes. Codes with high
// frequency counts will be assigned shorter codes to reduce bit entropy.
// This function is used primarily by compressors.
//
// The input codes must have the Cnt field populated, be sorted by count.
// Even if a code has a count of 0, a non-zero bit-length will be assigned.
//
// The result will have the Len field populated. The algorithm used guarantees
// that Len <= maxBits and that it is a complete prefix tree. The resulting
// codes will remain sorted by count.
func GenerateLengths(codes PrefixCodes, maxBits uint) error {
	if len(codes) <= 1 {
		if len(codes) == 1 {
			codes[0].Len = 0
		}
		return nil
	}

	// Verify that the codes are in ascending order by count.
	cntLast := codes[0].Cnt
	for _, c := range codes[1:] {
		if c.Cnt < cntLast {
			return errorf(errors.Invalid, "non-monotonically increasing symbol counts")
		}
		cntLast = c.Cnt
	}

	// Construct a Huffman tree used to generate the bit-lengths.
	//
	// The Huffman tree is a binary tree where each symbol lies as a leaf node
	// on this tree. The length of the prefix code to assign is the depth of
	// that leaf from the root. The Huffman algorithm, which runs in O(n),
	// is used to generate the tree. It assumes that codes are sorted in
	// increasing order of frequency.
	//
	// The algorithm is as follows:
	//	1. Start with two queues, F and Q, where F contains all of the starting
	//	symbols sorted such that symbols with lowest counts come first.
	//	2. While len(F)+len(Q) > 1:
	//		2a. Dequeue the node from F or Q that has the lowest weight as N0.
	//		2b. Dequeue the node from F or Q that has the lowest weight as N1.
	//		2c. Create a new node N that has N0 and N1 as its children.
	//		2d. Enqueue N into the back of Q.
	//	3. The tree's root node is Q[0].
	type node struct {
		cnt uint32

		// n0 or c0 represent the left child of this node.
		// Since Go does not have unions, only one of these will be set.
		// Similarly, n1 or c1 represent the right child of this node.
		//
		// If n0 or n1 is set, then it represents a "pointer" to another
		// node in the Huffman tree. Since Go's pointer analysis cannot reason
		// that these node pointers do not escape (golang.org/issue/13493),
		// we use an index to a node in the nodes slice as a pseudo-pointer.
		//
		// If c0 or c1 is set, then it represents a leaf "node" in the
		// Huffman tree. The leaves are the PrefixCode values themselves.
		n0, n1 int // Index to child nodes
		c0, c1 *PrefixCode
	}
	var nodeIdx int
	var nodeArr [1024]node // Large enough to handle most cases on the stack
	nodes := nodeArr[:]
	if len(nodes) < len(codes) {
		nodes = make([]node, len(codes)) // Number of internal nodes < number of leaves
	}
	freqs, queue := codes, nodes[:0]
	for len(freqs)+len(queue) > 1 {
		// These are the two smallest nodes at the front of freqs and queue.
		var n node
		if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) {
			n.c0, freqs = &freqs[0], freqs[1:]
			n.cnt += n.c0.Cnt
		} else {
			n.cnt += queue[0].cnt
			n.n0 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0]
			nodeIdx++
			queue = queue[1:]
		}
		if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) {
			n.c1, freqs = &freqs[0], freqs[1:]
			n.cnt += n.c1.Cnt
		} else {
			n.cnt += queue[0].cnt
			n.n1 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0]
			nodeIdx++
			queue = queue[1:]
		}
		queue = append(queue, n)
	}
	rootIdx := nodeIdx

	// Search the whole binary tree, noting when we hit each leaf node.
	// We do not care about the exact Huffman tree structure, but rather we only
	// care about depth of each of the leaf nodes. That is, the depth determines
	// how long each symbol is in bits.
	//
	// Since the number of leaves is n, there is at most n internal nodes.
	// Thus, this algorithm runs in O(n).
	var fixBits bool
	var explore func(int, uint)
	explore = func(rootIdx int, level uint) {
		root := &nodes[rootIdx]

		// Explore left branch.
		if root.c0 == nil {
			explore(root.n0, level+1)
		} else {
			fixBits = fixBits || (level > maxBits)
			root.c0.Len = uint32(level)
		}

		// Explore right branch.
		if root.c1 == nil {
			explore(root.n1, level+1)
		} else {
			fixBits = fixBits || (level > maxBits)
			root.c1.Len = uint32(level)
		}
	}
	explore(rootIdx, 1)

	// Fix the bit-lengths if we violate the maxBits requirement.
	if fixBits {
		// Create histogram for number of symbols with each bit-length.
		var symBitsArr [valueBits + 1]uint32
		symBits := symBitsArr[:] // symBits[nb] indicates number of symbols using nb bits
		for _, c := range codes {
			for int(c.Len) >= len(symBits) {
				symBits = append(symBits, 0)
			}
			symBits[c.Len]++
		}

		// Fudge the tree such that the largest bit-length is <= maxBits.
		// This is accomplish by effectively doing a tree rotation. That is, we
		// increase the bit-length of some higher frequency code, so that the
		// bit-lengths of lower frequency codes can be decreased.
		//
		// Visually, this looks like the following transform:
		//
		//	Level   Before       After
		//	          __          ___
		//	         /  \        /   \
		//	 n-1    X  / \      /\   /\
		//	 n        X  /\    X  X X  X
		//	 n+1        X  X
		//
		var treeRotate func(uint)
		treeRotate = func(nb uint) {
			if symBits[nb-1] == 0 {
				treeRotate(nb - 1)
			}
			symBits[nb-1] -= 1 // Push this node to the level below
			symBits[nb] += 3   // This level gets one node from above, two from below
			symBits[nb+1] -= 2 // Push two nodes to the level above
		}
		for i := uint(len(symBits)) - 1; i > maxBits; i-- {
			for symBits[i] > 0 {
				treeRotate(i - 1)
			}
		}

		// Assign bit-lengths to each code. Since codes is sorted in increasing
		// order of frequency, that means that the most frequently used symbols
		// should have the shortest bit-lengths. Thus, we copy symbols to codes
		// from the back of codes first.
		cs := codes
		for nb, cnt := range symBits {
			if cnt > 0 {
				pos := len(cs) - int(cnt)
				cs2 := cs[pos:]
				for i := range cs2 {
					cs2[i].Len = uint32(nb)
				}
				cs = cs[:pos]
			}
		}
		if len(cs) != 0 {
			panic("not all codes were used up")
		}
	}

	if internal.Debug && !codes.checkLengths() {
		panic("incomplete prefix tree detected")
	}
	return nil
}

// GeneratePrefixes assigns a prefix value to all codes according to the
// bit-lengths. This function is used by both compressors and decompressors.
//
// The input codes must have the Sym and Len fields populated and be
// sorted by symbol. The bit-lengths of each code must be properly allocated,
// such that it forms a complete tree.
//
// The result will have the Val field populated and will produce a canonical
// prefix tree. The resulting codes will remain sorted by symbol.
func GeneratePrefixes(codes PrefixCodes) error {
	if len(codes) <= 1 {
		if len(codes) == 1 {
			if codes[0].Len != 0 {
				return errorf(errors.Invalid, "degenerate prefix tree with one node")
			}
			codes[0].Val = 0
		}
		return nil
	}

	// Compute basic statistics on the symbols.
	var bitCnts [valueBits + 1]uint
	c0 := codes[0]
	bitCnts[c0.Len]++
	minBits, maxBits, symLast := c0.Len, c0.Len, c0.Sym
	for _, c := range codes[1:] {
		if c.Sym <= symLast {
			return errorf(errors.Invalid, "non-unique or non-monotonically increasing symbols")
		}
		if minBits > c.Len {
			minBits = c.Len
		}
		if maxBits < c.Len {
			maxBits = c.Len
		}
		bitCnts[c.Len]++ // Histogram of bit counts
		symLast = c.Sym  // Keep track of last symbol
	}
	if minBits == 0 {
		return errorf(errors.Invalid, "invalid prefix bit-length")
	}

	// Compute the next code for a symbol of a given bit length.
	var nextCodes [valueBits + 1]uint
	var code uint
	for i := minBits; i <= maxBits; i++ {
		code <<= 1
		nextCodes[i] = code
		code += bitCnts[i]
	}
	if code != 1<<maxBits {
		return errorf(errors.Invalid, "degenerate prefix tree")
	}

	// Assign the code to each symbol.
	for i, c := range codes {
		codes[i].Val = internal.ReverseUint32N(uint32(nextCodes[c.Len]), uint(c.Len))
		nextCodes[c.Len]++
	}

	if internal.Debug && !codes.checkPrefixes() {
		panic("overlapping prefixes detected")
	}
	if internal.Debug && !codes.checkCanonical() {
		panic("non-canonical prefixes detected")
	}
	return nil
}

func allocUint32s(s []uint32, n int) []uint32 {
	if cap(s) >= n {
		return s[:n]
	}
	return make([]uint32, n, n*3/2)
}

func extendSliceUint32s(s [][]uint32, n int) [][]uint32 {
	if cap(s) >= n {
		return s[:n]
	}
	ss := make([][]uint32, n, n*3/2)
	copy(ss, s[:cap(s)])
	return ss
}