/******************************************************
# DESC       : getty server
# MAINTAINER : Alex Stocks
# LICENCE    : Apache License 2.0
# EMAIL      : alexstocks@foxmail.com
# MOD        : 2016-08-17 11:21
# FILE       : server.go
******************************************************/

package getty

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io/ioutil"
	"net"
	"net/http"
	"strings"
	"sync"
	"sync/atomic"
	"time"
)

import (
	"github.com/dubbogo/gost/net"
	"github.com/gorilla/websocket"
	perrors "github.com/pkg/errors"
)

var (
	errSelfConnect        = perrors.New("connect self!")
	serverFastFailTimeout = time.Second * 1
	serverID              = EndPointID(0)
)

type server struct {
	ServerOptions

	// endpoint ID
	endPointID EndPointID

	// net
	pktListener    net.PacketConn
	streamListener net.Listener
	lock           sync.Mutex // for server
	endPointType   EndPointType
	server         *http.Server // for ws or wss server

	sync.Once
	done chan struct{}
	wg   sync.WaitGroup
}

func (s *server) init(opts ...ServerOption) {
	for _, opt := range opts {
		opt(&(s.ServerOptions))
	}
}

func newServer(t EndPointType, opts ...ServerOption) *server {
	s := &server{
		endPointID:   atomic.AddInt32(&serverID, 1),
		endPointType: t,
		done:         make(chan struct{}),
	}

	s.init(opts...)

	return s
}

// NewTCServer builds a tcp server.
func NewTCPServer(opts ...ServerOption) Server {
	return newServer(TCP_SERVER, opts...)
}

// NewUDPEndPoint builds a unconnected udp server.
func NewUDPPEndPoint(opts ...ServerOption) Server {
	return newServer(UDP_ENDPOINT, opts...)
}

// NewWSServer builds a websocket server.
func NewWSServer(opts ...ServerOption) Server {
	return newServer(WS_SERVER, opts...)
}

// NewWSSServer builds a secure websocket server.
func NewWSSServer(opts ...ServerOption) Server {
	s := newServer(WSS_SERVER, opts...)

	if s.addr == "" || s.cert == "" || s.privateKey == "" {
		panic(fmt.Sprintf("@addr:%s, @cert:%s, @privateKey:%s, @caCert:%s",
			s.addr, s.cert, s.privateKey, s.caCert))
	}

	return s
}

func (s server) ID() int32 {
	return s.endPointID
}

func (s server) EndPointType() EndPointType {
	return s.endPointType
}

func (s *server) stop() {
	var (
		err error
		ctx context.Context
	)

	select {
	case <-s.done:
		return
	default:
		s.Once.Do(func() {
			close(s.done)
			s.lock.Lock()
			if s.server != nil {
				ctx, _ = context.WithTimeout(context.Background(), serverFastFailTimeout)
				if err = s.server.Shutdown(ctx); err != nil {
					// if the log output is "shutdown ctx: context deadline exceeded"， it means that
					// there are still some active connections.
					log.Errorf("server shutdown ctx:%s error:%v", ctx, err)
				}
			}
			s.server = nil
			s.lock.Unlock()
			if s.streamListener != nil {
				// let the server exit asap when got error from RunEventLoop.
				s.streamListener.Close()
				s.streamListener = nil
			}
			if s.pktListener != nil {
				s.pktListener.Close()
				s.pktListener = nil
			}
		})
	}
}

func (s *server) IsClosed() bool {
	select {
	case <-s.done:
		return true
	default:
		return false
	}
}

// net.ipv4.tcp_max_syn_backlog
// net.ipv4.tcp_timestamps
// net.ipv4.tcp_tw_recycle
func (s *server) listenTCP() error {
	var (
		err            error
		streamListener net.Listener
	)

	if len(s.addr) == 0 || !strings.Contains(s.addr, ":") {
		streamListener, err = gxnet.ListenOnTCPRandomPort(s.addr)
		if err != nil {
			return perrors.Wrapf(err, "gxnet.ListenOnTCPRandomPort(addr:%s)", s.addr)
		}
	} else {
		streamListener, err = net.Listen("tcp", s.addr)
		if err != nil {
			return perrors.Wrapf(err, "net.Listen(tcp, addr:%s)", s.addr)
		}
	}

	s.streamListener = streamListener

	return nil
}

