blob: 2186fd68aee52aa4bbd3adbe913f2d514fb99bc6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package schema
import (
"testing"
"github.com/apache/arrow/go/v6/parquet"
format "github.com/apache/arrow/go/v6/parquet/internal/gen-go/parquet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
func NewPrimitive(name string, repetition format.FieldRepetitionType, typ format.Type, fieldID int32) *format.SchemaElement {
ret := &format.SchemaElement{
Name: name,
RepetitionType: format.FieldRepetitionTypePtr(repetition),
Type: format.TypePtr(typ),
}
if fieldID >= 0 {
ret.FieldID = &fieldID
}
return ret
}
func NewGroup(name string, repetition format.FieldRepetitionType, numChildren, fieldID int32) *format.SchemaElement {
ret := &format.SchemaElement{
Name: name,
RepetitionType: format.FieldRepetitionTypePtr(repetition),
NumChildren: &numChildren,
}
if fieldID >= 0 {
ret.FieldID = &fieldID
}
return ret
}
type SchemaFlattenSuite struct {
suite.Suite
name string
}
func (s *SchemaFlattenSuite) SetupSuite() {
s.name = "parquet_schema"
}
func (s *SchemaFlattenSuite) TestDecimalMetadata() {
group := MustGroup(NewGroupNodeConverted("group" /* name */, parquet.Repetitions.Repeated, FieldList{
MustPrimitive(NewPrimitiveNodeConverted("decimal" /* name */, parquet.Repetitions.Required, parquet.Types.Int64,
ConvertedTypes.Decimal, 0 /* type len */, 8 /* precision */, 4 /* scale */, -1 /* fieldID */)),
}, ConvertedTypes.List, -1 /* fieldID */))
elements := ToThrift(group)
s.Len(elements, 2)
s.Equal("decimal", elements[1].GetName())
s.True(elements[1].IsSetPrecision())
s.True(elements[1].IsSetScale())
group = MustGroup(NewGroupNodeLogical("group" /* name */, parquet.Repetitions.Repeated, FieldList{
MustPrimitive(NewPrimitiveNodeLogical("decimal" /* name */, parquet.Repetitions.Required, NewDecimalLogicalType(10 /* precision */, 5 /* scale */),
parquet.Types.Int64, 0 /* type len */, -1 /* fieldID */)),
}, NewListLogicalType(), -1 /* fieldID */))
elements = ToThrift(group)
s.Equal("decimal", elements[1].Name)
s.True(elements[1].IsSetPrecision())
s.True(elements[1].IsSetScale())
group = MustGroup(NewGroupNodeConverted("group" /* name */, parquet.Repetitions.Repeated, FieldList{
NewInt64Node("int64" /* name */, parquet.Repetitions.Required, -1 /* fieldID */)}, ConvertedTypes.List, -1 /* fieldID */))
elements = ToThrift(group)
s.Equal("int64", elements[1].Name)
s.False(elements[0].IsSetPrecision())
s.False(elements[1].IsSetPrecision())
s.False(elements[0].IsSetScale())
s.False(elements[1].IsSetScale())
}
func (s *SchemaFlattenSuite) TestNestedExample() {
elements := make([]*format.SchemaElement, 0)
elements = append(elements,
NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 /* numChildren */, 0 /* fieldID */),
NewPrimitive("a" /* name */, format.FieldRepetitionType_REQUIRED, format.Type_INT32, 1 /* fieldID */),
NewGroup("bag" /* name */, format.FieldRepetitionType_OPTIONAL, 1 /* numChildren */, 2 /* fieldID */))
elt := NewGroup("b" /* name */, format.FieldRepetitionType_REPEATED, 1 /* numChildren */, 3 /* fieldID */)
elt.ConvertedType = format.ConvertedTypePtr(format.ConvertedType_LIST)
elt.LogicalType = &format.LogicalType{LIST: format.NewListType()}
elements = append(elements, elt, NewPrimitive("item" /* name */, format.FieldRepetitionType_OPTIONAL, format.Type_INT64, 4 /* fieldID */))
fields := FieldList{NewInt32Node("a" /* name */, parquet.Repetitions.Required, 1 /* fieldID */)}
list := MustGroup(NewGroupNodeConverted("b" /* name */, parquet.Repetitions.Repeated, FieldList{
NewInt64Node("item" /* name */, parquet.Repetitions.Optional, 4 /* fieldID */)}, ConvertedTypes.List, 3 /* fieldID */))
fields = append(fields, MustGroup(NewGroupNode("bag" /* name */, parquet.Repetitions.Optional, FieldList{list}, 2 /* fieldID */)))
sc := MustGroup(NewGroupNode(s.name, parquet.Repetitions.Repeated, fields, 0 /* fieldID */))
flattened := ToThrift(sc)
s.Len(flattened, len(elements))
for idx, elem := range flattened {
s.Equal(elements[idx], elem)
}
}
func TestSchemaFlatten(t *testing.T) {
suite.Run(t, new(SchemaFlattenSuite))
}
func TestInvalidConvertedTypeInDeserialize(t *testing.T) {
n := MustPrimitive(NewPrimitiveNodeLogical("string" /* name */, parquet.Repetitions.Required, StringLogicalType{},
parquet.Types.ByteArray, -1 /* type len */, -1 /* fieldID */))
assert.True(t, n.LogicalType().Equals(StringLogicalType{}))
assert.True(t, n.LogicalType().IsValid())
assert.True(t, n.LogicalType().IsSerialized())
intermediary := n.toThrift()
// corrupt it
intermediary.LogicalType.STRING = nil
assert.Panics(t, func() {
PrimitiveNodeFromThrift(intermediary)
})
}
func TestInvalidTimeUnitInTimeLogical(t *testing.T) {
n := MustPrimitive(NewPrimitiveNodeLogical("time" /* name */, parquet.Repetitions.Required,
NewTimeLogicalType(true /* adjustedToUTC */, TimeUnitNanos), parquet.Types.Int64, -1 /* type len */, -1 /* fieldID */))
intermediary := n.toThrift()
// corrupt it
intermediary.LogicalType.TIME.Unit.NANOS = nil
assert.Panics(t, func() {
PrimitiveNodeFromThrift(intermediary)
})
}
func TestInvalidTimeUnitInTimestampLogical(t *testing.T) {
n := MustPrimitive(NewPrimitiveNodeLogical("time" /* name */, parquet.Repetitions.Required,
NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitNanos), parquet.Types.Int64, -1 /* type len */, -1 /* fieldID */))
intermediary := n.toThrift()
// corrupt it
intermediary.LogicalType.TIMESTAMP.Unit.NANOS = nil
assert.Panics(t, func() {
PrimitiveNodeFromThrift(intermediary)
})
}