/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.reef.io.network;

import org.apache.commons.lang3.StringUtils;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.network.util.*;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.wake.Identifier;
import org.apache.reef.wake.IdentifierFactory;
import org.apache.reef.wake.remote.Codec;
import org.apache.reef.wake.remote.address.LocalAddressProvider;
import org.apache.reef.wake.remote.impl.ObjectSerializableCodec;
import org.junit.Assume;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;

import java.util.concurrent.*;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Default Network connection service test.
 */
public class NetworkConnectionServiceTest {
  private static final Logger LOG = Logger.getLogger(NetworkConnectionServiceTest.class.getName());

  private final LocalAddressProvider localAddressProvider;
  private final String localAddress;
  private final Identifier groupCommClientId;
  private final Identifier shuffleClientId;

  public NetworkConnectionServiceTest() throws InjectionException {
    localAddressProvider = Tang.Factory.getTang().newInjector().getInstance(LocalAddressProvider.class);
    localAddress = localAddressProvider.getLocalAddress();

    final IdentifierFactory idFac = new StringIdentifierFactory();
    this.groupCommClientId = idFac.getNewInstance("groupComm");
    this.shuffleClientId = idFac.getNewInstance("shuffle");
  }

  @Rule
  public TestName name = new TestName();

  private void runMessagingNetworkConnectionService(final Codec<String> codec) throws Exception {
    final int numMessages = 2000;
    final Monitor monitor = new Monitor();
    try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(localAddress)) {
      messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);

