blob: 8121f889096a91633e91588070d27c06ea5bfcec [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nlpcraft.probe.mgrs.conn
import java.io.{EOFException, IOException, InterruptedIOException}
import java.net.{InetAddress, NetworkInterface}
import java.util
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Properties, TimeZone}
import io.opencensus.trace.Span
import org.apache.nlpcraft.common._
import org.apache.nlpcraft.common.config.NCConfigurable
import org.apache.nlpcraft.common.crypto._
import org.apache.nlpcraft.common.nlp.core.NCNlpCoreManager
import org.apache.nlpcraft.common.socket._
import org.apache.nlpcraft.common.version.NCVersion
import org.apache.nlpcraft.probe.mgrs.NCProbeMessage
import org.apache.nlpcraft.probe.mgrs.cmd.NCCommandManager
import org.apache.nlpcraft.probe.mgrs.model.NCModelManager
import scala.collection.mutable
/**
* Probe down/up link connection manager.
*/
object NCConnectionManager extends NCService {
// Uplink retry timeout.
private final val RETRY_TIMEOUT = 10 * 1000
// SO_TIMEOUT.
private final val SO_TIMEOUT = 5 * 1000
// Ping timeout.
private final val PING_TIMEOUT = 5 * 1000
// Internal probe GUID.
@volatile private var probeGuid: String = _
private final val sysProps: Properties = System.getProperties
private final val localHost: InetAddress = InetAddress.getLocalHost
@volatile private var hwAddrs: String = _
// Holding downlink queue.
@volatile private var dnLinkQueue: mutable.Queue[Serializable] = _
// Control thread.
@volatile private var ctrlThread: Thread = _
private object Config extends NCConfigurable {
private final val pre = "nlpcraft.probe"
def id: String = getString(s"$pre.id")
def token: String = getString(s"$pre.token")
def upLink: (String, Integer) = getHostPort(s"$pre.upLink")
def downLink: (String, Integer) = getHostPort(s"$pre.downLink")
def upLinkString = s"${upLink._1}:${upLink._2}"
def downLinkString = s"${downLink._1}:${downLink._2}"
}
/**
* Schedules message for sending to the server.
*
* @param msg Message to send to server.
*/
def send(msg: NCProbeMessage, parent: Span = null): Unit = startScopedSpan("send", parent) { span ⇒
addTags(
span,
"probeId"Config.id,
"token"Config.token,
"probeGuid" → probeGuid,
"msgType" → msg.getType,
"msgGuid" → msg.getGuid
)
// Set probe identification for each message, if necessary.
msg.setProbeToken(Config.token)
msg.setProbeId(Config.id)
msg.setProbeGuid(probeGuid)
dnLinkQueue.synchronized {
if (!isStopping) {
dnLinkQueue += msg
dnLinkQueue.notifyAll()
}
else
logger.trace(s"Message sending ignored b/c of stopping: $msg")
}
}
class HandshakeError(msg: String) extends RuntimeException(msg)
/**
* Opens down link socket.
*/
@throws[Exception]
private def openDownLinkSocket(): NCSocket = {
val (host, port) = Config.downLink
val cryptoKey = NCCipher.makeTokenKey(Config.token)
logger.trace(s"Opening downlink to '$host:$port'")
// Connect down socket.
val sock = NCSocket(host, port)
sock.write(U.mkSha256Hash(Config.token)) // Hash.
sock.write(NCProbeMessage( // Handshake.
// Type.
"INIT_HANDSHAKE",
// Payload.
// Probe identification.
"PROBE_TOKEN"Config.token,
"PROBE_ID"Config.id,
"PROBE_GUID" → probeGuid
), cryptoKey)
val resp = sock.read[NCProbeMessage](cryptoKey) // Get handshake response.
resp.getType match {
case "P2S_PROBE_OK" ⇒ logger.trace("Downlink handshake OK.") // Bingo!
case "P2S_PROBE_NOT_FOUND"throw new HandshakeError("Probe failed to start due to unknown error.")
case _ ⇒ throw new HandshakeError(s"Unexpected REST server message: ${resp.getType}")
}
sock
}
/**
* Opens uplink socket.
*/
@throws[Exception]
private def openUplinkSocket(): NCSocket = {
val netItf = NetworkInterface.getByInetAddress(localHost)
hwAddrs = ""
if (netItf != null) {
val addrs = netItf.getHardwareAddress
if (addrs != null)
hwAddrs = addrs.foldLeft("")((s, b) ⇒ s + (if (s == "") f"$b%02X" else f"-$b%02X"))
}
val (host, port) = Config.upLink
val cryptoKey = NCCipher.makeTokenKey(Config.token)
logger.trace(s"Opening uplink to '$host:$port'")
// Connect up socket.
val sock = NCSocket(host, port)
sock.write(U.mkSha256Hash(Config.token)) // Hash, sent clear text.
val hashResp = sock.read[NCProbeMessage]()
hashResp.getType match { // Get hash check response.
case "S2P_HASH_CHECK_OK"
val ver = NCVersion.getCurrent
val tmz = TimeZone.getDefault
val srvNlpEng =
hashResp.getOrElse(
"NLP_ENGINE",
throw new HandshakeError("NLP engine parameter missed in response.")
)
val probeNlpEng = NCNlpCoreManager.getEngine
if (srvNlpEng != probeNlpEng)
logger.warn(s"Invalid NLP engines configuration [server=$srvNlpEng, probe=$probeNlpEng]")
sock.write(NCProbeMessage( // Handshake.
// Type.
"INIT_HANDSHAKE",
// Payload.
// Probe identification.
"PROBE_TOKEN"Config.token,
"PROBE_ID"Config.id,
"PROBE_GUID" → probeGuid,
// Handshake data,
"PROBE_API_DATE" → ver.date,
"PROBE_API_VERSION" → ver.version,
"PROBE_OS_VER" → sysProps.getProperty("os.version"),
"PROBE_OS_NAME" → sysProps.getProperty("os.name"),
"PROBE_OS_ARCH" → sysProps.getProperty("os.arch"),
"PROBE_START_TSTAMP" → U.nowUtcMs(),
"PROBE_TMZ_ID" → tmz.getID,
"PROBE_TMZ_ABBR" → tmz.getDisplayName(false, TimeZone.SHORT),
"PROBE_TMZ_NAME" → tmz.getDisplayName(),
"PROBE_SYS_USERNAME" → sysProps.getProperty("user.name"),
"PROBE_JAVA_VER" → sysProps.getProperty("java.version"),
"PROBE_JAVA_VENDOR" → sysProps.getProperty("java.vendor"),
"PROBE_HOST_NAME" → localHost.getHostName,
"PROBE_HOST_ADDR" → localHost.getHostAddress,
"PROBE_HW_ADDR" → hwAddrs,
"PROBE_MODELS"
NCModelManager.getAllModels().map(wrapper ⇒ {
val mdl = wrapper.model
// Model already validated.
// util.HashSet created to avoid scala collections serialization error.
// Seems to be a Scala bug.
(
mdl.getId,
mdl.getName,
mdl.getVersion,
new util.HashSet[String](mdl.getEnabledBuiltInTokens)
)
})
), cryptoKey)
val resp = sock.read[NCProbeMessage](cryptoKey) // Get handshake response.
resp.getType match {
case "S2P_PROBE_MULTIPLE_INSTANCES"throw new HandshakeError("Duplicate probes ID detected. Each probe has to have a unique ID.")
case "S2P_PROBE_NOT_FOUND"throw new HandshakeError("Probe failed to start due to unknown error.")
case "S2P_PROBE_VERSION_MISMATCH"throw new HandshakeError(s"REST server does not support probe version: ${ver.version}")
case "S2P_PROBE_UNSUPPORTED_TOKENS_TYPES"throw new HandshakeError(s"REST server does not support some model enabled tokes types.")
case "S2P_PROBE_OK" ⇒ logger.trace("Uplink handshake OK.") // Bingo!
case _ ⇒ throw new HandshakeError(s"Unknown REST server message: ${resp.getType}")
}
sock
case "S2P_HASH_CHECK_UNKNOWN"throw new HandshakeError(s"Sever does not recognize probe token: ${Config.token}.")
}
}
/**
*
*/
private def abort(): Unit = {
// Make sure to exit & stop this thread.
ctrlThread.interrupt()
// Exit the probe with error code.
System.exit(1)
}
/**
*
* @return
*/
override def start(parent: Span = null): NCService = startScopedSpan("start", parent) { _ ⇒
require(NCCommandManager.isStarted)
require(NCModelManager.isStarted)
ackStarting()
probeGuid = U.genGuid()
dnLinkQueue = mutable.Queue.empty[Serializable]
val ctrlLatch = new CountDownLatch(1)
ctrlThread = U.mkThread("probe-ctrl-thread") { t ⇒
var dnSock: NCSocket = null
var upSock: NCSocket = null
var dnThread: Thread = null
var upThread: Thread = null
/**
*
*/
def closeAll(): Unit = {
U.stopThread(dnThread)
U.stopThread(upThread)
dnThread = null
upThread = null
if (dnSock != null) dnSock.close()
if (upSock != null) upSock.close()
dnSock = null
upSock = null
}
/**
*
*/
def timeout(): Unit = if (!t.isInterrupted) U.ignoreInterrupt {
Thread.sleep(RETRY_TIMEOUT)
}
val cryptoKey = NCCipher.makeTokenKey(Config.token)
while (!t.isInterrupted)
try {
logger.info(s"Connecting to REST server [" +
s"uplink=${Config.upLinkString}, " +
s"downlink=${Config.downLinkString}" +
s"]")
upSock = openUplinkSocket()
dnSock = openDownLinkSocket()
upSock.socket.setSoTimeout(SO_TIMEOUT)
val exitLatch = new CountDownLatch(1)
/**
*
* @param caller Caller thread to interrupt.
* @param msg Error message.
* @param cause Optional cause of the error.
*/
def exit(caller: Thread, msg: String, cause: Exception = null): Unit = {
if (cause != null)
U.prettyError(logger, msg, cause)
else
logger.error(msg)
caller.interrupt() // Interrupt current calling thread.
exitLatch.countDown()
}
upThread = U.mkThread("probe-uplink") { t ⇒
// Main reading loop.
while (!t.isInterrupted)
try
NCCommandManager.processServerMessage(upSock.read[NCProbeMessage](cryptoKey))
catch {
case _: InterruptedIOException | _: InterruptedException()
case _: EOFException ⇒ exit(t, s"Uplink REST server connection closed.")
case e: Exception ⇒ exit(t, s"Uplink connection failed.", e)
}
}
dnThread = U.mkThread("probe-downlink") { t ⇒
while (!t.isInterrupted)
try {
dnLinkQueue.synchronized {
if (dnLinkQueue.isEmpty) {
dnLinkQueue.wait(PING_TIMEOUT)
if (!dnThread.isInterrupted && dnLinkQueue.isEmpty) {
val pingMsg = NCProbeMessage("P2S_PING")
pingMsg.setProbeToken(Config.token)
pingMsg.setProbeId(Config.id)
pingMsg.setProbeGuid(probeGuid)
dnSock.write(pingMsg, cryptoKey)
}
}
else {
val msg = dnLinkQueue.head
// Write head first (without actually removing from queue).
dnSock.write(msg, cryptoKey)
// If sent ok - remove from queue.
dnLinkQueue.dequeue()
}
}
}
catch {
case _: InterruptedIOException | _: InterruptedException()
case _: EOFException ⇒ exit(t, s"Downlink REST server connection closed.")
case e: Exception ⇒ exit(t, s"Downlink connection failed.", e)
}
}
// Bingo - start downlink and uplink!
upThread.start()
dnThread.start()
// Indicate that server connection is established.
ctrlLatch.countDown()
logger.info("REST server connected.")
// Wait until probe connection is closed.
while (!t.isInterrupted && exitLatch.getCount > 0) U.ignoreInterrupt {
exitLatch.await()
}
closeAll()
if (!isStopping) {
logger.warn(s"REST server connection closed (retry in ${RETRY_TIMEOUT / 1000}s).")
timeout()
}
else
logger.info(s"REST server connection closed.")
}
catch {
case e: HandshakeError
// Clean up.
closeAll()
// Ack the handshake error message.
U.prettyError(logger, s"Failed REST server connection handshake (aborting).", e)
abort()
case e: IOException
// Clean up.
closeAll()
// Ack the error message.
U.prettyError(logger, s"Failed to connect to REST server (retry in ${RETRY_TIMEOUT / 1000}s).", e)
timeout()
case e: Exception
// Clean up.
closeAll()
// Ack the error message.
U.prettyError(logger, "Unexpected error connecting to REST server.", e)
abort()
}
closeAll()
}
ctrlThread.start()
// Only return when probe successfully connected to the server.
ctrlLatch.await()
ackStarted()
}
/**
*
*/
override def stop(parent: Span = null): Unit = startScopedSpan("stop", parent) { _ ⇒
ackStopping()
U.stopThread(ctrlThread)
ackStopped()
}
}