blob: ef81f8d4065c3bdd9ef829df0a52daa6064556a5 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
from six.moves import queue
from gremlin_python.driver import connection, protocol, request, serializer
from gremlin_python.process import traversal
# This is until concurrent.futures backport 3.1.0 release
try:
from multiprocessing import cpu_count
except ImportError:
# some platforms don't have multiprocessing
def cpu_count():
return None
__author__ = 'David M. Brown (davebshow@gmail.com), Lyndon Bauto (lyndonb@bitquilltech.com)'
class Client:
def __init__(self, url, traversal_source, protocol_factory=None,
transport_factory=None, pool_size=None, max_workers=None,
message_serializer=None, username="", password="",
kerberized_service="", headers=None, session=None,
enable_user_agent_on_connect=True, **transport_kwargs):
logging.info("Creating Client with url '%s'", url)
self._closed = False
self._url = url
self._headers = headers
self._enable_user_agent_on_connect = enable_user_agent_on_connect
self._traversal_source = traversal_source
if "max_content_length" not in transport_kwargs:
transport_kwargs["max_content_length"] = 10 * 1024 * 1024
if message_serializer is None:
message_serializer = serializer.GraphSONSerializersV3d0()
self._message_serializer = message_serializer
self._username = username
self._password = password
self._session = session
self._session_enabled = (session is not None and session != "")
if transport_factory is None:
try:
from gremlin_python.driver.aiohttp.transport import (
AiohttpTransport)
except ImportError:
raise Exception("Please install AIOHTTP or pass "
"custom transport factory")
else:
def transport_factory():
return AiohttpTransport(**transport_kwargs)
self._transport_factory = transport_factory
if protocol_factory is None:
def protocol_factory(): return protocol.GremlinServerWSProtocol(
self._message_serializer,
username=self._username,
password=self._password,
kerberized_service=kerberized_service)
self._protocol_factory = protocol_factory
if self._session_enabled:
if pool_size is None:
pool_size = 1
elif pool_size != 1:
raise Exception("PoolSize must be 1 on session mode!")
if pool_size is None:
pool_size = 4
self._pool_size = pool_size
# This is until concurrent.futures backport 3.1.0 release
if max_workers is None:
# If your application is overlapping Gremlin I/O on multiple threads
# consider passing kwarg max_workers = (cpu_count() or 1) * 5
max_workers = pool_size
self._executor = ThreadPoolExecutor(max_workers=max_workers)
# Threadsafe queue
self._pool = queue.Queue()
self._fill_pool()
@property
def available_pool_size(self):
return self._pool.qsize()
@property
def executor(self):
return self._executor
@property
def traversal_source(self):
return self._traversal_source
def _fill_pool(self):
for i in range(self._pool_size):
conn = self._get_connection()
self._pool.put_nowait(conn)
def is_closed(self):
return self._closed
def close(self):
# prevent the Client from being closed more than once. it raises errors if new jobby jobs
# get submitted to the executor when it is shutdown
if self._closed:
return
if self._session_enabled:
self._close_session()
logging.info("Closing Client with url '%s'", self._url)
while not self._pool.empty():
conn = self._pool.get(True)
conn.close()
self._executor.shutdown()
self._closed = True
def _close_session(self):
message = request.RequestMessage(
processor='session', op='close',
args={'session': str(self._session)})
conn = self._pool.get(True)
try:
write_result_set = conn.write(message).result()
return write_result_set.all().result() # wait for _receive() to finish
except protocol.GremlinServerError:
pass
def _get_connection(self):
protocol = self._protocol_factory()
return connection.Connection(
self._url, self._traversal_source, protocol,
self._transport_factory, self._executor, self._pool,
headers=self._headers, enable_user_agent_on_connect=self._enable_user_agent_on_connect)
def submit(self, message, bindings=None, request_options=None):
return self.submit_async(message, bindings=bindings, request_options=request_options).result()
def submitAsync(self, message, bindings=None, request_options=None):
warnings.warn(
"gremlin_python.driver.client.Client.submitAsync will be replaced by "
"gremlin_python.driver.client.Client.submit_async.",
DeprecationWarning)
return self.submit_async(message, bindings, request_options)
def submit_async(self, message, bindings=None, request_options=None):
if self.is_closed():
raise Exception("Client is closed")
logging.debug("message '%s'", str(message))
args = {'gremlin': message, 'aliases': {'g': self._traversal_source}}
processor = ''
op = 'eval'
if isinstance(message, traversal.Bytecode):
op = 'bytecode'
processor = 'traversal'
if isinstance(message, str) and bindings:
args['bindings'] = bindings
if self._session_enabled:
args['session'] = str(self._session)
processor = 'session'
if isinstance(message, traversal.Bytecode) or isinstance(message, str):
logging.debug("processor='%s', op='%s', args='%s'", str(processor), str(op), str(args))
message = request.RequestMessage(processor=processor, op=op, args=args)
conn = self._pool.get(True)
if request_options:
message.args.update(request_options)
return conn.write(message)