| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.samza.storage.kv |
| |
| import java.nio.file.Path |
| import java.util.Optional |
| |
| import com.google.common.annotations.VisibleForTesting |
| import org.apache.samza.checkpoint.CheckpointId |
| import org.apache.samza.util.Logging |
| import org.apache.samza.serializers._ |
| |
| /** |
| * A key-value store wrapper that handles serialization |
| */ |
| class SerializedKeyValueStore[K, V]( |
| store: KeyValueStore[Array[Byte], Array[Byte]], |
| keySerde: Serde[K], |
| msgSerde: Serde[V], |
| metrics: SerializedKeyValueStoreMetrics = new SerializedKeyValueStoreMetrics) extends KeyValueStore[K, V] with Logging { |
| |
| def get(key: K): V = { |
| val keyBytes = toBytesOrNull(key, keySerde) |
| val found = store.get(keyBytes) |
| metrics.gets.inc |
| fromBytesOrNull(found, msgSerde) |
| } |
| |
| override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = { |
| metrics.gets.inc(keys.size) |
| val mapBytes = store.getAll(serializeKeys(keys)) |
| if (mapBytes != null) { |
| val map = new java.util.HashMap[K, V](mapBytes.size) |
| val entryIterator = mapBytes.entrySet.iterator |
| while (entryIterator.hasNext) { |
| val entry = entryIterator.next |
| map.put(fromBytesOrNull(entry.getKey, keySerde), fromBytesOrNull(entry.getValue, msgSerde)) |
| } |
| map |
| } else { |
| null.asInstanceOf[java.util.Map[K, V]] |
| } |
| } |
| |
| def put(key: K, value: V) { |
| val keyBytes = toBytesOrNull(key, keySerde) |
| val valBytes = toBytesOrNull(value, msgSerde) |
| store.put(keyBytes, valBytes) |
| val keySizeBytes = if (keyBytes == null) 0 else keyBytes.length |
| val valSizeBytes = if (valBytes == null) 0 else valBytes.length |
| metrics.recordKeySizeBytes.update(keySizeBytes) |
| metrics.recordValueSizeBytes.update(valSizeBytes) |
| updatePutMetrics(1, keySizeBytes, valSizeBytes) |
| } |
| |
| def putAll(entries: java.util.List[Entry[K, V]]) { |
| val list = new java.util.ArrayList[Entry[Array[Byte], Array[Byte]]](entries.size()) |
| val iter = entries.iterator |
| var newMaxRecordKeySizeBytes = 0 |
| var newMaxRecordSizeBytes = 0 |
| while (iter.hasNext) { |
| val curr = iter.next |
| val keyBytes = toBytesOrNull(curr.getKey, keySerde) |
| val valBytes = toBytesOrNull(curr.getValue, msgSerde) |
| val keySizeBytes = if (keyBytes == null) 0 else keyBytes.length |
| val valSizeBytes = if (valBytes == null) 0 else valBytes.length |
| metrics.recordKeySizeBytes.update(keySizeBytes) |
| metrics.recordValueSizeBytes.update(valSizeBytes) |
| newMaxRecordKeySizeBytes = Math.max(newMaxRecordKeySizeBytes, keySizeBytes) |
| newMaxRecordSizeBytes = Math.max(newMaxRecordSizeBytes, valSizeBytes) |
| list.add(new Entry(keyBytes, valBytes)) |
| } |
| store.putAll(list) |
| updatePutMetrics(list.size, newMaxRecordKeySizeBytes, newMaxRecordSizeBytes) |
| } |
| |
| def delete(key: K) { |
| metrics.deletes.inc |
| val keyBytes = toBytesOrNull(key, keySerde) |
| store.delete(keyBytes) |
| } |
| |
| override def deleteAll(keys: java.util.List[K]) = { |
| metrics.deletes.inc(keys.size) |
| store.deleteAll(serializeKeys(keys)) |
| } |
| |
| def range(from: K, to: K): KeyValueIterator[K, V] = { |
| metrics.ranges.inc |
| val fromBytes = toBytesOrNull(from, keySerde) |
| val toBytes = toBytesOrNull(to, keySerde) |
| new DeserializingIterator(store.range(fromBytes, toBytes)) |
| } |
| |
| def all(): KeyValueIterator[K, V] = { |
| metrics.alls.inc |
| new DeserializingIterator(store.all) |
| } |
| |
| private class DeserializingIterator(iter: KeyValueIterator[Array[Byte], Array[Byte]]) extends KeyValueIterator[K, V] { |
| override def hasNext() = iter.hasNext() |
| override def remove() = iter.remove() |
| override def close() = iter.close() |
| override def next(): Entry[K, V] = { |
| val nxt = iter.next() |
| val key = fromBytesOrNull(nxt.getKey, keySerde) |
| val value = fromBytesOrNull(nxt.getValue, msgSerde) |
| new Entry(key, value) |
| } |
| } |
| |
| def flush { |
| trace("Flushing store.") |
| |
| metrics.flushes.inc |
| |
| store.flush |
| trace("Flushed store.") |
| } |
| |
| def close { |
| trace("Closing.") |
| |
| store.close |
| } |
| |
| private def toBytesOrNull[T](t: T, serde: Serde[T]): Array[Byte] = if (t == null) { |
| null |
| } else { |
| val bytes = serde.toBytes(t) |
| if (bytes != null) { |
| metrics.bytesSerialized.inc(bytes.size) |
| } |
| bytes |
| } |
| |
| private def fromBytesOrNull[T](bytes: Array[Byte], serde: Serde[T]): T = if (bytes == null) { |
| null.asInstanceOf[T] |
| } else { |
| val obj = serde.fromBytes(bytes) |
| metrics.bytesDeserialized.inc(bytes.size) |
| obj |
| } |
| |
| private def serializeKeys(keys: java.util.List[K]): java.util.List[Array[Byte]] = { |
| val bytes = new java.util.ArrayList[Array[Byte]](keys.size) |
| val keysIterator = keys.iterator |
| while (keysIterator.hasNext) { |
| bytes.add(toBytesOrNull(keysIterator.next, keySerde)) |
| } |
| bytes |
| } |
| |
| /** |
| * Updates put metrics with the given batch and record sizes. The max record size metric is updated with a |
| * thread UN-SAFE read-then-write, so accuracy is not guaranteed; if multiple threads overlap in their invocation of |
| * this method, the last to write simply wins regardless of the value it read. |
| */ |
| private def updatePutMetrics(batchSize: Long, newMaxRecordKeySizeBytes: Long, newMaxRecordSizeBytes: Long) = { |
| metrics.puts.inc(batchSize) |
| val keyMax = metrics.maxRecordKeySizeBytes.getValue |
| val valueMax = metrics.maxRecordSizeBytes.getValue |
| if(newMaxRecordKeySizeBytes > keyMax){ |
| metrics.maxRecordKeySizeBytes.set(newMaxRecordKeySizeBytes) |
| } |
| if (newMaxRecordSizeBytes > valueMax) { |
| metrics.maxRecordSizeBytes.set(newMaxRecordSizeBytes) |
| } |
| } |
| |
| override def snapshot(from: K, to: K): KeyValueSnapshot[K, V] = { |
| val fromBytes = toBytesOrNull(from, keySerde) |
| val toBytes = toBytesOrNull(to, keySerde) |
| val snapshot = store.snapshot(fromBytes, toBytes) |
| new KeyValueSnapshot[K, V] { |
| override def iterator(): KeyValueIterator[K, V] = { |
| new DeserializingIterator(snapshot.iterator()) |
| } |
| |
| override def close() = { |
| snapshot.close() |
| } |
| } |
| } |
| |
| override def checkpoint(id: CheckpointId): Optional[Path] = { |
| store.checkpoint(id) |
| } |
| |
| @VisibleForTesting |
| private[kv] def getStore: KeyValueStore[Array[Byte], Array[Byte]] = { |
| store |
| } |
| } |