package org.apache.helix;

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.net.Socket;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import org.apache.helix.PropertyKey.Builder;
import org.apache.helix.manager.zk.ZKHelixDataAccessor;
import org.apache.helix.manager.zk.ZkBaseDataAccessor;
import org.apache.helix.zookeeper.api.client.HelixZkClient;
import org.apache.helix.model.ExternalView;
import org.apache.helix.zookeeper.api.client.RealmAwareZkClient;
import org.apache.helix.zookeeper.zkclient.IZkChildListener;
import org.apache.helix.zookeeper.zkclient.IZkDataListener;
import org.apache.helix.zookeeper.zkclient.IZkStateListener;
import org.apache.helix.zookeeper.zkclient.ZkClient;
import org.apache.helix.zookeeper.zkclient.ZkConnection;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.Watcher.Event.EventType;
import org.apache.zookeeper.Watcher.Event.KeeperState;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.ZooKeeper.States;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.apache.helix.zookeeper.datamodel.ZNRecord;


public class ZkTestHelper {
  private static Logger LOG = LoggerFactory.getLogger(ZkTestHelper.class);
  private static ExecutorService _executor = Executors.newSingleThreadExecutor();

  static {
    // Logger.getRootLogger().setLevel(Level.DEBUG);
  }

  /**
   * Simulate a zk state change by calling {@link ZkClient#process(WatchedEvent)} directly
   */
  public static void simulateZkStateReconnected(RealmAwareZkClient client) {
    ZkClient zkClient = (ZkClient) client;
    WatchedEvent event = new WatchedEvent(EventType.None, KeeperState.Disconnected, null);
    zkClient.process(event);
    event = new WatchedEvent(EventType.None, KeeperState.SyncConnected, null);
    zkClient.process(event);
  }

  /**
   * Get zk connection session id
   * @param client
   * @return
   */
  public static String getSessionId(RealmAwareZkClient client) {
    ZkConnection connection = (ZkConnection) ((ZkClient) client).getConnection();
    ZooKeeper curZookeeper = connection.getZookeeper();
    return Long.toHexString(curZookeeper.getSessionId());
  }

  /**
   * Expire current zk session and wait for {@link IZkStateListener#handleNewSession(String)} invoked
   * @param client
   * @throws Exception
   */

  public static void disconnectSession(HelixZkClient client) throws Exception {
    final ZkClient zkClient = (ZkClient) client;
    IZkStateListener listener = new IZkStateListener() {
      @Override
      public void handleStateChanged(KeeperState state) throws Exception {
        // System.err.println("disconnectSession handleStateChanged. state: " + state);
      }

      @Override
      public void handleNewSession(final String sessionId) throws Exception {
        // make sure zkclient is connected again
        zkClient.waitUntilConnected(HelixZkClient.DEFAULT_CONNECTION_TIMEOUT, TimeUnit.SECONDS);

        LOG.info("handleNewSession. sessionId: {}.", sessionId);
      }

      @Override
      public void handleSessionEstablishmentError(Throwable var1) throws Exception {
      }
    };

    zkClient.subscribeStateChanges(listener);
    ZkConnection connection = (ZkConnection) zkClient.getConnection();
    ZooKeeper curZookeeper = connection.getZookeeper();
    LOG.info("Before expiry. sessionId: " + Long.toHexString(curZookeeper.getSessionId()));

    Watcher watcher = new Watcher() {
      @Override
      public void process(WatchedEvent event) {
        LOG.info("Process watchEvent: " + event);
      }
    };

    final ZooKeeper dupZookeeper =
        new ZooKeeper(connection.getServers(), curZookeeper.getSessionTimeout(), watcher,
            curZookeeper.getSessionId(), curZookeeper.getSessionPasswd());
    // wait until connected, then close
    while (dupZookeeper.getState() != States.CONNECTED) {
      Thread.sleep(10);
    }
    dupZookeeper.close();

    connection = (ZkConnection) zkClient.getConnection();
    curZookeeper = connection.getZookeeper();
    zkClient.unsubscribeStateChanges(listener);

    // System.err.println("zk: " + oldZookeeper);
    LOG.info("After expiry. sessionId: " + Long.toHexString(curZookeeper.getSessionId()));
  }

