| package pq |
| |
| import ( |
| "bytes" |
| "database/sql" |
| "database/sql/driver" |
| "encoding/hex" |
| "fmt" |
| "reflect" |
| "strconv" |
| "strings" |
| ) |
| |
| var typeByteSlice = reflect.TypeOf([]byte{}) |
| var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() |
| var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() |
| |
| // Array returns the optimal driver.Valuer and sql.Scanner for an array or |
| // slice of any dimension. |
| // |
| // For example: |
| // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) |
| // |
| // var x []sql.NullInt64 |
| // db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) |
| // |
| // Scanning multi-dimensional arrays is not supported. Arrays where the lower |
| // bound is not one (such as `[0:0]={1}') are not supported. |
| func Array(a interface{}) interface { |
| driver.Valuer |
| sql.Scanner |
| } { |
| switch a := a.(type) { |
| case []bool: |
| return (*BoolArray)(&a) |
| case []float64: |
| return (*Float64Array)(&a) |
| case []int64: |
| return (*Int64Array)(&a) |
| case []string: |
| return (*StringArray)(&a) |
| |
| case *[]bool: |
| return (*BoolArray)(a) |
| case *[]float64: |
| return (*Float64Array)(a) |
| case *[]int64: |
| return (*Int64Array)(a) |
| case *[]string: |
| return (*StringArray)(a) |
| } |
| |
| return GenericArray{a} |
| } |
| |
| // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner |
| // to override the array delimiter used by GenericArray. |
| type ArrayDelimiter interface { |
| // ArrayDelimiter returns the delimiter character(s) for this element's type. |
| ArrayDelimiter() string |
| } |
| |
| // BoolArray represents a one-dimensional array of the PostgreSQL boolean type. |
| type BoolArray []bool |
| |
| // Scan implements the sql.Scanner interface. |
| func (a *BoolArray) Scan(src interface{}) error { |
| switch src := src.(type) { |
| case []byte: |
| return a.scanBytes(src) |
| case string: |
| return a.scanBytes([]byte(src)) |
| case nil: |
| *a = nil |
| return nil |
| } |
| |
| return fmt.Errorf("pq: cannot convert %T to BoolArray", src) |
| } |
| |
| func (a *BoolArray) scanBytes(src []byte) error { |
| elems, err := scanLinearArray(src, []byte{','}, "BoolArray") |
| if err != nil { |
| return err |
| } |
| if *a != nil && len(elems) == 0 { |
| *a = (*a)[:0] |
| } else { |
| b := make(BoolArray, len(elems)) |
| for i, v := range elems { |
| if len(v) != 1 { |
| return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) |
| } |
| switch v[0] { |
| case 't': |
| b[i] = true |
| case 'f': |
| b[i] = false |
| default: |
| return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) |
| } |
| } |
| *a = b |
| } |
| return nil |
| } |
| |
| // Value implements the driver.Valuer interface. |
| func (a BoolArray) Value() (driver.Value, error) { |
| if a == nil { |
| return nil, nil |
| } |
| |
| if n := len(a); n > 0 { |
| // There will be exactly two curly brackets, N bytes of values, |
| // and N-1 bytes of delimiters. |
| b := make([]byte, 1+2*n) |
| |
| for i := 0; i < n; i++ { |
| b[2*i] = ',' |
| if a[i] { |
| b[1+2*i] = 't' |
| } else { |
| b[1+2*i] = 'f' |
| } |
| } |
| |
| b[0] = '{' |
| b[2*n] = '}' |
| |
| return string(b), nil |
| } |
| |
| return "{}", nil |
| } |
| |
| // ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. |
| type ByteaArray [][]byte |
| |
| // Scan implements the sql.Scanner interface. |
| func (a *ByteaArray) Scan(src interface{}) error { |
| switch src := src.(type) { |
| case []byte: |
| return a.scanBytes(src) |
| case string: |
| return a.scanBytes([]byte(src)) |
| case nil: |
| *a = nil |
| return nil |
| } |
| |
| return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) |
| } |
| |
| func (a *ByteaArray) scanBytes(src []byte) error { |
| elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") |
| if err != nil { |
| return err |
| } |
| if *a != nil && len(elems) == 0 { |
| *a = (*a)[:0] |
| } else { |
| b := make(ByteaArray, len(elems)) |
| for i, v := range elems { |
| b[i], err = parseBytea(v) |
| if err != nil { |
| return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) |
| } |
| } |
| *a = b |
| } |
| return nil |
| } |
| |
| // Value implements the driver.Valuer interface. It uses the "hex" format which |
| // is only supported on PostgreSQL 9.0 or newer. |
| func (a ByteaArray) Value() (driver.Value, error) { |
| if a == nil { |
| return nil, nil |
| } |
| |
| if n := len(a); n > 0 { |
| // There will be at least two curly brackets, 2*N bytes of quotes, |
| // 3*N bytes of hex formatting, and N-1 bytes of delimiters. |
| size := 1 + 6*n |
| for _, x := range a { |
| size += hex.EncodedLen(len(x)) |
| } |
| |
| b := make([]byte, size) |
| |
| for i, s := 0, b; i < n; i++ { |
| o := copy(s, `,"\\x`) |
| o += hex.Encode(s[o:], a[i]) |
| s[o] = '"' |
| s = s[o+1:] |
| } |
| |
| b[0] = '{' |
| b[size-1] = '}' |
| |
| return string(b), nil |
| } |
| |
| return "{}", nil |
| } |
| |
| // Float64Array represents a one-dimensional array of the PostgreSQL double |
| // precision type. |
| type Float64Array []float64 |
| |
| // Scan implements the sql.Scanner interface. |
| func (a *Float64Array) Scan(src interface{}) error { |
| switch src := src.(type) { |
| case []byte: |
| return a.scanBytes(src) |
| case string: |
| return a.scanBytes([]byte(src)) |
| case nil: |
| *a = nil |
| return nil |
| } |
| |
| return fmt.Errorf("pq: cannot convert %T to Float64Array", src) |
| } |
| |
| func (a *Float64Array) scanBytes(src []byte) error { |
| elems, err := scanLinearArray(src, []byte{','}, "Float64Array") |
| if err != nil { |
| return err |
| } |
| if *a != nil && len(elems) == 0 { |
| *a = (*a)[:0] |
| } else { |
| b := make(Float64Array, len(elems)) |
| for i, v := range elems { |
| if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { |
| return fmt.Errorf("pq: parsing array element index %d: %v", i, err) |
| } |
| } |
| *a = b |
| } |
| return nil |
| } |
| |
| // Value implements the driver.Valuer interface. |
| func (a Float64Array) Value() (driver.Value, error) { |
| if a == nil { |
| return nil, nil |
| } |
| |
| if n := len(a); n > 0 { |
| // There will be at least two curly brackets, N bytes of values, |
| // and N-1 bytes of delimiters. |
| b := make([]byte, 1, 1+2*n) |
| b[0] = '{' |
| |
| b = strconv.AppendFloat(b, a[0], 'f', -1, 64) |
| for i := 1; i < n; i++ { |
| b = append(b, ',') |
| b = strconv.AppendFloat(b, a[i], 'f', -1, 64) |
| } |
| |
| return string(append(b, '}')), nil |
| } |
| |
| return "{}", nil |
| } |
| |
| // GenericArray implements the driver.Valuer and sql.Scanner interfaces for |
| // an array or slice of any dimension. |
| type GenericArray struct{ A interface{} } |
| |
| func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { |
| var assign func([]byte, reflect.Value) error |
| var del = "," |
| |
| // TODO calculate the assign function for other types |
| // TODO repeat this section on the element type of arrays or slices (multidimensional) |
| { |
| if reflect.PtrTo(rt).Implements(typeSqlScanner) { |
| // dest is always addressable because it is an element of a slice. |
| assign = func(src []byte, dest reflect.Value) (err error) { |
| ss := dest.Addr().Interface().(sql.Scanner) |
| if src == nil { |
| err = ss.Scan(nil) |
| } else { |
| err = ss.Scan(src) |
| } |
| return |
| } |
| goto FoundType |
| } |
| |
| assign = func([]byte, reflect.Value) error { |
| return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) |
| } |
| } |
| |
| FoundType: |
| |
| if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { |
| del = ad.ArrayDelimiter() |
| } |
| |
| return rt, assign, del |
| } |
| |
| // Scan implements the sql.Scanner interface. |
| func (a GenericArray) Scan(src interface{}) error { |
| dpv := reflect.ValueOf(a.A) |
| switch { |
| case dpv.Kind() != reflect.Ptr: |
| return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) |
| case dpv.IsNil(): |
| return fmt.Errorf("pq: destination %T is nil", a.A) |
| } |
| |
| dv := dpv.Elem() |
| switch dv.Kind() { |
| case reflect.Slice: |
| case reflect.Array: |
| default: |
| return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) |
| } |
| |
| switch src := src.(type) { |
| case []byte: |
| return a.scanBytes(src, dv) |
| case string: |
| return a.scanBytes([]byte(src), dv) |
| case nil: |
| if dv.Kind() == reflect.Slice { |
| dv.Set(reflect.Zero(dv.Type())) |
| return nil |
| } |
| } |
| |
| return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) |
| } |
| |
| func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { |
| dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) |
| dims, elems, err := parseArray(src, []byte(del)) |
| if err != nil { |
| return err |
| } |
| |
| // TODO allow multidimensional |
| |
| if len(dims) > 1 { |
| return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", |
| strings.Replace(fmt.Sprint(dims), " ", "][", -1)) |
| } |
| |
| // Treat a zero-dimensional array like an array with a single dimension of zero. |
| if len(dims) == 0 { |
| dims = append(dims, 0) |
| } |
| |
| for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { |
| switch rt.Kind() { |
| case reflect.Slice: |
| case reflect.Array: |
| if rt.Len() != dims[i] { |
| return fmt.Errorf("pq: cannot convert ARRAY%s to %s", |
| strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) |
| } |
| default: |
| // TODO handle multidimensional |
| } |
| } |
| |
| values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) |
| for i, e := range elems { |
| if err := assign(e, values.Index(i)); err != nil { |
| return fmt.Errorf("pq: parsing array element index %d: %v", i, err) |
| } |
| } |
| |
| // TODO handle multidimensional |
| |
| switch dv.Kind() { |
| case reflect.Slice: |
| dv.Set(values.Slice(0, dims[0])) |
| case reflect.Array: |
| for i := 0; i < dims[0]; i++ { |
| dv.Index(i).Set(values.Index(i)) |
| } |
| } |
| |
| return nil |
| } |
| |
| // Value implements the driver.Valuer interface. |
| func (a GenericArray) Value() (driver.Value, error) { |
| if a.A == nil { |
| return nil, nil |
| } |
| |
| rv := reflect.ValueOf(a.A) |
| |
| switch rv.Kind() { |
| case reflect.Slice: |
| if rv.IsNil() { |
| return nil, nil |
| } |
| case reflect.Array: |
| default: |
| return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) |
| } |
| |
| if n := rv.Len(); n > 0 { |
| // There will be at least two curly brackets, N bytes of values, |
| // and N-1 bytes of delimiters. |
| b := make([]byte, 0, 1+2*n) |
| |
| b, _, err := appendArray(b, rv, n) |
| return string(b), err |
| } |
| |
| return "{}", nil |
| } |
| |
| // Int64Array represents a one-dimensional array of the PostgreSQL integer types. |
| type Int64Array []int64 |
| |
| // Scan implements the sql.Scanner interface. |
| func (a *Int64Array) Scan(src interface{}) error { |
| switch src := src.(type) { |
| case []byte: |
| return a.scanBytes(src) |
| case string: |
| return a.scanBytes([]byte(src)) |
| case nil: |
| *a = nil |
| return nil |
| } |
| |
| return fmt.Errorf("pq: cannot convert %T to Int64Array", src) |
| } |
| |
| func (a *Int64Array) scanBytes(src []byte) error { |
| elems, err := scanLinearArray(src, []byte{','}, "Int64Array") |
| if err != nil { |
| return err |
| } |
| if *a != nil && len(elems) == 0 { |
| *a = (*a)[:0] |
| } else { |
| b := make(Int64Array, len(elems)) |
| for i, v := range elems { |
| if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { |
| return fmt.Errorf("pq: parsing array element index %d: %v", i, err) |
| } |
| } |
| *a = b |
| } |
| return nil |
| } |
| |
| // Value implements the driver.Valuer interface. |
| func (a Int64Array) Value() (driver.Value, error) { |
| if a == nil { |
| return nil, nil |
| } |
| |
| if n := len(a); n > 0 { |
| // There will be at least two curly brackets, N bytes of values, |
| // and N-1 bytes of delimiters. |
| b := make([]byte, 1, 1+2*n) |
| b[0] = '{' |
| |
| b = strconv.AppendInt(b, a[0], 10) |
| for i := 1; i < n; i++ { |
| b = append(b, ',') |
| b = strconv.AppendInt(b, a[i], 10) |
| } |
| |
| return string(append(b, '}')), nil |
| } |
| |
| return "{}", nil |
| } |
| |
| // StringArray represents a one-dimensional array of the PostgreSQL character types. |
| type StringArray []string |
| |
| // Scan implements the sql.Scanner interface. |
| func (a *StringArray) Scan(src interface{}) error { |
| switch src := src.(type) { |
| case []byte: |
| return a.scanBytes(src) |
| case string: |
| return a.scanBytes([]byte(src)) |
| case nil: |
| *a = nil |
| return nil |
| } |
| |
| return fmt.Errorf("pq: cannot convert %T to StringArray", src) |
| } |
| |
| func (a *StringArray) scanBytes(src []byte) error { |
| elems, err := scanLinearArray(src, []byte{','}, "StringArray") |
| if err != nil { |
| return err |
| } |
| if *a != nil && len(elems) == 0 { |
| *a = (*a)[:0] |
| } else { |
| b := make(StringArray, len(elems)) |
| for i, v := range elems { |
| if b[i] = string(v); v == nil { |
| return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) |
| } |
| } |
| *a = b |
| } |
| return nil |
| } |
| |
| // Value implements the driver.Valuer interface. |
| func (a StringArray) Value() (driver.Value, error) { |
| if a == nil { |
| return nil, nil |
| } |
| |
| if n := len(a); n > 0 { |
| // There will be at least two curly brackets, 2*N bytes of quotes, |
| // and N-1 bytes of delimiters. |
| b := make([]byte, 1, 1+3*n) |
| b[0] = '{' |
| |
| b = appendArrayQuotedBytes(b, []byte(a[0])) |
| for i := 1; i < n; i++ { |
| b = append(b, ',') |
| b = appendArrayQuotedBytes(b, []byte(a[i])) |
| } |
| |
| return string(append(b, '}')), nil |
| } |
| |
| return "{}", nil |
| } |
| |
| // appendArray appends rv to the buffer, returning the extended buffer and |
| // the delimiter used between elements. |
| // |
| // It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. |
| func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { |
| var del string |
| var err error |
| |
| b = append(b, '{') |
| |
| if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { |
| return b, del, err |
| } |
| |
| for i := 1; i < n; i++ { |
| b = append(b, del...) |
| if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { |
| return b, del, err |
| } |
| } |
| |
| return append(b, '}'), del, nil |
| } |
| |
| // appendArrayElement appends rv to the buffer, returning the extended buffer |
| // and the delimiter to use before the next element. |
| // |
| // When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted |
| // using driver.DefaultParameterConverter and the resulting []byte or string |
| // is double-quoted. |
| // |
| // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO |
| func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { |
| if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { |
| if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { |
| if n := rv.Len(); n > 0 { |
| return appendArray(b, rv, n) |
| } |
| |
| return b, "", nil |
| } |
| } |
| |
| var del string = "," |
| var err error |
| var iv interface{} = rv.Interface() |
| |
| if ad, ok := iv.(ArrayDelimiter); ok { |
| del = ad.ArrayDelimiter() |
| } |
| |
| if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { |
| return b, del, err |
| } |
| |
| switch v := iv.(type) { |
| case nil: |
| return append(b, "NULL"...), del, nil |
| case []byte: |
| return appendArrayQuotedBytes(b, v), del, nil |
| case string: |
| return appendArrayQuotedBytes(b, []byte(v)), del, nil |
| } |
| |
| b, err = appendValue(b, iv) |
| return b, del, err |
| } |
| |
| func appendArrayQuotedBytes(b, v []byte) []byte { |
| b = append(b, '"') |
| for { |
| i := bytes.IndexAny(v, `"\`) |
| if i < 0 { |
| b = append(b, v...) |
| break |
| } |
| if i > 0 { |
| b = append(b, v[:i]...) |
| } |
| b = append(b, '\\', v[i]) |
| v = v[i+1:] |
| } |
| return append(b, '"') |
| } |
| |
| func appendValue(b []byte, v driver.Value) ([]byte, error) { |
| return append(b, encode(nil, v, 0)...), nil |
| } |
| |
| // parseArray extracts the dimensions and elements of an array represented in |
| // text format. Only representations emitted by the backend are supported. |
| // Notably, whitespace around brackets and delimiters is significant, and NULL |
| // is case-sensitive. |
| // |
| // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO |
| func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { |
| var depth, i int |
| |
| if len(src) < 1 || src[0] != '{' { |
| return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) |
| } |
| |
| Open: |
| for i < len(src) { |
| switch src[i] { |
| case '{': |
| depth++ |
| i++ |
| case '}': |
| elems = make([][]byte, 0) |
| goto Close |
| default: |
| break Open |
| } |
| } |
| dims = make([]int, i) |
| |
| Element: |
| for i < len(src) { |
| switch src[i] { |
| case '{': |
| if depth == len(dims) { |
| break Element |
| } |
| depth++ |
| dims[depth-1] = 0 |
| i++ |
| case '"': |
| var elem = []byte{} |
| var escape bool |
| for i++; i < len(src); i++ { |
| if escape { |
| elem = append(elem, src[i]) |
| escape = false |
| } else { |
| switch src[i] { |
| default: |
| elem = append(elem, src[i]) |
| case '\\': |
| escape = true |
| case '"': |
| elems = append(elems, elem) |
| i++ |
| break Element |
| } |
| } |
| } |
| default: |
| for start := i; i < len(src); i++ { |
| if bytes.HasPrefix(src[i:], del) || src[i] == '}' { |
| elem := src[start:i] |
| if len(elem) == 0 { |
| return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) |
| } |
| if bytes.Equal(elem, []byte("NULL")) { |
| elem = nil |
| } |
| elems = append(elems, elem) |
| break Element |
| } |
| } |
| } |
| } |
| |
| for i < len(src) { |
| if bytes.HasPrefix(src[i:], del) && depth > 0 { |
| dims[depth-1]++ |
| i += len(del) |
| goto Element |
| } else if src[i] == '}' && depth > 0 { |
| dims[depth-1]++ |
| depth-- |
| i++ |
| } else { |
| return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) |
| } |
| } |
| |
| Close: |
| for i < len(src) { |
| if src[i] == '}' && depth > 0 { |
| depth-- |
| i++ |
| } else { |
| return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) |
| } |
| } |
| if depth > 0 { |
| err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) |
| } |
| if err == nil { |
| for _, d := range dims { |
| if (len(elems) % d) != 0 { |
| err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") |
| } |
| } |
| } |
| return |
| } |
| |
| func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { |
| dims, elems, err := parseArray(src, del) |
| if err != nil { |
| return nil, err |
| } |
| if len(dims) > 1 { |
| return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) |
| } |
| return elems, err |
| } |