[New Scheduler] Add container counter (#5072)

* [New Scheduler] Add container counter

Get container count from ETCD when related data get updated in ETCD

* Fix tests

* Fix tests
diff --git a/core/scheduler/src/main/scala/org/apache/openwhisk/core/scheduler/queue/ContainerCounter.scala b/core/scheduler/src/main/scala/org/apache/openwhisk/core/scheduler/queue/ContainerCounter.scala
new file mode 100644
index 0000000..7859a19
--- /dev/null
+++ b/core/scheduler/src/main/scala/org/apache/openwhisk/core/scheduler/queue/ContainerCounter.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.openwhisk.core.scheduler.queue
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import akka.actor.{Actor, ActorRef, ActorSystem, Props}
+import org.apache.openwhisk.common.Logging
+import org.apache.openwhisk.core.etcd.EtcdClient
+import org.apache.openwhisk.core.etcd.EtcdKV.ContainerKeys
+import org.apache.openwhisk.core.service.{DeleteEvent, PutEvent, UnwatchEndpoint, WatchEndpoint, WatchEndpointOperation}
+
+import scala.collection.concurrent.TrieMap
+import scala.concurrent.{ExecutionContext, Future}
+
+class ContainerCounter(invocationNamespace: String, etcdClient: EtcdClient, watcherService: ActorRef)(
+  implicit val actorSystem: ActorSystem,
+  ec: ExecutionContext,
+  logging: Logging) {
+  private[queue] var existingContainerNumByNamespace: Int = 0
+  private[queue] var inProgressContainerNumByNamespace: Int = 0
+  private[queue] val references = new AtomicInteger(0)
+  private val watcherName = s"container-counter-$invocationNamespace"
+
+  private val inProgressContainerPrefixKeyByNamespace =
+    ContainerKeys.inProgressContainerPrefixByNamespace(invocationNamespace)
+  private val existingContainerPrefixKeyByNamespace =
+    ContainerKeys.existingContainersPrefixByNamespace(invocationNamespace)
+
+  private val watchedKeys = Seq(inProgressContainerPrefixKeyByNamespace, existingContainerPrefixKeyByNamespace)
+
+  private val watcher =
+    actorSystem.actorOf(Props(new Actor {
+      private var countingKeys = Set.empty[String]
+      private var waitingForCountKeys = Set.empty[String]
+
+      override def receive: Receive = {
+        case operation: WatchEndpointOperation if operation.isPrefix =>
+          if (countingKeys
+                .contains(operation.watchKey))
+            waitingForCountKeys += operation.watchKey
+          else {
+            countingKeys += operation.watchKey
+            refreshContainerCount(operation.watchKey)
+          }
+
+        case ReadyToGetCount(key) =>
+          if (waitingForCountKeys.contains(key)) {
+            waitingForCountKeys -= key
+            refreshContainerCount(key)
+          } else
+            countingKeys -= key
+      }
+    }))
+
+  private def refreshContainerCount(key: String): Future[Unit] = {
+    etcdClient
+      .getCount(key)
+      .map { count =>
+        key match {
+          case `inProgressContainerPrefixKeyByNamespace` => inProgressContainerNumByNamespace = count.toInt
+          case `existingContainerPrefixKeyByNamespace`   => existingContainerNumByNamespace = count.toInt
+        }
+        watcher ! ReadyToGetCount(key)
+      }
+      .recover {
+        case t: Throwable =>
+          logging.error(
+            this,
+            s"failed to get the number of existing containers for ${invocationNamespace} due to ${t}.")
+          watcher ! ReadyToGetCount(key)
+      }
+  }
+
+  def increaseReference(): ContainerCounter = {
+    if (references.incrementAndGet() == 1) {
+      watchedKeys.foreach { key =>
+        watcherService.tell(WatchEndpoint(key, "", true, watcherName, Set(PutEvent, DeleteEvent)), watcher)
+      }
+
+    }
+    this
+  }
+
+  def close(): Unit = {
+    if (references.decrementAndGet() == 0) {
+      watchedKeys.foreach { key =>
+        watcherService ! UnwatchEndpoint(key, true, watcherName)
+      }
+      NamespaceContainerCount.instances.remove(invocationNamespace)
+    }
+  }
+}
+
+object NamespaceContainerCount {
+  private[queue] val instances = TrieMap[String, ContainerCounter]()
+  def apply(namespace: String, etcdClient: EtcdClient, watcherService: ActorRef)(implicit actorSystem: ActorSystem,
+                                                                                 ec: ExecutionContext,
+                                                                                 logging: Logging): ContainerCounter = {
+    instances
+      .getOrElseUpdate(namespace, new ContainerCounter(namespace, etcdClient, watcherService))
+      .increaseReference()
+  }
+}
+
+case class ReadyToGetCount(key: String)
diff --git a/tests/src/test/scala/org/apache/openwhisk/core/scheduler/queue/test/ContainerCounterTests.scala b/tests/src/test/scala/org/apache/openwhisk/core/scheduler/queue/test/ContainerCounterTests.scala
new file mode 100644
index 0000000..e9e6694
--- /dev/null
+++ b/tests/src/test/scala/org/apache/openwhisk/core/scheduler/queue/test/ContainerCounterTests.scala
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.openwhisk.core.scheduler.queue.test
+
+import java.{lang, util}
+import java.util.concurrent.Executor
+
+import akka.actor.ActorSystem
+import akka.testkit.{TestKit, TestProbe}
+import com.google.protobuf.ByteString
+import com.ibm.etcd.api.Event.EventType
+import com.ibm.etcd.api.{Event, KeyValue, LeaseKeepAliveResponse, ResponseHeader, TxnResponse}
+import com.ibm.etcd.client.kv.KvClient.Watch
+import com.ibm.etcd.client.kv.WatchUpdate
+import com.ibm.etcd.client.{EtcdClient => Client}
+import common.StreamLogging
+import org.apache.openwhisk.core.entity.{
+  CreationId,
+  DocRevision,
+  EntityName,
+  EntityPath,
+  FullyQualifiedEntityName,
+  SchedulerInstanceId
+}
+import org.apache.openwhisk.core.etcd.EtcdClient
+import org.apache.openwhisk.core.etcd.EtcdKV.ContainerKeys
+import org.apache.openwhisk.core.etcd.EtcdKV.ContainerKeys.inProgressContainer
+import org.apache.openwhisk.core.scheduler.queue.NamespaceContainerCount
+import org.apache.openwhisk.core.service.{DeleteEvent, PutEvent, UnwatchEndpoint, WatchEndpoint, WatcherService}
+import org.junit.runner.RunWith
+import org.scalamock.scalatest.MockFactory
+import org.scalatest.concurrent.ScalaFutures
+import org.scalatest.{FlatSpecLike, Matchers}
+import org.scalatest.junit.JUnitRunner
+
+import scala.concurrent.Future
+import scala.concurrent.duration.TimeUnit
+
+@RunWith(classOf[JUnitRunner])
+class ContainerCounterTests
+    extends TestKit(ActorSystem("ContainerCounter"))
+    with FlatSpecLike
+    with Matchers
+    with MockFactory
+    with ScalaFutures
+    with StreamLogging {
+
+  private implicit val ec = system.dispatcher
+
+  private val namespace = "testNamespace"
+  private val namespace2 = "testNamespace2"
+  private val action = "testAction"
+  private val action2 = "testAction2"
+  private val schedulerId = SchedulerInstanceId("0")
+  private val fqn = FullyQualifiedEntityName(EntityPath(namespace), EntityName(action))
+  private val revision = DocRevision("1-testRev1")
+  private val fqn2 = FullyQualifiedEntityName(EntityPath(namespace), EntityName(action2))
+  private val revision2 = DocRevision("1-testRev2")
+  private val fqn3 = FullyQualifiedEntityName(EntityPath(namespace2), EntityName(action2))
+  private val revision3 = DocRevision("1-testRev3")
+  private val watcherName = s"container-counter-$namespace"
+  private val inProgressContainerPrefixKeyByNamespace =
+    ContainerKeys.inProgressContainerPrefixByNamespace(namespace)
+  private val existingContainerPrefixKeyByNamespace =
+    ContainerKeys.existingContainersPrefixByNamespace(namespace)
+
+  val client: Client = {
+    val hostAndPorts = "172.17.0.1:2379"
+    Client.forEndpoints(hostAndPorts).withPlainText().build()
+  }
+
+  it should "be shared for a same namespace" in {
+    val etcd = mock[EtcdClient]
+    val watcher = TestProbe()
+    val res = Future.sequence {
+      (0 to 99).map { _ =>
+        Future {
+          NamespaceContainerCount(namespace, etcd, watcher.ref)
+        }
+      }
+    }.futureValue
+
+    // only create one instance
+    res.toSet.size shouldBe 1
+    res.head.references.intValue shouldBe 100
+
+    // only register watch endpoint once
+    watcher.expectMsgAllOf(
+      WatchEndpoint(inProgressContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent)),
+      WatchEndpoint(existingContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent)))
+    watcher.expectNoMessage()
+    NamespaceContainerCount.instances.size shouldBe 1
+    NamespaceContainerCount.instances.clear()
+  }
+
+  it should "and only should be closed when all references are closed" in {
+    val etcd = mock[EtcdClient]
+    val watcher = TestProbe()
+    val res = Future.sequence {
+      (0 to 99).map { _ =>
+        Future {
+          NamespaceContainerCount(namespace, etcd, watcher.ref)
+        }
+      }
+    }.futureValue
+
+    // only create one instance
+    res.toSet.size shouldBe 1
+    res.head.references.intValue shouldBe 100
+
+    // only register watch endpoint once
+    watcher.expectMsgAllOf(
+      WatchEndpoint(inProgressContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent)),
+      WatchEndpoint(existingContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent)))
+    watcher.expectNoMessage()
+    NamespaceContainerCount.instances.size shouldBe 1
+
+    // close 50 times
+    Future.sequence {
+      (0 to 49).map { _ =>
+        Future(res.head.close())
+      }
+    }.futureValue
+    res.head.references.intValue shouldBe 50
+
+    // should not unregister watch endpoint
+    watcher.expectNoMessage()
+    NamespaceContainerCount.instances.size shouldBe 1
+
+    // close left 50 times
+    Future.sequence {
+      (0 to 49).map { _ =>
+        Future(res.head.close())
+      }
+    }.futureValue
+    res.head.references.intValue shouldBe 0
+
+    // only unregister watch endpoint once
+    watcher.expectMsgAllOf(
+      UnwatchEndpoint(inProgressContainerPrefixKeyByNamespace, true, watcherName),
+      UnwatchEndpoint(existingContainerPrefixKeyByNamespace, true, watcherName))
+    watcher.expectNoMessage()
+    NamespaceContainerCount.instances.size shouldBe 0
+  }
+
+  it should "update the number of containers based on Watch event" in {
+    val mockEtcdClient = new MockEtcdClient(client, true)
+    val watcher = system.actorOf(WatcherService.props(mockEtcdClient))
+
+    val ns = NamespaceContainerCount(namespace, mockEtcdClient, watcher)
+    Thread.sleep(1000)
+
+    ns.inProgressContainerNumByNamespace shouldBe 0
+    ns.existingContainerNumByNamespace shouldBe 0
+
+    val invoker = "invoker0"
+
+    mockEtcdClient.publishEvents(
+      EventType.PUT,
+      inProgressContainer(namespace, fqn, revision, schedulerId, CreationId("testId")),
+      "test-value")
+
+    mockEtcdClient.publishEvents(
+      EventType.PUT,
+      s"${ContainerKeys.existingContainers(namespace, fqn, DocRevision.empty)}/${invoker}/test-container",
+      "test-value")
+
+    Thread.sleep(1000)
+    ns.inProgressContainerNumByNamespace shouldBe 1
+    ns.existingContainerNumByNamespace shouldBe 1
+
+    // other action's containers under same namespace should have effect
+    mockEtcdClient.publishEvents(
+      EventType.PUT,
+      inProgressContainer(namespace, fqn2, revision2, schedulerId, CreationId("testId2")),
+      "test-value")
+
+    mockEtcdClient.publishEvents(
+      EventType.PUT,
+      s"${ContainerKeys.existingContainers(namespace, fqn2, DocRevision.empty)}/${invoker}/test-container2",
+      "test-value")
+
+    Thread.sleep(1000)
+    ns.inProgressContainerNumByNamespace shouldBe 2
+    ns.existingContainerNumByNamespace shouldBe 2
+
+    // other namespace's containers should have no influence
+    mockEtcdClient.publishEvents(
+      EventType.PUT,
+      inProgressContainer(namespace2, fqn3, revision3, schedulerId, CreationId("testId3")),
+      "test-value")
+
+    mockEtcdClient.publishEvents(
+      EventType.PUT,
+      s"${ContainerKeys.existingContainers(namespace2, fqn3, DocRevision.empty)}/${invoker}/test-container3",
+      "test-value")
+
+    Thread.sleep(1000)
+    ns.inProgressContainerNumByNamespace shouldBe 2
+    ns.existingContainerNumByNamespace shouldBe 2
+
+    // inProgress containers should have no effect on existing containers
+    mockEtcdClient.publishEvents(
+      EventType.DELETE,
+      inProgressContainer(namespace, fqn, revision, schedulerId, CreationId("testId")),
+      "test-value")
+
+    mockEtcdClient.publishEvents(
+      EventType.DELETE,
+      inProgressContainer(namespace, fqn2, revision2, schedulerId, CreationId("testId2")),
+      "test-value")
+
+    Thread.sleep(1000)
+    ns.inProgressContainerNumByNamespace shouldBe 0
+    ns.existingContainerNumByNamespace shouldBe 2
+
+    // existing containers should have no effect on inProgress containers
+    mockEtcdClient.publishEvents(
+      EventType.DELETE,
+      s"${ContainerKeys.existingContainers(namespace, fqn, DocRevision.empty)}/${invoker}/test-container",
+      "test-value")
+
+    mockEtcdClient.publishEvents(
+      EventType.DELETE,
+      s"${ContainerKeys.existingContainers(namespace, fqn2, DocRevision.empty)}/${invoker}/test-container2",
+      "test-value")
+
+    Thread.sleep(1000)
+    ns.inProgressContainerNumByNamespace shouldBe 0
+    ns.existingContainerNumByNamespace shouldBe 0
+
+    NamespaceContainerCount.instances.clear()
+  }
+
+  class MockEtcdClient(client: Client, isLeader: Boolean, leaseNotFound: Boolean = false, failedCount: Int = 1)
+      extends EtcdClient(client)(ec) {
+    var count = 0
+    var storedValues = List.empty[(String, String, Long, Long)]
+    var dataMap = Map[String, String]()
+
+    override def putTxn[T](key: String, value: T, cmpVersion: Long, leaseId: Long): Future[TxnResponse] = {
+      if (isLeader) {
+        storedValues = (key, value.toString, cmpVersion, leaseId) :: storedValues
+      }
+      Future.successful(TxnResponse.newBuilder().setSucceeded(isLeader).build())
+    }
+
+    /*
+     * this method count the number of entries whose key starts with the given prefix
+     */
+    override def getCount(prefixKey: String): Future[Long] = {
+      Future.successful { dataMap.count(data => data._1.startsWith(prefixKey)) }
+    }
+
+    var watchCallbackMap = Map[String, WatchUpdate => Unit]()
+
+    override def keepAliveOnce(leaseId: Long): Future[LeaseKeepAliveResponse] =
+      Future.successful(LeaseKeepAliveResponse.newBuilder().setID(leaseId).build())
+
+    /*
+     * this method adds one callback for the given key in watchCallbackMap.
+     *
+     * Note: Currently it only supports prefix-based watch.
+     */
+    override def watchAllKeys(next: WatchUpdate => Unit, error: Throwable => Unit, completed: () => Unit): Watch = {
+
+      watchCallbackMap += "" -> next
+      new Watch {
+        override def close(): Unit = {}
+
+        override def addListener(listener: Runnable, executor: Executor): Unit = {}
+
+        override def cancel(mayInterruptIfRunning: Boolean): Boolean = true
+
+        override def isCancelled: Boolean = true
+
+        override def isDone: Boolean = true
+
+        override def get(): lang.Boolean = true
+
+        override def get(timeout: Long, unit: TimeUnit): lang.Boolean = true
+      }
+    }
+
+    /*
+     * This method stores the data in dataMap to simulate etcd.put()
+     * After then, it calls the registered watch callback for the given key
+     * So we don't need to call put() to simulate watch API.
+     * Expected order of calls is 1. watch(), 2.publishEvents(). Data will be stored in dataMap and
+     * callbacks in the callbackMap for the given prefix will be called by publishEvents()
+     *
+     * Note: watch callback is currently registered based on prefix only.
+     */
+    def publishEvents(eventType: EventType, key: String, value: String): Unit = {
+      val eType = eventType match {
+        case EventType.PUT =>
+          dataMap += key -> value
+          EventType.PUT
+
+        case EventType.DELETE =>
+          dataMap -= key
+          EventType.DELETE
+
+        case EventType.UNRECOGNIZED => Event.EventType.UNRECOGNIZED
+      }
+      val event = Event
+        .newBuilder()
+        .setType(eType)
+        .setPrevKv(
+          KeyValue
+            .newBuilder()
+            .setKey(ByteString.copyFromUtf8(key))
+            .setValue(ByteString.copyFromUtf8(value))
+            .build())
+        .setKv(
+          KeyValue
+            .newBuilder()
+            .setKey(ByteString.copyFromUtf8(key))
+            .setValue(ByteString.copyFromUtf8(value))
+            .build())
+        .build()
+
+      // find the callbacks which has the proper prefix for the given key
+      watchCallbackMap.filter(callback => key.startsWith(callback._1)).foreach { callback =>
+        callback._2(new mockWatchUpdate().addEvents(event))
+      }
+    }
+  }
+
+  class mockWatchUpdate extends WatchUpdate {
+    private var eventLists: util.List[Event] = new util.ArrayList[Event]()
+    override def getHeader: ResponseHeader = ???
+
+    def addEvents(event: Event): WatchUpdate = {
+      eventLists.add(event)
+      this
+    }
+
+    override def getEvents: util.List[Event] = eventLists
+  }
+}