blob: 7245acfd27e312865b3392d627fafb49639bcd49 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spot.dns.model
import org.apache.log4j.Logger
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spot.SuspiciousConnectsArgumentParser.SuspiciousConnectsConfig
import org.apache.spot.dns.DNSSchema._
import org.apache.spot.dns.DNSWordCreation
import org.apache.spot.lda.SpotLDAWrapperSchema._
import org.apache.spot.lda._
import org.apache.spot.utilities.DomainProcessor.DomainInfo
import org.apache.spot.utilities._
import org.apache.spot.utilities.data.validation.InvalidDataHandler
/**
* A probabilistic model of the DNS queries issued by each client IP.
*
* The model uses a topic-modelling approach that:
* 1. Simplifies DNS log entries into words.
* 2. Treats the DNS queries of each client into a collection of words.
* 3. Decomposes common query behaviors using a collection of "topics" that represent common profiles
* of query behavior. These "topics" are probability distributions on words.
* 4. Each client IP has a mix of topics corresponding to its behavior.
* 5. Query probability at IP is estimated by simplifying query into word, and then
* combining the word probabilities per topic using the topic mix of the particular IP.
*
* Create these models using the factory in the companion object.
*
* @param inTopicCount Number of topics to use in the topic model.
* @param inIpToTopicMix Per-IP topic mix.
* @param inWordToPerTopicProb Per-word, an array of probability of word given topic per topic.
*/
class DNSSuspiciousConnectsModel(inTopicCount: Int,
inIpToTopicMix: DataFrame,
inWordToPerTopicProb: Map[String, Array[Double]]) {
val topicCount: Int = inTopicCount
val ipToTopicMix: DataFrame = inIpToTopicMix
val wordToPerTopicProb: Map[String, Array[Double]] = inWordToPerTopicProb
/**
* Use a suspicious connects model to assign estimated probabilities to a dataframe of
* DNS log events.
*
* @param sparkSession Spark Session
* @param inDF Dataframe of DNS log events, containing at least the columns of [[DNSSuspiciousConnectsModel.ModelSchema]]
* @param userDomain Domain associated to network data (ex: 'intel')
* @return Dataframe with a column named [[org.apache.spot.dns.DNSSchema.Score]] that contains the
* probability estimated for the network event at that row
*/
def score(sparkSession: SparkSession, inDF: DataFrame, userDomain: String, precisionUtility: FloatPointPrecisionUtility): DataFrame = {
val topDomainsBC = sparkSession.sparkContext.broadcast(TopDomains.TopDomains)
val wordToPerTopicProbBC = sparkSession.sparkContext.broadcast(wordToPerTopicProb)
val scoreFunction =
new DNSScoreFunction(topicCount,
wordToPerTopicProbBC,
topDomainsBC,
userDomain)
val scoringUDF = udf((timeStamp: String,
unixTimeStamp: Long,
frameLength: Int,
clientIP: String,
queryName: String,
queryClass: String,
queryType: Int,
queryResponseCode: Int,
documentTopicMix: Seq[precisionUtility.TargetType]) =>
scoreFunction.score(precisionUtility)(timeStamp,
unixTimeStamp,
frameLength,
clientIP,
queryName,
queryClass,
queryType,
queryResponseCode,
documentTopicMix))
inDF
.join(org.apache.spark.sql.functions.broadcast(ipToTopicMix), inDF(ClientIP) === ipToTopicMix(DocumentName),
"left_outer")
.selectExpr(inDF.schema.fieldNames :+ TopicProbabilityMix: _*)
.withColumn(Score, scoringUDF(DNSSuspiciousConnectsModel.modelColumns :+ col(TopicProbabilityMix): _*))
.drop(TopicProbabilityMix)
}
}
/**
* Contains dataframe schema information as well as the train-from-dataframe routine
* (which is a kind of factory routine) for [[DNSSuspiciousConnectsModel]] instances.
*
*/
object DNSSuspiciousConnectsModel {
val ModelSchema = StructType(List(TimeStampField,
UnixTimeStampField,
FrameLengthField,
ClientIPField,
QueryNameField,
QueryClassField,
QueryTypeField,
QueryResponseCodeField))
val modelColumns: List[Column] = ModelSchema.fieldNames.toList.map(col)
val DomainStatsSchema = StructType(List(TopDomainField, SubdomainLengthField, SubdomainEntropyField, NumPeriodsField))
/**
* Create a new DNS Suspicious Connects model by training it on a data frame and a feedback file.
*
* @param sparkSession Spark Session
* @param logger Application logger
* @param config Analysis configuration object containing CLI parameters.
* Contains the path to the feedback file in config.scoresFile
* @param inputRecords Data used to train the model.
* @return A new [[DNSSuspiciousConnectsModel]] instance trained on the dataframe and feedback file.
*/
def trainModel(sparkSession: SparkSession,
logger: Logger,
config: SuspiciousConnectsConfig,
inputRecords: DataFrame): DNSSuspiciousConnectsModel = {
logger.info("Training DNS suspicious connects model from " + config.inputPath)
val selectedRecords = inputRecords.select(modelColumns: _*)
val totalRecords = selectedRecords.union(DNSFeedback.loadFeedbackDF(sparkSession,
config.feedbackFile,
config.duplicationFactor))
val topDomainsBC = sparkSession.sparkContext.broadcast(TopDomains.TopDomains)
val userDomain = config.userDomain
// simplify DNS log entries into "words"
val dnsWordCreator = new DNSWordCreation(topDomainsBC, userDomain)
val dataWithWord = totalRecords.withColumn(Word, dnsWordCreator.wordCreationUDF(modelColumns: _*))
import sparkSession.implicits._
// aggregate per-word counts at each IP
val ipDstWordCounts =
dataWithWord
.select(ClientIP, Word)
.filter(dataWithWord(Word).notEqual(InvalidDataHandler.WordError))
.map({ case Row(destIP: String, word: String) => (destIP, word) -> 1 })
.rdd
.reduceByKey(_ + _)
.map({ case ((ipDst, word), count) => SpotLDAInput(ipDst, word, count) })
val spotLDAHelper: SpotLDAHelper = SpotLDAHelper(ipDstWordCounts, config.precisionUtility, sparkSession)
val model: SpotLDAModel = SpotLDAWrapper.run(config.topicCount,
logger,
config.ldaPRGSeed,
config.ldaAlpha,
config.ldaBeta,
config.ldaOptimizer,
config.ldaMaxiterations,
spotLDAHelper)
val results: SpotLDAResult = model.predict(spotLDAHelper)
new DNSSuspiciousConnectsModel(config.topicCount, results.documentToTopicMix, results.wordToTopicMix)
}
/**
*
* @param countryCodesBC Broadcast of the country codes set.
* @param topDomainsBC Broadcast of the most-popular domains set.
* @param userDomain Domain associated to network data (ex: 'intel')
* @param url URL string to anlayze for domain and subdomain information.
* @return [[TempFields]]
*/
def createTempFields(countryCodesBC: Broadcast[Set[String]],
topDomainsBC: Broadcast[Set[String]],
userDomain: String,
url: String): TempFields = {
val DomainInfo(_, topDomainClass, _, subDomainLength, subDomainEntropy, numPeriods) =
DomainProcessor.extractDomainInfo(url, topDomainsBC, userDomain)
TempFields(topDomainClass = topDomainClass,
subdomainLength = subDomainLength,
subdomainEntropy = subDomainEntropy,
numPeriods = numPeriods)
}
case class TempFields(topDomainClass: Int, subdomainLength: Integer, subdomainEntropy: Double, numPeriods: Integer)
}