| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.dataflow.worker; |
| |
| import java.io.IOException; |
| import java.io.PrintWriter; |
| import java.util.HashMap; |
| import java.util.Map; |
| import java.util.Objects; |
| import javax.servlet.ServletException; |
| import javax.servlet.http.HttpServletRequest; |
| import javax.servlet.http.HttpServletResponse; |
| import org.apache.beam.runners.core.StateNamespace; |
| import org.apache.beam.runners.core.StateTag; |
| import org.apache.beam.runners.core.StateTags; |
| import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet; |
| import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; |
| import org.apache.beam.sdk.state.State; |
| import org.apache.beam.sdk.util.Weighted; |
| import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalCause; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Weigher; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap; |
| |
| /** |
| * Process-wide cache of per-key state. |
| * |
| * <p>This is backed by Guava {@link Cache} which is thread-safe. The entries are accessed often |
| * from multiple threads. Logical consistency of each entry requires accessing each key (computation |
| * * processing key * state_family * namespace) by a single thread at a time. {@link |
| * StreamingDataflowWorker} ensures that a single computation * processing key is executing on one |
| * thread at a time, so this is safe. |
| */ |
| public class WindmillStateCache implements StatusDataProvider { |
| // Estimate of overhead per StateId. |
| private static final int PER_STATE_ID_OVERHEAD = 20; |
| // Initial size of hash tables per entry. |
| private static final int INITIAL_HASH_MAP_CAPACITY = 4; |
| // Overhead of each hash map entry. |
| private static final int HASH_MAP_ENTRY_OVERHEAD = 16; |
| // Overhead of each cache entry. Three longs, plus a hash table. |
| private static final int PER_CACHE_ENTRY_OVERHEAD = |
| 24 + HASH_MAP_ENTRY_OVERHEAD * INITIAL_HASH_MAP_CAPACITY; |
| |
| private Cache<StateId, StateCacheEntry> stateCache; |
| private HashMultimap<ComputationKey, StateId> keyIndex = |
| HashMultimap.<ComputationKey, StateId>create(); |
| private int displayedWeight = 0; // Only used for status pages and unit tests. |
| |
| public WindmillStateCache() { |
| final Weigher<Weighted, Weighted> weigher = Weighers.weightedKeysAndValues(); |
| |
| stateCache = |
| CacheBuilder.newBuilder() |
| .maximumWeight(100000000 /* 100 MB */) |
| .recordStats() |
| .weigher(weigher) |
| .removalListener( |
| removal -> { |
| if (removal.getCause() != RemovalCause.REPLACED) { |
| synchronized (this) { |
| StateId id = (StateId) removal.getKey(); |
| if (removal.getCause() != RemovalCause.EXPLICIT) { |
| // When we invalidate a key explicitly, we'll also update the keyIndex, so |
| // no need to do it here. |
| keyIndex.remove(id.getComputationKey(), id); |
| } |
| displayedWeight -= weigher.weigh(id, removal.getValue()); |
| } |
| } |
| }) |
| .build(); |
| } |
| |
| public long getWeight() { |
| return displayedWeight; |
| } |
| |
| /** Per-computation view of the state cache. */ |
| public class ForComputation { |
| private final String computation; |
| |
| private ForComputation(String computation) { |
| this.computation = computation; |
| } |
| |
| /** Invalidate all cache entries for this computation and {@code processingKey}. */ |
| public void invalidate(ByteString processingKey) { |
| synchronized (this) { |
| ComputationKey key = new ComputationKey(computation, processingKey); |
| for (StateId id : keyIndex.get(key)) { |
| stateCache.invalidate(id); |
| } |
| keyIndex.removeAll(key); |
| } |
| } |
| |
| /** Returns a per-computation, per-key view of the state cache. */ |
| public ForKey forKey(ByteString key, String stateFamily, long cacheToken, long workToken) { |
| return new ForKey(computation, key, stateFamily, cacheToken, workToken); |
| } |
| } |
| |
| /** Per-computation, per-key view of the state cache. */ |
| public class ForKey { |
| private final String computation; |
| private final ByteString key; |
| private final String stateFamily; |
| // Cache token must be consistent for the key for the cache to be valid. |
| private final long cacheToken; |
| |
| // The work token for processing must be greater than the last work token. As work items are |
| // increasing for a key, a less-than or equal to work token indicates that the current token is |
| // for stale processing. We don't use the cache so that fetches performed will fail with a no |
| // longer valid work token. |
| private final long workToken; |
| |
| private ForKey( |
| String computation, ByteString key, String stateFamily, long cacheToken, long workToken) { |
| this.computation = computation; |
| this.key = key; |
| this.stateFamily = stateFamily; |
| this.cacheToken = cacheToken; |
| this.workToken = workToken; |
| } |
| |
| public <T extends State> T get(StateNamespace namespace, StateTag<T> address) { |
| return WindmillStateCache.this.get( |
| computation, key, stateFamily, cacheToken, workToken, namespace, address); |
| } |
| |
| // Note that once a value has been put for a given workToken, get calls with that same workToken |
| // will fail. This is ok as we only put entries when we are building the commit and will no |
| // longer be performing gets for the same work token. |
| public <T extends State> void put( |
| StateNamespace namespace, StateTag<T> address, T value, long weight) { |
| WindmillStateCache.this.put( |
| computation, key, stateFamily, cacheToken, workToken, namespace, address, value, weight); |
| } |
| } |
| |
| /** Returns a per-computation view of the state cache. */ |
| public ForComputation forComputation(String computation) { |
| return new ForComputation(computation); |
| } |
| |
| private <T extends State> T get( |
| String computation, |
| ByteString processingKey, |
| String stateFamily, |
| long cacheToken, |
| long workToken, |
| StateNamespace namespace, |
| StateTag<T> address) { |
| StateId id = new StateId(computation, processingKey, stateFamily, namespace); |
| StateCacheEntry entry = stateCache.getIfPresent(id); |
| if (entry == null) { |
| return null; |
| } |
| if (entry.getCacheToken() != cacheToken) { |
| stateCache.invalidate(id); |
| return null; |
| } |
| if (workToken <= entry.getLastWorkToken()) { |
| // We don't used the cached item but we don't invalidate it. |
| return null; |
| } |
| return entry.get(namespace, address); |
| } |
| |
| private <T extends State> void put( |
| String computation, |
| ByteString processingKey, |
| String stateFamily, |
| long cacheToken, |
| long workToken, |
| StateNamespace namespace, |
| StateTag<T> address, |
| T value, |
| long weight) { |
| StateId id = new StateId(computation, processingKey, stateFamily, namespace); |
| StateCacheEntry entry = stateCache.getIfPresent(id); |
| if (entry == null) { |
| synchronized (this) { |
| keyIndex.put(id.getComputationKey(), id); |
| } |
| } |
| if (entry == null || entry.getCacheToken() != cacheToken) { |
| entry = new StateCacheEntry(cacheToken); |
| this.displayedWeight += (int) id.getWeight(); |
| this.displayedWeight += (int) entry.getWeight(); |
| } |
| entry.setLastWorkToken(workToken); |
| this.displayedWeight += (int) entry.put(namespace, address, value, weight); |
| // Always add back to the cache to update the weight. |
| stateCache.put(id, entry); |
| } |
| |
| private static class ComputationKey { |
| private final String computation; |
| private final ByteString key; |
| |
| public ComputationKey(String computation, ByteString key) { |
| this.computation = computation; |
| this.key = key; |
| } |
| |
| public ByteString getKey() { |
| return key; |
| } |
| |
| @Override |
| public boolean equals(Object that) { |
| if (that instanceof ComputationKey) { |
| ComputationKey other = (ComputationKey) that; |
| return computation.equals(other.computation) && key.equals(other.key); |
| } |
| return false; |
| } |
| |
| @Override |
| public int hashCode() { |
| return Objects.hash(computation, key); |
| } |
| } |
| |
| /** Struct identifying a cache entry that contains all data for a key and namespace. */ |
| private static class StateId implements Weighted { |
| private final ComputationKey computationKey; |
| private final String stateFamily; |
| private final Object namespaceKey; |
| |
| public StateId( |
| String computation, |
| ByteString processingKey, |
| String stateFamily, |
| StateNamespace namespace) { |
| this.computationKey = new ComputationKey(computation, processingKey); |
| this.stateFamily = stateFamily; |
| this.namespaceKey = namespace.getCacheKey(); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| if (other instanceof StateId) { |
| StateId otherId = (StateId) other; |
| return computationKey.equals(otherId.computationKey) |
| && stateFamily.equals(otherId.stateFamily) |
| && namespaceKey.equals(otherId.namespaceKey); |
| } |
| return false; |
| } |
| |
| public ComputationKey getComputationKey() { |
| return computationKey; |
| } |
| |
| @Override |
| public int hashCode() { |
| return Objects.hash(computationKey, namespaceKey); |
| } |
| |
| @Override |
| public long getWeight() { |
| return (long) computationKey.getKey().size() + PER_STATE_ID_OVERHEAD; |
| } |
| } |
| |
| /** |
| * Entry in the state cache that stores a map of values, a cache token representing the validity |
| * of the values, and a work token that is increasing to ensure sequential processing. |
| */ |
| private static class StateCacheEntry implements Weighted { |
| private final long cacheToken; |
| private long lastWorkToken; |
| private final Map<NamespacedTag<?>, WeightedValue<?>> values; |
| private long weight; |
| |
| public StateCacheEntry(long cacheToken) { |
| this.values = new HashMap<>(INITIAL_HASH_MAP_CAPACITY); |
| this.cacheToken = cacheToken; |
| this.lastWorkToken = Long.MIN_VALUE; |
| this.weight = 0; |
| } |
| |
| public void setLastWorkToken(long workToken) { |
| this.lastWorkToken = workToken; |
| } |
| |
| @SuppressWarnings("unchecked") |
| public <T extends State> T get(StateNamespace namespace, StateTag<T> tag) { |
| WeightedValue<T> weightedValue = |
| (WeightedValue<T>) values.get(new NamespacedTag<>(namespace, tag)); |
| return weightedValue == null ? null : weightedValue.value; |
| } |
| |
| public <T extends State> long put( |
| StateNamespace namespace, StateTag<T> tag, T value, long weight) { |
| @SuppressWarnings("unchecked") |
| WeightedValue<T> weightedValue = |
| (WeightedValue<T>) values.get(new NamespacedTag<>(namespace, tag)); |
| long weightDelta = 0; |
| if (weightedValue == null) { |
| weightedValue = new WeightedValue<>(); |
| weightDelta += HASH_MAP_ENTRY_OVERHEAD; |
| } else { |
| weightDelta -= weightedValue.weight; |
| } |
| weightedValue.value = value; |
| weightedValue.weight = weight; |
| weightDelta += weight; |
| this.weight += weightDelta; |
| values.put(new NamespacedTag<>(namespace, tag), weightedValue); |
| return weightDelta; |
| } |
| |
| @Override |
| public long getWeight() { |
| return weight + PER_CACHE_ENTRY_OVERHEAD; |
| } |
| |
| public long getCacheToken() { |
| return cacheToken; |
| } |
| |
| public long getLastWorkToken() { |
| return lastWorkToken; |
| } |
| |
| private static class NamespacedTag<T extends State> { |
| private final StateNamespace namespace; |
| private final Equivalence.Wrapper<StateTag> tag; |
| |
| NamespacedTag(StateNamespace namespace, StateTag<T> tag) { |
| this.namespace = namespace; |
| this.tag = StateTags.ID_EQUIVALENCE.wrap((StateTag) tag); |
| } |
| |
| @Override |
| public boolean equals(Object other) { |
| if (!(other instanceof NamespacedTag)) { |
| return false; |
| } |
| NamespacedTag<?> that = (NamespacedTag<?>) other; |
| return namespace.equals(that.namespace) && tag.equals(that.tag); |
| } |
| |
| @Override |
| public int hashCode() { |
| return Objects.hash(namespace, tag); |
| } |
| } |
| |
| private static class WeightedValue<T> { |
| public long weight = 0; |
| public T value = null; |
| } |
| } |
| |
| /** Print summary statistics of the cache to the given {@link PrintWriter}. */ |
| @Override |
| public void appendSummaryHtml(PrintWriter response) { |
| response.println("Cache Stats: <br><table border=0>"); |
| response.println( |
| "<tr><th>Hit Ratio</th><th>Evictions</th><th>Size</th><th>Weight</th></tr><tr>"); |
| response.println("<th>" + stateCache.stats().hitRate() + "</th>"); |
| response.println("<th>" + stateCache.stats().evictionCount() + "</th>"); |
| response.println("<th>" + stateCache.size() + "</th>"); |
| response.println("<th>" + getWeight() + "</th>"); |
| response.println("</tr></table><br>"); |
| } |
| |
| public BaseStatusServlet statusServlet() { |
| return new BaseStatusServlet("/cachez") { |
| @Override |
| protected void doGet(HttpServletRequest request, HttpServletResponse response) |
| throws IOException, ServletException { |
| |
| PrintWriter writer = response.getWriter(); |
| writer.println("<h1>Cache Information</h1>"); |
| appendSummaryHtml(writer); |
| } |
| }; |
| } |
| } |