blob: 0fa6d23674343ec3f2fb8b1a9585fa122dcbd060 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package exec
import (
"context"
"fmt"
"io"
"testing"
"github.com/apache/beam/sdks/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/go/pkg/beam/core/graph/mtime"
"github.com/apache/beam/sdks/go/pkg/beam/core/graph/window"
)
func TestDataSource_PerElement(t *testing.T) {
tests := []struct {
name string
expected []interface{}
Coder *coder.Coder
driver func(*coder.Coder, io.WriteCloser, []interface{})
}{
{
name: "perElement",
expected: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewVarInt(), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, pw io.WriteCloser, expected []interface{}) {
wc := MakeWindowEncoder(c.Window)
ec := MakeElementEncoder(coder.SkipW(c))
for _, v := range expected {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, pw)
ec.Encode(&FullValue{Elm: v}, pw)
}
pw.Close()
},
},
// TODO: Test progress.
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
out := &CaptureNode{UID: 1}
source := &DataSource{
UID: 2,
SID: StreamID{PtransformID: "myPTransform"},
Name: test.name,
Coder: test.Coder,
Out: out,
}
pr, pw := io.Pipe()
go test.driver(source.Coder, pw, test.expected)
constructAndExecutePlanWithContext(t, []Unit{out, source}, DataContext{
Data: &TestDataManager{R: pr},
})
validateSource(t, out, source, makeValues(test.expected...))
})
}
}
const tokenString = "token"
// TestDataSource_Iterators per wire protocols for ITERABLEs beam_runner_api.proto
func TestDataSource_Iterators(t *testing.T) {
extractCoders := func(c *coder.Coder) (WindowEncoder, ElementEncoder, ElementEncoder) {
wc := MakeWindowEncoder(c.Window)
cc := coder.SkipW(c)
kc := MakeElementEncoder(cc.Components[0])
vc := MakeElementEncoder(cc.Components[1])
return wc, kc, vc
}
tests := []struct {
name string
keys, vals []interface{}
Coder *coder.Coder
driver func(c *coder.Coder, dmw io.WriteCloser, siwFn func() io.WriteCloser, ks, vs []interface{})
}{
{
name: "beam:coder:iterable:v1-singleChunk",
keys: []interface{}{int64(42), int64(53)},
vals: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []interface{}) {
wc, kc, vc := extractCoders(c)
for _, k := range ks {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, dmw)
kc.Encode(&FullValue{Elm: k}, dmw)
coder.EncodeInt32(int32(len(vs)), dmw) // Number of elements.
for _, v := range vs {
vc.Encode(&FullValue{Elm: v}, dmw)
}
}
dmw.Close()
},
},
{
name: "beam:coder:iterable:v1-multiChunk",
keys: []interface{}{int64(42), int64(53)},
vals: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []interface{}) {
wc, kc, vc := extractCoders(c)
for _, k := range ks {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, dmw)
kc.Encode(&FullValue{Elm: k}, dmw)
coder.EncodeInt32(-1, dmw) // Mark this as a multi-Chunk (though beam runner proto says to use 0)
for _, v := range vs {
coder.EncodeVarInt(1, dmw) // Number of elements in this chunk.
vc.Encode(&FullValue{Elm: v}, dmw)
}
coder.EncodeVarInt(0, dmw) // Terminate the multi-chunk for this key.
}
dmw.Close()
},
},
{
name: "beam:coder:state_backed_iterable:v1",
keys: []interface{}{int64(42), int64(53)},
vals: []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)},
Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()),
driver: func(c *coder.Coder, dmw io.WriteCloser, swFn func() io.WriteCloser, ks, vs []interface{}) {
wc, kc, vc := extractCoders(c)
for _, k := range ks {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, dmw)
kc.Encode(&FullValue{Elm: k}, dmw)
coder.EncodeInt32(-1, dmw) // Mark as multi-chunk (though beam, runner says to use 0)
coder.EncodeVarInt(-1, dmw) // Mark subsequent chunks as "state backed"
token := []byte(tokenString)
coder.EncodeVarInt(int64(len(token)), dmw) // token.
dmw.Write(token)
// Each state stream needs to be a different writer, so get a new writer.
sw := swFn()
for _, v := range vs {
vc.Encode(&FullValue{Elm: v}, sw)
}
sw.Close()
}
dmw.Close()
},
},
// TODO: Test progress.
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
out := &IteratorCaptureNode{CaptureNode: CaptureNode{UID: 1}}
source := &DataSource{
UID: 2,
SID: StreamID{PtransformID: "myPTransform"},
Name: test.name,
Coder: test.Coder,
Out: out,
}
dmr, dmw := io.Pipe()
// Simulate individual state channels with pipes and a channel.
sRc := make(chan io.ReadCloser)
swFn := func() io.WriteCloser {
sr, sw := io.Pipe()
sRc <- sr
return sw
}
go test.driver(source.Coder, dmw, swFn, test.keys, test.vals)
constructAndExecutePlanWithContext(t, []Unit{out, source}, DataContext{
Data: &TestDataManager{R: dmr},
State: &TestStateReader{Rc: sRc},
})
if len(out.CapturedInputs) == 0 {
t.Fatal("did not capture source output")
}
expectedKeys := makeValues(test.keys...)
expectedValues := makeValuesNoWindowOrTime(test.vals...)
if got, want := len(out.CapturedInputs), len(expectedKeys); got != want {
t.Fatalf("lengths don't match: got %v, want %v", got, want)
}
var iVals []FullValue
for _, i := range out.CapturedInputs {
iVals = append(iVals, i.Key)
if got, want := i.Values, expectedValues; !equalList(got, want) {
t.Errorf("DataSource => key(%v) = %#v, want %#v", i.Key, extractValues(got...), extractValues(want...))
}
}
if got, want := iVals, expectedKeys; !equalList(got, want) {
t.Errorf("DataSource => %#v, want %#v", extractValues(got...), extractValues(want...))
}
})
}
}
func TestDataSource_Split(t *testing.T) {
elements := []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5)}
initSourceTest := func(name string) (*DataSource, *CaptureNode, io.ReadCloser) {
out := &CaptureNode{UID: 1}
c := coder.NewW(coder.NewVarInt(), coder.NewGlobalWindow())
source := &DataSource{
UID: 2,
SID: StreamID{PtransformID: "myPTransform"},
Name: name,
Coder: c,
Out: out,
}
pr, pw := io.Pipe()
go func(c *coder.Coder, pw io.WriteCloser, elements []interface{}) {
wc := MakeWindowEncoder(c.Window)
ec := MakeElementEncoder(coder.SkipW(c))
for _, v := range elements {
EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, pw)
ec.Encode(&FullValue{Elm: v}, pw)
}
pw.Close()
}(c, pw, elements)
return source, out, pr
}
tests := []struct {
name string
expected []interface{}
splitIdx int64
}{
{splitIdx: 1},
{splitIdx: 2},
{splitIdx: 3},
{splitIdx: 4},
{splitIdx: 5},
{
name: "wellBeyondRange",
expected: elements,
splitIdx: 1000,
},
}
for _, test := range tests {
test := test
if len(test.name) == 0 {
test.name = fmt.Sprintf("atIndex%d", test.splitIdx)
}
if test.expected == nil {
test.expected = elements[:test.splitIdx]
}
t.Run(test.name, func(t *testing.T) {
source, out, pr := initSourceTest(test.name)
p, err := NewPlan("a", []Unit{out, source})
if err != nil {
t.Fatalf("failed to construct plan: %v", err)
}
dc := DataContext{Data: &TestDataManager{R: pr}}
ctx := context.Background()
// StartBundle resets the source, so no splits can be actuated before then,
// which means we need to actuate the plan manually, and insert the split request
// after StartBundle.
for i, root := range p.units {
if err := root.Up(ctx); err != nil {
t.Fatalf("error in root[%d].Up: %v", i, err)
}
}
p.status = Active
runOnRoots(ctx, t, p, "StartBundle", func(root Root, ctx context.Context) error { return root.StartBundle(ctx, "1", dc) })
// SDK never splits on 0, so check that every test.
if splitIdx, err := p.Split(SplitPoints{Splits: []int64{0, test.splitIdx}}); err != nil {
t.Fatalf("error in Split: %v", err)
} else if got, want := splitIdx, test.splitIdx; got != want {
t.Fatalf("error in Split: got splitIdx = %v, want %v ", got, want)
}
runOnRoots(ctx, t, p, "Process", Root.Process)
runOnRoots(ctx, t, p, "FinishBundle", Root.FinishBundle)
validateSource(t, out, source, makeValues(test.expected...))
})
}
// Test expects splitting errors, but for processing to be successful.
t.Run("errors", func(t *testing.T) {
source, out, pr := initSourceTest("noSplitsUntilStarted")
p, err := NewPlan("a", []Unit{out, source})
if err != nil {
t.Fatalf("failed to construct plan: %v", err)
}
dc := DataContext{Data: &TestDataManager{R: pr}}
ctx := context.Background()
if _, err := p.Split(SplitPoints{Splits: []int64{0, 3}, Frac: -1}); err == nil {
t.Fatal("plan uninitialized, expected error when splitting, got nil")
}
for i, root := range p.units {
if err := root.Up(ctx); err != nil {
t.Fatalf("error in root[%d].Up: %v", i, err)
}
}
p.status = Active
if _, err := p.Split(SplitPoints{Splits: []int64{0, 3}, Frac: -1}); err == nil {
t.Fatal("plan not started, expected error when splitting, got nil")
}
runOnRoots(ctx, t, p, "StartBundle", func(root Root, ctx context.Context) error { return root.StartBundle(ctx, "1", dc) })
if _, err := p.Split(SplitPoints{Splits: []int64{0}, Frac: -1}); err == nil {
t.Fatal("plan started, expected error when splitting, got nil")
}
runOnRoots(ctx, t, p, "Process", Root.Process)
if _, err := p.Split(SplitPoints{Splits: []int64{0}, Frac: -1}); err == nil {
t.Fatal("plan in progress, expected error when unable to get a desired split, got nil")
}
runOnRoots(ctx, t, p, "FinishBundle", Root.FinishBundle)
if _, err := p.Split(SplitPoints{Splits: []int64{0}, Frac: -1}); err == nil {
t.Fatal("plan finished, expected error when splitting, got nil")
}
validateSource(t, out, source, makeValues(elements...))
})
t.Run("sanity_errors", func(t *testing.T) {
var source *DataSource
if _, err := source.Split([]int64{0}, -1); err == nil {
t.Fatal("expected error splitting nil *DataSource")
}
if _, err := source.Split(nil, -1); err == nil {
t.Fatal("expected error splitting nil desired splits")
}
})
}
func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd func(Root, context.Context) error) {
t.Helper()
for i, root := range p.roots {
if err := mthd(root, ctx); err != nil {
t.Fatalf("error in root[%d].%s: %v", i, name, err)
}
}
}
type TestDataManager struct {
R io.ReadCloser
}
func (dm *TestDataManager) OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) {
return dm.R, nil
}
func (dm *TestDataManager) OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) {
return nil, nil
}
// TestSideInputReader simulates state reads using channels.
type TestStateReader struct {
StateReader
Rc <-chan io.ReadCloser
}
func (si *TestStateReader) OpenIterable(ctx context.Context, id StreamID, key []byte) (io.ReadCloser, error) {
return <-si.Rc, nil
}
func constructAndExecutePlanWithContext(t *testing.T, us []Unit, dc DataContext) {
p, err := NewPlan("a", us)
if err != nil {
t.Fatalf("failed to construct plan: %v", err)
}
if err := p.Execute(context.Background(), "1", dc); err != nil {
t.Fatalf("execute failed: %v", err)
}
if err := p.Down(context.Background()); err != nil {
t.Fatalf("down failed: %v", err)
}
}
func validateSource(t *testing.T, out *CaptureNode, source *DataSource, expected []FullValue) {
t.Helper()
if got, want := len(out.Elements), len(expected); got != want {
t.Fatalf("lengths don't match: got %v, want %v", got, want)
}
if got, want := source.Progress().Count, int64(len(expected)); got != want {
t.Fatalf("progress count didn't match: got %v, want %v", got, want)
}
if !equalList(out.Elements, expected) {
t.Errorf("DataSource => %#v, want %#v", extractValues(out.Elements...), extractValues(expected...))
}
}