blob: ca93b114a131e8e0ea89320b30b954cd1bb37949 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Copyright (C) 2005-2009 Edgewall Software
# Copyright (C) 2005-2006 Christopher Lenz <cmlenz@gmx.de>
# Copyright (C) 2005 Jeff Weiss <trac@jeffweiss.org>
# Copyright (C) 2006 Andres Salomon <dilinger@athenacr.com>
# All rights reserved.
#
# This software is licensed as described in the file COPYING, which
# you should have received as part of this distribution. The terms
# are also available at http://trac.edgewall.org/wiki/TracLicense.
#
# This software consists of voluntary contributions made by many
# individuals. For the exact contribution history, see the revision
# history and logs, available at http://trac.edgewall.org/log/.
import os, re, types
from genshi.core import Markup
from trac.core import *
from trac.config import Option
from trac.db.api import IDatabaseConnector, _parse_db_str
from trac.db.util import ConnectionWrapper, IterableCursor
from trac.util import as_int, get_pkginfo
from trac.util.compat import close_fds
from trac.util.text import exception_to_unicode, to_unicode
from trac.util.translation import _
_like_escape_re = re.compile(r'([/_%])')
try:
import MySQLdb
import MySQLdb.cursors
has_mysqldb = True
class MySQLUnicodeCursor(MySQLdb.cursors.Cursor):
def _convert_row(self, row):
return tuple(v.decode('utf-8') if isinstance(v, str) else v
for v in row)
def fetchone(self):
row = super(MySQLUnicodeCursor, self).fetchone()
return self._convert_row(row) if row else None
def fetchmany(self, num):
rows = super(MySQLUnicodeCursor, self).fetchmany(num)
return [self._convert_row(row) for row in rows] \
if rows is not None else []
def fetchall(self):
rows = super(MySQLUnicodeCursor, self).fetchall()
return [self._convert_row(row) for row in rows] \
if rows is not None else []
except ImportError:
has_mysqldb = False
# Mapping from "abstract" SQL types to DB-specific types
_type_map = {
'int64': 'bigint',
}
class MySQLConnector(Component):
"""Database connector for MySQL version 4.1 and greater.
Database URLs should be of the form:
{{{
mysql://user[:password]@host[:port]/database[?param1=value&param2=value]
}}}
The following parameters are supported:
* `compress`: Enable compression (0 or 1)
* `init_command`: Command to run once the connection is created
* `named_pipe`: Use a named pipe to connect on Windows (0 or 1)
* `read_default_file`: Read default client values from the given file
* `read_default_group`: Configuration group to use from the default file
* `unix_socket`: Use a Unix socket at the given path to connect
"""
implements(IDatabaseConnector)
mysqldump_path = Option('trac', 'mysqldump_path', 'mysqldump',
"""Location of mysqldump for MySQL database backups""")
def __init__(self):
self._version = None
self.error = None
def get_supported_schemes(self):
if not has_mysqldb:
self.error = _("Cannot load Python bindings for MySQL")
yield ('mysql', -1 if self.error else 1)
def get_connection(self, path, log=None, user=None, password=None,
host=None, port=None, params={}):
cnx = MySQLConnection(path, log, user, password, host, port, params)
if not self._version:
self._version = get_pkginfo(MySQLdb).get('version',
MySQLdb.__version__)
mysql_info = 'server: "%s", client: "%s", thread-safe: %s' % \
(cnx.cnx.get_server_info(),
MySQLdb.get_client_info(),
MySQLdb.thread_safe())
self.env.systeminfo.extend([('MySQL', mysql_info),
('MySQLdb', self._version)])
self.required = True
return cnx
def get_exceptions(self):
return MySQLdb
def init_db(self, path, schema=None, log=None, user=None, password=None,
host=None, port=None, params={}):
cnx = self.get_connection(path, log, user, password, host, port,
params)
cursor = cnx.cursor()
utf8_size = {'utf8': 3, 'utf8mb4': 4}.get(cnx.charset)
if schema is None:
from trac.db_default import schema
for table in schema:
for stmt in self.to_sql(table, utf8_size=utf8_size):
self.log.debug(stmt)
cursor.execute(stmt)
cnx.commit()
def _collist(self, table, columns, utf8_size=3):
"""Take a list of columns and impose limits on each so that indexing
works properly.
Some Versions of MySQL limit each index prefix to 1000 bytes total,
with a max of 767 bytes per column.
"""
cols = []
limit_col = 767 / utf8_size
limit = min(1000 / (utf8_size * len(columns)), limit_col)
for c in columns:
name = '`%s`' % c
table_col = filter((lambda x: x.name == c), table.columns)
if len(table_col) == 1 and table_col[0].type.lower() == 'text':
if table_col[0].key_size is not None:
name += '(%d)' % min(table_col[0].key_size, limit_col)
elif name == '`rev`':
name += '(20)'
elif name == '`path`':
name += '(%d)' % limit_col
elif name == '`change_type`':
name += '(2)'
else:
name += '(%s)' % limit
# For non-text columns, we simply throw away the extra bytes.
# That could certainly be optimized better, but for now let's KISS.
cols.append(name)
return ','.join(cols)
def to_sql(self, table, utf8_size=3):
sql = ['CREATE TABLE %s (' % table.name]
coldefs = []
for column in table.columns:
ctype = column.type
ctype = _type_map.get(ctype, ctype)
if column.auto_increment:
ctype = 'INT UNSIGNED NOT NULL AUTO_INCREMENT'
# Override the column type, as a text field cannot
# use auto_increment.
column.type = 'int'
coldefs.append(' `%s` %s' % (column.name, ctype))
if len(table.key) > 0:
coldefs.append(' PRIMARY KEY (%s)' %
self._collist(table, table.key,
utf8_size=utf8_size))
sql.append(',\n'.join(coldefs) + '\n)')
yield '\n'.join(sql)
for index in table.indices:
unique = 'UNIQUE' if index.unique else ''
yield 'CREATE %s INDEX %s_%s_idx ON %s (%s);' % (unique, table.name,
'_'.join(index.columns), table.name,
self._collist(table, index.columns, utf8_size=utf8_size))
def alter_column_types(self, table, columns):
"""Yield SQL statements altering the type of one or more columns of
a table.
Type changes are specified as a `columns` dict mapping column names
to `(from, to)` SQL type tuples.
"""
alterations = []
for name, (from_, to) in sorted(columns.iteritems()):
to = _type_map.get(to, to)
if to != _type_map.get(from_, from_):
alterations.append((name, to))
if alterations:
yield "ALTER TABLE %s %s" % (table,
', '.join("MODIFY %s %s" % each
for each in alterations))
def backup(self, dest_file):
from subprocess import Popen, PIPE
db_url = self.env.config.get('trac', 'database')
scheme, db_prop = _parse_db_str(db_url)
db_params = db_prop.setdefault('params', {})
db_name = os.path.basename(db_prop['path'])
args = [self.mysqldump_path]
if 'host' in db_prop:
args.extend(['-h', db_prop['host']])
if 'port' in db_prop:
args.extend(['-P', str(db_prop['port'])])
if 'user' in db_prop:
args.extend(['-u', db_prop['user']])
for name, value in db_params.iteritems():
if name == 'compress' and as_int(value, 0):
args.append('--compress')
elif name == 'named_pipe' and as_int(value, 0):
args.append('--protocol=pipe')
elif name == 'read_default_file': # Must be first
args.insert(1, '--defaults-file=' + value)
elif name == 'unix_socket':
args.extend(['--protocol=socket', '--socket=' + value])
elif name not in ('init_command', 'read_default_group'):
self.log.warning("Invalid connection string parameter '%s'",
name)
args.extend(['-r', dest_file, db_name])
environ = os.environ.copy()
if 'password' in db_prop:
environ['MYSQL_PWD'] = str(db_prop['password'])
try:
p = Popen(args, env=environ, stderr=PIPE, close_fds=close_fds)
except OSError, e:
raise TracError(_("Unable to run %(path)s: %(msg)s",
path=self.mysqldump_path,
msg=exception_to_unicode(e)))
errmsg = p.communicate()[1]
if p.returncode != 0:
raise TracError(_("mysqldump failed: %(msg)s",
msg=to_unicode(errmsg.strip())))
if not os.path.exists(dest_file):
raise TracError(_("No destination file created"))
return dest_file
class MySQLConnection(ConnectionWrapper):
"""Connection wrapper for MySQL."""
poolable = True
def __init__(self, path, log, user=None, password=None, host=None,
port=None, params={}):
if path.startswith('/'):
path = path[1:]
if password == None:
password = ''
if port == None:
port = 3306
opts = {}
for name, value in params.iteritems():
if name in ('init_command', 'read_default_file',
'read_default_group', 'unix_socket'):
opts[name] = value
elif name in ('compress', 'named_pipe'):
opts[name] = as_int(value, 0)
else:
self.log.warning("Invalid connection string parameter '%s'",
name)
cnx = MySQLdb.connect(db=path, user=user, passwd=password, host=host,
port=port, charset='utf8', **opts)
if hasattr(cnx, 'encoders'):
# 'encoders' undocumented but present since 1.2.1 (r422)
cnx.encoders[Markup] = cnx.encoders[types.UnicodeType]
cursor = cnx.cursor()
cursor.execute("SHOW VARIABLES WHERE "
" variable_name='character_set_database'")
self.charset = cursor.fetchone()[1]
if self.charset != 'utf8':
cnx.query("SET NAMES %s" % self.charset)
cnx.store_result()
ConnectionWrapper.__init__(self, cnx, log)
self._is_closed = False
def cast(self, column, type):
if type == 'int' or type == 'int64':
type = 'signed'
elif type == 'text':
type = 'char'
return 'CAST(%s AS %s)' % (column, type)
def concat(self, *args):
return 'concat(%s)' % ', '.join(args)
def like(self):
"""Return a case-insensitive LIKE clause."""
return "LIKE %%s COLLATE %s_general_ci ESCAPE '/'" % self.charset
def like_escape(self, text):
return _like_escape_re.sub(r'/\1', text)
def quote(self, identifier):
"""Return the quoted identifier."""
return "`%s`" % identifier.replace('`', '``')
def get_last_id(self, cursor, table, column='id'):
return cursor.lastrowid
def update_sequence(self, cursor, table, column='id'):
# MySQL handles sequence updates automagically
pass
def rollback(self):
self.cnx.ping()
try:
self.cnx.rollback()
except MySQLdb.ProgrammingError:
self._is_closed = True
def close(self):
if not self._is_closed:
try:
self.cnx.close()
except MySQLdb.ProgrammingError:
pass # this error would mean it's already closed. So, ignore
self._is_closed = True
def cursor(self):
return IterableCursor(MySQLUnicodeCursor(self.cnx), self.log)