package restful

// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.

import (
	"bytes"
	"errors"
	"fmt"
	"net/http"
	"os"
	"runtime"
	"strings"
	"sync"

	"github.com/emicklei/go-restful/log"
)

// Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
// The requests are further dispatched to routes of WebServices using a RouteSelector
type Container struct {
	webServicesLock        sync.RWMutex
	webServices            []*WebService
	ServeMux               *http.ServeMux
	isRegisteredOnRoot     bool
	containerFilters       []FilterFunction
	doNotRecover           bool // default is true
	recoverHandleFunc      RecoverHandleFunction
	serviceErrorHandleFunc ServiceErrorHandleFunction
	router                 RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
	contentEncodingEnabled bool          // default is false
}

// NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
func NewContainer() *Container {
	return &Container{
		webServices:            []*WebService{},
		ServeMux:               http.NewServeMux(),
		isRegisteredOnRoot:     false,
		containerFilters:       []FilterFunction{},
		doNotRecover:           true,
		recoverHandleFunc:      logStackOnRecover,
		serviceErrorHandleFunc: writeServiceError,
		router:                 CurlyRouter{},
		contentEncodingEnabled: false}
}

// RecoverHandleFunction declares functions that can be used to handle a panic situation.
// The first argument is what recover() returns. The second must be used to communicate an error response.
type RecoverHandleFunction func(interface{}, http.ResponseWriter)

// RecoverHandler changes the default function (logStackOnRecover) to be called
// when a panic is detected. DoNotRecover must be have its default value (=false).
func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
	c.recoverHandleFunc = handler
}

// ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
// The first argument is the service error, the second is the request that resulted in the error and
// the third must be used to communicate an error response.
type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)

// ServiceErrorHandler changes the default function (writeServiceError) to be called
// when a ServiceError is detected.
func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
	c.serviceErrorHandleFunc = handler
}

// DoNotRecover controls whether panics will be caught to return HTTP 500.
// If set to true, Route functions are responsible for handling any error situation.
// Default value is true.
func (c *Container) DoNotRecover(doNot bool) {
	c.doNotRecover = doNot
}

// Router changes the default Router (currently CurlyRouter)
func (c *Container) Router(aRouter RouteSelector) {
	c.router = aRouter
}

// EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
func (c *Container) EnableContentEncoding(enabled bool) {
	c.contentEncodingEnabled = enabled
}

// Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
func (c *Container) Add(service *WebService) *Container {
	c.webServicesLock.Lock()
	defer c.webServicesLock.Unlock()

	// if rootPath was not set then lazy initialize it
	if len(service.rootPath) == 0 {
		service.Path("/")
	}

	// cannot have duplicate root paths
	for _, each := range c.webServices {
		if each.RootPath() == service.RootPath() {
			log.Printf("[restful] WebService with duplicate root path detected:['%v']", each)
			os.Exit(1)
		}
	}

	// If not registered on root then add specific mapping
	if !c.isRegisteredOnRoot {
		c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
	}
	c.webServices = append(c.webServices, service)
	return c
}

// addHandler may set a new HandleFunc for the serveMux
// this function must run inside the critical region protected by the webServicesLock.
// returns true if the function was registered on root ("/")
func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
	pattern := fixedPrefixPath(service.RootPath())
	// check if root path registration is needed
	if "/" == pattern || "" == pattern {
		serveMux.HandleFunc("/", c.dispatch)
		return true
	}
	// detect if registration already exists
	alreadyMapped := false
	for _, each := range c.webServices {
		if each.RootPath() == service.RootPath() {
			alreadyMapped = true
			break
		}
	}
	if !alreadyMapped {
		serveMux.HandleFunc(pattern, c.dispatch)
		if !strings.HasSuffix(pattern, "/") {
			serveMux.HandleFunc(pattern+"/", c.dispatch)
		}
	}
	return false
}

func (c *Container) Remove(ws *WebService) error {
	if c.ServeMux == http.DefaultServeMux {
		errMsg := fmt.Sprintf("[restful] cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
		log.Print(errMsg)
		return errors.New(errMsg)
	}
	c.webServicesLock.Lock()
	defer c.webServicesLock.Unlock()
	// build a new ServeMux and re-register all WebServices
	newServeMux := http.NewServeMux()
	newServices := []*WebService{}
	newIsRegisteredOnRoot := false
	for _, each := range c.webServices {
		if each.rootPath != ws.rootPath {
			// If not registered on root then add specific mapping
			if !newIsRegisteredOnRoot {
				newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
			}
			newServices = append(newServices, each)
		}
	}
	c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
	return nil
}

// logStackOnRecover is the default RecoverHandleFunction and is called
// when DoNotRecover is false and the recoverHandleFunc is not set for the container.
// Default implementation logs the stacktrace and writes the stacktrace on the response.
// This may be a security issue as it exposes sourcecode information.
func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
	var buffer bytes.Buffer
	buffer.WriteString(fmt.Sprintf("[restful] recover from panic situation: - %v\r\n", panicReason))
	for i := 2; ; i += 1 {
		_, file, line, ok := runtime.Caller(i)
		if !ok {
			break
		}
		buffer.WriteString(fmt.Sprintf("    %s:%d\r\n", file, line))
	}
	log.Print(buffer.String())
	httpWriter.WriteHeader(http.StatusInternalServerError)
	httpWriter.Write(buffer.Bytes())
}

