refactor string decode algorithm
diff --git a/string.go b/string.go
index 8702b7d..e01e83e 100644
--- a/string.go
+++ b/string.go
@@ -88,59 +88,6 @@
return 3
}
-func decodeUcs4Rune(r *bufio.Reader) (c rune, cLen, bLen int, err error) {
- c1, n1, err1 := decodeUcs2Rune(r)
- if err1 != nil {
- return c1, 0, n1, err1
- }
-
- if c1 >= 0xD800 && c1 <= 0xDBFF {
- c2, n2, err2 := decodeUcs2Rune(r)
- if err2 != nil {
- return c2, 0, n2, err2
- }
-
- c := (c1-0xD800)<<10 + (c2 - 0xDC00) + 0x10000
- return c, 2, n1 + n2, nil
- }
-
- return c1, 1, n1, nil
-}
-
-func decodeUcs2Rune(r *bufio.Reader) (rune, int, error) {
- ch, err := r.ReadByte()
- if err != nil {
- return utf8.RuneError, 1, err
- }
-
- if ch < 0x80 {
- return rune(ch), 1, nil
- }
-
- if (ch & 0xe0) == 0xc0 {
- ch1, err := r.ReadByte()
- if err != nil {
- return utf8.RuneError, 2, err
- }
- return rune(((uint32(ch) & 0x1f) << 6) + (uint32(ch1) & 0x3f)), 2, nil
- }
-
- if (ch & 0xf0) == 0xe0 {
- ch1, err := r.ReadByte()
- if err != nil {
- return utf8.RuneError, 2, err
- }
- ch2, err := r.ReadByte()
- if err != nil {
- return utf8.RuneError, 3, err
- }
- c := ((uint32(ch) & 0x0f) << 12) + ((uint32(ch1) & 0x3f) << 6) + (uint32(ch2) & 0x3f)
- return rune(c), 3, nil
- }
-
- return utf8.RuneError, 0, fmt.Errorf("bad utf-8 encoding at %x", ch)
-}
-
// # UTF-8 encoded character string split into 64k chunks
// ::= x52 b1 b0 <utf8-data> string # non-final chunk
// ::= 'S' b1 b0 <utf8-data> # string of length 0-65535
@@ -255,11 +202,8 @@
func (d *Decoder) decString(flag int32) (string, error) {
var (
- tag byte
- charTotal int
- last bool
- s string
- r rune
+ tag byte
+ s string
)
if flag != TAG_READ {
@@ -308,75 +252,160 @@
(tag >= 0x30 && tag <= 0x33) ||
(tag == BC_STRING_CHUNK || tag == BC_STRING) {
- if tag == BC_STRING_CHUNK {
- last = false
- } else {
- last = true
+ if tag != BC_STRING_CHUNK {
+ data, err := d.readStringChunkData(tag)
+ if err != nil {
+ return "", err
+ }
+ return *(*string)(unsafe.Pointer(&data)), nil
}
- l, err := d.getStringLength(tag)
- if err != nil {
- return s, perrors.WithStack(err)
- }
- charTotal = l
- charCount := 0
-
- runeData := make([]rune, charTotal)
- runeIndex := 0
-
- byteCount := 0
- byteLen := 0
- charLen := 0
+ var chunkDataSlice [][]byte
+ dataLength := 0
for {
- if charCount == charTotal {
- if last {
- return string(runeData[:runeIndex]), nil
- }
-
- b, _ := d.ReadByte()
- switch {
- case (tag >= BC_STRING_DIRECT && tag <= STRING_DIRECT_MAX) ||
- (tag >= 0x30 && tag <= 0x33) ||
- (tag == BC_STRING_CHUNK || tag == BC_STRING):
-
- if b == BC_STRING_CHUNK {
- last = false
- } else {
- last = true
- }
-
- l, err := d.getStringLength(b)
- if err != nil {
- return s, perrors.WithStack(err)
- }
- charTotal += l
- bs := make([]rune, charTotal)
- copy(bs, runeData)
- runeData = bs
-
- default:
- return s, perrors.New("expect string tag")
- }
- }
-
- r, charLen, byteLen, err = decodeUcs4Rune(d.reader)
+ data, err := d.readStringChunkData(tag)
if err != nil {
- if err == io.EOF {
- break
- }
- return s, perrors.WithStack(err)
+ return "", err
}
- runeData[runeIndex] = r
- runeIndex++
+ chunkDataSlice = append(chunkDataSlice, data)
+ dataLength += len(data)
- charCount += charLen
- byteCount += byteLen
+ // last chunk
+ if tag != BC_STRING_CHUNK {
+ allData := make([]byte, dataLength)
+ index := 0
+ for _, b := range chunkDataSlice {
+ copy(allData[index:], b)
+ index += len(b)
+ }
+ return *(*string)(unsafe.Pointer(&allData)), nil
+ }
+
+ // read next string chunk tag
+ tag, _ = d.ReadByte()
+ switch {
+ case (tag >= BC_STRING_DIRECT && tag <= STRING_DIRECT_MAX) ||
+ (tag >= 0x30 && tag <= 0x33) ||
+ (tag == BC_STRING_CHUNK || tag == BC_STRING):
+
+ default:
+ return s, perrors.New("expect string tag")
+ }
}
- return string(runeData[:runeIndex]), nil
}
return s, perrors.Errorf("unknown string tag %#x\n", tag)
}
+
+// readStringChunkData read one string chunk data as a utf8 buffer
+func (d *Decoder) readStringChunkData(tag byte) ([]byte, error) {
+ charTotal, err := d.getStringLength(tag)
+ if err != nil {
+ return nil, perrors.WithStack(err)
+ }
+
+ data := make([]byte, charTotal*3)
+
+ start := 0
+ end := 0
+
+ charCount := 0
+ charRead := 0
+
+ for charCount < charTotal {
+ _, err = io.ReadFull(d.reader, data[end:end+charTotal-charCount])
+ if err != nil {
+ return nil, err
+ }
+
+ end += charTotal - charCount
+
+ start, end, charRead, err = decode2utf8(d.reader, data, start, end)
+ if err != nil {
+ return nil, err
+ }
+
+ charCount += charRead
+ }
+
+ return data[:end], nil
+}
+
+// decode2utf8 decode hessian2 buffer to utf8 buffer
+// parameters:
+// - r : the input buffer
+// - data: the buffer already read
+// - start: the decoding index
+// - end: the already read buffer index
+// response: updated start, updated end, read char count, error.
+func decode2utf8(r *bufio.Reader, data []byte, start, end int) (int, int, int, error) {
+ var err error
+
+ charCount := 0
+
+ for start < end {
+ ch := data[start]
+ if ch < 0x80 {
+ start++
+ charCount++
+ continue
+ }
+
+ if start+1 == end {
+ data[end], err = r.ReadByte()
+ if err != nil {
+ return start, end, 0, err
+ }
+ end++
+ }
+
+ if (ch & 0xe0) == 0xc0 {
+ start += 2
+ charCount++
+ continue
+ }
+
+ if start+2 == end {
+ data[end], err = r.ReadByte()
+ if err != nil {
+ return start, end, 0, err
+ }
+ end++
+ }
+
+ if (ch & 0xf0) == 0xe0 {
+ c1 := ((uint32(ch) & 0x0f) << 12) + ((uint32(data[start+1]) & 0x3f) << 6) + (uint32(data[start+2]) & 0x3f)
+
+ if c1 >= 0xD800 && c1 <= 0xDBFF {
+ if start+6 >= end {
+ _, err = io.ReadFull(r, data[end:start+6])
+ if err != nil {
+ return start, end, 0, err
+ }
+ end = start + 6
+ }
+
+ c2 := ((uint32(data[start+3]) & 0x0f) << 12) + ((uint32(data[start+4]) & 0x3f) << 6) + (uint32(data[start+5]) & 0x3f)
+ c := (c1-0xD800)<<10 + (c2 - 0xDC00) + 0x10000
+
+ n := utf8.EncodeRune(data[start:], rune(c))
+ copy(data[start+n:], data[start+6:end])
+ start, end = start+n, end-6+n
+
+ charCount += 2
+ continue
+ }
+
+ start += 3
+ charCount++
+ continue
+ }
+
+ return start, end, 0, fmt.Errorf("bad utf-8 encoding at %x", ch)
+ }
+
+ return start, end, charCount, nil
+}
diff --git a/string_test.go b/string_test.go
index be3ac73..de945f2 100644
--- a/string_test.go
+++ b/string_test.go
@@ -19,6 +19,7 @@
import (
"fmt"
+ "strings"
"sync"
"testing"
)
@@ -212,3 +213,18 @@
testDecodeFramework(t, "customReplyComplexString", s0)
testJavaDecode(t, "customArgComplexString", s0)
}
+
+func BenchmarkDecodeString(b *testing.B) {
+ s := "❄️🚫🚫🚫🚫 多次自我介绍、任务、动态和"
+ s = strings.Repeat(s, 4096)
+
+ e := NewEncoder()
+ _ = e.Encode(s)
+ buf := e.buffer
+
+ d := NewDecoder(buf)
+ for i := 0; i < b.N; i++ {
+ d.Reset(buf)
+ _, _ = d.Decode()
+ }
+}