blob: dac60fe033bf7bce887be577be2fb698834f5549 [file] [log] [blame]
import logging
from contextlib import contextmanager
import thrift_connector.connection_pool as connection_pool
from django.conf import settings
from thrift.protocol import TBinaryProtocol
from thrift.protocol.TMultiplexedProtocol import TMultiplexedProtocol
from thrift.transport import TSocket, TSSLSocket, TTransport
from airavata.api import Airavata
from airavata.api.sharing import SharingRegistryService
from airavata.service.profile.groupmanager.cpi import GroupManagerService
from airavata.service.profile.groupmanager.cpi.constants import (
GROUP_MANAGER_CPI_NAME
)
from airavata.service.profile.iam.admin.services.cpi import IamAdminServices
from airavata.service.profile.iam.admin.services.cpi.constants import (
IAM_ADMIN_SERVICES_CPI_NAME
)
from airavata.service.profile.tenant.cpi import TenantProfileService
from airavata.service.profile.tenant.cpi.constants import (
TENANT_PROFILE_CPI_NAME
)
from airavata.service.profile.user.cpi import UserProfileService
from airavata.service.profile.user.cpi.constants import USER_PROFILE_CPI_NAME
log = logging.getLogger(__name__)
class ThriftConnectionException(Exception):
pass
class ThriftClientException(Exception):
pass
def get_unsecure_transport(hostname, port):
# Create a socket to the Airavata Server
transport = TSocket.TSocket(hostname, port)
# Use Buffered Protocol to speedup over raw sockets
transport = TTransport.TBufferedTransport(transport)
return transport
def get_secure_transport(hostname, port):
# Create a socket to the Airavata Server
# TODO: validate server certificate
transport = TSSLSocket.TSSLSocket(hostname, port, validate=False)
# Use Buffered Protocol to speedup over raw sockets
transport = TTransport.TBufferedTransport(transport)
return transport
def get_transport(hostname, port, secure=True):
if secure:
transport = get_secure_transport(hostname, port)
else:
transport = get_unsecure_transport(hostname, port)
return transport
def create_airavata_client(transport):
# Airavata currently uses Binary Protocol
protocol = TBinaryProtocol.TBinaryProtocol(transport)
# Create a Airavata client to use the protocol encoder
client = Airavata.Client(protocol)
return client
def create_sharing_client(transport):
protocol = TBinaryProtocol.TBinaryProtocol(transport)
return SharingRegistryService.Client(protocol)
def get_binary_protocol(transport):
return TBinaryProtocol.TBinaryProtocol(transport)
def create_group_manager_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, GROUP_MANAGER_CPI_NAME)
return GroupManagerService.Client(multiplex_prot)
def create_iamadmin_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol,
IAM_ADMIN_SERVICES_CPI_NAME)
return IamAdminServices.Client(multiplex_prot)
def create_tenant_profile_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, TENANT_PROFILE_CPI_NAME)
return TenantProfileService.Client(multiplex_prot)
def create_user_profile_client(transport):
protocol = get_binary_protocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, USER_PROFILE_CPI_NAME)
return UserProfileService.Client(multiplex_prot)
def get_airavata_client():
"""Get Airavata API client as context manager (use in `with statement`)."""
return get_thrift_client(settings.AIRAVATA_API_HOST,
settings.AIRAVATA_API_PORT,
settings.AIRAVATA_API_SECURE,
create_airavata_client)
def get_sharing_client():
"""Get Sharing API client as context manager (use in `with statement`)."""
return get_thrift_client(settings.SHARING_API_HOST,
settings.SHARING_API_PORT,
settings.SHARING_API_SECURE,
create_sharing_client)
def get_group_manager_client():
"""Group Manager client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_group_manager_client)
def get_iam_admin_client():
"""IAM Admin client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_iamadmin_client)
def get_tenant_profile_client():
"""Tenant Profile client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_tenant_profile_client)
def get_user_profile_client():
"""User Profile client as context manager (use in `with statement`)."""
return get_thrift_client(settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
settings.PROFILE_SERVICE_SECURE,
create_user_profile_client)
@contextmanager
def get_thrift_client(host, port, is_secure, client_generator):
transport = get_transport(host, port, is_secure)
client = client_generator(transport)
try:
transport.open()
log.debug("Thrift connection opened to {}:{}, "
"secure={}".format(host, port, is_secure))
try:
yield client
except Exception as e:
log.exception("Thrift client error occurred")
raise ThriftClientException(
"Thrift client error occurred: " + str(e)) from e
finally:
if transport.isOpen():
transport.close()
log.debug("Thrift connection closed to {}:{}, "
"secure={}".format(host, port, is_secure))
except ThriftClientException as tce:
# Allow thrift client errors to bubble up
raise tce
except Exception as e:
msg = "Failed to open thrift connection to {}:{}, secure={}".format(
host, port, is_secure)
log.debug(msg)
raise ThriftConnectionException(msg) from e
class CustomThriftClient(connection_pool.ThriftClient):
secure = False
validate = False
@classmethod
def get_socket_factory(cls):
if not cls.secure:
return super().get_socket_factory()
else:
def factory(host, port):
return TSSLSocket.TSSLSocket(host, port, validate=cls.validate)
return factory
def ping(self):
try:
self.client.getAPIVersion()
except Exception as e:
log.debug("getAPIVersion failed: {}".format(str(e)))
raise
class MultiplexThriftClientMixin:
service_name = None
@classmethod
def get_protoco_factory(cls):
def factory(transport):
protocol = TBinaryProtocol.TBinaryProtocol(transport)
multiplex_prot = TMultiplexedProtocol(protocol, cls.service_name)
return multiplex_prot
return factory
class AiravataAPIThriftClient(CustomThriftClient):
secure = settings.AIRAVATA_API_SECURE
class GroupManagerServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = GROUP_MANAGER_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class IAMAdminServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = IAM_ADMIN_SERVICES_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class TenantProfileServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = TENANT_PROFILE_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class UserProfileServiceThriftClient(MultiplexThriftClientMixin,
CustomThriftClient):
service_name = USER_PROFILE_CPI_NAME
secure = settings.PROFILE_SERVICE_SECURE
class SharingAPIThriftClient(CustomThriftClient):
secure = settings.SHARING_API_SECURE
airavata_api_client_pool = connection_pool.ClientPool(
Airavata,
settings.AIRAVATA_API_HOST,
settings.AIRAVATA_API_PORT,
connection_class=AiravataAPIThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
group_manager_client_pool = connection_pool.ClientPool(
GroupManagerService,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=GroupManagerServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
iamadmin_client_pool = connection_pool.ClientPool(
IamAdminServices,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=IAMAdminServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
tenant_profile_client_pool = connection_pool.ClientPool(
TenantProfileService,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=TenantProfileServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
user_profile_client_pool = connection_pool.ClientPool(
UserProfileService,
settings.PROFILE_SERVICE_HOST,
settings.PROFILE_SERVICE_PORT,
connection_class=UserProfileServiceThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)
sharing_api_client_pool = connection_pool.ClientPool(
SharingRegistryService,
settings.SHARING_API_HOST,
settings.SHARING_API_PORT,
connection_class=SharingAPIThriftClient,
keepalive=settings.THRIFT_CLIENT_POOL_KEEPALIVE
)