blob: 75047c7804ad636a4fe6aa7bf8172fc64a953381 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nlpcraft.probe.mgrs.cmd
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import io.opencensus.trace.Span
import org.apache.nlpcraft.common.nlp.NCNlpSentence
import org.apache.nlpcraft.common.{NCService, _}
import org.apache.nlpcraft.model.{NCCustomParser, NCElement, NCModelView, NCToken, NCValue, NCValueLoader}
import org.apache.nlpcraft.probe.mgrs.NCProbeMessage
import org.apache.nlpcraft.probe.mgrs.conn.NCConnectionManager
import org.apache.nlpcraft.probe.mgrs.conversation.NCConversationManager
import org.apache.nlpcraft.probe.mgrs.dialogflow.NCDialogFlowManager
import org.apache.nlpcraft.probe.mgrs.model.NCModelManager
import org.apache.nlpcraft.probe.mgrs.nlp.NCProbeEnrichmentManager
import java.io.{Serializable => JSerializable}
import java.util.{Collections, Optional, List => JList}
import java.{lang, util}
import scala.jdk.CollectionConverters.{ListHasAsScala, MapHasAsJava, MapHasAsScala, SeqHasAsJava, SetHasAsJava, SetHasAsScala}
/**
* Probe commands processor.
*/
object NCCommandManager extends NCService {
private final val JS_MAPPER = new ObjectMapper()
JS_MAPPER.registerModule(DefaultScalaModule)
override def start(parent: Span): NCService = startScopedSpan("start", parent) { _ =>
ackStarting()
ackStarted()
}
/**
* Stops this service.
*
* @param parent Optional parent span.
*/
override def stop(parent: Span): Unit = startScopedSpan("stop", parent) { _ =>
ackStopping()
ackStopped()
}
/**
*
* @param mkMsg
* @param mkErrorMsg
*/
private def send0(mkMsg: () => NCProbeMessage, mkErrorMsg: Throwable => NCProbeMessage, parent: Span = null): Unit = {
val msgOpt: Option[NCProbeMessage] =
try
Some(mkMsg())
catch {
case e: Throwable =>
NCConnectionManager.send(mkErrorMsg(e), parent)
None
}
msgOpt match {
case Some(msg) => NCConnectionManager.send(msg, parent)
case None => // No-op.
}
}
/**
*
* @param msg Server message to process.
* @param parent Optional parent span.
*/
def processServerMessage(msg: NCProbeMessage, parent: Span = null): Unit = {
require(isStarted)
startScopedSpan("processServerMessage", parent,
"msgType" -> msg.getType,
"srvReqId" -> msg.dataOpt[String]("srvReqId").getOrElse(""),
"usrId" -> msg.dataOpt[Long]("userId").getOrElse(-1),
"mdlId" -> msg.dataOpt[String]("mdlId").getOrElse("")) { span =>
if (msg.getType != "S2P_PING")
logger.trace(s"Probe server message received: $msg")
try
msg.getType match {
case "S2P_PING" => ()
case "S2P_CLEAR_CONV" =>
val usrId = msg.data[Long]("usrId")
val mdlId = msg.data[String]("mdlId")
NCConversationManager.getConversation(usrId, mdlId, span).clearTokens((_: NCToken) => true)
logger.trace(s"Cleared conversation history upon server request [usrId=$usrId, mdlId=$mdlId]")
case "S2P_CLEAR_DLG" =>
val usrId = msg.data[Long]("usrId")
val mdlId = msg.data[String]("mdlId")
NCDialogFlowManager.clear(usrId, mdlId, span)
logger.trace(s"Cleared dialog history upon server request [usrId=$usrId, mdlId=$mdlId]")
case "S2P_ASK" =>
NCProbeEnrichmentManager.ask(
srvReqId = msg.data[String]("srvReqId"),
txt = msg.data[String]("txt"),
nlpSens = msg.data[JList[NCNlpSentence]]("nlpSens").asScala,
usrId = msg.data[Long]("userId"),
senMeta = msg.data[java.util.Map[String, JSerializable]]("senMeta").asScala.toMap,
mdlId = msg.data[String]("mdlId"),
enableLog = msg.data[Boolean]("enableLog"),
span
)
case "S2P_MODEL_SYNS_INFO" =>
send0(
mkMsg = () => {
val mdlId = msg.data[String]("mdlId")
val mdlData = NCModelManager.getModel(mdlId)
val macros: util.Map[String, String] = mdlData.model.getMacros
val syns: util.Map[String, util.List[String]] = mdlData.model.getElements.asScala.map(p => p.getId -> p.getSynonyms).toMap.asJava
val samples: util.Map[String, util.List[util.List[String]]] = mdlData.samples.map(p => p._1 -> p._2.map(_.asJava).asJava).toMap.asJava
NCProbeMessage(
"P2S_MODEL_SYNS_INFO",
"reqGuid" -> msg.getGuid,
"resp" -> JS_MAPPER.writeValueAsString(
Map(
"macros" -> macros.asInstanceOf[JSerializable],
"synonyms" -> syns.asInstanceOf[JSerializable],
"samples" -> samples.asInstanceOf[JSerializable]
).asJava
)
)
},
mkErrorMsg = e =>
NCProbeMessage(
"P2S_MODEL_SYNS_INFO",
"reqGuid" -> msg.getGuid,
"error" -> e.getLocalizedMessage
),
span
)
case "S2P_MODEL_ELEMENT_INFO" =>
send0(
mkMsg = () => {
val mdlId = msg.data[String]("mdlId")
val elmId = msg.data[String]("elmId")
val mdl = NCModelManager.getModel(mdlId).model
val elm = mdl.getElements.asScala.find(_.getId == elmId).getOrElse(throw new NCE(s"Element not found in model: $elmId"))
val vals: util.Map[String, JList[String]] =
if (elm.getValues != null)
elm.getValues.asScala.map(e => e.getName -> e.getSynonyms).toMap.asJava
else
Collections.emptyMap()
NCProbeMessage(
"P2S_MODEL_ELEMENT_INFO",
"reqGuid" -> msg.getGuid,
"resp" -> JS_MAPPER.writeValueAsString(
Map(
"synonyms" -> elm.getSynonyms.asInstanceOf[JSerializable],
"values" -> vals.asInstanceOf[JSerializable],
"macros" -> mdl.getMacros.asInstanceOf[JSerializable]
).asJava
)
)
},
mkErrorMsg = e =>
NCProbeMessage(
"P2S_MODEL_ELEMENT_INFO",
"reqGuid" -> msg.getGuid,
"error" -> e.getLocalizedMessage
),
span
)
case "S2P_MODEL_INFO" =>
send0(
mkMsg = () => {
val mdlId = msg.data[String]("mdlId")
val mdl = NCModelManager.getModel(mdlId).model
val jsonMdl: NCModelView =
new NCModelView {
// As is.
override def getId: String = mdl.getId
override def getName: String = mdl.getName
override def getVersion: String = mdl.getVersion
override def getDescription: String = mdl.getDescription
override def getOrigin: String = mdl.getOrigin
override def getMaxUnknownWords: Int = mdl.getMaxUnknownWords
override def getMaxFreeWords: Int = mdl.getMaxFreeWords
override def getMaxSuspiciousWords: Int = mdl.getMaxSuspiciousWords
override def getMinWords: Int = mdl.getMinWords
override def getMaxWords: Int = mdl.getMaxWords
override def getMinTokens: Int = mdl.getMinTokens
override def getMaxTokens: Int = mdl.getMaxTokens
override def getMinNonStopwords: Int = mdl.getMinNonStopwords
override def isNonEnglishAllowed: Boolean = mdl.isNonEnglishAllowed
override def isNotLatinCharsetAllowed: Boolean = mdl.isNotLatinCharsetAllowed
override def isSwearWordsAllowed: Boolean = mdl.isSwearWordsAllowed
override def isNoNounsAllowed: Boolean = mdl.isNoNounsAllowed
override def isPermutateSynonyms: Boolean = mdl.isPermutateSynonyms
override def isDupSynonymsAllowed: Boolean = mdl.isDupSynonymsAllowed
override def getMaxTotalSynonyms: Int = mdl.getMaxTotalSynonyms
override def isNoUserTokensAllowed: Boolean = mdl.isNoUserTokensAllowed
override def isSparse: Boolean = mdl.isSparse
override def getMetadata: util.Map[String, AnyRef] = mdl.getMetadata
override def getAdditionalStopWords: util.Set[String] = mdl.getAdditionalStopWords
override def getExcludedStopWords: util.Set[String] = mdl.getExcludedStopWords
override def getSuspiciousWords: util.Set[String] = mdl.getSuspiciousWords
override def getMacros: util.Map[String, String] = mdl.getMacros
override def getEnabledBuiltInTokens: util.Set[String] = mdl.getEnabledBuiltInTokens
override def getAbstractTokens: util.Set[String] = mdl.getAbstractTokens
override def getMaxElementSynonyms: Int = mdl.getMaxElementSynonyms
override def isMaxSynonymsThresholdError: Boolean = mdl.isMaxSynonymsThresholdError
override def getConversationTimeout: Long = mdl.getConversationTimeout
override def getConversationDepth: Int = mdl.getConversationDepth
override def getRestrictedCombinations: util.Map[String, util.Set[String]] = mdl.getRestrictedCombinations
// Cleared.
override def getParsers: JList[NCCustomParser] = null
// Converted.
override def getElements: util.Set[NCElement] = mdl.getElements.asScala.map(e =>
new NCElement {
// As is.
override def getId: String = e.getId
override def getGroups: JList[String] = e.getGroups
override def getMetadata: util.Map[String, AnyRef] = e.getMetadata
override def getDescription: String = e.getDescription
override def getParentId: String = e.getParentId
override def getSynonyms: JList[String] = e.getSynonyms
override def isPermutateSynonyms: Optional[lang.Boolean] = e.isPermutateSynonyms
override def isSparse: Optional[lang.Boolean] = e.isSparse
override def getValues: JList[NCValue] = e.getValues
// Cleared.
override def getValueLoader: Optional[NCValueLoader] = null
}
).asJava
}
NCProbeMessage(
"P2S_MODEL_INFO",
"reqGuid" -> msg.getGuid,
"resp" -> JS_MAPPER.writeValueAsString(jsonMdl)
)
},
mkErrorMsg = e =>
NCProbeMessage(
"P2S_MODEL_SYNS_INFO",
"reqGuid" -> msg.getGuid,
"error" -> e.getLocalizedMessage
),
span
)
case _ =>
logger.error(s"Received unknown server message (you need to update the probe): ${msg.getType}")
}
catch {
case e: Throwable => U.prettyError(logger, s"Error while processing server message (ignoring): ${msg.getType}", e)
}
}
}
}