| package msgpack |
| |
| import ( |
| "bufio" |
| "bytes" |
| "errors" |
| "fmt" |
| "io" |
| "reflect" |
| "time" |
| |
| "github.com/vmihailenco/msgpack/codes" |
| ) |
| |
| const bytesAllocLimit = 1024 * 1024 // 1mb |
| |
| type bufReader interface { |
| io.Reader |
| io.ByteScanner |
| } |
| |
| func newBufReader(r io.Reader) bufReader { |
| if br, ok := r.(bufReader); ok { |
| return br |
| } |
| return bufio.NewReader(r) |
| } |
| |
| func makeBuffer() []byte { |
| return make([]byte, 0, 64) |
| } |
| |
| // Unmarshal decodes the MessagePack-encoded data and stores the result |
| // in the value pointed to by v. |
| func Unmarshal(data []byte, v interface{}) error { |
| return NewDecoder(bytes.NewReader(data)).Decode(v) |
| } |
| |
| type Decoder struct { |
| r io.Reader |
| s io.ByteScanner |
| buf []byte |
| |
| extLen int |
| rec []byte // accumulates read data if not nil |
| |
| useLoose bool |
| useJSONTag bool |
| |
| decodeMapFunc func(*Decoder) (interface{}, error) |
| } |
| |
| // NewDecoder returns a new decoder that reads from r. |
| // |
| // The decoder introduces its own buffering and may read data from r |
| // beyond the MessagePack values requested. Buffering can be disabled |
| // by passing a reader that implements io.ByteScanner interface. |
| func NewDecoder(r io.Reader) *Decoder { |
| d := &Decoder{ |
| buf: makeBuffer(), |
| } |
| d.resetReader(r) |
| return d |
| } |
| |
| func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) { |
| d.decodeMapFunc = fn |
| } |
| |
| // UseDecodeInterfaceLoose causes decoder to use DecodeInterfaceLoose |
| // to decode msgpack value into Go interface{}. |
| func (d *Decoder) UseDecodeInterfaceLoose(flag bool) { |
| d.useLoose = flag |
| } |
| |
| // UseJSONTag causes the Decoder to use json struct tag as fallback option |
| // if there is no msgpack tag. |
| func (d *Decoder) UseJSONTag(v bool) *Decoder { |
| d.useJSONTag = v |
| return d |
| } |
| |
| func (d *Decoder) Reset(r io.Reader) error { |
| d.resetReader(r) |
| return nil |
| } |
| |
| func (d *Decoder) resetReader(r io.Reader) { |
| reader := newBufReader(r) |
| d.r = reader |
| d.s = reader |
| } |
| |
| func (d *Decoder) Decode(v interface{}) error { |
| var err error |
| switch v := v.(type) { |
| case *string: |
| if v != nil { |
| *v, err = d.DecodeString() |
| return err |
| } |
| case *[]byte: |
| if v != nil { |
| return d.decodeBytesPtr(v) |
| } |
| case *int: |
| if v != nil { |
| *v, err = d.DecodeInt() |
| return err |
| } |
| case *int8: |
| if v != nil { |
| *v, err = d.DecodeInt8() |
| return err |
| } |
| case *int16: |
| if v != nil { |
| *v, err = d.DecodeInt16() |
| return err |
| } |
| case *int32: |
| if v != nil { |
| *v, err = d.DecodeInt32() |
| return err |
| } |
| case *int64: |
| if v != nil { |
| *v, err = d.DecodeInt64() |
| return err |
| } |
| case *uint: |
| if v != nil { |
| *v, err = d.DecodeUint() |
| return err |
| } |
| case *uint8: |
| if v != nil { |
| *v, err = d.DecodeUint8() |
| return err |
| } |
| case *uint16: |
| if v != nil { |
| *v, err = d.DecodeUint16() |
| return err |
| } |
| case *uint32: |
| if v != nil { |
| *v, err = d.DecodeUint32() |
| return err |
| } |
| case *uint64: |
| if v != nil { |
| *v, err = d.DecodeUint64() |
| return err |
| } |
| case *bool: |
| if v != nil { |
| *v, err = d.DecodeBool() |
| return err |
| } |
| case *float32: |
| if v != nil { |
| *v, err = d.DecodeFloat32() |
| return err |
| } |
| case *float64: |
| if v != nil { |
| *v, err = d.DecodeFloat64() |
| return err |
| } |
| case *[]string: |
| return d.decodeStringSlicePtr(v) |
| case *map[string]string: |
| return d.decodeMapStringStringPtr(v) |
| case *map[string]interface{}: |
| return d.decodeMapStringInterfacePtr(v) |
| case *time.Duration: |
| if v != nil { |
| vv, err := d.DecodeInt64() |
| *v = time.Duration(vv) |
| return err |
| } |
| case *time.Time: |
| if v != nil { |
| *v, err = d.DecodeTime() |
| return err |
| } |
| } |
| |
| vv := reflect.ValueOf(v) |
| if !vv.IsValid() { |
| return errors.New("msgpack: Decode(nil)") |
| } |
| if vv.Kind() != reflect.Ptr { |
| return fmt.Errorf("msgpack: Decode(nonsettable %T)", v) |
| } |
| vv = vv.Elem() |
| if !vv.IsValid() { |
| return fmt.Errorf("msgpack: Decode(nonsettable %T)", v) |
| } |
| return d.DecodeValue(vv) |
| } |
| |
| func (d *Decoder) DecodeMulti(v ...interface{}) error { |
| for _, vv := range v { |
| if err := d.Decode(vv); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| func (d *Decoder) decodeInterfaceCond() (interface{}, error) { |
| if d.useLoose { |
| return d.DecodeInterfaceLoose() |
| } |
| return d.DecodeInterface() |
| } |
| |
| func (d *Decoder) DecodeValue(v reflect.Value) error { |
| decode := getDecoder(v.Type()) |
| return decode(d, v) |
| } |
| |
| func (d *Decoder) DecodeNil() error { |
| c, err := d.readCode() |
| if err != nil { |
| return err |
| } |
| if c != codes.Nil { |
| return fmt.Errorf("msgpack: invalid code=%x decoding nil", c) |
| } |
| return nil |
| } |
| |
| func (d *Decoder) decodeNilValue(v reflect.Value) error { |
| err := d.DecodeNil() |
| if v.IsNil() { |
| return err |
| } |
| if v.Kind() == reflect.Ptr { |
| v = v.Elem() |
| } |
| v.Set(reflect.Zero(v.Type())) |
| return err |
| } |
| |
| func (d *Decoder) DecodeBool() (bool, error) { |
| c, err := d.readCode() |
| if err != nil { |
| return false, err |
| } |
| return d.bool(c) |
| } |
| |
| func (d *Decoder) bool(c codes.Code) (bool, error) { |
| if c == codes.False { |
| return false, nil |
| } |
| if c == codes.True { |
| return true, nil |
| } |
| return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c) |
| } |
| |
| // DecodeInterface decodes value into interface. It returns following types: |
| // - nil, |
| // - bool, |
| // - int8, int16, int32, int64, |
| // - uint8, uint16, uint32, uint64, |
| // - float32 and float64, |
| // - string, |
| // - []byte, |
| // - slices of any of the above, |
| // - maps of any of the above. |
| // |
| // DecodeInterface should be used only when you don't know the type of value |
| // you are decoding. For example, if you are decoding number it is better to use |
| // DecodeInt64 for negative numbers and DecodeUint64 for positive numbers. |
| func (d *Decoder) DecodeInterface() (interface{}, error) { |
| c, err := d.readCode() |
| if err != nil { |
| return nil, err |
| } |
| |
| if codes.IsFixedNum(c) { |
| return int8(c), nil |
| } |
| if codes.IsFixedMap(c) { |
| err = d.s.UnreadByte() |
| if err != nil { |
| return nil, err |
| } |
| return d.DecodeMap() |
| } |
| if codes.IsFixedArray(c) { |
| return d.decodeSlice(c) |
| } |
| if codes.IsFixedString(c) { |
| return d.string(c) |
| } |
| |
| switch c { |
| case codes.Nil: |
| return nil, nil |
| case codes.False, codes.True: |
| return d.bool(c) |
| case codes.Float: |
| return d.float32(c) |
| case codes.Double: |
| return d.float64(c) |
| case codes.Uint8: |
| return d.uint8() |
| case codes.Uint16: |
| return d.uint16() |
| case codes.Uint32: |
| return d.uint32() |
| case codes.Uint64: |
| return d.uint64() |
| case codes.Int8: |
| return d.int8() |
| case codes.Int16: |
| return d.int16() |
| case codes.Int32: |
| return d.int32() |
| case codes.Int64: |
| return d.int64() |
| case codes.Bin8, codes.Bin16, codes.Bin32: |
| return d.bytes(c, nil) |
| case codes.Str8, codes.Str16, codes.Str32: |
| return d.string(c) |
| case codes.Array16, codes.Array32: |
| return d.decodeSlice(c) |
| case codes.Map16, codes.Map32: |
| err = d.s.UnreadByte() |
| if err != nil { |
| return nil, err |
| } |
| return d.DecodeMap() |
| case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16, |
| codes.Ext8, codes.Ext16, codes.Ext32: |
| return d.extInterface(c) |
| } |
| |
| return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c) |
| } |
| |
| // DecodeInterfaceLoose is like DecodeInterface except that: |
| // - int8, int16, and int32 are converted to int64, |
| // - uint8, uint16, and uint32 are converted to uint64, |
| // - float32 is converted to float64. |
| func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) { |
| c, err := d.readCode() |
| if err != nil { |
| return nil, err |
| } |
| |
| if codes.IsFixedNum(c) { |
| return int64(c), nil |
| } |
| if codes.IsFixedMap(c) { |
| err = d.s.UnreadByte() |
| if err != nil { |
| return nil, err |
| } |
| return d.DecodeMap() |
| } |
| if codes.IsFixedArray(c) { |
| return d.decodeSlice(c) |
| } |
| if codes.IsFixedString(c) { |
| return d.string(c) |
| } |
| |
| switch c { |
| case codes.Nil: |
| return nil, nil |
| case codes.False, codes.True: |
| return d.bool(c) |
| case codes.Float, codes.Double: |
| return d.float64(c) |
| case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64: |
| return d.uint(c) |
| case codes.Int8, codes.Int16, codes.Int32, codes.Int64: |
| return d.int(c) |
| case codes.Bin8, codes.Bin16, codes.Bin32: |
| return d.bytes(c, nil) |
| case codes.Str8, codes.Str16, codes.Str32: |
| return d.string(c) |
| case codes.Array16, codes.Array32: |
| return d.decodeSlice(c) |
| case codes.Map16, codes.Map32: |
| err = d.s.UnreadByte() |
| if err != nil { |
| return nil, err |
| } |
| return d.DecodeMap() |
| case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16, |
| codes.Ext8, codes.Ext16, codes.Ext32: |
| return d.extInterface(c) |
| } |
| |
| return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c) |
| } |
| |
| // Skip skips next value. |
| func (d *Decoder) Skip() error { |
| c, err := d.readCode() |
| if err != nil { |
| return err |
| } |
| |
| if codes.IsFixedNum(c) { |
| return nil |
| } else if codes.IsFixedMap(c) { |
| return d.skipMap(c) |
| } else if codes.IsFixedArray(c) { |
| return d.skipSlice(c) |
| } else if codes.IsFixedString(c) { |
| return d.skipBytes(c) |
| } |
| |
| switch c { |
| case codes.Nil, codes.False, codes.True: |
| return nil |
| case codes.Uint8, codes.Int8: |
| return d.skipN(1) |
| case codes.Uint16, codes.Int16: |
| return d.skipN(2) |
| case codes.Uint32, codes.Int32, codes.Float: |
| return d.skipN(4) |
| case codes.Uint64, codes.Int64, codes.Double: |
| return d.skipN(8) |
| case codes.Bin8, codes.Bin16, codes.Bin32: |
| return d.skipBytes(c) |
| case codes.Str8, codes.Str16, codes.Str32: |
| return d.skipBytes(c) |
| case codes.Array16, codes.Array32: |
| return d.skipSlice(c) |
| case codes.Map16, codes.Map32: |
| return d.skipMap(c) |
| case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16, |
| codes.Ext8, codes.Ext16, codes.Ext32: |
| return d.skipExt(c) |
| } |
| |
| return fmt.Errorf("msgpack: unknown code %x", c) |
| } |
| |
| // PeekCode returns the next MessagePack code without advancing the reader. |
| // Subpackage msgpack/codes contains list of available codes. |
| func (d *Decoder) PeekCode() (codes.Code, error) { |
| c, err := d.s.ReadByte() |
| if err != nil { |
| return 0, err |
| } |
| return codes.Code(c), d.s.UnreadByte() |
| } |
| |
| func (d *Decoder) hasNilCode() bool { |
| code, err := d.PeekCode() |
| return err == nil && code == codes.Nil |
| } |
| |
| func (d *Decoder) readCode() (codes.Code, error) { |
| d.extLen = 0 |
| c, err := d.s.ReadByte() |
| if err != nil { |
| return 0, err |
| } |
| if d.rec != nil { |
| d.rec = append(d.rec, c) |
| } |
| return codes.Code(c), nil |
| } |
| |
| func (d *Decoder) readFull(b []byte) error { |
| _, err := io.ReadFull(d.r, b) |
| if err != nil { |
| return err |
| } |
| if d.rec != nil { |
| d.rec = append(d.rec, b...) |
| } |
| return nil |
| } |
| |
| func (d *Decoder) readN(n int) ([]byte, error) { |
| buf, err := readN(d.r, d.buf, n) |
| if err != nil { |
| return nil, err |
| } |
| d.buf = buf |
| if d.rec != nil { |
| d.rec = append(d.rec, buf...) |
| } |
| return buf, nil |
| } |
| |
| func readN(r io.Reader, b []byte, n int) ([]byte, error) { |
| if b == nil { |
| if n == 0 { |
| return make([]byte, 0), nil |
| } |
| if n <= bytesAllocLimit { |
| b = make([]byte, n) |
| } else { |
| b = make([]byte, bytesAllocLimit) |
| } |
| } |
| |
| if n <= cap(b) { |
| b = b[:n] |
| _, err := io.ReadFull(r, b) |
| return b, err |
| } |
| b = b[:cap(b)] |
| |
| var pos int |
| for { |
| alloc := n - len(b) |
| if alloc > bytesAllocLimit { |
| alloc = bytesAllocLimit |
| } |
| b = append(b, make([]byte, alloc)...) |
| |
| _, err := io.ReadFull(r, b[pos:]) |
| if err != nil { |
| return nil, err |
| } |
| |
| if len(b) == n { |
| break |
| } |
| pos = len(b) |
| } |
| |
| return b, nil |
| } |
| |
| func min(a, b int) int { |
| if a <= b { |
| return a |
| } |
| return b |
| } |