blob: 325a96274c2e2383df1f3da21b839a456017f4da [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package snippets
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder/testutil"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/schema"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)
func atomicSchemaField(name string, typ pipepb.AtomicType) *pipepb.Field {
return &pipepb.Field{
Name: name,
Type: &pipepb.FieldType{
TypeInfo: &pipepb.FieldType_AtomicType{
AtomicType: typ,
},
},
}
}
func rowSchemaField(name string, typ *pipepb.Schema) *pipepb.Field {
return &pipepb.Field{
Name: name,
Type: &pipepb.FieldType{
TypeInfo: &pipepb.FieldType_RowType{
RowType: &pipepb.RowType{
Schema: typ,
},
},
},
}
}
func listSchemaField(name string, typ *pipepb.Field) *pipepb.Field {
return &pipepb.Field{
Name: name,
Type: &pipepb.FieldType{
TypeInfo: &pipepb.FieldType_ArrayType{
ArrayType: &pipepb.ArrayType{
ElementType: typ.GetType(),
},
},
},
}
}
func nillable(f *pipepb.Field) *pipepb.Field {
f.Type.Nullable = true
return f
}
func TestSchemaTypes(t *testing.T) {
transactionSchema := &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("bank", pipepb.AtomicType_STRING),
atomicSchemaField("purchaseAmount", pipepb.AtomicType_DOUBLE),
},
}
shippingAddressSchema := &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("streetAddress", pipepb.AtomicType_STRING),
atomicSchemaField("city", pipepb.AtomicType_STRING),
nillable(atomicSchemaField("state", pipepb.AtomicType_STRING)),
atomicSchemaField("country", pipepb.AtomicType_STRING),
atomicSchemaField("postCode", pipepb.AtomicType_STRING),
},
}
tests := []struct {
rt reflect.Type
st *pipepb.Schema
preReg func(reg *schema.Registry)
}{{
rt: reflect.TypeOf(Transaction{}),
st: transactionSchema,
}, {
rt: reflect.TypeOf(ShippingAddress{}),
st: shippingAddressSchema,
}, {
rt: reflect.TypeOf(Purchase{}),
st: &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("userId", pipepb.AtomicType_STRING),
atomicSchemaField("itemId", pipepb.AtomicType_INT64),
rowSchemaField("shippingAddress", shippingAddressSchema),
atomicSchemaField("cost", pipepb.AtomicType_INT64),
listSchemaField("transactions",
rowSchemaField("n/a", transactionSchema)),
},
},
}, {
rt: tnType,
st: &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("seconds", pipepb.AtomicType_INT64),
atomicSchemaField("nanos", pipepb.AtomicType_INT32),
},
},
preReg: func(reg *schema.Registry) {
reg.RegisterLogicalType(schema.ToLogicalType(tnType.Name(), tnType, tnStorageType))
},
}}
for _, test := range tests {
t.Run(fmt.Sprintf("%v", test.rt), func(t *testing.T) {
reg := schema.NewRegistry()
if test.preReg != nil {
test.preReg(reg)
}
{
got, err := reg.FromType(test.rt)
if err != nil {
t.Fatalf("error FromType(%v) = %v", test.rt, err)
}
if d := cmp.Diff(test.st, got,
protocmp.Transform(),
protocmp.IgnoreFields(proto.Message(&pipepb.Schema{}), "id", "options"),
); d != "" {
t.Errorf("diff (-want, +got): %v", d)
}
}
})
}
}
func TestSchema_validate(t *testing.T) {
tests := []struct {
rt reflect.Type
p beam.SchemaProvider
logical, storage any
}{
{
rt: tnType,
p: &TimestampNanosProvider{},
logical: TimestampNanos(time.Unix(2300003, 456789)),
storage: tnStorage{},
},
}
for _, test := range tests {
sc := &testutil.SchemaCoder{
CmpOptions: cmp.Options{
cmp.Comparer(func(a, b TimestampNanos) bool {
return a.Seconds() == b.Seconds() && a.Nanos() == b.Nanos()
})},
}
sc.Validate(t, test.rt, test.p.BuildEncoder, test.p.BuildDecoder, test.storage, test.logical)
}
}