  public static void expireSession(RealmAwareZkClient client) throws Exception {
    final CountDownLatch waitNewSession = new CountDownLatch(1);
    final ZkClient zkClient = (ZkClient) client;

    IZkStateListener listener = new IZkStateListener() {
      @Override
      public void handleStateChanged(KeeperState state) throws Exception {
        LOG.info("IZkStateListener#handleStateChanged, state: " + state);
      }

      @Override
      public void handleNewSession(final String sessionId) throws Exception {
        // make sure zkclient is connected again
        zkClient.waitUntilConnected(HelixZkClient.DEFAULT_CONNECTION_TIMEOUT, TimeUnit.SECONDS);

        LOG.info("handleNewSession. sessionId: {}.", sessionId);
        waitNewSession.countDown();
      }

      @Override
      public void handleSessionEstablishmentError(Throwable var1) throws Exception {
      }
    };

    zkClient.subscribeStateChanges(listener);

    ZkConnection connection = ((ZkConnection) zkClient.getConnection());
    ZooKeeper curZookeeper = connection.getZookeeper();
    String oldSessionId = Long.toHexString(curZookeeper.getSessionId());
    LOG.info("Before session expiry. sessionId: " + oldSessionId + ", zk: " + curZookeeper);

    Watcher watcher = new Watcher() {
      @Override
      public void process(WatchedEvent event) {
        LOG.info("Watcher#process, event: " + event);
      }
    };

    final ZooKeeper dupZookeeper =
        new ZooKeeper(connection.getServers(), curZookeeper.getSessionTimeout(), watcher,
            curZookeeper.getSessionId(), curZookeeper.getSessionPasswd());
    // wait until connected, then close
    while (dupZookeeper.getState() != States.CONNECTED) {
      Thread.sleep(10);
    }
    Assert.assertEquals(dupZookeeper.getState(), States.CONNECTED,
        "Fail to connect to zk using current session info");
    dupZookeeper.close();

    // make sure session expiry really happens
    waitNewSession.await();
    zkClient.unsubscribeStateChanges(listener);

    connection = (ZkConnection) zkClient.getConnection();
    curZookeeper = connection.getZookeeper();

    String newSessionId = Long.toHexString(curZookeeper.getSessionId());
    LOG.info("After session expiry. sessionId: " + newSessionId + ", zk: " + curZookeeper);
    Assert.assertFalse(newSessionId.equals(oldSessionId),
        "Fail to expire current session, zk: " + curZookeeper);
  }

  /**
   * expire zk session asynchronously
   * @param client
   * @throws Exception
   */
  public static void asyncExpireSession(RealmAwareZkClient client) throws Exception {
    final ZkClient zkClient = (ZkClient) client;
    ZkConnection connection = ((ZkConnection) zkClient.getConnection());
    ZooKeeper curZookeeper = connection.getZookeeper();
    LOG.info("Before expiry. sessionId: " + Long.toHexString(curZookeeper.getSessionId()));

    Watcher watcher = new Watcher() {
      @Override
      public void process(WatchedEvent event) {
        LOG.info("Process watchEvent: " + event);
      }
    };

    final ZooKeeper dupZookeeper =
        new ZooKeeper(connection.getServers(), curZookeeper.getSessionTimeout(), watcher,
            curZookeeper.getSessionId(), curZookeeper.getSessionPasswd());
    // wait until connected, then close
    while (dupZookeeper.getState() != States.CONNECTED) {
      Thread.sleep(10);
    }
    dupZookeeper.close();

    connection = (ZkConnection) zkClient.getConnection();
    curZookeeper = connection.getZookeeper();

    // System.err.println("zk: " + oldZookeeper);
    LOG.info("After expiry. sessionId: " + Long.toHexString(curZookeeper.getSessionId()));
  }

