Merge pull request #12467 [BEAM-7996] Add map & nil encoding to Go SDK.
diff --git a/sdks/go/pkg/beam/core/graph/coder/map.go b/sdks/go/pkg/beam/core/graph/coder/map.go
new file mode 100644
index 0000000..4e5dc2c
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/map.go
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+ "io"
+ "reflect"
+)
+
+// TODO(lostluck): 2020.08.04 export these for use for others?
+
+// mapDecoder produces a decoder for the beam schema map encoding.
+func mapDecoder(rt reflect.Type, decodeToKey, decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error {
+ return func(ret reflect.Value, r io.Reader) error {
+ // (1) Read count prefixed encoded data
+ size, err := DecodeInt32(r)
+ if err != nil {
+ return err
+ }
+ n := int(size)
+ ret.Set(reflect.MakeMapWithSize(rt, n))
+ for i := 0; i < n; i++ {
+ rvk := reflect.New(rt.Key()).Elem()
+ if err := decodeToKey(rvk, r); err != nil {
+ return err
+ }
+ rvv := reflect.New(rt.Elem()).Elem()
+ if err := decodeToElem(rvv, r); err != nil {
+ return err
+ }
+ ret.SetMapIndex(rvk, rvv)
+ }
+ return nil
+ }
+}
+
+// containerNilDecoder handles when a value is nillable for map or iterable components.
+// Nillable types have an extra byte prefixing them indicating nil status.
+func containerNilDecoder(decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error {
+ return func(ret reflect.Value, r io.Reader) error {
+ hasValue, err := DecodeBool(r)
+ if err != nil {
+ return err
+ }
+ if !hasValue {
+ return nil
+ }
+ rv := reflect.New(ret.Type().Elem())
+ if err := decodeToElem(rv.Elem(), r); err != nil {
+ return err
+ }
+ ret.Set(rv)
+ return nil
+ }
+}
+
+// mapEncoder reflectively encodes a map or array type using the beam map encoding.
+func mapEncoder(rt reflect.Type, encodeKey, encodeValue func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error {
+ return func(rv reflect.Value, w io.Writer) error {
+ size := rv.Len()
+ if err := EncodeInt32((int32)(size), w); err != nil {
+ return err
+ }
+ iter := rv.MapRange()
+ for iter.Next() {
+ if err := encodeKey(iter.Key(), w); err != nil {
+ return err
+ }
+ if err := encodeValue(iter.Value(), w); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+// containerNilEncoder handles when a value is nillable for map or iterable components.
+// Nillable types have an extra byte prefixing them indicating nil status.
+func containerNilEncoder(encodeElem func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error {
+ return func(rv reflect.Value, w io.Writer) error {
+ if rv.IsNil() {
+ return EncodeBool(false, w)
+ }
+ if err := EncodeBool(true, w); err != nil {
+ return err
+ }
+ return encodeElem(rv.Elem(), w)
+ }
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/map_test.go b/sdks/go/pkg/beam/core/graph/coder/map_test.go
new file mode 100644
index 0000000..0b825c2
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/map_test.go
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestEncodeDecodeMap(t *testing.T) {
+ byteEnc := containerEncoderForType(reflectx.Uint8)
+ byteDec := containerDecoderForType(reflectx.Uint8)
+ bytePtrEnc := containerEncoderForType(reflect.PtrTo(reflectx.Uint8))
+ bytePtrDec := containerDecoderForType(reflect.PtrTo(reflectx.Uint8))
+
+ ptrByte := byte(42)
+
+ tests := []struct {
+ v interface{}
+ encK, encV func(reflect.Value, io.Writer) error
+ decK, decV func(reflect.Value, io.Reader) error
+ encoded []byte
+ decodeOnly bool
+ }{
+ {
+ v: map[byte]byte{10: 42},
+ encK: byteEnc,
+ encV: byteEnc,
+ decK: byteDec,
+ decV: byteDec,
+ encoded: []byte{0, 0, 0, 1, 10, 42},
+ }, {
+ v: map[byte]*byte{10: &ptrByte},
+ encK: byteEnc,
+ encV: bytePtrEnc,
+ decK: byteDec,
+ decV: bytePtrDec,
+ encoded: []byte{0, 0, 0, 1, 10, 1, 42},
+ }, {
+ v: map[byte]*byte{10: &ptrByte, 23: nil, 53: nil},
+ encK: byteEnc,
+ encV: bytePtrEnc,
+ decK: byteDec,
+ decV: bytePtrDec,
+ encoded: []byte{0, 0, 0, 3, 10, 1, 42, 23, 0, 53, 0},
+ decodeOnly: true,
+ },
+ }
+ for _, test := range tests {
+ test := test
+ if !test.decodeOnly {
+ t.Run(fmt.Sprintf("encode %q", test.v), func(t *testing.T) {
+ var buf bytes.Buffer
+ err := mapEncoder(reflect.TypeOf(test.v), test.encK, test.encV)(reflect.ValueOf(test.v), &buf)
+ if err != nil {
+ t.Fatalf("mapEncoder(%q) = %v", test.v, err)
+ }
+ if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" {
+ t.Errorf("mapEncoder(%q) = %v, want %v diff(-want,+got):\n %v", test.v, buf.Bytes(), test.encoded, d)
+ }
+ })
+ }
+ t.Run(fmt.Sprintf("decode %v", test.v), func(t *testing.T) {
+ buf := bytes.NewBuffer(test.encoded)
+ rt := reflect.TypeOf(test.v)
+ var dec func(reflect.Value, io.Reader) error
+ dec = mapDecoder(rt, test.decK, test.decV)
+ rv := reflect.New(rt).Elem()
+ err := dec(rv, buf)
+ if err != nil {
+ t.Fatalf("mapDecoder(%q) = %v", test.encoded, err)
+ }
+ got := rv.Interface()
+ if d := cmp.Diff(test.v, got); d != "" {
+ t.Errorf("mapDecoder(%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, got, test.v, d)
+ }
+ })
+ }
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row.go b/sdks/go/pkg/beam/core/graph/coder/row.go
index aac34ac..00b4c26 100644
--- a/sdks/go/pkg/beam/core/graph/coder/row.go
+++ b/sdks/go/pkg/beam/core/graph/coder/row.go
@@ -152,21 +152,26 @@
return nil
}
}
- decf := decoderForSingleTypeReflect(t.Elem())
- sdec := iterableDecoderForSlice(t, decf)
- return func(rv reflect.Value, r io.Reader) error {
- return sdec(rv, r)
- }
+ decf := containerDecoderForType(t.Elem())
+ return iterableDecoderForSlice(t, decf)
case reflect.Array:
- decf := decoderForSingleTypeReflect(t.Elem())
- sdec := iterableDecoderForArray(t, decf)
- return func(rv reflect.Value, r io.Reader) error {
- return sdec(rv, r)
- }
+ decf := containerDecoderForType(t.Elem())
+ return iterableDecoderForArray(t, decf)
+ case reflect.Map:
+ decK := containerDecoderForType(t.Key())
+ decV := containerDecoderForType(t.Elem())
+ return mapDecoder(t, decK, decV)
}
panic(fmt.Sprintf("unimplemented type to decode: %v", t))
}
+func containerDecoderForType(t reflect.Type) func(reflect.Value, io.Reader) error {
+ if t.Kind() == reflect.Ptr {
+ return containerNilDecoder(decoderForSingleTypeReflect(t.Elem()))
+ }
+ return decoderForSingleTypeReflect(t)
+}
+
type typeDecoderReflect struct {
typ reflect.Type
fields []func(reflect.Value, io.Reader) error
@@ -270,15 +275,26 @@
return EncodeBytes(rv.Bytes(), w)
}
}
- encf := encoderForSingleTypeReflect(t.Elem())
+ encf := containerEncoderForType(t.Elem())
return iterableEncoder(t, encf)
case reflect.Array:
- encf := encoderForSingleTypeReflect(t.Elem())
+ encf := containerEncoderForType(t.Elem())
return iterableEncoder(t, encf)
+ case reflect.Map:
+ encK := containerEncoderForType(t.Key())
+ encV := containerEncoderForType(t.Elem())
+ return mapEncoder(t, encK, encV)
}
panic(fmt.Sprintf("unimplemented type to encode: %v", t))
}
+func containerEncoderForType(t reflect.Type) func(reflect.Value, io.Writer) error {
+ if t.Kind() == reflect.Ptr {
+ return containerNilEncoder(encoderForSingleTypeReflect(t.Elem()))
+ }
+ return encoderForSingleTypeReflect(t)
+}
+
type typeEncoderReflect struct {
fields []func(reflect.Value, io.Writer) error
}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row_test.go b/sdks/go/pkg/beam/core/graph/coder/row_test.go
index f1089b8..38b7c5d 100644
--- a/sdks/go/pkg/beam/core/graph/coder/row_test.go
+++ b/sdks/go/pkg/beam/core/graph/coder/row_test.go
@@ -78,16 +78,18 @@
V12 [0]int
V13 [2]int
V14 []int
- // V15 map[string]int // not yet a standard coder (BEAM-7996)
+ V15 map[string]int
V16 float32
V17 float64
V18 []byte
+ V19 [2]*int
+ V20 map[*string]*int
}{},
}, {
want: struct {
V00 bool
- V01 byte
- V02 uint8
+ V01 byte // unsupported by spec (same as uint8)
+ V02 uint8 // unsupported by spec
V03 int16
// V04 uint16 // unsupported by spec
V05 int32
@@ -100,10 +102,13 @@
V12 [0]int
V13 [2]int
V14 []int
- // V15 map[string]int // not yet a standard coder (BEAM-7996) (encoding unspecified)
+ V15 map[string]int
V16 float32
V17 float64
V18 []byte
+ V19 [2]*int
+ V20 map[string]*int
+ V21 []*int
}{
V00: true,
V01: 1,
@@ -117,9 +122,16 @@
V12: [0]int{},
V13: [2]int{72, 908},
V14: []int{12, 9326, 641346, 6},
+ V15: map[string]int{"pants": 42},
V16: 3.14169,
V17: 2.6e100,
V18: []byte{21, 17, 65, 255, 0, 16},
+ V19: [2]*int{nil, &num},
+ V20: map[string]*int{
+ "notnil": &num,
+ "nil": nil,
+ },
+ V21: []*int{nil, &num, nil},
},
// TODO add custom types such as protocol buffers.
},