blob: 492fffe27795041f8be243b295a43b2b3adf63b1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gearpump.experimental.rabbitmq
import org.apache.gearpump.Message
import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.streaming.sink.DataSink
import org.apache.gearpump.streaming.task.TaskContext
import com.rabbitmq.client.Channel
import com.rabbitmq.client.{Connection, ConnectionFactory}
import org.apache.gearpump.util.LogUtil
class RMQSink(userConfig: UserConfig,
val connFactory: (UserConfig) => ConnectionFactory) extends DataSink{
private val LOG = LogUtil.getLogger(getClass)
var connectionFactory: ConnectionFactory = connFactory(userConfig)
var connection: Connection = null
var channel: Channel = null
var queueName: String = null
def this(userConfig: UserConfig) = {
this(userConfig, RMQSink.getConnectionFactory)
}
override def open(context: TaskContext): Unit = {
connection = connectionFactory.newConnection
channel = connection.createChannel
if (channel == null) {
throw new RuntimeException("None of RabbitMQ channels are available.")
}
setupQueue()
}
override def write(message: Message): Unit = {
publish(message.msg)
}
override def close(): Unit = {
channel.close()
connection.close()
}
protected def setupQueue(): Unit = {
val queue = RMQSink.getQueueName(userConfig)
if (queue.isEmpty) {
throw new RuntimeException("can not get a RabbitMQ queue name")
}
queueName = queue.get
channel.queueDeclare(queue.get, false, false, false, null)
}
def publish(msg: Any): Unit = {
msg match {
case seq: Seq[Any] =>
seq.foreach(publish)
case str: String => {
channel.basicPublish("", queueName, null, msg.asInstanceOf[String].getBytes)
}
case byteArray: Array[Byte] => {
channel.basicPublish("", queueName, null, byteArray)
}
case _ => {
LOG.warn("matched unsupported message!")
}
}
}
}
object RMQSink {
val RMQSINK = "rmqsink"
val QUEUE_NAME = "rabbitmq.queue.name"
val SERVER_HOST = "rabbitmq.connection.host"
val SERVER_PORT = "rabbitmq.connection.port"
val CONNECTION_URI = "rabbitmq.connection.uri"
val VIRTUAL_HOST = "rabbitmq.virtualhost"
val AUTH_USERNAME = "rabbitmq.auth.username"
val AUTH_PASSWORD = "rabbitmq.auth.password"
val AUTOMATIC_RECOVERY = "rabbitmq.automatic.recovery"
val CONNECTION_TIMEOUT = "rabbitmq.connection.timeout"
val NETWORK_RECOVERY_INTERVAL = "rabbitmq.network.recovery.interval"
val REQUESTED_HEARTBEAT = "rabbitmq.requested.heartbeat"
val TOPOLOGY_RECOVERY_ENABLED = "rabbitmq.topology.recoveryenabled"
val REQUESTED_CHANNEL_MAX = "rabbitmq.channel.max"
val REQUESTED_FRAME_MAX = "rabbitmq.frame.max"
def getConnectionFactory(userConfig : UserConfig): ConnectionFactory = {
val factory : ConnectionFactory = new ConnectionFactory
val uri : Option[String] = userConfig.getString(CONNECTION_URI)
if (uri.nonEmpty) {
factory.setUri(uri.get)
} else {
val serverHost : Option[String] = userConfig.getString(SERVER_HOST)
val serverPort : Option[Int] = userConfig.getInt(SERVER_PORT)
if (!serverHost.nonEmpty) {
throw new RuntimeException("missed config key : " + SERVER_HOST)
}
if (!serverPort.nonEmpty) {
throw new RuntimeException("missed config key : " + SERVER_PORT)
}
factory.setHost(serverHost.get)
factory.setPort(serverPort.get)
}
val virtualHost : Option[String] = userConfig.getString(VIRTUAL_HOST)
if (virtualHost.nonEmpty) {
factory.setVirtualHost(virtualHost.get)
}
val authUserName : Option[String] = userConfig.getString(AUTH_USERNAME)
if (authUserName.nonEmpty) {
factory.setUsername(authUserName.get)
}
val authPassword : Option[String] = userConfig.getString(AUTH_PASSWORD)
if (authPassword.nonEmpty) {
factory.setPassword(authPassword.get)
}
val automaticRecovery : Option[Boolean] = userConfig.getBoolean(AUTOMATIC_RECOVERY)
if (automaticRecovery.nonEmpty) {
factory.setAutomaticRecoveryEnabled(automaticRecovery.get)
}
val connectionTimeOut : Option[Int] = userConfig.getInt(CONNECTION_TIMEOUT)
if (connectionTimeOut.nonEmpty) {
factory.setConnectionTimeout(connectionTimeOut.get)
}
val networkRecoveryInterval : Option[Int] = userConfig.getInt(NETWORK_RECOVERY_INTERVAL)
if (networkRecoveryInterval.nonEmpty) {
factory.setNetworkRecoveryInterval(networkRecoveryInterval.get)
}
val requestedHeartBeat : Option[Int] = userConfig.getInt(REQUESTED_HEARTBEAT)
if (requestedHeartBeat.nonEmpty) {
factory.setRequestedHeartbeat(requestedHeartBeat.get)
}
val topologyRecoveryEnabled : Option[Boolean] = userConfig.getBoolean(TOPOLOGY_RECOVERY_ENABLED)
if (topologyRecoveryEnabled.nonEmpty) {
factory.setTopologyRecoveryEnabled(topologyRecoveryEnabled.get)
}
val requestedChannelMax : Option[Int] = userConfig.getInt(REQUESTED_CHANNEL_MAX)
if (requestedChannelMax.nonEmpty) {
factory.setRequestedChannelMax(requestedChannelMax.get)
}
val requestedFrameMax : Option[Int] = userConfig.getInt(REQUESTED_FRAME_MAX)
if (requestedFrameMax.nonEmpty) {
factory.setRequestedFrameMax(requestedFrameMax.get)
}
factory
}
def getQueueName(userConfig: UserConfig): Option[String] = {
userConfig.getString(QUEUE_NAME)
}
}