| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package getty |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "fmt" |
| "io/ioutil" |
| "net" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
| ) |
| |
| import ( |
| gxnet "github.com/dubbogo/gost/net" |
| gxsync "github.com/dubbogo/gost/sync" |
| gxtime "github.com/dubbogo/gost/time" |
| |
| "github.com/gorilla/websocket" |
| |
| perrors "github.com/pkg/errors" |
| |
| uatomic "go.uber.org/atomic" |
| ) |
| |
| var ( |
| errSelfConnect = perrors.New("connect self!") |
| serverFastFailTimeout = time.Second * 1 |
| |
| serverID uatomic.Int32 |
| ) |
| |
| // Server interface |
| type Server interface { |
| EndPoint |
| } |
| |
| // StreamServer is like tcp/websocket/wss server |
| type StreamServer interface { |
| Server |
| // Listener get the network listener |
| Listener() net.Listener |
| } |
| |
| // PacketServer is like udp listen endpoint |
| type PacketServer interface { |
| Server |
| // PacketConn get the network listener |
| PacketConn() net.PacketConn |
| } |
| |
| type server struct { |
| ServerOptions |
| |
| // endpoint ID |
| endPointID EndPointID |
| |
| // net |
| pktListener net.PacketConn |
| streamListener net.Listener |
| lock sync.Mutex // for server |
| endPointType EndPointType |
| server *http.Server // for ws or wss server |
| sync.Once |
| done chan struct{} |
| wg sync.WaitGroup |
| } |
| |
| func (s *server) init(opts ...ServerOption) { |
| for _, opt := range opts { |
| opt(&(s.ServerOptions)) |
| } |
| } |
| |
| func newServer(t EndPointType, opts ...ServerOption) *server { |
| s := &server{ |
| endPointID: serverID.Add(1), |
| endPointType: t, |
| done: make(chan struct{}), |
| } |
| |
| s.init(opts...) |
| |
| return s |
| } |
| |
| // NewTCPServer builds a tcp server. |
| func NewTCPServer(opts ...ServerOption) Server { |
| return newServer(TCP_SERVER, opts...) |
| } |
| |
| // NewUDPEndPoint builds a unconnected udp server. |
| func NewUDPEndPoint(opts ...ServerOption) Server { |
| return newServer(UDP_ENDPOINT, opts...) |
| } |
| |
| // NewWSServer builds a websocket server. |
| func NewWSServer(opts ...ServerOption) Server { |
| return newServer(WS_SERVER, opts...) |
| } |
| |
| // NewWSSServer builds a secure websocket server. |
| func NewWSSServer(opts ...ServerOption) Server { |
| s := newServer(WSS_SERVER, opts...) |
| |
| if s.addr == "" || s.cert == "" || s.privateKey == "" { |
| panic(fmt.Sprintf("@addr:%s, @cert:%s, @privateKey:%s, @caCert:%s", |
| s.addr, s.cert, s.privateKey, s.caCert)) |
| } |
| |
| return s |
| } |
| |
| func (s *server) ID() int32 { |
| return s.endPointID |
| } |
| |
| func (s *server) EndPointType() EndPointType { |
| return s.endPointType |
| } |
| |
| func (s *server) stop() { |
| select { |
| case <-s.done: |
| return |
| default: |
| s.Once.Do(func() { |
| close(s.done) |
| s.lock.Lock() |
| if s.server != nil { |
| ctx, cancel := context.WithTimeout(context.Background(), serverFastFailTimeout) |
| if err := s.server.Shutdown(ctx); err != nil { |
| // if the log output is "shutdown ctx: context deadline exceeded", it means that |
| // there are still some active connections. |
| log.Errorf("server shutdown ctx:%s error:%v", ctx, err) |
| } |
| cancel() |
| } |
| s.server = nil |
| s.lock.Unlock() |
| if s.streamListener != nil { |
| // let the server exit asap when got error from RunEventLoop. |
| s.streamListener.Close() |
| s.streamListener = nil |
| } |
| if s.pktListener != nil { |
| s.pktListener.Close() |
| s.pktListener = nil |
| } |
| }) |
| } |
| } |
| |
| func (s *server) GetTaskPool() gxsync.GenericTaskPool { |
| return s.tPool |
| } |
| |
| func (s *server) IsClosed() bool { |
| select { |
| case <-s.done: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| // net.ipv4.tcp_max_syn_backlog |
| // net.ipv4.tcp_timestamps |
| // net.ipv4.tcp_tw_recycle |
| func (s *server) listenTCP() error { |
| var ( |
| err error |
| streamListener net.Listener |
| ) |
| |
| if len(s.addr) == 0 || !strings.Contains(s.addr, ":") { |
| streamListener, err = gxnet.ListenOnTCPRandomPort(s.addr) |
| if err != nil { |
| return perrors.Wrapf(err, "gxnet.ListenOnTCPRandomPort(addr:%s)", s.addr) |
| } |
| } else { |
| if s.sslEnabled { |
| if sslConfig, buildTlsConfErr := s.tlsConfigBuilder.BuildTlsConfig(); buildTlsConfErr == nil && sslConfig != nil { |
| streamListener, err = tls.Listen("tcp", s.addr, sslConfig) |
| } |
| } else { |
| streamListener, err = net.Listen("tcp", s.addr) |
| } |
| if err != nil { |
| return perrors.Wrapf(err, "net.Listen(tcp, addr:%s)", s.addr) |
| } |
| } |
| |
| s.streamListener = streamListener |
| s.addr = s.streamListener.Addr().String() |
| |
| return nil |
| } |
| |
| func (s *server) listenUDP() error { |
| var ( |
| err error |
| localAddr *net.UDPAddr |
| pktListener *net.UDPConn |
| ) |
| |
| if len(s.addr) == 0 || !strings.Contains(s.addr, ":") { |
| pktListener, err = gxnet.ListenOnUDPRandomPort(s.addr) |
| if err != nil { |
| return perrors.Wrapf(err, "gxnet.ListenOnUDPRandomPort(addr:%s)", s.addr) |
| } |
| } else { |
| localAddr, err = net.ResolveUDPAddr("udp", s.addr) |
| if err != nil { |
| return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr) |
| } |
| pktListener, err = net.ListenUDP("udp", localAddr) |
| if err != nil { |
| return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr) |
| } |
| } |
| |
| s.pktListener = pktListener |
| s.addr = s.pktListener.LocalAddr().String() |
| |
| return nil |
| } |
| |
| // Listen announces on the local network address. |
| func (s *server) listen() error { |
| switch s.endPointType { |
| case TCP_SERVER, WS_SERVER, WSS_SERVER: |
| return perrors.WithStack(s.listenTCP()) |
| case UDP_ENDPOINT: |
| return perrors.WithStack(s.listenUDP()) |
| } |
| |
| return nil |
| } |
| |
| func (s *server) accept(newSession NewSessionCallback) (Session, error) { |
| conn, err := s.streamListener.Accept() |
| if err != nil { |
| return nil, perrors.WithStack(err) |
| } |
| if gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) { |
| log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String()) |
| return nil, perrors.WithStack(errSelfConnect) |
| } |
| |
| ss := newTCPSession(conn, s) |
| err = newSession(ss) |
| if err != nil { |
| conn.Close() |
| return nil, perrors.WithStack(err) |
| } |
| |
| return ss, nil |
| } |
| |
| func (s *server) runTCPEventLoop(newSession NewSessionCallback) { |
| s.wg.Add(1) |
| go func() { |
| defer s.wg.Done() |
| var ( |
| err error |
| client Session |
| delay time.Duration |
| ) |
| for { |
| if s.IsClosed() { |
| log.Infof("server{%s} stop accepting client connect request.", s.addr) |
| return |
| } |
| if delay != 0 { |
| <-gxtime.After(delay) |
| } |
| client, err = s.accept(newSession) |
| if err != nil { |
| if netErr, ok := perrors.Cause(err).(net.Error); ok && netErr.Temporary() { |
| if delay == 0 { |
| delay = 5 * time.Millisecond |
| } else { |
| delay *= 2 |
| } |
| if max := 1 * time.Second; delay > max { |
| delay = max |
| } |
| continue |
| } |
| log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, perrors.WithStack(err)) |
| continue |
| } |
| delay = 0 |
| client.(*session).run() |
| } |
| }() |
| } |
| |
| func (s *server) runUDPEventLoop(newSession NewSessionCallback) { |
| s.wg.Add(1) |
| go func() { |
| defer s.wg.Done() |
| var ( |
| err error |
| conn *net.UDPConn |
| ss Session |
| ) |
| |
| conn = s.pktListener.(*net.UDPConn) |
| ss = newUDPSession(conn, s) |
| if err = newSession(ss); err != nil { |
| conn.Close() |
| panic(err.Error()) |
| } |
| ss.(*session).run() |
| }() |
| } |
| |
| type wsHandler struct { |
| http.ServeMux |
| server *server |
| newSession NewSessionCallback |
| upgrader websocket.Upgrader |
| } |
| |
| func newWSHandler(server *server, newSession NewSessionCallback) *wsHandler { |
| return &wsHandler{ |
| server: server, |
| newSession: newSession, |
| upgrader: websocket.Upgrader{ |
| // in default, ReadBufferSize & WriteBufferSize is 4k |
| // HandshakeTimeout: server.HTTPTimeout, |
| CheckOrigin: func(_ *http.Request) bool { return true }, // allow connections from any origin |
| EnableCompression: true, |
| }, |
| } |
| } |
| |
| func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) { |
| if r.Method != "GET" { |
| // w.WriteHeader(http.StatusMethodNotAllowed) |
| http.Error(w, "Method not allowed", 405) |
| return |
| } |
| |
| if s.server.IsClosed() { |
| http.Error(w, "HTTP server is closed(code:500-11).", 500) |
| log.Warnf("server{%s} stop acceptting client connect request.", s.server.addr) |
| return |
| } |
| |
| conn, err := s.upgrader.Upgrade(w, r, nil) |
| if err != nil { |
| log.Warnf("upgrader.Upgrader(http.Request{%#v}) = error:%+v", r, err) |
| return |
| } |
| if conn.RemoteAddr().String() == conn.LocalAddr().String() { |
| log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String()) |
| return |
| } |
| // conn.SetReadLimit(int64(handler.maxMsgLen)) |
| ss := newWSSession(conn, s.server) |
| err = s.newSession(ss) |
| if err != nil { |
| conn.Close() |
| log.Warnf("server{%s}.newSession(ss{%#v}) = err {%s}", s.server.addr, ss, err) |
| return |
| } |
| if ss.(*session).maxMsgLen > 0 { |
| conn.SetReadLimit(int64(ss.(*session).maxMsgLen)) |
| } |
| ss.(*session).run() |
| } |
| |
| // runWSEventLoop serve websocket client request |
| // @newSession: new websocket connection callback |
| func (s *server) runWSEventLoop(newSession NewSessionCallback) { |
| s.wg.Add(1) |
| go func() { |
| defer s.wg.Done() |
| var ( |
| err error |
| handler *wsHandler |
| server *http.Server |
| ) |
| handler = newWSHandler(s, newSession) |
| handler.HandleFunc(s.path, handler.serveWSRequest) |
| server = &http.Server{ |
| Addr: s.addr, |
| Handler: handler, |
| // ReadTimeout: server.HTTPTimeout, |
| // WriteTimeout: server.HTTPTimeout, |
| } |
| s.lock.Lock() |
| s.server = server |
| s.lock.Unlock() |
| err = server.Serve(s.streamListener) |
| if err != nil { |
| log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err)) |
| } |
| }() |
| } |
| |
| // serve websocket client request |
| // RunWSSEventLoop serve websocket client request |
| func (s *server) runWSSEventLoop(newSession NewSessionCallback) { |
| s.wg.Add(1) |
| go func() { |
| var ( |
| err error |
| certPem []byte |
| certificate tls.Certificate |
| certPool *x509.CertPool |
| config *tls.Config |
| handler *wsHandler |
| server *http.Server |
| ) |
| defer s.wg.Done() |
| |
| if certificate, err = tls.LoadX509KeyPair(s.cert, s.privateKey); err != nil { |
| panic(fmt.Sprintf("tls.LoadX509KeyPair(certs{%s}, privateKey{%s}) = err:%+v", |
| s.cert, s.privateKey, perrors.WithStack(err))) |
| } |
| config = &tls.Config{ |
| InsecureSkipVerify: true, // do not verify peer certs |
| ClientAuth: tls.NoClientCert, |
| NextProtos: []string{"http/1.1"}, |
| Certificates: []tls.Certificate{certificate}, |
| } |
| |
| if s.caCert != "" { |
| certPem, err = ioutil.ReadFile(s.caCert) |
| if err != nil { |
| panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.caCert, perrors.WithStack(err))) |
| } |
| certPool = x509.NewCertPool() |
| if ok := certPool.AppendCertsFromPEM(certPem); !ok { |
| panic("failed to parse root certificate file") |
| } |
| config.ClientCAs = certPool |
| config.ClientAuth = tls.RequireAndVerifyClientCert |
| config.InsecureSkipVerify = false |
| } |
| |
| handler = newWSHandler(s, newSession) |
| handler.HandleFunc(s.path, handler.serveWSRequest) |
| server = &http.Server{ |
| Addr: s.addr, |
| Handler: handler, |
| // ReadTimeout: server.HTTPTimeout, |
| // WriteTimeout: server.HTTPTimeout, |
| } |
| server.SetKeepAlivesEnabled(true) |
| s.lock.Lock() |
| s.server = server |
| s.lock.Unlock() |
| err = server.Serve(tls.NewListener(s.streamListener, config)) |
| if err != nil { |
| log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err)) |
| panic(err) |
| } |
| }() |
| } |
| |
| // RunEventLoop serves client request. |
| // @newSession: new connection callback |
| func (s *server) RunEventLoop(newSession NewSessionCallback) { |
| if err := s.listen(); err != nil { |
| panic(fmt.Errorf("server.listen() = error:%+v", perrors.WithStack(err))) |
| } |
| |
| switch s.endPointType { |
| case TCP_SERVER: |
| s.runTCPEventLoop(newSession) |
| case UDP_ENDPOINT: |
| s.runUDPEventLoop(newSession) |
| case WS_SERVER: |
| s.runWSEventLoop(newSession) |
| case WSS_SERVER: |
| s.runWSSEventLoop(newSession) |
| default: |
| panic(fmt.Sprintf("illegal server type %s", s.endPointType.String())) |
| } |
| } |
| |
| func (s *server) Listener() net.Listener { |
| return s.streamListener |
| } |
| |
| func (s *server) PacketConn() net.PacketConn { |
| return s.pktListener |
| } |
| |
| func (s *server) Close() { |
| s.stop() |
| s.wg.Wait() |
| } |