blob: 7b2d0ed72fcc65c5248b71f9a27435afb8330494 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from Crypto.Cipher import AES
import base64
import time
import socket
import string
import hashlib
import shutil
import os
from log import LogFactory
BS = 16
log = LogFactory().get_log(__name__)
def unpad(s): return s[0:-ord(s[-1])]
def current_milli_time(): return int(round(time.time() * 1000))
def pad(s): return s + (BS - len(s) % BS) * chr(BS - len(s) % BS)
def decrypt_password(pass_str, secret):
"""
Decrypts the given password using the given secret. The encryption is assumed to be done
without IV, in AES.
:param str pass_str: Encrypted password string in Base64 encoding
:param str secret: The secret string
:return: The decrypted password
:rtype: str
"""
if pass_str is None or pass_str.strip() == "":
return pass_str.strip()
dec_pass = ""
try:
log.debug("Decrypting cipher text: %s" % pass_str)
bdecoded_pass = base64.b64decode(pass_str.strip())
# secret length should be 16
cipher = AES.new(secret.strip(), AES.MODE_ECB)
# dec_pass = unpad(cipher.decrypt(bdecoded_pass))
dec_pass = cipher.decrypt(bdecoded_pass)
except:
log.exception("Exception occurred while decrypting password")
# remove nonprintable characters that are padded in the decrypted password
dec_pass = filter(lambda x: x in string.printable, dec_pass)
# dec_pass_md5 = hashlib.md5(dec_pass.encode('utf-8')).hexdigest()
# log.debug("Decrypted password md5sum: [%r]" % dec_pass_md5)
return dec_pass
def wait_until_ports_active(ip_address, ports, ports_check_timeout=600000):
"""
Blocks until the given list of ports become active
:param str ip_address: Ip address of the member to be checked
:param list[str] ports: List of ports to be checked
:param int ports_check_timeout: The timeout in milliseconds, defaults to 1000*60*10
:return: void
"""
if ports_check_timeout is None:
ports_check_timeout = 1000 * 60 * 10
log.debug("Port check timeout: %s" % ports_check_timeout)
ports_left = ports
start_time = current_milli_time()
# check ports until all are active or timeout exceeds
while True:
log.info("Waiting for ports to be active: [ip] %s [ports] %s" % (ip_address, ports))
# check each port for activity
for checking_port in list(ports_left):
port_active = check_port_active(ip_address, checking_port)
if port_active:
log.debug("Port %s on host %s active" % (checking_port, ip_address))
ports_left.remove(checking_port)
# if no ports are left to check for activity, return
if len(ports_left) == 0:
log.info("Ports activated: [ip] %r [ports] %r" % (ip_address, ports))
return True
# active = check_ports_active(ip_address, ports)
end_time = current_milli_time()
duration = end_time - start_time
if duration > ports_check_timeout:
log.info("Port check timeout reached: [ip] %s [ports] %s [timeout] %s"
% (ip_address, ports, ports_check_timeout))
return False
time.sleep(5)
def check_port_active(ip_address, port):
"""
Checks the given port on the given host for activity
:param str ip_address: Ip address of the member to be checked
:param str port: The port to be checked
:return: True if the ports are active, False if at least one is not active
:rtype: bool
"""
if port is None:
raise RuntimeError("Cannot check invalid port for activity")
try:
port_int = int(port)
except ValueError:
raise RuntimeError("Cannot check invalid port for activity %s" % port)
s = socket.socket()
s.settimeout(5)
try:
s.connect((ip_address, port_int))
log.debug("Port %s is active" % port)
s.close()
return True
except socket.error:
log.debug("Port %s is not active" % port)
return False
class Utils(object):
@staticmethod
def directory_exists(dir):
"""
Checks if the given directory exists
:param dir: directory to check
:return: True if the directory dir exists, else False
:rtype: bool
"""
try:
return os.path.isdir(dir)
except OSError as e:
log.error("Unable to check directory existance [%s]" % e)
return False
@staticmethod
def copy_directory(src, destination):
"""
Copies if the directory 'src' to 'destination'
:param src: location of directory to copy
:param destination: new directory location
"""
try:
shutil.copytree(src, destination)
log.debug("Directory [%s] copied to [%s]" % (src, destination))
except OSError as e:
log.error('Directory not copied. Error: %s' % e)
@staticmethod
def strip_trailing_slash(string):
"""
If the string has a trailing '/', removes it
:param string: string to check
:return: string without a trailing '/'
:rtype: string
"""
if string.endswith('/'):
return string[:-1]
return string
@staticmethod
def move_directory(src, destination):
"""
Moves if the directory 'src' to 'destination'
:param src: location of directory to move
:param destination: new directory location
"""
try:
shutil.move(src, destination)
log.debug("Directory [%s] moved to [%s]" % (src, destination))
except OSError as e:
log.error('Directory not moved. Error: %s' % e)
@staticmethod
def print_dictionary(dic):
for x in dic:
print (x)
for y in dic[x]:
print (y, ':', dic[x][y])
class IncrementalCeilingListIterator(object):
"""
Iterates through a given list and returns elements. At the end of the list if terminate_at_end is set to false,
the last element will be returned repeatedly. If terminate_at_end is set to true, an IndexError will be thrown.
"""
def __init__(self, intervals, terminate_at_end):
self.__intervals = intervals
self.__index = 0
self.__terminate_at_end = terminate_at_end
def get_next_retry_interval(self):
"""
Retrieves the next element in the list.
:return:
:rtype: int
"""
if self.__index < len(self.__intervals):
next_interval = self.__intervals[self.__index]
self.__index += 1
else:
if self.__terminate_at_end:
raise IndexError("Reached the end of the list")
else:
next_interval = self.__intervals[-1]
return next_interval