blob: 54e97cebf06b495dac75bb2ac125b65db629d66b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spot.dns
import org.apache.log4j.Logger
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spot.SuspiciousConnects.SuspiciousConnectsAnalysisResults
import org.apache.spot.SuspiciousConnectsArgumentParser.SuspiciousConnectsConfig
import org.apache.spot.dns.DNSSchema._
import org.apache.spot.dns.model.DNSSuspiciousConnectsModel
import org.apache.spot.proxy.ProxySchema.Score
import org.apache.spot.utilities.data.validation.InputSchema.InputSchemaValidationResponse
import org.apache.spot.utilities.data.validation.{InputSchema, InvalidDataHandler => dataValidation}
/**
* The suspicious connections analysis of DNS log data develops a probabilistic model the DNS queries
* made by each client IP and flags those assigned a low probability as "suspicious".
*/
object DNSSuspiciousConnectsAnalysis {
val DefaultQueryClass = "unknown"
val DefaultQueryType = -1
val DefaultQueryResponseCode = -1
val InStructType = StructType(List(TimeStampField, UnixTimeStampField, FrameLengthField, ClientIPField,
QueryNameField, QueryClassField, QueryTypeField, QueryResponseCodeField))
val InSchema = InStructType.fieldNames.map(col)
val OutSchema = StructType(
List(TimeStampField,
UnixTimeStampField,
FrameLengthField,
ClientIPField,
QueryNameField,
QueryClassField,
QueryTypeField,
QueryResponseCodeField,
ScoreField)).fieldNames.map(col)
/**
* Run suspicious connections analysis on DNS log data.
* Saves the most suspicious connections to a CSV file on HDFS.
*
* @param config Object encapsulating runtime parameters and CLI options.
* @param sparkSession
* @param logger
*/
def run(config: SuspiciousConnectsConfig, sparkSession: SparkSession, logger: Logger,
inputDNSRecords: DataFrame): Option[SuspiciousConnectsAnalysisResults] = {
logger.info("Starting DNS suspicious connects analysis.")
logger.info("Validating schema...")
val InputSchemaValidationResponse(isValid, errorMessages) = validateSchema(inputDNSRecords)
if (!isValid) {
errorMessages.foreach(logger.error(_))
None
} else {
val dnsRecords = filterRecords(inputDNSRecords)
.select(InSchema: _*)
.na.fill(DefaultQueryClass, Seq(QueryClass))
.na.fill(DefaultQueryType, Seq(QueryType))
.na.fill(DefaultQueryResponseCode, Seq(QueryResponseCode))
logger.info("Fitting probabilistic model to data")
val model =
DNSSuspiciousConnectsModel.trainModel(sparkSession, logger, config, dnsRecords)
logger.info("Identifying outliers")
val scoredDNSRecords = model.score(sparkSession, dnsRecords, config.userDomain, config.precisionUtility)
val filteredScored = filterScoredRecords(scoredDNSRecords, config.threshold)
val orderedDNSRecords = filteredScored.orderBy(Score)
val mostSuspiciousDNSRecords =
if (config.maxResults > 0) orderedDNSRecords.limit(config.maxResults)
else orderedDNSRecords
val outputDNSRecords = mostSuspiciousDNSRecords.select(OutSchema: _*)
val invalidDNSRecords = filterInvalidRecords(inputDNSRecords).select(InSchema: _*)
Option(SuspiciousConnectsAnalysisResults(outputDNSRecords, invalidDNSRecords))
}
}
/**
* Return all DNS records with valid values.
*
* @param inputDNSRecords raw DNS records.
* @return
* @see filterInvalidRecords(DataFrame)
*/
def filterRecords(inputDNSRecords: DataFrame): DataFrame = {
val cleanDNSRecordsFilter = inputDNSRecords(TimeStamp).isNotNull &&
inputDNSRecords(TimeStamp).notEqual("") &&
inputDNSRecords(TimeStamp).notEqual("-") &&
inputDNSRecords(UnixTimeStamp).geq(0) &&
inputDNSRecords(FrameLength).geq(0) &&
inputDNSRecords(QueryName).isNotNull &&
inputDNSRecords(QueryName).notEqual("") &&
inputDNSRecords(QueryName).notEqual("-") &&
inputDNSRecords(QueryName).notEqual("(empty)") &&
inputDNSRecords(ClientIP).isNotNull &&
inputDNSRecords(ClientIP).notEqual("") &&
inputDNSRecords(ClientIP).notEqual("-") &&
((inputDNSRecords(QueryClass).isNotNull &&
inputDNSRecords(QueryClass).notEqual("") &&
inputDNSRecords(QueryClass).notEqual("-")) ||
inputDNSRecords(QueryType).isNotNull ||
inputDNSRecords(QueryResponseCode).geq(0))
inputDNSRecords
.filter(cleanDNSRecordsFilter)
}
/**
* Return all DNS records with invalid values.
*
* @param inputDNSRecords raw DNS records.
* @return
* @see filterRecords(DataFrame)
*/
def filterInvalidRecords(inputDNSRecords: DataFrame): DataFrame = {
val invalidDNSRecordsFilter = inputDNSRecords(TimeStamp).isNull ||
inputDNSRecords(TimeStamp).equalTo("") ||
inputDNSRecords(TimeStamp).equalTo("-") ||
inputDNSRecords(UnixTimeStamp).isNull ||
inputDNSRecords(FrameLength).isNull ||
inputDNSRecords(QueryName).isNull ||
inputDNSRecords(QueryName).equalTo("") ||
inputDNSRecords(QueryName).equalTo("-") ||
inputDNSRecords(QueryName).equalTo("(empty)") ||
inputDNSRecords(ClientIP).isNull ||
inputDNSRecords(ClientIP).equalTo("") ||
inputDNSRecords(ClientIP).equalTo("-") ||
((inputDNSRecords(QueryClass).isNull ||
inputDNSRecords(QueryClass).equalTo("") ||
inputDNSRecords(QueryClass).equalTo("-")) &&
inputDNSRecords(QueryType).isNull &&
inputDNSRecords(QueryResponseCode).isNull)
inputDNSRecords
.filter(invalidDNSRecordsFilter)
}
/**
*
* @param scoredDNSRecords scored DNS records.
* @param threshold score tolerance.
* @return
*/
def filterScoredRecords(scoredDNSRecords: DataFrame, threshold: Double): DataFrame = {
val filteredDNSRecordsFilter = scoredDNSRecords(Score).leq(threshold) &&
scoredDNSRecords(Score).gt(dataValidation.ScoreError)
scoredDNSRecords.filter(filteredDNSRecordsFilter)
}
/**
* Validates incoming data matches required schema for model training and scoring.
*
* @param inputDNSRecords incoming data frame
* @return
*/
def validateSchema(inputDNSRecords: DataFrame): InputSchemaValidationResponse = {
InputSchema.validate(inputDNSRecords.schema, DNSSuspiciousConnectsModel.ModelSchema)
}
}