/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.dataflow.worker;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalCause;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Weigher;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap;

/**
 * Process-wide cache of per-key state.
 *
 * <p>This is backed by Guava {@link Cache} which is thread-safe. The entries are accessed often
 * from multiple threads. Logical consistency of each entry requires accessing each key (computation
 * * processing key * state_family * namespace) by a single thread at a time. {@link
 * StreamingDataflowWorker} ensures that a single computation * processing key is executing on one
 * thread at a time, so this is safe.
 */
public class WindmillStateCache implements StatusDataProvider {
  // Estimate of overhead per StateId.
  private static final int PER_STATE_ID_OVERHEAD = 20;
  // Initial size of hash tables per entry.
  private static final int INITIAL_HASH_MAP_CAPACITY = 4;
  // Overhead of each hash map entry.
  private static final int HASH_MAP_ENTRY_OVERHEAD = 16;
  // Overhead of each cache entry.  Three longs, plus a hash table.
  private static final int PER_CACHE_ENTRY_OVERHEAD =
      24 + HASH_MAP_ENTRY_OVERHEAD * INITIAL_HASH_MAP_CAPACITY;

  private Cache<StateId, StateCacheEntry> stateCache;
  private HashMultimap<ComputationKey, StateId> keyIndex =
      HashMultimap.<ComputationKey, StateId>create();
  private int displayedWeight = 0; // Only used for status pages and unit tests.

  public WindmillStateCache() {
    final Weigher<Weighted, Weighted> weigher = Weighers.weightedKeysAndValues();

    stateCache =
        CacheBuilder.newBuilder()
            .maximumWeight(100000000 /* 100 MB */)
            .recordStats()
            .weigher(weigher)
            .removalListener(
                removal -> {
                  if (removal.getCause() != RemovalCause.REPLACED) {
                    synchronized (this) {
                      StateId id = (StateId) removal.getKey();
                      if (removal.getCause() != RemovalCause.EXPLICIT) {
                        // When we invalidate a key explicitly, we'll also update the keyIndex, so
                        // no need to do it here.
                        keyIndex.remove(id.getComputationKey(), id);
                      }
                      displayedWeight -= weigher.weigh(id, removal.getValue());
                    }
                  }
                })
            .build();
  }

  public long getWeight() {
    return displayedWeight;
  }

  /** Per-computation view of the state cache. */
  public class ForComputation {
    private final String computation;

    private ForComputation(String computation) {
      this.computation = computation;
    }

    /** Invalidate all cache entries for this computation and {@code processingKey}. */
    public void invalidate(ByteString processingKey) {
      synchronized (this) {
        ComputationKey key = new ComputationKey(computation, processingKey);
        for (StateId id : keyIndex.get(key)) {
          stateCache.invalidate(id);
        }
        keyIndex.removeAll(key);
      }
    }

    /** Returns a per-computation, per-key view of the state cache. */
    public ForKey forKey(ByteString key, String stateFamily, long cacheToken, long workToken) {
      return new ForKey(computation, key, stateFamily, cacheToken, workToken);
    }
  }

  /** Per-computation, per-key view of the state cache. */
  public class ForKey {
    private final String computation;
    private final ByteString key;
    private final String stateFamily;
    // Cache token must be consistent for the key for the cache to be valid.
    private final long cacheToken;

    // The work token for processing must be greater than the last work token.  As work items are
    // increasing for a key, a less-than or equal to work token indicates that the current token is
    // for stale processing. We don't use the cache so that fetches performed will fail with a no
    // longer valid work token.
    private final long workToken;

    private ForKey(
        String computation, ByteString key, String stateFamily, long cacheToken, long workToken) {
      this.computation = computation;
      this.key = key;
      this.stateFamily = stateFamily;
      this.cacheToken = cacheToken;
      this.workToken = workToken;
    }

    public <T extends State> T get(StateNamespace namespace, StateTag<T> address) {
      return WindmillStateCache.this.get(
          computation, key, stateFamily, cacheToken, workToken, namespace, address);
    }

