blob: 6a474ce5fb755225bfc62fede5b3f3aebb8a4832 [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""
TODO: module docs
"""
import sys
import os
import stat
try:
from pygresql import pgdb
from gppylib.commands.unix import UserId
except ImportError, e:
sys.exit('Error: unable to import module: ' + str(e))
from gppylib import gplog
logger = gplog.get_default_logger()
class ConnectionError(StandardError): pass
class Pgpass():
""" Class for handling .pgpass file.
"""
entries = []
valid_pgpass = True
def __init__(self):
HOME = os.getenv('HOME')
PGPASSFILE = os.getenv('PGPASSFILE', '%s/.pgpass' % HOME)
if not os.path.exists(PGPASSFILE):
return
st_info = os.stat(PGPASSFILE)
mode = str(oct(st_info[stat.ST_MODE] & 0777))
if mode != "0600":
print 'WARNING: password file "%s" has group or world access; permissions should be u=rw (0600) or less' % PGPASSFILE
self.valid_pgpass = False
return
try:
fp = open(PGPASSFILE, 'r')
try:
lineno = 1
for line in fp:
line = line.strip()
if line.startswith('#'):
continue
try:
(hostname, port, database, username, password) = line.strip().split(':')
entry = {'hostname': hostname,
'port': port,
'database': database,
'username': username,
'password': password }
self.entries.append(entry)
except:
print 'Invalid line in .pgpass file. Line number %d' % lineno
lineno += 1
except IOError:
pass
finally:
if fp: fp.close()
except OSError:
pass
def get_password(self, username, hostname, port, database):
for entry in self.entries:
if ((entry['hostname'] == hostname or entry['hostname'] == '*') and
(entry['port'] == str(port) or entry['port'] == '*') and
(entry['database'] == database or entry['database'] == '*') and
(entry['username'] == username or entry['username'] == '*')):
return entry['password']
return None
def pgpass_valid(self):
return self.valid_pgpass
class DbURL:
""" DbURL is used to store all of the data required to get at a PG
or GP database.
"""
pghost='foo'
pgport=5432
pgdb='template1'
pguser='username'
pgpass='pass'
timeout=None
retries=None
def __init__(self,hostname=None,port=0,dbname=None,username=None,password=None,timeout=None,retries=None):
if hostname is None:
self.pghost = os.environ.get('PGHOST', 'localhost')
else:
self.pghost = hostname
if port is 0:
self.pgport = int(os.environ.get('PGPORT', '5432'))
else:
self.pgport = int(port)
if dbname is None:
self.pgdb = os.environ.get('PGDATABASE', 'template1')
else:
self.pgdb = dbname
if username is None:
self.pguser = os.environ.get('PGUSER', os.environ.get('USER', UserId.local('Get uid')))
if self.pguser is None or self.pguser == '':
raise Exception('Both $PGUSER and $USER env variables are not set!')
else:
self.pguser = username
if password is None:
pgpass = Pgpass()
if pgpass.pgpass_valid():
password = pgpass.get_password(self.pguser, self.pghost, self.pgport, self.pgdb)
if password:
self.pgpass = password
else:
self.pgpass = os.environ.get('PGPASSWORD', None)
else:
self.pgpass = password
if timeout is not None:
self.timeout = int(timeout)
if retries is None:
self.retries = 1
else:
self.retries = int(retries)
def __str__(self):
# MPP-13617
def canonicalize(s):
if ':' not in s: return s
return '[' + s + ']'
return "%s:%d:%s:%s:%s" % \
(canonicalize(self.pghost),self.pgport,self.pgdb,self.pguser,self.pgpass)
def connect(dburl, utility=False, verbose=False, encoding=None, allowSystemTableMods=None, upgrade=False):
if utility:
options = '-c gp_session_role=utility'
else:
options = ''
# MPP-13779, et al
if allowSystemTableMods in ['dml', 'ddl', 'all']:
options += ' -c allow_system_table_mods=' + allowSystemTableMods
elif allowSystemTableMods is not None:
raise Exception('allowSystemTableMods invalid: %s' % allowSystemTableMods)
# gpmigrator needs gpstart to make master connection in maintenance mode
if upgrade:
options += ' -c gp_maintenance_conn=true'
# bypass pgdb.connect() and instead call pgdb._connect_
# to avoid silly issues with : in ipv6 address names and the url string
#
dbbase = dburl.pgdb
dbhost = dburl.pghost
dbport = int(dburl.pgport)
dbopt = options
dbtty = "1"
dbuser = dburl.pguser
dbpasswd = dburl.pgpass
timeout = dburl.timeout
cnx = None
# MPP-14121, use specified connection timeout
#
if timeout is not None:
cstr = "dbname=%s connect_timeout=%s" % (dbbase, timeout)
retries = dburl.retries
else:
cstr = "dbname=%s" % dbbase
retries = 1
(logger.info if timeout is not None else logger.debug)("Connecting to %s" % cstr)
for i in range(retries):
try:
cnx = pgdb._connect_(cstr, dbhost, dbport, dbopt, dbtty, dbuser, dbpasswd)
break
except pgdb.InternalError, e:
if 'timeout expired' in str(e):
logger.warning('Timeout expired connecting to %s, attempt %d/%d' % (dbbase, i+1, retries))
continue
raise
if cnx is None:
raise ConnectionError('Failed to connect to %s' % dbbase)
conn = pgdb.pgdbCnx(cnx)
#by default, libpq will print WARNINGS to stdout
if not verbose:
cursor=conn.cursor()
cursor.execute("SET CLIENT_MIN_MESSAGES='ERROR'")
conn.commit()
cursor.close()
# set client encoding if needed
if encoding:
cursor=conn.cursor()
cursor.execute("SET CLIENT_ENCODING='%s'" % encoding)
conn.commit()
cursor.close()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
conn.__class__.__enter__, conn.__class__.__exit__ = __enter__, __exit__
return conn
def execSQL(conn,sql):
"""
If necessary, user must invoke conn.commit().
Do *NOT* violate that API here without considering
the existing callers of this function.
"""
cursor=conn.cursor()
cursor.execute(sql)
return cursor
def execSQLForSingletonRow(conn, sql):
"""
Run SQL that returns exactly one row, and return that one row
TODO: Handle like gppylib.system.comfigurationImplGpdb.fetchSingleOutputRow().
In the event of the wrong number of rows/columns, some logging would be helpful...
"""
cursor=conn.cursor()
cursor.execute(sql)
if cursor.rowcount != 1 :
raise UnexpectedRowsError(1, cursor.rowcount, sql)
res = cursor.fetchall()[0]
cursor.close()
return res
class UnexpectedRowsError(Exception):
def __init__(self, expected, actual, sql):
self.expected, self.actual, self.sql = expected, actual, sql
Exception.__init__(self, "SQL retrieved %d rows but %d was expected:\n%s" % \
(self.actual, self.expected, self.sql))
def execSQLForSingleton(conn, sql):
"""
Run SQL that returns exactly one row and one column, and return that cell
TODO: Handle like gppylib.system.comfigurationImplGpdb.fetchSingleOutputRow().
In the event of the wrong number of rows/columns, some logging would be helpful...
"""
row = execSQLForSingletonRow(conn, sql)
if len(row) > 1:
raise Exception("SQL retrieved %d columns but 1 was expected:\n%s" % \
(len(row), sql))
return row[0]
def executeUpdateOrInsert(conn, sql, expectedRowUpdatesOrInserts):
cursor=conn.cursor()
cursor.execute(sql)
if cursor.rowcount != expectedRowUpdatesOrInserts :
raise Exception("SQL affected %s rows but %s were expected:\n%s" % \
(cursor.rowcount, expectedRowUpdatesOrInserts, sql))
return cursor