blob: daef17d290bac6bb012b0d228bf72482c62933e8 [file] [log] [blame]
#!/bin/env python
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
util模块,提供一些辅助函数
Date: 2014/08/05 17:19:26
"""
import datetime
import os
import sys
import struct
import random
import inspect
import time
import subprocess
import pexpect
from decimal import Decimal
import hashlib
import functools
import palo_logger
LOG = palo_logger.Logger.getLogger()
L = palo_logger.StructedLogMessage
def pretty(data):
"""
将data转成string方便打印,需要list中的元素实现__str__方法
"""
result = ""
if isinstance(data, dict):
result += '{'
for key, value in data.iteritems():
result += '%s:%s' % (str(key), pretty(value))
result += ', '
result = result.rstrip(", ")
result += '}'
elif isinstance(data, list):
for val in data:
result += '['
result += pretty(val)
result += ', '
result = result.rstrip(", ")
result += ']'
else:
result += str(data)
return result
def gen_name_list(prefix=""):
"""
根据调用的文件名和函数名生成database_name, table_name, index_name
"""
file_name = ""
for command in sys.argv:
if command.endswith(".py"):
file_name = command
dir_name, file_name = os.path.split(os.path.abspath(file_name))
file_name = file_name[:-3]
function_name = inspect.stack()[1][3]
database_name = "%s_%s_%s_db" % (prefix, file_name, function_name)
LOG.info(L('test case', file_name=file_name, case=function_name))
if len(database_name) > 60:
md5_str = get_md5(database_name)
database_name = 'd' + md5_str[-4:] + '_' + database_name[-50:]
database_name = database_name.lstrip("_")
table_name = "%s_%s_%s_tb" % (prefix, file_name, function_name)
if len(table_name) > 60:
md5_str = get_md5(table_name)
table_name = 't' + md5_str[-4:] + '_' + table_name[-50:]
table_name = table_name.lstrip("_")
index_name = "%s_%s_%s_index" % (prefix, file_name, function_name)
if len(index_name) > 60:
md5_str = get_md5(index_name)
index_name = 'i' + md5_str[-4:] + '_' + index_name[-50:]
index_name = index_name.lstrip("_")
return database_name, table_name, index_name
def gen_num_format_name_list(prefix=""):
"""
根据调用的文件名和函数名生成database_name, table_name, index_name
在文件名和函数名生成的字符串过长时,取其md5值的一部分,尽可能减少case之间重名的问题
"""
file_name = ""
for command in sys.argv:
if command.endswith(".py"):
file_name = command
dir_name, file_name = os.path.split(os.path.abspath(file_name))
file_name = file_name[:-3]
function_name = inspect.stack()[1][3]
# 分别获取database_name, table_name, index_name
suffixes = ['db', 'tb', 'index']
names = []
for suffix in suffixes:
name = "%s_%s_%s_%s" % (prefix, file_name, function_name, suffix)
if len(name) > 60:
# palo里的数据库和表名必须以字符串开始
md5_str = get_md5(name)
name = 's' + md5_str[-4:] + '_' + name[-50:]
name = name.lstrip("_")
names.append(name)
return names
def get_label():
"""
生成label字符串
"""
fmt = '%d_%H_%M_%S_%f'
return "label_%s_%d" % (datetime.datetime.now().strftime(fmt), random.randint(0, 2 ** 31 - 1))
def get_snapshot_label(prefix=None):
"""生成snapshot label"""
if prefix is None:
prefix = 'random'
fmt = '%d_%H_%M_%S_%f'
return "%s_snapshot_%s_%d" % (prefix, datetime.datetime.now().strftime(fmt),
random.randint(0, 2 ** 31 - 1))
def column_to_sql(column, set_null=False):
"""
将column 4元组转成palo格式的sql字符串
(column_name, column_type, aggregation_type, default_value)
"""
sql = "%s %s" % (column[0], column[1])
#value列有聚合方法
#key列有默认值时指定此项为None
if len(column) > 2:
if column[2]:
sql = "%s %s" % (sql, column[2])
if set_null is False:
sql = sql + ' NOT NULL'
elif set_null is True:
sql = sql + ' NULL'
else:
pass
#有默认值的列
if len(column) > 3:
if column[3] is None:
sql = '%s DEFAULT NULL' % sql
else:
sql = '%s DEFAULT "%s"' % (sql, column[3])
return sql
def column_to_no_agg_sql(column, set_null=False):
"""
将column 4元组转成SQL字符串
(column_name, column_type, aggregation_type, default_value)
"""
sql = "%s %s" % (column[0], column[1])
if set_null is False:
sql = sql + ' NOT NULL'
elif set_null is True:
sql = sql + ' NULL'
else:
pass
#有默认值的列
if len(column) > 3:
sql = '%s DEFAULT "%s"' % (sql, column[3])
return sql
def convert_agg_column_to_no_agg_column(column_list):
"""
将agg column 4元组转成no agg column 4元组
(column_name, column_type, aggregation_type, default_value)
"""
no_agg_column = []
for i in column_list:
if len(i) == 2:
no_agg_column.append(i)
elif len(i) == 3:
no_agg_column.append((i[0], i[1]))
elif len(i) == 4:
no_agg_column.append((i[0], i[1], "", i[3]))
return no_agg_column
def file_to_insert_sql_value(file_name, to_str=False):
"""
将文件中的列转为insert into类型sql中value字段
会将文件中的N转为NULL(broker load和insert中空值的差别)
:param file_name:
:param to_str: 默认是false,这样的话对数字不会加双引号
:return:
"""
fp = open(file_name, 'r')
total_sql_list = []
for line in fp.readlines():
items = line.split('\n')[0].split('\t')
str = ''
for i in range(len(items)):
item = items[i]
if not to_str and is_number(item):
str += item
elif item == '\\N':
str += 'NULL'
else:
str += '"' + item + '"'
if i < len(items) - 1:
str += ','
total_sql_list.append('(' + str + ')')
return ','.join(total_sql_list)
def is_number(s):
"""
验证字符串是否为数字/浮点数
:param s:
:return:
"""
try:
if s == 'NaN':
return False
float(s)
return True
except ValueError:
return False
def exec_cmd(cmd, user=None, password=None, host=None, timeout=30):
"""
执行shell命令
"""
if user is not None and password is not None and host is not None:
if cmd.find("'") >= 0:
raise Exception("We do not support quote ' now!")
output, status = pexpect.run("ssh %s@%s '%s'" % (user, host, cmd),
timeout=timeout, withexitstatus=True,
events = {"continue connecting":"yes\n", "password:":"%s\n" % password})
LOG.info(L('exec remote cmd', cmd=cmd, output=output, status=status))
else:
status, output = subprocess.getstatusoutput(cmd)
return status, output
def compare(a, b):
"""compare data to None"""
assert isinstance(a, (tuple, list))
assert isinstance(b, (tuple, list))
for i in range(0, len(a)):
if a[i] is not None and b[i] is not None:
if a[i] > b[i]:
return 1
elif a[i] < b[i]:
return -1
else:
continue
elif a[i] is None and b[i] is None:
continue
elif a[i] is None:
return -1
else:
return 1
return 0
def check(palo_result, mysql_result, force_order=False):
"""
check the palo result and mysql result
1. data format
2. order
"""
if force_order and palo_result != ():
try:
palo_result = list(palo_result)
mysql_result = list(mysql_result)
palo_result.sort()
mysql_result.sort()
except Exception as e:
if isinstance(e, TypeError):
palo_result.sort(key=functools.cmp_to_key(compare))
mysql_result.sort(key=functools.cmp_to_key(compare))
if mysql_result != palo_result:
if len(palo_result) != len(mysql_result):
LOG.error(L('check error', palo_length=len(palo_result),
expect_length=len(mysql_result)))
assert 0 == 1, "\npalo_result length: %d\nmysql_result length:%d" % \
(len(palo_result), len(mysql_result))
for palo_line, mysql_line in zip(palo_result, mysql_result):
if len(palo_line) != len(mysql_line):
LOG.error(L('check error, number of columns not match',
palo_colums_number=len(palo_line), expect_columns_number=len(mysql_line)))
LOG.error(L('check error', palo_line=str(palo_line)))
LOG.error(L('check error', expect_line=str(mysql_line)))
assert 0 == 1, "\npalo line: %s\nmysql line:%s" % (str(palo_line), str(mysql_line))
if palo_line != mysql_line:
for palo_data, mysql_data in zip(palo_line, mysql_line):
if palo_data != mysql_data:
same = False
# str vs bytes
if isinstance(palo_data, (str, bytes)) and isinstance(mysql_data, (str, bytes)):
if isinstance(palo_data, bytes):
palo_data = str(palo_data, "utf8")
if isinstance(mysql_data, bytes):
mysql_data = str(mysql_data, "utf8")
if palo_data == mysql_data:
return True
# chinese & palo largeint return unicode
if isinstance(palo_data, str):
if palo_data == str(mysql_data):
same = True
if palo_data == mysql_data:
same = True
# blank
if isinstance(palo_data, str):
if palo_data.strip() == "" and mysql_data == "":
same = True
# null string
if palo_data is None and mysql_data == "":
same = True
#float
elif isinstance(mysql_data, (Decimal, float)) \
and isinstance(palo_data, (Decimal, float)):
same = check_float(float(palo_data), float(mysql_data))
# list
elif isinstance(mysql_data, list):
same = check_list(palo_data, mysql_data)
if not same:
LOG.error(L('check error', palo_data=palo_data))
LOG.error(L('check error', expect_data=mysql_data))
LOG.error(L('check error', palo_line=palo_line))
LOG.error(L('check error', expect_line=mysql_line))
assert 0 == 1, "\ndiff data: \npalo:%s; \nexpect:%s;\npalo line:%s\nexpect line: %s" \
% (palo_data, mysql_data, palo_line, mysql_line)
def check_float(data1, data2):
"""
check float
"""
if float(data1) == float(data2):
return True
if abs(float(data1) - float(data2)) < 0.001:
return True
if data2 != 0 and abs(float(data1) / float(data2) - 1.0) < 0.001:
return True
return False
def check_list(data1, data2):
"""
check list
"""
if data1 is None and data2 is None:
return True
elif data1 is None or data2 is None:
return False
elif len(data2) > 0 and isinstance(data2[0], float):
for d1, d2 in zip(data1, data2):
if not check_float(d1, d2):
return False
else:
return data1 == data2
return True
def convert_dict2property(properties):
"""convert mat to property str"""
sql = '('
for k, v in properties.items():
if v is not None:
sql += ' "%s" = "%s", ' % (k, v)
sql = sql.rstrip(', ')
sql += ')'
return sql
def get_timestamp(palo_data):
"""将datetime 转为时间戳"""
timeArray = palo_data[0][0].timetuple()
timestamp = time.mktime(timeArray)
return timestamp
def check2_time_zone(palo_re, mysql_res, gap=2):
"""比较两个datetime时间,转为时间戳,允许的差值是gap"""
palo_data = get_timestamp(palo_re)
mysql_data = get_timestamp(mysql_res)
assert abs(palo_data - mysql_data) <= 2 + gap, "palo_res %s, mysql_res %s, \
excepted gap < %s" % (palo_data, mysql_data, gap)
def assert_return(expected_flag, expected_msg, func, *args, **kwargs):
"""
验证一个函数的执行结果
验证是否正确;如果错误的话,验证错误信息
:param expected_flag:True/False
:param expected_msg:
:param func:
:param args:
:param kwargs:
:return:
"""
try:
func(*args, **kwargs)
except Exception as e:
print(str(e))
print(expected_msg)
LOG.info(L('get an error', msg=str(e)))
assert not expected_flag, "real sql status is False, expect %s" % expected_flag
assert expected_msg in str(e), "expect:%s, doris:%s" % (expected_msg, str(e))
else:
assert expected_flag, "real sql status is True, expect %s" % expected_flag
def get_md5(target):
"""
计算字符串的md5值
:param target:
:return:
"""
obj = hashlib.md5(b"fkldsajlkfjlaksdjfkladsjfkladsjkldsjfklfjs")
obj.update(target.encode("utf-8"))
return obj.hexdigest()
def assert_return_flag(expected_flag, func, *args, **kwargs):
"""
验证一个函数的返回值
验证是否正确
:param expected_flag:
:param func:
:param args:
:param kwargs:
:return:
"""
return_flag = func(*args, **kwargs)
assert expected_flag == return_flag
def bitmap_index_to_sql(bitmap_index, set_null=False):
"""
bitmap_index: 由3元组(index_name, column_name, index_type)组成
"""
sql = "INDEX %s (%s) USING %s" % (bitmap_index[0], bitmap_index[1], bitmap_index[2])
return sql
def get_attr(ret, column_idx):
"""
获取返回结果ret的第n列
ret = client.show_backend()
be_ip = get_attr(ret, palo_job.BackendShowInfo.IP)
"""
return_list = list()
for record in ret:
return_list.append(record[column_idx])
LOG.info(L('get column from ret', idx=column_idx, ret=return_list))
return return_list
def get_attr_condition_value(ret, condition_col_idx, condition_value, retrun_col_idx=None):
"""寻找ret的第condition_col_idx列的值为condition_value的行,返回该行的第retrun_col_idx列,只返回一个符合条件的"""
if retrun_col_idx is None:
retrun_col_idx = condition_col_idx
for record in ret:
if record[condition_col_idx] == condition_value:
LOG.info(L('get result from ret', searced_key=condition_value, value=record[retrun_col_idx]))
return record[retrun_col_idx]
LOG.info(L('can not get result from ret', searched_key=condition_value))
return None
def get_attr_condition_list(ret, condition_col_idx, condition_value, retrun_col_idx=None):
"""寻找ret的第condition_col_idx列的值为condition_value的行,返回所有行的第retrun_col_idx列,返回说有符合条件的"""
result_list = list()
if retrun_col_idx is None:
retrun_col_idx = condition_col_idx
for record in ret:
if record[condition_col_idx] == condition_value:
result_list.append(record[retrun_col_idx])
if len(result_list) == 0:
return None
else:
LOG.info(L('get all result from ret', searced_key=condition_value, value=result_list))
return result_list
def gen_tuple_num_str(begin, end):
"""gen_tuple_num_str(1, 3) -> ('1', '2')"""
return tuple(map(str, range(begin, end)))
def get_string_md5(st):
"""get string md5"""
hl = hashlib.md5()
hl.update(st.encode(encoding='utf-8'))
return hl.hexdigest()