    // Note that once a value has been put for a given workToken, get calls with that same workToken
    // will fail. This is ok as we only put entries when we are building the commit and will no
    // longer be performing gets for the same work token.
    public <T extends State> void put(
        StateNamespace namespace, StateTag<T> address, T value, long weight) {
      WindmillStateCache.this.put(
          computation, key, stateFamily, cacheToken, workToken, namespace, address, value, weight);
    }
  }

  /** Returns a per-computation view of the state cache. */
  public ForComputation forComputation(String computation) {
    return new ForComputation(computation);
  }

  private <T extends State> T get(
      String computation,
      ByteString processingKey,
      String stateFamily,
      long cacheToken,
      long workToken,
      StateNamespace namespace,
      StateTag<T> address) {
    StateId id = new StateId(computation, processingKey, stateFamily, namespace);
    StateCacheEntry entry = stateCache.getIfPresent(id);
    if (entry == null) {
      return null;
    }
    if (entry.getCacheToken() != cacheToken) {
      stateCache.invalidate(id);
      return null;
    }
    if (workToken <= entry.getLastWorkToken()) {
      // We don't used the cached item but we don't invalidate it.
      return null;
    }
    return entry.get(namespace, address);
  }

  private <T extends State> void put(
      String computation,
      ByteString processingKey,
      String stateFamily,
      long cacheToken,
      long workToken,
      StateNamespace namespace,
      StateTag<T> address,
      T value,
      long weight) {
    StateId id = new StateId(computation, processingKey, stateFamily, namespace);
    StateCacheEntry entry = stateCache.getIfPresent(id);
    if (entry == null) {
      synchronized (this) {
        keyIndex.put(id.getComputationKey(), id);
      }
    }
    if (entry == null || entry.getCacheToken() != cacheToken) {
      entry = new StateCacheEntry(cacheToken);
      this.displayedWeight += (int) id.getWeight();
      this.displayedWeight += (int) entry.getWeight();
    }
    entry.setLastWorkToken(workToken);
    this.displayedWeight += (int) entry.put(namespace, address, value, weight);
    // Always add back to the cache to update the weight.
    stateCache.put(id, entry);
  }

  private static class ComputationKey {
    private final String computation;
    private final ByteString key;

    public ComputationKey(String computation, ByteString key) {
      this.computation = computation;
      this.key = key;
    }

    public ByteString getKey() {
      return key;
    }

    @Override
    public boolean equals(Object that) {
      if (that instanceof ComputationKey) {
        ComputationKey other = (ComputationKey) that;
        return computation.equals(other.computation) && key.equals(other.key);
      }
      return false;
    }

    @Override
    public int hashCode() {
      return Objects.hash(computation, key);
    }
  }

  /** Struct identifying a cache entry that contains all data for a key and namespace. */
  private static class StateId implements Weighted {
    private final ComputationKey computationKey;
    private final String stateFamily;
    private final Object namespaceKey;

    public StateId(
        String computation,
        ByteString processingKey,
        String stateFamily,
        StateNamespace namespace) {
      this.computationKey = new ComputationKey(computation, processingKey);
      this.stateFamily = stateFamily;
      this.namespaceKey = namespace.getCacheKey();
    }

    @Override
    public boolean equals(Object other) {
      if (other instanceof StateId) {
        StateId otherId = (StateId) other;
        return computationKey.equals(otherId.computationKey)
            && stateFamily.equals(otherId.stateFamily)
            && namespaceKey.equals(otherId.namespaceKey);
      }
      return false;
    }

    public ComputationKey getComputationKey() {
      return computationKey;
    }

    @Override
    public int hashCode() {
      return Objects.hash(computationKey, namespaceKey);
    }

    @Override
    public long getWeight() {
      return (long) computationKey.getKey().size() + PER_STATE_ID_OVERHEAD;
    }
  }

