blob: f1473f199057648056a4df822d96925e45bb4b63 [file]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package xlang
import (
"flag"
"fmt"
"log"
"reflect"
"sort"
"testing"
"github.com/apache/beam/sdks/v2/go/examples/xlang"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/samza"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/spark"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
"github.com/apache/beam/sdks/v2/go/test/integration"
)
var expansionAddr string // Populate with expansion address labelled "test".
func init() {
beam.RegisterType(reflect.TypeOf((*IntString)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*StringInt)(nil)).Elem())
beam.RegisterFunction(formatIntStringsFn)
beam.RegisterFunction(formatStringIntFn)
beam.RegisterFunction(formatStringIntsFn)
beam.RegisterFunction(formatIntFn)
beam.RegisterFunction(getIntString)
beam.RegisterFunction(getStringInt)
beam.RegisterFunction(sumCounts)
beam.RegisterFunction(collectValues)
}
func checkFlags(t *testing.T) {
if expansionAddr == "" {
t.Skip("No Test expansion address provided.")
}
}
// formatIntStringsFn is a DoFn that formats an int64 and a list of strings.
func formatIntStringsFn(i int64, s []string) string {
sort.Strings(s)
return fmt.Sprintf("%v:%v", i, s)
}
// formatStringIntFn is a DoFn that formats a string and an int64.
func formatStringIntFn(s string, i int64) string {
return fmt.Sprintf("%s:%v", s, i)
}
// formatStringIntsFn is a DoFn that formats a string and a list of ints.
func formatStringIntsFn(s string, i []int) string {
sort.Ints(i)
return fmt.Sprintf("%v:%v", s, i)
}
// formatIntFn is a DoFn that formats an int64 as a string.
func formatIntFn(i int64) string {
return fmt.Sprintf("%v", i)
}
// IntString used to represent KV PCollection values of int64, string.
type IntString struct {
X int64
Y string
}
func getIntString(kv IntString, emit func(int64, string)) {
emit(kv.X, kv.Y)
}
// StringInt used to represent KV PCollection values of string, int64.
type StringInt struct {
X string
Y int64
}
func getStringInt(kv StringInt, emit func(string, int64)) {
emit(kv.X, kv.Y)
}
func sumCounts(key int64, iter1 func(*string) bool) (int64, []string) {
var val string
var values []string
for iter1(&val) {
values = append(values, val)
}
return key, values
}
func collectValues(key string, iter func(*int64) bool) (string, []int) {
var count int64
var values []int
for iter(&count) {
values = append(values, int(count))
}
return key, values
}
func TestXLang_Prefix(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
// Using the cross-language transform
strings := beam.Create(s, "a", "b", "c")
prefixed := xlang.Prefix(s, "prefix_", expansionAddr, strings)
passert.Equals(s, prefixed, "prefix_a", "prefix_b", "prefix_c")
ptest.RunAndValidate(t, p)
}
func TestXLang_CoGroupBy(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
// Using the cross-language transform
col1 := beam.ParDo(s, getIntString, beam.Create(s, IntString{X: 0, Y: "1"}, IntString{X: 0, Y: "2"}, IntString{X: 1, Y: "3"}))
col2 := beam.ParDo(s, getIntString, beam.Create(s, IntString{X: 0, Y: "4"}, IntString{X: 1, Y: "5"}, IntString{X: 1, Y: "6"}))
c := xlang.CoGroupByKey(s, expansionAddr, col1, col2)
sums := beam.ParDo(s, sumCounts, c)
formatted := beam.ParDo(s, formatIntStringsFn, sums)
passert.Equals(s, formatted, "0:[1 2 4]", "1:[3 5 6]")
ptest.RunAndValidate(t, p)
}
func TestXLang_Combine(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
// Using the cross-language transform
kvs := beam.Create(s, StringInt{X: "a", Y: 1}, StringInt{X: "a", Y: 2}, StringInt{X: "b", Y: 3})
ins := beam.ParDo(s, getStringInt, kvs)
c := xlang.CombinePerKey(s, expansionAddr, ins)
formatted := beam.ParDo(s, formatStringIntFn, c)
passert.Equals(s, formatted, "a:3", "b:3")
ptest.RunAndValidate(t, p)
}
func TestXLang_CombineGlobally(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
in := beam.CreateList(s, []int64{1, 2, 3})
// Using the cross-language transform
c := xlang.CombineGlobally(s, expansionAddr, in)
formatted := beam.ParDo(s, formatIntFn, c)
passert.Equals(s, formatted, "6")
ptest.RunAndValidate(t, p)
}
func TestXLang_Flatten(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
col1 := beam.CreateList(s, []int64{1, 2, 3})
col2 := beam.CreateList(s, []int64{4, 5, 6})
// Using the cross-language transform
c := xlang.Flatten(s, expansionAddr, col1, col2)
formatted := beam.ParDo(s, formatIntFn, c)
passert.Equals(s, formatted, "1", "2", "3", "4", "5", "6")
ptest.RunAndValidate(t, p)
}
func TestXLang_GroupBy(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
// Using the cross-language transform
kvs := beam.Create(s, StringInt{X: "0", Y: 1}, StringInt{X: "0", Y: 2}, StringInt{X: "1", Y: 3})
in := beam.ParDo(s, getStringInt, kvs)
out := xlang.GroupByKey(s, expansionAddr, in)
vals := beam.ParDo(s, collectValues, out)
formatted := beam.ParDo(s, formatStringIntsFn, vals)
passert.Equals(s, formatted, "0:[1 2]", "1:[3]")
ptest.RunAndValidate(t, p)
}
func TestXLang_Multi(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
main1 := beam.CreateList(s, []string{"a", "bb"})
main2 := beam.CreateList(s, []string{"x", "yy", "zzz"})
side := beam.CreateList(s, []string{"s"})
// Using the cross-language transform
mainOut, sideOut := xlang.Multi(s, expansionAddr, main1, main2, side)
passert.Equals(s, mainOut, "as", "bbs", "xs", "yys", "zzzs")
passert.Equals(s, sideOut, "ss")
ptest.RunAndValidate(t, p)
}
func TestXLang_Partition(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
p := beam.NewPipeline()
s := p.Root()
col := beam.CreateList(s, []int64{1, 2, 3, 4, 5, 6})
// Using the cross-language transform
out0, out1 := xlang.Partition(s, expansionAddr, col)
formatted0 := beam.ParDo(s, formatIntFn, out0)
formatted1 := beam.ParDo(s, formatIntFn, out1)
passert.Equals(s, formatted0, "2", "4", "6")
passert.Equals(s, formatted1, "1", "3", "5")
ptest.RunAndValidate(t, p)
}
func TestMain(m *testing.M) {
flag.Parse()
beam.Init()
services := integration.NewExpansionServices()
defer func() { services.Shutdown() }()
addr, err := services.GetAddr("test")
if err != nil {
log.Printf("skipping missing expansion service: %v", err)
} else {
expansionAddr = addr
}
ptest.MainRet(m)
}