package hbase

import (
	"bufio"
	"bytes"
	"io"
	"net"
	"strings"
	"sync"

	pb "github.com/golang/protobuf/proto"
	"github.com/juju/errors"
	"github.com/ngaut/log"
	"github.com/pingcap/go-hbase/iohelper"
	"github.com/pingcap/go-hbase/proto"
)

type ServiceType byte

const (
	MasterMonitorService = iota + 1
	MasterService
	MasterAdminService
	AdminService
	ClientService
	RegionServerStatusService
)

// convert above const to protobuf string
var ServiceString = map[ServiceType]string{
	MasterMonitorService:      "MasterMonitorService",
	MasterService:             "MasterService",
	MasterAdminService:        "MasterAdminService",
	AdminService:              "AdminService",
	ClientService:             "ClientService",
	RegionServerStatusService: "RegionServerStatusService",
}

type idGenerator struct {
	n  int
	mu *sync.RWMutex
}

func newIdGenerator() *idGenerator {
	return &idGenerator{
		n:  0,
		mu: &sync.RWMutex{},
	}
}

func (a *idGenerator) get() int {
	a.mu.RLock()
	v := a.n
	a.mu.RUnlock()
	return v
}

func (a *idGenerator) incrAndGet() int {
	a.mu.Lock()
	a.n++
	v := a.n
	a.mu.Unlock()
	return v
}

type connection struct {
	mu           sync.Mutex
	addr         string
	conn         net.Conn
	bw           *bufio.Writer
	idGen        *idGenerator
	serviceType  ServiceType
	in           chan *iohelper.PbBuffer
	ongoingCalls map[int]*call
}

func processMessage(msg []byte) ([][]byte, error) {
	buf := pb.NewBuffer(msg)
	payloads := make([][]byte, 0)

	// Question: why can we ignore this error?
	for {
		hbytes, err := buf.DecodeRawBytes(true)
		if err != nil {
			// Check whether error is `unexpected EOF`.
			if strings.Contains(err.Error(), "unexpected EOF") {
				break
			}

			log.Errorf("Decode raw bytes error - %v", errors.ErrorStack(err))
			return nil, errors.Trace(err)
		}

		payloads = append(payloads, hbytes)
	}

	return payloads, nil
}

func readPayloads(r io.Reader) ([][]byte, error) {
	nBytesExpecting, err := iohelper.ReadInt32(r)
	if err != nil {
		return nil, errors.Trace(err)
	}

	if nBytesExpecting > 0 {
		buf, err := iohelper.ReadN(r, nBytesExpecting)
		// Question: why should we return error only when we get an io.EOF error?
		if err != nil && ErrorEqual(err, io.EOF) {
			return nil, errors.Trace(err)
		}

		payloads, err := processMessage(buf)
		if err != nil {
			return nil, errors.Trace(err)
		}

		if len(payloads) > 0 {
			return payloads, nil
		}
	}
	return nil, errors.New("unexpected payload")
}

func newConnection(addr string, srvType ServiceType) (*connection, error) {
	conn, err := net.Dial("tcp", addr)
	if err != nil {
		return nil, errors.Trace(err)
	}
	if _, ok := ServiceString[srvType]; !ok {
		return nil, errors.Errorf("unexpected service type [serviceType=%d]", srvType)
	}
	c := &connection{
		addr:         addr,
		bw:           bufio.NewWriter(conn),
		conn:         conn,
		in:           make(chan *iohelper.PbBuffer, 20),
		serviceType:  srvType,
		idGen:        newIdGenerator(),
		ongoingCalls: map[int]*call{},
	}

	err = c.init()
	if err != nil {
		return nil, errors.Trace(err)
	}

	return c, nil
}

func (c *connection) init() error {
	err := c.writeHead()
	if err != nil {
		return errors.Trace(err)
	}

	err = c.writeConnectionHeader()
	if err != nil {
		return errors.Trace(err)
	}

	go func() {
		err := c.processMessages()
		if err != nil {
			log.Warnf("process messages failed - %v", errors.ErrorStack(err))
			return
		}
	}()
	go c.dispatch()
	return nil
}

func (c *connection) processMessages() error {
	for {
		msgs, err := readPayloads(c.conn)
		if err != nil {
			return errors.Trace(err)
		}

		var rh proto.ResponseHeader
		err = pb.Unmarshal(msgs[0], &rh)
		if err != nil {
			return errors.Trace(err)
		}

		callId := rh.GetCallId()
		c.mu.Lock()
		call, ok := c.ongoingCalls[int(callId)]
		if !ok {
			c.mu.Unlock()
			return errors.Errorf("Invalid call id: %d", callId)
		}
		delete(c.ongoingCalls, int(callId))
		c.mu.Unlock()

		exception := rh.GetException()
		if exception != nil {
			call.complete(errors.Errorf("Exception returned: %s\n%s", exception.GetExceptionClassName(), exception.GetStackTrace()), nil)
		} else if len(msgs) == 2 {
			call.complete(nil, msgs[1])
		}
	}
}

func (c *connection) writeHead() error {
	buf := bytes.NewBuffer(nil)
	buf.Write(hbaseHeaderBytes)
	buf.WriteByte(0)
	buf.WriteByte(80)
	_, err := c.conn.Write(buf.Bytes())
	return errors.Trace(err)
}

func (c *connection) writeConnectionHeader() error {
	buf := iohelper.NewPbBuffer()
	service := pb.String(ServiceString[c.serviceType])

	err := buf.WritePBMessage(&proto.ConnectionHeader{
		UserInfo: &proto.UserInformation{
			EffectiveUser: pb.String("pingcap"),
		},
		ServiceName: service,
	})
	if err != nil {
		return errors.Trace(err)
	}

	err = buf.PrependSize()
	if err != nil {
		return errors.Trace(err)
	}

	_, err = c.conn.Write(buf.Bytes())
	if err != nil {
		return errors.Trace(err)
	}

	return nil
}

func (c *connection) dispatch() {
	for {
		select {
		case buf := <-c.in:
			// TODO: add error check.
			c.bw.Write(buf.Bytes())
			if len(c.in) == 0 {
				c.bw.Flush()
			}
		}
	}
}

func (c *connection) call(request *call) error {
	id := c.idGen.incrAndGet()
	rh := &proto.RequestHeader{
		CallId:       pb.Uint32(uint32(id)),
		MethodName:   pb.String(request.methodName),
		RequestParam: pb.Bool(true),
	}

	request.id = uint32(id)

	bfrh := iohelper.NewPbBuffer()
	err := bfrh.WritePBMessage(rh)
	if err != nil {
		return errors.Trace(err)
	}

	bfr := iohelper.NewPbBuffer()
	err = bfr.WritePBMessage(request.request)
	if err != nil {
		return errors.Trace(err)
	}

	// Buf =>
	// | total size | pb1 size | pb1 | pb2 size | pb2 | ...
	buf := iohelper.NewPbBuffer()
	buf.WriteDelimitedBuffers(bfrh, bfr)

	c.mu.Lock()
	c.ongoingCalls[id] = request
	c.in <- buf
	c.mu.Unlock()

	return nil
}

func (c *connection) close() error {
	return c.conn.Close()
}