blob: 2fe4d73bf81e09c65d8ccf89efbbfa2d57707086 [file] [log] [blame]
import logging
import ssl
from contextlib import contextmanager
import thrift_connector.connection_pool as connection_pool
from airavata.api import Airavata
from airavata.service.profile.groupmanager.cpi import GroupManagerService
from airavata.service.profile.groupmanager.cpi.constants import (
GROUP_MANAGER_CPI_NAME
)
from airavata.service.profile.iam.admin.services.cpi import IamAdminServices
from airavata.service.profile.iam.admin.services.cpi.constants import (
IAM_ADMIN_SERVICES_CPI_NAME
)
from airavata.service.profile.tenant.cpi import TenantProfileService
from airavata.service.profile.tenant.cpi.constants import (
TENANT_PROFILE_CPI_NAME
)
from airavata.service.profile.user.cpi import UserProfileService
from airavata.service.profile.user.cpi.constants import USER_PROFILE_CPI_NAME
from django.conf import settings
from thrift.protocol import TBinaryProtocol
from thrift.protocol.TMultiplexedProtocol import TMultiplexedProtocol
from thrift.transport import TSocket, TSSLSocket, TTransport
log = logging.getLogger(__name__)
class ThriftConnectionException(Exception):
pass
class ThriftClientException(Exception):
pass
def get_unsecure_transport(hostname, port):
# Create a socket to the Airavata Server
transport = TSocket.TSocket(hostname, port)
# Use Buffered Protocol to speedup over raw sockets
transport = TTransport.TBufferedTransport(transport)
return transport
def get_secure_transport(hostname, port):
# Create a socket to the Airavata Server
# TODO: validate server certificate
transport = TSSLSocket.TSSLSocket(hostname, port, validate=False)
# Use Buffered Protocol to speedup over raw sockets
transport = TTransport.TBufferedTransport(transport)
return transport
def get_transport(hostname, port, secure=True):
if secure:
transport = get_secure_transport(hostname, port)
else:
transport = get_unsecure_transport(hostname, port)
return transport
def create_airavata_client(transport):
# Airavata currently uses Binary Protocol
protocol = TBinaryProtocol.TBinaryProtocol(transport)
# Create a Airavata client to use the protocol encoder
client = Airavata.Client(protocol)
return client
def get_binary_protocol(transport):
return TBinaryProtocol.TBinaryProtocol(transport)
def create_group_manager_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, GROUP_MANAGER_CPI_NAME)
return GroupManagerService.Client(multiplex_prot)
def create_iamadmin_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol,
IAM_ADMIN_SERVICES_CPI_NAME)
return IamAdminServices.Client(multiplex_prot)
def create_tenant_profile_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, TENANT_PROFILE_CPI_NAME)
return TenantProfileService.Client(multiplex_prot)
def create_user_profile_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, USER_PROFILE_CPI_NAME)
return UserProfileService.Client(multiplex_prot)
def get_airavata_client():
"""Get Airavata API client as context manager (use in `with statement`)."""
return get_thrift_client(settings.AIRAVATA_API_HOST,
settings.AIRAVATA_API_PORT,
settings.AIRAVATA_API_SECURE,
create_airavata_client)
def get_group_manager_client():
"""Group Manager client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_group_manager_client)
def get_iam_admin_client():
"""IAM Admin client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_iamadmin_client)
def get_tenant_profile_client():
"""Tenant Profile client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_tenant_profile_client)
def get_user_profile_client():
"""User Profile client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_user_profile_client)
@contextmanager
def get_thrift_client(host, port, is_secure, client_generator):
transport = get_transport(host, port, is_secure)
client = client_generator(transport)
try:
transport.open()
log.debug("Thrift connection opened to {}:{}, "
"secure={}".format(host, port, is_secure))
try:
yield client
except Exception as e:
log.exception("Thrift client error occurred")
raise ThriftClientException(
"Thrift client error occurred: " + str(e)) from e
finally:
if transport.isOpen():
transport.close()
log.debug("Thrift connection closed to {}:{}, "
"secure={}".format(host, port, is_secure))
except ThriftClientException as tce:
# Allow thrift client errors to bubble up
raise tce
except Exception as e:
msg = "Failed to open thrift connection to {}:{}, secure={}".format(
host, port, is_secure)
log.debug(msg)
raise ThriftConnectionException(msg) from e
class CustomThriftClient(connection_pool.ThriftClient):
secure = False
validate = False
@classmethod
def get_socket_factory(cls):
if not cls.secure:
return super().get_socket_factory()
else:
def factory(host, port):
return TSSLSocket.TSSLSocket(host, port,
cert_reqs=(ssl.CERT_REQUIRED
if cls.validate
else ssl.CERT_NONE))
return factory
def ping(self):
try:
self.client.getAPIVersion()
except Exception as e:
log.debug("getAPIVersion failed: {}".format(str(e)))
raise
class MultiplexThriftClientMixin:
service_name = None
@classmethod
def get_protoco_factory(cls):
def factory(transport):
protocol = TBinaryProtocol.TBinaryProtocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, cls.service_name)
return multiplex_prot
return factory
class AiravataAPIThriftClient(CustomThriftClient):
secure = settings.AIRAVATA_API_SECURE
class GroupManagerServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = GROUP_MANAGER_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class IAMAdminServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = IAM_ADMIN_SERVICES_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class TenantProfileServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = TENANT_PROFILE_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class UserProfileServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = USER_PROFILE_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
airavata_api_client_pool = connection_pool.ClientPool(
Airavata,
settings.AIRAVATA_API_HOST,
settings.AIRAVATA_API_PORT,
connection_class=AiravataAPIThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
group_manager_client_pool = connection_pool.ClientPool(
GroupManagerService,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=GroupManagerServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
iamadmin_client_pool = connection_pool.ClientPool(
IamAdminServices,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=IAMAdminServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
tenant_profile_client_pool = connection_pool.ClientPool(
TenantProfileService,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=TenantProfileServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
user_profile_client_pool = connection_pool.ClientPool(
UserProfileService,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=UserProfileServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)