fix(#114): fix transfering large body failed (#124)
* fix(#114): fix transfering large body failed
diff --git a/internal/http/request.go b/internal/http/request.go
index 561e1b7..5c00172 100644
--- a/internal/http/request.go
+++ b/internal/http/request.go
@@ -366,19 +366,19 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = util.RPCExtraInfo
- n, err := c.Write(header)
+ n, err := util.WriteBytes(c, header, len(header))
if err != nil {
util.WriteErr(n, err)
return nil, common.ErrConnClosed
}
- n, err = c.Write(out)
+ n, err = util.WriteBytes(c, out, size)
if err != nil {
util.WriteErr(n, err)
return nil, common.ErrConnClosed
}
- n, err = c.Read(header)
+ n, err = util.ReadBytes(c, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return nil, common.ErrConnClosed
}
@@ -390,7 +390,7 @@
log.Infof("receive rpc type: %d data length: %d", ty, length)
buf := make([]byte, length)
- n, err = c.Read(buf)
+ n, err = util.ReadBytes(c, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return nil, common.ErrConnClosed
}
diff --git a/internal/http/request_test.go b/internal/http/request_test.go
index 51fe954..d6b58ed 100644
--- a/internal/http/request_test.go
+++ b/internal/http/request_test.go
@@ -305,7 +305,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -316,7 +316,7 @@
length := binary.BigEndian.Uint32(header)
buf := make([]byte, length)
- n, err = sc.Read(buf)
+ n, err = util.ReadBytes(sc, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return
}
@@ -336,13 +336,13 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = util.RPCExtraInfo
- n, err = sc.Write(header)
+ n, err = util.WriteBytes(sc, header, len(header))
if err != nil {
util.WriteErr(n, err)
return
}
- n, err = sc.Write(out)
+ n, err = util.WriteBytes(sc, out, size)
if err != nil {
util.WriteErr(n, err)
return
@@ -365,7 +365,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -385,7 +385,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -396,7 +396,7 @@
length := binary.BigEndian.Uint32(header)
buf := make([]byte, length)
- n, err = sc.Read(buf)
+ n, err = util.ReadBytes(sc, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return
}
@@ -458,7 +458,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -469,7 +469,7 @@
length := binary.BigEndian.Uint32(header)
buf := make([]byte, length)
- n, err = sc.Read(buf)
+ n, err = util.ReadBytes(sc, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return
}
@@ -488,13 +488,13 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = util.RPCExtraInfo
- n, err = sc.Write(header)
+ n, err = util.WriteBytes(sc, header, len(header))
if err != nil {
util.WriteErr(n, err)
return
}
- n, err = sc.Write(out)
+ n, err = util.WriteBytes(sc, out, size)
if err != nil {
util.WriteErr(n, err)
return
diff --git a/internal/http/response.go b/internal/http/response.go
index 7c97f42..7d7870e 100644
--- a/internal/http/response.go
+++ b/internal/http/response.go
@@ -73,19 +73,19 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = util.RPCExtraInfo
- n, err := c.Write(header)
+ n, err := util.WriteBytes(c, header, len(header))
if err != nil {
util.WriteErr(n, err)
return nil, common.ErrConnClosed
}
- n, err = c.Write(out)
+ n, err = util.WriteBytes(c, out, size)
if err != nil {
util.WriteErr(n, err)
return nil, common.ErrConnClosed
}
- n, err = c.Read(header)
+ n, err = util.ReadBytes(c, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return nil, common.ErrConnClosed
}
@@ -97,7 +97,7 @@
log.Infof("receive rpc type: %d data length: %d", ty, length)
buf := make([]byte, length)
- n, err = c.Read(buf)
+ n, err = util.ReadBytes(c, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return nil, common.ErrConnClosed
}
diff --git a/internal/http/response_test.go b/internal/http/response_test.go
index 128bf73..4062dd6 100644
--- a/internal/http/response_test.go
+++ b/internal/http/response_test.go
@@ -202,7 +202,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -213,7 +213,7 @@
length := binary.BigEndian.Uint32(header)
buf := make([]byte, length)
- n, err = sc.Read(buf)
+ n, err = util.ReadBytes(sc, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return
}
@@ -233,13 +233,13 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = util.RPCExtraInfo
- n, err = sc.Write(header)
+ n, err = util.WriteBytes(sc, header, len(header))
if err != nil {
util.WriteErr(n, err)
return
}
- n, err = sc.Write(out)
+ n, err = util.WriteBytes(sc, out, size)
if err != nil {
util.WriteErr(n, err)
return
@@ -262,7 +262,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -282,7 +282,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -293,7 +293,7 @@
length := binary.BigEndian.Uint32(header)
buf := make([]byte, length)
- n, err = sc.Read(buf)
+ n, err = util.ReadBytes(sc, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return
}
@@ -314,7 +314,7 @@
go func() {
header := make([]byte, util.HeaderLen)
- n, err := sc.Read(header)
+ n, err := util.ReadBytes(sc, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
return
}
@@ -325,7 +325,7 @@
length := binary.BigEndian.Uint32(header)
buf := make([]byte, length)
- n, err = sc.Read(buf)
+ n, err = util.ReadBytes(sc, buf, int(length))
if util.ReadErr(n, err, int(length)) {
return
}
@@ -344,13 +344,13 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = util.RPCExtraInfo
- n, err = sc.Write(header)
+ n, err = util.WriteBytes(sc, header, len(header))
if err != nil {
util.WriteErr(n, err)
return
}
- n, err = sc.Write(out)
+ n, err = util.WriteBytes(sc, out, size)
if err != nil {
util.WriteErr(n, err)
return
diff --git a/internal/server/server.go b/internal/server/server.go
index 5ac291f..e8a0d8c 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -117,7 +117,7 @@
header := make([]byte, util.HeaderLen)
for {
- n, err := c.Read(header)
+ n, err := util.ReadBytes(c, header, util.HeaderLen)
if util.ReadErr(n, err, util.HeaderLen) {
break
}
@@ -131,7 +131,7 @@
log.Infof("receive rpc type: %d data length: %d", ty, length)
buf := make([]byte, length)
- n, err = c.Read(buf)
+ n, err = util.ReadBytes(c, buf, int(length))
if util.ReadErr(n, err, int(length)) {
break
}
@@ -142,13 +142,13 @@
binary.BigEndian.PutUint32(header, uint32(size))
header[0] = respTy
- n, err = c.Write(header)
+ n, err = util.WriteBytes(c, header, len(header))
if err != nil {
util.WriteErr(n, err)
break
}
- n, err = c.Write(out)
+ n, err = util.WriteBytes(c, out, size)
if err != nil {
util.WriteErr(n, err)
break
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index b854973..a976b18 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -121,7 +121,7 @@
conn, err := net.DialTimeout("unix", addr[len("unix:"):], 1*time.Second)
assert.NotNil(t, conn, err)
defer conn.Close()
- conn.Write(c.header)
+ util.WriteBytes(conn, c.header, len(c.header))
}
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
diff --git a/internal/util/msg.go b/internal/util/msg.go
index f95e2d9..78121ff 100644
--- a/internal/util/msg.go
+++ b/internal/util/msg.go
@@ -20,6 +20,7 @@
import (
"fmt"
"io"
+ "net"
flatbuffers "github.com/google/flatbuffers/go"
@@ -65,3 +66,27 @@
log.Errorf("write: %s", err)
}
}
+
+func ReadBytes(c net.Conn, b []byte, n int) (int, error) {
+ l := 0
+ for l < n {
+ tmp, err := c.Read(b[l:])
+ if err != nil {
+ return l + tmp, err
+ }
+ l += tmp
+ }
+ return l, nil
+}
+
+func WriteBytes(c net.Conn, b []byte, n int) (int, error) {
+ l := 0
+ for l < n {
+ tmp, err := c.Write(b[l:])
+ if err != nil {
+ return l + tmp, err
+ }
+ l += tmp
+ }
+ return l, nil
+}
diff --git a/internal/util/msg_test.go b/internal/util/msg_test.go
new file mode 100644
index 0000000..24a526e
--- /dev/null
+++ b/internal/util/msg_test.go
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package util
+
+import (
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestReadAndWriteBytes(t *testing.T) {
+ path := "/tmp/test.sock"
+ server, err := net.Listen("unix", path)
+ assert.NoError(t, err)
+ defer server.Close()
+
+ // transfer large enough data
+ n := 10000000
+
+ const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ in := make([]byte, n)
+ for i := range in {
+ in[i] = letterBytes[rand.Intn(len(letterBytes))]
+ }
+
+ go func() {
+ client, err := net.DialTimeout("unix", path, 1*time.Second)
+ assert.NoError(t, err)
+ defer client.Close()
+ WriteBytes(client, in, len(in))
+ }()
+
+ fd, err := server.Accept()
+ assert.NoError(t, err)
+ out := make([]byte, n)
+ ReadBytes(fd, out, n)
+ assert.Equal(t, in, out)
+}