/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.pirk.test.utils;

import static org.junit.Assert.fail;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import org.apache.hadoop.fs.FileSystem;
import org.apache.pirk.query.wideskies.QueryUtils;
import org.apache.pirk.schema.query.QuerySchema;
import org.apache.pirk.schema.query.QuerySchemaRegistry;
import org.apache.pirk.schema.response.QueryResponseJSON;
import org.apache.pirk.test.distributed.testsuite.DistTestSuite;
import org.apache.pirk.utils.StringUtils;
import org.apache.pirk.utils.SystemConfiguration;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Class to hold the base functional distributed tests
 */
public class BaseTests
{
  private static final Logger logger = LoggerFactory.getLogger(BaseTests.class);

  public static final UUID queryIdentifier = UUID.randomUUID();
  public static final int dataPartitionBitSize = 8;

  // Selectors for domain and IP queries, queryIdentifier is the first entry for file generation
  private static ArrayList<String> selectorsDomain = new ArrayList<>(Arrays.asList("s.t.u.net", "d.e.com", "r.r.r.r", "a.b.c.com", "something.else", "x.y.net"));
  private static ArrayList<String> selectorsIP = new ArrayList<>(Arrays.asList("55.55.55.55", "5.6.7.8", "10.20.30.40", "13.14.15.16", "21.22.23.24"));

  // Encryption variables -- Paillier mechanisms are tested in the Paillier test code, so these are fixed...
  public static final int hashBitSize = 12;
  public static final String hashKey = "someKey";
  public static final int paillierBitSize = 384;
  public static final int certainty = 128;

  public static void testDNSHostnameQuery(ArrayList<JSONObject> dataElements, int numThreads, boolean testFalsePositive) throws Exception
  {
    testDNSHostnameQuery(dataElements, null, false, false, numThreads, testFalsePositive, false);
  }

  public static void testDNSHostnameQuery(List<JSONObject> dataElements, FileSystem fs, boolean isSpark, boolean isDistributed, int numThreads)
      throws Exception
  {
    testDNSHostnameQuery(dataElements, fs, isSpark, isDistributed, numThreads, false, false);
  }

  // Query for the watched hostname occurred; ; watched value type: hostname (String)
  public static void testDNSHostnameQuery(List<JSONObject> dataElements, FileSystem fs, boolean isSpark, boolean isDistributed, int numThreads,
      boolean testFalsePositive, boolean isStreaming) throws Exception
  {
    logger.info("Running testDNSHostnameQuery(): ");

    int numExpectedResults = 6;
    List<QueryResponseJSON> results;
    if (isDistributed)
    {
      results = DistTestSuite.performQuery(Inputs.DNS_HOSTNAME_QUERY, selectorsDomain, fs, isSpark, numThreads, isStreaming);
    }
    else
    {
      results = StandaloneQuery.performStandaloneQuery(dataElements, Inputs.DNS_HOSTNAME_QUERY, selectorsDomain, numThreads, testFalsePositive);
      if (!testFalsePositive)
      {
        numExpectedResults = 7; // all 7 for non distributed case; if testFalsePositive==true, then 6
      }
    }
    checkDNSHostnameQueryResults(results, isDistributed, numExpectedResults, testFalsePositive, dataElements);
    logger.info("Completed testDNSHostnameQuery(): ");
  }

