Merge pull request #58 from cvictory/fix/support_return_length
Ftr: return the total length and write length
diff --git a/client_test.go b/client_test.go
index 8a8528c..5bdc58f 100644
--- a/client_test.go
+++ b/client_test.go
@@ -170,8 +170,10 @@
func TestUDPClient(t *testing.T) {
var (
- err error
- conn *net.UDPConn
+ err error
+ conn *net.UDPConn
+ sendLen int
+ totalLen int
)
func() {
ip := net.ParseIP("127.0.0.1")
@@ -205,10 +207,14 @@
assert.Equal(t, 1, msgHandler.SessionNumber())
ss := msgHandler.array[0]
- err = ss.WritePkg(nil, 0)
+ totalLen, sendLen, err = ss.WritePkg(nil, 0)
assert.NotNil(t, err)
- err = ss.WritePkg([]byte("hello"), 0)
+ assert.True(t, sendLen == 0)
+ assert.True(t, totalLen == 0)
+ totalLen, sendLen, err = ss.WritePkg([]byte("hello"), 0)
assert.NotNil(t, perrors.Cause(err))
+ assert.True(t, sendLen == 0)
+ assert.True(t, totalLen == 0)
l, err := ss.WriteBytes([]byte("hello"))
assert.Zero(t, l)
assert.NotNil(t, err)
@@ -240,9 +246,11 @@
assert.Nil(t, err)
beforeWritePkgNum := atomic.LoadUint32(&udpConn.writePkgNum)
- err = ss.WritePkg(udpCtx, 0)
+ totalLen, sendLen, err = ss.WritePkg(udpCtx, 0)
assert.Equal(t, beforeWritePkgNum+1, atomic.LoadUint32(&udpConn.writePkgNum))
assert.Nil(t, err)
+ assert.True(t, sendLen == 0)
+ assert.True(t, totalLen == 0)
clt.Close()
assert.True(t, clt.IsClosed())
diff --git a/demo/hello/hello.go b/demo/hello/hello.go
index 917b911..72d6d7f 100644
--- a/demo/hello/hello.go
+++ b/demo/hello/hello.go
@@ -31,7 +31,7 @@
go func() {
echoTimes := 10
for i := 0; i < echoTimes; i++ {
- err := ss.WritePkg("hello", WritePkgTimeout)
+ _, _, err := ss.WritePkg("hello", WritePkgTimeout)
if err != nil {
log.Infof("session.WritePkg(session{%s}, error{%v}", ss.Stat(), err)
ss.Close()
diff --git a/getty.go b/getty.go
index 6622329..cc6afe2 100644
--- a/getty.go
+++ b/getty.go
@@ -171,7 +171,10 @@
// the Writer will invoke this function. Pls attention that if timeout is less than 0, WritePkg will send @pkg asap.
// for udp session, the first parameter should be UDPContext.
- WritePkg(pkg interface{}, timeout time.Duration) error
+ // totalBytesLength: @pkg stream bytes length after encoding @pkg.
+ // sendBytesLength: stream bytes length that sent out successfully.
+ // err: maybe it has illegal data, encoding error, or write out system error.
+ WritePkg(pkg interface{}, timeout time.Duration) (totalBytesLength int, sendBytesLength int, err error)
WriteBytes([]byte) (int, error)
WriteBytesArray(...[]byte) (int, error)
Close()
diff --git a/session.go b/session.go
index 2334d24..45e53c6 100644
--- a/session.go
+++ b/session.go
@@ -347,12 +347,12 @@
s.name, s.EndPoint().EndPointType(), s.ID(), s.LocalAddr(), s.RemoteAddr())
}
-func (s *session) WritePkg(pkg interface{}, timeout time.Duration) error {
+func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (int, int, error) {
if pkg == nil {
- return fmt.Errorf("@pkg is nil")
+ return 0, 0, fmt.Errorf("@pkg is nil")
}
if s.IsClosed() {
- return ErrSessionClosed
+ return 0, 0, ErrSessionClosed
}
defer func() {
@@ -367,7 +367,7 @@
pkgBytes, err := s.writer.Write(s, pkg)
if err != nil {
log.Warnf("%s, [session.WritePkg] session.writer.Write(@pkg:%#v) = error:%+v", s.Stat(), pkg, err)
- return perrors.WithStack(err)
+ return len(pkgBytes), 0, perrors.WithStack(err)
}
var udpCtxPtr *UDPContext
if udpCtx, ok := pkg.(UDPContext); ok {
@@ -384,13 +384,13 @@
if 0 < timeout {
s.Connection.SetWriteTimeout(timeout)
}
- _, err = s.Connection.send(pkg)
+ var succssCount int
+ succssCount, err = s.Connection.send(pkg)
if err != nil {
log.Warnf("%s, [session.WritePkg] @s.Connection.Write(pkg:%#v) = err:%+v", s.Stat(), pkg, err)
- return perrors.WithStack(err)
+ return len(pkgBytes), succssCount, perrors.WithStack(err)
}
-
- return nil
+ return len(pkgBytes), succssCount, nil
}
// for codecs