blob: 51f2236c972ac570985acf9f3d00c3b121f65093 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nlpcraft.probe.mgrs.conversation
import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Predicate
import com.typesafe.scalalogging.LazyLogging
import io.opencensus.trace.Span
import org.apache.nlpcraft.common._
import org.apache.nlpcraft.common.ascii.NCAsciiTable
import org.apache.nlpcraft.common.opencensus.NCOpenCensusTrace
import org.apache.nlpcraft.model._
import scala.collection.JavaConverters._
import scala.collection.mutable
/**
* An active conversation is an ordered set of utterances for the specific user and data model.
*/
case class NCConversation(
usrId: Long,
mdlId: String,
timeoutMs: Long,
maxDepth: Int
) extends LazyLogging with NCOpenCensusTrace {
private final val data = new ConcurrentHashMap[String, Object]()
/**
*
* @param token
* @param tokenTypeUsageTime
*/
case class TokenHolder(token: NCToken, var tokenTypeUsageTime: Long = 0)
/**
*
* @param holders Tokens holders.
* @param srvReqId Server request ID. Used just for logging.
* @param tstamp Request timestamp. Used just for logging.
*/
case class ConversationItem(holders: mutable.ArrayBuffer[TokenHolder], srvReqId: String, tstamp: Long)
// Short-Term-Memory.
private val stm = mutable.ArrayBuffer.empty[ConversationItem]
private val lastToks = mutable.ArrayBuffer.empty[Iterable[NCToken]]
@volatile private var ctx: util.List[NCToken] = new util.ArrayList[NCToken]()
@volatile private var lastUpdateTstamp = U.nowUtcMs()
@volatile private var depth = 0
/**
*
*/
private def squeezeTokens(): Unit = {
require(Thread.holdsLock(stm))
stm --= stm.filter(_.holders.isEmpty)
}
/**
* Gets called on each input request for given user and model.
*
* @param parent Optional parent span.
*/
def updateTokens(parent: Span = null): Unit = startScopedSpan("updateTokens", parent) { _ ⇒
val now = U.nowUtcMs()
stm.synchronized {
depth += 1
// Conversation cleared by timeout or when there are too much unsuccessful requests.
if (now - lastUpdateTstamp > timeoutMs) {
stm.clear()
logger.trace(s"Conversation is reset by timeout [" +
s"usrId=$usrId, " +
s"mdlId=$mdlId" +
s"]")
}
else if (depth > maxDepth) {
stm.clear()
logger.trace(s"Conversation is reset after reaching max depth [" +
s"usrId=$usrId, " +
s"mdlId=$mdlId" +
s"]")
}
else {
val minUsageTime = now - timeoutMs
val toks = lastToks.flatten
for (item ← stm) {
val delHs =
// Deleted by timeout for tokens type or when token type used too many requests ago.
item.holders.filter(h ⇒ h.tokenTypeUsageTime < minUsageTime || !toks.contains(h.token))
if (delHs.nonEmpty) {
item.holders --= delHs
logger.trace(
s"Conversation overridden tokens removed [" +
s"usrId=$usrId, " +
s"mdlId=$mdlId, " +
s"srvReqId=${item.srvReqId}, " +
s"toks=${delHs.map(_.token).mkString("[", ", ", "]")}" +
s"]"
)
}
}
squeezeTokens()
}
lastUpdateTstamp = now
ctx = new util.ArrayList[NCToken](stm.flatMap(_.holders.map(_.token)).asJava)
ack()
}
}
/**
* Clears all tokens from this conversation satisfying given predicate.
*
* @param p Java-side predicate.
* @param parent Optional parent span.
*/
def clearTokens(p: Predicate[NCToken], parent: Span = null): Unit =
startScopedSpan("clearTokens", parent) { _ ⇒
stm.synchronized {
for (item ← stm)
item.holders --= item.holders.filter(h ⇒ p.test(h.token))
squeezeTokens()
ctx = ctx.asScala.filter(tok ⇒ !p.test(tok)).asJava
}
logger.trace(s"Conversation is cleared [" +
s"usrId=$usrId, " +
s"mdlId=$mdlId" +
s"]")
}
/**
* Clears all tokens from this conversation satisfying given predicate.
*
* @param p Scala-side predicate.
*/
def clearTokens(p: NCTokenBoolean): Unit =
clearTokens(new Predicate[NCToken] {
override def test(t: NCToken): Boolean = p(t)
})
/**
* Adds given tokens to the conversation.
*
* @param srvReqId Server request ID.
* @param toks Tokens to add to the conversation STM.
* @param parent Optional parent span.
*/
def addTokens(srvReqId: String, toks: Seq[NCToken], parent: Span = null): Unit =
startScopedSpan("addTokens", parent, "srvReqId" → srvReqId) { _ ⇒
stm.synchronized {
depth = 0
// Last used tokens processing.
lastToks += toks
val delCnt = lastToks.length - maxDepth
if (delCnt > 0)
lastToks.remove(0, delCnt)
val senToks = toks.
filter(_.getServerRequestId == srvReqId).
filter(t ⇒ !t.isFreeWord && !t.isStopWord)
if (senToks.nonEmpty) {
// Adds new conversation element.
stm += ConversationItem(
mutable.ArrayBuffer.empty[TokenHolder] ++ senToks.map(TokenHolder(_)),
srvReqId,
lastUpdateTstamp
)
logger.trace(
s"Added new tokens to the conversation [" +
s"usrId=$usrId, " +
s"mdlId=$mdlId, " +
s"srvReqId=$srvReqId, " +
s"toks=${toks.mkString("[", ", ", "]")}" +
s"]"
)
val registered = mutable.HashSet.empty[Seq[String]]
for (
item ← stm.reverse;
(gs, hs) ← item.holders.groupBy(
t ⇒ if (t.token.getGroups != null) t.token.getGroups.asScala else Seq.empty
)
) {
val grps = gs.sorted
// Reversed iteration.
// N : (A, B) → registered.
// N-1 : (C) → registered.
// N-2 : (A, B) or (A, B, X) etc → deleted, because registered has less groups.
registered.find(grps.containsSlice) match {
case Some(_)
item.holders --= hs
logger.trace(
"Conversation tokens are overridden [" +
s"usrId=$usrId, " +
s"mdlId=$mdlId, " +
s"srvReqId=$srvReqId, " +
s"groups=${grps.mkString(", ")}, " +
s"toks=${hs.map(_.token).mkString("[", ", ", "]")}" +
s"]"
)
case None ⇒ registered += grps
}
}
// Updates tokens usage time.
stm.foreach(_.holders.filter(h ⇒ toks.contains(h.token)).foreach(_.tokenTypeUsageTime = lastUpdateTstamp))
squeezeTokens()
}
}
}
/**
* Prints out ASCII table for current STM.
*/
private def ack(): Unit = {
require(Thread.holdsLock(stm))
if (ctx.isEmpty)
logger.trace(s"Conversation context is empty for [" +
s"mdlId=$mdlId, " +
s"usrId=$usrId" +
s"]")
else {
val tbl = NCAsciiTable("Token ID", "Groups", "Text", "Value", "From request")
ctx.asScala.foreach(tok ⇒ tbl += (
tok.getId,
tok.getGroups.asScala.mkString(", "),
tok.normText,
tok.getValue,
tok.getServerRequestId
))
logger.trace(s"Conversation tokens [" +
s"mdlId=$mdlId, " +
s"usrId=$usrId" +
s"]:\n${tbl.toString()}")
}
}
/**
*
* @param parent Optional parent span.
* @return
*/
def getTokens(parent: Span = null): util.List[NCToken] = startScopedSpan("getTokens", parent) { _ ⇒
stm.synchronized {
val srvReqIds = ctx.asScala.map(_.getServerRequestId).distinct.zipWithIndex.toMap
val toks = ctx.asScala.groupBy(_.getServerRequestId).toSeq.sortBy(p ⇒ srvReqIds(p._1)).reverse.flatMap(_._2)
new util.ArrayList[NCToken](toks.asJava)
}
}
/**
*
*/
def getUserData: util.Map[String, Object] = data
}