blob: b78e14a88aeba8a4ad4c706c11d122063a3a6ce8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.storage.kv
import java.nio.file.Path
import java.util.Optional
import com.google.common.annotations.VisibleForTesting
import org.apache.samza.checkpoint.CheckpointId
import org.apache.samza.util.Logging
import org.apache.samza.serializers._
/**
* A key-value store wrapper that handles serialization
*/
class SerializedKeyValueStore[K, V](
store: KeyValueStore[Array[Byte], Array[Byte]],
keySerde: Serde[K],
msgSerde: Serde[V],
metrics: SerializedKeyValueStoreMetrics = new SerializedKeyValueStoreMetrics) extends KeyValueStore[K, V] with Logging {
def get(key: K): V = {
val keyBytes = toBytesOrNull(key, keySerde)
val found = store.get(keyBytes)
metrics.gets.inc
fromBytesOrNull(found, msgSerde)
}
override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
metrics.gets.inc(keys.size)
val mapBytes = store.getAll(serializeKeys(keys))
if (mapBytes != null) {
val map = new java.util.HashMap[K, V](mapBytes.size)
val entryIterator = mapBytes.entrySet.iterator
while (entryIterator.hasNext) {
val entry = entryIterator.next
map.put(fromBytesOrNull(entry.getKey, keySerde), fromBytesOrNull(entry.getValue, msgSerde))
}
map
} else {
null.asInstanceOf[java.util.Map[K, V]]
}
}
def put(key: K, value: V) {
val keyBytes = toBytesOrNull(key, keySerde)
val valBytes = toBytesOrNull(value, msgSerde)
store.put(keyBytes, valBytes)
val keySizeBytes = if (keyBytes == null) 0 else keyBytes.length
val valSizeBytes = if (valBytes == null) 0 else valBytes.length
metrics.recordKeySizeBytes.update(keySizeBytes)
metrics.recordValueSizeBytes.update(valSizeBytes)
updatePutMetrics(1, keySizeBytes, valSizeBytes)
}
def putAll(entries: java.util.List[Entry[K, V]]) {
val list = new java.util.ArrayList[Entry[Array[Byte], Array[Byte]]](entries.size())
val iter = entries.iterator
var newMaxRecordKeySizeBytes = 0
var newMaxRecordSizeBytes = 0
while (iter.hasNext) {
val curr = iter.next
val keyBytes = toBytesOrNull(curr.getKey, keySerde)
val valBytes = toBytesOrNull(curr.getValue, msgSerde)
val keySizeBytes = if (keyBytes == null) 0 else keyBytes.length
val valSizeBytes = if (valBytes == null) 0 else valBytes.length
metrics.recordKeySizeBytes.update(keySizeBytes)
metrics.recordValueSizeBytes.update(valSizeBytes)
newMaxRecordKeySizeBytes = Math.max(newMaxRecordKeySizeBytes, keySizeBytes)
newMaxRecordSizeBytes = Math.max(newMaxRecordSizeBytes, valSizeBytes)
list.add(new Entry(keyBytes, valBytes))
}
store.putAll(list)
updatePutMetrics(list.size, newMaxRecordKeySizeBytes, newMaxRecordSizeBytes)
}
def delete(key: K) {
metrics.deletes.inc
val keyBytes = toBytesOrNull(key, keySerde)
store.delete(keyBytes)
}
override def deleteAll(keys: java.util.List[K]) = {
metrics.deletes.inc(keys.size)
store.deleteAll(serializeKeys(keys))
}
def range(from: K, to: K): KeyValueIterator[K, V] = {
metrics.ranges.inc
val fromBytes = toBytesOrNull(from, keySerde)
val toBytes = toBytesOrNull(to, keySerde)
new DeserializingIterator(store.range(fromBytes, toBytes))
}
def all(): KeyValueIterator[K, V] = {
metrics.alls.inc
new DeserializingIterator(store.all)
}
private class DeserializingIterator(iter: KeyValueIterator[Array[Byte], Array[Byte]]) extends KeyValueIterator[K, V] {
override def hasNext() = iter.hasNext()
override def remove() = iter.remove()
override def close() = iter.close()
override def next(): Entry[K, V] = {
val nxt = iter.next()
val key = fromBytesOrNull(nxt.getKey, keySerde)
val value = fromBytesOrNull(nxt.getValue, msgSerde)
new Entry(key, value)
}
}
def flush {
trace("Flushing store.")
metrics.flushes.inc
store.flush
trace("Flushed store.")
}
def close {
trace("Closing.")
store.close
}
private def toBytesOrNull[T](t: T, serde: Serde[T]): Array[Byte] = if (t == null) {
null
} else {
val bytes = serde.toBytes(t)
if (bytes != null) {
metrics.bytesSerialized.inc(bytes.size)
}
bytes
}
private def fromBytesOrNull[T](bytes: Array[Byte], serde: Serde[T]): T = if (bytes == null) {
null.asInstanceOf[T]
} else {
val obj = serde.fromBytes(bytes)
metrics.bytesDeserialized.inc(bytes.size)
obj
}
private def serializeKeys(keys: java.util.List[K]): java.util.List[Array[Byte]] = {
val bytes = new java.util.ArrayList[Array[Byte]](keys.size)
val keysIterator = keys.iterator
while (keysIterator.hasNext) {
bytes.add(toBytesOrNull(keysIterator.next, keySerde))
}
bytes
}
/**
* Updates put metrics with the given batch and record sizes. The max record size metric is updated with a
* thread UN-SAFE read-then-write, so accuracy is not guaranteed; if multiple threads overlap in their invocation of
* this method, the last to write simply wins regardless of the value it read.
*/
private def updatePutMetrics(batchSize: Long, newMaxRecordKeySizeBytes: Long, newMaxRecordSizeBytes: Long) = {
metrics.puts.inc(batchSize)
val keyMax = metrics.maxRecordKeySizeBytes.getValue
val valueMax = metrics.maxRecordSizeBytes.getValue
if(newMaxRecordKeySizeBytes > keyMax){
metrics.maxRecordKeySizeBytes.set(newMaxRecordKeySizeBytes)
}
if (newMaxRecordSizeBytes > valueMax) {
metrics.maxRecordSizeBytes.set(newMaxRecordSizeBytes)
}
}
override def snapshot(from: K, to: K): KeyValueSnapshot[K, V] = {
val fromBytes = toBytesOrNull(from, keySerde)
val toBytes = toBytesOrNull(to, keySerde)
val snapshot = store.snapshot(fromBytes, toBytes)
new KeyValueSnapshot[K, V] {
override def iterator(): KeyValueIterator[K, V] = {
new DeserializingIterator(snapshot.iterator())
}
override def close() = {
snapshot.close()
}
}
}
override def checkpoint(id: CheckpointId): Optional[Path] = {
store.checkpoint(id)
}
@VisibleForTesting
private[kv] def getStore: KeyValueStore[Array[Byte], Array[Byte]] = {
store
}
}