blob: 74773b8c50e29a457a53e3d08074c1586f5a0980 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import abc
import base64
import struct
# import kerberos Optional dependency imported in relevant codeblock
import six
try:
import ujson as json
except ImportError:
import json
from gremlin_python.driver import request
from gremlin_python.driver.resultset import ResultSet
__author__ = 'David M. Brown (davebshow@gmail.com)'
class GremlinServerError(Exception):
def __init__(self, status):
super(GremlinServerError, self).__init__('{0}: {1}'.format(status['code'], status['message']))
self._status_attributes = status['attributes']
self.status_code = status['code']
@property
def status_attributes(self):
return self._status_attributes
class ConfigurationError(Exception):
pass
@six.add_metaclass(abc.ABCMeta)
class AbstractBaseProtocol:
@abc.abstractmethod
def connection_made(self, transport):
self._transport = transport
@abc.abstractmethod
def data_received(self, message, results_dict):
pass
@abc.abstractmethod
def write(self, request_id, request_message):
pass
class GremlinServerWSProtocol(AbstractBaseProtocol):
MAX_CONTENT_LENGTH = 65536
QOP_AUTH_BIT = 1
_kerberos_context = None
def __init__(self, message_serializer, username='', password='', kerberized_service=''):
self._message_serializer = message_serializer
self._username = username
self._password = password
self._kerberized_service = kerberized_service
def connection_made(self, transport):
super(GremlinServerWSProtocol, self).connection_made(transport)
def write(self, request_id, request_message):
message = self._message_serializer.serialize_message(
request_id, request_message)
self._transport.write(message)
def data_received(self, message, results_dict):
# if Gremlin Server cuts off then we get a None for the message
if message is None:
raise GremlinServerError({'code': 500,
'message': 'Server disconnected - please try to reconnect', 'attributes': {}})
message = self._message_serializer.deserialize_message(message)
request_id = message['requestId']
result_set = results_dict[request_id] if request_id in results_dict else ResultSet(None, None)
status_code = message['status']['code']
aggregate_to = message['result']['meta'].get('aggregateTo', 'list')
data = message['result']['data']
result_set.aggregate_to = aggregate_to
if status_code == 407:
if self._username and self._password:
auth_bytes = b''.join([b'\x00', self._username.encode('utf-8'),
b'\x00', self._password.encode('utf-8')])
auth = base64.b64encode(auth_bytes)
request_message = request.RequestMessage(
'traversal', 'authentication', {'sasl': auth.decode()})
elif self._kerberized_service:
request_message = self._kerberos_received(message)
else:
raise ConfigurationError(
'Gremlin server requires authentication credentials in DriverRemoteConnection.'
'For basic authentication provide username and password. '
'For kerberos authentication provide the kerberized_service parameter.')
self.write(request_id, request_message)
data = self._transport.read()
# Allow for auth handshake with multiple steps
return self.data_received(data, results_dict)
elif status_code == 204:
result_set.stream.put_nowait([])
del results_dict[request_id]
return status_code
elif status_code in [200, 206]:
result_set.stream.put_nowait(data)
if status_code == 200:
result_set.status_attributes = message['status']['attributes']
del results_dict[request_id]
return status_code
else:
del results_dict[request_id]
raise GremlinServerError(message['status'])
def _kerberos_received(self, message):
# Inspired by: https://github.com/thobbs/pure-sasl/blob/0.6.2/puresasl/mechanisms.py
# https://github.com/thobbs/pure-sasl/blob/0.6.2/LICENSE
try:
import kerberos
except ImportError:
raise ImportError('Please install gremlinpython[kerberos].')
# First pass: get service granting ticket and return it to gremlin-server
if not self._kerberos_context:
try:
_, kerberos_context = kerberos.authGSSClientInit(
self._kerberized_service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
kerberos.authGSSClientStep(kerberos_context, '')
auth = kerberos.authGSSClientResponse(kerberos_context)
self._kerberos_context = kerberos_context
except kerberos.KrbError as e:
raise ConfigurationError(
'Kerberos authentication requires a valid service name in DriverRemoteConnection, '
'as well as a valid tgt (export KRB5CCNAME) or keytab (export KRB5_KTNAME): ' + str(e))
return request.RequestMessage('', 'authentication', {'sasl': auth})
# Second pass: completion of authentication
sasl_response = message['status']['attributes']['sasl']
if not self._username:
result_code = kerberos.authGSSClientStep(self._kerberos_context, sasl_response)
if result_code == kerberos.AUTH_GSS_COMPLETE:
self._username = kerberos.authGSSClientUserName(self._kerberos_context)
return request.RequestMessage('', 'authentication', {'sasl': ''})
# Third pass: sasl quality of protection (qop) handshake
# Gremlin-server Krb5Authenticator only supports qop=QOP_AUTH; use ssl for confidentiality.
# Handshake content format:
# byte 0: the selected qop. 1==auth, 2==auth-int, 4==auth-conf
# byte 1-3: the max length for any buffer sent back and forth on this connection. (big endian)
# the rest of the buffer: the authorization user name in UTF-8 - not null terminated.
kerberos.authGSSClientUnwrap(self._kerberos_context, sasl_response)
data = kerberos.authGSSClientResponse(self._kerberos_context)
plaintext_data = base64.b64decode(data)
assert len(plaintext_data) == 4, "Unexpected response from gremlin server sasl handshake"
word, = struct.unpack('!I', plaintext_data)
qop_bits = word >> 24
assert self.QOP_AUTH_BIT & qop_bits, "Unexpected sasl qop level received from gremlin server"
name_length = len(self._username)
fmt = '!I' + str(name_length) + 's'
word = self.QOP_AUTH_BIT << 24 | self.MAX_CONTENT_LENGTH
out = struct.pack(fmt, word, self._username.encode("utf-8"),)
encoded = base64.b64encode(out).decode('ascii')
kerberos.authGSSClientWrap(self._kerberos_context, encoded)
auth = kerberos.authGSSClientResponse(self._kerberos_context)
return request.RequestMessage('', 'authentication', {'sasl': auth})