| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package getty |
| |
| import ( |
| "bytes" |
| "net" |
| "net/http" |
| "os" |
| "strconv" |
| "sync" |
| "testing" |
| "time" |
| ) |
| |
| import ( |
| "github.com/stretchr/testify/assert" |
| ) |
| |
| type PackageHandler struct{} |
| |
| func (h *PackageHandler) Read(ss Session, data []byte) (interface{}, int, error) { |
| return nil, 0, nil |
| } |
| |
| func (h *PackageHandler) Write(ss Session, pkg interface{}) ([]byte, error) { |
| return nil, nil |
| } |
| |
| type MessageHandler struct { |
| lock sync.Mutex |
| array []Session |
| } |
| |
| func newMessageHandler() *MessageHandler { |
| return &MessageHandler{} |
| } |
| |
| func (h *MessageHandler) SessionNumber() int { |
| h.lock.Lock() |
| connNum := len(h.array) |
| h.lock.Unlock() |
| |
| return connNum |
| } |
| |
| func (h *MessageHandler) OnOpen(session Session) error { |
| h.lock.Lock() |
| defer h.lock.Unlock() |
| h.array = append(h.array, session) |
| |
| return nil |
| } |
| func (h *MessageHandler) OnError(session Session, err error) {} |
| func (h *MessageHandler) OnClose(session Session) {} |
| func (h *MessageHandler) OnMessage(session Session, pkg interface{}) {} |
| func (h *MessageHandler) OnCron(session Session) {} |
| |
| type Package struct{} |
| |
| func (p Package) String() string { |
| return "" |
| } |
| func (p Package) Marshal() (*bytes.Buffer, error) { return nil, nil } |
| func (p *Package) Unmarshal(buf *bytes.Buffer) (int, error) { return 0, nil } |
| |
| func newSessionCallback(session Session, handler *MessageHandler) error { |
| var pkgHandler PackageHandler |
| session.SetName("hello-client-session") |
| session.SetMaxMsgLen(128 * 1024) // max message package length 128k |
| session.SetPkgHandler(&pkgHandler) |
| session.SetEventListener(handler) |
| session.SetReadTimeout(3e9) |
| session.SetWriteTimeout(3e9) |
| session.SetCronPeriod((int)(30e9 / 1e6)) |
| session.SetWaitTime(3e9) |
| |
| return nil |
| } |
| |
| func TestTCPClient(t *testing.T) { |
| listenLocalServer := func() (net.Listener, error) { |
| listener, err := net.Listen("tcp", ":0") |
| if err != nil { |
| return nil, err |
| } |
| |
| go http.Serve(listener, nil) |
| return listener, nil |
| } |
| |
| listener, err := listenLocalServer() |
| assert.Nil(t, err) |
| assert.NotNil(t, listener) |
| |
| addr := listener.Addr().(*net.TCPAddr) |
| t.Logf("server addr: %v", addr) |
| clt := NewTCPClient( |
| WithServerAddress(addr.String()), |
| WithReconnectInterval(5e8), |
| WithConnectionNumber(1), |
| ) |
| assert.NotNil(t, clt) |
| assert.True(t, clt.ID() > 0) |
| // assert.Equal(t, clt.endPointType, TCP_CLIENT) |
| |
| var msgHandler MessageHandler |
| cb := func(session Session) error { |
| return newSessionCallback(session, &msgHandler) |
| } |
| |
| clt.RunEventLoop(cb) |
| time.Sleep(1e9) |
| |
| assert.Equal(t, 1, msgHandler.SessionNumber()) |
| ss := msgHandler.array[0] |
| ss.SetCompressType(CompressNone) |
| conn := ss.(*session).Connection.(*gettyTCPConn) |
| assert.True(t, conn.compress == CompressNone) |
| beforeWriteBytes := conn.writeBytes |
| beforeWritePkgNum := conn.writePkgNum |
| l, err := conn.send([]byte("hello")) |
| assert.Nil(t, err) |
| assert.True(t, l == 5) |
| beforeWritePkgNum.Add(1) |
| beforeWriteBytes.Add(5) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| l, err = ss.WriteBytes([]byte("hello")) |
| assert.Nil(t, err) |
| assert.True(t, l == 5) |
| beforeWriteBytes.Add(5) |
| beforeWritePkgNum.Add(1) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| var pkgs [][]byte |
| pkgs = append(pkgs, []byte("hello"), []byte("hello")) |
| l, err = conn.send(pkgs) |
| assert.Nil(t, err) |
| assert.True(t, l == 10) |
| beforeWritePkgNum.Add(2) |
| beforeWriteBytes.Add(10) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| ss.SetCompressType(CompressSnappy) |
| l, err = ss.WriteBytesArray(pkgs...) |
| assert.Nil(t, err) |
| assert.True(t, l == 10) |
| beforeWritePkgNum.Add(2) |
| beforeWriteBytes.Add(10) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| assert.True(t, conn.compress == CompressSnappy) |
| |
| batchSize := 128 * 1023 |
| source := make([]byte, batchSize) |
| for i := 0; i < batchSize; i++ { |
| source[i] = 't' |
| } |
| l, err = ss.WriteBytes(source) |
| assert.Nil(t, err) |
| assert.True(t, l == batchSize) |
| beforeWriteBytes.Add(uint32(batchSize)) |
| beforeWritePkgNum.Add(uint32(batchSize/16/1024) + 1) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| |
| batchSize = 32 * 1024 |
| source = make([]byte, batchSize) |
| for i := 0; i < batchSize; i++ { |
| source[i] = 't' |
| } |
| l, err = ss.WriteBytes(source) |
| assert.Nil(t, err) |
| assert.True(t, l == batchSize) |
| beforeWriteBytes.Add(uint32(batchSize)) |
| beforeWritePkgNum.Add(2) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| |
| clt.Close() |
| assert.True(t, clt.IsClosed()) |
| } |
| |
| func TestUDPClient(t *testing.T) { |
| var ( |
| err error |
| conn *net.UDPConn |
| sendLen int |
| totalLen int |
| ) |
| func() { |
| ip := net.ParseIP("127.0.0.1") |
| srcAddr := &net.UDPAddr{IP: ip, Port: 0} |
| conn, err = net.ListenUDP("udp", srcAddr) |
| assert.Nil(t, err) |
| assert.NotNil(t, conn) |
| }() |
| defer conn.Close() |
| |
| addr := conn.LocalAddr() |
| t.Logf("server addr: %v", addr) |
| clt := NewUDPClient( |
| WithServerAddress(addr.String()), |
| WithReconnectInterval(5e8), |
| WithConnectionNumber(1), |
| ) |
| assert.NotNil(t, clt) |
| assert.True(t, clt.ID() > 0) |
| // assert.Equal(t, clt.endPointType, UDP_CLIENT) |
| |
| var msgHandler MessageHandler |
| cb := func(session Session) error { |
| return newSessionCallback(session, &msgHandler) |
| } |
| |
| clt.RunEventLoop(cb) |
| time.Sleep(1e9) |
| |
| assert.Equal(t, 1, msgHandler.SessionNumber()) |
| ss := msgHandler.array[0] |
| totalLen, sendLen, err = ss.WritePkg(nil, 0) |
| assert.NotNil(t, err) |
| assert.True(t, sendLen == 0) |
| assert.True(t, totalLen == 0) |
| totalLen, sendLen, err = ss.WritePkg([]byte("hello"), 0) |
| assert.NotNil(t, err) |
| assert.True(t, sendLen == 0) |
| assert.True(t, totalLen == 0) |
| l, err := ss.WriteBytes([]byte("hello")) |
| assert.Zero(t, l) |
| assert.NotNil(t, err) |
| l, err = ss.WriteBytesArray([]byte("hello")) |
| assert.Zero(t, l) |
| assert.NotNil(t, err) |
| l, err = ss.WriteBytesArray([]byte("hello"), []byte("world")) |
| assert.Zero(t, l) |
| assert.NotNil(t, err) |
| ss.SetCompressType(CompressNone) |
| host, port, _ := net.SplitHostPort(addr.String()) |
| if len(host) < 8 { |
| host = "127.0.0.1" |
| } |
| remotePort, _ := strconv.Atoi(port) |
| serverAddr := net.UDPAddr{IP: net.ParseIP(host), Port: remotePort} |
| udpCtx := UDPContext{ |
| Pkg: "hello", |
| PeerAddr: &serverAddr, |
| } |
| t.Logf("udp context:%s", udpCtx) |
| udpConn := ss.(*session).Connection.(*gettyUDPConn) |
| _, err = udpConn.send(udpCtx) |
| assert.NotNil(t, err) |
| udpCtx.Pkg = []byte("hello") |
| beforeWriteBytes := udpConn.writeBytes |
| _, err = udpConn.send(udpCtx) |
| beforeWriteBytes.Add(5) |
| assert.Equal(t, beforeWriteBytes, udpConn.writeBytes) |
| assert.Nil(t, err) |
| |
| beforeWritePkgNum := udpConn.writePkgNum |
| totalLen, sendLen, err = ss.WritePkg(udpCtx, 0) |
| beforeWritePkgNum.Add(1) |
| assert.Equal(t, beforeWritePkgNum, udpConn.writePkgNum) |
| assert.Nil(t, err) |
| assert.True(t, sendLen == 0) |
| assert.True(t, totalLen == 0) |
| |
| clt.Close() |
| assert.True(t, clt.IsClosed()) |
| msgHandler.array[0].Reset() |
| assert.Nil(t, msgHandler.array[0].Conn()) |
| // ss.WritePkg([]byte("hello"), 0) |
| } |
| |
| func TestNewWSClient(t *testing.T) { |
| var ( |
| server Server |
| serverMsgHandler MessageHandler |
| ) |
| addr := "127.0.0.1:65000" |
| path := "/hello" |
| func() { |
| server = NewWSServer( |
| WithLocalAddress(addr), |
| WithWebsocketServerPath(path), |
| ) |
| newServerSession := func(session Session) error { |
| return newSessionCallback(session, &serverMsgHandler) |
| } |
| go server.RunEventLoop(newServerSession) |
| }() |
| time.Sleep(1e9) |
| |
| client := NewWSClient( |
| WithServerAddress("ws://"+addr+path), |
| WithConnectionNumber(1), |
| ) |
| |
| var msgHandler MessageHandler |
| cb := func(session Session) error { |
| return newSessionCallback(session, &msgHandler) |
| } |
| |
| client.RunEventLoop(cb) |
| time.Sleep(1e9) |
| |
| assert.Equal(t, 1, msgHandler.SessionNumber()) |
| ss := msgHandler.array[0] |
| ss.SetCompressType(CompressNone) |
| conn := ss.(*session).Connection.(*gettyWSConn) |
| assert.True(t, conn.compress == CompressNone) |
| err := conn.handlePing("hello") |
| assert.Nil(t, err) |
| l, err := conn.send("hello") |
| assert.NotNil(t, err) |
| assert.True(t, l == 0) |
| beforeWriteBytes := conn.writeBytes |
| _, err = conn.send([]byte("hello")) |
| assert.Nil(t, err) |
| beforeWriteBytes.Add(5) |
| assert.Equal(t, beforeWriteBytes, conn.writeBytes) |
| beforeWritePkgNum := conn.writePkgNum |
| l, err = ss.WriteBytes([]byte("hello")) |
| assert.Nil(t, err) |
| assert.True(t, l == 5) |
| beforeWritePkgNum.Add(1) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| l, err = ss.WriteBytesArray([]byte("hello"), []byte("hello")) |
| assert.Nil(t, err) |
| assert.True(t, l == 10) |
| beforeWritePkgNum.Add(2) |
| assert.Equal(t, beforeWritePkgNum, conn.writePkgNum) |
| err = conn.writePing() |
| assert.Nil(t, err) |
| |
| ss.SetReader(nil) |
| assert.Nil(t, ss.(*session).reader) |
| ss.SetWriter(nil) |
| assert.Nil(t, ss.(*session).writer) |
| assert.Nil(t, ss.(*session).GetAttribute("hello")) |
| |
| client.Close() |
| assert.True(t, client.IsClosed()) |
| server.Close() |
| assert.True(t, server.IsClosed()) |
| } |
| |
| var ( |
| WssServerCRT = []byte(`-----BEGIN CERTIFICATE----- |
| MIICHjCCAYegAwIBAgIQKpKqamBqmZ0hfp8sYb4uNDANBgkqhkiG9w0BAQsFADAS |
| MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw |
| MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB |
| iQKBgQC5Nxsk6WjeaYazRYiGxHZ5G3FXSlSjV7lZeebItdEPzO8kVPIGCSTy/M5X |
| Nnpp3uVDFXQub0/O5t9Y6wcuqpUGMOV+XL7MZqSZlodXm0XhNYzCAjZ+URNjTHGP |
| NXIqdDEG5Ba8SXMOfY6H97+QxugZoAMFZ+N83ggr12IYNO/FbQIDAQABo3MwcTAO |
| BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw |
| AwEB/zA5BgNVHREEMjAwgglsb2NhbGhvc3SCC2V4YW1wbGUuY29thwR/AAABhxAA |
| AAAAAAAAAAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4GBAE5dr9q7ORmKZ7yZqeSL |
| 305armc13A7UxffUajeJFujpl2jOqnb5PuKJ7fn5HQKGB0qSq3IHsFua2WONXcTW |
| Vn4gS0k50IaDpW+yl+ArIo0QwbjPIAcFysX10p9dVO7A1uEpHbRDzefem6r9uVGk |
| i7dOLEoC8hkfk6nJsNEIEqu6 |
| -----END CERTIFICATE-----`) |
| WssServerCRTFile = "/tmp/server.crt" |
| WssServerKEY = []byte(`-----BEGIN RSA PRIVATE KEY----- |
| MIICXgIBAAKBgQC5Nxsk6WjeaYazRYiGxHZ5G3FXSlSjV7lZeebItdEPzO8kVPIG |
| CSTy/M5XNnpp3uVDFXQub0/O5t9Y6wcuqpUGMOV+XL7MZqSZlodXm0XhNYzCAjZ+ |
| URNjTHGPNXIqdDEG5Ba8SXMOfY6H97+QxugZoAMFZ+N83ggr12IYNO/FbQIDAQAB |
| AoGBAJgvuXQY/fxSxUWkysvBvn9Al17cSrN0r23gBkvBaakMASvfSIbBGMU4COwM |
| bYV0ivkWNcK539/oQHk1lU85Bv0K9V9wtuFrYW0mN3TU6jnl6eEnzW5oy0Z9TwyY |
| wuGQOSXGr/aDVu8Wr7eOmSvn6j8rWO2dSMHCllJnSBoqQ1aZAkEA5YQspoMhUaq+ |
| kC53GTgMhotnmK3fWfWKrlLf0spsaNl99W3+plwqxnJbye+5uEutRR1PWSWCCKq5 |
| bN9veOXViwJBAM6WS5aeKO/JX09O0Ang9Y0+atMKO0YjX6fNFE2UJ5Ewzyr4DMZK |
| TmBpyzm4x/GhV9ukqcDcd3dNlUOtgRqY3+cCQQDCGmssk1+dUpqBE1rT8CvfqYv+ |
| eqWWzerwDNSPz3OppK4630Bqby4Z0GNCP8RAUXgDKIuPqAH11HSm17vNcgqLAkA8 |
| 8FCzyUvCD+CxgEoV3+oPFA5m2mnJsr2QvgnzKHTTe1ZhEnKSO3ELN6nfCQbR3AoS |
| nGwGnAIRiy0wnYmr0tSZAkEAsWFm/D7sTQhX4Qnh15ZDdUn1WSWjBZevUtJnQcpx |
| TjihZq2sd3uK/XrzG+w7B+cPZlrZtQ94sDSVQwWl/sxB4A== |
| -----END RSA PRIVATE KEY-----`) |
| WssServerKEYFile = "/tmp/server.key" |
| WssClientCRT = []byte(`-----BEGIN CERTIFICATE----- |
| MIICHjCCAYegAwIBAgIQKpKqamBqmZ0hfp8sYb4uNDANBgkqhkiG9w0BAQsFADAS |
| MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw |
| MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB |
| iQKBgQC5Nxsk6WjeaYazRYiGxHZ5G3FXSlSjV7lZeebItdEPzO8kVPIGCSTy/M5X |
| Nnpp3uVDFXQub0/O5t9Y6wcuqpUGMOV+XL7MZqSZlodXm0XhNYzCAjZ+URNjTHGP |
| NXIqdDEG5Ba8SXMOfY6H97+QxugZoAMFZ+N83ggr12IYNO/FbQIDAQABo3MwcTAO |
| BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw |
| AwEB/zA5BgNVHREEMjAwgglsb2NhbGhvc3SCC2V4YW1wbGUuY29thwR/AAABhxAA |
| AAAAAAAAAAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4GBAE5dr9q7ORmKZ7yZqeSL |
| 305armc13A7UxffUajeJFujpl2jOqnb5PuKJ7fn5HQKGB0qSq3IHsFua2WONXcTW |
| Vn4gS0k50IaDpW+yl+ArIo0QwbjPIAcFysX10p9dVO7A1uEpHbRDzefem6r9uVGk |
| i7dOLEoC8hkfk6nJsNEIEqu6 |
| -----END CERTIFICATE-----`) |
| WssClientCRTFile = "/tmp/client.crt" |
| ) |
| |
| func DownloadFile(filepath string, content []byte) error { |
| // Create the file |
| out, err := os.Create(filepath) |
| if err != nil { |
| return err |
| } |
| defer out.Close() |
| |
| // Write the body to file |
| _, err = out.Write(content) |
| return err |
| } |
| |
| func TestNewWSSClient(t *testing.T) { |
| var ( |
| err error |
| server Server |
| serverMsgHandler MessageHandler |
| ) |
| |
| os.Remove(WssServerCRTFile) |
| err = DownloadFile(WssServerCRTFile, WssServerCRT) |
| assert.Nil(t, err) |
| defer os.Remove(WssServerCRTFile) |
| |
| os.Remove(WssServerKEYFile) |
| err = DownloadFile(WssServerKEYFile, WssServerKEY) |
| assert.Nil(t, err) |
| defer os.Remove(WssServerKEYFile) |
| |
| os.Remove(WssClientCRTFile) |
| err = DownloadFile(WssClientCRTFile, WssClientCRT) |
| assert.Nil(t, err) |
| defer os.Remove(WssClientCRTFile) |
| |
| addr := "127.0.0.1:63450" |
| path := "/hello" |
| func() { |
| server = NewWSSServer( |
| WithLocalAddress(addr), |
| WithWebsocketServerPath(path), |
| WithWebsocketServerCert(WssServerCRTFile), |
| WithWebsocketServerPrivateKey(WssServerKEYFile), |
| ) |
| newServerSession := func(session Session) error { |
| return newSessionCallback(session, &serverMsgHandler) |
| } |
| go server.RunEventLoop(newServerSession) |
| }() |
| time.Sleep(1e9) |
| |
| client := NewWSSClient( |
| WithServerAddress("wss://"+addr+path), |
| WithConnectionNumber(1), |
| WithRootCertificateFile(WssClientCRTFile), |
| ) |
| |
| var msgHandler MessageHandler |
| cb := func(session Session) error { |
| return newSessionCallback(session, &msgHandler) |
| } |
| |
| client.RunEventLoop(cb) |
| time.Sleep(1e9) |
| |
| assert.Equal(t, 1, msgHandler.SessionNumber()) |
| client.Close() |
| assert.True(t, client.IsClosed()) |
| assert.False(t, server.IsClosed()) |
| // time.Sleep(1000e9) |
| // server.Close() |
| // assert.True(t, server.IsClosed()) |
| } |