func (s *server) listenUDP() error {
	var (
		err         error
		localAddr   *net.UDPAddr
		pktListener *net.UDPConn
	)

	if len(s.addr) == 0 || !strings.Contains(s.addr, ":") {
		pktListener, err = gxnet.ListenOnUDPRandomPort(s.addr)
		if err != nil {
			return perrors.Wrapf(err, "gxnet.ListenOnUDPRandomPort(addr:%s)", s.addr)
		}
	} else {
		localAddr, err = net.ResolveUDPAddr("udp", s.addr)
		if err != nil {
			return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr)
		}
		pktListener, err = net.ListenUDP("udp", localAddr)
		if err != nil {
			return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr)
		}
	}

	s.pktListener = pktListener

	return nil
}

// Listen announces on the local network address.
func (s *server) listen() error {
	switch s.endPointType {
	case TCP_SERVER, WS_SERVER, WSS_SERVER:
		return perrors.WithStack(s.listenTCP())
	case UDP_ENDPOINT:
		return perrors.WithStack(s.listenUDP())
	}

	return nil
}

func (s *server) accept(newSession NewSessionCallback) (Session, error) {
	conn, err := s.streamListener.Accept()
	if err != nil {
		return nil, perrors.WithStack(err)
	}
	if gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
		log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String())
		return nil, perrors.WithStack(errSelfConnect)
	}

	ss := newTCPSession(conn, s)
	err = newSession(ss)
	if err != nil {
		conn.Close()
		return nil, perrors.WithStack(err)
	}

	return ss, nil
}

func (s *server) runTcpEventLoop(newSession NewSessionCallback) {
	s.wg.Add(1)
	go func() {
		defer s.wg.Done()
		var (
			err    error
			client Session
			delay  time.Duration
		)
		for {
			if s.IsClosed() {
				log.Warnf("server{%s} stop accepting client connect request.", s.addr)
				return
			}
			if delay != 0 {
				<-wheel.After(delay)
			}
			client, err = s.accept(newSession)
			if err != nil {
				if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
					if delay == 0 {
						delay = 5 * time.Millisecond
					} else {
						delay *= 2
					}
					if max := 1 * time.Second; delay > max {
						delay = max
					}
					continue
				}
				log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, perrors.WithStack(err))
				continue
			}
			delay = 0
			client.(*session).run()
		}
	}()
}

func (s *server) runUDPEventLoop(newSession NewSessionCallback) {
	var (
		ss Session
	)

	ss = newUDPSession(s.pktListener.(*net.UDPConn), s)
	if err := newSession(ss); err != nil {
		panic(err.Error())
	}
	ss.(*session).run()
}

type wsHandler struct {
	http.ServeMux
	server     *server
	newSession NewSessionCallback
	upgrader   websocket.Upgrader
}

func newWSHandler(server *server, newSession NewSessionCallback) *wsHandler {
	return &wsHandler{
		server:     server,
		newSession: newSession,
		upgrader: websocket.Upgrader{
			// in default, ReadBufferSize & WriteBufferSize is 4k
			// HandshakeTimeout: server.HTTPTimeout,
			CheckOrigin:       func(_ *http.Request) bool { return true }, // allow connections from any origin
			EnableCompression: true,
		},
	}
}

func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) {
	if r.Method != "GET" {
		// w.WriteHeader(http.StatusMethodNotAllowed)
		http.Error(w, "Method not allowed", 405)
		return
	}

	if s.server.IsClosed() {
		http.Error(w, "HTTP server is closed(code:500-11).", 500)
		log.Warnf("server{%s} stop acceptting client connect request.", s.server.addr)
		return
	}

	conn, err := s.upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Warnf("upgrader.Upgrader(http.Request{%#v}) = error:%+v", r, err)
		return
	}
	if conn.RemoteAddr().String() == conn.LocalAddr().String() {
		log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String())
		return
	}
	// conn.SetReadLimit(int64(handler.maxMsgLen))
	ss := newWSSession(conn, s.server)
	err = s.newSession(ss)
	if err != nil {
		conn.Close()
		log.Warnf("server{%s}.newSession(ss{%#v}) = err {%s}", s.server.addr, ss, err)
		return
	}
	if ss.(*session).maxMsgLen > 0 {
		conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
	}
	ss.(*session).run()
}

