/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.samza.storage.kv

import java.nio.file.Path
import java.util.Optional

import com.google.common.annotations.VisibleForTesting
import org.apache.samza.checkpoint.CheckpointId
import org.apache.samza.util.Logging
import org.apache.samza.serializers._

/**
 * A key-value store wrapper that handles serialization
 */
class SerializedKeyValueStore[K, V](
  store: KeyValueStore[Array[Byte], Array[Byte]],
  keySerde: Serde[K],
  msgSerde: Serde[V],
  metrics: SerializedKeyValueStoreMetrics = new SerializedKeyValueStoreMetrics) extends KeyValueStore[K, V] with Logging {

  def get(key: K): V = {
    val keyBytes = toBytesOrNull(key, keySerde)
    val found = store.get(keyBytes)
    metrics.gets.inc
    fromBytesOrNull(found, msgSerde)
  }

  override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
    metrics.gets.inc(keys.size)
    val mapBytes = store.getAll(serializeKeys(keys))
    if (mapBytes != null) {
      val map = new java.util.HashMap[K, V](mapBytes.size)
      val entryIterator = mapBytes.entrySet.iterator
      while (entryIterator.hasNext) {
        val entry = entryIterator.next
        map.put(fromBytesOrNull(entry.getKey, keySerde), fromBytesOrNull(entry.getValue, msgSerde))
      }
      map
    } else {
      null.asInstanceOf[java.util.Map[K, V]]
    }
  }

  def put(key: K, value: V) {
    val keyBytes = toBytesOrNull(key, keySerde)
    val valBytes = toBytesOrNull(value, msgSerde)
    store.put(keyBytes, valBytes)
    val keySizeBytes = if (keyBytes == null) 0 else keyBytes.length
    val valSizeBytes = if (valBytes == null) 0 else valBytes.length
    metrics.recordKeySizeBytes.update(keySizeBytes)
    metrics.recordValueSizeBytes.update(valSizeBytes)
    updatePutMetrics(1, keySizeBytes, valSizeBytes)
  }

  def putAll(entries: java.util.List[Entry[K, V]]) {
    val list = new java.util.ArrayList[Entry[Array[Byte], Array[Byte]]](entries.size())
    val iter = entries.iterator
    var newMaxRecordKeySizeBytes = 0
    var newMaxRecordSizeBytes = 0
    while (iter.hasNext) {
      val curr = iter.next
      val keyBytes = toBytesOrNull(curr.getKey, keySerde)
      val valBytes = toBytesOrNull(curr.getValue, msgSerde)
      val keySizeBytes = if (keyBytes == null) 0 else keyBytes.length
      val valSizeBytes = if (valBytes == null) 0 else valBytes.length
      metrics.recordKeySizeBytes.update(keySizeBytes)
      metrics.recordValueSizeBytes.update(valSizeBytes)
      newMaxRecordKeySizeBytes = Math.max(newMaxRecordKeySizeBytes, keySizeBytes)
      newMaxRecordSizeBytes = Math.max(newMaxRecordSizeBytes, valSizeBytes)
      list.add(new Entry(keyBytes, valBytes))
    }
    store.putAll(list)
    updatePutMetrics(list.size, newMaxRecordKeySizeBytes, newMaxRecordSizeBytes)
  }

  def delete(key: K) {
    metrics.deletes.inc
    val keyBytes = toBytesOrNull(key, keySerde)
    store.delete(keyBytes)
  }

  override def deleteAll(keys: java.util.List[K]) = {
    metrics.deletes.inc(keys.size)
    store.deleteAll(serializeKeys(keys))
  }

  def range(from: K, to: K): KeyValueIterator[K, V] = {
    metrics.ranges.inc
    val fromBytes = toBytesOrNull(from, keySerde)
    val toBytes = toBytesOrNull(to, keySerde)
    new DeserializingIterator(store.range(fromBytes, toBytes))
  }

  def all(): KeyValueIterator[K, V] = {
    metrics.alls.inc
    new DeserializingIterator(store.all)
  }

  private class DeserializingIterator(iter: KeyValueIterator[Array[Byte], Array[Byte]]) extends KeyValueIterator[K, V] {
    override def hasNext() = iter.hasNext()
    override def remove() = iter.remove()
    override def close() = iter.close()
    override def next(): Entry[K, V] = {
      val nxt = iter.next()
      val key = fromBytesOrNull(nxt.getKey, keySerde)
      val value = fromBytesOrNull(nxt.getValue, msgSerde)
      new Entry(key, value)
    }
  }

  def flush {
    trace("Flushing store.")

    metrics.flushes.inc

    store.flush
    trace("Flushed store.")
  }

  def close {
    trace("Closing.")

    store.close
  }

  private def toBytesOrNull[T](t: T, serde: Serde[T]): Array[Byte] = if (t == null) {
    null
  } else {
    val bytes = serde.toBytes(t)
    if (bytes != null) {
      metrics.bytesSerialized.inc(bytes.size)
    }
    bytes
  }

  private def fromBytesOrNull[T](bytes: Array[Byte], serde: Serde[T]): T = if (bytes == null) {
    null.asInstanceOf[T]
  } else {
    val obj = serde.fromBytes(bytes)
    metrics.bytesDeserialized.inc(bytes.size)
    obj
  }

  private def serializeKeys(keys: java.util.List[K]): java.util.List[Array[Byte]] = {
    val bytes = new java.util.ArrayList[Array[Byte]](keys.size)
    val keysIterator = keys.iterator
    while (keysIterator.hasNext) {
      bytes.add(toBytesOrNull(keysIterator.next, keySerde))
    }
    bytes
  }

  /**
   * Updates put metrics with the given batch and record sizes. The max record size metric is updated with a
   * thread UN-SAFE read-then-write, so accuracy is not guaranteed; if multiple threads overlap in their invocation of
   * this method, the last to write simply wins regardless of the value it read.
   */
  private def updatePutMetrics(batchSize: Long, newMaxRecordKeySizeBytes: Long, newMaxRecordSizeBytes: Long) = {
    metrics.puts.inc(batchSize)
    val keyMax = metrics.maxRecordKeySizeBytes.getValue
    val valueMax = metrics.maxRecordSizeBytes.getValue
    if(newMaxRecordKeySizeBytes > keyMax){
      metrics.maxRecordKeySizeBytes.set(newMaxRecordKeySizeBytes)
    }
    if (newMaxRecordSizeBytes > valueMax) {
      metrics.maxRecordSizeBytes.set(newMaxRecordSizeBytes)
    }
  }

  override def snapshot(from: K, to: K): KeyValueSnapshot[K, V] = {
    val fromBytes = toBytesOrNull(from, keySerde)
    val toBytes = toBytesOrNull(to, keySerde)
    val snapshot = store.snapshot(fromBytes, toBytes)
    new KeyValueSnapshot[K, V] {
      override def iterator(): KeyValueIterator[K, V] = {
        new DeserializingIterator(snapshot.iterator())
      }

      override def close() = {
        snapshot.close()
      }
    }
  }

  override def checkpoint(id: CheckpointId): Optional[Path] = {
    store.checkpoint(id)
  }

  @VisibleForTesting
  private[kv] def getStore: KeyValueStore[Array[Byte], Array[Byte]] = {
    store
  }
}
