blob: 0b825c2d51058716d31e03357fb8338cd23151ca [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package coder
import (
"bytes"
"fmt"
"io"
"reflect"
"testing"
"github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
"github.com/google/go-cmp/cmp"
)
func TestEncodeDecodeMap(t *testing.T) {
byteEnc := containerEncoderForType(reflectx.Uint8)
byteDec := containerDecoderForType(reflectx.Uint8)
bytePtrEnc := containerEncoderForType(reflect.PtrTo(reflectx.Uint8))
bytePtrDec := containerDecoderForType(reflect.PtrTo(reflectx.Uint8))
ptrByte := byte(42)
tests := []struct {
v interface{}
encK, encV func(reflect.Value, io.Writer) error
decK, decV func(reflect.Value, io.Reader) error
encoded []byte
decodeOnly bool
}{
{
v: map[byte]byte{10: 42},
encK: byteEnc,
encV: byteEnc,
decK: byteDec,
decV: byteDec,
encoded: []byte{0, 0, 0, 1, 10, 42},
}, {
v: map[byte]*byte{10: &ptrByte},
encK: byteEnc,
encV: bytePtrEnc,
decK: byteDec,
decV: bytePtrDec,
encoded: []byte{0, 0, 0, 1, 10, 1, 42},
}, {
v: map[byte]*byte{10: &ptrByte, 23: nil, 53: nil},
encK: byteEnc,
encV: bytePtrEnc,
decK: byteDec,
decV: bytePtrDec,
encoded: []byte{0, 0, 0, 3, 10, 1, 42, 23, 0, 53, 0},
decodeOnly: true,
},
}
for _, test := range tests {
test := test
if !test.decodeOnly {
t.Run(fmt.Sprintf("encode %q", test.v), func(t *testing.T) {
var buf bytes.Buffer
err := mapEncoder(reflect.TypeOf(test.v), test.encK, test.encV)(reflect.ValueOf(test.v), &buf)
if err != nil {
t.Fatalf("mapEncoder(%q) = %v", test.v, err)
}
if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" {
t.Errorf("mapEncoder(%q) = %v, want %v diff(-want,+got):\n %v", test.v, buf.Bytes(), test.encoded, d)
}
})
}
t.Run(fmt.Sprintf("decode %v", test.v), func(t *testing.T) {
buf := bytes.NewBuffer(test.encoded)
rt := reflect.TypeOf(test.v)
var dec func(reflect.Value, io.Reader) error
dec = mapDecoder(rt, test.decK, test.decV)
rv := reflect.New(rt).Elem()
err := dec(rv, buf)
if err != nil {
t.Fatalf("mapDecoder(%q) = %v", test.encoded, err)
}
got := rv.Interface()
if d := cmp.Diff(test.v, got); d != "" {
t.Errorf("mapDecoder(%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, got, test.v, d)
}
})
}
}