blob: 282cbf213010b9bf20912e6df014bd629d35a9f8 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tensor
import (
"github.com/apache/arrow/go/arrow"
"github.com/apache/arrow/go/arrow/array"
)
{{range .In}}
// {{.Name}} is an n-dim array of {{.Type}}s.
type {{.Name}} struct {
tensorBase
values []{{.Type}}
}
// New{{.Name}} returns a new n-dimensional array of {{.Type}}s.
// If strides is nil, row-major strides will be inferred.
// If names is nil, a slice of empty strings will be created.
func New{{.Name}}(data *array.Data, shape, strides []int64, names []string) *{{.Name}} {
tsr := &{{.Name}}{tensorBase:*newTensor(arrow.PrimitiveTypes.{{.Name}}, data, shape, strides, names)}
vals := tsr.data.Buffers()[1]
if vals != nil {
tsr.values = arrow.{{.Name}}Traits.CastFromBytes(vals.Bytes())
beg := tsr.data.Offset()
end := beg + tsr.data.Len()
tsr.values = tsr.values[beg:end]
}
return tsr
}
func (tsr *{{.Name}}) Value(i []int64) {{or .Type}} { j := int(tsr.offset(i)); return tsr.values[j] }
func (tsr *{{.Name}}) {{.Name}}Values() []{{or .Type}} { return tsr.values }
{{end}}
var (
{{range .In}}
_ Interface = (*{{.Name}})(nil)
{{- end}}
)