      try (final Connection<String> conn = messagingTestService.getConnectionFromSenderToReceiver(groupCommClientId)) {
        try {
          conn.open();
          for (int count = 0; count < numMessages; ++count) {
            // send messages to the receiver.
            conn.write("hello" + count);
          }
          monitor.mwait();
        } catch (final NetworkException e) {
          e.printStackTrace();
          throw new RuntimeException(e);
        }
      }
    }
  }

  /**
   * NetworkConnectionService messaging test.
   */
  @Test
  public void testMessagingNetworkConnectionService() throws Exception {
    LOG.log(Level.FINEST, name.getMethodName());
    runMessagingNetworkConnectionService(new StringCodec());
  }

  /**
   * NetworkConnectionService streaming messaging test.
   */
  @Test
  public void testStreamingMessagingNetworkConnectionService() throws Exception {
    LOG.log(Level.FINEST, name.getMethodName());
    runMessagingNetworkConnectionService(new StreamingStringCodec());
  }

  public void runNetworkConnServiceWithMultipleConnFactories(final Codec<String> stringCodec,
                                                             final Codec<Integer> integerCodec)
      throws Exception {
    final ExecutorService executor = Executors.newFixedThreadPool(5);

    final int groupcommMessages = 1000;
    final Monitor monitor = new Monitor();
    try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(localAddress)) {

      messagingTestService.registerTestConnectionFactory(groupCommClientId, groupcommMessages, monitor, stringCodec);

      final int shuffleMessages = 2000;
      final Monitor monitor2 = new Monitor();
      messagingTestService.registerTestConnectionFactory(shuffleClientId, shuffleMessages, monitor2, integerCodec);

      executor.submit(new Runnable() {
        @Override
        public void run() {
          try (final Connection<String> conn =
                   messagingTestService.getConnectionFromSenderToReceiver(groupCommClientId)) {
            conn.open();
            for (int count = 0; count < groupcommMessages; ++count) {
              // send messages to the receiver.
              conn.write("hello" + count);
            }
            monitor.mwait();
          } catch (final Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
          }
        }
      });

      executor.submit(new Runnable() {
        @Override
        public void run() {
          try (final Connection<Integer> conn =
                   messagingTestService.getConnectionFromSenderToReceiver(shuffleClientId)) {
            conn.open();
            for (int count = 0; count < shuffleMessages; ++count) {
              // send messages to the receiver.
              conn.write(count);
            }
            monitor2.mwait();
          } catch (final Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
          }
        }
      });

      monitor.mwait();
      monitor2.mwait();
      executor.shutdown();
    }
  }

  /**
   * Test NetworkService registering multiple connection factories.
   */
  @Test
  public void testMultipleConnectionFactoriesTest() throws Exception {
    LOG.log(Level.FINEST, name.getMethodName());
    runNetworkConnServiceWithMultipleConnFactories(new StringCodec(), new ObjectSerializableCodec<Integer>());
  }

  /**
   * Test NetworkService registering multiple connection factories with Streamingcodec.
   */
  @Test
  public void testMultipleConnectionFactoriesStreamingTest() throws Exception {
    LOG.log(Level.FINEST, name.getMethodName());
    runNetworkConnServiceWithMultipleConnFactories(new StreamingStringCodec(), new StreamingIntegerCodec());
  }

  /**
   * NetworkService messaging rate benchmark.
   */
  @Test
  public void testMessagingNetworkConnServiceRate() throws Exception {

    Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

    LOG.log(Level.FINEST, name.getMethodName());

    final int[] messageSizes = {1, 16, 32, 64, 512, 64 * 1024, 1024 * 1024};

    for (final int size : messageSizes) {
      final String message = StringUtils.repeat('1', size);
      final int numMessages = 300000 / (Math.max(1, size / 512));
      final Monitor monitor = new Monitor();
      final Codec<String> codec = new StringCodec();
      try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(localAddress)) {
        messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);

        try (final Connection<String> conn =
                 messagingTestService.getConnectionFromSenderToReceiver(groupCommClientId)) {

          final long start = System.currentTimeMillis();
          try {
            conn.open();
            for (int count = 0; count < numMessages; ++count) {
              // send messages to the receiver.
              conn.write(message);
            }
            monitor.mwait();
          } catch (final NetworkException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
          }
          final long end = System.currentTimeMillis();

          final double runtime = ((double) end - start) / 1000;
          LOG.log(Level.INFO, "size: " + size + "; messages/s: " + numMessages / runtime +
              " bandwidth(bytes/s): " + ((double) numMessages * 2 * size) / runtime); // x2 for unicode chars
        }
      }
    }
  }

  /**
   * NetworkService messaging rate benchmark.
   */
  @Test
  public void testMessagingNetworkConnServiceRateDisjoint() throws Exception {

    Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

    LOG.log(Level.FINEST, name.getMethodName());

    final BlockingQueue<Object> barrier = new LinkedBlockingQueue<>();

    final int numThreads = 4;
    final int size = 2000;
    final int numMessages = 300000 / (Math.max(1, size / 512));
    final int totalNumMessages = numMessages * numThreads;
    final String message = StringUtils.repeat('1', size);

    final ExecutorService e = Executors.newCachedThreadPool();
    for (int t = 0; t < numThreads; t++) {
      final int tt = t;

      e.submit(new Runnable() {
        public void run() {
          try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(localAddress)) {
            final Monitor monitor = new Monitor();
            final Codec<String> codec = new StringCodec();

            messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);
            try (final Connection<String> conn =
                     messagingTestService.getConnectionFromSenderToReceiver(groupCommClientId)) {
              try {
                conn.open();
                for (int count = 0; count < numMessages; ++count) {
                  // send messages to the receiver.
                  conn.write(message);
                }
                monitor.mwait();
              } catch (final NetworkException e) {
                e.printStackTrace();
                throw new RuntimeException(e);
              }
            }
          } catch (final Exception e) {
            throw new RuntimeException(e);
          }
        }
      });
    }

    // start and time
    final long start = System.currentTimeMillis();
    final Object ignore = new Object();
    for (int i = 0; i < numThreads; i++) {
      barrier.add(ignore);
    }
    e.shutdown();
    e.awaitTermination(100, TimeUnit.SECONDS);
    final long end = System.currentTimeMillis();
    final double runtime = ((double) end - start) / 1000;
    LOG.log(Level.INFO, "size: " + size + "; messages/s: " + totalNumMessages / runtime +
        " bandwidth(bytes/s): " + ((double) totalNumMessages * 2 * size) / runtime); // x2 for unicode chars
  }

  @Test
  public void testMultithreadedSharedConnMessagingNetworkConnServiceRate() throws Exception {

    Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

    LOG.log(Level.FINEST, name.getMethodName());

    final int[] messageSizes = {2000}; // {1,16,32,64,512,64*1024,1024*1024};

    for (final int size : messageSizes) {
      final String message = StringUtils.repeat('1', size);
      final int numMessages = 300000 / (Math.max(1, size / 512));
      final int numThreads = 2;
      final int totalNumMessages = numMessages * numThreads;
      final Monitor monitor = new Monitor();
      final Codec<String> codec = new StringCodec();
      try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(localAddress)) {
        messagingTestService.registerTestConnectionFactory(groupCommClientId, totalNumMessages, monitor, codec);

        final ExecutorService e = Executors.newCachedThreadPool();

        try (final Connection<String> conn =
                 messagingTestService.getConnectionFromSenderToReceiver(groupCommClientId)) {

          final long start = System.currentTimeMillis();
          for (int i = 0; i < numThreads; i++) {
            e.submit(new Runnable() {
              @Override
              public void run() {

                try {
                  conn.open();
                  for (int count = 0; count < numMessages; ++count) {
                    // send messages to the receiver.
                    conn.write(message);
                  }
                } catch (final Exception e) {
                  throw new RuntimeException(e);
                }
              }
            });
          }

          e.shutdown();
          e.awaitTermination(30, TimeUnit.SECONDS);
          monitor.mwait();
          final long end = System.currentTimeMillis();
          final double runtime = ((double) end - start) / 1000;
          LOG.log(Level.INFO, "size: " + size + "; messages/s: " + totalNumMessages / runtime + 
              " bandwidth(bytes/s): " + ((double) totalNumMessages * 2 * size) / runtime); // x2 for unicode chars
        }
      }
    }
  }

  /**
   * NetworkService messaging rate benchmark.
   */
  @Test
  public void testMessagingNetworkConnServiceBatchingRate() throws Exception {

    Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

    LOG.log(Level.FINEST, name.getMethodName());

    final int batchSize = 1024 * 1024;
    final int[] messageSizes = {32, 64, 512};

    for (final int size : messageSizes) {
      final String message = StringUtils.repeat('1', batchSize);
      final int numMessages = 300 / (Math.max(1, size / 512));
      final Monitor monitor = new Monitor();
      final Codec<String> codec = new StringCodec();
      try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(localAddress)) {
        messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);
        try (final Connection<String> conn =
                 messagingTestService.getConnectionFromSenderToReceiver(groupCommClientId)) {
          final long start = System.currentTimeMillis();
          try {
            conn.open();
            for (int i = 0; i < numMessages; i++) {
              conn.write(message);
            }
            monitor.mwait();
          } catch (final NetworkException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
          }

          final long end = System.currentTimeMillis();
          final double runtime = ((double) end - start) / 1000;
          final long numAppMessages = numMessages * batchSize / size;
          LOG.log(Level.INFO, "size: " + size + "; messages/s: " + numAppMessages / runtime +
              " bandwidth(bytes/s): " + ((double) numAppMessages * 2 * size) / runtime); // x2 for unicode chars
        }
      }
    }
  }
}