// writeServiceError is the default ServiceErrorHandleFunction and is called
// when a ServiceError is returned during route selection. Default implementation
// calls resp.WriteErrorString(err.Code, err.Message)
func writeServiceError(err ServiceError, req *Request, resp *Response) {
	resp.WriteErrorString(err.Code, err.Message)
}

// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
	if httpWriter == nil {
		panic("httpWriter cannot be nil")
	}
	if httpRequest == nil {
		panic("httpRequest cannot be nil")
	}
	c.dispatch(httpWriter, httpRequest)
}

// Dispatch the incoming Http Request to a matching WebService.
func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
	writer := httpWriter

	// CompressingResponseWriter should be closed after all operations are done
	defer func() {
		if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
			compressWriter.Close()
		}
	}()

	// Instal panic recovery unless told otherwise
	if !c.doNotRecover { // catch all for 500 response
		defer func() {
			if r := recover(); r != nil {
				c.recoverHandleFunc(r, writer)
				return
			}
		}()
	}

	// Detect if compression is needed
	// assume without compression, test for override
	if c.contentEncodingEnabled {
		doCompress, encoding := wantsCompressedResponse(httpRequest)
		if doCompress {
			var err error
			writer, err = NewCompressingResponseWriter(httpWriter, encoding)
			if err != nil {
				log.Print("[restful] unable to install compressor: ", err)
				httpWriter.WriteHeader(http.StatusInternalServerError)
				return
			}
		}
	}
	// Find best match Route ; err is non nil if no match was found
	var webService *WebService
	var route *Route
	var err error
	func() {
		c.webServicesLock.RLock()
		defer c.webServicesLock.RUnlock()
		webService, route, err = c.router.SelectRoute(
			c.webServices,
			httpRequest)
	}()
	if err != nil {
		// a non-200 response has already been written
		// run container filters anyway ; they should not touch the response...
		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
			switch err.(type) {
			case ServiceError:
				ser := err.(ServiceError)
				c.serviceErrorHandleFunc(ser, req, resp)
			}
			// TODO
		}}
		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
		return
	}
	pathProcessor, routerProcessesPath := c.router.(PathProcessor)
	if !routerProcessesPath {
		pathProcessor = defaultPathProcessor{}
	}
	pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
	wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
	// pass through filters (if any)
	if len(c.containerFilters)+len(webService.filters)+len(route.Filters) > 0 {
		// compose filter chain
		allFilters := []FilterFunction{}
		allFilters = append(allFilters, c.containerFilters...)
		allFilters = append(allFilters, webService.filters...)
		allFilters = append(allFilters, route.Filters...)
		chain := FilterChain{Filters: allFilters, Target: func(req *Request, resp *Response) {
			// handle request by route after passing all filters
			route.Function(wrappedRequest, wrappedResponse)
		}}
		chain.ProcessFilter(wrappedRequest, wrappedResponse)
	} else {
		// no filters, handle request by route
		route.Function(wrappedRequest, wrappedResponse)
	}
}

// fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
func fixedPrefixPath(pathspec string) string {
	varBegin := strings.Index(pathspec, "{")
	if -1 == varBegin {
		return pathspec
	}
	return pathspec[:varBegin]
}

// ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
func (c *Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) {
	c.ServeMux.ServeHTTP(httpwriter, httpRequest)
}

// Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
func (c *Container) Handle(pattern string, handler http.Handler) {
	c.ServeMux.Handle(pattern, handler)
}

// HandleWithFilter registers the handler for the given pattern.
// Container's filter chain is applied for handler.
// If a handler already exists for pattern, HandleWithFilter panics.
func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
	f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
		if len(c.containerFilters) == 0 {
			handler.ServeHTTP(httpResponse, httpRequest)
			return
		}

		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
			handler.ServeHTTP(httpResponse, httpRequest)
		}}
		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
	}

	c.Handle(pattern, http.HandlerFunc(f))
}

// Filter appends a container FilterFunction. These are called before dispatching
// a http.Request to a WebService from the container
func (c *Container) Filter(filter FilterFunction) {
	c.containerFilters = append(c.containerFilters, filter)
}

// RegisteredWebServices returns the collections of added WebServices
func (c *Container) RegisteredWebServices() []*WebService {
	c.webServicesLock.RLock()
	defer c.webServicesLock.RUnlock()
	result := make([]*WebService, len(c.webServices))
	for ix := range c.webServices {
		result[ix] = c.webServices[ix]
	}
	return result
}

// computeAllowedMethods returns a list of HTTP methods that are valid for a Request
func (c *Container) computeAllowedMethods(req *Request) []string {
	// Go through all RegisteredWebServices() and all its Routes to collect the options
	methods := []string{}
	requestPath := req.Request.URL.Path
	for _, ws := range c.RegisteredWebServices() {
		matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
		if matches != nil {
			finalMatch := matches[len(matches)-1]
			for _, rt := range ws.Routes() {
				matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
				if matches != nil {
					lastMatch := matches[len(matches)-1]
					if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
						methods = append(methods, rt.Method)
					}
				}
			}
		}
	}
	// methods = append(methods, "OPTIONS")  not sure about this
	return methods
}

// newBasicRequestResponse creates a pair of Request,Response from its http versions.
// It is basic because no parameter or (produces) content-type information is given.
func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
	resp := NewResponse(httpWriter)
	resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
	return NewRequest(httpRequest), resp
}
