blob: 0e703c617b0bc5136a75f21091d11bf7ccc26362 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.wayang.api
import org.apache.wayang.api.dataquanta.DataQuanta
import java.io.File
import java.net.URI
import java.nio.file.{Files, Paths}
import java.sql.{Connection, Statement}
import java.util.function.Consumer
import org.junit.{Assert, Test}
import org.apache.wayang.basic.WayangBasics
import org.apache.wayang.core.api.{Configuration, WayangContext}
import org.apache.wayang.core.function.FunctionDescriptor.ExtendedSerializablePredicate
import org.apache.wayang.core.function.{ExecutionContext, TransformationDescriptor}
import org.apache.wayang.core.util.fs.LocalFileSystem
import org.apache.wayang.java.Java
import org.apache.wayang.java.operators.JavaMapOperator
import org.apache.wayang.spark.Spark
import org.apache.wayang.sqlite3.Sqlite3
import org.apache.wayang.sqlite3.operators.Sqlite3TableSource
/**
* Tests the Wayang API.
*/
class ApiTest {
@Test
def testReadMapCollect(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = (for (i <- 1 to 10) yield i).toArray
// Build and execute a Wayang plan.
val outputValues = wayang
.loadCollection(inputValues).withName("Load input values")
.map(_ + 2).withName("Add 2")
.collect()
// Check the outcome.
val expectedOutputValues = inputValues.map(_ + 2)
Assert.assertArrayEquals(expectedOutputValues, outputValues.toArray)
}
@Test
def testCustomOperator(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = (for (i <- 1 to 10) yield i).toArray
// Build and execute a Wayang plan.
val inputDataSet = wayang.loadCollection(inputValues).withName("Load input values")
// Add the custom operator.
val IndexedSeq(addedValues) = wayang.customOperator(new JavaMapOperator(
dataSetType[Int],
dataSetType[Int],
new TransformationDescriptor(
toSerializableFunction[Int, Int](_ + 2),
basicDataUnitType[Int], basicDataUnitType[Int]
)
), inputDataSet)
addedValues.withName("Add 2")
// Collect the result.
val outputValues = addedValues.asInstanceOf[DataQuanta[Int]].collect()
// Check the outcome.
val expectedOutputValues = inputValues.map(_ + 2)
Assert.assertArrayEquals(expectedOutputValues, outputValues.toArray)
}
@Test
def testCustomOperatorShortCut(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = (for (i <- 1 to 10) yield i).toArray
// Build and execute a Wayang plan.
val outputValues = wayang
.loadCollection(inputValues).withName("Load input values")
.customOperator[Int](new JavaMapOperator(
dataSetType[Int],
dataSetType[Int],
new TransformationDescriptor(
toSerializableFunction[Int, Int](_ + 2),
basicDataUnitType[Int], basicDataUnitType[Int]
)
)).withName("Add 2")
.collect()
// Check the outcome.
val expectedOutputValues = inputValues.map(_ + 2)
Assert.assertArrayEquals(expectedOutputValues, outputValues.toArray)
}
@Test
def testWordCount(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array("Big data is big.", "Is data big data?")
// Build and execute a word count WayangPlan.
val wordCounts = wayang
.loadCollection(inputValues).withName("Load input values")
.flatMap(_.split("\\s+")).withName("Split words")
.map(_.replaceAll("\\W+", "").toLowerCase).withName("To lowercase")
.map((_, 1)).withName("Attach counter")
.reduceByKey(_._1, (a, b) => (a._1, a._2 + b._2)).withName("Sum counters")
.collect().toSet
val expectedWordCounts = Set(("big", 3), ("is", 2), ("data", 3))
Assert.assertEquals(expectedWordCounts, wordCounts)
}
@Test
def testWordCountOnSparkAndJava(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array("Big data is big.", "Is data big data?")
// Build and execute a word count WayangPlan.
val wordCounts = wayang
.loadCollection(inputValues).withName("Load input values").withTargetPlatforms(Java.platform)
.flatMap(_.split("\\s+")).withName("Split words").withTargetPlatforms(Java.platform)
.map(_.replaceAll("\\W+", "").toLowerCase).withName("To lowercase").withTargetPlatforms(Spark.platform)
.map((_, 1)).withName("Attach counter").withTargetPlatforms(Spark.platform)
.reduceByKey(_._1, (a, b) => (a._1, a._2 + b._2)).withName("Sum counters").withTargetPlatforms(Spark.platform)
.collect().toSet
val expectedWordCounts = Set(("big", 3), ("is", 2), ("data", 3))
Assert.assertEquals(expectedWordCounts, wordCounts)
}
@Test
def testSample(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = for (i <- 0 until 100) yield i
// Build and execute the WayangPlan.
val sample = wayang
.loadCollection(inputValues)
.sample(10)
.collect()
// Check the result.
Assert.assertEquals(10, sample.size)
Assert.assertEquals(10, sample.toSet.size)
}
@Test
def testDoWhile(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array(1, 2)
// Build and execute a word count WayangPlan.
val values = wayang
.loadCollection(inputValues).withName("Load input values")
.doWhile[Int](vals => vals.max > 100, {
start =>
val sum = start.reduce(_ + _).withName("Sum")
(start.union(sum).withName("Old+new"), sum)
}).withName("While <= 100")
.collect().toSet
val expectedValues = Set(1, 2, 3, 6, 12, 24, 48, 96, 192)
Assert.assertEquals(expectedValues, values)
}
@Test
def testRepeat(): Unit = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
// Generate some test data.
val inputValues = Array(1, 2)
// Build and execute a word count WayangPlan.
val values = wayang
.loadCollection(inputValues).withName("Load input values").withName(inputValues.mkString(","))
.repeat(3,
_.reduce(_ * _).withName("Multiply")
.flatMap(v => Seq(v, v + 1)).withName("Duplicate")
).withName("Repeat 3x")
.collect().toSet
// initial: 1,2 -> 1st: 2,3 -> 2nd: 6,7 => 3rd: 42,43
val expectedValues = Set(42, 43)
Assert.assertEquals(expectedValues, values)
}
@Test
def testBroadcast() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val builder = new PlanBuilder(wayang)
// Generate some test data.
val inputStrings = Array("Hello", "World", "Hi", "Mars")
val selectors = Array('o', 'l')
val selectorsDataSet = builder.loadCollection(selectors).withName("Load selectors")
// Build and execute a word count WayangPlan.
val values = builder
.loadCollection(inputStrings).withName("Load input values")
.filterJava(new ExtendedSerializablePredicate[String] {
var selectors: Iterable[Char] = _
override def open(ctx: ExecutionContext): Unit = {
import scala.collection.JavaConversions._
selectors = collectionAsScalaIterable(ctx.getBroadcast[Char]("selectors"))
}
override def test(t: String): Boolean = selectors.forall(selector => t.contains(selector))
}).withName("Filter words")
.withBroadcast(selectorsDataSet, "selectors")
.collect().toSet
val expectedValues = Set("Hello", "World")
Assert.assertEquals(expectedValues, values)
}
@Test
def testGroupBy() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues = Array(1, 2, 3, 4, 5, 7, 8, 9, 10)
val result = wayang
.loadCollection(inputValues)
.groupByKey(_ % 2).withName("group odd and even")
.map {
group =>
import scala.collection.JavaConversions._
val buffer = group.toBuffer
buffer.sortBy(identity)
if (buffer.size % 2 == 0) (buffer(buffer.size / 2 - 1) + buffer(buffer.size / 2)) / 2
else buffer(buffer.size / 2)
}.withName("median")
.collect()
val expectedValues = Set(5, 6)
Assert.assertEquals(expectedValues, result.toSet)
}
@Test
def testGroup() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues = Array(1, 2, 3, 4, 5, 7, 8, 9, 10)
val result = wayang
.loadCollection(inputValues)
.group()
.map {
group =>
import scala.collection.JavaConversions._
val buffer = group.toBuffer
buffer.sortBy(int => int)
if (buffer.size % 2 == 0) (buffer(buffer.size / 2) + buffer(buffer.size / 2 + 1)) / 2
else buffer(buffer.size / 2)
}
.collect()
val expectedValues = Set(5)
Assert.assertEquals(expectedValues, result.toSet)
}
@Test
def testJoin() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(("Water", 0), ("Tonic", 5), ("Juice", 10))
val inputValues2 = Array(("Apple juice", "Juice"), ("Tap water", "Water"), ("Orange juice", "Juice"))
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1
.join[(String, String), String](_._1, dataQuanta2, _._2)
.map(joinTuple => (joinTuple.field1._1, joinTuple.field0._2))
.collect()
val expectedValues = Set(("Apple juice", 10), ("Tap water", 0), ("Orange juice", 10))
Assert.assertEquals(expectedValues, result.toSet)
}
@Test
def testJoinAndAssemble() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(("Water", 0), ("Tonic", 5), ("Juice", 10))
val inputValues2 = Array(("Apple juice", "Juice"), ("Tap water", "Water"), ("Orange juice", "Juice"))
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1.keyBy(_._1).join(dataQuanta2.keyBy(_._2))
.assemble((dq1, dq2) => (dq2._1, dq1._2))
.collect()
val expectedValues = Set(("Apple juice", 10), ("Tap water", 0), ("Orange juice", 10))
Assert.assertEquals(expectedValues, result.toSet)
}
@Test
def testCoGroup() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(("Water", 0), ("Cola", 5), ("Juice", 10))
val inputValues2 = Array(("Apple juice", "Juice"), ("Tap water", "Water"), ("Orange juice", "Juice"))
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1
.coGroup[(String, String), String](_._1, dataQuanta2, _._2)
.collect()
import scala.collection.JavaConversions._
val actualValues = result.map(coGroup => (coGroup.field0.toSet, coGroup.field1.toSet)).toSet
val expectedValues = Set(
(Set(("Water", 0)), Set(("Tap water", "Water"))),
(Set(("Cola", 5)), Set()),
(Set(("Juice", 10)), Set(("Apple juice", "Juice"), ("Orange juice", "Juice")))
)
Assert.assertEquals(expectedValues, actualValues)
}
@Test
def testIntersect() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(1, 2, 3, 4, 5, 7, 8, 9, 10)
val inputValues2 = Array(0, 2, 3, 3, 4, 5, 7, 8, 9, 11)
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val dataQuanta2 = builder.loadCollection(inputValues2)
val result = dataQuanta1
.intersect(dataQuanta2)
.collect()
val expectedValues = Set(2, 3, 4, 5, 7, 8, 9)
Assert.assertEquals(expectedValues, result.toSet)
}
@Test
def testSort() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues1 = Array(3, 4, 5, 2, 1)
val builder = new PlanBuilder(wayang)
val dataQuanta1 = builder.loadCollection(inputValues1)
val result = dataQuanta1
.sort(r=>r)
.collect()
val expectedValues = Array(1, 2, 3, 4, 5)
Assert.assertArrayEquals(expectedValues, result.toArray)
}
@Test
def testPageRank() = {
// Set up WayangContext.
val wayang = new WayangContext()
.withPlugin(Java.graphPlugin)
.withPlugin(WayangBasics.graphPlugin)
.withPlugin(Java.basicPlugin)
import org.apache.wayang.api.graph._
val edges = Seq((0, 1), (0, 2), (0, 3), (1, 0), (2, 1), (3, 2), (3, 1)).map(t => Edge(t._1, t._2))
val pageRanks = wayang
.loadCollection(edges).withName("Load edges")
.pageRank(20).withName("PageRank")
.collect()
.map(t => t.field0.longValue -> t.field1)
.toMap
print(pageRanks)
// Let's not check absolute numbers but only the relative ordering.
Assert.assertTrue(pageRanks(1) > pageRanks(0))
Assert.assertTrue(pageRanks(0) > pageRanks(2))
Assert.assertTrue(pageRanks(2) > pageRanks(3))
}
@Test
def testMapPartitions() = {
// Set up WayangContext.
val wayang = new WayangContext()
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin)
val typeCounts = wayang
.loadCollection(Seq(0, 1, 2, 3, 4, 6, 8))
.mapPartitions { ints =>
var (numOdds, numEvens) = (0, 0)
ints.foreach(i => if ((i & 1) == 0) numEvens += 1 else numOdds += 1)
Seq(("odd", numOdds), ("even", numEvens))
}
.reduceByKey(_._1, { case ((kind1, count1), (kind2, count2)) => (kind1, count1 + count2) })
.collect()
Assert.assertEquals(Set(("odd", 2), ("even", 5)), typeCounts.toSet)
}
@Test
def testZipWithId() = {
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin).withPlugin(Spark.basicPlugin)
val inputValues = for (i <- 0 until 100; j <- 0 until 42) yield i
val result = wayang
.loadCollection(inputValues)
.zipWithId
.groupByKey(_.field1)
.map { group =>
import scala.collection.JavaConversions._
(group.map(_.field0).toSet.size, 1)
}
.reduceByKey(_._1, (t1, t2) => (t1._1, t1._2 + t2._2))
.collect()
val expectedValues = Set((42, 100))
Assert.assertEquals(expectedValues, result.toSet)
}
@Test
def testWriteTextFile() = {
val tempDir = LocalFileSystem.findTempDir
val targetUrl = LocalFileSystem.toURL(new File(tempDir, "testWriteTextFile.txt"))
// Set up WayangContext.
val wayang = new WayangContext().withPlugin(Java.basicPlugin)
val inputValues = for (i <- 0 to 5) yield i * 0.333333333333
val result = wayang
.loadCollection(inputValues)
.writeTextFile(targetUrl, formatterUdf = d => f"${d % .2f}")
val lines = scala.collection.mutable.Set[String]()
Files.lines(Paths.get(new URI(targetUrl))).forEach(new Consumer[String] {
override def accept(line: String): Unit = lines += line
})
val expectedLines = inputValues.map(v => f"${v % .2f}").toSet
Assert.assertEquals(expectedLines, lines)
}
@Test
def testSqlOnJava() = {
// Initialize some test data.
val configuration = new Configuration
val sqlite3dbFile = File.createTempFile("wayang-sqlite3", "db")
sqlite3dbFile.deleteOnExit()
configuration.setProperty("wayang.sqlite3.jdbc.url", "jdbc:sqlite:" + sqlite3dbFile.getAbsolutePath)
try {
val connection: Connection = Sqlite3.platform.createDatabaseDescriptor(configuration).createJdbcConnection
try {
val statement: Statement = connection.createStatement
statement.addBatch("DROP TABLE IF EXISTS customer;")
statement.addBatch("CREATE TABLE customer (name TEXT, age INT);")
statement.addBatch("INSERT INTO customer VALUES ('John', 20)")
statement.addBatch("INSERT INTO customer VALUES ('Timmy', 16)")
statement.addBatch("INSERT INTO customer VALUES ('Evelyn', 35)")
statement.executeBatch()
} finally {
if (connection != null) connection.close()
}
}
// Set up WayangContext.
val wayang = new WayangContext(configuration).withPlugin(Java.basicPlugin).withPlugin(Sqlite3.plugin)
val result = wayang
.readTable(new Sqlite3TableSource("customer", "name", "age"))
.filter(r => r.getField(1).asInstanceOf[Integer] >= 18, sqlUdf = "age >= 18").withTargetPlatforms(Java.platform)
.projectRecords(Seq("name"))
.map(_.getField(0).asInstanceOf[String])
.collect()
.toSet
val expectedValues = Set("John", "Evelyn")
Assert.assertEquals(expectedValues, result)
}
@Test
def testSqlOnSqlite3() = {
// Initialize some test data.
val configuration = new Configuration
val sqlite3dbFile = File.createTempFile("wayang-sqlite3", "db")
sqlite3dbFile.deleteOnExit()
configuration.setProperty("wayang.sqlite3.jdbc.url", "jdbc:sqlite:" + sqlite3dbFile.getAbsolutePath)
try {
val connection: Connection = Sqlite3.platform.createDatabaseDescriptor(configuration).createJdbcConnection
try {
val statement: Statement = connection.createStatement
statement.addBatch("DROP TABLE IF EXISTS customer;")
statement.addBatch("CREATE TABLE customer (name TEXT, age INT);")
statement.addBatch("INSERT INTO customer VALUES ('John', 20)")
statement.addBatch("INSERT INTO customer VALUES ('Timmy', 16)")
statement.addBatch("INSERT INTO customer VALUES ('Evelyn', 35)")
statement.executeBatch
} finally {
if (connection != null) connection.close()
}
}
// Set up WayangContext.
val wayang = new WayangContext(configuration).withPlugin(Java.basicPlugin).withPlugin(Sqlite3.plugin)
val result = wayang
.readTable(new Sqlite3TableSource("customer", "name", "age"))
.filter(r => r.getField(1).asInstanceOf[Integer] >= 18, sqlUdf = "age >= 18")
.projectRecords(Seq("name")).withTargetPlatforms(Sqlite3.platform)
.map(_.getField(0).asInstanceOf[String])
.collect()
.toSet
val expectedValues = Set("John", "Evelyn")
Assert.assertEquals(expectedValues, result)
}
}