// runWSEventLoop serve websocket client request
// @newSession: new websocket connection callback
func (s *server) runWSEventLoop(newSession NewSessionCallback) {
	s.wg.Add(1)
	go func() {
		defer s.wg.Done()
		var (
			err     error
			handler *wsHandler
			server  *http.Server
		)
		handler = newWSHandler(s, newSession)
		handler.HandleFunc(s.path, handler.serveWSRequest)
		server = &http.Server{
			Addr:    s.addr,
			Handler: handler,
			// ReadTimeout:    server.HTTPTimeout,
			// WriteTimeout:   server.HTTPTimeout,
		}
		s.lock.Lock()
		s.server = server
		s.lock.Unlock()
		err = server.Serve(s.streamListener)
		if err != nil {
			log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err))
		}
	}()
}

// serve websocket client request
// RunWSSEventLoop serve websocket client request
func (s *server) runWSSEventLoop(newSession NewSessionCallback) {
	s.wg.Add(1)
	go func() {
		var (
			err         error
			certPem     []byte
			certificate tls.Certificate
			certPool    *x509.CertPool
			config      *tls.Config
			handler     *wsHandler
			server      *http.Server
		)
		defer s.wg.Done()

		if certificate, err = tls.LoadX509KeyPair(s.cert, s.privateKey); err != nil {
			panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err:%+v",
				s.cert, s.privateKey, perrors.WithStack(err)))
			return
		}
		config = &tls.Config{
			InsecureSkipVerify: true, // do not verify peer cert
			ClientAuth:         tls.NoClientCert,
			NextProtos:         []string{"http/1.1"},
			Certificates:       []tls.Certificate{certificate},
		}

		if s.caCert != "" {
			certPem, err = ioutil.ReadFile(s.caCert)
			if err != nil {
				panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.caCert, perrors.WithStack(err)))
			}
			certPool = x509.NewCertPool()
			if ok := certPool.AppendCertsFromPEM(certPem); !ok {
				panic("failed to parse root certificate file")
			}
			config.ClientCAs = certPool
			config.ClientAuth = tls.RequireAndVerifyClientCert
			config.InsecureSkipVerify = false
		}

		handler = newWSHandler(s, newSession)
		handler.HandleFunc(s.path, handler.serveWSRequest)
		server = &http.Server{
			Addr:    s.addr,
			Handler: handler,
			// ReadTimeout:    server.HTTPTimeout,
			// WriteTimeout:   server.HTTPTimeout,
		}
		server.SetKeepAlivesEnabled(true)
		s.lock.Lock()
		s.server = server
		s.lock.Unlock()
		err = server.Serve(tls.NewListener(s.streamListener, config))
		if err != nil {
			log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err))
			panic(err)
		}
	}()
}

// RunEventLoop serves client request.
// @newSession: new connection callback
func (s *server) RunEventLoop(newSession NewSessionCallback) {
	if err := s.listen(); err != nil {
		panic(fmt.Errorf("server.listen() = error:%+v", perrors.WithStack(err)))
	}

	switch s.endPointType {
	case TCP_SERVER:
		s.runTcpEventLoop(newSession)
	case UDP_ENDPOINT:
		s.runUDPEventLoop(newSession)
	case WS_SERVER:
		s.runWSEventLoop(newSession)
	case WSS_SERVER:
		s.runWSSEventLoop(newSession)
	default:
		panic(fmt.Sprintf("illegal server type %s", s.endPointType.String()))
	}
}

func (s *server) Listener() net.Listener {
	return s.streamListener
}

func (s *server) Close() {
	s.stop()
	s.wg.Wait()
}