  public static void checkDNSHostnameQueryResults(List<QueryResponseJSON> results, boolean isDistributed, int numExpectedResults, boolean testFalsePositive,
      List<JSONObject> dataElements)
  {
    QuerySchema qSchema = QuerySchemaRegistry.get(Inputs.DNS_HOSTNAME_QUERY);
    logger.info("results:");
    printResultList(results);

    if (isDistributed && SystemConfiguration.isSetTrue("pir.limitHitsPerSelector"))
    {
      // 3 elements returned - one for each qname -- a.b.c.com, d.e.com, something.else
      if (results.size() != 3)
      {
        fail("results.size() = " + results.size() + " -- must equal 3");
      }

      // Check that each qname appears once in the result set
      HashSet<String> correctQnames = new HashSet<>();
      correctQnames.add("a.b.c.com");
      correctQnames.add("d.e.com");
      correctQnames.add("something.else");

      HashSet<String> resultQnames = new HashSet<>();
      for (QueryResponseJSON qrJSON : results)
      {
        resultQnames.add((String) qrJSON.getValue(Inputs.QNAME));
      }

      if (correctQnames.size() != resultQnames.size())
      {
        fail("correctQnames.size() = " + correctQnames.size() + " != resultQnames.size() " + resultQnames.size());
      }

      for (String resultQname : resultQnames)
      {
        if (!correctQnames.contains(resultQname))
        {
          fail("correctQnames does not contain resultQname = " + resultQname);
        }
      }
    }
    else
    {
      if (results.size() != numExpectedResults)
      {
        fail("results.size() = " + results.size() + " -- must equal " + numExpectedResults);
      }

      // Number of original elements at the end of the list that we do not need to consider for hits
      int removeTailElements = 2; // the last two data elements should not hit
      if (testFalsePositive)
      {
        removeTailElements = 3;
      }

      ArrayList<QueryResponseJSON> correctResults = new ArrayList<>();
      int i = 0;
      while (i < (dataElements.size() - removeTailElements))
      {
        JSONObject dataMap = dataElements.get(i);

        boolean addElement = true;
        if (isDistributed && dataMap.get(Inputs.RCODE).toString().equals("3"))
        {
          addElement = false;
        }
        if (addElement)
        {
          QueryResponseJSON wlJSON = new QueryResponseJSON();
          wlJSON.setMapping(QueryResponseJSON.QUERY_ID, queryIdentifier.toString());
          wlJSON.setMapping(QueryResponseJSON.EVENT_TYPE, Inputs.DNS_HOSTNAME_QUERY);
          wlJSON.setMapping(Inputs.DATE, dataMap.get(Inputs.DATE));
          wlJSON.setMapping(Inputs.SRCIP, dataMap.get(Inputs.SRCIP));
          wlJSON.setMapping(Inputs.DSTIP, dataMap.get(Inputs.DSTIP));
          wlJSON.setMapping(Inputs.QNAME, dataMap.get(Inputs.QNAME)); // this gets re-embedded as the original selector after decryption
          wlJSON.setMapping(Inputs.QTYPE, parseShortArray(dataMap, Inputs.QTYPE));
          wlJSON.setMapping(Inputs.RCODE, dataMap.get(Inputs.RCODE));
          wlJSON.setMapping(Inputs.IPS, parseArray(dataMap, Inputs.IPS, true));
          wlJSON.setMapping(QueryResponseJSON.SELECTOR, QueryUtils.getSelectorByQueryTypeJSON(qSchema, dataMap));
          correctResults.add(wlJSON);
        }
        ++i;
      }
      logger.info("correctResults: ");
      printResultList(correctResults);

      if (results.size() != correctResults.size())
      {
        logger.info("correctResults:");
        printResultList(correctResults);
        fail("results.size() = " + results.size() + " != correctResults.size() = " + correctResults.size());
      }
      for (QueryResponseJSON result : results)
      {
        if (!compareResultArray(correctResults, result))
        {
          fail("correctResults does not contain result = " + result.toString());
        }
      }
    }
  }

  public static void testDNSIPQuery(ArrayList<JSONObject> dataElements, int numThreads) throws Exception
  {
    testDNSIPQuery(dataElements, null, false, false, numThreads, false);
  }

