blob: fccf256c381fe903df78e468dedf48c5494e7387 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.runtime.stream.sql
import org.apache.flink.api.common.time.Time
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.table.api._
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.functions.{AggregateFunction, FunctionContext, ScalarFunction}
import org.apache.flink.table.planner.factories.TestValuesTableFactory
import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.WeightedAvg
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.planner.runtime.utils.TimeTestUtil.EventTimeSourceFunction
import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestingAppendSink, UserDefinedFunctionTestUtils}
import org.apache.flink.table.planner.utils.TableTestUtil
import org.apache.flink.types.Row
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.sql.Timestamp
import java.time.{Instant, ZoneId}
import java.util.TimeZone
import scala.collection.mutable
@RunWith(classOf[Parameterized])
class MatchRecognizeITCase(backend: StateBackendMode) extends StreamingWithStateTestBase(backend) {
@Test
def testSimplePattern(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(Int, String)]
data.+=((1, "a"))
data.+=((2, "z"))
data.+=((3, "b"))
data.+=((4, "c"))
data.+=((5, "d"))
data.+=((6, "a"))
data.+=((7, "b"))
data.+=((8, "c"))
data.+=((9, "h"))
val t = env.fromCollection(data).toTable(tEnv,'id, 'name, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
val sqlQuery =
s"""
|SELECT T.aid, T.bid, T.cid
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| `A"`.id AS aid,
| \u006C.id AS bid,
| C.id AS cid
| PATTERN (`A"` \u006C C)
| DEFINE
| `A"` AS name = 'a',
| \u006C AS name = 'b',
| C AS name = 'c'
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList("6,7,8")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testSimplePatternWithNulls(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(Int, String, String)]
data.+=((1, "a", null))
data.+=((2, "b", null))
data.+=((3, "c", null))
data.+=((4, "d", null))
data.+=((5, null, null))
data.+=((6, "a", null))
data.+=((7, "b", null))
data.+=((8, "c", null))
data.+=((9, null, null))
val t = env.fromCollection(data).toTable(tEnv,'id, 'name, 'nullField, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
val sqlQuery =
s"""
|SELECT T.aid, T.bNull, T.cid, T.aNull
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| A.id AS aid,
| A.nullField AS aNull,
| LAST(B.nullField) AS bNull,
| C.id AS cid
| PATTERN (A B C)
| DEFINE
| A AS name = 'a' AND nullField IS NULL,
| B AS name = 'b' AND LAST(A.nullField) IS NULL,
| C AS name = 'c'
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList("1,null,3,null", "6,null,8,null")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testCodeSplitsAreProperlyGenerated(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
// TODO: this code is ported from old planner,
// However code split is not supported in planner yet.
tEnv.getConfig.setMaxGeneratedCodeLength(1)
val data = new mutable.MutableList[(Int, String, String, String)]
data.+=((1, "a", "key1", "second_key3"))
data.+=((2, "b", "key1", "second_key3"))
data.+=((3, "c", "key1", "second_key3"))
data.+=((4, "d", "key", "second_key"))
data.+=((5, "e", "key", "second_key"))
data.+=((6, "a", "key2", "second_key4"))
data.+=((7, "b", "key2", "second_key4"))
data.+=((8, "c", "key2", "second_key4"))
data.+=((9, "f", "key", "second_key"))
val t = env.fromCollection(data)
.toTable(tEnv, 'id, 'name, 'key1, 'key2, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
val sqlQuery =
s"""
|SELECT *
|FROM MyTable
|MATCH_RECOGNIZE (
| PARTITION BY key1, key2
| ORDER BY proctime
| MEASURES
| A.id AS aid,
| A.key1 AS akey1,
| LAST(B.id) AS bid,
| C.id AS cid,
| C.key2 AS ckey2
| PATTERN (A B C)
| DEFINE
| A AS name = 'a' AND key1 LIKE '%key%' AND id > 0,
| B AS name = 'b' AND LAST(A.name, 2) IS NULL,
| C AS name = 'c' AND LAST(A.name) = 'a'
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList(
"key1,second_key3,1,key1,2,3,second_key3",
"key2,second_key4,6,key2,7,8,second_key4")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testEventsAreProperlyOrdered(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = Seq(
Left(2L, (12, 1, "a", 1)),
Left(1L, (11, 2, "b", 2)),
Left(3L, (10, 3, "c", 3)), //event time order breaks this match
Right(3L),
Left(4L, (8, 4, "a", 4)),
Left(4L, (9, 5, "b", 5)),
Left(5L, (7, 6, "c", 6)), //secondary order breaks this match
Right(5L),
Left(6L, (6, 8, "a", 7)),
Left(6L, (6, 7, "b", 8)),
Left(8L, (4, 9, "c", 9)), //ternary order breaks this match
Right(8L),
Left(9L, (3, 10, "a", 10)),
Left(10L, (2, 11, "b", 11)),
Left(11L, (1, 12, "c", 12)),
Right(11L)
)
val t = env.addSource(new EventTimeSourceFunction[(Int, Int, String, Int)](data))
.toTable(tEnv, 'secondaryOrder, 'ternaryOrder, 'name, 'id, 'rowtime.rowtime)
tEnv.registerTable("MyTable", t)
val sqlQuery =
s"""
|SELECT T.aid, T.bid, T.cid
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY rowtime, secondaryOrder DESC, ternaryOrder ASC
| MEASURES
| A.id AS aid,
| B.id AS bid,
| C.id AS cid
| PATTERN (A B C)
| DEFINE
| A AS name = 'a',
| B AS name = 'b',
| C AS name = 'c'
|) AS T
|""".stripMargin
val table = tEnv.sqlQuery(sqlQuery)
val sink = new TestingAppendSink()
val result = table.toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList("10,11,12")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testMatchRecognizeAppliedToWindowedGrouping(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(String, Long, Int, Int)]
//first window
data.+=(("ACME", Time.seconds(1).toMilliseconds, 1, 1))
data.+=(("ACME", Time.seconds(2).toMilliseconds, 2, 2))
//second window
data.+=(("ACME", Time.seconds(4).toMilliseconds, 1, 4))
data.+=(("ACME", Time.seconds(5).toMilliseconds, 1, 3))
//third window
data.+=(("ACME", Time.seconds(7).toMilliseconds, 2, 3))
data.+=(("ACME", Time.seconds(8).toMilliseconds, 2, 3))
data.+=(("ACME1", Time.seconds(1).toMilliseconds, 20, 4))
data.+=(("ACME1", Time.seconds(1).toMilliseconds, 24, 4))
data.+=(("ACME1", Time.seconds(1).toMilliseconds, 25, 3))
data.+=(("ACME1", Time.seconds(1).toMilliseconds, 19, 8))
val t = env.fromCollection(data)
.assignAscendingTimestamps(e => e._2)
.toTable(tEnv, 'symbol, 'rowtime.rowtime, 'price, 'tax)
tEnv.registerTable("Ticker", t)
val sqlQuery =
s"""
|SELECT *
|FROM (
| SELECT
| symbol,
| SUM(price) as price,
| TUMBLE_ROWTIME(rowtime, interval '3' second) as rowTime,
| TUMBLE_START(rowtime, interval '3' second) as startTime
| FROM Ticker
| GROUP BY symbol, TUMBLE(rowtime, interval '3' second)
|)
|MATCH_RECOGNIZE (
| PARTITION BY symbol
| ORDER BY rowTime
| MEASURES
| B.price as dPrice,
| B.startTime as dTime
| ONE ROW PER MATCH
| PATTERN (A B)
| DEFINE
| B AS B.price < A.price
|)
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List("ACME,2,1970-01-01T00:00:03")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testWindowedGroupingAppliedToMatchRecognize(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(String, Long, Int, Int)]
//first window
data.+=(("ACME", Time.seconds(1).toMilliseconds, 1, 1))
data.+=(("ACME", Time.seconds(2).toMilliseconds, 2, 2))
//second window
data.+=(("ACME", Time.seconds(4).toMilliseconds, 1, 4))
data.+=(("ACME", Time.seconds(5).toMilliseconds, 1, 3))
val tickerEvents = env.fromCollection(data)
.assignAscendingTimestamps(tickerEvent => tickerEvent._2)
.toTable(tEnv, 'symbol, 'rowtime.rowtime, 'price, 'tax)
tEnv.registerTable("Ticker", tickerEvents)
val sqlQuery =
s"""
|SELECT
| symbol,
| SUM(price) as price,
| TUMBLE_ROWTIME(matchRowtime, interval '3' second) as rowTime,
| TUMBLE_START(matchRowtime, interval '3' second) as startTime
|FROM Ticker
|MATCH_RECOGNIZE (
| PARTITION BY symbol
| ORDER BY rowtime
| MEASURES
| A.price as price,
| A.tax as tax,
| MATCH_ROWTIME() as matchRowtime
| ONE ROW PER MATCH
| PATTERN (A)
| DEFINE
| A AS A.price > 0
|) AS T
|GROUP BY symbol, TUMBLE(matchRowtime, interval '3' second)
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List(
"ACME,3,1970-01-01T00:00:02.999,1970-01-01T00:00",
"ACME,2,1970-01-01T00:00:05.999,1970-01-01T00:00:03")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testWindowedGroupingAppliedToMatchRecognizeOnLtzRowtime(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data: Seq[Row] = Seq(
//first window
rowOf("ACME", Instant.ofEpochSecond(1), 1, 1),
rowOf("ACME", Instant.ofEpochSecond(2), 2, 2),
//second window
rowOf("ACME", Instant.ofEpochSecond(3), 1, 4),
rowOf("ACME", Instant.ofEpochSecond(4), 1, 3)
)
tEnv.getConfig.setLocalTimeZone(ZoneId.of("Asia/Shanghai"))
val dataId = TestValuesTableFactory.registerData(data)
tEnv.executeSql(
s"""
|CREATE TABLE Ticker (
| `symbol` STRING,
| `ts_ltz` TIMESTAMP_LTZ(3),
| `price` INT,
| `tax` INT,
| WATERMARK FOR `ts_ltz` AS `ts_ltz` - INTERVAL '1' SECOND
|) WITH (
| 'connector' = 'values',
| 'data-id' = '$dataId'
|)
|""".stripMargin)
val sqlQuery =
s"""
|SELECT
| symbol,
| SUM(price) as price,
| TUMBLE_ROWTIME(matchRowtime, interval '3' second) as rowTime,
| TUMBLE_START(matchRowtime, interval '3' second) as startTime
|FROM Ticker
|MATCH_RECOGNIZE (
| PARTITION BY symbol
| ORDER BY ts_ltz
| MEASURES
| A.price as price,
| A.tax as tax,
| MATCH_ROWTIME(ts_ltz) as matchRowtime
| ONE ROW PER MATCH
| PATTERN (A)
| DEFINE
| A AS A.price > 0
|) AS T
|GROUP BY symbol, TUMBLE(matchRowtime, interval '3' second)
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List(
"ACME,3,1970-01-01T00:00:02.999Z,1970-01-01T08:00",
"ACME,2,1970-01-01T00:00:05.999Z,1970-01-01T08:00:03")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testLogicalOffsets(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(String, Long, Int, Int)]
data.+=(("ACME", 1L, 19, 1))
data.+=(("ACME", 2L, 17, 2))
data.+=(("ACME", 3L, 13, 3))
data.+=(("ACME", 4L, 20, 4))
data.+=(("ACME", 5L, 20, 5))
data.+=(("ACME", 6L, 26, 6))
data.+=(("ACME", 7L, 20, 7))
data.+=(("ACME", 8L, 25, 8))
val t = env.fromCollection(data)
.toTable(tEnv, 'symbol, 'tstamp, 'price, 'tax, 'proctime.proctime)
tEnv.registerTable("Ticker", t)
val sqlQuery =
s"""
|SELECT *
|FROM Ticker
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| FIRST(DOWN.tstamp) AS start_tstamp,
| LAST(DOWN.tstamp) AS bottom_tstamp,
| UP.tstamp AS end_tstamp,
| FIRST(DOWN.price + DOWN.tax + 1) AS bottom_total,
| UP.price + UP.tax AS end_total
| ONE ROW PER MATCH
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (DOWN{2,} UP)
| DEFINE
| DOWN AS price < LAST(DOWN.price, 1) OR LAST(DOWN.price, 1) IS NULL,
| UP AS price < FIRST(DOWN.price)
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List("6,7,8,33,33")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testPartitionByWithParallelSource(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(String, Long, Int, Int)]
data.+=(("ACME", 1L, 19, 1))
data.+=(("ACME", 2L, 17, 2))
data.+=(("ACME", 3L, 13, 3))
data.+=(("ACME", 4L, 20, 4))
val t = env.fromCollection(data)
.assignAscendingTimestamps(tickerEvent => tickerEvent._2)
.setParallelism(env.getParallelism)
.toTable(tEnv, 'symbol, 'rowtime.rowtime, 'price, 'tax)
tEnv.registerTable("Ticker", t)
val sqlQuery =
s"""
|SELECT *
|FROM Ticker
|MATCH_RECOGNIZE (
| PARTITION BY symbol
| ORDER BY rowtime
| MEASURES
| DOWN.tax AS bottom_tax,
| UP.tax AS end_tax
| ONE ROW PER MATCH
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (DOWN UP)
| DEFINE
| DOWN AS DOWN.price = 13,
| UP AS UP.price = 20
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List("ACME,3,4")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testLogicalOffsetsWithStarVariable(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(Int, String, Long, Int)]
data.+=((1, "ACME", 1L, 20))
data.+=((2, "ACME", 2L, 19))
data.+=((3, "ACME", 3L, 18))
data.+=((4, "ACME", 4L, 17))
data.+=((5, "ACME", 5L, 16))
data.+=((6, "ACME", 6L, 15))
data.+=((7, "ACME", 7L, 14))
data.+=((8, "ACME", 8L, 20))
val t = env.fromCollection(data)
.toTable(tEnv, 'id, 'symbol, 'tstamp, 'price, 'proctime.proctime)
tEnv.registerTable("Ticker", t)
val sqlQuery =
s"""
|SELECT *
|FROM Ticker
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| FIRST(id, 0) as id0,
| FIRST(id, 1) as id1,
| FIRST(id, 2) as id2,
| FIRST(id, 3) as id3,
| FIRST(id, 4) as id4,
| FIRST(id, 5) as id5,
| FIRST(id, 6) as id6,
| FIRST(id, 7) as id7,
| LAST(id, 0) as id8,
| LAST(id, 1) as id9,
| LAST(id, 2) as id10,
| LAST(id, 3) as id11,
| LAST(id, 4) as id12,
| LAST(id, 5) as id13,
| LAST(id, 6) as id14,
| LAST(id, 7) as id15
| ONE ROW PER MATCH
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (`DOWN"`{2,} UP)
| DEFINE
| `DOWN"` AS price < LAST(price, 1) OR LAST(price, 1) IS NULL,
| UP AS price = FIRST(price) AND price > FIRST(price, 3) AND price = LAST(price, 7)
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List("1,2,3,4,5,6,7,8,8,7,6,5,4,3,2,1")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testLogicalOffsetOutsideOfRangeInMeasures(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(String, Long, Int, Int)]
data.+=(("ACME", 1L, 19, 1))
data.+=(("ACME", 2L, 17, 2))
data.+=(("ACME", 3L, 13, 3))
data.+=(("ACME", 4L, 20, 4))
val t = env.fromCollection(data)
.toTable(tEnv, 'symbol, 'tstamp, 'price, 'tax, 'proctime.proctime)
tEnv.registerTable("Ticker", t)
val sqlQuery =
s"""
|SELECT *
|FROM Ticker
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| FIRST(DOWN.price) as first,
| LAST(DOWN.price) as last,
| FIRST(DOWN.price, 5) as nullPrice
| ONE ROW PER MATCH
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (DOWN{2,} UP)
| DEFINE
| DOWN AS price < LAST(DOWN.price, 1) OR LAST(DOWN.price, 1) IS NULL,
| UP AS price > LAST(DOWN.price)
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List("19,13,null")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
/**
* This query checks:
*
* 1. count(D.price) produces 0, because no rows matched to D
* 2. sum(D.price) produces null, because no rows matched to D
* 3. aggregates that take multiple parameters work
* 4. aggregates with expressions work
*/
@Test
def testAggregates(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
tEnv.getConfig.setMaxGeneratedCodeLength(1)
val data = new mutable.MutableList[(Int, String, Long, Double, Int)]
data.+=((1, "a", 1, 0.8, 1))
data.+=((2, "z", 2, 0.8, 3))
data.+=((3, "b", 1, 0.8, 2))
data.+=((4, "c", 1, 0.8, 5))
data.+=((5, "d", 4, 0.1, 5))
data.+=((6, "a", 2, 1.5, 2))
data.+=((7, "b", 2, 0.8, 3))
data.+=((8, "c", 1, 0.8, 2))
data.+=((9, "h", 4, 0.8, 3))
data.+=((10, "h", 4, 0.8, 3))
data.+=((11, "h", 2, 0.8, 3))
data.+=((12, "h", 2, 0.8, 3))
val t = env.fromCollection(data)
.toTable(tEnv, 'id, 'name, 'price, 'rate, 'weight, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
tEnv.createTemporarySystemFunction("weightedAvg", classOf[WeightedAvg])
val sqlQuery =
s"""
|SELECT *
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| FIRST(id) as startId,
| SUM(A.price) AS sumA,
| COUNT(D.price) AS countD,
| SUM(D.price) as sumD,
| weightedAvg(price, weight) as wAvg,
| AVG(B.price) AS avgB,
| SUM(B.price * B.rate) as sumExprB,
| LAST(id) as endId
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (A+ B+ C D? E )
| DEFINE
| A AS SUM(A.price) < 6,
| B AS SUM(B.price * B.rate) < SUM(A.price) AND
| SUM(B.price * B.rate) > 0.2 AND
| SUM(B.price) >= 1 AND
| AVG(B.price) >= 1 AND
| weightedAvg(price, weight) > 1
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList("1,5,0,null,2,3,3.4,8", "9,4,0,null,3,4,3.2,12")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testAggregatesWithNullInputs(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
tEnv.getConfig.setMaxGeneratedCodeLength(1)
val data = new mutable.MutableList[Row]
data.+=(Row.of(Int.box(1), "a", Int.box(10)))
data.+=(Row.of(Int.box(2), "z", Int.box(10)))
data.+=(Row.of(Int.box(3), "b", null))
data.+=(Row.of(Int.box(4), "c", null))
data.+=(Row.of(Int.box(5), "d", Int.box(3)))
data.+=(Row.of(Int.box(6), "c", Int.box(3)))
data.+=(Row.of(Int.box(7), "c", Int.box(3)))
data.+=(Row.of(Int.box(8), "c", Int.box(3)))
data.+=(Row.of(Int.box(9), "c", Int.box(2)))
val t = env.fromCollection(data)(Types.ROW(
BasicTypeInfo.INT_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO))
.toTable(tEnv, 'id, 'name, 'price, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
tEnv.registerFunction("weightedAvg", new WeightedAvg)
val sqlQuery =
s"""
|SELECT *
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| SUM(A.price) as sumA,
| COUNT(A.id) as countAId,
| COUNT(A.price) as countAPrice,
| COUNT(*) as countAll,
| COUNT(price) as countAllPrice,
| LAST(id) as endId
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (A+ C)
| DEFINE
| A AS SUM(A.price) < 30,
| C AS C.name = 'c'
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList("29,7,5,8,6,8")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
@Test
def testAccessingCurrentTime(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
val data = new mutable.MutableList[(Int, String)]
data.+=((1, "a"))
val t = env.fromCollection(data).toTable(tEnv,'id, 'name, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
val sqlQuery =
s"""
|SELECT T.aid
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| A.id AS aid,
| A.proctime AS aProctime,
| LAST(A.proctime + INTERVAL '1' second) as calculatedField
| PATTERN (A)
| DEFINE
| A AS proctime >= (CURRENT_TIMESTAMP - INTERVAL '1' day)
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = List("1")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
// We do not assert the proctime in the result, cause it is currently
// accessed from System.currentTimeMillis(), so there is no graceful way to assert the proctime
}
@Test
def testUserDefinedFunctions(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
tEnv.getConfig.setMaxGeneratedCodeLength(1)
val data = new mutable.MutableList[(Int, String, Long)]
data.+=((1, "a", 1))
data.+=((2, "a", 1))
data.+=((3, "a", 1))
data.+=((4, "a", 1))
data.+=((5, "a", 1))
data.+=((6, "b", 1))
data.+=((7, "a", 1))
data.+=((8, "a", 1))
data.+=((9, "f", 1))
val t = env.fromCollection(data)
.toTable(tEnv, 'id, 'name, 'price, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
tEnv.registerFunction("prefix", new PrefixingScalarFunc)
tEnv.registerFunction("countFrom", new RichAggFunc)
val prefix = "PREF"
val startFrom = 4
UserDefinedFunctionTestUtils
.setJobParameters(env, Map("prefix" -> prefix, "start" -> startFrom.toString))
val sqlQuery =
s"""
|SELECT *
|FROM MyTable
|MATCH_RECOGNIZE (
| ORDER BY proctime
| MEASURES
| FIRST(id) as firstId,
| prefix(A.name) as prefixedNameA,
| countFrom(A.price) as countFromA,
| LAST(id) as lastId
| AFTER MATCH SKIP PAST LAST ROW
| PATTERN (A+ C)
| DEFINE
| A AS prefix(A.name) = '$prefix:a' AND countFrom(A.price) <= ${startFrom + 4}
|) AS T
|""".stripMargin
val sink = new TestingAppendSink()
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
result.addSink(sink)
env.execute()
val expected = mutable.MutableList("1,PREF:a,8,5", "7,PREF:a,6,9")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}
}
@SerialVersionUID(1L)
class ToMillis extends ScalarFunction {
def eval(t: Timestamp): Long = {
t.toInstant.toEpochMilli + TimeZone.getDefault.getOffset(t.toInstant.toEpochMilli)
}
}
@SerialVersionUID(1L)
private class PrefixingScalarFunc extends ScalarFunction {
private var prefix = "ERROR_VALUE"
override def open(context: FunctionContext): Unit = {
prefix = context.getJobParameter("prefix", "")
}
def eval(value: String): String = {
s"$prefix:$value"
}
}
private case class CountAcc(var count: Long)
private class RichAggFunc extends AggregateFunction[Long, CountAcc] {
private var start : Long = 0
override def open(context: FunctionContext): Unit = {
start = context.getJobParameter("start", "0").toLong
}
override def close(): Unit = {
start = 0
}
def accumulate(countAcc: CountAcc, value: Long): Unit = {
countAcc.count += value
}
override def createAccumulator(): CountAcc = CountAcc(start)
override def getValue(accumulator: CountAcc): Long = accumulator.count
}