blob: 375fd8f37bca80d2d1d0837990f9872e20a42e7c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package harness
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"log"
"strings"
"sync"
"testing"
"time"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
)
// fakeStateClient replicates the call and response protocol
// of the state channel.
type fakeStateClient struct {
// Blocks the read routine
recv chan *fnpb.StateResponse
recvErr error
recvMu sync.Mutex
// Blocks the write routine
send chan *fnpb.StateRequest
sendErr error
sendMu sync.Mutex
}
func (f *fakeStateClient) Recv() (*fnpb.StateResponse, error) {
// Blocks until something is sent.
v := <-f.recv
f.recvMu.Lock()
defer f.recvMu.Unlock()
return v, f.recvErr
}
func (f *fakeStateClient) Send(req *fnpb.StateRequest) error {
f.send <- req // blocks until consumed.
f.sendMu.Lock()
defer f.sendMu.Unlock()
return f.sendErr
}
func (f *fakeStateClient) setRecvErr(err error) {
f.recvMu.Lock()
defer f.recvMu.Unlock()
f.recvErr = err
}
func (f *fakeStateClient) setSendErr(err error) {
f.sendMu.Lock()
defer f.sendMu.Unlock()
f.sendErr = err
}
func TestStateChannel(t *testing.T) {
// The logging of channels closed is quite noisy for this test
log.SetOutput(ioutil.Discard)
expectedError := fmt.Errorf("EXPECTED ERROR")
tests := []struct {
name string
caseFn func(t *testing.T, c *StateChannel, client *fakeStateClient) error
expectedErr error
validateCancelled bool
}{
{
name: "HappyPath",
caseFn: func(t *testing.T, c *StateChannel, client *fakeStateClient) error {
// Verify that we can send a bunch of requests and get the expected paired response back.
// This is not a real StateAPI protocol, since that's handled at a higher level than the StateChannel,
// but validates that the routing occurs correctly.
const count = 10
for i := 0; i < count; i++ {
go func() {
req := <-client.send
client.recv <- &fnpb.StateResponse{
Id: req.Id, // Ids need to match up to ensure routing can occur properly.
Response: &fnpb.StateResponse_Get{
Get: &fnpb.StateGetResponse{
ContinuationToken: req.GetGet().GetContinuationToken(),
},
},
}
}()
}
for i := 0; i < count; i++ {
token := []byte(fmt.Sprintf("%d", i))
resp, err := c.Send(&fnpb.StateRequest{
Request: &fnpb.StateRequest_Get{
Get: &fnpb.StateGetRequest{
ContinuationToken: token,
},
},
})
if err != nil {
t.Fatalf("unexpected error from Send: %v", err)
}
if got, want := string(resp.GetGet().GetContinuationToken()), string(token); got != want {
t.Fatalf("req/response mismatch from Send: got %v, want %v", got, want)
}
}
return nil
},
}, {
name: "readEOF",
caseFn: func(t *testing.T, c *StateChannel, client *fakeStateClient) error {
go func() {
req := <-client.send // Send should succeed.
client.setRecvErr(io.EOF)
client.recv <- &fnpb.StateResponse{
Id: req.Id,
}
}()
_, err := c.Send(&fnpb.StateRequest{})
return err
},
expectedErr: io.EOF,
validateCancelled: true,
}, {
name: "readOtherErr",
caseFn: func(t *testing.T, c *StateChannel, client *fakeStateClient) error {
go func() {
req := <-client.send // Send should succeed.
client.setRecvErr(expectedError)
client.recv <- &fnpb.StateResponse{
Id: req.Id,
}
}()
_, err := c.Send(&fnpb.StateRequest{})
return err
},
expectedErr: expectedError,
validateCancelled: true,
}, {
name: "readResponseChannelDeleted",
caseFn: func(t *testing.T, c *StateChannel, client *fakeStateClient) error {
go func() {
req := <-client.send // Send should succeed.
c.mu.Lock()
ch := c.responses[req.Id]
delete(c.responses, req.Id)
c.mu.Unlock()
resp := &fnpb.StateResponse{
Id: req.Id,
}
client.recv <- resp
// unblock Send.
ch <- resp
}()
_, err := c.Send(&fnpb.StateRequest{})
return err
},
}, {
name: "writeEOF",
caseFn: func(t *testing.T, c *StateChannel, client *fakeStateClient) error {
go func() {
client.setSendErr(io.EOF)
req := <-client.send
// This can be plumbed through on either side, write or read,
// the important part is that we get it.
client.setRecvErr(expectedError)
client.recv <- &fnpb.StateResponse{
Id: req.Id,
}
}()
_, err := c.Send(&fnpb.StateRequest{})
return err
},
expectedErr: expectedError,
validateCancelled: true,
}, {
name: "writeOtherError",
caseFn: func(t *testing.T, c *StateChannel, client *fakeStateClient) error {
go func() {
client.setSendErr(expectedError)
<-client.send
// Shouldn't need to unblock any Recv calls.
}()
_, err := c.Send(&fnpb.StateRequest{})
return err
},
expectedErr: expectedError,
validateCancelled: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := &fakeStateClient{
recv: make(chan *fnpb.StateResponse),
send: make(chan *fnpb.StateRequest),
}
ctx, cancelFn := context.WithCancel(context.Background())
c := makeStateChannel(ctx, "id", client, cancelFn)
forceRecreateCalled := false
var forceRecreateError error
c.forceRecreate = func(_ string, err error) {
forceRecreateCalled = true
forceRecreateError = err
}
retErr := test.caseFn(t, c, client)
if got, want := retErr, test.expectedErr; !contains(got, want) {
t.Errorf("Unexpected error: got %v, want %v", got, want)
}
// Verify that new Sends return the same error on their reads after client.Recv is done.
go func() {
// Ensure that the client isn't helping us.
client.setSendErr(nil)
client.setRecvErr(nil)
// Drain the next send, and ensure the response is unblocked.
req := <-client.send
client.recv <- &fnpb.StateResponse{Id: req.Id} // Ids need to match up to ensure routing can occur properly.
}()
if _, err := c.Send(&fnpb.StateRequest{}); !contains(err, test.expectedErr) {
t.Errorf("Unexpected error from Send: got %v, want %v", err, test.expectedErr)
}
if test.validateCancelled {
select {
case <-ctx.Done(): // Assert that the context must have been cancelled on read failures.
case <-time.After(time.Second * 5):
t.Fatal("context wasn't cancelled")
}
if !forceRecreateCalled {
t.Fatal("forceRecreate wasn't called")
}
if got, want := forceRecreateError, test.expectedErr; !contains(got, want) {
t.Errorf("Unexpected error from forceRecreate: got %v, want %v", got, want)
}
}
})
}
}
// TestStateKeyReader validates ordinary Read cases
func TestStateKeyReader(t *testing.T) {
const readLen = 4
tests := []struct {
name string
buflens []int // sizes of the buffers received on the state channel.
numReads int
closed bool // tries to read from closed reader
noGet bool // tries to read from nil get response reader
}{
{
name: "emptyData",
buflens: []int{-1},
numReads: 1,
}, {
name: "singleBufferSingleRead",
buflens: []int{readLen},
numReads: 2,
}, {
name: "singleBufferMultipleReads",
buflens: []int{2 * readLen},
numReads: 3,
}, {
name: "singleBufferShortRead",
buflens: []int{readLen - 1},
numReads: 2,
}, {
name: "multiBuffer",
buflens: []int{readLen, readLen},
numReads: 3,
}, {
name: "multiBuffer-short-reads",
buflens: []int{readLen - 1, readLen - 1, readLen - 2},
numReads: 4,
}, {
name: "emptyDataFirst", // Shouldn't happen, but not unreasonable to handle.
buflens: []int{-1, readLen, readLen},
numReads: 4,
}, {
name: "emptyDataMid", // Shouldn't happen, but not unreasonable to handle.
buflens: []int{readLen, readLen, -1, readLen},
numReads: 5,
}, {
name: "emptyDataLast", // Shouldn't happen, but not unreasonable to handle.
buflens: []int{readLen, readLen, -1},
numReads: 3,
}, {
name: "emptyDataLast-short",
buflens: []int{3*readLen - 2, -1},
numReads: 4,
}, {
name: "closed",
buflens: []int{-1, -1},
numReads: 1,
closed: true,
}, {
name: "noGet",
buflens: []int{-1},
numReads: 1,
noGet: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx, cancelFn := context.WithCancel(context.Background())
ch := &StateChannel{
id: "test",
requests: make(chan *fnpb.StateRequest),
responses: make(map[string]chan<- *fnpb.StateResponse),
cancelFn: cancelFn,
DoneCh: ctx.Done(),
}
// Handle the channel behavior asynchronously.
go func() {
if test.noGet {
req := <-ch.requests
ch.responses[req.Id] <- &fnpb.StateResponse{
Id: req.Id,
}
return
}
for i, buflen := range test.buflens {
var buf []byte
if buflen >= 0 {
buf = bytes.Repeat([]byte{42}, buflen)
}
token := []byte(fmt.Sprint(i))
if i+1 == len(test.buflens) {
// On the last request response pair, send no token.
token = nil
}
req := <-ch.requests
ch.responses[req.Id] <- &fnpb.StateResponse{
Id: req.Id,
Response: &fnpb.StateResponse_Get{
Get: &fnpb.StateGetResponse{
ContinuationToken: token,
Data: buf,
},
},
}
}
}()
r := stateKeyReader{
ch: ch,
}
if test.closed {
err := r.Close()
if err != nil {
t.Errorf("unexpected error on Close(), got %v", err)
}
}
var totalBytes int
for _, l := range test.buflens {
if l > 0 {
totalBytes += l
}
}
var finalerr error
var count, reads int
// Read all the bytes.
for count <= totalBytes {
reads++
b := make([]byte, readLen) // io.Read is keyed off of length, not capacity.
n, err := r.Read(b)
if err != nil {
finalerr = err
break
}
count += n
// Special check to avoid spurious zero elements.
if count == totalBytes && n == 0 {
t.Error("expected byte count read, last read is 0, but no EOF")
}
}
if got, want := reads, test.numReads; got != want {
t.Errorf("read %d times, want %d", got, want)
}
if got, want := count, totalBytes; got != want {
t.Errorf("read %v bytes, want %v", got, want)
}
if test.closed {
if got, want := finalerr, errors.New("side input closed"); !contains(got, want) {
t.Errorf("got err %q, want to contain %q", got.Error(), want.Error())
}
return
}
if got, want := finalerr, io.EOF; got != want {
t.Errorf("got err %q, want %q", got.Error(), want.Error())
}
})
}
}
// This likely can't be replaced by the "errors" package helpers,
// since we serialize errors in some cases.
func contains(got, want error) bool {
if got == want {
return true
}
return strings.Contains(got.Error(), want.Error())
}