blob: 71ebf5ff7af55b101fb71849be6df290488972bb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.data.storage.jdbc
import org.apache.predictionio.data.storage._
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.specs2._
import org.specs2.specification.Step
class PEventsSpec extends Specification with TestEvents {
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
val sc = new SparkContext("local[4]", "PEventAggregatorSpec test")
val appId = 1
val channelId = 6
val dbName = "test_pio_storage_events_" + hashCode
def jdbcLocal = Storage.getDataObject[LEvents](
StorageTestUtils.jdbcSourceName,
dbName
)
def jdbcPar = Storage.getDataObject[PEvents](
StorageTestUtils.jdbcSourceName,
dbName
)
def stopSpark = {
sc.stop()
}
def is = s2"""
PredictionIO Storage PEvents Specification
PEvents can be implemented by:
- JDBCPEvents ${jdbcPEvents}
- (stop Spark) ${Step(sc.stop())}
"""
def jdbcPEvents = sequential ^ s2"""
JDBCPEvents should
- behave like any PEvents implementation ${events(jdbcLocal, jdbcPar)}
- (table cleanup) ${Step(StorageTestUtils.dropJDBCTable(s"${dbName}_$appId"))}
- (table cleanup) ${Step(StorageTestUtils.dropJDBCTable(s"${dbName}_${appId}_$channelId"))}
"""
def events(localEventClient: LEvents, parEventClient: PEvents) = sequential ^ s2"""
- (init test) ${initTest(localEventClient)}
- (insert test events) ${insertTestEvents(localEventClient)}
find in default ${find(parEventClient)}
find in channel ${findChannel(parEventClient)}
aggregate user properties in default ${aggregateUserProperties(parEventClient)}
aggregate user properties in channel ${aggregateUserPropertiesChannel(parEventClient)}
write to default ${write(parEventClient)}
write to channel ${writeChannel(parEventClient)}
"""
/* setup */
// events from TestEvents trait
val listOfEvents = List(u1e5, u2e2, u1e3, u1e1, u2e3, u2e1, u1e4, u1e2, r1, r2)
val listOfEventsChannel = List(u3e1, u3e2, u3e3, r3, r4)
def initTest(localEventClient: LEvents) = {
localEventClient.init(appId)
localEventClient.init(appId, Some(channelId))
}
def insertTestEvents(localEventClient: LEvents) = {
listOfEvents.map( localEventClient.insert(_, appId) )
// insert to channel
listOfEventsChannel.map( localEventClient.insert(_, appId, Some(channelId)) )
success
}
/* following are tests */
def find(parEventClient: PEvents) = {
val resultRDD: RDD[Event] = parEventClient.find(
appId = appId
)(sc)
val results = resultRDD.collect.toList
.map {_.copy(eventId = None)} // ignore eventId
results must containTheSameElementsAs(listOfEvents)
}
def findChannel(parEventClient: PEvents) = {
val resultRDD: RDD[Event] = parEventClient.find(
appId = appId,
channelId = Some(channelId)
)(sc)
val results = resultRDD.collect.toList
.map {_.copy(eventId = None)} // ignore eventId
results must containTheSameElementsAs(listOfEventsChannel)
}
def aggregateUserProperties(parEventClient: PEvents) = {
val resultRDD: RDD[(String, PropertyMap)] = parEventClient.aggregateProperties(
appId = appId,
entityType = "user"
)(sc)
val result: Map[String, PropertyMap] = resultRDD.collectAsMap.toMap
val expected = Map(
"u1" -> PropertyMap(u1, u1BaseTime, u1LastTime),
"u2" -> PropertyMap(u2, u2BaseTime, u2LastTime)
)
result must beEqualTo(expected)
}
def aggregateUserPropertiesChannel(parEventClient: PEvents) = {
val resultRDD: RDD[(String, PropertyMap)] = parEventClient.aggregateProperties(
appId = appId,
channelId = Some(channelId),
entityType = "user"
)(sc)
val result: Map[String, PropertyMap] = resultRDD.collectAsMap.toMap
val expected = Map(
"u3" -> PropertyMap(u3, u3BaseTime, u3LastTime)
)
result must beEqualTo(expected)
}
def write(parEventClient: PEvents) = {
val written = List(r5, r6)
val writtenRDD = sc.parallelize(written)
parEventClient.write(writtenRDD, appId)(sc)
// read back
val resultRDD = parEventClient.find(
appId = appId
)(sc)
val results = resultRDD.collect.toList
.map { _.copy(eventId = None)} // ignore eventId
val expected = listOfEvents ++ written
results must containTheSameElementsAs(expected)
}
def writeChannel(parEventClient: PEvents) = {
val written = List(r1, r5, r6)
val writtenRDD = sc.parallelize(written)
parEventClient.write(writtenRDD, appId, Some(channelId))(sc)
// read back
val resultRDD = parEventClient.find(
appId = appId,
channelId = Some(channelId)
)(sc)
val results = resultRDD.collect.toList
.map { _.copy(eventId = None)} // ignore eventId
val expected = listOfEventsChannel ++ written
results must containTheSameElementsAs(expected)
}
}