/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package thrift

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"math"
)

type TBinaryProtocol struct {
	trans         TRichTransport
	origTransport TTransport
	strictRead    bool
	strictWrite   bool
	buffer        [64]byte
}

type TBinaryProtocolFactory struct {
	strictRead  bool
	strictWrite bool
}

func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
	return NewTBinaryProtocol(t, false, true)
}

func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
	p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
	if et, ok := t.(TRichTransport); ok {
		p.trans = et
	} else {
		p.trans = NewTRichTransport(t)
	}
	return p
}

func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
	return NewTBinaryProtocolFactory(false, true)
}

func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
	return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
}

func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
	return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
}

/**
 * Writing Methods
 */

func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
	if p.strictWrite {
		version := uint32(VERSION_1) | uint32(typeId)
		e := p.WriteI32(ctx, int32(version))
		if e != nil {
			return e
		}
		e = p.WriteString(ctx, name)
		if e != nil {
			return e
		}
		e = p.WriteI32(ctx, seqId)
		return e
	} else {
		e := p.WriteString(ctx, name)
		if e != nil {
			return e
		}
		e = p.WriteByte(ctx, int8(typeId))
		if e != nil {
			return e
		}
		e = p.WriteI32(ctx, seqId)
		return e
	}
	return nil
}

func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error {
	return nil
}

func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
	e := p.WriteByte(ctx, int8(typeId))
	if e != nil {
		return e
	}
	e = p.WriteI16(ctx, id)
	return e
}

func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error {
	e := p.WriteByte(ctx, STOP)
	return e
}

func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
	e := p.WriteByte(ctx, int8(keyType))
	if e != nil {
		return e
	}
	e = p.WriteByte(ctx, int8(valueType))
	if e != nil {
		return e
	}
	e = p.WriteI32(ctx, int32(size))
	return e
}

func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
	e := p.WriteByte(ctx, int8(elemType))
	if e != nil {
		return e
	}
	e = p.WriteI32(ctx, int32(size))
	return e
}

func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
	e := p.WriteByte(ctx, int8(elemType))
	if e != nil {
		return e
	}
	e = p.WriteI32(ctx, int32(size))
	return e
}

func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error {
	if value {
		return p.WriteByte(ctx, 1)
	}
	return p.WriteByte(ctx, 0)
}

func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error {
	e := p.trans.WriteByte(byte(value))
	return NewTProtocolException(e)
}

func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error {
	v := p.buffer[0:2]
	binary.BigEndian.PutUint16(v, uint16(value))
	_, e := p.trans.Write(v)
	return NewTProtocolException(e)
}

func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error {
	v := p.buffer[0:4]
	binary.BigEndian.PutUint32(v, uint32(value))
	_, e := p.trans.Write(v)
	return NewTProtocolException(e)
}

func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error {
	v := p.buffer[0:8]
	binary.BigEndian.PutUint64(v, uint64(value))
	_, err := p.trans.Write(v)
	return NewTProtocolException(err)
}

func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error {
	return p.WriteI64(ctx, int64(math.Float64bits(value)))
}

func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error {
	e := p.WriteI32(ctx, int32(len(value)))
	if e != nil {
		return e
	}
	_, err := p.trans.WriteString(value)
	return NewTProtocolException(err)
}

func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error {
	e := p.WriteI32(ctx, int32(len(value)))
	if e != nil {
		return e
	}
	_, err := p.trans.Write(value)
	return NewTProtocolException(err)
}

/**
 * Reading methods
 */

func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
	size, e := p.ReadI32(ctx)
	if e != nil {
		return "", typeId, 0, NewTProtocolException(e)
	}
	if size < 0 {
		typeId = TMessageType(size & 0x0ff)
		version := int64(int64(size) & VERSION_MASK)
		if version != VERSION_1 {
			return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
		}
		name, e = p.ReadString(ctx)
		if e != nil {
			return name, typeId, seqId, NewTProtocolException(e)
		}
		seqId, e = p.ReadI32(ctx)
		if e != nil {
			return name, typeId, seqId, NewTProtocolException(e)
		}
		return name, typeId, seqId, nil
	}
	if p.strictRead {
		return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
	}
	name, e2 := p.readStringBody(size)
	if e2 != nil {
		return name, typeId, seqId, e2
	}
	b, e3 := p.ReadByte(ctx)
	if e3 != nil {
		return name, typeId, seqId, e3
	}
	typeId = TMessageType(b)
	seqId, e4 := p.ReadI32(ctx)
	if e4 != nil {
		return name, typeId, seqId, e4
	}
	return name, typeId, seqId, nil
}