  /*
   * stateMap: partition->instance->state
   */
  public static boolean verifyState(RealmAwareZkClient zkclient, String clusterName, String resourceName,
      Map<String, Map<String, String>> expectStateMap, String op) {
    boolean result = true;
    ZkBaseDataAccessor<ZNRecord> baseAccessor = new ZkBaseDataAccessor<ZNRecord>(zkclient);
    ZKHelixDataAccessor accessor = new ZKHelixDataAccessor(clusterName, baseAccessor);
    Builder keyBuilder = accessor.keyBuilder();

    ExternalView extView = accessor.getProperty(keyBuilder.externalView(resourceName));
    Map<String, Map<String, String>> actualStateMap = extView.getRecord().getMapFields();
    for (String partition : actualStateMap.keySet()) {
      for (String expectPartiton : expectStateMap.keySet()) {
        if (!partition.matches(expectPartiton)) {
          continue;
        }

        Map<String, String> actualInstanceStateMap = actualStateMap.get(partition);
        Map<String, String> expectInstanceStateMap = expectStateMap.get(expectPartiton);
        for (String instance : actualInstanceStateMap.keySet()) {
          for (String expectInstance : expectStateMap.get(expectPartiton).keySet()) {
            if (!instance.matches(expectInstance)) {
              continue;
            }

            String actualState = actualInstanceStateMap.get(instance);
            String expectState = expectInstanceStateMap.get(expectInstance);
            boolean equals = expectState.equals(actualState);
            if (op.equals("==") && !equals || op.equals("!=") && equals) {
              System.out.println(
                  partition + "/" + instance + " state mismatch. actual state: " + actualState
                      + ", but expect: " + expectState + ", op: " + op);
              result = false;
            }
          }
        }
      }
    }
    return result;
  }

  /**
   * return the number of listeners on given zk-path
   * @param zkAddr
   * @param path
   * @return
   * @throws Exception
   */
  public static int numberOfListeners(String zkAddr, String path) throws Exception {
    Map<String, Set<String>> listenerMap = getListenersByZkPath(zkAddr);
    if (listenerMap.containsKey(path)) {
      return listenerMap.get(path).size();
    }
    return 0;
  }

  /**
   * return a map from zk-path to a set of zk-session-id that put watches on the zk-path
   * @param zkAddr
   * @return
   * @throws Exception
   */
  public static Map<String, Set<String>> getListenersByZkPath(String zkAddr) throws Exception {
    String splits[] = zkAddr.split(":");
    Map<String, Set<String>> listenerMap = new TreeMap<String, Set<String>>();
    Socket sock = null;
    int retry = 5;

    while (retry > 0) {
      try {
        sock = new Socket(splits[0], Integer.parseInt(splits[1]));
        PrintWriter out = new PrintWriter(sock.getOutputStream(), true);
        BufferedReader in = new BufferedReader(new InputStreamReader(sock.getInputStream()));

        out.println("wchp");

        listenerMap.clear();
        String lastPath = null;
        String line = in.readLine();
        while (line != null) {
          line = line.trim();

          if (line.startsWith("/")) {
            lastPath = line;
            if (!listenerMap.containsKey(lastPath)) {
              listenerMap.put(lastPath, new TreeSet<String>());
            }
          } else if (line.startsWith("0x")) {
            if (lastPath != null && listenerMap.containsKey(lastPath)) {
              listenerMap.get(lastPath).add(line);
            } else {
              LOG.error("Not path associated with listener sessionId: " + line + ", lastPath: "
                  + lastPath);
            }
          } else {
            // LOG.error("unrecognized line: " + line);
          }
          line = in.readLine();
        }
        break;
      } catch (Exception e) {
        // sometimes in test, we see connection-reset exceptions when in.readLine()
        // so add this retry logic
        retry--;
      } finally {
        if (sock != null) {
          sock.close();
        }
      }
    }
    return listenerMap;
  }

  /**
   * return a map from session-id to a set of zk-path that the session has watches on
   * @return
   */
  public static Map<String, Set<String>> getListenersBySession(String zkAddr) throws Exception {
    Map<String, Set<String>> listenerMapByInstance = getListenersByZkPath(zkAddr);

    // convert to index by sessionId
    Map<String, Set<String>> listenerMapBySession = new TreeMap<>();
    for (String path : listenerMapByInstance.keySet()) {
      for (String sessionId : listenerMapByInstance.get(path)) {
        if (!listenerMapBySession.containsKey(sessionId)) {
          listenerMapBySession.put(sessionId, new TreeSet<String>());
        }
        listenerMapBySession.get(sessionId).add(path);
      }
    }

    return listenerMapBySession;
  }