  // The watched IP address was detected in the response to a query; watched value type: IP address (String)
  public static void testDNSIPQuery(List<JSONObject> dataElements, FileSystem fs, boolean isSpark, boolean isDistributed, int numThreads, boolean isStreaming)
      throws Exception
  {
    logger.info("Running testDNSIPQuery(): ");

    QuerySchema qSchema = QuerySchemaRegistry.get(Inputs.DNS_IP_QUERY);
    List<QueryResponseJSON> results;

    if (isDistributed)
    {
      results = DistTestSuite.performQuery(Inputs.DNS_IP_QUERY, selectorsIP, fs, isSpark, numThreads, isStreaming);

      if (results.size() != 5)
      {
        fail("results.size() = " + results.size() + " -- must equal 5");
      }
    }
    else
    {
      results = StandaloneQuery.performStandaloneQuery(dataElements, Inputs.DNS_IP_QUERY, selectorsIP, numThreads, false);

      if (results.size() != 6)
      {
        fail("results.size() = " + results.size() + " -- must equal 6");
      }
    }
    printResultList(results);

    ArrayList<QueryResponseJSON> correctResults = new ArrayList<>();
    int i = 0;
    while (i < (dataElements.size() - 3)) // last three data elements not hit - one on stoplist, two don't match selectors
    {
      JSONObject dataMap = dataElements.get(i);

      boolean addElement = true;
      if (isDistributed && dataMap.get(Inputs.RCODE).toString().equals("3"))
      {
        addElement = false;
      }
      if (addElement)
      {
        QueryResponseJSON wlJSON = new QueryResponseJSON();
        wlJSON.setMapping(QueryResponseJSON.QUERY_ID, queryIdentifier);
        wlJSON.setMapping(QueryResponseJSON.EVENT_TYPE, Inputs.DNS_IP_QUERY);
        wlJSON.setMapping(Inputs.SRCIP, dataMap.get(Inputs.SRCIP));
        wlJSON.setMapping(Inputs.DSTIP, dataMap.get(Inputs.DSTIP));
        wlJSON.setMapping(Inputs.IPS, parseArray(dataMap, Inputs.IPS, true));
        wlJSON.setMapping(QueryResponseJSON.SELECTOR, QueryUtils.getSelectorByQueryTypeJSON(qSchema, dataMap));
        correctResults.add(wlJSON);
      }
      ++i;
    }
    if (results.size() != correctResults.size())
    {
      logger.info("correctResults:");
      printResultList(correctResults);
      fail("results.size() = " + results.size() + " != correctResults.size() = " + correctResults.size());
    }
    for (QueryResponseJSON result : results)
    {
      if (!compareResultArray(correctResults, result))
      {
        fail("correctResults does not contain result = " + result.toString());
      }
    }
    logger.info("Completed testDNSIPQuery(): ");
  }

  public static void testDNSNXDOMAINQuery(ArrayList<JSONObject> dataElements, int numThreads) throws Exception
  {
    testDNSNXDOMAINQuery(dataElements, null, false, false, numThreads);
  }

  // A query that returned an nxdomain response was made for the watched hostname; watched value type: hostname (String)
  public static void testDNSNXDOMAINQuery(List<JSONObject> dataElements, FileSystem fs, boolean isSpark, boolean isDistributed, int numThreads)
      throws Exception
  {
    logger.info("Running testDNSNXDOMAINQuery(): ");

    QuerySchema qSchema = QuerySchemaRegistry.get(Inputs.DNS_NXDOMAIN_QUERY);
    List<QueryResponseJSON> results;

    if (isDistributed)
    {
      results = DistTestSuite.performQuery(Inputs.DNS_NXDOMAIN_QUERY, selectorsDomain, fs, isSpark, numThreads, false);
    }
    else
    {
      results = StandaloneQuery.performStandaloneQuery(dataElements, Inputs.DNS_NXDOMAIN_QUERY, selectorsDomain, numThreads, false);
    }
    printResultList(results);

    if (results.size() != 1)
    {
      fail("results.size() = " + results.size() + " -- must equal 1");
    }

    ArrayList<QueryResponseJSON> correctResults = new ArrayList<>();
    int i = 0;
    while (i < dataElements.size())
    {
      JSONObject dataMap = dataElements.get(i);

      if (dataMap.get(Inputs.RCODE).toString().equals("3"))
      {
        QueryResponseJSON wlJSON = new QueryResponseJSON();
        wlJSON.setMapping(QueryResponseJSON.QUERY_ID, queryIdentifier);
        wlJSON.setMapping(QueryResponseJSON.EVENT_TYPE, Inputs.DNS_NXDOMAIN_QUERY);
        wlJSON.setMapping(Inputs.QNAME, dataMap.get(Inputs.QNAME)); // this gets re-embedded as the original selector after decryption
        wlJSON.setMapping(Inputs.DSTIP, dataMap.get(Inputs.DSTIP));
        wlJSON.setMapping(Inputs.SRCIP, dataMap.get(Inputs.SRCIP));
        wlJSON.setMapping(QueryResponseJSON.SELECTOR, QueryUtils.getSelectorByQueryTypeJSON(qSchema, dataMap));
        correctResults.add(wlJSON);
      }
      ++i;
    }
    if (results.size() != correctResults.size())
    {
      logger.info("correctResults:");
      printResultList(correctResults);
      fail("results.size() = " + results.size() + " != correctResults.size() = " + correctResults.size());
    }
    for (QueryResponseJSON result : results)
    {
      if (!compareResultArray(correctResults, result))
      {
        fail("correctResults does not contain result = " + result.toString());
      }
    }
    logger.info("Completed testDNSNXDOMAINQuery(): ");
  }

