blob: f5d3102a1d19fe3f9095aef32bd4930b3987299f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package harness
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"github.com/apache/beam/sdks/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/go/pkg/beam/log"
pb "github.com/apache/beam/sdks/go/pkg/beam/model/fnexecution_v1"
"github.com/golang/protobuf/proto"
)
// ScopedSideInputReader scopes the global gRPC state manager to a single instruction
// for side input use. The indirection makes it easier to control access.
type ScopedSideInputReader struct {
mgr *StateChannelManager
instID string
opened []io.Closer // track open readers to force close all
closed bool
mu sync.Mutex
}
// NewScopedSideInputReader returns a ScopedSideInputReader for the given instruction.
func NewScopedSideInputReader(mgr *StateChannelManager, instID string) *ScopedSideInputReader {
return &ScopedSideInputReader{mgr: mgr, instID: instID}
}
func (s *ScopedSideInputReader) Open(ctx context.Context, id exec.StreamID, key, w []byte) (io.ReadCloser, error) {
ch, err := s.open(ctx, id.Port)
if err != nil {
return nil, err
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil, errors.Errorf("instruction %v no longer processing", s.instID)
}
ret := newSideInputReader(ch, id.Target, s.instID, key, w)
s.opened = append(s.opened, ret)
s.mu.Unlock()
return ret, nil
}
func (s *ScopedSideInputReader) open(ctx context.Context, port exec.Port) (*StateChannel, error) {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil, errors.Errorf("instruction %v no longer processing", s.instID)
}
local := s.mgr
s.mu.Unlock()
return local.Open(ctx, port) // don't hold lock over potentially slow operation
}
func (s *ScopedSideInputReader) Close() error {
s.mu.Lock()
s.closed = true
s.mgr = nil
for _, r := range s.opened {
r.Close() // force close all opened readers
}
s.opened = nil
s.mu.Unlock()
return nil
}
type sideInputReader struct {
instID string
key *pb.StateKey
token []byte
buf []byte
eof bool
ch *StateChannel
closed bool
mu sync.Mutex
}
func newSideInputReader(ch *StateChannel, target exec.Target, instID string, k, w []byte) *sideInputReader {
key := &pb.StateKey{
Type: &pb.StateKey_MultimapSideInput_{
MultimapSideInput: &pb.StateKey_MultimapSideInput{
PtransformId: target.ID,
SideInputId: target.Name,
Window: w,
Key: k,
},
},
}
return &sideInputReader{
instID: instID,
key: key,
ch: ch,
}
}
func (r *sideInputReader) Read(buf []byte) (int, error) {
if r.buf == nil {
if r.eof {
return 0, io.EOF
}
// Buffer empty. Get next segment.
r.mu.Lock()
if r.closed {
r.mu.Unlock()
return 0, errors.New("side input closed")
}
local := r.ch
r.mu.Unlock()
req := &pb.StateRequest{
// Id: set by channel
InstructionReference: r.instID,
StateKey: r.key,
Request: &pb.StateRequest_Get{
Get: &pb.StateGetRequest{
ContinuationToken: r.token,
},
},
}
resp, err := local.Send(req)
if err != nil {
return 0, err
}
get := resp.GetGet()
r.token = get.ContinuationToken
r.buf = get.Data
if r.token == nil {
r.eof = true // no token == this is the last segment.
}
}
n := copy(buf, r.buf)
if len(r.buf) == n {
r.buf = nil
} else {
r.buf = r.buf[n:]
}
return n, nil
}
func (r *sideInputReader) Close() error {
r.mu.Lock()
r.closed = true
r.ch = nil
r.mu.Unlock()
return nil
}
// StateChannelManager manages data channels over the State API. A fixed number of channels
// are generally used, each managing multiple logical byte streams. Thread-safe.
type StateChannelManager struct {
ports map[string]*StateChannel
mu sync.Mutex
}
// Open opens a R/W StateChannel over the given port.
func (m *StateChannelManager) Open(ctx context.Context, port exec.Port) (*StateChannel, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.ports == nil {
m.ports = make(map[string]*StateChannel)
}
if con, ok := m.ports[port.URL]; ok {
return con, nil
}
ch, err := newStateChannel(ctx, port)
if err != nil {
return nil, err
}
m.ports[port.URL] = ch
return ch, nil
}
// StateChannel manages state transactions over a single gRPC connection.
// It does not need to track readers and writers as carefully as the
// DataChannel, because the state protocol is request-based.
type StateChannel struct {
id string
client pb.BeamFnState_StateClient
requests chan *pb.StateRequest
nextRequestNo int32
responses map[string]chan<- *pb.StateResponse
mu sync.Mutex
}
func newStateChannel(ctx context.Context, port exec.Port) (*StateChannel, error) {
cc, err := dial(ctx, port.URL, 15*time.Second)
if err != nil {
return nil, errors.Wrap(err, "failed to connect")
}
client, err := pb.NewBeamFnStateClient(cc).State(ctx)
if err != nil {
cc.Close()
return nil, errors.Wrap(err, "failed to connect to data service")
}
ret := &StateChannel{
id: port.URL,
client: client,
requests: make(chan *pb.StateRequest, 10),
responses: make(map[string]chan<- *pb.StateResponse),
}
go ret.read(ctx)
go ret.write(ctx)
return ret, nil
}
func (c *StateChannel) read(ctx context.Context) {
for {
msg, err := c.client.Recv()
if err != nil {
if err == io.EOF {
// TODO(herohde) 10/12/2017: can this happen before shutdown? Reconnect?
log.Warnf(ctx, "StateChannel %v closed", c.id)
return
}
panic(errors.Wrapf(err, "state channel %v bad", c.id))
}
c.mu.Lock()
ch, ok := c.responses[msg.Id]
delete(c.responses, msg.Id)
c.mu.Unlock()
if !ok {
// This can happen if Send returns an error that write handles, but
// the message was actually sent.
log.Errorf(ctx, "no consumer for state response: %v", proto.MarshalTextString(msg))
continue
}
select {
case ch <- msg:
// ok
default:
panic(fmt.Sprintf("failed to consume state response: %v", proto.MarshalTextString(msg)))
}
}
}
func (c *StateChannel) write(ctx context.Context) {
for req := range c.requests {
err := c.client.Send(req)
if err == nil {
continue // ok
}
// Failed to send. Return error.
c.mu.Lock()
ch, ok := c.responses[req.Id]
delete(c.responses, req.Id)
c.mu.Unlock()
if ok {
ch <- &pb.StateResponse{Id: req.Id, Error: fmt.Sprintf("failed to send: %v", err)}
} // else ignore: already received response due to race
}
}
// Send sends a state request and returns the response.
func (c *StateChannel) Send(req *pb.StateRequest) (*pb.StateResponse, error) {
id := fmt.Sprintf("r%v", atomic.AddInt32(&c.nextRequestNo, 1))
req.Id = id
ch := make(chan *pb.StateResponse, 1)
c.mu.Lock()
c.responses[id] = ch
c.mu.Unlock()
c.requests <- req
// TODO(herohde) 7/21/2018: time out?
resp := <-ch
if resp.Error != "" {
return nil, errors.New(resp.Error)
}
return resp, nil
}