blob: 686f1173499e426005e408899aa31fd3db28a0be [file] [log] [blame]
#
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import argparse
import base64
import functools
import json
import sys
import uuid
import weakref
from threading import Thread, Lock
try:
from urllib.parse import urlencode
except ImportError:
from urllib import urlencode
import pika
import tornado.websocket
import tornado.ioloop
import tornado.auth
import tornado.escape
import tornado.concurrent
SETTINGS = {}
class Error(Exception):
"""Base error class for exceptions in this module"""
pass
class ConsumerConfigError(Error):
"""Raised when an issue with consumer configuration occurs"""
def __init__(self, message):
self.message = message
class ConsumerKeyError(Error):
def __init__(self, message, key):
self.message = message
self.key = key
class AuthError(Error):
"""Raised when something went wrong during authentication"""
def __init__(self, error, code):
self.message = error
self.code = code
class PikaAsyncConsumer(Thread):
"""
The primary entry point for routing incoming messages to the proper handler.
"""
def __init__(self, rabbitmq_url, exchange_name, queue_name,
exchange_type="direct", routing_key="#"):
"""
Create a new instance of Streamer.
Arguments:
rabbitmq_url -- URL to RabbitMQ server
exchange_name -- name of RabbitMQ exchange to join
queue_name -- name of RabbitMQ queue to join
Keyword Arguments:
exchange_type -- one of 'direct', 'topic', 'fanout', 'headers'
(default 'direct')
routing_keys -- the routing key that this consumer listens for
(default '#', receives all messages)
"""
print("Creating new consumer")
super(PikaAsyncConsumer, self).__init__(daemon=True)
self._connection = None
self._channel = None
self._shut_down = False
self._consumer_tag = None
self._url = rabbitmq_url
self._client_list = []
self._lock = Lock()
# The following are necessary to guarantee that both the RabbitMQ
# server and Streamer know where to look for messages. These names will
# be decided before dispatch and should be recorded in a config file or
# else on a per-job basis.
self._exchange = exchange_name
self._exchange_type = exchange_type
self._queue = queue_name
self._routing_key = routing_key
def add_client(self, client):
"""Add a new client to the recipient list.
Arguments:
client -- a reference to the client object to add
"""
self._lock.acquire()
# Create a weakref to ensure that cyclic references to WebSocketHandler
# objects do not cause problems for garbage collection
self._client_list.append(weakref.ref(client))
self._lock.release()
def remove_client(self, client):
"""Remove a client from the recipient list.
Arguments:
client -- a reference to the client object to remove
"""
self._lock.acquire()
for i in range(0, len(self._client_list)):
# Parentheses after _client_list[i] to deference the weakref to its
# strong reference
if self._client_list[i]() is client:
self._client_list.pop(i)
break
self._lock.release()
def connect(self):
"""
Create an asynchronous connection to the RabbitMQ server at URL.
"""
return pika.SelectConnection(pika.URLParameters(self._url),
on_open_callback=self.on_connection_open,
on_close_callback=self.on_connection_close,
stop_ioloop_on_close=False)
def on_connection_open(self, unused_connection):
"""
Actions to perform when the connection opens. This may not happen
immediately, so defer action to this callback.
Arguments:
unused_connection -- the created connection (by this point already
available as self._connection)
"""
self._connection.channel(on_open_callback=self.on_channel_open)
def on_connection_close(self, connection, code, text):
"""
Actions to perform when the connection is unexpectedly closed by the
RabbitMQ server.
Arguments:
connection -- the connection that was closed (same as self._connection)
code -- response code from the RabbitMQ server
text -- response body from the RabbitMQ server
"""
self._channel = None
if self._shut_down:
self._connection.ioloop.stop()
else:
self._connection.add_timeout(5, self.reconnect)
def reconnect(self):
"""
Attempt to reestablish a connection with the RabbitMQ server.
"""
self._connection.ioloop.stop() # Stop the ioloop to completely close
if not self._shut_down: # Connect and restart the ioloop
self._connection = self.connect()
self._connection.ioloop.start()
def on_channel_open(self, channel):
"""
Store the opened channel for future use and set up the exchange and
queue to be used.
Arguments:
channel -- the Channel instance opened by the Channel.Open RPC
"""
self._channel = channel
self._channel.add_on_close_callback(self.on_channel_close)
self.declare_exchange()
def on_channel_close(self, channel, code, text):
"""
Actions to perform when the channel is unexpectedly closed by the
RabbitMQ server.
Arguments:
connection -- the connection that was closed (same as self._connection)
code -- response code from the RabbitMQ server
text -- response body from the RabbitMQ server
"""
self._connection.close()
def declare_exchange(self):
"""
Set up the exchange that will route messages to this consumer. Each
RabbitMQ exchange is uniquely identified by its name, so it does not
matter if the exchange has already been declared.
"""
self._channel.exchange_declare(self.declare_exchange_success,
self._exchange,
self._exchange_type)
def declare_exchange_success(self, unused_connection):
"""
Actions to perform on successful exchange declaration.
"""
self.declare_queue()
def declare_queue(self):
"""
Set up the queue that will route messages to this consumer. Each
RabbitMQ queue can be defined with routing keys to use only one
queue for multiple jobs.
"""
self._channel.queue_declare(self.declare_queue_success,
self._queue)
def declare_queue_success(self, method_frame):
"""
Actions to perform on successful queue declaration.
"""
self._channel.queue_bind(self.munch,
self._queue,
self._exchange,
self._routing_key
)
def munch(self, unused):
"""
Begin consuming messages from the Airavata API server.
"""
self._channel.add_on_cancel_callback(self.cancel_channel)
self._consumer_tag = self._channel.basic_consume(self._process_message)
def cancel_channel(self, method_frame):
if self._channel is not None:
self._channel._close()
def _process_message(self, ch, method, properties, body):
"""
Receive and verify a message, then pass it to the router.
Arguments:
ch -- the channel that routed the message
method -- delivery information
properties -- message properties
body -- the message
"""
print("Received Message: %s" % body)
self._lock.acquire()
for client in self._client_list:
# Parentheses after client to deference the weakref to its
# strong reference
client().write_message(body)
self._lock.release()
self._channel.basic_ack(delivery_tag=method.delivery_tag)
def stop_consuming(self):
"""
Stop the consumer if active.
"""
if self._channel:
self._channel.basic_cancel(self.close_channel, self._consumer_tag)
def close_channel(self, unused):
"""
Close the channel to shut down the consumer and connection.
"""
self._channel.queue_delete(queue=self._queue)
self._channel.close()
def run(self):
"""
Start a connection with the RabbitMQ server.
"""
self._connection = self.connect()
self._connection.ioloop.start()
def stop(self):
"""
Stop an active connection with the RabbitMQ server.
"""
self._closing = True
self.stop_consuming()
class Wso2OAuth2Mixin(tornado.auth.OAuth2Mixin):
_OAUTH_AUTHORIZE_URL = "https://idp.scigap.org:9443/oauth2/authorize"
_OAUTH_ACCESS_TOKEN_URL = "https://idp.scigap.org:9443/oauth2/token"
@tornado.auth._auth_return_future
def get_authenticated_user(self, username, password, callback=None):
print("Authenticating user %s" % (username))
http = self.get_auth_http_client()
body = urlencode({
"client_id": SETTINGS["oauth_client_key"],
"client_secret": SETTINGS["oauth_client_secret"],
"grant_type": SETTINGS["oauth_grant_type"],
"username": username,
"password": password
})
http.fetch(self._OAUTH_ACCESS_TOKEN_URL, functools.partial(self._on_access_token, callback), method="POST", body=body)
def _on_access_token(self, future, response):
if response.error:
print(str(response))
print(response.body)
print(response.error)
future.set_exception(AuthError(response.error, response.code))
return
print(response.body)
future.set_result(tornado.escape.json_decode(response.body))
class AuthHandler(tornado.web.RequestHandler, Wso2OAuth2Mixin):
def get_current_user(self):
expires_in = self.get_secure_cookie("expires-in", max_age_days=SETTINGS['maximum_cookie_age'])
print(expires_in)
if expires_in:
return self.get_secure_cookie("ws-auth-token", max_age_days=float(expires_in))
return None
def set_default_headers(self):
self.set_header("Content-Type", "text/plain")
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header("Access-Control-Allow-Headers", "x-requested-with")
self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
def get(self):
if self.get_current_user():
self.set_status(200)
print("Authenticated")
self.write("Authenticated")
else:
self.set_status(403)
print("Not Authenticated")
self.write("Not Authenticated")
@tornado.gen.coroutine
def post(self):
try:
username = self.get_body_argument("username")
password = self.get_body_argument("password")
redirect = self.get_body_argument("redirect")
if username == "" or password == "":
raise tornado.web.MissingArgumentError
access = yield self.get_authenticated_user(username, password)
days = (access["expires_in"] / 3600) / 24 # Convert to days
print(days)
self.set_secure_cookie("ws-auth-token",
access["access_token"],
expires_days=days)
self.set_secure_cookie("expires-in",
str(1),
expires_days=SETTINGS['maximum_cookie_age'])
self.write("Success")
except tornado.web.MissingArgumentError:
print("Missing an argument")
self.set_status(400)
self.write("Authentication information missing")
except AuthError as e:
print("The future freaks me out")
self.set_status(access.code)
self.set_header("Content-Type", "text/html")
self.write(access.message)
success_code = """<p>Redirecting to <a href="%(url)s">%(url)s</a></p>
<script type="text/javascript">
window.location = %(url)s;
</script>
""" % { 'url': redirect}
self.set_status(200)
self.redirect(redirect)
#return self.render_string(success_code)
class AMQPWSHandler(tornado.websocket.WebSocketHandler):#, Wso2OAuth2Mixin):
"""
Pass messages to a connected WebSockets client.
A subclass of the Tornado WebSocketHandler class, this class takes no
action when receiving a message from the client. Instead, it is associated
with an AMQP consumer and writes a message to the client each time one is
consumed in the queue.
"""
# def set_default_headers(self):
# self.set_header("Access-Control-Allow-Origin", "*")
# self.set_header("Access-Control-Allow-Headers", "x-requested-with")
# self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
def check_origin(self, origin):
"""Check the domain origin of the connection request.
This can be made more robust to ensure that connections are only
accepted from verified PGAs.
Arguments:
origin -- the value of the Origin HTTP header
"""
return True
def open(self, resource_type, resource_id):
"""Associate a new connection with a consumer.
When a new connection is opened, it is a request to retrieve data
from an AMQP queue. The open operation should also do some kind of
authentication.
Arguments:
resource_type -- "experiment" or "project" or "data"
resource_id -- the Airavata id for the resource
"""
self.stream.set_nodelay(True)
self.resource_id = resource_id
self.write_message("Opened the connection")
self.add_to_consumer()
# expires_in = self.get_secure_cookie("expires_in", max_age_days=SETTINGS["maximum_cookie_age"])
# if expires_in is not None and self.get_secure_cookie("ws-auth-token", max_age_days=float(expires_in)):
# print("Found secure cookie")
# self.write_message("Authenticated")
# self.add_to_consumer()
# else:
# print("Closing connection")
# self.close()
def on_message(self, message):
"""Handle incoming messages from the client.
Tornado requires subclasses to override this method, however in this
case we do not wish to take any action when receiving a message from
the client. The purpose of this class is only to push messages to the
client.
"""
print(message)
message = tornado.escape.json_decode(message)
access = yield self.get_authenticated_user(message["username"], message["password"])
access = access
days = (access["expires_in"] / 3600) / 24 # Convert to days
print(days)
self.set_secure_cookie("ws-auth-token",
access["access_token"],
expires_days=days)
self.set_secure_cookie("expires_in",
str(days),
expires_days=SETTINGS['maximum_cookie_age'])
def on_close(self):
try:
print("Closing connection")
self.application.remove_client_from_consumer(self.resource_id, self)
except KeyError:
print("Error: resource %s does not exist" % self.resource_id)
finally:
self.close()
def add_to_consumer(self):
try:
self.application.add_client_to_consumer(self.resource_id, self)
except AttributeError as e:
print("Error: tornado.web.Application object is not AMQPWSTunnel")
print(e)
class AMQPWSTunnel(tornado.web.Application):
"""
Send messages from an AMQP queue to WebSockets clients.
In addition to the standard Tornado Application class functionality, this
class maintains a list of active AMQP consumers and maps WebSocketHandlers
to the correct consumers.
"""
def __init__(self, consumer_list=None, consumer_config=None, handlers=None,
default_host='', transforms=None, **settings):
print("Starting AMQP-WS-Tunnel application")
super(AMQPWSTunnel, self).__init__(handlers=handlers,
default_host=default_host,
transforms=transforms,
**settings)
self.consumer_list = {} if consumer_list is None else consumer_list
if consumer_config is None:
raise ConsumerConfigError("No consumer configuration provided")
self.consumer_config = consumer_config
def consumer_exists(self, resource_id):
"""Determine if a consumer exists for a particular resource.
Arguments:
resource_id -- the consumer to find
"""
return resource_id in self.consumer_list
def add_client_to_consumer(self, resource_id, client):
"""Add a new client to a consumer's messaging list.
Arguments:
resource_id -- the consumer to add to
client -- the client to add
"""
if not self.consumer_exists(resource_id):
print("Creating new consumer")
print(self.consumer_config)
consumer = PikaAsyncConsumer(self.consumer_config["rabbitmq_url"],
self.consumer_config["exchange_name"],
self.consumer_config["queue_name"],
exchange_type=self.consumer_config["exchange_type"],
routing_key=resource_id)
print("Adding to consumer list")
self.consumer_list[resource_id] = consumer
print("Starting consumer")
consumer.start()
print("Adding new client to %s" % (resource_id))
consumer = self.consumer_list[resource_id]
consumer.add_client(client)
def remove_client_from_consumer(self, resource_id, client):
"""Remove a client from a consumer's messaging list.
Arguments:
resource_id -- the consumer to remove from
client -- the client to remove
"""
if self.consumer_exists(resource_id):
print("Removing client from %s" % (resource_id))
self.consumer_list[resource_id].remove_client(client)
#else:
# raise ConsumerKeyError("Trying to remove client from nonexistent consumer", resource_id)
def shutdown(self):
"""Shut down the application and release all resources.
"""
for name, consumer in self.consumer_list.items():
consumer.stop()
#consumer.join()
#self.consumer_list[name] = None
#self.consumer_list = {}
if __name__ == "__main__":
i = open(sys.argv[1])
config = json.load(i)
i.close()
SETTINGS["oauth_client_key"] = config["oauth_client_key"]
SETTINGS["oauth_client_secret"] = config["oauth_client_secret"]
SETTINGS["oauth_grant_type"] = config["oauth_grant_type"]
SETTINGS["maximum_cookie_age"] = config["maximum_cookie_age"]
settings = {
"cookie_secret": base64.b64encode(uuid.uuid4().bytes + uuid.uuid4().bytes),
#"xsrf_cookies": True
}
application = AMQPWSTunnel(handlers=[
(r"/auth", AuthHandler),
(r"/(experiment)/(.+)", AMQPWSHandler)
],
consumer_config=config,
debug=True,
**settings)
application.listen(8888)
try:
tornado.ioloop.IOLoop.current().start()
except KeyboardInterrupt:
application.shutdown()