blob: eee8daad48cd7e17b06fa573ab629dfd5f8ba001 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
package org.apache.toree.kernel.protocol.v5.client.socket
import org.apache.pekko.actor.Actor
import org.apache.toree.communication.ZMQMessage
import org.apache.toree.comm.{CommStorage, CommRegistrar, ClientCommWriter}
import org.apache.toree.kernel.protocol.v5.client.{ActorLoader, Utilities}
import Utilities._
import org.apache.toree.kernel.protocol.v5.client.execution.DeferredExecutionManager
import org.apache.toree.kernel.protocol.v5.content._
import org.apache.toree.kernel.protocol.v5.{KMBuilder, KernelMessage}
import org.apache.toree.utils.LogLike
import scala.util.Failure
/**
* The client endpoint for IOPub messages specified in the IPython Kernel Spec
* @param socketFactory A factory to create the ZeroMQ socket connection
*/
class IOPubClient(
socketFactory: SocketFactory, actorLoader: ActorLoader,
signatureEnabled: Boolean,
commRegistrar: CommRegistrar, commStorage: CommStorage
) extends Actor with LogLike {
private val PARENT_HEADER_NULL_MESSAGE = "Parent Header was null in Kernel Message."
private val socket = socketFactory.IOPubClient(context.system, self)
logger.info("Created IOPub socket")
/**
* Constructs and returns a map directing message types to functions.
*
* @param kernelMessage The kernel message used to generate the map
*
* @return The map of
*/
private def getMessageMap(kernelMessage: KernelMessage) = Map[String, () => Unit](
ExecuteResult.toTypeString -> { () =>
receiveKernelMessage(kernelMessage, receiveExecuteResult(_, kernelMessage))
},
StreamContent.toTypeString -> { () =>
receiveKernelMessage(kernelMessage, receiveStreamMessage(_, kernelMessage))
},
CommOpen.toTypeString -> { () =>
receiveKernelMessage(kernelMessage, receiveCommOpen(_, kernelMessage))
},
CommMsg.toTypeString -> { () =>
receiveKernelMessage(kernelMessage, receiveCommMsg(_, kernelMessage))
},
CommClose.toTypeString -> { () =>
receiveKernelMessage(kernelMessage, receiveCommClose(_, kernelMessage))
}
)
private def receiveKernelMessage(
kernelMessage: KernelMessage, func: (String) => Unit
): Unit = {
if(kernelMessage.parentHeader != null){
func(kernelMessage.parentHeader.msg_id)
} else {
logger.warn("Received message with null parent header.")
logger.debug(s"Kernel message is: $kernelMessage")
sender.forward(Failure(new RuntimeException(PARENT_HEADER_NULL_MESSAGE)))
}
}
private def receiveStreamMessage(
parentHeaderId:String, kernelMessage: KernelMessage
): Unit = {
// look up callback in CallbackMap based on msg_id and invoke
val optionalDE = DeferredExecutionManager.get(parentHeaderId)
optionalDE match {
case Some(de) => parseAndHandle(kernelMessage.contentString,
StreamContent.streamContentReads,
(streamContent: StreamContent) => de.emitStreamContent(streamContent))
case None =>
logger.warn(s"No deferred execution found for id $parentHeaderId")
}
}
private def receiveExecuteResult(
parentHeaderId:String, kernelMessage: KernelMessage
): Unit = {
// look up callback in CallbackMap based on msg_id and invoke
val optionalDE = DeferredExecutionManager.get(parentHeaderId)
optionalDE match {
case Some(de) => parseAndHandle(kernelMessage.contentString,
ExecuteResult.executeResultReads,
(executeResult: ExecuteResult) => de.resolveResult(executeResult))
case None =>
logger.warn(s"No deferred execution found for id $parentHeaderId")
}
}
private def receiveCommOpen(
parentHeaderId:String, kernelMessage: KernelMessage
): Unit = {
parseAndHandle(
kernelMessage.contentString,
CommOpen.commOpenReads,
(commOpen: CommOpen) => {
val targetName = commOpen.target_name
val commId = commOpen.comm_id
val commWriter = new ClientCommWriter(
actorLoader, KMBuilder(), commId)
commStorage.getTargetCallbacks(targetName) match {
// Unexpected target (does not exist on this side), so send close
case None =>
commWriter.close()
// Found target, so execute callbacks
case Some(callbacks) =>
callbacks.executeOpenCallbacks(
commWriter, commId, targetName, commOpen.data)
}
}
)
}
private def receiveCommMsg(
parentHeaderId:String, kernelMessage: KernelMessage
): Unit = {
parseAndHandle(
kernelMessage.contentString,
CommMsg.commMsgReads,
(commMsg: CommMsg) => {
val commId = commMsg.comm_id
commStorage.getCommIdCallbacks(commId) match {
case None =>
case Some(callbacks) =>
val commWriter = new ClientCommWriter(
actorLoader, KMBuilder(), commId)
callbacks.executeMsgCallbacks(commWriter, commId, commMsg.data)
}
}
)
}
private def receiveCommClose(
parentHeaderId:String, kernelMessage: KernelMessage
): Unit = {
parseAndHandle(
kernelMessage.contentString,
CommClose.commCloseReads,
(commClose: CommClose) => {
val commId = commClose.comm_id
commStorage.getCommIdCallbacks(commId) match {
case None =>
case Some(callbacks) =>
val commWriter = new ClientCommWriter(
actorLoader, KMBuilder(), commId)
callbacks.executeCloseCallbacks(commWriter, commId, commClose.data)
}
}
)
}
override def receive: Receive = {
case message: ZMQMessage =>
// convert to KernelMessage using implicits in v5
logger.debug("Received IOPub kernel message.")
val kernelMessage: KernelMessage = message
// TODO: Validate incoming message signature
logger.trace(s"Kernel message is $kernelMessage")
val messageTypeString = kernelMessage.header.msg_type
val messageMap = getMessageMap(kernelMessage)
if (messageMap.contains(messageTypeString)) {
messageMap(messageTypeString)()
} else {
logger.warn(s"Received unhandled MessageType $messageTypeString")
}
}
}