| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.nlpcraft.server.sugsyn |
| |
| import com.google.gson.Gson |
| import com.google.gson.reflect.TypeToken |
| import io.opencensus.trace.Span |
| import org.apache.http.HttpResponse |
| import org.apache.http.client.ResponseHandler |
| import org.apache.http.util.EntityUtils |
| import org.apache.nlpcraft.common._ |
| import org.apache.nlpcraft.common.config.NCConfigurable |
| import org.apache.nlpcraft.common.nlp.core.NCNlpPorterStemmer |
| import org.apache.nlpcraft.server.probe.NCProbeManager |
| |
| import scala.collection.JavaConverters._ |
| import scala.collection.{Seq, mutable} |
| import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, Promise} |
| import scala.util.{Failure, Success} |
| import java.util |
| import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} |
| import java.util.concurrent._ |
| |
| import org.apache.http.client.methods.HttpPost |
| import org.apache.http.entity.StringEntity |
| import org.apache.http.impl.client.HttpClients |
| import org.apache.nlpcraft.common.makro.NCMacroParser |
| |
| /** |
| * Synonym suggestion manager. |
| */ |
| object NCSuggestSynonymManager extends NCService { |
| // For context word server requests. |
| private final val MAX_LIMIT: Int = 10000 |
| private final val BATCH_SIZE = 20 |
| private final val DFLT_MIN_SCORE = 0.0 |
| |
| @volatile private var pool: ExecutorService = _ |
| @volatile private var executor: ExecutionContextExecutor = _ |
| |
| // For warnings. |
| private final val MIN_CNT_INTENT = 5 |
| private final val MIN_CNT_MODEL = 20 |
| |
| private final val GSON = new Gson |
| private final val TYPE_RESP = new TypeToken[util.List[util.List[Suggestion]]]() {}.getType |
| private final val SEPARATORS = Seq('?', ',', '.', '-', '!') |
| |
| private object Config extends NCConfigurable { |
| val urlOpt: Option[String] = getStringOpt("nlpcraft.server.ctxword.url") |
| } |
| |
| private final val HANDLER: ResponseHandler[Seq[Seq[Suggestion]]] = |
| (resp: HttpResponse) ⇒ { |
| val code = resp.getStatusLine.getStatusCode |
| val e = resp.getEntity |
| |
| val js = if (e != null) EntityUtils.toString(e) else null |
| |
| if (js == null) |
| throw new NCE(s"Unexpected empty HTTP response from 'ctxword' server [code=$code]") |
| |
| code match { |
| case 200 ⇒ |
| val data: util.List[util.List[Suggestion]] = GSON.fromJson(js, TYPE_RESP) |
| |
| data.asScala.map(p ⇒ if (p.isEmpty) Seq.empty else p.asScala.tail) |
| |
| case _ ⇒ |
| throw new NCE( |
| s"Unexpected HTTP response from 'ctxword' server [" + |
| s"code=$code, " + |
| s"response=$js" + |
| s"]" |
| ) |
| } |
| } |
| |
| case class Suggestion(word: String, score: Double) |
| case class RequestData(sentence: String, ex: String, elmId: String, index: Int) |
| case class RestRequestSentence(text: String, indexes: util.List[Int]) |
| case class RestRequest(sentences: util.List[RestRequestSentence], limit: Int, minScore: Double) |
| case class Word(word: String, stem: String) { |
| require(!word.contains(" "), s"Word cannot contains spaces: $word") |
| require( |
| word.forall(ch ⇒ |
| ch.isLetterOrDigit || |
| ch == '\'' || |
| SEPARATORS.contains(ch) |
| ), |
| s"Unsupported symbols: $word" |
| ) |
| } |
| case class SuggestionResult(synonym: String, score: Double) |
| |
| private def split(s: String): Seq[String] = s.split(" ").toSeq.map(_.trim).filter(_.nonEmpty) |
| private def toStem(s: String): String = split(s).map(NCNlpPorterStemmer.stem).mkString(" ") |
| private def toStemWord(s: String): String = NCNlpPorterStemmer.stem(s) |
| |
| /** |
| * |
| * @param parent Optional parent span. |
| * @return |
| */ |
| override def start(parent: Span): NCService = startScopedSpan("start", parent) { _ ⇒ |
| pool = Executors.newCachedThreadPool() |
| executor = ExecutionContext.fromExecutor(pool) |
| |
| ackStart() |
| } |
| |
| /** |
| * |
| * @param parent Optional parent span. |
| */ |
| override def stop(parent: Span): Unit = startScopedSpan("stop", parent) { _ ⇒ |
| U.shutdownPools(pool) |
| |
| pool = null |
| executor = null |
| |
| ackStop() |
| } |
| |
| /** |
| * |
| * @param seq1 |
| * @param seq2 |
| */ |
| private def getAllSlices(seq1: Seq[String], seq2: Seq[String]): Seq[Int] = { |
| val seq = mutable.Buffer.empty[Int] |
| |
| var i = seq1.indexOfSlice(seq2) |
| |
| while (i >= 0) { |
| seq += i |
| |
| i = seq1.indexOfSlice(seq2, i + 1) |
| } |
| |
| seq |
| } |
| |
| /** |
| * |
| * @param mdlId |
| * @param minScoreOpt |
| * @param parent |
| * @return |
| */ |
| def suggest(mdlId: String, minScoreOpt: Option[Double], parent: Span = null): Future[NCSuggestSynonymResult] = |
| startScopedSpan("inspect", parent, "mdlId" → mdlId) { _ ⇒ |
| val now = System.currentTimeMillis() |
| |
| val promise = Promise[NCSuggestSynonymResult]() |
| |
| NCProbeManager.getModelInfo(mdlId, parent).onComplete { |
| case Success(m) ⇒ |
| try { |
| require( |
| m.containsKey("macros") && |
| m.containsKey("synonyms") && |
| m.containsKey("samples") |
| ) |
| |
| val mdlMacros = m.get("macros"). |
| asInstanceOf[util.Map[String, String]].asScala |
| val mdlSyns = m.get("synonyms"). |
| asInstanceOf[util.Map[String, util.List[String]]].asScala.map(p ⇒ p._1 → p._2.asScala) |
| val mdlExs = m.get("samples"). |
| asInstanceOf[util.Map[String, util.List[String]]].asScala.map(p ⇒ p._1 → p._2.asScala) |
| |
| val minScore = minScoreOpt.getOrElse(DFLT_MIN_SCORE) |
| |
| def onError(err: String): Unit = |
| promise.success( |
| NCSuggestSynonymResult( |
| modelId = mdlId, |
| minScore = minScore, |
| durationMs = System.currentTimeMillis() - now, |
| timestamp = now, |
| error = err, |
| suggestions = Seq.empty.asJava, |
| warnings = Seq.empty.asJava |
| ) |
| ) |
| |
| if (mdlExs.isEmpty) |
| onError(s"Missed intents samples for: '$mdlId'") |
| else { |
| val url = s"${Config.urlOpt.getOrElse(throw new NCE("Context word server is not configured."))}/suggestions" |
| |
| val allSamplesCnt = mdlExs.map { case (_, samples) ⇒ samples.size }.sum |
| |
| val warns = mutable.ArrayBuffer.empty[String] |
| |
| if (allSamplesCnt < MIN_CNT_MODEL) |
| warns += |
| s"Model has too few ($allSamplesCnt) intents samples. " + |
| s"It will negatively affect the quality of suggestions. " + |
| s"Try to increase overall sample count to at least $MIN_CNT_MODEL." |
| |
| else { |
| val ids = |
| mdlExs. |
| filter { case (_, samples) ⇒ samples.size < MIN_CNT_INTENT }. |
| map { case (intentId, _) ⇒ intentId } |
| |
| if (ids.nonEmpty) |
| warns += |
| s"Following model intent have too few samples (${ids.mkString(", ")}). " + |
| s"It will negatively affect the quality of suggestions. " + |
| s"Try to increase overall sample count to at least $MIN_CNT_INTENT." |
| } |
| |
| val parser = new NCMacroParser() |
| |
| mdlMacros.foreach { case (name, str) ⇒ parser.addMacro(name, str) } |
| |
| // Note that we don't use system tokenizer, because 'ctxword' module' doesn't have this tokenizer. |
| // We split examples words by spaces. We also treat separator as separate words. |
| val exs = mdlExs. |
| flatMap { case (_, samples) ⇒ samples }. |
| map(ex ⇒ SEPARATORS.foldLeft(ex)((s, ch) ⇒ s.replaceAll(s"\\$ch", s" $ch "))). |
| map(ex ⇒ { |
| val seq = ex.split(" ") |
| |
| seq → seq.map(toStemWord) |
| }). |
| toMap |
| |
| val elemSyns = |
| mdlSyns.map { case (elemId, syns) ⇒ elemId → syns.flatMap(parser.expand) }. |
| map { case (id, seq) ⇒ id → seq.map(txt ⇒ split(txt).map(p ⇒ Word(p, toStemWord(p)))) } |
| |
| val allReqs = |
| elemSyns.map { |
| case (elemId, syns) ⇒ |
| val normSyns: Seq[Seq[Word]] = syns.filter(_.size == 1) |
| val synsStems = normSyns.map(_.map(_.stem)) |
| val synsWords = normSyns.map(_.map(_.word)) |
| |
| val reqs = |
| exs.flatMap { case (exWords, exampleStems) ⇒ |
| val exIdxs = synsStems.flatMap(synStems ⇒ getAllSlices(exampleStems, synStems)) |
| |
| def mkRequestData(idx: Int, synStems: Seq[String], synStemsIdx: Int): RequestData = { |
| val fromIncl = idx |
| val toExcl = idx + synStems.length |
| |
| RequestData( |
| sentence = exWords.zipWithIndex.flatMap { |
| case (exWord, i) ⇒ |
| i match { |
| case x if x == fromIncl ⇒ synsWords(synStemsIdx) |
| case x if x > fromIncl && x < toExcl ⇒ Seq.empty |
| case _ ⇒ Seq(exWord) |
| } |
| }.mkString(" "), |
| ex = exWords.mkString(" "), |
| elmId = elemId, |
| index = idx |
| ) |
| } |
| |
| (for (idx ← exIdxs; (synStems, i) ← synsStems.zipWithIndex) |
| yield mkRequestData(idx, synStems, i)).distinct |
| } |
| |
| elemId → reqs.toSet |
| }.filter(_._2.nonEmpty) |
| |
| val noExElems = |
| mdlSyns. |
| filter { case (elemId, syns) ⇒ syns.nonEmpty && !allReqs.contains(elemId) }. |
| map { case (elemId, _) ⇒ elemId } |
| |
| if (noExElems.nonEmpty) |
| warns += |
| "Some elements don't have synonyms in their intent samples, " + |
| s"so the service can't suggest any new synonyms for such elements: [${noExElems.mkString(", ")}]" |
| |
| val allReqsCnt = allReqs.map(_._2.size).sum |
| val allSynsCnt = elemSyns.map(_._2.size).sum |
| |
| logger.trace(s"Request is going to execute on 'ctxword' server [" + |
| s"exs=${exs.size}, " + |
| s"syns=$allSynsCnt, " + |
| s"reqs=$allReqsCnt" + |
| s"]") |
| |
| if (allReqsCnt == 0) |
| onError(s"Suggestions cannot be generated for model: '$mdlId'") |
| else { |
| val allSgsts = new ConcurrentHashMap[String, util.List[Suggestion]]() |
| val cdl = new CountDownLatch(1) |
| val debugs = mutable.HashMap.empty[RequestData, Seq[Suggestion]] |
| val cnt = new AtomicInteger(0) |
| |
| val cli = HttpClients.createDefault |
| val err = new AtomicReference[Throwable]() |
| |
| for ((elemId, reqs) ← allReqs; batch ← reqs.sliding(BATCH_SIZE, BATCH_SIZE).map(_.toSeq)) { |
| U.asFuture( |
| _ ⇒ { |
| val post = new HttpPost(url) |
| |
| post.setHeader("Content-Type", "application/json") |
| post.setEntity( |
| new StringEntity( |
| GSON.toJson( |
| RestRequest( |
| sentences = batch.map(p ⇒ RestRequestSentence(p.sentence, Seq(p.index).asJava)).asJava, |
| minScore = 0, |
| limit = MAX_LIMIT |
| ) |
| ), |
| "UTF-8" |
| ) |
| ) |
| |
| val resps: Seq[Seq[Suggestion]] = try |
| cli.execute(post, HANDLER) |
| finally |
| post.releaseConnection() |
| |
| require(batch.size == resps.size, s"Batch: ${batch.size}, responses: ${resps.size}") |
| |
| batch.zip(resps).foreach { case (req, resp) ⇒ debugs += req → resp } |
| |
| val i = cnt.addAndGet(batch.size) |
| |
| logger.debug(s"Executed: $i requests...") |
| |
| allSgsts. |
| computeIfAbsent(elemId, (_: String) ⇒ new CopyOnWriteArrayList[Suggestion]()). |
| addAll(resps.flatten.asJava) |
| |
| if (i == allReqsCnt) |
| cdl.countDown() |
| }, |
| (e: Throwable) ⇒ { |
| err.compareAndSet(null, e) |
| |
| cdl.countDown() |
| }, |
| (_: Unit) ⇒ () |
| ) |
| } |
| |
| cdl.await(Long.MaxValue, TimeUnit.MILLISECONDS) |
| |
| if (err.get() != null) |
| throw new NCE("Error during work with 'ContextWordServer'.", err.get()) |
| |
| val allSynsStems = elemSyns.flatMap(_._2).flatten.map(_.stem).toSet |
| |
| val nonEmptySgsts = allSgsts.asScala.map(p ⇒ p._1 → p._2.asScala).filter(_._2.nonEmpty) |
| |
| val res = mutable.HashMap.empty[String, mutable.ArrayBuffer[SuggestionResult]] |
| |
| nonEmptySgsts.foreach { case (elemId, elemSgsts) ⇒ |
| elemSgsts. |
| map(sgst ⇒ (sgst, toStem(sgst.word))). |
| groupBy { case (_, stem) ⇒ stem }. |
| // Drops already defined. |
| filter { case (stem, _) ⇒ !allSynsStems.contains(stem) }. |
| map { case (_, group) ⇒ |
| val seq = group.map { case (sgst, _) ⇒ sgst }.sortBy(-_.score) |
| |
| // Drops repeated. |
| (seq.head.word, seq.length, seq.map(_.score).sum / seq.size) |
| }. |
| toSeq. |
| map { case (sgst, cnt, score) ⇒ (sgst, cnt, score * cnt / elemSgsts.size) }. |
| sortBy { case (_, _, sumFactor) ⇒ -sumFactor }. |
| zipWithIndex. |
| foreach { case ((word, _, sumFactor), _) ⇒ |
| val seq = |
| res.get(elemId) match { |
| case Some(seq) ⇒ seq |
| case None ⇒ |
| val buf = mutable.ArrayBuffer.empty[SuggestionResult] |
| |
| res += elemId → buf |
| |
| buf |
| } |
| |
| seq += SuggestionResult(word, sumFactor) |
| } |
| } |
| |
| val resJ: util.Map[String, util.List[util.HashMap[String, Any]]] = |
| res.map { case (id, data) ⇒ |
| val norm = |
| if (data.nonEmpty) { |
| val factors = data.map(_.score) |
| |
| val min = factors.min |
| val max = factors.max |
| var delta = max - min |
| |
| if (delta == 0) |
| delta = max |
| |
| def normalize(v: Double): Double = (v - min) / delta |
| |
| data. |
| map(s ⇒ SuggestionResult(s.synonym, normalize(s.score))). |
| filter(_.score >= minScore) |
| } |
| else |
| Seq.empty |
| |
| id → norm.map(d ⇒ { |
| val m = new util.HashMap[String, Any]() |
| |
| m.put("synonym", d.synonym.toLowerCase) |
| m.put("score", d.score) |
| |
| m |
| }).asJava |
| }.asJava |
| |
| promise.success( |
| NCSuggestSynonymResult( |
| modelId = mdlId, |
| minScore = minScore, |
| durationMs = System.currentTimeMillis() - now, |
| timestamp = now, |
| error = null, |
| suggestions = Seq(resJ.asInstanceOf[AnyRef]).asJava, |
| warnings = warns.asJava |
| ) |
| ) |
| } |
| } |
| } |
| catch { |
| case e: NCE ⇒ promise.failure(e) |
| case e: Throwable ⇒ |
| U.prettyError(logger, "Unexpected error:", e) |
| |
| promise.failure(e) |
| } |
| case Failure(e) ⇒ promise.failure(e) |
| }(executor) |
| |
| promise.future |
| } |
| } |