blob: ff20ff9e9db5d60ba6df277aa90466cf6aa51e32 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package harness
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/go/pkg/beam/log"
pb "github.com/apache/beam/sdks/go/pkg/beam/model/fnexecution_v1"
"github.com/golang/protobuf/proto"
)
// ScopedStateReader scopes the global gRPC state manager to a single instruction
// for side input use. The indirection makes it easier to control access.
type ScopedStateReader struct {
mgr *StateChannelManager
instID string
opened []io.Closer // track open readers to force close all
closed bool
mu sync.Mutex
}
// NewScopedStateReader returns a ScopedStateReader for the given instruction.
func NewScopedStateReader(mgr *StateChannelManager, instID string) *ScopedStateReader {
return &ScopedStateReader{mgr: mgr, instID: instID}
}
// OpenSideInput opens a byte stream for reading iterable side input.
func (s *ScopedStateReader) OpenSideInput(ctx context.Context, id exec.StreamID, sideInputID string, key, w []byte) (io.ReadCloser, error) {
return s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader {
return newSideInputReader(ch, id, sideInputID, s.instID, key, w)
})
}
// OpenIterable opens a byte stream for reading unwindowed iterables from the runner.
func (s *ScopedStateReader) OpenIterable(ctx context.Context, id exec.StreamID, key []byte) (io.ReadCloser, error) {
return s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader {
return newRunnerReader(ch, s.instID, key)
})
}
func (s *ScopedStateReader) openReader(ctx context.Context, id exec.StreamID, readerFn func(*StateChannel) *stateKeyReader) (*stateKeyReader, error) {
ch, err := s.open(ctx, id.Port)
if err != nil {
return nil, err
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil, errors.Errorf("instruction %v no longer processing", s.instID)
}
ret := readerFn(ch)
s.opened = append(s.opened, ret)
s.mu.Unlock()
return ret, nil
}
func (s *ScopedStateReader) open(ctx context.Context, port exec.Port) (*StateChannel, error) {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil, errors.Errorf("instruction %v no longer processing", s.instID)
}
local := s.mgr
s.mu.Unlock()
return local.Open(ctx, port) // don't hold lock over potentially slow operation
}
// Close closes all open readers.
func (s *ScopedStateReader) Close() error {
s.mu.Lock()
s.closed = true
s.mgr = nil
for _, r := range s.opened {
r.Close() // force close all opened readers
}
s.opened = nil
s.mu.Unlock()
return nil
}
type stateKeyReader struct {
instID string
key *pb.StateKey
token []byte
buf []byte
eof bool
ch *StateChannel
closed bool
mu sync.Mutex
}
func newSideInputReader(ch *StateChannel, id exec.StreamID, sideInputID string, instID string, k, w []byte) *stateKeyReader {
key := &pb.StateKey{
Type: &pb.StateKey_MultimapSideInput_{
MultimapSideInput: &pb.StateKey_MultimapSideInput{
TransformId: id.PtransformID,
SideInputId: sideInputID,
Window: w,
Key: k,
},
},
}
return &stateKeyReader{
instID: instID,
key: key,
ch: ch,
}
}
func newRunnerReader(ch *StateChannel, instID string, k []byte) *stateKeyReader {
key := &pb.StateKey{
Type: &pb.StateKey_Runner_{
Runner: &pb.StateKey_Runner{
Key: k,
},
},
}
return &stateKeyReader{
instID: instID,
key: key,
ch: ch,
}
}
func (r *stateKeyReader) Read(buf []byte) (int, error) {
if r.buf == nil {
if r.eof {
return 0, io.EOF
}
// Buffer empty. Get next segment.
r.mu.Lock()
if r.closed {
r.mu.Unlock()
return 0, errors.New("side input closed")
}
local := r.ch
r.mu.Unlock()
req := &pb.StateRequest{
// Id: set by channel
InstructionId: r.instID,
StateKey: r.key,
Request: &pb.StateRequest_Get{
Get: &pb.StateGetRequest{
ContinuationToken: r.token,
},
},
}
resp, err := local.Send(req)
if err != nil {
return 0, err
}
get := resp.GetGet()
if get == nil { // no data associated with this segment.
r.eof = true
return 0, io.EOF
}
r.token = get.GetContinuationToken()
r.buf = get.GetData()
if r.token == nil {
r.eof = true // no token == this is the last segment.
}
}
n := copy(buf, r.buf)
if len(r.buf) == n {
r.buf = nil
} else {
r.buf = r.buf[n:]
}
return n, nil
}
func (r *stateKeyReader) Close() error {
r.mu.Lock()
r.closed = true
r.ch = nil
r.mu.Unlock()
return nil
}
// StateChannelManager manages data channels over the State API. A fixed number of channels
// are generally used, each managing multiple logical byte streams. Thread-safe.
type StateChannelManager struct {
ports map[string]*StateChannel
mu sync.Mutex
}
// Open opens a R/W StateChannel over the given port.
func (m *StateChannelManager) Open(ctx context.Context, port exec.Port) (*StateChannel, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.ports == nil {
m.ports = make(map[string]*StateChannel)
}
if con, ok := m.ports[port.URL]; ok {
return con, nil
}
ch, err := newStateChannel(ctx, port)
if err != nil {
return nil, err
}
m.ports[port.URL] = ch
return ch, nil
}
// StateChannel manages state transactions over a single gRPC connection.
// It does not need to track readers and writers as carefully as the
// DataChannel, because the state protocol is request-based.
type StateChannel struct {
id string
client pb.BeamFnState_StateClient
requests chan *pb.StateRequest
nextRequestNo int32
responses map[string]chan<- *pb.StateResponse
mu sync.Mutex
}
func newStateChannel(ctx context.Context, port exec.Port) (*StateChannel, error) {
cc, err := dial(ctx, port.URL, 15*time.Second)
if err != nil {
return nil, errors.Wrapf(err, "failed to connect to state service %v", port.URL)
}
client, err := pb.NewBeamFnStateClient(cc).State(ctx)
if err != nil {
cc.Close()
return nil, errors.Wrapf(err, "failed to create state client %v", port.URL)
}
ret := &StateChannel{
id: port.URL,
client: client,
requests: make(chan *pb.StateRequest, 10),
responses: make(map[string]chan<- *pb.StateResponse),
}
go ret.read(ctx)
go ret.write(ctx)
return ret, nil
}
func (c *StateChannel) read(ctx context.Context) {
for {
msg, err := c.client.Recv()
if err != nil {
if err == io.EOF {
// TODO(herohde) 10/12/2017: can this happen before shutdown? Reconnect?
log.Warnf(ctx, "StateChannel[%v].read: closed", c.id)
return
}
log.Errorf(ctx, "StateChannel[%v].read bad: %v", c.id, err)
return
}
c.mu.Lock()
ch, ok := c.responses[msg.Id]
delete(c.responses, msg.Id)
c.mu.Unlock()
if !ok {
// This can happen if Send returns an error that write handles, but
// the message was actually sent.
log.Errorf(ctx, "StateChannel[%v].read: no consumer for state response: %v", c.id, proto.MarshalTextString(msg))
continue
}
select {
case ch <- msg:
// ok
default:
panic(fmt.Sprintf("StateChannel[%v].read: failed to consume state response: %v", c.id, proto.MarshalTextString(msg)))
}
}
}
func (c *StateChannel) write(ctx context.Context) {
for req := range c.requests {
err := c.client.Send(req)
if err == nil {
continue // ok
}
// Failed to send. Return error.
c.mu.Lock()
ch, ok := c.responses[req.Id]
delete(c.responses, req.Id)
c.mu.Unlock()
if ok {
ch <- &pb.StateResponse{Id: req.Id, Error: fmt.Sprintf("failed to send: %v", err)}
} // else ignore: already received response due to race
}
}
// Send sends a state request and returns the response.
func (c *StateChannel) Send(req *pb.StateRequest) (*pb.StateResponse, error) {
id := fmt.Sprintf("r%v", atomic.AddInt32(&c.nextRequestNo, 1))
req.Id = id
ch := make(chan *pb.StateResponse, 1)
c.mu.Lock()
c.responses[id] = ch
c.mu.Unlock()
c.requests <- req
// TODO(herohde) 7/21/2018: time out?
resp := <-ch
if resp.Error != "" {
return nil, errors.New(resp.Error)
}
return resp, nil
}