  public static void testSRCIPQuery(ArrayList<JSONObject> dataElements, int numThreads) throws Exception
  {
    testSRCIPQuery(dataElements, null, false, false, numThreads, false);
  }

  // Query for responses from watched srcIPs
  public static void testSRCIPQuery(List<JSONObject> dataElements, FileSystem fs, boolean isSpark, boolean isDistributed, int numThreads, boolean isStreaming)
      throws Exception
  {
    logger.info("Running testSRCIPQuery(): ");

    QuerySchema qSchema = QuerySchemaRegistry.get(Inputs.DNS_SRCIP_QUERY);
    List<QueryResponseJSON> results;

    int removeTailElements = 0;
    int numExpectedResults = 1;
    if (isDistributed)
    {
      results = DistTestSuite.performQuery(Inputs.DNS_SRCIP_QUERY, selectorsIP, fs, isSpark, numThreads, isStreaming);
      removeTailElements = 2; // The last two elements are on the distributed stoplist
    }
    else
    {
      numExpectedResults = 3;
      results = StandaloneQuery.performStandaloneQuery(dataElements, Inputs.DNS_SRCIP_QUERY, selectorsIP, numThreads, false);
    }
    printResultList(results);

    if (results.size() != numExpectedResults)
    {
      fail("results.size() = " + results.size() + " -- must equal " + numExpectedResults);
    }

    ArrayList<QueryResponseJSON> correctResults = new ArrayList<>();
    int i = 0;
    while (i < (dataElements.size() - removeTailElements))
    {
      JSONObject dataMap = dataElements.get(i);

      boolean addElement = false;
      if (dataMap.get(Inputs.SRCIP).toString().equals("55.55.55.55") || dataMap.get(Inputs.SRCIP).toString().equals("5.6.7.8"))
      {
        addElement = true;
      }
      if (addElement)
      {
        // Form the correct result QueryResponseJSON object
        QueryResponseJSON qrJSON = new QueryResponseJSON();
        qrJSON.setMapping(QueryResponseJSON.QUERY_ID, queryIdentifier);
        qrJSON.setMapping(QueryResponseJSON.EVENT_TYPE, Inputs.DNS_SRCIP_QUERY);
        qrJSON.setMapping(Inputs.QNAME, parseString(dataMap, Inputs.QNAME));
        qrJSON.setMapping(Inputs.DSTIP, dataMap.get(Inputs.DSTIP));
        qrJSON.setMapping(Inputs.SRCIP, dataMap.get(Inputs.SRCIP));
        qrJSON.setMapping(Inputs.IPS, parseArray(dataMap, Inputs.IPS, true));
        qrJSON.setMapping(QueryResponseJSON.SELECTOR, QueryUtils.getSelectorByQueryTypeJSON(qSchema, dataMap));
        correctResults.add(qrJSON);
      }
      ++i;
    }
    logger.info("correctResults:");
    printResultList(correctResults);

    if (results.size() != correctResults.size())
    {
      logger.info("correctResults:");
      printResultList(correctResults);
      fail("results.size() = " + results.size() + " != correctResults.size() = " + correctResults.size());
    }
    for (QueryResponseJSON result : results)
    {
      if (!compareResultArray(correctResults, result))
      {
        fail("correctResults does not contain result = " + result.toString());
      }
    }
    logger.info("Completed testSRCIPQuery(): ");
  }

