blob: 8d75f1d87e300e0ae996ce3467d0e59ea5615579 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
from concurrent.futures import Future
import warnings
from gremlin_python.driver import client, serializer
from gremlin_python.driver.remote_connection import (
RemoteConnection, RemoteTraversal)
from gremlin_python.process.strategies import OptionsStrategy
from gremlin_python.process.traversal import Bytecode
import uuid
log = logging.getLogger("gremlinpython")
__author__ = 'David M. Brown (davebshow@gmail.com), Lyndon Bauto (lyndonb@bitquilltech.com)'
class DriverRemoteConnection(RemoteConnection):
def __init__(self, url, traversal_source="g", protocol_factory=None,
transport_factory=None, pool_size=None, max_workers=None,
username="", password="", kerberized_service='',
message_serializer=None, graphson_reader=None,
graphson_writer=None, headers=None, session=None,
enable_user_agent_on_connect=True, **transport_kwargs):
log.info("Creating DriverRemoteConnection with url '%s'", str(url))
self.__url = url
self.__traversal_source = traversal_source
self.__protocol_factory = protocol_factory
self.__transport_factory = transport_factory
self.__pool_size = pool_size
self.__max_workers = max_workers
self.__username = username
self.__password = password
self.__kerberized_service = kerberized_service
self.__message_serializer = message_serializer
self.__graphson_reader = graphson_reader
self.__graphson_writer = graphson_writer
self.__headers = headers
self.__session = session
self.__enable_user_agent_on_connect = enable_user_agent_on_connect
self.__transport_kwargs = transport_kwargs
# keeps a list of sessions that have been spawned from this DriverRemoteConnection
# so that they can be closed if this parent session is closed.
self.__spawned_sessions = []
if message_serializer is None and graphson_reader is not None and graphson_writer is not None:
message_serializer = serializer.GraphSONMessageSerializer(
reader=graphson_reader,
writer=graphson_writer)
self._client = client.Client(url, traversal_source,
protocol_factory=protocol_factory,
transport_factory=transport_factory,
pool_size=pool_size,
max_workers=max_workers,
message_serializer=message_serializer,
username=username,
password=password,
kerberized_service=kerberized_service,
headers=headers,
session=session,
enable_user_agent_on_connect=enable_user_agent_on_connect,
**transport_kwargs)
self._url = self._client._url
self._traversal_source = self._client._traversal_source
def close(self):
# close this client and any DriverRemoteConnection instances spawned from this one
# for a session
if len(self.__spawned_sessions) > 0:
log.info("closing spawned sessions from DriverRemoteConnection with url '%s'", str(self._url))
for spawned_session in self.__spawned_sessions:
spawned_session.close()
self.__spawned_sessions.clear()
if self.__session:
log.info("closing DriverRemoteConnection with url '%s' with session '%s'",
str(self._url), str(self.__session))
else:
log.info("closing DriverRemoteConnection with url '%s'", str(self._url))
self._client.close()
def submit(self, bytecode):
log.debug("submit with bytecode '%s'", str(bytecode))
result_set = self._client.submit(bytecode, request_options=self._extract_request_options(bytecode))
results = result_set.all().result()
return RemoteTraversal(iter(results))
def submitAsync(self, message, bindings=None, request_options=None):
warnings.warn(
"gremlin_python.driver.driver_remote_connection.DriverRemoteConnection.submitAsync will be replaced by "
"gremlin_python.driver.driver_remote_connection.DriverRemoteConnection.submit_async.",
DeprecationWarning)
self.submit_async(message, bindings, request_options)
def submit_async(self, bytecode):
log.debug("submit_async with bytecode '%s'", str(bytecode))
future = Future()
future_result_set = self._client.submit_async(bytecode, request_options=self._extract_request_options(bytecode))
def cb(f):
try:
result_set = f.result()
results = result_set.all().result()
future.set_result(RemoteTraversal(iter(results)))
except Exception as e:
future.set_exception(e)
future_result_set.add_done_callback(cb)
return future
def is_closed(self):
return self._client.is_closed()
def is_session_bound(self):
return self.__session is not None
def create_session(self):
log.info("Creating session based connection")
if self.is_session_bound():
raise Exception('Connection is already bound to a session - child sessions are not allowed')
conn = DriverRemoteConnection(self.__url,
traversal_source=self.__traversal_source,
protocol_factory=self.__protocol_factory,
transport_factory=self.__transport_factory,
pool_size=self.__pool_size,
max_workers=self.__max_workers,
username=self.__username,
password=self.__password,
kerberized_service=self.__kerberized_service,
message_serializer=self.__message_serializer,
graphson_reader=self.__graphson_reader,
graphson_writer=self.__graphson_writer,
headers=self.__headers,
session=uuid.uuid4(),
enable_user_agent_on_connect=self.__enable_user_agent_on_connect,
**self.__transport_kwargs)
self.__spawned_sessions.append(conn)
return conn
def remove_session(self, session_based_connection):
session_based_connection.close()
self.__spawned_sessions.remove(session_based_connection)
def commit(self):
log.info("Submitting commit graph operation.")
return self._client.submit(Bytecode.GraphOp.commit())
def rollback(self):
log.info("Submitting rollback graph operation.")
return self._client.submit(Bytecode.GraphOp.rollback())
@staticmethod
def _extract_request_options(bytecode):
options_strategy = next((x for x in bytecode.source_instructions
if x[0] == "withStrategies" and type(x[1]) is OptionsStrategy), None)
request_options = None
if options_strategy:
allowed_keys = ['evaluationTimeout', 'scriptEvaluationTimeout', 'batchSize', 'requestId', 'userAgent']
request_options = {allowed: options_strategy[1].configuration[allowed] for allowed in allowed_keys
if allowed in options_strategy[1].configuration}
return request_options