blob: 6d9b44cf6d0281adf1260a87877e5778671f15dd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.spark
import org.apache.ignite.{Ignite, Ignition}
import org.apache.ignite.configuration.{CacheConfiguration, IgniteConfiguration}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.scalatest._
import java.lang.{LongJLong}
import org.apache.ignite.cache.query.SqlFieldsQuery
import org.apache.ignite.cache.query.annotations.QuerySqlField
import org.apache.ignite.internal.IgnitionEx.loadConfiguration
import org.apache.ignite.spark.AbstractDataFrameSpec.configuration
import org.apache.ignite.spark.impl.IgniteSQLAccumulatorRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.ignite.spark.AbstractDataFrameSpec._
import scala.annotation.meta.field
import scala.reflect.ClassTag
/**
*/
abstract class AbstractDataFrameSpec extends FunSpec with Matchers with BeforeAndAfterAll with BeforeAndAfter
with Assertions {
var spark: SparkSession = _
var client: Ignite = _
private val NUM_SERVERS = 5
override protected def beforeAll(): Unit = {
for (i ← 0 to NUM_SERVERS)
Ignition.start(configuration("grid-" + i, client = false))
client = Ignition.getOrStart(configuration("client", client = true))
createSparkSession()
}
override protected def afterAll(): Unit = {
Ignition.stop("client", false)
for (i ← 0 to NUM_SERVERS)
Ignition.stop("grid-" + i, false)
spark.close()
}
protected def createSparkSession(): Unit = {
spark = SparkSession.builder()
.appName("DataFrameSpec")
.master("local")
.config("spark.executor.instances", "2")
.getOrCreate()
}
def createPersonTable2(client: Ignite, cacheName: String): Unit =
createPersonTable0(client, cacheName, PERSON_TBL_NAME_2)
def createPersonTable(client: Ignite, cacheName: String): Unit =
createPersonTable0(client, cacheName, PERSON_TBL_NAME)
private def createPersonTable0(client: Ignite, cacheName: String, tblName: String): Unit = {
val cache = client.cache(cacheName)
cache.query(new SqlFieldsQuery(
s"""
| CREATE TABLE $tblName (
| id LONG,
| name VARCHAR,
| birth_date DATE,
| is_resident BOOLEAN,
| salary FLOAT,
| pension DOUBLE,
| account DECIMAL,
| age INT,
| city_id LONG,
| PRIMARY KEY (id, city_id)) WITH "backups=1, affinityKey=city_id"
""".stripMargin)).getAll
val qry = new SqlFieldsQuery(s"INSERT INTO $tblName (id, name, city_id) values (?, ?, ?)")
cache.query(qry.setArgs(1L.asInstanceOf[JLong], "John Doe", 3L.asInstanceOf[JLong])).getAll
cache.query(qry.setArgs(2L.asInstanceOf[JLong], "Jane Roe", 2L.asInstanceOf[JLong])).getAll
cache.query(qry.setArgs(3L.asInstanceOf[JLong], "Mary Major", 1L.asInstanceOf[JLong])).getAll
cache.query(qry.setArgs(4L.asInstanceOf[JLong], "Richard Miles", 2L.asInstanceOf[JLong])).getAll
cache.query(qry.setArgs(5L.asInstanceOf[JLong], null, 2L.asInstanceOf[JLong])).getAll
}
def createCityTable(client: Ignite, cacheName: String): Unit = {
val cache = client.cache(cacheName)
cache.query(new SqlFieldsQuery(
"CREATE TABLE city (id LONG PRIMARY KEY, name VARCHAR) WITH \"template=replicated\"")).getAll
val qry = new SqlFieldsQuery("INSERT INTO city (id, name) VALUES (?, ?)")
cache.query(qry.setArgs(1L.asInstanceOf[JLong], "Forest Hill")).getAll
cache.query(qry.setArgs(2L.asInstanceOf[JLong], "Denver")).getAll
cache.query(qry.setArgs(3L.asInstanceOf[JLong], "St. Petersburg")).getAll
cache.query(qry.setArgs(4L.asInstanceOf[JLong], "St. Petersburg")).getAll
}
def createEmployeeCache(client: Ignite, cacheName: String, schemaName: Option[String] = None): Unit = {
val ccfg = AbstractDataFrameSpec.cacheConfiguration[String, Employee](cacheName)
schemaName.foreach(ccfg.setSqlSchema)
val cache = client.getOrCreateCache(ccfg)
cache.put("key1", Employee(1, "John Connor", 15, 0))
cache.put("key2", Employee(2, "Sarah Connor", 32, 10000))
cache.put("key3", Employee(3, "Arnold Schwarzenegger", 27, 1000))
}
def checkQueryData[T](res: DataFrame, expectedRes: Product)
(implicit ord: T ⇒ Ordered[T]): Unit =
checkQueryData(res, expectedRes, _.getAs[T](0))
def checkQueryData[Ordered](res: DataFrame, expectedRes: Product, sorter: Row => Ordered)
(implicit ord: Ordering[Ordered]): Unit = {
val data = res.rdd.collect.sortBy(sorter)
for(i ← 0 until expectedRes.productArity) {
val row = data(i)
if (row.size == 1)
assert(row(0) == expectedRes.productElement(i), s"row[$i, 0] = ${row(0)} should be equal ${expectedRes.productElement(i)}")
else {
val expectedRow: Product = expectedRes.productElement(i).asInstanceOf[Product]
assert(expectedRow.productArity == row.size, s"Rows size should be equal, but expected.size=${expectedRow.productArity} " +
s"and row.size=${row.size}")
for (j ← 0 until expectedRow.productArity)
assert(row(j) == expectedRow.productElement(j), s"row[$i, $j] = ${row(j)} should be equal ${expectedRow.productElement(j)}")
}
}
}
}
object AbstractDataFrameSpec {
val TEST_CONFIG_FILE = "ignite-spark-config.xml"
val DEFAULT_CACHE = "cache1"
val TEST_OBJ_TEST_OBJ_CACHE_NAME = "cache2"
val EMPLOYEE_CACHE_NAME = "cache3"
val PERSON_TBL_NAME = "person"
val PERSON_TBL_NAME_2 = "person2"
def configuration(igniteInstanceName: String, client: Boolean): IgniteConfiguration = {
val cfg = loadConfiguration(TEST_CONFIG_FILE).get1()
cfg.setClientMode(client)
cfg.setIgniteInstanceName(igniteInstanceName)
cfg
}
/**
* Gets cache configuration for the given grid name.
*
* @tparam K class of cached keys
* @tparam V class of cached values
* @param cacheName cache name.
* @return Cache configuration.
*/
def cacheConfiguration[K : ClassTag, V : ClassTag](cacheName : String): CacheConfiguration[Object, Object] = {
val ccfg = new CacheConfiguration[Object, Object]()
ccfg.setBackups(1)
ccfg.setName(cacheName)
ccfg.setIndexedTypes(
implicitly[reflect.ClassTag[K]].runtimeClass.asInstanceOf[Class[K]],
implicitly[reflect.ClassTag[V]].runtimeClass.asInstanceOf[Class[V]])
ccfg
}
/**
* @param df Data frame.
* @param qry SQL Query.
*/
def checkOptimizationResult(df: DataFrame, qry: String = ""): Unit = {
df.explain(true)
val plan = df.queryExecution.optimizedPlan
val cnt = plan.collectLeaves.count {
case LogicalRelation(relation: IgniteSQLAccumulatorRelation[_, _], _, _, _)
if (qry != "")
assert(qry.toLowerCase == relation.acc.compileQuery().toLowerCase,
s"Generated query should be equal to expected.\nexpected - ${qry.toLowerCase}\ngenerated - ${relation.acc.compileQuery().toLowerCase}")
true
case _ ⇒
false
}
assert(cnt != 0, s"Plan should contains IgniteSQLAccumulatorRelation")
}
/**
* Enclose some closure, so it doesn't on outer object(default scala behaviour) while serializing.
*/
def enclose[E, R](enclosed: E)(func: E => R): R = func(enclosed)
}
case class Employee (
@(QuerySqlField @field)(index = true) id: Long,
@(QuerySqlField @field) name: String,
age: Int,
@(QuerySqlField @field)(index = true, descending = true) salary: Float
) extends Serializable { }