  /**
   * Entry in the state cache that stores a map of values, a cache token representing the validity
   * of the values, and a work token that is increasing to ensure sequential processing.
   */
  private static class StateCacheEntry implements Weighted {
    private final long cacheToken;
    private long lastWorkToken;
    private final Map<NamespacedTag<?>, WeightedValue<?>> values;
    private long weight;

    public StateCacheEntry(long cacheToken) {
      this.values = new HashMap<>(INITIAL_HASH_MAP_CAPACITY);
      this.cacheToken = cacheToken;
      this.lastWorkToken = Long.MIN_VALUE;
      this.weight = 0;
    }

    public void setLastWorkToken(long workToken) {
      this.lastWorkToken = workToken;
    }

    @SuppressWarnings("unchecked")
    public <T extends State> T get(StateNamespace namespace, StateTag<T> tag) {
      WeightedValue<T> weightedValue =
          (WeightedValue<T>) values.get(new NamespacedTag<>(namespace, tag));
      return weightedValue == null ? null : weightedValue.value;
    }

    public <T extends State> long put(
        StateNamespace namespace, StateTag<T> tag, T value, long weight) {
      @SuppressWarnings("unchecked")
      WeightedValue<T> weightedValue =
          (WeightedValue<T>) values.get(new NamespacedTag<>(namespace, tag));
      long weightDelta = 0;
      if (weightedValue == null) {
        weightedValue = new WeightedValue<>();
        weightDelta += HASH_MAP_ENTRY_OVERHEAD;
      } else {
        weightDelta -= weightedValue.weight;
      }
      weightedValue.value = value;
      weightedValue.weight = weight;
      weightDelta += weight;
      this.weight += weightDelta;
      values.put(new NamespacedTag<>(namespace, tag), weightedValue);
      return weightDelta;
    }

    @Override
    public long getWeight() {
      return weight + PER_CACHE_ENTRY_OVERHEAD;
    }

    public long getCacheToken() {
      return cacheToken;
    }

    public long getLastWorkToken() {
      return lastWorkToken;
    }

    private static class NamespacedTag<T extends State> {
      private final StateNamespace namespace;
      private final Equivalence.Wrapper<StateTag> tag;

      NamespacedTag(StateNamespace namespace, StateTag<T> tag) {
        this.namespace = namespace;
        this.tag = StateTags.ID_EQUIVALENCE.wrap((StateTag) tag);
      }

      @Override
      public boolean equals(Object other) {
        if (!(other instanceof NamespacedTag)) {
          return false;
        }
        NamespacedTag<?> that = (NamespacedTag<?>) other;
        return namespace.equals(that.namespace) && tag.equals(that.tag);
      }

      @Override
      public int hashCode() {
        return Objects.hash(namespace, tag);
      }
    }

    private static class WeightedValue<T> {
      public long weight = 0;
      public T value = null;
    }
  }

  /** Print summary statistics of the cache to the given {@link PrintWriter}. */
  @Override
  public void appendSummaryHtml(PrintWriter response) {
    response.println("Cache Stats: <br><table border=0>");
    response.println(
        "<tr><th>Hit Ratio</th><th>Evictions</th><th>Size</th><th>Weight</th></tr><tr>");
    response.println("<th>" + stateCache.stats().hitRate() + "</th>");
    response.println("<th>" + stateCache.stats().evictionCount() + "</th>");
    response.println("<th>" + stateCache.size() + "</th>");
    response.println("<th>" + getWeight() + "</th>");
    response.println("</tr></table><br>");
  }

  public BaseStatusServlet statusServlet() {
    return new BaseStatusServlet("/cachez") {
      @Override
      protected void doGet(HttpServletRequest request, HttpServletResponse response)
          throws IOException, ServletException {

        PrintWriter writer = response.getWriter();
        writer.println("<h1>Cache Information</h1>");
        appendSummaryHtml(writer);
      }
    };
  }
}