  static java.lang.reflect.Field getField(Class clazz, String fieldName)
      throws NoSuchFieldException {
    try {
      return clazz.getDeclaredField(fieldName);
    } catch (NoSuchFieldException e) {
      Class superClass = clazz.getSuperclass();
      if (superClass == null) {
        throw e;
      } else {
        return getField(superClass, fieldName);
      }
    }
  }

  public static Map<String, List<String>> getZkWatch(RealmAwareZkClient client) throws Exception {
    Map<String, List<String>> lists = new HashMap<String, List<String>>();
    ZkClient zkClient = (ZkClient) client;

    ZkConnection connection = ((ZkConnection) zkClient.getConnection());
    ZooKeeper zk = connection.getZookeeper();

    java.lang.reflect.Field field = getField(zk.getClass(), "watchManager");
    field.setAccessible(true);
    Object watchManager = field.get(zk);

    java.lang.reflect.Field field2 = getField(watchManager.getClass(), "dataWatches");
    field2.setAccessible(true);
    HashMap<String, Set<Watcher>> dataWatches =
        (HashMap<String, Set<Watcher>>) field2.get(watchManager);

    field2 = getField(watchManager.getClass(), "existWatches");
    field2.setAccessible(true);
    HashMap<String, Set<Watcher>> existWatches =
        (HashMap<String, Set<Watcher>>) field2.get(watchManager);

    field2 = getField(watchManager.getClass(), "childWatches");
    field2.setAccessible(true);
    HashMap<String, Set<Watcher>> childWatches =
        (HashMap<String, Set<Watcher>>) field2.get(watchManager);

    lists.put("dataWatches", new ArrayList<>(dataWatches.keySet()));
    lists.put("existWatches", new ArrayList<>(existWatches.keySet()));
    lists.put("childWatches", new ArrayList<>(childWatches.keySet()));

    return lists;
  }

  public static Map<String, Set<IZkDataListener>> getZkDataListener(RealmAwareZkClient client)
      throws Exception {
    java.lang.reflect.Field field = getField(client.getClass(), "_dataListener");
    field.setAccessible(true);
    Map<String, Set<IZkDataListener>> dataListener =
        (Map<String, Set<IZkDataListener>>) field.get(client);
    return dataListener;
  }

  public static Map<String, Set<IZkChildListener>> getZkChildListener(RealmAwareZkClient client)
      throws Exception {
    java.lang.reflect.Field field = getField(client.getClass(), "_childListener");
    field.setAccessible(true);
    Map<String, Set<IZkChildListener>> childListener =
        (Map<String, Set<IZkChildListener>>) field.get(client);
    return childListener;
  }

  public static boolean tryWaitZkEventsCleaned(RealmAwareZkClient zkclient) throws Exception {
    java.lang.reflect.Field field = getField(zkclient.getClass(), "_eventThread");
    field.setAccessible(true);
    Object eventThread = field.get(zkclient);
    // System.out.println("field: " + eventThread);

    java.lang.reflect.Field field2 = getField(eventThread.getClass(), "_events");
    field2.setAccessible(true);
    BlockingQueue queue = (BlockingQueue) field2.get(eventThread);
    // System.out.println("field2: " + queue + ", " + queue.size());

    if (queue == null) {
      LOG.error("fail to get event-queue from zkclient. skip waiting");
      return false;
    }

    for (int i = 0; i < 20; i++) {
      if (queue.size() == 0) {
        return true;
      }
      Thread.sleep(100);
      System.out.println("pending zk-events in queue: " + queue);
    }
    return false;
  }

  public static void injectExpire(RealmAwareZkClient client)
      throws ExecutionException, InterruptedException {
    final ZkClient zkClient = (ZkClient) client;
    Future future = _executor.submit(new Runnable() {
      @Override
      public void run() {
        WatchedEvent event =
            new WatchedEvent(Watcher.Event.EventType.None, Watcher.Event.KeeperState.Expired, null);
        zkClient.process(event);
      }
    });
    future.get();
  }
}
