blob: f8eff5c80dff11d378a826a043e39cf3d8ad1e39 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.Objects;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiConsumer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheStats;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Weigher;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.MapMaker;
import org.checkerframework.checker.nullness.qual.Nullable;
/**
* Process-wide cache of per-key state.
*
* <p>This is backed by Guava {@link Cache} which is thread-safe. The entries are accessed often
* from multiple threads. Logical consistency of each entry requires accessing each key (computation
* * processing key * state_family * namespace) by a single thread at a time. {@link
* StreamingDataflowWorker} ensures that a single computation * processing key is executing on one
* thread at a time, so this is safe.
*/
public class WindmillStateCache implements StatusDataProvider {
// Convert Megabytes to bytes
private static final long MEGABYTES = 1024 * 1024;
// Estimate of overhead per StateId.
private static final long PER_STATE_ID_OVERHEAD = 28;
// Initial size of hash tables per entry.
private static final int INITIAL_HASH_MAP_CAPACITY = 4;
// Overhead of each hash map entry.
private static final int HASH_MAP_ENTRY_OVERHEAD = 16;
// Overhead of each StateCacheEntry. One long, plus a hash table.
private static final int PER_CACHE_ENTRY_OVERHEAD =
8 + HASH_MAP_ENTRY_OVERHEAD * INITIAL_HASH_MAP_CAPACITY;
private final Cache<StateId, StateCacheEntry> stateCache;
// Contains the current valid ForKey object. Entries in the cache are keyed by ForKey with pointer
// equality so entries may be invalidated by creating a new key object, rendering the previous
// entries inaccessible. They will be evicted through normal cache operation.
private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex =
new MapMaker().weakValues().concurrencyLevel(4).makeMap();
private final long workerCacheBytes; // Copy workerCacheMb and convert to bytes.
public WindmillStateCache(long workerCacheMb) {
final Weigher<Weighted, Weighted> weigher = Weighers.weightedKeysAndValues();
workerCacheBytes = workerCacheMb * MEGABYTES;
stateCache =
CacheBuilder.newBuilder()
.maximumWeight(workerCacheBytes)
.recordStats()
.weigher(weigher)
.concurrencyLevel(4)
.build();
}
private static class EntryStats {
long entries;
long idWeight;
long entryWeight;
long entryValues;
long maxEntryValues;
}
private EntryStats calculateEntryStats() {
EntryStats stats = new EntryStats();
BiConsumer<StateId, StateCacheEntry> consumer =
(stateId, stateCacheEntry) -> {
stats.entries++;
stats.idWeight += stateId.getWeight();
stats.entryWeight += stateCacheEntry.getWeight();
stats.entryValues += stateCacheEntry.values.size();
stats.maxEntryValues = Math.max(stats.maxEntryValues, stateCacheEntry.values.size());
};
stateCache.asMap().forEach(consumer);
return stats;
}
public long getWeight() {
EntryStats w = calculateEntryStats();
return w.idWeight + w.entryWeight;
}
public long getMaxWeight() {
return workerCacheBytes;
}
public CacheStats getCacheStats() {
return stateCache.stats();
}
/** Per-computation view of the state cache. */
public class ForComputation {
private final String computation;
private ForComputation(String computation) {
this.computation = computation;
}
/** Invalidate all cache entries for this computation and {@code processingKey}. */
public void invalidate(ByteString processingKey, long shardingKey) {
WindmillComputationKey key =
WindmillComputationKey.create(computation, processingKey, shardingKey);
// By removing the ForKey object, all state for the key is orphaned in the cache and will
// be removed by normal cache cleanup.
keyIndex.remove(key);
}
/**
* Returns a per-computation, per-key view of the state cache. Access to the cached data for
* this key is not thread-safe. Callers should ensure that there is only a single ForKey object
* in use at a time and that access to it is synchronized or single-threaded.
*/
public ForKey forKey(WindmillComputationKey computationKey, long cacheToken, long workToken) {
ForKey forKey = keyIndex.get(computationKey);
if (forKey == null || !forKey.updateTokens(cacheToken, workToken)) {
forKey = new ForKey(computationKey, cacheToken, workToken);
// We prefer this implementation to using compute because that is implemented similarly for
// ConcurrentHashMap with the downside of it performing inserts for unchanged existing
// values as well.
keyIndex.put(computationKey, forKey);
}
return forKey;
}
}
/** Per-computation, per-key view of the state cache. */
// Note that we utilize the default equality and hashCode for this class based upon the instance
// (instead of the fields) to optimize cache invalidation.
public class ForKey {
private final WindmillComputationKey computationKey;
// Cache token must be consistent for the key for the cache to be valid.
private final long cacheToken;
// The work token for processing must be greater than the last work token. As work items are
// increasing for a key, a less-than or equal to work token indicates that the current token is
// for stale processing.
private long workToken;
/**
* Returns a per-computation, per-key, per-family view of the state cache. Access to the cached
* data for this key is not thread-safe. Callers should ensure that there is only a single
* ForKeyAndFamily object in use at a time for a given computation, key, family tuple and that
* access to it is synchronized or single-threaded.
*/
public ForKeyAndFamily forFamily(String stateFamily) {
return new ForKeyAndFamily(this, stateFamily);
}
private ForKey(WindmillComputationKey computationKey, long cacheToken, long workToken) {
this.computationKey = computationKey;
this.cacheToken = cacheToken;
this.workToken = workToken;
}
private boolean updateTokens(long cacheToken, long workToken) {
if (this.cacheToken != cacheToken || workToken <= this.workToken) {
return false;
}
this.workToken = workToken;
return true;
}
}
/**
* Per-computation, per-key, per-family view of the state cache. Modifications are cached locally
* and must be flushed to the cache by calling persist. This class is not thread-safe.
*/
public class ForKeyAndFamily {
final ForKey forKey;
final String stateFamily;
private final HashMap<StateId, StateCacheEntry> localCache;
private ForKeyAndFamily(ForKey forKey, String stateFamily) {
this.forKey = forKey;
this.stateFamily = stateFamily;
localCache = new HashMap<>();
}
public String getStateFamily() {
return stateFamily;
}
public <T extends State> @Nullable T get(StateNamespace namespace, StateTag<T> address) {
StateId id = new StateId(forKey, stateFamily, namespace);
@SuppressWarnings("nullness") // Unsure how to annotate lambda return allowing null.
@Nullable
StateCacheEntry entry = localCache.computeIfAbsent(id, key -> stateCache.getIfPresent(key));
return entry == null ? null : entry.get(namespace, address);
}
public <T extends State> void put(
StateNamespace namespace, StateTag<T> address, T value, long weight) {
StateId id = new StateId(forKey, stateFamily, namespace);
@Nullable StateCacheEntry entry = localCache.get(id);
if (entry == null) {
entry = stateCache.getIfPresent(id);
if (entry == null) {
entry = new StateCacheEntry();
}
boolean hadValue = localCache.putIfAbsent(id, entry) != null;
Preconditions.checkState(!hadValue);
}
entry.put(namespace, address, value, weight);
}
public void persist() {
localCache.forEach((id, entry) -> stateCache.put(id, entry));
}
}
/** Returns a per-computation view of the state cache. */
public ForComputation forComputation(String computation) {
return new ForComputation(computation);
}
/**
* Struct identifying a cache entry that contains all data for a ForKey instance and namespace.
*/
private static class StateId implements Weighted {
private final ForKey forKey;
private final String stateFamily;
private final Object namespaceKey;
private final int hashCode;
public StateId(ForKey forKey, String stateFamily, StateNamespace namespace) {
this.forKey = forKey;
this.stateFamily = stateFamily;
this.namespaceKey = namespace.getCacheKey();
this.hashCode = Objects.hash(forKey, stateFamily, namespaceKey);
}
@Override
public boolean equals(@Nullable Object other) {
if (this == other) {
return true;
}
if (!(other instanceof StateId)) {
return false;
}
StateId otherId = (StateId) other;
return hashCode == otherId.hashCode
&& forKey == otherId.forKey
&& stateFamily.equals(otherId.stateFamily)
&& namespaceKey.equals(otherId.namespaceKey);
}
@Override
public int hashCode() {
return hashCode;
}
@Override
public long getWeight() {
return forKey.computationKey.key().size() + stateFamily.length() + PER_STATE_ID_OVERHEAD;
}
}
/** Entry in the state cache that stores a map of values. */
private static class StateCacheEntry implements Weighted {
private final HashMap<NamespacedTag<?>, WeightedValue<?>> values;
private long weight;
public StateCacheEntry() {
this.values = new HashMap<>(INITIAL_HASH_MAP_CAPACITY);
this.weight = 0;
}
public <T extends State> @Nullable T get(StateNamespace namespace, StateTag<T> tag) {
@SuppressWarnings("unchecked")
@Nullable
WeightedValue<T> weightedValue =
(WeightedValue<T>) values.get(new NamespacedTag<>(namespace, tag));
return weightedValue == null ? null : weightedValue.value;
}
public <T extends State> void put(
StateNamespace namespace, StateTag<T> tag, T value, long weight) {
values.compute(
new NamespacedTag<>(namespace, tag),
(t, v) -> {
@SuppressWarnings("unchecked")
WeightedValue<T> weightedValue = (WeightedValue<T>) v;
if (weightedValue == null) {
weightedValue = new WeightedValue<>();
this.weight += HASH_MAP_ENTRY_OVERHEAD;
} else {
this.weight -= weightedValue.weight;
}
this.weight += weight;
weightedValue.value = value;
weightedValue.weight = weight;
return weightedValue;
});
}
@Override
public long getWeight() {
return weight + PER_CACHE_ENTRY_OVERHEAD;
}
// Even though we use the namespace at the higher cache level, we are only using the cacheKey.
// That allows for grouped eviction of entries sharing a cacheKey but we require the full
// namespace here to distinguish between grouped entries.
private static class NamespacedTag<T extends State> {
private final StateNamespace namespace;
private final Equivalence.Wrapper<StateTag<T>> tag;
NamespacedTag(StateNamespace namespace, StateTag<T> tag) {
this.namespace = namespace;
this.tag = StateTags.ID_EQUIVALENCE.wrap(tag);
}
@Override
public boolean equals(@Nullable Object other) {
if (other == this) {
return true;
}
if (!(other instanceof NamespacedTag)) {
return false;
}
NamespacedTag<?> that = (NamespacedTag<?>) other;
return namespace.equals(that.namespace) && tag.equals(that.tag);
}
@Override
public int hashCode() {
return Objects.hash(namespace, tag);
}
}
private static class WeightedValue<T> {
public long weight;
public @Nullable T value;
}
}
/** Print summary statistics of the cache to the given {@link PrintWriter}. */
@Override
public void appendSummaryHtml(PrintWriter response) {
response.println("Cache Stats: <br><table>");
response.println(
"<tr><th>Hit Ratio</th><th>Evictions</th><th>Entries</th>"
+ "<th>Entry Values</th><th>Max Entry Values</th>"
+ "<th>Id Weight</th><th>Entry Weight</th><th>Max Weight</th><th>Keys</th>"
+ "</tr><tr>");
CacheStats cacheStats = stateCache.stats();
EntryStats entryStats = calculateEntryStats();
response.println("<td>" + cacheStats.hitRate() + "</td>");
response.println("<td>" + cacheStats.evictionCount() + "</td>");
response.println("<td>" + entryStats.entries + "(" + stateCache.size() + " inc. weak) </td>");
response.println("<td>" + entryStats.entryValues + "</td>");
response.println("<td>" + entryStats.maxEntryValues + "</td>");
response.println("<td>" + entryStats.idWeight / MEGABYTES + "MB</td>");
response.println("<td>" + entryStats.entryWeight / MEGABYTES + "MB</td>");
response.println("<td>" + getMaxWeight() / MEGABYTES + "MB</td>");
response.println("<td>" + keyIndex.size() + "</td>");
response.println("</tr></table><br>");
}
public BaseStatusServlet statusServlet() {
return new BaseStatusServlet("/cachez") {
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException {
PrintWriter writer = response.getWriter();
writer.println("<h1>Cache Information</h1>");
appendSummaryHtml(writer);
}
};
}
}