| # vim: sw=4:expandtab:foldmethod=marker |
| # |
| # Copyright (c) 2007-2009, Mathieu Fenniak |
| # All rights reserved. |
| # |
| # Redistribution and use in source and binary forms, with or without |
| # modification, are permitted provided that the following conditions are |
| # met: |
| # |
| # * Redistributions of source code must retain the above copyright notice, |
| # this list of conditions and the following disclaimer. |
| # * Redistributions in binary form must reproduce the above copyright notice, |
| # this list of conditions and the following disclaimer in the documentation |
| # and/or other materials provided with the distribution. |
| # * The name of the author may not be used to endorse or promote products |
| # derived from this software without specific prior written permission. |
| # |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
| # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| # POSSIBILITY OF SUCH DAMAGE. |
| |
| __author__ = "Mathieu Fenniak" |
| |
| import socket |
| import select |
| import threading |
| import struct |
| import hashlib |
| from cStringIO import StringIO |
| |
| from errors import * |
| from util import MulticastDelegate |
| import types |
| |
| ## |
| # An SSLRequest message. To initiate an SSL-encrypted connection, an |
| # SSLRequest message is used rather than a {@link StartupMessage |
| # StartupMessage}. A StartupMessage is still sent, but only after SSL |
| # negotiation (if accepted). |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class SSLRequest(object): |
| def __init__(self): |
| pass |
| |
| # Int32(8) - Message length, including self.<br> |
| # Int32(80877103) - The SSL request code.<br> |
| def serialize(self): |
| return struct.pack("!ii", 8, 80877103) |
| |
| |
| ## |
| # A StartupMessage message. Begins a DB session, identifying the user to be |
| # authenticated as and the database to connect to. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class StartupMessage(object): |
| # Greenplum utility mode |
| def __init__(self, user, database=None, options=None): |
| self.user = user |
| self.database = database |
| self.options = options |
| |
| # Int32 - Message length, including self. |
| # Int32(196608) - Protocol version number. Version 3.0. |
| # Any number of key/value pairs, terminated by a zero byte: |
| # String - A parameter name (user, database, or options) |
| # String - Parameter value |
| def serialize(self): |
| protocol = 196608 |
| val = struct.pack("!i", protocol) |
| val += "user\x00" + self.user + "\x00" |
| if self.database: |
| val += "database\x00" + self.database + "\x00" |
| if self.options: |
| val += "options\x00" + self.options + "\x00" |
| val += "\x00" |
| val = struct.pack("!i", len(val) + 4) + val |
| return val |
| |
| |
| ## |
| # Parse message. Creates a prepared statement in the DB session. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| # |
| # @param ps Name of the prepared statement to create. |
| # @param qs Query string. |
| # @param type_oids An iterable that contains the PostgreSQL type OIDs for |
| # parameters in the query string. |
| class Parse(object): |
| def __init__(self, ps, qs, type_oids): |
| self.ps = ps |
| self.qs = qs |
| self.type_oids = type_oids |
| |
| def __repr__(self): |
| return "<Parse ps=%r qs=%r>" % (self.ps, self.qs) |
| |
| # Byte1('P') - Identifies the message as a Parse command. |
| # Int32 - Message length, including self. |
| # String - Prepared statement name. An empty string selects the unnamed |
| # prepared statement. |
| # String - The query string. |
| # Int16 - Number of parameter data types specified (can be zero). |
| # For each parameter: |
| # Int32 - The OID of the parameter data type. |
| def serialize(self): |
| val = self.ps + "\x00" + self.qs + "\x00" |
| val = val + struct.pack("!h", len(self.type_oids)) |
| for oid in self.type_oids: |
| # Parse message doesn't seem to handle the -1 type_oid for NULL |
| # values that other messages handle. So we'll provide type_oid 705, |
| # the PG "unknown" type. |
| if oid == -1: oid = 705 |
| val = val + struct.pack("!i", oid) |
| val = struct.pack("!i", len(val) + 4) + val |
| val = "P" + val |
| return val |
| |
| |
| ## |
| # Bind message. Readies a prepared statement for execution. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| # |
| # @param portal Name of the destination portal. |
| # @param ps Name of the source prepared statement. |
| # @param in_fc An iterable containing the format codes for input |
| # parameters. 0 = Text, 1 = Binary. |
| # @param params The parameters. |
| # @param out_fc An iterable containing the format codes for output |
| # parameters. 0 = Text, 1 = Binary. |
| # @param kwargs Additional arguments to pass to the type conversion |
| # methods. |
| class Bind(object): |
| def __init__(self, portal, ps, in_fc, params, out_fc, **kwargs): |
| self.portal = portal |
| self.ps = ps |
| self.in_fc = in_fc |
| self.params = [] |
| for i in range(len(params)): |
| if len(self.in_fc) == 0: |
| fc = 0 |
| elif len(self.in_fc) == 1: |
| fc = self.in_fc[0] |
| else: |
| fc = self.in_fc[i] |
| self.params.append(types.pg_value(params[i], fc, **kwargs)) |
| self.out_fc = out_fc |
| |
| def __repr__(self): |
| return "<Bind p=%r s=%r>" % (self.portal, self.ps) |
| |
| # Byte1('B') - Identifies the Bind command. |
| # Int32 - Message length, including self. |
| # String - Name of the destination portal. |
| # String - Name of the source prepared statement. |
| # Int16 - Number of parameter format codes. |
| # For each parameter format code: |
| # Int16 - The parameter format code. |
| # Int16 - Number of parameter values. |
| # For each parameter value: |
| # Int32 - The length of the parameter value, in bytes, not including this |
| # this length. -1 indicates a NULL parameter value, in which no |
| # value bytes follow. |
| # Byte[n] - Value of the parameter. |
| # Int16 - The number of result-column format codes. |
| # For each result-column format code: |
| # Int16 - The format code. |
| def serialize(self): |
| retval = StringIO() |
| retval.write(self.portal + "\x00") |
| retval.write(self.ps + "\x00") |
| retval.write(struct.pack("!h", len(self.in_fc))) |
| for fc in self.in_fc: |
| retval.write(struct.pack("!h", fc)) |
| retval.write(struct.pack("!h", len(self.params))) |
| for param in self.params: |
| if param == None: |
| # special case, NULL value |
| retval.write(struct.pack("!i", -1)) |
| else: |
| retval.write(struct.pack("!i", len(param))) |
| retval.write(param) |
| retval.write(struct.pack("!h", len(self.out_fc))) |
| for fc in self.out_fc: |
| retval.write(struct.pack("!h", fc)) |
| val = retval.getvalue() |
| val = struct.pack("!i", len(val) + 4) + val |
| val = "B" + val |
| return val |
| |
| |
| ## |
| # A Close message, used for closing prepared statements and portals. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| # |
| # @param typ 'S' for prepared statement, 'P' for portal. |
| # @param name The name of the item to close. |
| class Close(object): |
| def __init__(self, typ, name): |
| if len(typ) != 1: |
| raise InternalError("Close typ must be 1 char") |
| self.typ = typ |
| self.name = name |
| |
| # Byte1('C') - Identifies the message as a close command. |
| # Int32 - Message length, including self. |
| # Byte1 - 'S' for prepared statement, 'P' for portal. |
| # String - The name of the item to close. |
| def serialize(self): |
| val = self.typ + self.name + "\x00" |
| val = struct.pack("!i", len(val) + 4) + val |
| val = "C" + val |
| return val |
| |
| |
| ## |
| # A specialized Close message for a portal. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class ClosePortal(Close): |
| def __init__(self, name): |
| Close.__init__(self, "P", name) |
| |
| |
| ## |
| # A specialized Close message for a prepared statement. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class ClosePreparedStatement(Close): |
| def __init__(self, name): |
| Close.__init__(self, "S", name) |
| |
| |
| ## |
| # A Describe message, used for obtaining information on prepared statements |
| # and portals. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| # |
| # @param typ 'S' for prepared statement, 'P' for portal. |
| # @param name The name of the item to close. |
| class Describe(object): |
| def __init__(self, typ, name): |
| if len(typ) != 1: |
| raise InternalError("Describe typ must be 1 char") |
| self.typ = typ |
| self.name = name |
| |
| # Byte1('D') - Identifies the message as a describe command. |
| # Int32 - Message length, including self. |
| # Byte1 - 'S' for prepared statement, 'P' for portal. |
| # String - The name of the item to close. |
| def serialize(self): |
| val = self.typ + self.name + "\x00" |
| val = struct.pack("!i", len(val) + 4) + val |
| val = "D" + val |
| return val |
| |
| |
| ## |
| # A specialized Describe message for a portal. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class DescribePortal(Describe): |
| def __init__(self, name): |
| Describe.__init__(self, "P", name) |
| |
| def __repr__(self): |
| return "<DescribePortal %r>" % (self.name) |
| |
| |
| ## |
| # A specialized Describe message for a prepared statement. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class DescribePreparedStatement(Describe): |
| def __init__(self, name): |
| Describe.__init__(self, "S", name) |
| |
| def __repr__(self): |
| return "<DescribePreparedStatement %r>" % (self.name) |
| |
| |
| ## |
| # A Flush message forces the backend to deliver any data pending in its |
| # output buffers. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class Flush(object): |
| # Byte1('H') - Identifies the message as a flush command. |
| # Int32(4) - Length of message, including self. |
| def serialize(self): |
| return 'H\x00\x00\x00\x04' |
| |
| def __repr__(self): |
| return "<Flush>" |
| |
| ## |
| # Causes the backend to close the current transaction (if not in a BEGIN/COMMIT |
| # block), and issue ReadyForQuery. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class Sync(object): |
| # Byte1('S') - Identifies the message as a sync command. |
| # Int32(4) - Length of message, including self. |
| def serialize(self): |
| return 'S\x00\x00\x00\x04' |
| |
| def __repr__(self): |
| return "<Sync>" |
| |
| |
| ## |
| # Transmits a password. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class PasswordMessage(object): |
| def __init__(self, pwd): |
| self.pwd = pwd |
| |
| # Byte1('p') - Identifies the message as a password message. |
| # Int32 - Message length including self. |
| # String - The password. Password may be encrypted. |
| def serialize(self): |
| val = self.pwd + "\x00" |
| val = struct.pack("!i", len(val) + 4) + val |
| val = "p" + val |
| return val |
| |
| |
| ## |
| # Requests that the backend execute a portal and retrieve any number of rows. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| # @param row_count The number of rows to return. Can be zero to indicate the |
| # backend should return all rows. If the portal represents a |
| # query that does not return rows, no rows will be returned |
| # no matter what the row_count. |
| class Execute(object): |
| def __init__(self, portal, row_count): |
| self.portal = portal |
| self.row_count = row_count |
| |
| # Byte1('E') - Identifies the message as an execute message. |
| # Int32 - Message length, including self. |
| # String - The name of the portal to execute. |
| # Int32 - Maximum number of rows to return, if portal contains a query that |
| # returns rows. 0 = no limit. |
| def serialize(self): |
| val = self.portal + "\x00" + struct.pack("!i", self.row_count) |
| val = struct.pack("!i", len(val) + 4) + val |
| val = "E" + val |
| return val |
| |
| |
| ## |
| # Informs the backend that the connection is being closed. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class Terminate(object): |
| def __init__(self): |
| pass |
| |
| # Byte1('X') - Identifies the message as a terminate message. |
| # Int32(4) - Message length, including self. |
| def serialize(self): |
| return 'X\x00\x00\x00\x04' |
| |
| ## |
| # Base class of all Authentication[*] messages. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class AuthenticationRequest(object): |
| def __init__(self, data): |
| pass |
| |
| # Byte1('R') - Identifies the message as an authentication request. |
| # Int32(8) - Message length, including self. |
| # Int32 - An authentication code that represents different |
| # authentication messages: |
| # 0 = AuthenticationOk |
| # 5 = MD5 pwd |
| # 2 = Kerberos v5 (not supported by pg8000) |
| # 3 = Cleartext pwd (not supported by pg8000) |
| # 4 = crypt() pwd (not supported by pg8000) |
| # 6 = SCM credential (not supported by pg8000) |
| # 7 = GSSAPI (not supported by pg8000) |
| # 8 = GSSAPI data (not supported by pg8000) |
| # 9 = SSPI (not supported by pg8000) |
| # Some authentication messages have additional data following the |
| # authentication code. That data is documented in the appropriate class. |
| def createFromData(data): |
| ident = struct.unpack("!i", data[:4])[0] |
| klass = authentication_codes.get(ident, None) |
| if klass != None: |
| return klass(data[4:]) |
| else: |
| raise NotSupportedError("authentication method %r not supported" % (ident,)) |
| createFromData = staticmethod(createFromData) |
| |
| def ok(self, conn, user, **kwargs): |
| raise InternalError("ok method should be overridden on AuthenticationRequest instance") |
| |
| ## |
| # A message representing that the backend accepting the provided username |
| # without any challenge. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class AuthenticationOk(AuthenticationRequest): |
| def ok(self, conn, user, **kwargs): |
| return True |
| |
| |
| ## |
| # A message representing the backend requesting an MD5 hashed password |
| # response. The response will be sent as md5(md5(pwd + login) + salt). |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class AuthenticationMD5Password(AuthenticationRequest): |
| # Additional message data: |
| # Byte4 - Hash salt. |
| def __init__(self, data): |
| self.salt = "".join(struct.unpack("4c", data)) |
| |
| def ok(self, conn, user, password=None, **kwargs): |
| if password == None: |
| raise InterfaceError("server requesting MD5 password authentication, but no password was provided") |
| pwd = "md5" + hashlib.md5(hashlib.md5(password + user).hexdigest() + self.salt).hexdigest() |
| conn._send(PasswordMessage(pwd)) |
| conn._flush() |
| |
| reader = MessageReader(conn) |
| reader.add_message(AuthenticationRequest, lambda msg, reader: reader.return_value(msg.ok(conn, user)), reader) |
| reader.add_message(ErrorResponse, self._ok_error) |
| return reader.handle_messages() |
| |
| def _ok_error(self, msg): |
| if msg.code == "28000": |
| raise InterfaceError("md5 password authentication failed") |
| else: |
| raise msg.createException() |
| |
| authentication_codes = { |
| 0: AuthenticationOk, |
| 5: AuthenticationMD5Password, |
| } |
| |
| |
| ## |
| # ParameterStatus message sent from backend, used to inform the frotnend of |
| # runtime configuration parameter changes. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class ParameterStatus(object): |
| def __init__(self, key, value): |
| self.key = key |
| self.value = value |
| |
| # Byte1('S') - Identifies ParameterStatus |
| # Int32 - Message length, including self. |
| # String - Runtime parameter name. |
| # String - Runtime parameter value. |
| def createFromData(data): |
| key = data[:data.find("\x00")] |
| value = data[data.find("\x00")+1:-1] |
| return ParameterStatus(key, value) |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # BackendKeyData message sent from backend. Contains a connection's process |
| # ID and a secret key. Can be used to terminate the connection's current |
| # actions, such as a long running query. Not supported by pg8000 yet. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class BackendKeyData(object): |
| def __init__(self, process_id, secret_key): |
| self.process_id = process_id |
| self.secret_key = secret_key |
| |
| # Byte1('K') - Identifier. |
| # Int32(12) - Message length, including self. |
| # Int32 - Process ID. |
| # Int32 - Secret key. |
| def createFromData(data): |
| process_id, secret_key = struct.unpack("!2i", data) |
| return BackendKeyData(process_id, secret_key) |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Message representing a query with no data. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class NoData(object): |
| # Byte1('n') - Identifier. |
| # Int32(4) - Message length, including self. |
| def createFromData(data): |
| return NoData() |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Message representing a successful Parse. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class ParseComplete(object): |
| # Byte1('1') - Identifier. |
| # Int32(4) - Message length, including self. |
| def createFromData(data): |
| return ParseComplete() |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Message representing a successful Bind. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class BindComplete(object): |
| # Byte1('2') - Identifier. |
| # Int32(4) - Message length, including self. |
| def createFromData(data): |
| return BindComplete() |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Message representing a successful Close. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class CloseComplete(object): |
| # Byte1('3') - Identifier. |
| # Int32(4) - Message length, including self. |
| def createFromData(data): |
| return CloseComplete() |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Message representing data from an Execute has been received, but more data |
| # exists in the portal. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class PortalSuspended(object): |
| # Byte1('s') - Identifier. |
| # Int32(4) - Message length, including self. |
| def createFromData(data): |
| return PortalSuspended() |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Message representing the backend is ready to process a new query. |
| # <p> |
| # Stability: This is an internal class. No stability guarantee is made. |
| class ReadyForQuery(object): |
| def __init__(self, status): |
| self._status = status |
| |
| ## |
| # I = Idle, T = Idle in Transaction, E = idle in failed transaction. |
| status = property(lambda self: self._status) |
| |
| def __repr__(self): |
| return "<ReadyForQuery %s>" % \ |
| {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status] |
| |
| # Byte1('Z') - Identifier. |
| # Int32(5) - Message length, including self. |
| # Byte1 - Status indicator. |
| def createFromData(data): |
| return ReadyForQuery(data) |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # Represents a notice sent from the server. This is not the same as a |
| # notification. A notice is just additional information about a query, such |
| # as a notice that a primary key has automatically been created for a table. |
| # <p> |
| # A NoticeResponse instance will have properties containing the data sent |
| # from the server: |
| # <ul> |
| # <li>severity -- "ERROR", "FATAL', "PANIC", "WARNING", "NOTICE", "DEBUG", |
| # "INFO", or "LOG". Always present.</li> |
| # <li>code -- the SQLSTATE code for the error. See Appendix A of the |
| # PostgreSQL documentation for specific error codes. Always present.</li> |
| # <li>msg -- human-readable error message. Always present.</li> |
| # <li>detail -- Optional additional information.</li> |
| # <li>hint -- Optional suggestion about what to do about the issue.</li> |
| # <li>position -- Optional index into the query string.</li> |
| # <li>where -- Optional context.</li> |
| # <li>file -- Source-code file.</li> |
| # <li>line -- Source-code line.</li> |
| # <li>routine -- Source-code routine.</li> |
| # </ul> |
| # <p> |
| # Stability: Added in pg8000 v1.03. Required properties severity, code, and |
| # msg are guaranteed for v1.xx. Other properties should be checked with |
| # hasattr before accessing. |
| class NoticeResponse(object): |
| responseKeys = { |
| "S": "severity", # always present |
| "C": "code", # always present |
| "M": "msg", # always present |
| "D": "detail", |
| "H": "hint", |
| "P": "position", |
| "p": "_position", |
| "q": "_query", |
| "W": "where", |
| "F": "file", |
| "L": "line", |
| "R": "routine", |
| } |
| |
| def __init__(self, **kwargs): |
| for arg, value in kwargs.items(): |
| setattr(self, arg, value) |
| |
| def __repr__(self): |
| return "<NoticeResponse %s %s %r>" % (self.severity, self.code, self.msg) |
| |
| def dataIntoDict(data): |
| retval = {} |
| for s in data.split("\x00"): |
| if not s: continue |
| key, value = s[0], s[1:] |
| key = NoticeResponse.responseKeys.get(key, key) |
| retval[key] = value |
| return retval |
| dataIntoDict = staticmethod(dataIntoDict) |
| |
| # Byte1('N') - Identifier |
| # Int32 - Message length |
| # Any number of these, followed by a zero byte: |
| # Byte1 - code identifying the field type (see responseKeys) |
| # String - field value |
| def createFromData(data): |
| return NoticeResponse(**NoticeResponse.dataIntoDict(data)) |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # A message sent in case of a server-side error. Contains the same properties |
| # that {@link NoticeResponse NoticeResponse} contains. |
| # <p> |
| # Stability: Added in pg8000 v1.03. Required properties severity, code, and |
| # msg are guaranteed for v1.xx. Other properties should be checked with |
| # hasattr before accessing. |
| class ErrorResponse(object): |
| def __init__(self, **kwargs): |
| for arg, value in kwargs.items(): |
| setattr(self, arg, value) |
| |
| def __repr__(self): |
| return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg) |
| |
| def createException(self): |
| return ProgrammingError(self.severity, self.code, self.msg) |
| |
| def createFromData(data): |
| return ErrorResponse(**NoticeResponse.dataIntoDict(data)) |
| createFromData = staticmethod(createFromData) |
| |
| |
| ## |
| # A message sent if this connection receives a NOTIFY that it was LISTENing for. |
| # <p> |
| # Stability: Added in pg8000 v1.03. When limited to accessing properties from |
| # a notification event dispatch, stability is guaranteed for v1.xx. |
| class NotificationResponse(object): |
| def __init__(self, backend_pid, condition, additional_info): |
| self._backend_pid = backend_pid |
| self._condition = condition |
| self._additional_info = additional_info |
| |
| ## |
| # An integer representing the process ID of the backend that triggered |
| # the NOTIFY. |
| # <p> |
| # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. |
| backend_pid = property(lambda self: self._backend_pid) |
| |
| ## |
| # The name of the notification fired. |
| # <p> |
| # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. |
| condition = property(lambda self: self._condition) |
| |
| ## |
| # Currently unspecified by the PostgreSQL documentation as of v8.3.1. |
| # <p> |
| # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. |
| additional_info = property(lambda self: self._additional_info) |
| |
| def __repr__(self): |
| return "<NotificationResponse %s %s %r>" % (self.backend_pid, self.condition, self.additional_info) |
| |
| def createFromData(data): |
| backend_pid = struct.unpack("!i", data[:4])[0] |
| data = data[4:] |
| null = data.find("\x00") |
| condition = data[:null] |
| data = data[null+1:] |
| null = data.find("\x00") |
| additional_info = data[:null] |
| return NotificationResponse(backend_pid, condition, additional_info) |
| createFromData = staticmethod(createFromData) |
| |
| |
| class ParameterDescription(object): |
| def __init__(self, type_oids): |
| self.type_oids = type_oids |
| def createFromData(data): |
| count = struct.unpack("!h", data[:2])[0] |
| type_oids = struct.unpack("!" + "i"*count, data[2:]) |
| return ParameterDescription(type_oids) |
| createFromData = staticmethod(createFromData) |
| |
| |
| class RowDescription(object): |
| def __init__(self, fields): |
| self.fields = fields |
| |
| def createFromData(data): |
| count = struct.unpack("!h", data[:2])[0] |
| data = data[2:] |
| fields = [] |
| for i in range(count): |
| null = data.find("\x00") |
| field = {"name": data[:null]} |
| data = data[null+1:] |
| field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18]) |
| data = data[18:] |
| fields.append(field) |
| return RowDescription(fields) |
| createFromData = staticmethod(createFromData) |
| |
| class CommandComplete(object): |
| def __init__(self, command, rows=None, oid=None): |
| self.command = command |
| self.rows = rows |
| self.oid = oid |
| |
| def createFromData(data): |
| values = data[:-1].split(" ") |
| args = {} |
| args['command'] = values[0] |
| if args['command'] in ("INSERT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY"): |
| args['rows'] = int(values[-1]) |
| if args['command'] == "INSERT": |
| args['oid'] = int(values[1]) |
| else: |
| args['command'] = data[:-1] |
| return CommandComplete(**args) |
| createFromData = staticmethod(createFromData) |
| |
| |
| class DataRow(object): |
| def __init__(self, fields): |
| self.fields = fields |
| |
| def createFromData(data): |
| count = struct.unpack("!h", data[:2])[0] |
| data = data[2:] |
| fields = [] |
| for i in range(count): |
| val_len = struct.unpack("!i", data[:4])[0] |
| data = data[4:] |
| if val_len == -1: |
| fields.append(None) |
| else: |
| fields.append(data[:val_len]) |
| data = data[val_len:] |
| return DataRow(fields) |
| createFromData = staticmethod(createFromData) |
| |
| |
| class CopyData(object): |
| # "d": CopyData, |
| def __init__(self, data): |
| self.data = data |
| |
| def createFromData(data): |
| return CopyData(data) |
| createFromData = staticmethod(createFromData) |
| |
| def serialize(self): |
| return 'd' + struct.pack('!i', len(self.data) + 4) + self.data |
| |
| |
| class CopyDone(object): |
| # Byte1('c') - Identifier. |
| # Int32(4) - Message length, including self. |
| |
| def createFromData(data): |
| return CopyDone() |
| |
| createFromData = staticmethod(createFromData) |
| |
| def serialize(self): |
| return 'c\x00\x00\x00\x04' |
| |
| class CopyOutResponse(object): |
| # Byte1('H') |
| # Int32(4) - Length of message contents in bytes, including self. |
| # Int8(1) - 0 textual, 1 binary |
| # Int16(2) - Number of columns |
| # Int16(N) - Format codes for each column (0 text, 1 binary) |
| |
| def __init__(self, is_binary, column_formats): |
| self.is_binary = is_binary |
| self.column_formats = column_formats |
| |
| def createFromData(data): |
| is_binary, num_cols = struct.unpack('!bh', data[:3]) |
| column_formats = struct.unpack('!' + ('h' * num_cols), data[3:]) |
| return CopyOutResponse(is_binary, column_formats) |
| |
| createFromData = staticmethod(createFromData) |
| |
| |
| class CopyInResponse(object): |
| # Byte1('G') |
| # Otherwise the same as CopyOutResponse |
| |
| def __init__(self, is_binary, column_formats): |
| self.is_binary = is_binary |
| self.column_formats = column_formats |
| |
| def createFromData(data): |
| is_binary, num_cols = struct.unpack('!bh', data[:3]) |
| column_formats = struct.unpack('!' + ('h' * num_cols), data[3:]) |
| return CopyInResponse(is_binary, column_formats) |
| |
| createFromData = staticmethod(createFromData) |
| |
| class SSLWrapper(object): |
| def __init__(self, sslobj): |
| self.sslobj = sslobj |
| def send(self, data): |
| self.sslobj.write(data) |
| def recv(self, num): |
| return self.sslobj.read(num) |
| |
| |
| class MessageReader(object): |
| def __init__(self, connection): |
| self._conn = connection |
| self._msgs = [] |
| |
| # If true, raise exception from an ErrorResponse after messages are |
| # processed. This can be used to leave the connection in a usable |
| # state after an error response, rather than having unconsumed |
| # messages that won't be understood in another context. |
| self.delay_raising_exception = False |
| |
| self.ignore_unhandled_messages = False |
| |
| def add_message(self, msg_class, handler, *args, **kwargs): |
| self._msgs.append((msg_class, handler, args, kwargs)) |
| |
| def clear_messages(self): |
| self._msgs = [] |
| |
| def return_value(self, value): |
| self._retval = value |
| |
| def handle_messages(self): |
| exc = None |
| while 1: |
| msg = self._conn._read_message() |
| msg_handled = False |
| for (msg_class, handler, args, kwargs) in self._msgs: |
| if isinstance(msg, msg_class): |
| msg_handled = True |
| retval = handler(msg, *args, **kwargs) |
| if retval: |
| # The handler returned a true value, meaning that the |
| # message loop should be aborted. |
| if exc != None: |
| raise exc |
| return retval |
| elif hasattr(self, "_retval"): |
| # The handler told us to return -- used for non-true |
| # return values |
| if exc != None: |
| raise exc |
| return self._retval |
| if msg_handled: |
| continue |
| elif isinstance(msg, ErrorResponse): |
| exc = msg.createException() |
| if not self.delay_raising_exception: |
| raise exc |
| elif isinstance(msg, NoticeResponse): |
| self._conn.handleNoticeResponse(msg) |
| elif isinstance(msg, ParameterStatus): |
| self._conn.handleParameterStatus(msg) |
| elif isinstance(msg, NotificationResponse): |
| self._conn.handleNotificationResponse(msg) |
| elif not self.ignore_unhandled_messages: |
| raise InternalError("Unexpected response msg %r" % (msg)) |
| |
| def sync_on_error(fn): |
| def _fn(self, *args, **kwargs): |
| try: |
| self._sock_lock.acquire() |
| return fn(self, *args, **kwargs) |
| except: |
| self._sync() |
| raise |
| finally: |
| self._sock_lock.release() |
| return _fn |
| |
| class Connection(object): |
| def __init__(self, unix_sock=None, host=None, port=5432, socket_timeout=60, ssl=False, records=False): |
| self._client_encoding = "ascii" |
| self._integer_datetimes = False |
| self._record_field_names = {} |
| self._sock_buf = "" |
| self._sock_buf_pos = 0 |
| self._send_sock_buf = [] |
| self._block_size = 8192 |
| self.user_wants_records = records |
| if unix_sock == None and host != None: |
| self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| elif unix_sock != None: |
| if not hasattr(socket, "AF_UNIX"): |
| raise InterfaceError("attempt to connect to unix socket on unsupported platform") |
| self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) |
| else: |
| raise ProgrammingError("one of host or unix_sock must be provided") |
| if unix_sock == None and host != None: |
| self._sock.connect((host, port)) |
| elif unix_sock != None: |
| self._sock.connect(unix_sock) |
| if ssl: |
| self._send(SSLRequest()) |
| self._flush() |
| resp = self._sock.recv(1) |
| if resp == 'S': |
| self._sock = SSLWrapper(socket.ssl(self._sock)) |
| else: |
| raise InterfaceError("server refuses SSL") |
| else: |
| # settimeout causes ssl failure, on windows. Python bug 1462352. |
| self._sock.settimeout(socket_timeout) |
| self._state = "noauth" |
| self._backend_key_data = None |
| self._sock_lock = threading.Lock() |
| |
| self.NoticeReceived = MulticastDelegate() |
| self.ParameterStatusReceived = MulticastDelegate() |
| self.NotificationReceived = MulticastDelegate() |
| |
| self.ParameterStatusReceived += self._onParameterStatusReceived |
| |
| def verifyState(self, state): |
| if self._state != state: |
| raise InternalError("connection state must be %s, is %s" % (state, self._state)) |
| |
| def _send(self, msg): |
| assert self._sock_lock.locked() |
| #print "_send(%r)" % msg |
| data = msg.serialize() |
| self._send_sock_buf.append(data) |
| |
| def _flush(self): |
| assert self._sock_lock.locked() |
| self._sock.sendall("".join(self._send_sock_buf)) |
| del self._send_sock_buf[:] |
| |
| def _read_bytes(self, byte_count): |
| retval = [] |
| bytes_read = 0 |
| while bytes_read < byte_count: |
| if self._sock_buf_pos == len(self._sock_buf): |
| self._sock_buf = self._sock.recv(1024) |
| self._sock_buf_pos = 0 |
| rpos = min(len(self._sock_buf), self._sock_buf_pos + (byte_count - bytes_read)) |
| addt_data = self._sock_buf[self._sock_buf_pos:rpos] |
| bytes_read += (rpos - self._sock_buf_pos) |
| assert bytes_read <= byte_count |
| self._sock_buf_pos = rpos |
| retval.append(addt_data) |
| return "".join(retval) |
| |
| def _read_message(self): |
| assert self._sock_lock.locked() |
| bytes = self._read_bytes(5) |
| message_code = bytes[0] |
| data_len = struct.unpack("!i", bytes[1:])[0] - 4 |
| bytes = self._read_bytes(data_len) |
| assert len(bytes) == data_len |
| msg = message_types[message_code].createFromData(bytes) |
| #print "_read_message() -> %r" % msg |
| return msg |
| |
| def authenticate(self, user, **kwargs): |
| self.verifyState("noauth") |
| self._sock_lock.acquire() |
| try: |
| self._send(StartupMessage(user, database=kwargs.get("database",None), options=kwargs.get("options", None))) |
| self._flush() |
| msg = self._read_message() |
| if isinstance(msg, ErrorResponse): |
| raise msg.createException() |
| if not isinstance(msg, AuthenticationRequest): |
| raise InternalError("StartupMessage was responded to with non-AuthenticationRequest msg %r" % msg) |
| if not msg.ok(self, user, **kwargs): |
| raise InterfaceError("authentication method %s failed" % msg.__class__.__name__) |
| |
| self._state = "auth" |
| |
| reader = MessageReader(self) |
| reader.add_message(ReadyForQuery, self._ready_for_query) |
| reader.add_message(BackendKeyData, self._receive_backend_key_data) |
| reader.handle_messages() |
| finally: |
| self._sock_lock.release() |
| |
| self._cache_record_attnames() |
| |
| def _ready_for_query(self, msg): |
| self._state = "ready" |
| return True |
| |
| def _receive_backend_key_data(self, msg): |
| self._backend_key_data = msg |
| |
| def _cache_record_attnames(self): |
| if not self.user_wants_records: |
| return |
| |
| parse_retval = self.parse("", |
| """SELECT |
| pg_type.oid, attname |
| FROM |
| pg_type |
| INNER JOIN pg_attribute ON (attrelid = pg_type.typrelid) |
| WHERE typreceive = 'record_recv'::regproc |
| ORDER BY pg_type.oid, attnum""", |
| []) |
| row_desc, cmd = self.bind("tmp", "", (), parse_retval, None) |
| eod, rows = self.fetch_rows("tmp", 0, row_desc) |
| |
| self._record_field_names = {} |
| typoid, attnames = None, [] |
| for row in rows: |
| new_typoid, attname = row |
| if new_typoid != typoid and typoid != None: |
| self._record_field_names[typoid] = attnames |
| attnames = [] |
| typoid = new_typoid |
| attnames.append(attname) |
| self._record_field_names[typoid] = attnames |
| |
| @sync_on_error |
| def parse(self, statement, qs, param_types): |
| self.verifyState("ready") |
| |
| type_info = [types.pg_type_info(x) for x in param_types] |
| param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr |
| self._send(Parse(statement, qs, param_types)) |
| self._send(DescribePreparedStatement(statement)) |
| self._send(Flush()) |
| self._flush() |
| |
| reader = MessageReader(self) |
| |
| # ParseComplete is good. |
| reader.add_message(ParseComplete, lambda msg: 0) |
| |
| # Well, we don't really care -- we're going to send whatever we |
| # want and let the database deal with it. But thanks anyways! |
| reader.add_message(ParameterDescription, lambda msg: 0) |
| |
| # We're not waiting for a row description. Return something |
| # destinctive to let bind know that there is no output. |
| reader.add_message(NoData, lambda msg: (None, param_fc)) |
| |
| # Common row description response |
| reader.add_message(RowDescription, lambda msg: (msg, param_fc)) |
| |
| return reader.handle_messages() |
| |
| @sync_on_error |
| def bind(self, portal, statement, params, parse_data, copy_stream): |
| self.verifyState("ready") |
| |
| row_desc, param_fc = parse_data |
| if row_desc == None: |
| # no data coming out |
| output_fc = () |
| else: |
| # We've got row_desc that allows us to identify what we're going to |
| # get back from this statement. |
| output_fc = [types.py_type_info(f, self._record_field_names) for f in row_desc.fields] |
| self._send(Bind(portal, statement, param_fc, params, output_fc, client_encoding = self._client_encoding, integer_datetimes = self._integer_datetimes)) |
| # We need to describe the portal after bind, since the return |
| # format codes will be different (hopefully, always what we |
| # requested). |
| self._send(DescribePortal(portal)) |
| self._send(Flush()) |
| self._flush() |
| |
| # Read responses from server... |
| reader = MessageReader(self) |
| |
| # BindComplete is good -- just ignore |
| reader.add_message(BindComplete, lambda msg: 0) |
| |
| # NoData in this case means we're not executing a query. As a |
| # result, we won't be fetching rows, so we'll never execute the |
| # portal we just created... unless we execute it right away, which |
| # we'll do. |
| reader.add_message(NoData, self._bind_nodata, portal, reader, copy_stream) |
| |
| # Return the new row desc, since it will have the format types we |
| # asked the server for |
| reader.add_message(RowDescription, lambda msg: (msg, None)) |
| |
| return reader.handle_messages() |
| |
| def _copy_in_response(self, copyin, fileobj, old_reader): |
| if fileobj == None: |
| raise CopyQueryWithoutStreamError() |
| while True: |
| data = fileobj.read(self._block_size) |
| if not data: |
| break |
| self._send(CopyData(data)) |
| self._flush() |
| self._send(CopyDone()) |
| self._send(Sync()) |
| self._flush() |
| |
| def _copy_out_response(self, copyout, fileobj, old_reader): |
| if fileobj == None: |
| raise CopyQueryWithoutStreamError() |
| reader = MessageReader(self) |
| reader.add_message(CopyData, self._copy_data, fileobj) |
| reader.add_message(CopyDone, lambda msg: 1) |
| reader.handle_messages() |
| |
| def _copy_data(self, copydata, fileobj): |
| fileobj.write(copydata.data) |
| |
| def _bind_nodata(self, msg, portal, old_reader, copy_stream): |
| # Bind message returned NoData, causing us to execute the command. |
| self._send(Execute(portal, 0)) |
| self._send(Sync()) |
| self._flush() |
| |
| output = {} |
| reader = MessageReader(self) |
| reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader) |
| reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader) |
| reader.add_message(CommandComplete, lambda msg, out: out.setdefault('msg', msg) and False, output) |
| reader.add_message(ReadyForQuery, lambda msg: 1) |
| reader.delay_raising_exception = True |
| reader.handle_messages() |
| |
| old_reader.return_value((None, output['msg'])) |
| |
| @sync_on_error |
| def fetch_rows(self, portal, row_count, row_desc): |
| self.verifyState("ready") |
| |
| self._send(Execute(portal, row_count)) |
| self._send(Flush()) |
| self._flush() |
| rows = [] |
| |
| reader = MessageReader(self) |
| reader.add_message(DataRow, self._fetch_datarow, rows, row_desc) |
| reader.add_message(PortalSuspended, lambda msg: 1) |
| reader.add_message(CommandComplete, self._fetch_commandcomplete, portal) |
| retval = reader.handle_messages() |
| |
| # retval = 2 when command complete, indicating that we've hit the |
| # end of the available data for this command |
| return (retval == 2), rows |
| |
| def _fetch_datarow(self, msg, rows, row_desc): |
| rows.append( |
| [ |
| types.py_value( |
| msg.fields[i], |
| row_desc.fields[i], |
| client_encoding=self._client_encoding, |
| integer_datetimes=self._integer_datetimes, |
| record_field_names=self._record_field_names |
| ) |
| for i in range(len(msg.fields)) |
| ] |
| ) |
| |
| def _fetch_commandcomplete(self, msg, portal): |
| self._send(ClosePortal(portal)) |
| self._send(Sync()) |
| self._flush() |
| |
| reader = MessageReader(self) |
| reader.add_message(ReadyForQuery, self._fetch_commandcomplete_rfq) |
| reader.add_message(CloseComplete, lambda msg: False) |
| reader.handle_messages() |
| |
| return 2 # signal end-of-data |
| |
| def _fetch_commandcomplete_rfq(self, msg): |
| self._state = "ready" |
| return True |
| |
| # Send a Sync message, then read and discard all messages until we |
| # receive a ReadyForQuery message. |
| def _sync(self): |
| # it is assumed _sync is called from sync_on_error, which holds |
| # a _sock_lock throughout the call |
| self._send(Sync()) |
| self._flush() |
| reader = MessageReader(self) |
| reader.ignore_unhandled_messages = True |
| reader.add_message(ReadyForQuery, lambda msg: True) |
| reader.handle_messages() |
| |
| def close_statement(self, statement): |
| if self._state == "closed": |
| return |
| self.verifyState("ready") |
| self._sock_lock.acquire() |
| try: |
| self._send(ClosePreparedStatement(statement)) |
| self._send(Sync()) |
| self._flush() |
| |
| reader = MessageReader(self) |
| reader.add_message(CloseComplete, lambda msg: 0) |
| reader.add_message(ReadyForQuery, lambda msg: 1) |
| reader.handle_messages() |
| finally: |
| self._sock_lock.release() |
| |
| def close_portal(self, portal): |
| if self._state == "closed": |
| return |
| self.verifyState("ready") |
| self._sock_lock.acquire() |
| try: |
| self._send(ClosePortal(portal)) |
| self._send(Sync()) |
| self._flush() |
| |
| reader = MessageReader(self) |
| reader.add_message(CloseComplete, lambda msg: 0) |
| reader.add_message(ReadyForQuery, lambda msg: 1) |
| reader.handle_messages() |
| finally: |
| self._sock_lock.release() |
| |
| def close(self): |
| self._sock_lock.acquire() |
| try: |
| self._send(Terminate()) |
| self._flush() |
| self._sock.close() |
| self._state = "closed" |
| finally: |
| self._sock_lock.release() |
| |
| def _onParameterStatusReceived(self, msg): |
| if msg.key == "client_encoding": |
| self._client_encoding = msg.value |
| elif msg.key == "integer_datetimes": |
| self._integer_datetimes = (msg.value == "on") |
| |
| def handleNoticeResponse(self, msg): |
| self.NoticeReceived(msg) |
| |
| def handleParameterStatus(self, msg): |
| self.ParameterStatusReceived(msg) |
| |
| def handleNotificationResponse(self, msg): |
| self.NotificationReceived(msg) |
| |
| def fileno(self): |
| # This should be safe to do without a lock |
| return self._sock.fileno() |
| |
| def isready(self): |
| self._sock_lock.acquire() |
| try: |
| rlst, _wlst, _xlst = select.select([self], [], [], 0) |
| if not rlst: |
| return False |
| |
| self._sync() |
| return True |
| finally: |
| self._sock_lock.release() |
| |
| message_types = { |
| "N": NoticeResponse, |
| "R": AuthenticationRequest, |
| "S": ParameterStatus, |
| "K": BackendKeyData, |
| "Z": ReadyForQuery, |
| "T": RowDescription, |
| "E": ErrorResponse, |
| "D": DataRow, |
| "C": CommandComplete, |
| "1": ParseComplete, |
| "2": BindComplete, |
| "3": CloseComplete, |
| "s": PortalSuspended, |
| "n": NoData, |
| "t": ParameterDescription, |
| "A": NotificationResponse, |
| "c": CopyDone, |
| "d": CopyData, |
| "G": CopyInResponse, |
| "H": CopyOutResponse, |
| } |
| |
| |