package org.apache.s2graph.s2jobs.wal.udafs

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.annotation.tailrec
import scala.collection.mutable

object S2EdgeDataAggregate {
  type Element = (Long, String, String, String)

  val emptyRow = new GenericRow(Array(-1L, "empty", "empty", "empty"))

  val elementOrd = Ordering.by[Element, Long](_._1)

  val rowOrdering = new Ordering[Row] {
    override def compare(x: Row, y: Row): Int = {
      x.getAs[Long](0).compareTo(y.getAs[Long](0))
    }
  }

  val rowOrderingDesc = new Ordering[Row] {
    override def compare(x: Row, y: Row): Int = {
      -x.getAs[Long](0).compareTo(y.getAs[Long](0))
    }
  }

  val fields = Seq(
    StructField(name = "timestamp", LongType),
    StructField(name = "to", StringType),
    StructField(name = "label", StringType),
    StructField(name = "props", StringType)
  )

  val arrayType = ArrayType(elementType = StructType(fields))

  def apply(maxNumOfEdges: Int = 1000): GroupByAggOptimized = {
    new GroupByAggOptimized(maxNumOfEdges)
  }

  def swap[T](array: mutable.Seq[T], i: Int, j: Int) = {
    val tmp = array(i)
    array(i) = array(j)
    array(j) = tmp
  }

  @tailrec
  def percolateDown[T](array: mutable.Seq[T], idx: Int)(implicit ordering: Ordering[T]): Unit = {
    val left = 2 * idx + 1
    val right = 2 * idx + 2
    var smallest = idx

    if (left < array.size && ordering.compare(array(left), array(smallest)) < 0) {
      smallest = left
    }

    if (right < array.size && ordering.compare(array(right), array(smallest)) < 0) {
      smallest = right
    }

    if (smallest != idx) {
      swap(array, idx, smallest)
      percolateDown(array, smallest)
    }
  }

  def percolateUp[T](array: mutable.Seq[T],
                     idx: Int)(implicit ordering: Ordering[T]): Unit = {
    var pos = idx
    var parent = (pos - 1) / 2
    while (parent >= 0 && ordering.compare(array(pos), array(parent)) < 0) {
      // swap pos and parent, since a[parent] > array[pos]
      swap(array, parent, pos)
      pos = parent
      parent = (pos - 1) / 2
    }
  }

  def addToTopK[T](array: mutable.Seq[T],
                   size: Int,
                   newData: T)(implicit ordering: Ordering[T]): mutable.Seq[T] = {
    // use array as minHeap to keep track of topK.
    // parent = (i -1) / 2
    // left child = 2 * i + 1
    // right chiud = 2  * i + 2

    // check if array is already full.
    if (array.size >= size) {
      // compare newData to min. newData < array(0)
      val currentMin = array(0)
      if (ordering.compare(newData, currentMin) < 0) {
        // drop newData
      } else {
        // delete min
        array(0) = newData
        // percolate down
        percolateDown(array, 0)
      }
      array
    } else {
      // append new element into seqeunce since there are room left.
      val newArray = array :+ newData
      val idx = newArray.size - 1
      // percolate up last element
      percolateUp(newArray, idx)
      newArray
    }
  }

  def mergeTwoSeq[T](prev: Seq[T], cur: Seq[T], size: Int)(implicit ordering: Ordering[T]): Seq[T] = {
    import scala.collection.mutable
    val (n, m) = (cur.size, prev.size)

    var (i, j) = (0, 0)
    var idx = 0
    val arr = new mutable.ArrayBuffer[T](size)

    while (idx < size && i < n && j < m) {
      if (ordering.compare(cur(i), prev(j)) < 0) {
        arr += cur(i)
        i += 1
      } else {
        arr += prev(j)
        j += 1
      }
      idx += 1
    }
    while (idx < size && i < n) {
      arr += cur(i)
      i += 1
    }
    while (idx < size && j < m) {
      arr += prev(j)
      j += 1
    }

    arr
  }
}

