blob: 3f90188ce5eedddb0a0adeaa53b6453794e214e4 [file] [log] [blame]
# vim: sw=4:expandtab:foldmethod=marker
#
# Copyright (c) 2007-2009, Mathieu Fenniak
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * The name of the author may not be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
__author__ = "Mathieu Fenniak"
import datetime
import time
import interface
import types
import threading
from errors import *
from warnings import warn
##
# The DBAPI level supported. Currently 2.0. This property is part of the
# DBAPI 2.0 specification.
apilevel = "2.0"
##
# Integer constant stating the level of thread safety the DBAPI interface
# supports. This DBAPI interface supports sharing of the module, connections,
# and cursors. This property is part of the DBAPI 2.0 specification.
threadsafety = 3
##
# String property stating the type of parameter marker formatting expected by
# the interface. This value defaults to "format". This property is part of
# the DBAPI 2.0 specification.
# <p>
# Unlike the DBAPI specification, this value is not constant. It can be
# changed to any standard paramstyle value (ie. qmark, numeric, named, format,
# and pyformat).
paramstyle = 'format' # paramstyle can be changed to any DB-API paramstyle
def convert_paramstyle(src_style, query, args):
# I don't see any way to avoid scanning the query string char by char,
# so we might as well take that careful approach and create a
# state-based scanner. We'll use int variables for the state.
# 0 -- outside quoted string
# 1 -- inside single-quote string '...'
# 2 -- inside quoted identifier "..."
# 3 -- inside escaped single-quote string, E'...'
if args is None:
return query, args
state = 0
output_query = ""
output_args = []
if src_style == "numeric":
output_args = args
elif src_style in ("pyformat", "named"):
mapping_to_idx = {}
i = 0
while 1:
if i == len(query):
break
c = query[i]
# print "begin loop", repr(i), repr(c), repr(state)
if state == 0:
if c == "'":
i += 1
output_query += c
state = 1
elif c == '"':
i += 1
output_query += c
state = 2
elif c == 'E':
# check for escaped single-quote string
i += 1
if i < len(query) and i > 1 and query[i] == "'":
i += 1
output_query += "E'"
state = 3
else:
output_query += c
elif src_style == "qmark" and c == "?":
i += 1
param_idx = len(output_args)
if param_idx == len(args):
raise QueryParameterIndexError("too many parameter fields, not enough parameters")
output_args.append(args[param_idx])
output_query += "$" + str(param_idx + 1)
elif src_style == "numeric" and c == ":":
i += 1
if i < len(query) and i > 1 and query[i].isdigit():
output_query += "$" + query[i]
i += 1
else:
raise QueryParameterParseError("numeric parameter : does not have numeric arg")
elif src_style == "named" and c == ":":
name = ""
while 1:
i += 1
if i == len(query):
break
c = query[i]
if c.isalnum() or c == '_':
name += c
else:
break
if name == "":
raise QueryParameterParseError("empty name of named parameter")
idx = mapping_to_idx.get(name)
if idx == None:
idx = len(output_args)
output_args.append(args[name])
idx += 1
mapping_to_idx[name] = idx
output_query += "$" + str(idx)
elif src_style == "format" and c == "%":
i += 1
if i < len(query) and i > 1:
if query[i] == "s":
param_idx = len(output_args)
if param_idx == len(args):
raise QueryParameterIndexError("too many parameter fields, not enough parameters")
output_args.append(args[param_idx])
output_query += "$" + str(param_idx + 1)
elif query[i] == "%":
output_query += "%"
else:
raise QueryParameterParseError("Only %s and %% are supported")
i += 1
else:
raise QueryParameterParseError("format parameter % does not have format code")
elif src_style == "pyformat" and c == "%":
i += 1
if i < len(query) and i > 1:
if query[i] == "(":
i += 1
# begin mapping name
end_idx = query.find(')', i)
if end_idx == -1:
raise QueryParameterParseError("began pyformat dict read, but couldn't find end of name")
else:
name = query[i:end_idx]
i = end_idx + 1
if i < len(query) and query[i] == "s":
i += 1
idx = mapping_to_idx.get(name)
if idx == None:
idx = len(output_args)
output_args.append(args[name])
idx += 1
mapping_to_idx[name] = idx
output_query += "$" + str(idx)
else:
raise QueryParameterParseError("format not specified or not supported (only %(...)s supported)")
elif query[i] == "%":
output_query += "%"
elif query[i] == "s":
# we have a %s in a pyformat query string. Assume
# support for format instead.
i -= 1
src_style = "format"
else:
raise QueryParameterParseError("Only %(name)s, %s and %% are supported")
else:
i += 1
output_query += c
elif state == 1:
output_query += c
i += 1
if c == "'":
# Could be a double ''
if i < len(query) and query[i] == "'":
# is a double quote.
output_query += query[i]
i += 1
else:
state = 0
elif src_style in ("pyformat","format") and c == "%":
# hm... we're only going to support an escaped percent sign
if i < len(query):
if query[i] == "%":
# good. We already output the first percent sign.
i += 1
else:
raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string")
elif state == 2:
output_query += c
i += 1
if c == '"':
state = 0
elif src_style in ("pyformat","format") and c == "%":
# hm... we're only going to support an escaped percent sign
if i < len(query):
if query[i] == "%":
# good. We already output the first percent sign.
i += 1
else:
raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string")
elif state == 3:
output_query += c
i += 1
if c == "\\":
# check for escaped single-quote
if i < len(query) and query[i] == "'":
output_query += "'"
i += 1
elif c == "'":
state = 0
elif src_style in ("pyformat","format") and c == "%":
# hm... we're only going to support an escaped percent sign
if i < len(query):
if query[i] == "%":
# good. We already output the first percent sign.
i += 1
else:
raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string")
return output_query, tuple(output_args)
def require_open_cursor(fn):
def _fn(self, *args, **kwargs):
if self.cursor == None:
raise CursorClosedError()
return fn(self, *args, **kwargs)
return _fn
##
# The class of object returned by the {@link #ConnectionWrapper.cursor cursor method}.
class CursorWrapper(object):
def __init__(self, conn, connection):
self.cursor = interface.Cursor(conn)
self.arraysize = 1
self._connection = connection
self._override_rowcount = None
##
# This read-only attribute returns a reference to the connection object on
# which the cursor was created.
# <p>
# Stability: Part of a DBAPI 2.0 extension. A warning "DB-API extension
# cursor.connection used" will be fired.
connection = property(lambda self: self._getConnection())
def _getConnection(self):
warn("DB-API extension cursor.connection used", stacklevel=3)
return self._connection
##
# This read-only attribute specifies the number of rows that the last
# .execute*() produced (for DQL statements like 'select') or affected (for
# DML statements like 'update' or 'insert').
# <p>
# The attribute is -1 in case no .execute*() has been performed on the
# cursor or the rowcount of the last operation is cannot be determined by
# the interface.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
rowcount = property(lambda self: self._getRowCount())
@require_open_cursor
def _getRowCount(self):
if self._override_rowcount != None:
return self._override_rowcount
return self.cursor.row_count
##
# This read-only attribute is a sequence of 7-item sequences. Each value
# contains information describing one result column. The 7 items returned
# for each column are (name, type_code, display_size, internal_size,
# precision, scale, null_ok). Only the first two values are provided by
# this interface implementation.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
description = property(lambda self: self._getDescription())
@require_open_cursor
def _getDescription(self):
if self.cursor.row_description == None:
return None
columns = []
for col in self.cursor.row_description:
columns.append((col["name"], col["type_oid"], None, None, None, None, None))
return columns
##
# Executes a database operation. Parameters may be provided as a sequence
# or mapping and will be bound to variables in the operation.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def execute(self, operation, args=()):
self._override_rowcount = None
self._execute(operation, args)
def _execute(self, operation, args=()):
new_query, new_args = convert_paramstyle(paramstyle, operation, args)
try:
self.cursor.execute(new_query, *new_args)
except ConnectionClosedError:
# can't rollback in this case
raise
except:
# any error will rollback the transaction to-date
self._connection.rollback()
raise
def copy_from(self, fileobj, table=None, sep='\t', null=None, query=None):
if query == None:
if table == None:
raise CopyQueryOrTableRequiredError()
query = "COPY %s FROM stdout DELIMITER '%s'" % (table, sep)
if null is not None:
query += " NULL '%s'" % (null,)
self.copy_execute(fileobj, query)
def copy_to(self, fileobj, table=None, sep='\t', null=None, query=None):
if query == None:
if table == None:
raise CopyQueryOrTableRequiredError()
query = "COPY %s TO stdout DELIMITER '%s'" % (table, sep)
if null is not None:
query += " NULL '%s'" % (null,)
self.copy_execute(fileobj, query)
@require_open_cursor
def copy_execute(self, fileobj, query):
try:
self.cursor.execute(query, stream=fileobj)
except ConnectionClosedError:
# can't rollback in this case
raise
except:
# any error will rollback the transaction to-date
import traceback; traceback.print_exc()
self._connection.rollback()
raise
##
# Prepare a database operation and then execute it against all parameter
# sequences or mappings provided.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def executemany(self, operation, parameter_sets):
self._override_rowcount = 0
for parameters in parameter_sets:
self._execute(operation, parameters)
if self.cursor.row_count == -1 or self._override_rowcount == -1:
self._override_rowcount = -1
else:
self._override_rowcount += self.cursor.row_count
##
# Fetch the next row of a query result set, returning a single sequence, or
# None when no more data is available.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def fetchone(self):
return self.cursor.read_tuple()
##
# Fetch the next set of rows of a query result, returning a sequence of
# sequences. An empty sequence is returned when no more rows are
# available.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
# @param size The number of rows to fetch when called. If not provided,
# the arraysize property value is used instead.
def fetchmany(self, size=None):
if size == None:
size = self.arraysize
rows = []
for i in range(size):
value = self.fetchone()
if value == None:
break
rows.append(value)
return rows
##
# Fetch all remaining rows of a query result, returning them as a sequence
# of sequences.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def fetchall(self):
return tuple(self.cursor.iterate_tuple())
##
# Close the cursor.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def close(self):
self.cursor.close()
self.cursor = None
self._override_rowcount = None
def next(self):
warn("DB-API extension cursor.next() used", stacklevel=2)
retval = self.fetchone()
if retval == None:
raise StopIteration()
return retval
def __iter__(self):
warn("DB-API extension cursor.__iter__() used", stacklevel=2)
return self
def setinputsizes(self, sizes):
pass
def setoutputsize(self, size, column=None):
pass
@require_open_cursor
def fileno(self):
return self.cursor.fileno()
@require_open_cursor
def isready(self):
return self.cursor.isready()
def require_open_connection(fn):
def _fn(self, *args, **kwargs):
if self.conn == None:
raise ConnectionClosedError()
return fn(self, *args, **kwargs)
return _fn
##
# The class of object returned by the {@link #connect connect method}.
class ConnectionWrapper(object):
# DBAPI Extension: supply exceptions as attributes on the connection
Warning = property(lambda self: self._getError(Warning))
Error = property(lambda self: self._getError(Error))
InterfaceError = property(lambda self: self._getError(InterfaceError))
DatabaseError = property(lambda self: self._getError(DatabaseError))
OperationalError = property(lambda self: self._getError(OperationalError))
IntegrityError = property(lambda self: self._getError(IntegrityError))
InternalError = property(lambda self: self._getError(InternalError))
ProgrammingError = property(lambda self: self._getError(ProgrammingError))
NotSupportedError = property(lambda self: self._getError(NotSupportedError))
def _getError(self, error):
warn("DB-API extension connection.%s used" % error.__name__, stacklevel=3)
return error
def __init__(self, **kwargs):
self.conn = interface.Connection(**kwargs)
self.notifies = []
self.notifies_lock = threading.Lock()
self.conn.NotificationReceived += self._notificationReceived
self.conn.begin()
def _notificationReceived(self, notice):
try:
# psycopg2 compatible notification interface
self.notifies_lock.acquire()
self.notifies.append((notice.backend_pid, notice.condition))
finally:
self.notifies_lock.release()
##
# Creates a {@link #CursorWrapper CursorWrapper} object bound to this
# connection.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def cursor(self):
return CursorWrapper(self.conn, self)
##
# Commits the current database transaction.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def commit(self):
# There's a threading bug here. If a query is sent after the
# commit, but before the begin, it will be executed immediately
# without a surrounding transaction. Like all threading bugs -- it
# sounds unlikely, until it happens every time in one
# application... however, to fix this, we need to lock the
# database connection entirely, so that no cursors can execute
# statements on other threads. Support for that type of lock will
# be done later.
self.conn.commit()
self.conn.begin()
##
# Rolls back the current database transaction.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def rollback(self):
# see bug description in commit.
self.conn.rollback()
self.conn.begin()
##
# Closes the database connection.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def close(self):
self.conn.close()
self.conn = None
@require_open_connection
def recache_record_types(self):
self.conn.recache_record_types()
##
# Creates a DBAPI 2.0 compatible interface to a PostgreSQL database.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
#
# @param user The username to connect to the PostgreSQL server with. This
# parameter is required.
#
# @keyparam host The hostname of the PostgreSQL server to connect with.
# Providing this parameter is necessary for TCP/IP connections. One of either
# host, or unix_sock, must be provided.
#
# @keyparam unix_sock The path to the UNIX socket to access the database
# through, for example, '/tmp/.s.PGSQL.5432'. One of either unix_sock or host
# must be provided. The port parameter will have no affect if unix_sock is
# provided.
#
# @keyparam port The TCP/IP port of the PostgreSQL server instance. This
# parameter defaults to 5432, the registered and common port of PostgreSQL
# TCP/IP servers.
#
# @keyparam database The name of the database instance to connect with. This
# parameter is optional, if omitted the PostgreSQL server will assume the
# database name is the same as the username.
#
# @keyparam password The user password to connect to the server with. This
# parameter is optional. If omitted, and the database server requests password
# based authentication, the connection will fail. On the other hand, if this
# parameter is provided and the database does not request password
# authentication, then the password will not be used.
#
# @keyparam socket_timeout Socket connect timeout measured in seconds.
# Defaults to 60 seconds.
#
# @keyparam ssl Use SSL encryption for TCP/IP socket. Defaults to False.
#
# @return An instance of {@link #ConnectionWrapper ConnectionWrapper}.
def connect(user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False, options=None):
return ConnectionWrapper(user=user, host=host,
unix_sock=unix_sock, port=port, database=database,
password=password, socket_timeout=socket_timeout, ssl=ssl, options=options)
def Date(year, month, day):
return datetime.date(year, month, day)
def Time(hour, minute, second):
return datetime.time(hour, minute, second)
def Timestamp(year, month, day, hour, minute, second):
return datetime.datetime(year, month, day, hour, minute, second)
def DateFromTicks(ticks):
return Date(*time.localtime(ticks)[:3])
def TimeFromTicks(ticks):
return Time(*time.localtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return Timestamp(*time.localtime(ticks)[:6])
##
# Construct an object holding binary data.
def Binary(value):
return types.Bytea(value)
# I have no idea what this would be used for by a client app. Should it be
# TEXT, VARCHAR, CHAR? It will only compare against row_description's
# type_code if it is this one type. It is the varchar type oid for now, this
# appears to match expectations in the DB API 2.0 compliance test suite.
STRING = 1043
# bytea type_oid
BINARY = 17
# numeric type_oid
NUMBER = 1700
# timestamp type_oid
DATETIME = 1114
# oid type_oid
ROWID = 26