blob: 377074bd651d927cbda90b266a94f3371a21f652 [file] [log] [blame]
# vim: sw=4:expandtab:foldmethod=marker
#
# Copyright (c) 2007-2009, Mathieu Fenniak
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * The name of the author may not be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
__author__ = "Mathieu Fenniak"
import socket
import select
import threading
import struct
import hashlib
from cStringIO import StringIO
from errors import *
from util import MulticastDelegate
import types
##
# An SSLRequest message. To initiate an SSL-encrypted connection, an
# SSLRequest message is used rather than a {@link StartupMessage
# StartupMessage}. A StartupMessage is still sent, but only after SSL
# negotiation (if accepted).
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class SSLRequest(object):
def __init__(self):
pass
# Int32(8) - Message length, including self.<br>
# Int32(80877103) - The SSL request code.<br>
def serialize(self):
return struct.pack("!ii", 8, 80877103)
##
# A StartupMessage message. Begins a DB session, identifying the user to be
# authenticated as and the database to connect to.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class StartupMessage(object):
# Greenplum utility mode
def __init__(self, user, database=None, options=None):
self.user = user
self.database = database
self.options = options
# Int32 - Message length, including self.
# Int32(196608) - Protocol version number. Version 3.0.
# Any number of key/value pairs, terminated by a zero byte:
# String - A parameter name (user, database, or options)
# String - Parameter value
def serialize(self):
protocol = 196608
val = struct.pack("!i", protocol)
val += "user\x00" + self.user + "\x00"
if self.database:
val += "database\x00" + self.database + "\x00"
if self.options:
val += "options\x00" + self.options + "\x00"
val += "\x00"
val = struct.pack("!i", len(val) + 4) + val
return val
##
# Parse message. Creates a prepared statement in the DB session.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
#
# @param ps Name of the prepared statement to create.
# @param qs Query string.
# @param type_oids An iterable that contains the PostgreSQL type OIDs for
# parameters in the query string.
class Parse(object):
def __init__(self, ps, qs, type_oids):
self.ps = ps
self.qs = qs
self.type_oids = type_oids
def __repr__(self):
return "<Parse ps=%r qs=%r>" % (self.ps, self.qs)
# Byte1('P') - Identifies the message as a Parse command.
# Int32 - Message length, including self.
# String - Prepared statement name. An empty string selects the unnamed
# prepared statement.
# String - The query string.
# Int16 - Number of parameter data types specified (can be zero).
# For each parameter:
# Int32 - The OID of the parameter data type.
def serialize(self):
val = self.ps + "\x00" + self.qs + "\x00"
val = val + struct.pack("!h", len(self.type_oids))
for oid in self.type_oids:
# Parse message doesn't seem to handle the -1 type_oid for NULL
# values that other messages handle. So we'll provide type_oid 705,
# the PG "unknown" type.
if oid == -1: oid = 705
val = val + struct.pack("!i", oid)
val = struct.pack("!i", len(val) + 4) + val
val = "P" + val
return val
##
# Bind message. Readies a prepared statement for execution.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
#
# @param portal Name of the destination portal.
# @param ps Name of the source prepared statement.
# @param in_fc An iterable containing the format codes for input
# parameters. 0 = Text, 1 = Binary.
# @param params The parameters.
# @param out_fc An iterable containing the format codes for output
# parameters. 0 = Text, 1 = Binary.
# @param kwargs Additional arguments to pass to the type conversion
# methods.
class Bind(object):
def __init__(self, portal, ps, in_fc, params, out_fc, **kwargs):
self.portal = portal
self.ps = ps
self.in_fc = in_fc
self.params = []
for i in range(len(params)):
if len(self.in_fc) == 0:
fc = 0
elif len(self.in_fc) == 1:
fc = self.in_fc[0]
else:
fc = self.in_fc[i]
self.params.append(types.pg_value(params[i], fc, **kwargs))
self.out_fc = out_fc
def __repr__(self):
return "<Bind p=%r s=%r>" % (self.portal, self.ps)
# Byte1('B') - Identifies the Bind command.
# Int32 - Message length, including self.
# String - Name of the destination portal.
# String - Name of the source prepared statement.
# Int16 - Number of parameter format codes.
# For each parameter format code:
# Int16 - The parameter format code.
# Int16 - Number of parameter values.
# For each parameter value:
# Int32 - The length of the parameter value, in bytes, not including this
# this length. -1 indicates a NULL parameter value, in which no
# value bytes follow.
# Byte[n] - Value of the parameter.
# Int16 - The number of result-column format codes.
# For each result-column format code:
# Int16 - The format code.
def serialize(self):
retval = StringIO()
retval.write(self.portal + "\x00")
retval.write(self.ps + "\x00")
retval.write(struct.pack("!h", len(self.in_fc)))
for fc in self.in_fc:
retval.write(struct.pack("!h", fc))
retval.write(struct.pack("!h", len(self.params)))
for param in self.params:
if param == None:
# special case, NULL value
retval.write(struct.pack("!i", -1))
else:
retval.write(struct.pack("!i", len(param)))
retval.write(param)
retval.write(struct.pack("!h", len(self.out_fc)))
for fc in self.out_fc:
retval.write(struct.pack("!h", fc))
val = retval.getvalue()
val = struct.pack("!i", len(val) + 4) + val
val = "B" + val
return val
##
# A Close message, used for closing prepared statements and portals.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
#
# @param typ 'S' for prepared statement, 'P' for portal.
# @param name The name of the item to close.
class Close(object):
def __init__(self, typ, name):
if len(typ) != 1:
raise InternalError("Close typ must be 1 char")
self.typ = typ
self.name = name
# Byte1('C') - Identifies the message as a close command.
# Int32 - Message length, including self.
# Byte1 - 'S' for prepared statement, 'P' for portal.
# String - The name of the item to close.
def serialize(self):
val = self.typ + self.name + "\x00"
val = struct.pack("!i", len(val) + 4) + val
val = "C" + val
return val
##
# A specialized Close message for a portal.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class ClosePortal(Close):
def __init__(self, name):
Close.__init__(self, "P", name)
##
# A specialized Close message for a prepared statement.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class ClosePreparedStatement(Close):
def __init__(self, name):
Close.__init__(self, "S", name)
##
# A Describe message, used for obtaining information on prepared statements
# and portals.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
#
# @param typ 'S' for prepared statement, 'P' for portal.
# @param name The name of the item to close.
class Describe(object):
def __init__(self, typ, name):
if len(typ) != 1:
raise InternalError("Describe typ must be 1 char")
self.typ = typ
self.name = name
# Byte1('D') - Identifies the message as a describe command.
# Int32 - Message length, including self.
# Byte1 - 'S' for prepared statement, 'P' for portal.
# String - The name of the item to close.
def serialize(self):
val = self.typ + self.name + "\x00"
val = struct.pack("!i", len(val) + 4) + val
val = "D" + val
return val
##
# A specialized Describe message for a portal.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class DescribePortal(Describe):
def __init__(self, name):
Describe.__init__(self, "P", name)
def __repr__(self):
return "<DescribePortal %r>" % (self.name)
##
# A specialized Describe message for a prepared statement.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class DescribePreparedStatement(Describe):
def __init__(self, name):
Describe.__init__(self, "S", name)
def __repr__(self):
return "<DescribePreparedStatement %r>" % (self.name)
##
# A Flush message forces the backend to deliver any data pending in its
# output buffers.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class Flush(object):
# Byte1('H') - Identifies the message as a flush command.
# Int32(4) - Length of message, including self.
def serialize(self):
return 'H\x00\x00\x00\x04'
def __repr__(self):
return "<Flush>"
##
# Causes the backend to close the current transaction (if not in a BEGIN/COMMIT
# block), and issue ReadyForQuery.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class Sync(object):
# Byte1('S') - Identifies the message as a sync command.
# Int32(4) - Length of message, including self.
def serialize(self):
return 'S\x00\x00\x00\x04'
def __repr__(self):
return "<Sync>"
##
# Transmits a password.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class PasswordMessage(object):
def __init__(self, pwd):
self.pwd = pwd
# Byte1('p') - Identifies the message as a password message.
# Int32 - Message length including self.
# String - The password. Password may be encrypted.
def serialize(self):
val = self.pwd + "\x00"
val = struct.pack("!i", len(val) + 4) + val
val = "p" + val
return val
##
# Requests that the backend execute a portal and retrieve any number of rows.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
# @param row_count The number of rows to return. Can be zero to indicate the
# backend should return all rows. If the portal represents a
# query that does not return rows, no rows will be returned
# no matter what the row_count.
class Execute(object):
def __init__(self, portal, row_count):
self.portal = portal
self.row_count = row_count
# Byte1('E') - Identifies the message as an execute message.
# Int32 - Message length, including self.
# String - The name of the portal to execute.
# Int32 - Maximum number of rows to return, if portal contains a query that
# returns rows. 0 = no limit.
def serialize(self):
val = self.portal + "\x00" + struct.pack("!i", self.row_count)
val = struct.pack("!i", len(val) + 4) + val
val = "E" + val
return val
##
# Informs the backend that the connection is being closed.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class Terminate(object):
def __init__(self):
pass
# Byte1('X') - Identifies the message as a terminate message.
# Int32(4) - Message length, including self.
def serialize(self):
return 'X\x00\x00\x00\x04'
##
# Base class of all Authentication[*] messages.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class AuthenticationRequest(object):
def __init__(self, data):
pass
# Byte1('R') - Identifies the message as an authentication request.
# Int32(8) - Message length, including self.
# Int32 - An authentication code that represents different
# authentication messages:
# 0 = AuthenticationOk
# 5 = MD5 pwd
# 2 = Kerberos v5 (not supported by pg8000)
# 3 = Cleartext pwd (not supported by pg8000)
# 4 = crypt() pwd (not supported by pg8000)
# 6 = SCM credential (not supported by pg8000)
# 7 = GSSAPI (not supported by pg8000)
# 8 = GSSAPI data (not supported by pg8000)
# 9 = SSPI (not supported by pg8000)
# Some authentication messages have additional data following the
# authentication code. That data is documented in the appropriate class.
def createFromData(data):
ident = struct.unpack("!i", data[:4])[0]
klass = authentication_codes.get(ident, None)
if klass != None:
return klass(data[4:])
else:
raise NotSupportedError("authentication method %r not supported" % (ident,))
createFromData = staticmethod(createFromData)
def ok(self, conn, user, **kwargs):
raise InternalError("ok method should be overridden on AuthenticationRequest instance")
##
# A message representing that the backend accepting the provided username
# without any challenge.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class AuthenticationOk(AuthenticationRequest):
def ok(self, conn, user, **kwargs):
return True
##
# A message representing the backend requesting an MD5 hashed password
# response. The response will be sent as md5(md5(pwd + login) + salt).
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class AuthenticationMD5Password(AuthenticationRequest):
# Additional message data:
# Byte4 - Hash salt.
def __init__(self, data):
self.salt = "".join(struct.unpack("4c", data))
def ok(self, conn, user, password=None, **kwargs):
if password == None:
raise InterfaceError("server requesting MD5 password authentication, but no password was provided")
pwd = "md5" + hashlib.md5(hashlib.md5(password + user).hexdigest() + self.salt).hexdigest()
conn._send(PasswordMessage(pwd))
conn._flush()
reader = MessageReader(conn)
reader.add_message(AuthenticationRequest, lambda msg, reader: reader.return_value(msg.ok(conn, user)), reader)
reader.add_message(ErrorResponse, self._ok_error)
return reader.handle_messages()
def _ok_error(self, msg):
if msg.code == "28000":
raise InterfaceError("md5 password authentication failed")
else:
raise msg.createException()
authentication_codes = {
0: AuthenticationOk,
5: AuthenticationMD5Password,
}
##
# ParameterStatus message sent from backend, used to inform the frotnend of
# runtime configuration parameter changes.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class ParameterStatus(object):
def __init__(self, key, value):
self.key = key
self.value = value
# Byte1('S') - Identifies ParameterStatus
# Int32 - Message length, including self.
# String - Runtime parameter name.
# String - Runtime parameter value.
def createFromData(data):
key = data[:data.find("\x00")]
value = data[data.find("\x00")+1:-1]
return ParameterStatus(key, value)
createFromData = staticmethod(createFromData)
##
# BackendKeyData message sent from backend. Contains a connection's process
# ID and a secret key. Can be used to terminate the connection's current
# actions, such as a long running query. Not supported by pg8000 yet.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class BackendKeyData(object):
def __init__(self, process_id, secret_key):
self.process_id = process_id
self.secret_key = secret_key
# Byte1('K') - Identifier.
# Int32(12) - Message length, including self.
# Int32 - Process ID.
# Int32 - Secret key.
def createFromData(data):
process_id, secret_key = struct.unpack("!2i", data)
return BackendKeyData(process_id, secret_key)
createFromData = staticmethod(createFromData)
##
# Message representing a query with no data.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class NoData(object):
# Byte1('n') - Identifier.
# Int32(4) - Message length, including self.
def createFromData(data):
return NoData()
createFromData = staticmethod(createFromData)
##
# Message representing a successful Parse.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class ParseComplete(object):
# Byte1('1') - Identifier.
# Int32(4) - Message length, including self.
def createFromData(data):
return ParseComplete()
createFromData = staticmethod(createFromData)
##
# Message representing a successful Bind.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class BindComplete(object):
# Byte1('2') - Identifier.
# Int32(4) - Message length, including self.
def createFromData(data):
return BindComplete()
createFromData = staticmethod(createFromData)
##
# Message representing a successful Close.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class CloseComplete(object):
# Byte1('3') - Identifier.
# Int32(4) - Message length, including self.
def createFromData(data):
return CloseComplete()
createFromData = staticmethod(createFromData)
##
# Message representing data from an Execute has been received, but more data
# exists in the portal.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class PortalSuspended(object):
# Byte1('s') - Identifier.
# Int32(4) - Message length, including self.
def createFromData(data):
return PortalSuspended()
createFromData = staticmethod(createFromData)
##
# Message representing the backend is ready to process a new query.
# <p>
# Stability: This is an internal class. No stability guarantee is made.
class ReadyForQuery(object):
def __init__(self, status):
self._status = status
##
# I = Idle, T = Idle in Transaction, E = idle in failed transaction.
status = property(lambda self: self._status)
def __repr__(self):
return "<ReadyForQuery %s>" % \
{"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status]
# Byte1('Z') - Identifier.
# Int32(5) - Message length, including self.
# Byte1 - Status indicator.
def createFromData(data):
return ReadyForQuery(data)
createFromData = staticmethod(createFromData)
##
# Represents a notice sent from the server. This is not the same as a
# notification. A notice is just additional information about a query, such
# as a notice that a primary key has automatically been created for a table.
# <p>
# A NoticeResponse instance will have properties containing the data sent
# from the server:
# <ul>
# <li>severity -- "ERROR", "FATAL', "PANIC", "WARNING", "NOTICE", "DEBUG",
# "INFO", or "LOG". Always present.</li>
# <li>code -- the SQLSTATE code for the error. See Appendix A of the
# PostgreSQL documentation for specific error codes. Always present.</li>
# <li>msg -- human-readable error message. Always present.</li>
# <li>detail -- Optional additional information.</li>
# <li>hint -- Optional suggestion about what to do about the issue.</li>
# <li>position -- Optional index into the query string.</li>
# <li>where -- Optional context.</li>
# <li>file -- Source-code file.</li>
# <li>line -- Source-code line.</li>
# <li>routine -- Source-code routine.</li>
# </ul>
# <p>
# Stability: Added in pg8000 v1.03. Required properties severity, code, and
# msg are guaranteed for v1.xx. Other properties should be checked with
# hasattr before accessing.
class NoticeResponse(object):
responseKeys = {
"S": "severity", # always present
"C": "code", # always present
"M": "msg", # always present
"D": "detail",
"H": "hint",
"P": "position",
"p": "_position",
"q": "_query",
"W": "where",
"F": "file",
"L": "line",
"R": "routine",
}
def __init__(self, **kwargs):
for arg, value in kwargs.items():
setattr(self, arg, value)
def __repr__(self):
return "<NoticeResponse %s %s %r>" % (self.severity, self.code, self.msg)
def dataIntoDict(data):
retval = {}
for s in data.split("\x00"):
if not s: continue
key, value = s[0], s[1:]
key = NoticeResponse.responseKeys.get(key, key)
retval[key] = value
return retval
dataIntoDict = staticmethod(dataIntoDict)
# Byte1('N') - Identifier
# Int32 - Message length
# Any number of these, followed by a zero byte:
# Byte1 - code identifying the field type (see responseKeys)
# String - field value
def createFromData(data):
return NoticeResponse(**NoticeResponse.dataIntoDict(data))
createFromData = staticmethod(createFromData)
##
# A message sent in case of a server-side error. Contains the same properties
# that {@link NoticeResponse NoticeResponse} contains.
# <p>
# Stability: Added in pg8000 v1.03. Required properties severity, code, and
# msg are guaranteed for v1.xx. Other properties should be checked with
# hasattr before accessing.
class ErrorResponse(object):
def __init__(self, **kwargs):
for arg, value in kwargs.items():
setattr(self, arg, value)
def __repr__(self):
return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg)
def createException(self):
return ProgrammingError(self.severity, self.code, self.msg)
def createFromData(data):
return ErrorResponse(**NoticeResponse.dataIntoDict(data))
createFromData = staticmethod(createFromData)
##
# A message sent if this connection receives a NOTIFY that it was LISTENing for.
# <p>
# Stability: Added in pg8000 v1.03. When limited to accessing properties from
# a notification event dispatch, stability is guaranteed for v1.xx.
class NotificationResponse(object):
def __init__(self, backend_pid, condition, additional_info):
self._backend_pid = backend_pid
self._condition = condition
self._additional_info = additional_info
##
# An integer representing the process ID of the backend that triggered
# the NOTIFY.
# <p>
# Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx.
backend_pid = property(lambda self: self._backend_pid)
##
# The name of the notification fired.
# <p>
# Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx.
condition = property(lambda self: self._condition)
##
# Currently unspecified by the PostgreSQL documentation as of v8.3.1.
# <p>
# Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx.
additional_info = property(lambda self: self._additional_info)
def __repr__(self):
return "<NotificationResponse %s %s %r>" % (self.backend_pid, self.condition, self.additional_info)
def createFromData(data):
backend_pid = struct.unpack("!i", data[:4])[0]
data = data[4:]
null = data.find("\x00")
condition = data[:null]
data = data[null+1:]
null = data.find("\x00")
additional_info = data[:null]
return NotificationResponse(backend_pid, condition, additional_info)
createFromData = staticmethod(createFromData)
class ParameterDescription(object):
def __init__(self, type_oids):
self.type_oids = type_oids
def createFromData(data):
count = struct.unpack("!h", data[:2])[0]
type_oids = struct.unpack("!" + "i"*count, data[2:])
return ParameterDescription(type_oids)
createFromData = staticmethod(createFromData)
class RowDescription(object):
def __init__(self, fields):
self.fields = fields
def createFromData(data):
count = struct.unpack("!h", data[:2])[0]
data = data[2:]
fields = []
for i in range(count):
null = data.find("\x00")
field = {"name": data[:null]}
data = data[null+1:]
field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18])
data = data[18:]
fields.append(field)
return RowDescription(fields)
createFromData = staticmethod(createFromData)
class CommandComplete(object):
def __init__(self, command, rows=None, oid=None):
self.command = command
self.rows = rows
self.oid = oid
def createFromData(data):
values = data[:-1].split(" ")
args = {}
args['command'] = values[0]
if args['command'] in ("INSERT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY"):
args['rows'] = int(values[-1])
if args['command'] == "INSERT":
args['oid'] = int(values[1])
else:
args['command'] = data[:-1]
return CommandComplete(**args)
createFromData = staticmethod(createFromData)
class DataRow(object):
def __init__(self, fields):
self.fields = fields
def createFromData(data):
count = struct.unpack("!h", data[:2])[0]
data = data[2:]
fields = []
for i in range(count):
val_len = struct.unpack("!i", data[:4])[0]
data = data[4:]
if val_len == -1:
fields.append(None)
else:
fields.append(data[:val_len])
data = data[val_len:]
return DataRow(fields)
createFromData = staticmethod(createFromData)
class CopyData(object):
# "d": CopyData,
def __init__(self, data):
self.data = data
def createFromData(data):
return CopyData(data)
createFromData = staticmethod(createFromData)
def serialize(self):
return 'd' + struct.pack('!i', len(self.data) + 4) + self.data
class CopyDone(object):
# Byte1('c') - Identifier.
# Int32(4) - Message length, including self.
def createFromData(data):
return CopyDone()
createFromData = staticmethod(createFromData)
def serialize(self):
return 'c\x00\x00\x00\x04'
class CopyOutResponse(object):
# Byte1('H')
# Int32(4) - Length of message contents in bytes, including self.
# Int8(1) - 0 textual, 1 binary
# Int16(2) - Number of columns
# Int16(N) - Format codes for each column (0 text, 1 binary)
def __init__(self, is_binary, column_formats):
self.is_binary = is_binary
self.column_formats = column_formats
def createFromData(data):
is_binary, num_cols = struct.unpack('!bh', data[:3])
column_formats = struct.unpack('!' + ('h' * num_cols), data[3:])
return CopyOutResponse(is_binary, column_formats)
createFromData = staticmethod(createFromData)
class CopyInResponse(object):
# Byte1('G')
# Otherwise the same as CopyOutResponse
def __init__(self, is_binary, column_formats):
self.is_binary = is_binary
self.column_formats = column_formats
def createFromData(data):
is_binary, num_cols = struct.unpack('!bh', data[:3])
column_formats = struct.unpack('!' + ('h' * num_cols), data[3:])
return CopyInResponse(is_binary, column_formats)
createFromData = staticmethod(createFromData)
class SSLWrapper(object):
def __init__(self, sslobj):
self.sslobj = sslobj
def send(self, data):
self.sslobj.write(data)
def recv(self, num):
return self.sslobj.read(num)
class MessageReader(object):
def __init__(self, connection):
self._conn = connection
self._msgs = []
# If true, raise exception from an ErrorResponse after messages are
# processed. This can be used to leave the connection in a usable
# state after an error response, rather than having unconsumed
# messages that won't be understood in another context.
self.delay_raising_exception = False
self.ignore_unhandled_messages = False
def add_message(self, msg_class, handler, *args, **kwargs):
self._msgs.append((msg_class, handler, args, kwargs))
def clear_messages(self):
self._msgs = []
def return_value(self, value):
self._retval = value
def handle_messages(self):
exc = None
while 1:
msg = self._conn._read_message()
msg_handled = False
for (msg_class, handler, args, kwargs) in self._msgs:
if isinstance(msg, msg_class):
msg_handled = True
retval = handler(msg, *args, **kwargs)
if retval:
# The handler returned a true value, meaning that the
# message loop should be aborted.
if exc != None:
raise exc
return retval
elif hasattr(self, "_retval"):
# The handler told us to return -- used for non-true
# return values
if exc != None:
raise exc
return self._retval
if msg_handled:
continue
elif isinstance(msg, ErrorResponse):
exc = msg.createException()
if not self.delay_raising_exception:
raise exc
elif isinstance(msg, NoticeResponse):
self._conn.handleNoticeResponse(msg)
elif isinstance(msg, ParameterStatus):
self._conn.handleParameterStatus(msg)
elif isinstance(msg, NotificationResponse):
self._conn.handleNotificationResponse(msg)
elif not self.ignore_unhandled_messages:
raise InternalError("Unexpected response msg %r" % (msg))
def sync_on_error(fn):
def _fn(self, *args, **kwargs):
try:
self._sock_lock.acquire()
return fn(self, *args, **kwargs)
except:
self._sync()
raise
finally:
self._sock_lock.release()
return _fn
class Connection(object):
def __init__(self, unix_sock=None, host=None, port=5432, socket_timeout=60, ssl=False, records=False):
self._client_encoding = "ascii"
self._integer_datetimes = False
self._record_field_names = {}
self._sock_buf = ""
self._sock_buf_pos = 0
self._send_sock_buf = []
self._block_size = 8192
self.user_wants_records = records
if unix_sock == None and host != None:
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
elif unix_sock != None:
if not hasattr(socket, "AF_UNIX"):
raise InterfaceError("attempt to connect to unix socket on unsupported platform")
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
raise ProgrammingError("one of host or unix_sock must be provided")
if unix_sock == None and host != None:
self._sock.connect((host, port))
elif unix_sock != None:
self._sock.connect(unix_sock)
if ssl:
self._send(SSLRequest())
self._flush()
resp = self._sock.recv(1)
if resp == 'S':
self._sock = SSLWrapper(socket.ssl(self._sock))
else:
raise InterfaceError("server refuses SSL")
else:
# settimeout causes ssl failure, on windows. Python bug 1462352.
self._sock.settimeout(socket_timeout)
self._state = "noauth"
self._backend_key_data = None
self._sock_lock = threading.Lock()
self.NoticeReceived = MulticastDelegate()
self.ParameterStatusReceived = MulticastDelegate()
self.NotificationReceived = MulticastDelegate()
self.ParameterStatusReceived += self._onParameterStatusReceived
def verifyState(self, state):
if self._state != state:
raise InternalError("connection state must be %s, is %s" % (state, self._state))
def _send(self, msg):
assert self._sock_lock.locked()
#print "_send(%r)" % msg
data = msg.serialize()
self._send_sock_buf.append(data)
def _flush(self):
assert self._sock_lock.locked()
self._sock.sendall("".join(self._send_sock_buf))
del self._send_sock_buf[:]
def _read_bytes(self, byte_count):
retval = []
bytes_read = 0
while bytes_read < byte_count:
if self._sock_buf_pos == len(self._sock_buf):
self._sock_buf = self._sock.recv(1024)
self._sock_buf_pos = 0
rpos = min(len(self._sock_buf), self._sock_buf_pos + (byte_count - bytes_read))
addt_data = self._sock_buf[self._sock_buf_pos:rpos]
bytes_read += (rpos - self._sock_buf_pos)
assert bytes_read <= byte_count
self._sock_buf_pos = rpos
retval.append(addt_data)
return "".join(retval)
def _read_message(self):
assert self._sock_lock.locked()
bytes = self._read_bytes(5)
message_code = bytes[0]
data_len = struct.unpack("!i", bytes[1:])[0] - 4
bytes = self._read_bytes(data_len)
assert len(bytes) == data_len
msg = message_types[message_code].createFromData(bytes)
#print "_read_message() -> %r" % msg
return msg
def authenticate(self, user, **kwargs):
self.verifyState("noauth")
self._sock_lock.acquire()
try:
self._send(StartupMessage(user, database=kwargs.get("database",None), options=kwargs.get("options", None)))
self._flush()
msg = self._read_message()
if isinstance(msg, ErrorResponse):
raise msg.createException()
if not isinstance(msg, AuthenticationRequest):
raise InternalError("StartupMessage was responded to with non-AuthenticationRequest msg %r" % msg)
if not msg.ok(self, user, **kwargs):
raise InterfaceError("authentication method %s failed" % msg.__class__.__name__)
self._state = "auth"
reader = MessageReader(self)
reader.add_message(ReadyForQuery, self._ready_for_query)
reader.add_message(BackendKeyData, self._receive_backend_key_data)
reader.handle_messages()
finally:
self._sock_lock.release()
self._cache_record_attnames()
def _ready_for_query(self, msg):
self._state = "ready"
return True
def _receive_backend_key_data(self, msg):
self._backend_key_data = msg
def _cache_record_attnames(self):
if not self.user_wants_records:
return
parse_retval = self.parse("",
"""SELECT
pg_type.oid, attname
FROM
pg_type
INNER JOIN pg_attribute ON (attrelid = pg_type.typrelid)
WHERE typreceive = 'record_recv'::regproc
ORDER BY pg_type.oid, attnum""",
[])
row_desc, cmd = self.bind("tmp", "", (), parse_retval, None)
eod, rows = self.fetch_rows("tmp", 0, row_desc)
self._record_field_names = {}
typoid, attnames = None, []
for row in rows:
new_typoid, attname = row
if new_typoid != typoid and typoid != None:
self._record_field_names[typoid] = attnames
attnames = []
typoid = new_typoid
attnames.append(attname)
self._record_field_names[typoid] = attnames
@sync_on_error
def parse(self, statement, qs, param_types):
self.verifyState("ready")
type_info = [types.pg_type_info(x) for x in param_types]
param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr
self._send(Parse(statement, qs, param_types))
self._send(DescribePreparedStatement(statement))
self._send(Flush())
self._flush()
reader = MessageReader(self)
# ParseComplete is good.
reader.add_message(ParseComplete, lambda msg: 0)
# Well, we don't really care -- we're going to send whatever we
# want and let the database deal with it. But thanks anyways!
reader.add_message(ParameterDescription, lambda msg: 0)
# We're not waiting for a row description. Return something
# destinctive to let bind know that there is no output.
reader.add_message(NoData, lambda msg: (None, param_fc))
# Common row description response
reader.add_message(RowDescription, lambda msg: (msg, param_fc))
return reader.handle_messages()
@sync_on_error
def bind(self, portal, statement, params, parse_data, copy_stream):
self.verifyState("ready")
row_desc, param_fc = parse_data
if row_desc == None:
# no data coming out
output_fc = ()
else:
# We've got row_desc that allows us to identify what we're going to
# get back from this statement.
output_fc = [types.py_type_info(f, self._record_field_names) for f in row_desc.fields]
self._send(Bind(portal, statement, param_fc, params, output_fc, client_encoding = self._client_encoding, integer_datetimes = self._integer_datetimes))
# We need to describe the portal after bind, since the return
# format codes will be different (hopefully, always what we
# requested).
self._send(DescribePortal(portal))
self._send(Flush())
self._flush()
# Read responses from server...
reader = MessageReader(self)
# BindComplete is good -- just ignore
reader.add_message(BindComplete, lambda msg: 0)
# NoData in this case means we're not executing a query. As a
# result, we won't be fetching rows, so we'll never execute the
# portal we just created... unless we execute it right away, which
# we'll do.
reader.add_message(NoData, self._bind_nodata, portal, reader, copy_stream)
# Return the new row desc, since it will have the format types we
# asked the server for
reader.add_message(RowDescription, lambda msg: (msg, None))
return reader.handle_messages()
def _copy_in_response(self, copyin, fileobj, old_reader):
if fileobj == None:
raise CopyQueryWithoutStreamError()
while True:
data = fileobj.read(self._block_size)
if not data:
break
self._send(CopyData(data))
self._flush()
self._send(CopyDone())
self._send(Sync())
self._flush()
def _copy_out_response(self, copyout, fileobj, old_reader):
if fileobj == None:
raise CopyQueryWithoutStreamError()
reader = MessageReader(self)
reader.add_message(CopyData, self._copy_data, fileobj)
reader.add_message(CopyDone, lambda msg: 1)
reader.handle_messages()
def _copy_data(self, copydata, fileobj):
fileobj.write(copydata.data)
def _bind_nodata(self, msg, portal, old_reader, copy_stream):
# Bind message returned NoData, causing us to execute the command.
self._send(Execute(portal, 0))
self._send(Sync())
self._flush()
output = {}
reader = MessageReader(self)
reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader)
reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader)
reader.add_message(CommandComplete, lambda msg, out: out.setdefault('msg', msg) and False, output)
reader.add_message(ReadyForQuery, lambda msg: 1)
reader.delay_raising_exception = True
reader.handle_messages()
old_reader.return_value((None, output['msg']))
@sync_on_error
def fetch_rows(self, portal, row_count, row_desc):
self.verifyState("ready")
self._send(Execute(portal, row_count))
self._send(Flush())
self._flush()
rows = []
reader = MessageReader(self)
reader.add_message(DataRow, self._fetch_datarow, rows, row_desc)
reader.add_message(PortalSuspended, lambda msg: 1)
reader.add_message(CommandComplete, self._fetch_commandcomplete, portal)
retval = reader.handle_messages()
# retval = 2 when command complete, indicating that we've hit the
# end of the available data for this command
return (retval == 2), rows
def _fetch_datarow(self, msg, rows, row_desc):
rows.append(
[
types.py_value(
msg.fields[i],
row_desc.fields[i],
client_encoding=self._client_encoding,
integer_datetimes=self._integer_datetimes,
record_field_names=self._record_field_names
)
for i in range(len(msg.fields))
]
)
def _fetch_commandcomplete(self, msg, portal):
self._send(ClosePortal(portal))
self._send(Sync())
self._flush()
reader = MessageReader(self)
reader.add_message(ReadyForQuery, self._fetch_commandcomplete_rfq)
reader.add_message(CloseComplete, lambda msg: False)
reader.handle_messages()
return 2 # signal end-of-data
def _fetch_commandcomplete_rfq(self, msg):
self._state = "ready"
return True
# Send a Sync message, then read and discard all messages until we
# receive a ReadyForQuery message.
def _sync(self):
# it is assumed _sync is called from sync_on_error, which holds
# a _sock_lock throughout the call
self._send(Sync())
self._flush()
reader = MessageReader(self)
reader.ignore_unhandled_messages = True
reader.add_message(ReadyForQuery, lambda msg: True)
reader.handle_messages()
def close_statement(self, statement):
if self._state == "closed":
return
self.verifyState("ready")
self._sock_lock.acquire()
try:
self._send(ClosePreparedStatement(statement))
self._send(Sync())
self._flush()
reader = MessageReader(self)
reader.add_message(CloseComplete, lambda msg: 0)
reader.add_message(ReadyForQuery, lambda msg: 1)
reader.handle_messages()
finally:
self._sock_lock.release()
def close_portal(self, portal):
if self._state == "closed":
return
self.verifyState("ready")
self._sock_lock.acquire()
try:
self._send(ClosePortal(portal))
self._send(Sync())
self._flush()
reader = MessageReader(self)
reader.add_message(CloseComplete, lambda msg: 0)
reader.add_message(ReadyForQuery, lambda msg: 1)
reader.handle_messages()
finally:
self._sock_lock.release()
def close(self):
self._sock_lock.acquire()
try:
self._send(Terminate())
self._flush()
self._sock.close()
self._state = "closed"
finally:
self._sock_lock.release()
def _onParameterStatusReceived(self, msg):
if msg.key == "client_encoding":
self._client_encoding = msg.value
elif msg.key == "integer_datetimes":
self._integer_datetimes = (msg.value == "on")
def handleNoticeResponse(self, msg):
self.NoticeReceived(msg)
def handleParameterStatus(self, msg):
self.ParameterStatusReceived(msg)
def handleNotificationResponse(self, msg):
self.NotificationReceived(msg)
def fileno(self):
# This should be safe to do without a lock
return self._sock.fileno()
def isready(self):
self._sock_lock.acquire()
try:
rlst, _wlst, _xlst = select.select([self], [], [], 0)
if not rlst:
return False
self._sync()
return True
finally:
self._sock_lock.release()
message_types = {
"N": NoticeResponse,
"R": AuthenticationRequest,
"S": ParameterStatus,
"K": BackendKeyData,
"Z": ReadyForQuery,
"T": RowDescription,
"E": ErrorResponse,
"D": DataRow,
"C": CommandComplete,
"1": ParseComplete,
"2": BindComplete,
"3": CloseComplete,
"s": PortalSuspended,
"n": NoData,
"t": ParameterDescription,
"A": NotificationResponse,
"c": CopyDone,
"d": CopyData,
"G": CopyInResponse,
"H": CopyOutResponse,
}