class GroupByAggOptimized(maxNumOfEdges: Int = 1000) extends UserDefinedAggregateFunction {

  import S2EdgeDataAggregate._

  implicit val ord = rowOrdering

  val arrayType = ArrayType(elementType = StructType(fields))

  type ROWS = mutable.Seq[Row]

  override def inputSchema: StructType = StructType(fields)

  override def bufferSchema: StructType = StructType(Seq(
    StructField(name = "edges", dataType = arrayType)
  ))

  override def dataType: DataType = arrayType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, mutable.ArrayBuffer.empty[Row])
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val prev = buffer.getAs[ROWS](0)

    val updated = addToTopK(prev, maxNumOfEdges, input)

    buffer.update(0, updated)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    var prev = buffer1.getAs[ROWS](0)
    val cur = buffer2.getAs[ROWS](0)

    cur.filter(_ != null).foreach { row =>
      prev = addToTopK(prev, maxNumOfEdges, row)
    }

    buffer1.update(0, prev)
  }

  override def evaluate(buffer: Row): Any = {
    val ls = buffer.getAs[ROWS](0)
    takeTopK(ls, maxNumOfEdges)
  }

  private def takeTopK(ls: Seq[Row], k: Int) = {
    val sorted = ls.sorted
    if (sorted.size <= k) sorted else sorted.take(k)
  }
}

class GroupByAgg(maxNumOfEdges: Int = 1000) extends UserDefinedAggregateFunction {
  import S2EdgeDataAggregate._

  implicit val ord = rowOrderingDesc

  val arrayType = ArrayType(elementType = StructType(fields))

  override def inputSchema: StructType = StructType(fields)

  override def bufferSchema: StructType = StructType(Seq(
    StructField(name = "edges", dataType = arrayType),
    StructField(name = "buffered", dataType = BooleanType)
  ))

  override def dataType: DataType = arrayType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, scala.collection.mutable.ListBuffer.empty[Element])
  }

  /* not optimized */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val element = input

    val prev = buffer.getAs[Seq[Row]](0)
    val appended = prev :+ element

    buffer.update(0, appended)
    buffer.update(1, false)
  }

  private def takeTopK(ls: Seq[Row], k: Int) = {
    val sorted = ls.sorted
    if (sorted.size <= k) sorted else sorted.take(k)
  }
  /* not optimized */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val cur = buffer2.getAs[Seq[Row]](0)
    val prev = buffer1.getAs[Seq[Row]](0)

    buffer1.update(0, takeTopK(prev ++ cur, maxNumOfEdges))
    buffer1.update(1, true)
  }

  override def evaluate(buffer: Row): Any = {
    val ls = buffer.getAs[Seq[Row]](0)
    val buffered = buffer.getAs[Boolean](1)
    if (buffered) ls
    else takeTopK(ls, maxNumOfEdges)
  }
}

class GroupByArrayAgg(maxNumOfEdges: Int = 1000) extends UserDefinedAggregateFunction {
  import S2EdgeDataAggregate._

  implicit val ord = rowOrdering

  import scala.collection.mutable

  override def inputSchema: StructType = StructType(Seq(
    StructField(name = "edges", dataType = arrayType)
  ))

  override def bufferSchema: StructType = StructType(Seq(
    StructField(name = "edges", dataType = arrayType)
  ))

  override def dataType: DataType = arrayType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit =
    buffer.update(0, mutable.ListBuffer.empty[Row])

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val cur = input.getAs[Seq[Row]](0)
    val prev = buffer.getAs[Seq[Row]](0)
    val merged = mergeTwoSeq(cur, prev, maxNumOfEdges)

    buffer.update(0, merged)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val cur = buffer2.getAs[Seq[Row]](0)
    val prev = buffer1.getAs[Seq[Row]](0)

    val merged = mergeTwoSeq(cur, prev, maxNumOfEdges)
    buffer1.update(0, merged)
  }

  override def evaluate(buffer: Row): Any = buffer.getAs[Seq[Row]](0)
}