  // Query for responses from watched srcIPs
  public static void testSRCIPQueryNoFilter(List<JSONObject> dataElements, FileSystem fs, boolean isSpark, boolean isDistributed, int numThreads,
      boolean isStreaming) throws Exception
  {
    logger.info("Running testSRCIPQueryNoFilter(): ");

    QuerySchema qSchema = QuerySchemaRegistry.get(Inputs.DNS_SRCIP_QUERY_NO_FILTER);
    List<QueryResponseJSON> results;

    int numExpectedResults = 3;
    if (isDistributed)
    {
      results = DistTestSuite.performQuery(Inputs.DNS_SRCIP_QUERY_NO_FILTER, selectorsIP, fs, isSpark, numThreads, isStreaming);
    }
    else
    {
      results = StandaloneQuery.performStandaloneQuery(dataElements, Inputs.DNS_SRCIP_QUERY_NO_FILTER, selectorsIP, numThreads, false);
    }
    printResultList(results);

    if (results.size() != numExpectedResults)
    {
      fail("results.size() = " + results.size() + " -- must equal " + numExpectedResults);
    }

    ArrayList<QueryResponseJSON> correctResults = new ArrayList<>();
    int i = 0;
    while (i < dataElements.size())
    {
      JSONObject dataMap = dataElements.get(i);

      boolean addElement = false;
      if (dataMap.get(Inputs.SRCIP).toString().equals("55.55.55.55") || dataMap.get(Inputs.SRCIP).toString().equals("5.6.7.8"))
      {
        addElement = true;
      }
      if (addElement)
      {
        // Form the correct result QueryResponseJSON object
        QueryResponseJSON qrJSON = new QueryResponseJSON();
        qrJSON.setMapping(QueryResponseJSON.QUERY_ID, queryIdentifier);
        qrJSON.setMapping(QueryResponseJSON.EVENT_TYPE, Inputs.DNS_SRCIP_QUERY_NO_FILTER);
        qrJSON.setMapping(Inputs.QNAME, parseString(dataMap, Inputs.QNAME));
        qrJSON.setMapping(Inputs.DSTIP, dataMap.get(Inputs.DSTIP));
        qrJSON.setMapping(Inputs.SRCIP, dataMap.get(Inputs.SRCIP));
        qrJSON.setMapping(Inputs.IPS, parseArray(dataMap, Inputs.IPS, true));
        qrJSON.setMapping(QueryResponseJSON.SELECTOR, QueryUtils.getSelectorByQueryTypeJSON(qSchema, dataMap));
        correctResults.add(qrJSON);
      }
      ++i;
    }
    logger.info("correctResults:");
    printResultList(correctResults);

    if (results.size() != correctResults.size())
    {
      logger.info("correctResults:");
      printResultList(correctResults);
      fail("results.size() = " + results.size() + " != correctResults.size() = " + correctResults.size());
    }
    for (QueryResponseJSON result : results)
    {
      if (!compareResultArray(correctResults, result))
      {
        fail("correctResults does not contain result = " + result.toString());
      }
    }
    logger.info("Completed testSRCIPQueryNoFilter(): ");
  }

