blob: 83e254880926269429e2f8714714ce479ccb5a2d [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.wayang.api
import java.nio.file.{Files, Paths}
import java.sql.{Connection, Statement}
import java.util.function.Consumer
import org.apache.wayang.basic.WayangBasics
import org.apache.wayang.core.api.{Configuration, WayangContext}
import org.apache.wayang.core.function.FunctionDescriptor.ExtendedSerializablePredicate
import org.apache.wayang.core.function.{ExecutionContext, TransformationDescriptor}
import org.apache.wayang.core.util.fs.LocalFileSystem
import org.apache.wayang.spark.Spark
import org.apache.wayang.sqlite3.Sqlite3
import org.apache.wayang.sqlite3.operators.Sqlite3TableSource
import org.junit.{Assert, Test}
* Tests the Wayang API.
class ApiTest {
def testReadMapCollect(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = (for (i <- 1 to 10) yield i).toArray
// Build and execute a Wayang plan.
val outputValues = wayang
.loadCollection(inputValues).withName("Load input values")
.map(_ + 2).withName("Add 2")
// Check the outcome.
val expectedOutputValues = + 2)
Assert.assertArrayEquals(expectedOutputValues, outputValues.toArray)
def testCustomOperator(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = (for (i <- 1 to 10) yield i).toArray
// Build and execute a Wayang plan.
val inputDataSet = wayang.loadCollection(inputValues).withName("Load input values")
// Add the custom operator.
val IndexedSeq(addedValues) = wayang.customOperator(new JavaMapOperator(
new TransformationDescriptor(
toSerializableFunction[Int, Int](_ + 2),
basicDataUnitType[Int], basicDataUnitType[Int]
), inputDataSet)
addedValues.withName("Add 2")
// Collect the result.
val outputValues = addedValues.asInstanceOf[DataQuanta[Int]].collect()
// Check the outcome.
val expectedOutputValues = + 2)
Assert.assertArrayEquals(expectedOutputValues, outputValues.toArray)
def testCustomOperatorShortCut(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = (for (i <- 1 to 10) yield i).toArray
// Build and execute a Wayang plan.
val outputValues = wayang
.loadCollection(inputValues).withName("Load input values")
.customOperator[Int](new JavaMapOperator(
new TransformationDescriptor(
toSerializableFunction[Int, Int](_ + 2),
basicDataUnitType[Int], basicDataUnitType[Int]
)).withName("Add 2")
// Check the outcome.
val expectedOutputValues = + 2)
Assert.assertArrayEquals(expectedOutputValues, outputValues.toArray)
def testWordCount(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array("Big data is big.", "Is data big data?")
// Build and execute a word count WayangPlan.
val wordCounts = wayang
.loadCollection(inputValues).withName("Load input values")
.flatMap(_.split("\\s+")).withName("Split words")
.map(_.replaceAll("\\W+", "").toLowerCase).withName("To lowercase")
.map((_, 1)).withName("Attach counter")
.reduceByKey(_._1, (a, b) => (a._1, a._2 + b._2)).withName("Sum counters")
val expectedWordCounts = Set(("big", 3), ("is", 2), ("data", 3))
Assert.assertEquals(expectedWordCounts, wordCounts)
def testWordCountOnSparkAndJava(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array("Big data is big.", "Is data big data?")
// Build and execute a word count WayangPlan.
val wordCounts = wayang
.loadCollection(inputValues).withName("Load input values").withTargetPlatforms(Java.platform)
.flatMap(_.split("\\s+")).withName("Split words").withTargetPlatforms(Java.platform)
.map(_.replaceAll("\\W+", "").toLowerCase).withName("To lowercase").withTargetPlatforms(Spark.platform)
.map((_, 1)).withName("Attach counter").withTargetPlatforms(Spark.platform)
.reduceByKey(_._1, (a, b) => (a._1, a._2 + b._2)).withName("Sum counters").withTargetPlatforms(Spark.platform)
val expectedWordCounts = Set(("big", 3), ("is", 2), ("data", 3))
Assert.assertEquals(expectedWordCounts, wordCounts)
def testSample(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = for (i <- 0 until 100) yield i
// Build and execute the WayangPlan.
val sample = wayang
// Check the result.
Assert.assertEquals(10, sample.size)
Assert.assertEquals(10, sample.toSet.size)
def testDoWhile(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array(1, 2)
// Build and execute a word count WayangPlan.
val values = wayang
.loadCollection(inputValues).withName("Load input values")
.doWhile[Int](vals => vals.max > 100, {
start =>
val sum = start.reduce(_ + _).withName("Sum")
(start.union(sum).withName("Old+new"), sum)
}).withName("While <= 100")
val expectedValues = Set(1, 2, 3, 6, 12, 24, 48, 96, 192)
Assert.assertEquals(expectedValues, values)
def testRepeat(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array(1, 2)
// Build and execute a word count WayangPlan.
val values = wayang
.loadCollection(inputValues).withName("Load input values").withName(inputValues.mkString(","))
_.reduce(_ * _).withName("Multiply")
.flatMap(v => Seq(v, v + 1)).withName("Duplicate")
).withName("Repeat 3x")
// initial: 1,2 -> 1st: 2,3 -> 2nd: 6,7 => 3rd: 42,43
val expectedValues = Set(42, 43)
Assert.assertEquals(expectedValues, values)
def testBroadcast() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val builder = new PlanBuilder(wayang)
// Generate some test data.
val inputStrings = Array("Hello", "World", "Hi", "Mars")
val selectors = Array('o', 'l')
val selectorsDataSet = builder.loadCollection(selectors).withName("Load selectors")
// Build and execute a word count WayangPlan.
val values = builder
.loadCollection(inputStrings).withName("Load input values")
.filterJava(new ExtendedSerializablePredicate[String] {
var selectors: Iterable[Char] = _
override def open(ctx: ExecutionContext): Unit = {
import scala.collection.JavaConversions._
selectors = collectionAsScalaIterable(ctx.getBroadcast[Char]("selectors"))
override def test(t: String): Boolean = selectors.forall(selector => t.contains(selector))
}).withName("Filter words")
.withBroadcast(selectorsDataSet, "selectors")
val expectedValues = Set("Hello", "World")
Assert.assertEquals(expectedValues, values)
def testGroupBy() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues = Array(1, 2, 3, 4, 5, 7, 8, 9, 10)
val result = wayang
.groupByKey(_ % 2).withName("group odd and even")
.map {
group =>
import scala.collection.JavaConversions._
val buffer = group.toBuffer
if (buffer.size % 2 == 0) (buffer(buffer.size / 2 - 1) + buffer(buffer.size / 2)) / 2
else buffer(buffer.size / 2)
val expectedValues = Set(5, 6)
Assert.assertEquals(expectedValues, result.toSet)
def testGroup() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues = Array(1, 2, 3, 4, 5, 7, 8, 9, 10)
val result = wayang
.map {
group =>
import scala.collection.JavaConversions._
val buffer = group.toBuffer
buffer.sortBy(int => int)
if (buffer.size % 2 == 0) (buffer(buffer.size / 2) + buffer(buffer.size / 2 + 1)) / 2
else buffer(buffer.size / 2)
val expectedValues = Set(5)
Assert.assertEquals(expectedValues, result.toSet)
def testJoin() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(("Water", 0), ("Tonic", 5), ("Juice", 10))
val inputValues2 = Array(("Apple juice", "Juice"), ("Tap water", "Water"), ("Orange juice", "Juice"))
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1
.join[(String, String), String](_._1, dataQuanta2, _._2)
.map(joinTuple => (joinTuple.field1._1, joinTuple.field0._2))
val expectedValues = Set(("Apple juice", 10), ("Tap water", 0), ("Orange juice", 10))
Assert.assertEquals(expectedValues, result.toSet)
def testJoinAndAssemble() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(("Water", 0), ("Tonic", 5), ("Juice", 10))
val inputValues2 = Array(("Apple juice", "Juice"), ("Tap water", "Water"), ("Orange juice", "Juice"))
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1.keyBy(_._1).join(dataQuanta2.keyBy(_._2))
.assemble((dq1, dq2) => (dq2._1, dq1._2))
val expectedValues = Set(("Apple juice", 10), ("Tap water", 0), ("Orange juice", 10))
Assert.assertEquals(expectedValues, result.toSet)
def testCoGroup() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(("Water", 0), ("Cola", 5), ("Juice", 10))
val inputValues2 = Array(("Apple juice", "Juice"), ("Tap water", "Water"), ("Orange juice", "Juice"))
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1
.coGroup[(String, String), String](_._1, dataQuanta2, _._2)
import scala.collection.JavaConversions._
val actualValues = => (coGroup.field0.toSet, coGroup.field1.toSet)).toSet
val expectedValues = Set(
(Set(("Water", 0)), Set(("Tap water", "Water"))),
(Set(("Cola", 5)), Set()),
(Set(("Juice", 10)), Set(("Apple juice", "Juice"), ("Orange juice", "Juice")))
Assert.assertEquals(expectedValues, actualValues)
def testIntersect() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(1, 2, 3, 4, 5, 7, 8, 9, 10)
val inputValues2 = Array(0, 2, 3, 3, 4, 5, 7, 8, 9, 11)
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1
val expectedValues = Set(2, 3, 4, 5, 7, 8, 9)
Assert.assertEquals(expectedValues, result.toSet)
def testSort() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(3, 4, 5, 2, 1)
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val result = dataQuanta1
val expectedValues = Array(1, 2, 3, 4, 5)
Assert.assertArrayEquals(expectedValues, result.toArray)
def testPageRank() = {
// Set up WayangContext.
val wayang = new WayangContext()
import org.apache.wayang.api.graph._
val edges = Seq((0, 1), (0, 2), (0, 3), (1, 0), (2, 1), (3, 2), (3, 1)).map(t => Edge(t._1, t._2))
val pageRanks = wayang
.loadCollection(edges).withName("Load edges")
.map(t => t.field0.longValue -> t.field1)
// Let's not check absolute numbers but only the relative ordering.
Assert.assertTrue(pageRanks(1) > pageRanks(0))
Assert.assertTrue(pageRanks(0) > pageRanks(2))
Assert.assertTrue(pageRanks(2) > pageRanks(3))
def testMapPartitions() = {
// Set up WayangContext.
val wayang = new WayangContext()
val typeCounts = wayang
.loadCollection(Seq(0, 1, 2, 3, 4, 6, 8))
.mapPartitions { ints =>
var (numOdds, numEvens) = (0, 0)
ints.foreach(i => if ((i & 1) == 0) numEvens += 1 else numOdds += 1)
Seq(("odd", numOdds), ("even", numEvens))
.reduceByKey(_._1, { case ((kind1, count1), (kind2, count2)) => (kind1, count1 + count2) })
Assert.assertEquals(Set(("odd", 2), ("even", 5)), typeCounts.toSet)
def testZipWithId() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues = for (i <- 0 until 100; j <- 0 until 42) yield i
val result = wayang
.map { group =>
import scala.collection.JavaConversions._
(, 1)
.reduceByKey(_._1, (t1, t2) => (t1._1, t1._2 + t2._2))
val expectedValues = Set((42, 100))
Assert.assertEquals(expectedValues, result.toSet)
def testWriteTextFile() = {
val tempDir = LocalFileSystem.findTempDir
val targetUrl = LocalFileSystem.toURL(new File(tempDir, "testWriteTextFile.txt"))
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin)
val inputValues = for (i <- 0 to 5) yield i * 0.333333333333
val result = wayang
.writeTextFile(targetUrl, formatterUdf = d => f"${d % .2f}")
val lines = scala.collection.mutable.Set[String]()
Files.lines(Paths.get(new URI(targetUrl))).forEach(new Consumer[String] {
override def accept(line: String): Unit = lines += line
val expectedLines = => f"${v % .2f}").toSet
Assert.assertEquals(expectedLines, lines)
def testSqlOnJava() = {
// Initialize some test data.
val configuration = new Configuration
val sqlite3dbFile = File.createTempFile("wayang-sqlite3", "db")
configuration.setProperty("wayang.sqlite3.jdbc.url", "jdbc:sqlite:" + sqlite3dbFile.getAbsolutePath)
try {
val connection: Connection = Sqlite3.platform.createDatabaseDescriptor(configuration).createJdbcConnection
try {
val statement: Statement = connection.createStatement
statement.addBatch("DROP TABLE IF EXISTS customer;")
statement.addBatch("CREATE TABLE customer (name TEXT, age INT);")
statement.addBatch("INSERT INTO customer VALUES ('John', 20)")
statement.addBatch("INSERT INTO customer VALUES ('Timmy', 16)")
statement.addBatch("INSERT INTO customer VALUES ('Evelyn', 35)")
} finally {
if (connection != null) connection.close()
// Set up WayangContext.
val wayang = new WayangContext(configuration).withPlugin(Java.basicPlugin).withPlugin(Sqlite3.plugin)
val result = wayang
.readTable(new Sqlite3TableSource("customer", "name", "age"))
.filter(r => r.getField(1).asInstanceOf[Integer] >= 18, sqlUdf = "age >= 18").withTargetPlatforms(Java.platform)
val expectedValues = Set("John", "Evelyn")
Assert.assertEquals(expectedValues, result)
def testSqlOnSqlite3() = {
// Initialize some test data.
val configuration = new Configuration
val sqlite3dbFile = File.createTempFile("wayang-sqlite3", "db")
configuration.setProperty("wayang.sqlite3.jdbc.url", "jdbc:sqlite:" + sqlite3dbFile.getAbsolutePath)
try {
val connection: Connection = Sqlite3.platform.createDatabaseDescriptor(configuration).createJdbcConnection
try {
val statement: Statement = connection.createStatement
statement.addBatch("DROP TABLE IF EXISTS customer;")
statement.addBatch("CREATE TABLE customer (name TEXT, age INT);")
statement.addBatch("INSERT INTO customer VALUES ('John', 20)")
statement.addBatch("INSERT INTO customer VALUES ('Timmy', 16)")
statement.addBatch("INSERT INTO customer VALUES ('Evelyn', 35)")
} finally {
if (connection != null) connection.close()
// Set up WayangContext.
val wayang = new WayangContext(configuration).withPlugin(Java.basicPlugin).withPlugin(Sqlite3.plugin)
val result = wayang
.readTable(new Sqlite3TableSource("customer", "name", "age"))
.filter(r => r.getField(1).asInstanceOf[Integer] >= 18, sqlUdf = "age >= 18")
val expectedValues = Set("John", "Evelyn")
Assert.assertEquals(expectedValues, result)