blob: 627e67a919b229699edefc38ed402572c2b2ac1f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipc // import "github.com/apache/arrow/go/arrow/ipc"
import (
"bytes"
"io"
"sync/atomic"
"github.com/apache/arrow/go/arrow"
"github.com/apache/arrow/go/arrow/array"
"github.com/apache/arrow/go/arrow/internal/debug"
"github.com/apache/arrow/go/arrow/internal/flatbuf"
"github.com/apache/arrow/go/arrow/memory"
"github.com/pkg/errors"
)
// Reader reads records from an io.Reader.
// Reader expects a schema (plus any dictionaries) as the first messages
// in the stream, followed by records.
type Reader struct {
r *MessageReader
schema *arrow.Schema
refCount int64
rec array.Record
err error
types dictTypeMap
memo dictMemo
mem memory.Allocator
done bool
}
// NewReader returns a reader that reads records from an input stream.
func NewReader(r io.Reader, opts ...Option) (*Reader, error) {
cfg := newConfig()
for _, opt := range opts {
opt(cfg)
}
rr := &Reader{
r: NewMessageReader(r),
types: make(dictTypeMap),
memo: newMemo(),
mem: cfg.alloc,
}
err := rr.readSchema(cfg.schema)
if err != nil {
return nil, errors.Wrap(err, "arrow/ipc: could not read schema from stream")
}
return rr, nil
}
// Err returns the last error encountered during the iteration over the
// underlying stream.
func (r *Reader) Err() error { return r.err }
func (r *Reader) Schema() *arrow.Schema { return r.schema }
func (r *Reader) readSchema(schema *arrow.Schema) error {
msg, err := r.r.Message()
if err != nil {
return errors.Wrap(err, "arrow/ipc: could not read message schema")
}
if msg.Type() != MessageSchema {
return errors.Errorf("arrow/ipc: invalid message type (got=%v, want=%v)", msg.Type(), MessageSchema)
}
// FIXME(sbinet) refactor msg-header handling.
var schemaFB flatbuf.Schema
initFB(&schemaFB, msg.msg.Header)
r.types, err = dictTypesFromFB(&schemaFB)
if err != nil {
return errors.Wrap(err, "arrow/ipc: could read dictionary types from message schema")
}
// TODO(sbinet): in the future, we may want to reconcile IDs in the stream with
// those found in the schema.
for range r.types {
panic("not implemented") // FIXME(sbinet): ReadNextDictionary
}
r.schema, err = schemaFromFB(&schemaFB, &r.memo)
if err != nil {
return errors.Wrap(err, "arrow/ipc: could not decode schema from message schema")
}
// check the provided schema match the one read from stream.
if schema != nil && !schema.Equal(r.schema) {
return errInconsistentSchema
}
return nil
}
// Retain increases the reference count by 1.
// Retain may be called simultaneously from multiple goroutines.
func (r *Reader) Retain() {
atomic.AddInt64(&r.refCount, 1)
}
// Release decreases the reference count by 1.
// When the reference count goes to zero, the memory is freed.
// Release may be called simultaneously from multiple goroutines.
func (r *Reader) Release() {
debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases")
if atomic.AddInt64(&r.refCount, -1) == 0 {
if r.rec != nil {
r.rec.Release()
r.rec = nil
}
if r.r != nil {
r.r.Release()
r.r = nil
}
}
}
// Next returns whether a Record could be extracted from the underlying stream.
func (r *Reader) Next() bool {
if r.rec != nil {
r.rec.Release()
r.rec = nil
}
if r.err != nil || r.done {
return false
}
return r.next()
}
func (r *Reader) next() bool {
var msg *Message
msg, r.err = r.r.Message()
if r.err != nil {
r.done = true
if r.err == io.EOF {
r.err = nil
}
return false
}
if got, want := msg.Type(), MessageRecordBatch; got != want {
r.err = errors.Errorf("arrow/ipc: invalid message type (got=%v, want=%v", got, want)
return false
}
r.rec = newRecord(r.schema, msg.meta, bytes.NewReader(msg.body.Bytes()))
return true
}
// Record returns the current record that has been extracted from the
// underlying stream.
// It is valid until the next call to Next.
func (r *Reader) Record() array.Record {
return r.rec
}
// Read reads the current record from the underlying stream and an error, if any.
// When the Reader reaches the end of the underlying stream, it returns (nil, io.EOF).
func (r *Reader) Read() (array.Record, error) {
if r.rec != nil {
r.rec.Release()
r.rec = nil
}
if !r.next() {
if r.done {
return nil, io.EOF
}
return nil, r.err
}
return r.rec, nil
}
var (
_ array.RecordReader = (*Reader)(nil)
)