| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| |
| package org.apache.predictionio.data.storage.jdbc |
| |
| import org.apache.predictionio.data.storage._ |
| import org.apache.spark.SparkContext |
| import org.apache.spark.rdd.RDD |
| import org.specs2._ |
| import org.specs2.specification.Step |
| |
| class PEventsSpec extends Specification with TestEvents { |
| |
| System.clearProperty("spark.driver.port") |
| System.clearProperty("spark.hostPort") |
| val sc = new SparkContext("local[4]", "PEventAggregatorSpec test") |
| |
| val appId = 1 |
| val channelId = 6 |
| val dbName = "test_pio_storage_events_" + hashCode |
| |
| def jdbcLocal = Storage.getDataObject[LEvents]( |
| StorageTestUtils.jdbcSourceName, |
| dbName |
| ) |
| |
| def jdbcPar = Storage.getDataObject[PEvents]( |
| StorageTestUtils.jdbcSourceName, |
| dbName |
| ) |
| |
| def stopSpark = { |
| sc.stop() |
| } |
| |
| def is = s2""" |
| |
| PredictionIO Storage PEvents Specification |
| |
| PEvents can be implemented by: |
| - JDBCPEvents ${jdbcPEvents} |
| - (stop Spark) ${Step(sc.stop())} |
| |
| """ |
| |
| def jdbcPEvents = sequential ^ s2""" |
| |
| JDBCPEvents should |
| - behave like any PEvents implementation ${events(jdbcLocal, jdbcPar)} |
| - (table cleanup) ${Step(StorageTestUtils.dropJDBCTable(s"${dbName}_$appId"))} |
| - (table cleanup) ${Step(StorageTestUtils.dropJDBCTable(s"${dbName}_${appId}_$channelId"))} |
| |
| """ |
| |
| def events(localEventClient: LEvents, parEventClient: PEvents) = sequential ^ s2""" |
| |
| - (init test) ${initTest(localEventClient)} |
| - (insert test events) ${insertTestEvents(localEventClient)} |
| find in default ${find(parEventClient)} |
| find in channel ${findChannel(parEventClient)} |
| aggregate user properties in default ${aggregateUserProperties(parEventClient)} |
| aggregate user properties in channel ${aggregateUserPropertiesChannel(parEventClient)} |
| write to default ${write(parEventClient)} |
| write to channel ${writeChannel(parEventClient)} |
| |
| """ |
| |
| /* setup */ |
| |
| // events from TestEvents trait |
| val listOfEvents = List(u1e5, u2e2, u1e3, u1e1, u2e3, u2e1, u1e4, u1e2, r1, r2) |
| val listOfEventsChannel = List(u3e1, u3e2, u3e3, r3, r4) |
| |
| def initTest(localEventClient: LEvents) = { |
| localEventClient.init(appId) |
| localEventClient.init(appId, Some(channelId)) |
| } |
| |
| def insertTestEvents(localEventClient: LEvents) = { |
| listOfEvents.map( localEventClient.insert(_, appId) ) |
| // insert to channel |
| listOfEventsChannel.map( localEventClient.insert(_, appId, Some(channelId)) ) |
| success |
| } |
| |
| /* following are tests */ |
| |
| def find(parEventClient: PEvents) = { |
| val resultRDD: RDD[Event] = parEventClient.find( |
| appId = appId |
| )(sc) |
| |
| val results = resultRDD.collect.toList |
| .map {_.copy(eventId = None)} // ignore eventId |
| |
| results must containTheSameElementsAs(listOfEvents) |
| } |
| |
| def findChannel(parEventClient: PEvents) = { |
| val resultRDD: RDD[Event] = parEventClient.find( |
| appId = appId, |
| channelId = Some(channelId) |
| )(sc) |
| |
| val results = resultRDD.collect.toList |
| .map {_.copy(eventId = None)} // ignore eventId |
| |
| results must containTheSameElementsAs(listOfEventsChannel) |
| } |
| |
| def aggregateUserProperties(parEventClient: PEvents) = { |
| val resultRDD: RDD[(String, PropertyMap)] = parEventClient.aggregateProperties( |
| appId = appId, |
| entityType = "user" |
| )(sc) |
| val result: Map[String, PropertyMap] = resultRDD.collectAsMap.toMap |
| |
| val expected = Map( |
| "u1" -> PropertyMap(u1, u1BaseTime, u1LastTime), |
| "u2" -> PropertyMap(u2, u2BaseTime, u2LastTime) |
| ) |
| |
| result must beEqualTo(expected) |
| } |
| |
| def aggregateUserPropertiesChannel(parEventClient: PEvents) = { |
| val resultRDD: RDD[(String, PropertyMap)] = parEventClient.aggregateProperties( |
| appId = appId, |
| channelId = Some(channelId), |
| entityType = "user" |
| )(sc) |
| val result: Map[String, PropertyMap] = resultRDD.collectAsMap.toMap |
| |
| val expected = Map( |
| "u3" -> PropertyMap(u3, u3BaseTime, u3LastTime) |
| ) |
| |
| result must beEqualTo(expected) |
| } |
| |
| def write(parEventClient: PEvents) = { |
| val written = List(r5, r6) |
| val writtenRDD = sc.parallelize(written) |
| parEventClient.write(writtenRDD, appId)(sc) |
| |
| // read back |
| val resultRDD = parEventClient.find( |
| appId = appId |
| )(sc) |
| |
| val results = resultRDD.collect.toList |
| .map { _.copy(eventId = None)} // ignore eventId |
| |
| val expected = listOfEvents ++ written |
| |
| results must containTheSameElementsAs(expected) |
| } |
| |
| def writeChannel(parEventClient: PEvents) = { |
| val written = List(r1, r5, r6) |
| val writtenRDD = sc.parallelize(written) |
| parEventClient.write(writtenRDD, appId, Some(channelId))(sc) |
| |
| // read back |
| val resultRDD = parEventClient.find( |
| appId = appId, |
| channelId = Some(channelId) |
| )(sc) |
| |
| val results = resultRDD.collect.toList |
| .map { _.copy(eventId = None)} // ignore eventId |
| |
| val expected = listOfEventsChannel ++ written |
| |
| results must containTheSameElementsAs(expected) |
| } |
| |
| } |