func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
	return
}

func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) {
	t, err := p.ReadByte(ctx)
	typeId = TType(t)
	if err != nil {
		return name, typeId, seqId, err
	}
	if t != STOP {
		seqId, err = p.ReadI16(ctx)
	}
	return name, typeId, seqId, err
}

func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error {
	return nil
}

var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))

func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) {
	k, e := p.ReadByte(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	kType = TType(k)
	v, e := p.ReadByte(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	vType = TType(v)
	size32, e := p.ReadI32(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	if size32 < 0 {
		err = invalidDataLength
		return
	}
	size = int(size32)
	return kType, vType, size, nil
}

func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
	b, e := p.ReadByte(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	elemType = TType(b)
	size32, e := p.ReadI32(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	if size32 < 0 {
		err = invalidDataLength
		return
	}
	size = int(size32)

	return
}

func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
	b, e := p.ReadByte(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	elemType = TType(b)
	size32, e := p.ReadI32(ctx)
	if e != nil {
		err = NewTProtocolException(e)
		return
	}
	if size32 < 0 {
		err = invalidDataLength
		return
	}
	size = int(size32)
	return elemType, size, nil
}

func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error {
	return nil
}

func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) {
	b, e := p.ReadByte(ctx)
	v := true
	if b != 1 {
		v = false
	}
	return v, e
}

func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) {
	v, err := p.trans.ReadByte()
	return int8(v), err
}

func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) {
	buf := p.buffer[0:2]
	err = p.readAll(ctx, buf)
	value = int16(binary.BigEndian.Uint16(buf))
	return value, err
}

func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) {
	buf := p.buffer[0:4]
	err = p.readAll(ctx, buf)
	value = int32(binary.BigEndian.Uint32(buf))
	return value, err
}

func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) {
	buf := p.buffer[0:8]
	err = p.readAll(ctx, buf)
	value = int64(binary.BigEndian.Uint64(buf))
	return value, err
}

func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
	buf := p.buffer[0:8]
	err = p.readAll(ctx, buf)
	value = math.Float64frombits(binary.BigEndian.Uint64(buf))
	return value, err
}

func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) {
	size, e := p.ReadI32(ctx)
	if e != nil {
		return "", e
	}
	if size < 0 {
		err = invalidDataLength
		return
	}
	if size == 0 {
		return "", nil
	}
	if size < int32(len(p.buffer)) {
		// Avoid allocation on small reads
		buf := p.buffer[:size]
		read, e := io.ReadFull(p.trans, buf)
		return string(buf[:read]), NewTProtocolException(e)
	}

	return p.readStringBody(size)
}

func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
	size, e := p.ReadI32(ctx)
	if e != nil {
		return nil, e
	}
	if size < 0 {
		return nil, invalidDataLength
	}

	buf, err := safeReadBytes(size, p.trans)
	return buf, NewTProtocolException(err)
}

func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
	return NewTProtocolException(p.trans.Flush(ctx))
}

func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
	return SkipDefaultDepth(ctx, p, fieldType)
}

func (p *TBinaryProtocol) Transport() TTransport {
	return p.origTransport
}

func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
	var read int
	_, deadlineSet := ctx.Deadline()
	for {
		read, err = io.ReadFull(p.trans, buf)
		if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil {
			// This is I/O timeout without anything read,
			// and we still have time left, keep retrying.
			continue
		}
		// For anything else, don't retry
		break
	}
	return NewTProtocolException(err)
}

func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
	buf, err := safeReadBytes(size, p.trans)
	return string(buf), NewTProtocolException(err)
}

// This function is shared between TBinaryProtocol and TCompactProtocol.
//
// It tries to read size bytes from trans, in a way that prevents large
// allocations when size is insanely large (mostly caused by malformed message).
func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
	if size < 0 {
		return nil, nil
	}

	buf := new(bytes.Buffer)
	_, err := io.CopyN(buf, trans, int64(size))
	return buf.Bytes(), err
}
