blob: 9ad8f10a484e2d38a441647265de2324b6a6cef5 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.apollo.openwire
import OpenwireConstants._
import org.fusesource.hawtdispatch._
import org.fusesource.hawtbuf._
import collection.mutable.{ListBuffer, HashMap}
import java.io.IOException
import org.apache.activemq.apollo.selector.SelectorParser
import org.apache.activemq.apollo.filter.{BooleanExpression, FilterException}
import org.apache.activemq.apollo.broker.store._
import org.apache.activemq.apollo.util._
import java.util.concurrent.TimeUnit
import java.util.Map.Entry
import scala.util.continuations._
import org.fusesource.hawtdispatch.transport._
import codec.OpenWireFormat
import command._
import org.apache.activemq.apollo.openwire.dto.{OpenwireConnectionStatusDTO,OpenwireDTO}
import org.apache.activemq.apollo.dto.{AcceptingConnectorDTO, TopicDestinationDTO, DurableSubscriptionDestinationDTO, DestinationDTO}
import org.apache.activemq.apollo.openwire.DestinationConverter._
import org.apache.activemq.apollo.broker._
import protocol._
import security.SecurityContext
object OpenwireProtocolHandler extends Log {
def unit:Unit = {}
val DEFAULT_DIE_DELAY = 5 * 1000L
var die_delay = DEFAULT_DIE_DELAY
val preferred_wireformat_settings = new WireFormatInfo();
preferred_wireformat_settings.setVersion(OpenWireFormat.DEFAULT_VERSION);
preferred_wireformat_settings.setStackTraceEnabled(true);
preferred_wireformat_settings.setCacheEnabled(true);
preferred_wireformat_settings.setTcpNoDelayEnabled(true);
preferred_wireformat_settings.setTightEncodingEnabled(true);
preferred_wireformat_settings.setSizePrefixDisabled(false);
preferred_wireformat_settings.setMaxInactivityDuration(30 * 1000 * 1000);
preferred_wireformat_settings.setMaxInactivityDurationInitalDelay(10 * 1000 * 1000);
preferred_wireformat_settings.setCacheSize(1024);
preferred_wireformat_settings.setMaxFrameSize(OpenWireFormat.DEFAULT_MAX_FRAME_SIZE);
}
/**
*
*/
class OpenwireProtocolHandler extends ProtocolHandler {
var minimum_protocol_version = 1
import OpenwireProtocolHandler._
def dispatchQueue: DispatchQueue = connection.dispatch_queue
def protocol = PROTOCOL
var sink_manager:SinkMux[Command] = null
var connection_session:Sink[Command] = null
var closed = false
var last_command_id=0
def next_command_id = {
last_command_id += 1
last_command_id
}
def broker = connection.connector.broker
var producerRoutes = new LRUCache[List[DestinationDTO], DeliveryProducerRoute](10) {
override def onCacheEviction(eldest: Entry[List[DestinationDTO], DeliveryProducerRoute]) = {
host.router.disconnect(eldest.getKey.toArray, eldest.getValue)
}
}
var host: VirtualHost = null
private def queue = connection.dispatch_queue
var session_id: AsciiBuffer = _
var wire_format: OpenWireFormat = _
var login: Option[AsciiBuffer] = None
var passcode: Option[AsciiBuffer] = None
var dead = false
val security_context = new SecurityContext
var config:OpenwireDTO = _
var heart_beat_monitor = new HeartBeatMonitor
var waiting_on: String = "client request"
var current_command: Object = _
var codec:OpenwireCodec = _
override def create_connection_status = {
var rc = new OpenwireConnectionStatusDTO
rc.protocol_version = ""+(if (wire_format == null) 0 else wire_format.getVersion)
rc.user = login.map(_.toString).getOrElse(null)
rc.subscription_count = all_consumers.size
rc.waiting_on = waiting_on
rc
}
override def set_connection(connection: BrokerConnection) = {
super.set_connection(connection)
import collection.JavaConversions._
codec = connection.transport.getProtocolCodec.asInstanceOf[OpenwireCodec]
var connector_config = connection.connector.config.asInstanceOf[AcceptingConnectorDTO]
config = connector_config.protocols.find( _.isInstanceOf[OpenwireDTO]).map(_.asInstanceOf[OpenwireDTO]).getOrElse(new OpenwireDTO)
// protocol_filters = ProtocolFilter.create_filters(config.protocol_filters.toList, this)
//
// config.max_data_length.foreach( codec.max_data_length = _ )
// config.max_header_length.foreach( codec.max_header_length = _ )
// config.max_headers.foreach( codec.max_headers = _ )
if( config.destination_separator!=null ||
config.path_separator!= null ||
config.any_child_wildcard != null ||
config.any_descendant_wildcard!= null ) {
// destination_parser = new DestinationParser().copy(Stomp.destination_parser)
// if( config.destination_separator!=null ) { destination_parser.destination_separator = config.destination_separator }
// if( config.path_separator!=null ) { destination_parser.path_separator = config.path_separator }
// if( config.any_child_wildcard!=null ) { destination_parser.any_child_wildcard = config.any_child_wildcard }
// if( config.any_descendant_wildcard!=null ) { destination_parser.any_descendant_wildcard = config.any_descendant_wildcard }
}
}
def suspend_read(reason: String) = {
waiting_on = reason
connection.transport.suspendRead
heart_beat_monitor.suspendRead
}
def resume_read() = {
waiting_on = "client request"
connection.transport.resumeRead
heart_beat_monitor.resumeRead
}
def ack(command: Command):Unit = {
if (command.isResponseRequired()) {
val rc = new Response();
rc.setCorrelationId(command.getCommandId());
connection_session.offer(rc);
}
}
override def on_transport_failure(error: IOException) = {
if (!connection.stopped) {
error.printStackTrace
suspend_read("shutdown")
debug(error, "Shutting connection down due to: %s", error)
connection.stop
}
}
override def on_transport_connected():Unit = {
security_context.connection_id = Some(connection.id)
security_context.local_address = connection.transport.getLocalAddress
security_context.remote_address = connection.transport.getRemoteAddress
sink_manager = new SinkMux[Command]( connection.transport_sink.map {x=>
x.setCommandId(next_command_id)
debug("sending openwire command: %s", x.toString())
x
})
connection_session = new OverflowSink(sink_manager.open());
// Send our preferred wire format settings..
connection.transport.offer(preferred_wireformat_settings)
resume_read
reset {
suspend_read("virtual host lookup")
this.host = broker.get_default_virtual_host
resume_read
if(host==null) {
async_die("Could not find default virtual host")
}
}
}
override def on_transport_disconnected():Unit = {
if (!closed) {
closed = true;
dead = true;
heart_beat_monitor.stop
import collection.JavaConversions._
producerRoutes.foreach{
case (dests, route) => host.router.disconnect(dests.toArray, route)
}
producerRoutes.clear
// consumers.foreach{
// case (_, consumer) =>
// if (consumer.binding == null) {
// host.router.unbind(consumer.destination, consumer)
// } else {
// host.router.get_queue(consumer.binding) {
// queue =>
// queue.foreach(_.unbind(consumer :: Nil))
// }
// }
// }
// consumers = Map()
trace("openwire protocol resources released")
}
}
override def on_transport_command(command: Object):Unit = {
if( dead ) {
// We stop processing client commands once we are dead
return;
}
try {
current_command = command
trace("received: %s", command)
if (wire_format == null) {
command match {
case codec: OpenwireCodec =>
// this is passed on to us by the protocol discriminator
// so we know which wire format is being used.
case command: WireFormatInfo =>
on_wire_format_info(command)
case _ =>
die("Unexpected command: " + command.getClass);
}
} else {
command match {
case msg:ActiveMQMessage=> on_message(msg)
case ack:MessageAck=> on_message_ack(ack)
case info:TransactionInfo => on_transaction_info(info)
case info:ProducerInfo=> on_producer_info(info)
case info:ConsumerInfo=> on_consumer_info(info)
case info:SessionInfo=> on_session_info(info)
case info:ConnectionInfo=> on_connection_info(info)
case info:RemoveInfo=> on_remove_info(info)
case info:KeepAliveInfo=> ack(info)
case info:ShutdownInfo=> ack(info); connection.stop
case info:FlushCommand=> ack(info)
case info:DestinationInfo=> on_destination_info(info)
// case info:ConnectionControl=>
// case info:ConnectionError=>
// case info:ConsumerControl=>
// case info:RemoveSubscriptionInfo=>
// case info:ControlCommand=>
///////////////////////////////////////////////////////////////////
// Methods for cluster operations
// These commands are sent to the broker when it's acting like a
// client to another broker.
///////////////////////////////////////////////////////////////////
// case info:BrokerInfo=>
// case info:MessageDispatch=>
// case info:MessageDispatchNotification=>
// case info:ProducerAck=>
case _ =>
die("Unspported command: " + command.getClass);
}
}
} catch {
case e: Break =>
case e: Exception =>
e.printStackTrace
async_die("Internal Server Error")
} finally {
current_command = null
}
}
class ProtocolException(msg:String) extends RuntimeException(msg)
class Break extends RuntimeException
def async_fail(msg: String, actual:Command=null):Unit = try {
fail(msg, actual)
} catch {
case x:Break=>
}
def fail[T](msg: String, actual:Command=null):T = {
def respond(command:Command) = {
if(command.isResponseRequired()) {
val e = new ProtocolException(msg)
e.fillInStackTrace
val rc = new ExceptionResponse()
rc.setCorrelationId(command.getCommandId())
rc.setException(e)
connection_session.offer(rc)
} else {
connection_error()
}
}
def connection_error() = {
val e = new ProtocolException(msg)
e.fillInStackTrace()
val err = new ConnectionError()
err.setException(e)
connection_session.offer(err)
}
(current_command,actual) match {
case (null, null)=>
connection_error()
case (null, command:Command)=>
respond(command)
case (command:Command, null)=>
connection_error()
case (command:Command, command2:Command)=>
respond(command)
}
throw new Break()
}
def async_die(msg: String, actual:Command=null):Unit = try {
die(msg, actual)
} catch {
case x:Break=>
}
/**
* A protocol error that cannot be recovered from. It results in the connections being terminated.
*/
def die[T](msg: String, actual:Command=null):T = {
if (!dead) {
dead = true
debug("Shutting connection down due to: " + msg)
// TODO: if there are too many open connections we should just close the connection
// without waiting for the error to get sent to the client.
queue.after(die_delay, TimeUnit.MILLISECONDS) {
connection.stop()
}
fail(msg, actual)
}
throw new Break()
}
def on_wire_format_info(info: WireFormatInfo) = {
if (!info.isValid()) {
die("Remote wire format magic is invalid")
} else if (info.getVersion() < minimum_protocol_version) {
die("Remote wire format (%s) is lower the minimum version required (%s)".format(info.getVersion(), minimum_protocol_version))
}
wire_format = connection.transport.getProtocolCodec.asInstanceOf[OpenwireCodec].format
wire_format.renegotiateWireFormat(info, preferred_wireformat_settings)
connection.transport match {
case x: TcpTransport =>
x.getSocketChannel.socket.setTcpNoDelay(wire_format.isTcpNoDelayEnabled())
case _ =>
}
val inactive_time = preferred_wireformat_settings.getMaxInactivityDuration().min(info.getMaxInactivityDuration())
val initial_delay = preferred_wireformat_settings.getMaxInactivityDurationInitalDelay().min(info.getMaxInactivityDurationInitalDelay())
if (inactive_time > 0) {
heart_beat_monitor.setReadInterval((inactive_time.min(5000)*1.5).toLong)
heart_beat_monitor.setOnDead(^{
async_die("Stale connection. Missed heartbeat.")
})
heart_beat_monitor.setWriteInterval(inactive_time / 2)
heart_beat_monitor.setOnKeepAlive(^{
// we don't care if the offer gets rejected.. since that just
// means there is other traffic getting transmitted.
connection.transport.offer(new KeepAliveInfo)
})
}
heart_beat_monitor.setInitialReadCheckDelay(initial_delay)
heart_beat_monitor.setInitialWriteCheckDelay(initial_delay)
heart_beat_monitor.suspendRead()
heart_beat_monitor.setTransport(connection.transport)
heart_beat_monitor.start
// Give the client some info about this broker.
val brokerInfo = new BrokerInfo();
brokerInfo.setBrokerId(new BrokerId(host.config.id));
brokerInfo.setBrokerName(host.config.id);
brokerInfo.setBrokerURL(host.broker.get_connect_address);
connection_session.offer(brokerInfo);
}
///////////////////////////////////////////////////////////////////
// Connection / Session / Consumer / Producer state tracking.
///////////////////////////////////////////////////////////////////
def on_connection_info(info: ConnectionInfo) = {
val id = info.getConnectionId()
if (!all_connections.contains(id)) {
new ConnectionContext(info).attach
security_context.user = info.getUserName
security_context.password = info.getPassword
reset {
if( host.authenticator!=null && host.authorizer!=null ) {
suspend_read("authenticating and authorizing connect")
val auth_failure = host.authenticator.authenticate(security_context)
if( auth_failure!=null ) {
async_die(auth_failure+". Credentials="+security_context.credential_dump)
noop // to make the cps compiler plugin happy.
} else if( !host.authorizer.can(security_context, "connect", connection.connector) ) {
async_die("Not authorized to connect to connector '%s'. Principals=".format(connection.connector.id, security_context.principal_dump))
noop // to make the cps compiler plugin happy.
} else if( !host.authorizer.can(security_context, "connect", this.host) ) {
async_die("Not authorized to connect to virtual host '%s'. Principals=".format(this.host.id, security_context.principal_dump))
noop // to make the cps compiler plugin happy.
} else {
resume_read
ack(info);
noop
}
} else {
ack(info);
noop
}
}
} else {
ack(info);
}
}
def on_session_info(info: SessionInfo) = {
val id = info.getSessionId();
if (!all_sessions.contains(id)) {
val parent = all_connections.get(id.getParentId()).getOrElse(die("Cannot add a session to a connection that had not been registered."))
new SessionContext(parent, info).attach
}
ack(info);
}
def on_producer_info(info: ProducerInfo) = {
val id = info.getProducerId
if (!all_producers.contains(id)) {
val parent = all_sessions.get(id.getParentId()).getOrElse(die("Cannot add a producer to a session that had not been registered."))
new ProducerContext(parent, info).attach
}
ack(info);
}
def on_consumer_info(info: ConsumerInfo) = {
val id = info.getConsumerId
if (!all_consumers.contains(id)) {
val parent = all_sessions.get(id.getParentId()).getOrElse(die("Cannot add a consumer to a session that had not been registered."))
new ConsumerContext(parent, info).attach
} else {
ack(info);
}
}
def on_destination_info(info:DestinationInfo) = {
val destinations = to_destination_dto(info.getDestination)
// if( info.getDestination.isTemporary ) {
// destinations.foreach(_.temp_owner = connection.id)
// }
reset{
val rc = info.getOperationType match {
case DestinationInfo.ADD_OPERATION_TYPE=>
host.router.create(destinations, security_context)
case DestinationInfo.REMOVE_OPERATION_TYPE=>
host.router.delete(destinations, security_context)
}
rc match {
case None =>
ack(info)
case Some(error)=>
ack(info)
}
}
}
def on_remove_info(info: RemoveInfo) = {
info.getObjectId match {
case id: ConnectionId => all_connections.get(id).foreach(_.dettach)
case id: SessionId => all_sessions.get(id).foreach(_.dettach)
case id: ProducerId => all_producers.get(id).foreach(_.dettach)
case id: ConsumerId => all_consumers.get(id).foreach(_.dettach )
// case id: DestinationInfo =>
case _ => die("Invalid object id.")
}
ack(info)
}
def on_transaction_info(info:TransactionInfo) = {
val parent = all_connections.get(info.getConnectionId()).getOrElse(die("Cannot add a session to a connection that had not been registered."))
val id = info.getTransactionId
info.getType match {
case TransactionInfo.BEGIN =>
get_or_create_tx_ctx(parent, id)
ack(info)
case TransactionInfo.COMMIT_ONE_PHASE =>
get_tx_ctx(id).commit {
ack(info)
}
case TransactionInfo.ROLLBACK =>
get_tx_ctx(id).rollback
ack(info)
case TransactionInfo.FORGET =>
//die("XA not yet supported")
// get_tx_ctx(id).forget
ack(info)
case TransactionInfo.END =>
//die("XA not yet supported")
// get_tx_ctx(id).end
ack(info)
case TransactionInfo.PREPARE =>
// die("XA not yet supported")
// get_tx_ctx(id).prepare
ack(info)
case TransactionInfo.COMMIT_TWO_PHASE =>
// die("XA not yet supported")
get_tx_ctx(id).commit {
ack(info)
}
case TransactionInfo.RECOVER =>
// die("XA not yet supported")
val receipt = new DataArrayResponse
var data = Array[DataStructure]()
receipt.setData(data)
receipt.setCorrelationId(info.getCommandId)
connection_session.offer(receipt);
case _ =>
fail("Transaction info type unknown: " + info.getType)
}
}
///////////////////////////////////////////////////////////////////
// Core message processing
///////////////////////////////////////////////////////////////////
def on_message(msg: ActiveMQMessage) = {
val producer = all_producers.get(msg.getProducerId).getOrElse(die("Producer associated with the message has not been registered."))
if (msg.getOriginalDestination() == null) {
msg.setOriginalDestination(msg.getDestination());
}
if( msg.getTransactionId==null ) {
perform_send(msg)
} else {
get_or_create_tx_ctx(producer.parent.parent, msg.getTransactionId) { (uow)=>
perform_send(msg, uow)
}
}
}
def perform_send(msg:ActiveMQMessage, uow:StoreUOW=null): Unit = {
val destiantion = to_destination_dto(msg.getDestination)
val key = destiantion.toList
producerRoutes.get(key) match {
case null =>
// create the producer route...
val route = new DeliveryProducerRoute(host.router) {
override def connection = Some(OpenwireProtocolHandler.this.connection)
override def dispatch_queue = queue
refiller = ^ {
resume_read
}
}
// don't process frames until producer is connected...
connection.transport.suspendRead
reset {
val rc = host.router.connect(destiantion, route, security_context)
rc match {
case Some(failure) =>
async_die(failure, msg)
case None =>
if (!connection.stopped) {
resume_read
producerRoutes.put(key, route)
send_via_route(route, msg, uow)
}
}
}
case route =>
// we can re-use the existing producer route
send_via_route(route, msg, uow)
}
}
def send_via_route(route:DeliveryProducerRoute, message:ActiveMQMessage, uow:StoreUOW) = {
if( !route.targets.isEmpty ) {
// We may need to add some headers..
val delivery = new Delivery
delivery.message = new OpenwireMessage(message)
delivery.size = message.getSize
delivery.uow = uow
if( message.isResponseRequired ) {
delivery.ack = { (consumed, uow) =>
dispatchQueue <<| ^{
ack(message)
}
}
}
// routes can always accept at least 1 delivery...
assert( !route.full )
route.offer(delivery)
if( route.full ) {
// but once it gets full.. suspend, so that we get more messages
// until it's not full anymore.
suspend_read("blocked destination: "+route.overflowSessions.mkString(", "))
}
} else {
// info("Dropping message. No consumers interested in message.")
ack(message)
}
// message.release
}
def on_message_ack(info:MessageAck) = {
val consumer = all_consumers.get(info.getConsumerId).getOrElse(die("Cannot ack a message on a consumer that had not been registered."))
consumer.ack_handler.credit(info)
info.getTransactionId match {
case null =>
consumer.ack_handler.perform_ack(info)
case txid =>
get_or_create_tx_ctx(consumer.parent.parent, txid){ (uow)=>
consumer.ack_handler.perform_ack(info, uow)
}
}
ack(info)
}
// public Response processAddDestination(DestinationInfo info) throws Exception {
// ActiveMQDestination destination = info.getDestination();
// if (destination.isTemporary()) {
// // Keep track of it so that we can remove them this connection
// // shuts down.
// temporaryDestinations.add(destination);
// }
// host.createQueue(destination);
// return ack(info);
// }
val all_connections = new HashMap[ConnectionId, ConnectionContext]();
val all_sessions = new HashMap[SessionId, SessionContext]();
val all_producers = new HashMap[ProducerId, ProducerContext]();
val all_consumers = new HashMap[ConsumerId, ConsumerContext]();
val all_transactions = new HashMap[TransactionId, TransactionContext]();
val all_temp_dests = List[ActiveMQDestination]();
class ConnectionContext(val info: ConnectionInfo) {
val sessions = new HashMap[SessionId, SessionContext]();
val transactions = new HashMap[TransactionId, TransactionContext]();
def default_session_id = new SessionId(info.getConnectionId(), -1)
def attach = {
// create the default session.
new SessionContext(this, new SessionInfo(default_session_id)).attach
all_connections.put(info.getConnectionId, this)
}
def dettach = {
sessions.values.toArray.foreach(_.dettach)
transactions.values.toArray.foreach(_.dettach)
all_connections.remove(info.getConnectionId)
}
}
class SessionContext(val parent: ConnectionContext, val info: SessionInfo) {
val producers = new HashMap[ProducerId, ProducerContext]();
val consumers = new HashMap[ConsumerId, ConsumerContext]();
def attach = {
parent.sessions.put(info.getSessionId, this)
all_sessions.put(info.getSessionId, this)
}
def dettach = {
producers.values.toArray.foreach(_.dettach)
consumers.values.toArray.foreach(_.dettach)
parent.sessions.remove(info.getSessionId)
all_sessions.remove(info.getSessionId)
}
}
def noop = shift { k: (Unit=>Unit) => k() }
class ProducerContext(val parent: SessionContext, val info: ProducerInfo) {
def attach = {
parent.producers.put(info.getProducerId, this)
all_producers.put(info.getProducerId, this)
}
def dettach = {
parent.producers.remove(info.getProducerId)
all_producers.remove(info.getProducerId)
}
}
class ConsumerContext(val parent: SessionContext, val info: ConsumerInfo) extends BaseRetained with DeliveryConsumer {
// The following comes in handy if we need to debug the
// reference counts of the consumers.
// val r = new BaseRetained
//
// def setDisposer(p1: Runnable): Unit = r.setDisposer(p1)
// def retained: Int =r.retained
//
// def printST(name:String) = {
// val e = new Exception
// println(name+": "+connection.map(_.id))
// println(" "+e.getStackTrace.drop(1).take(4).mkString("\n "))
// }
//
// def retain: Unit = {
// printST("retain")
// r.retain
// }
// def release: Unit = {
// printST("release")
// r.release
// }
override def toString = "openwire consumer id:"+info.getConsumerId+", remote address: "+security_context.remote_address
var selector_expression:BooleanExpression = _
var destination:Array[DestinationDTO] = _
val consumer_sink = sink_manager.open()
val credit_window_filter = new CreditWindowFilter[Delivery](consumer_sink.map { delivery =>
val dispatch = new MessageDispatch
dispatch.setConsumerId(info.getConsumerId)
if( delivery.message eq EndOfBrowseMessage ) {
// Then send the end of browse message.
dispatch
} else {
var msg = delivery.message.asInstanceOf[OpenwireMessage].message
ack_handler.track(msg.getMessageId, delivery.ack)
dispatch.setDestination(msg.getDestination)
dispatch.setMessage(msg)
}
dispatch
}, Delivery)
credit_window_filter.credit(0, info.getPrefetchSize)
val session_manager = new SessionSinkMux[Delivery](credit_window_filter, dispatchQueue, Delivery) {
override def time_stamp = broker.now
}
override def dispose() = dispatchQueue {
ack_handler.close
super.dispose()
sink_manager.close(consumer_sink,(frame)=>{
// No point in sending the frame down to the socket..
})
}
override def exclusive = info.isExclusive
override def browser = info.isBrowser
def attach = {
if( info.getDestination == null ) fail("destination was not set")
destination = to_destination_dto(info.getDestination)
// if they are temp dests.. attach our owner id so that we don't
// get rejected.
// if( info.getDestination.isTemporary ) {
// destination.foreach(_.temp_owner = connection.get.id)
// }
parent.consumers.put(info.getConsumerId, this)
all_consumers.put(info.getConsumerId, this)
var is_durable_sub = info.getSubscriptionName!=null
selector_expression = info.getSelector match {
case null=> null
case x=>
try {
SelectorParser.parse(x)
} catch {
case e:FilterException =>
fail("Invalid selector expression: "+e.getMessage)
}
}
if( is_durable_sub ) {
var subscription_id = ""
if( parent.parent.info.getClientId != null ) {
subscription_id += parent.parent.info.getClientId + ":"
}
subscription_id += info.getSubscriptionName
val rc = new DurableSubscriptionDestinationDTO(subscription_id)
rc.selector = info.getSelector
destination.foreach { _ match {
case x:TopicDestinationDTO=>
rc.topics.add(new TopicDestinationDTO(x.path))
case _ => die("A durable subscription can only be used on a topic destination")
}
}
destination = Array(rc)
}
reset {
val rc = host.router.bind(destination, this, security_context)
rc match {
case None =>
ack(info)
noop
case Some(reason) =>
async_fail(reason, info)
noop
}
}
this.release
}
def dettach = {
host.router.unbind(destination, this, false , security_context)
parent.consumers.remove(info.getConsumerId)
all_consumers.remove(info.getConsumerId)
}
///////////////////////////////////////////////////////////////////
// DeliveryConsumer impl
///////////////////////////////////////////////////////////////////
def dispatch_queue = OpenwireProtocolHandler.this.dispatchQueue
override def connection = Some(OpenwireProtocolHandler.this.connection)
def is_persistent = false
override def receive_buffer_size = codec.write_buffer_size
def matches(delivery:Delivery) = {
if( delivery.message.protocol eq OpenwireProtocol ) {
if( selector_expression!=null ) {
selector_expression.matches(delivery.message)
} else {
true
}
} else {
false
}
}
class OpenwireConsumerSession(val producer:DeliveryProducer) extends DeliverySession with SessionSinkFilter[Delivery] {
producer.dispatch_queue.assertExecuting()
retain
val downstream = session_manager.open(producer.dispatch_queue, receive_buffer_size)
var closed = false
def consumer = ConsumerContext.this
def close = {
assert(producer.dispatch_queue.isExecuting)
if( !closed ) {
closed = true
if( browser ) {
val delivery = new Delivery()
delivery.message = EndOfBrowseMessage
if( downstream.full ) {
// session is full so use an overflow sink so to hold the message,
// and then trigger closing the session once it empties out.
val sink = new OverflowSink(downstream)
sink.refiller = ^{
dispose
}
sink.offer(delivery)
} else {
downstream.offer(delivery)
dispose
}
} else {
dispose
}
}
}
def dispose = {
session_manager.close(downstream,(delivery)=>{
// We have been closed so we have to nak any deliveries.
if( delivery.ack!=null ) {
delivery.ack(Undelivered, delivery.uow)
}
})
if( info.getDestination.isTemporary ) {
reset {
val rc = host.router.delete(destination, security_context)
rc match {
case Some(error) =>
async_die(error)
case None =>
unit
}
}
}
release
}
// Delegate all the flow control stuff to the session
def offer(delivery:Delivery) = {
if( full ) {
false
} else {
delivery.message.retain()
val rc = downstream.offer(delivery)
assert(rc, "offer should be accepted since it was not full")
true
}
}
}
def connect(p:DeliveryProducer) = new OpenwireConsumerSession(p)
class TrackedAck(val ack:(DeliveryResult, StoreUOW)=>Unit) {
var credited = false
}
val ack_source = createSource(EventAggregators.INTEGER_ADD, dispatch_queue)
ack_source.setEventHandler(^ {
val data = ack_source.getData
credit_window_filter.credit(0, data)
});
ack_source.resume
object ack_handler {
// TODO: Need to validate all the range ack cases...
var consumer_acks = ListBuffer[(MessageId,TrackedAck)]()
def close = {
queue.assertExecuting()
consumer_acks.foreach { case(_, tack) =>
if( tack.ack !=null ) {
tack.ack(Delivered, null)
}
}
consumer_acks = null
}
def track(msgid:MessageId, ack:(DeliveryResult, StoreUOW)=>Unit) = {
queue.assertExecuting()
if( consumer_acks==null ) {
// It can happen if we get closed.. but destination is still sending data..
if( ack!=null ) {
ack(Undelivered, null)
}
} else {
consumer_acks += msgid -> new TrackedAck(ack)
}
}
def credit(messageAck: MessageAck):Unit = {
queue.assertExecuting()
val msgid: MessageId = messageAck.getLastMessageId
if( messageAck.getAckType == MessageAck.INDIVIDUAL_ACK_TYPE) {
for( (id, delivery) <- consumer_acks.find(_._1 == msgid) ) {
if ( !delivery.credited ) {
ack_source.merge(1)
delivery.credited = true;
}
}
} else {
var found = false
val (acked, not_acked) = consumer_acks.partition{ case (id, ack)=>
if( id == msgid ) {
found = true
true
} else {
!found
}
}
for( (id, delivery) <- acked ) {
// only credit once...
if( !delivery.credited ) {
ack_source.merge(1)
delivery.credited = true;
}
}
}
}
def perform_ack(messageAck: MessageAck, uow:StoreUOW=null) = {
queue.assertExecuting()
val msgid = messageAck.getLastMessageId
val consumed = messageAck.getAckType match {
case MessageAck.DELIVERED_ACK_TYPE => Consumed
case MessageAck.INDIVIDUAL_ACK_TYPE => Consumed
case MessageAck.STANDARD_ACK_TYPE => Consumed
case MessageAck.POSION_ACK_TYPE => Poisoned
case MessageAck.REDELIVERED_ACK_TYPE => Delivered
case MessageAck.UNMATCHED_ACK_TYPE => Consumed
}
if( messageAck.getAckType == MessageAck.INDIVIDUAL_ACK_TYPE) {
consumer_acks = consumer_acks.filterNot{ case (id, delivery)=>
if( id == msgid) {
if( delivery.ack!=null ) {
delivery.ack(consumed, uow)
}
true
} else {
false
}
}
} else {
// session acks ack all previously received messages..
var found = false
val (acked, not_acked) = consumer_acks.partition{ case (id, ack)=>
if( id == msgid ) {
found = true
true
} else {
!found
}
}
if( !found ) {
trace("%s: ACK failed, invalid message id: %s, dest: %s".format(security_context.remote_address, msgid, destination.mkString(",")))
} else {
consumer_acks = not_acked
acked.foreach{case (id, delivery)=>
if( delivery.ack!=null ) {
delivery.ack(consumed, uow)
}
}
}
}
}
//
// def apply(messageAck: MessageAck, uow:StoreUOW=null) = {
//
// var found = false
// val (acked, not_acked) = consumer_acks.partition{ case (id, _)=>
// if( found ) {
// false
// } else {
// if( id == messageAck.getLastMessageId ) {
// found = true
// }
// true
// }
// }
//
// if( acked.isEmpty ) {
// async_fail("ACK failed, invalid message id: %s".format(messageAck.getLastMessageId), messageAck)
// } else {
// consumer_acks = not_acked
// acked.foreach{case (_, callback)=>
// if( callback!=null ) {
// callback(Delivered, uow)
// }
// }
// }
// }
}
}
class TransactionContext(val parent: ConnectionContext, val id: TransactionId) {
// TODO: eventually we want to back this /w a broker Queue which
// can provides persistence and memory swapping.
// Buffer xid = null;
// if (tid.isXATransaction()) {
// xid = XidImpl.toBuffer((Xid) tid);
// }
// t = host.getTransactionManager().createTransaction(xid);
// transactions.put(tid, t);
val actions = ListBuffer[(StoreUOW)=>Unit]()
def attach = {
parent.transactions.put(id, this)
all_transactions.put(id, this)
}
def dettach = {
actions.clear
parent.transactions.remove(id)
all_transactions.remove(id)
}
def apply(proc:(StoreUOW)=>Unit) = {
actions += proc
}
def commit(onComplete: => Unit) = {
val uow = if( host.store!=null ) {
host.store.create_uow
} else {
null
}
actions.foreach { proc =>
proc(uow)
}
if( uow!=null ) {
uow.on_complete(onComplete)
uow.release
} else {
onComplete
}
}
def rollback() = {
actions.clear
}
}
def create_tx_ctx(connection:ConnectionContext, txid:TransactionId):TransactionContext= {
if ( all_transactions.contains(txid) ) {
die("transaction allready started")
} else {
val context = new TransactionContext(connection, txid)
context.attach
context
}
}
def get_or_create_tx_ctx(connection:ConnectionContext, txid:TransactionId):TransactionContext = {
all_transactions.get(txid) match {
case Some(ctx)=> ctx
case None=>
val context = new TransactionContext(connection, txid)
context.attach
context
}
}
def get_tx_ctx(txid:TransactionId):TransactionContext = {
all_transactions.get(txid) match {
case Some(ctx)=> ctx
case None=> die("transaction not active: %d".format(txid))
}
}
def remove_tx_ctx(txid:TransactionId):TransactionContext= {
all_transactions.get(txid) match {
case None=>
die("transaction not active: %d".format(txid))
case Some(tx)=>
tx.dettach
tx
}
}
}