  @SuppressWarnings("unchecked")
  // Method to convert a ArrayList<String> into the correct (padded) returned ArrayList
  private static ArrayList<String> parseArray(JSONObject dataMap, String fieldName, boolean isIP)
  {
    ArrayList<String> retArray = new ArrayList<>();

    ArrayList<String> values;
    if (dataMap.get(fieldName) instanceof ArrayList)
    {
      values = (ArrayList<String>) dataMap.get(fieldName);
    }
    else
    {
      values = StringUtils.jsonArrayStringToArrayList((String) dataMap.get(fieldName));
    }

    int numArrayElementsToReturn = SystemConfiguration.getIntProperty("pir.numReturnArrayElements", 1);
    for (int i = 0; i < numArrayElementsToReturn; ++i)
    {
      if (i < values.size())
      {
        retArray.add(values.get(i));
      }
      else if (isIP)
      {
        retArray.add("0.0.0.0");
      }
      else
      {
        retArray.add("0");
      }
    }

    return retArray;
  }

  // Method to convert a ArrayList<Short> into the correct (padded) returned ArrayList
  private static ArrayList<Short> parseShortArray(JSONObject dataMap, String fieldName)
  {
    ArrayList<Short> retArray = new ArrayList<>();

    ArrayList<Short> values = (ArrayList<Short>) dataMap.get(fieldName);

    int numArrayElementsToReturn = SystemConfiguration.getIntProperty("pir.numReturnArrayElements", 1);
    for (int i = 0; i < numArrayElementsToReturn; ++i)
    {
      if (i < values.size())
      {
        retArray.add(values.get(i));
      }
      else
      {
        retArray.add((short) 0);
      }
    }

    return retArray;
  }

  // Method to convert the String field value to the correct returned substring
  private static String parseString(JSONObject dataMap, String fieldName)
  {
    String ret;

    String element = (String) dataMap.get(fieldName);
    int numParts = Integer.parseInt(SystemConfiguration.getProperty("pir.stringBits")) / dataPartitionBitSize;
    int len = numParts;
    if (element.length() < numParts)
    {
      len = element.length();
    }
    ret = new String(element.getBytes(), 0, len);

    return ret;
  }

  // Method to determine whether or not the correctResults contains an object equivalent to
  // the given result
  private static boolean compareResultArray(ArrayList<QueryResponseJSON> correctResults, QueryResponseJSON result)
  {
    boolean equivalent = false;

    for (QueryResponseJSON correct : correctResults)
    {
      equivalent = compareResults(correct, result);
      if (equivalent)
      {
        break;
      }
    }

    return equivalent;
  }

  // Method to test the equivalence of two test results
  private static boolean compareResults(QueryResponseJSON r1, QueryResponseJSON r2)
  {
    boolean equivalent = true;

    JSONObject jsonR1 = r1.getJSONObject();
    JSONObject jsonR2 = r2.getJSONObject();

    Set<String> r1KeySet = jsonR1.keySet();
    Set<String> r2KeySet = jsonR2.keySet();
    if (!r1KeySet.equals(r2KeySet))
    {
      equivalent = false;
    }
    if (equivalent)
    {
      for (String key : r1KeySet)
      {
        if (key.equals(Inputs.QTYPE) || key.equals(Inputs.IPS)) // array types
        {
          HashSet<String> set1 = getSetFromList(jsonR1.get(key));
          HashSet<String> set2 = getSetFromList(jsonR2.get(key));

          if (!set1.equals(set2))
          {
            equivalent = false;
          }
        }
        else
        {
          if (!(jsonR1.get(key).toString()).equals(jsonR2.get(key).toString()))
          {
            equivalent = false;
          }
        }
      }
    }
    return equivalent;
  }

  // Method to pull the elements of a list (either an ArrayList or JSONArray) into a HashSet
  private static HashSet<String> getSetFromList(Object list)
  {
    HashSet<String> set = new HashSet<>();

    if (list instanceof ArrayList)
    {
      for (Object obj : (ArrayList) list)
      {
        set.add(obj.toString());
      }
    }
    else
    // JSONArray
    {
      for (Object obj : (JSONArray) list)
      {
        set.add(obj.toString());
      }
    }

    return set;
  }

  private static void printResultList(List<QueryResponseJSON> list)
  {
    for (QueryResponseJSON obj : list)
    {
      logger.info(obj.toString());
    }
  }
}
