blob: 8f70033a30f576d555577d713742640bcf74fc75 [file] [log] [blame]
#!/bin/env python
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""palo client verify"""
import petl
import math
from decimal import Decimal
from collections import OrderedDict
from datetime import datetime
import palo_logger
LOG = palo_logger.Logger.getLogger()
L = palo_logger.StructedLogMessage
class VerifyFile(object):
"""
VerifyFile
"""
def __init__(self, file_name, delimiter='\t'):
self.file_name = file_name
self.delimiter = delimiter
def get_file_name(self):
"""
get file name
"""
return self.file_name
def get_delimiter(self):
"""
get delimiter
"""
return self.delimiter
def __str__(self):
return str(self.file_name)
class Verify(object):
"""verify class"""
def __init__(self, expected_file_list, datas, schema, table_name, database_name, encoding=None):
"""
file:校验文件,可以是str,['file1', 'file2'] or VerifyFile
sql_ret:sql执行的结果
schema:verify校验使用的表的desc结果做schema
verify_by_sql使用的4元组,sql查询结果的schema, 由四元组(name, type, agg_type, default_value)组成的list
table_name, database_name: 生成默认的校验文件名称用
"""
self.expected_file_list = expected_file_list
self.table_name = table_name
self.database_name = database_name
self.schema = schema
self.datas = datas
self.encoding = encoding
@staticmethod
def __get_type_convert_handler(field_type):
""""""
def __int_type(min, max):
"""Return a function that will attempt to parse the value as a number,
"""
def f(v):
"""check and return type
"""
try:
value = int(v)
except (ValueError, TypeError) as e:
raise e
if min <= value <= max:
return int(v)
else:
return None
return f
def __char_type():
"""regurn a function"""
def f(v):
"""check v if null"""
if v == "None":
v = None
return v
return f
tinyint = __int_type(-2 ** 7, 2 ** 7 - 1)
smallint = __int_type(-2 ** 15, 2 ** 15 - 1)
paloint = __int_type(-2 ** 31, 2 ** 31 - 1)
bigint = __int_type(-2 ** 63, 2 ** 63 - 1)
largeint = __int_type(-2 ** 127, 2 ** 127 - 1)
datetime = petl.datetimeparser('%Y-%m-%d %H:%M:%S')
date = petl.dateparser('%Y-%m-%d')
char = __char_type()
varchar = __char_type()
field_type = field_type.lower().split('<')[0]
field_type = field_type.lower().split('(')[0]
field_type_handler_dict = {'char': char, 'varchar': varchar, 'decimal': Decimal,
'tinyint': tinyint, 'smallint': smallint, 'int': paloint,
'bigint': bigint, 'largeint': largeint, 'text': varchar,
'float': float, 'double': float, 'datetime': datetime, 'date': date,
'boolean': tinyint,
'array': varchar, 'decimalv3': Decimal}
return field_type_handler_dict[field_type]
def __get_convert_dict(self):
"""get column type from schema, and get convert func"""
convert_dict = {field[0]: self.__get_type_convert_handler(field[1]) for field in self.schema}
return convert_dict
def __get_field_list(self):
"""get column name from schema"""
field_list = [field[0] for field in self.schema]
return field_list
def __get_key_list(self):
"""get key column from schema"""
key_list = [field[0] for field in self.schema if field[3] == 'true']
return tuple(key_list)
def __get_type_list(self):
"""get column type from schema"""
type_list = [field[1] for field in self.schema]
return type_list
@staticmethod
def __get_aggregate_key(key_list):
"""get key"""
if len(key_list) == 1:
return key_list[0]
else:
return key_list
def __get_aggregation_ordereddict(self):
"""aggregation table value agg func"""
def _sum(l):
items = []
for i in l:
if i is not None:
items.append(i)
if len(items) == 0:
return None
else:
return sum(items)
def __agg_replace(l):
items = []
for i in l:
items.append(i)
return items[-1]
def __agg_replace_if_not_null(l):
""" replace if not null """
items = []
for i in l:
if i is not None:
items.append(i)
if len(items) == 0:
return None
else:
return items[-1]
agg_function_dict = {'max': max, 'min': min, 'sum': _sum, 'replace': __agg_replace,
'replace_if_not_null': __agg_replace_if_not_null}
aggregation = OrderedDict()
aggtype_list = [(field[0], field[5]) for field in self.schema if field[5] != '']
for item in aggtype_list:
aggregation[item[0]] = item[0], agg_function_dict[item[1].lower()]
return aggregation
def __write_data_to_file(self, data_from_database, data_from_file, save_verifyfile_list):
"""将文件中的数据写入tmp文件中"""
if self.encoding is not None:
if save_verifyfile_list[0] is not None:
petl.tocsv(data_from_database, save_verifyfile_list[0].get_file_name(),
encoding=self.encoding, delimiter=save_verifyfile_list[0].get_delimiter())
if save_verifyfile_list[1] is not None:
petl.tocsv(data_from_file, save_verifyfile_list[1].get_file_name(),
encoding=self.encoding, delimiter=save_verifyfile_list[1].get_delimiter())
else:
if save_verifyfile_list[0] is not None:
petl.tocsv(data_from_database, save_verifyfile_list[0].get_file_name(),
delimiter=save_verifyfile_list[0].get_delimiter())
if save_verifyfile_list[1] is not None:
petl.tocsv(data_from_file, save_verifyfile_list[1].get_file_name(),
delimiter=save_verifyfile_list[1].get_delimiter())
@staticmethod
def __check_float(field_of_database, field_of_file, type):
def __adjust_data(num):
if num is None:
return None
else:
num = float(num)
if num == 0.0:
return 0.0
else:
return num / 10 ** (math.floor(math.log10(abs(num))) + 1)
data_of_database = __adjust_data(field_of_database)
data_of_file = __adjust_data(field_of_file)
# 最后一个有效数字可以相差 1,比如: 0.123456001 == 0.123456999 => True
# 0.123456001 == 0.123457999 => True 0.123456001 == 0.123458999 => False
# 0.123456001 == 0.123455999 => True 0.123456001 == 0.123454999 => False
precision = None
if type.lower() == 'float':
precision = 2e-6
elif type.lower() == 'double':
precision = 2e-15
if math.fabs(data_of_database - data_of_file) < precision or \
math.fabs(data_of_database - data_of_file) / data_of_file < 2e-3:
return True
else:
return False
def __check_data(self, data_from_database, data_from_file):
rows_number_of_database = petl.nrows(data_from_database)
rows_number_of_file = petl.nrows(data_from_file)
if rows_number_of_database != rows_number_of_file:
LOG.warning(L("verify data error", lines_of_database=rows_number_of_database,
lines_of_file=rows_number_of_file))
return False
result_of_database = petl.records(data_from_database)
result_of_file = petl.records(data_from_file)
type_list = self.__get_type_list()
for record_of_database, record_of_file in zip(result_of_database, result_of_file):
for field_of_database, field_of_file, field_type in \
zip(record_of_database, record_of_file, type_list):
if field_of_database is None and field_of_file is None:
continue
else:
if field_of_database is None or field_of_file is None:
return False
if field_type.lower() == 'float' or field_type.lower() == 'double':
if not self.__check_float(field_of_database, field_of_file,
type=field_type.lower()):
LOG.error(L("FLOAT VERIFY FAIL", field_of_database=field_of_database,
field_of_file=field_of_file, record_of_database=record_of_database,
record_of_file=record_of_file))
return False
elif field_of_database != field_of_file:
LOG.error(L("VERIFY FAIL", field_of_database=field_of_database,
field_of_file=field_of_file, record_of_database=record_of_database,
record_of_file=record_of_file))
return False
return True
def __get_data_from_database(self):
"""
处理数据库中的数据,datas是client.execute(sql)的结果
"""
key_list = self.__get_key_list()
header = self.__get_field_list()
field_list = self.__get_field_list()
convert_dict = {}
for field in self.schema:
if field[1].lower().startswith('largeint'):
convert_dict[field[0]] = self.__get_type_convert_handler(field[1])
dict_list = []
for row in self.datas:
field_value_dict = {}
for field, value in zip(header, row):
field_value_dict[field] = value
dict_list.append(field_value_dict)
table_database_from = petl.fromdicts(dict_list, header)
table_database_convert = petl.convert(table_database_from, convert_dict)
table_database_sort = petl.sort(table_database_convert, field_list)
table_database_merge_sort = petl.mergesort(table_database_sort,
key=field_list, presorted=False)
return table_database_merge_sort
def __get_data_from_file(self):
"""
从文件中获取数据,排序,按照表的聚合模型处理数据
"""
# 为了兼容以前的代码
if type(self.expected_file_list) is str:
from_verifyfile_list = [VerifyFile(self.expected_file_list, '\t')]
elif type(self.expected_file_list) is list and type(self.expected_file_list[0]) is str:
from_verifyfile_list = [VerifyFile(file, '\t') for file in self.expected_file_list]
elif type(self.expected_file_list) is VerifyFile:
from_verifyfile_list = [self.expected_file_list]
else:
from_verifyfile_list = None
header = self.__get_field_list()
key_list = self.__get_key_list()
field_list = self.__get_field_list()
convert_dict = self.__get_convert_dict()
dup = False
for col in self.schema:
if col[5] == 'NONE':
dup = True
table_file_to_merge_list = []
for etl_file in from_verifyfile_list:
# 读取csv文件数据
table_file_from = petl.fromcsv(etl_file.get_file_name(),
encoding='utf8', delimiter=etl_file.get_delimiter())
# 给数据增加表头
table_file_push = petl.pushheader(table_file_from, header)
# 给数据加类型
table_file_convert = petl.convert(table_file_push, convert_dict)
table_file_to_merge_list.append(table_file_convert)
if not dup:
table_file_merge_sort = petl.mergesort(*table_file_to_merge_list,
key=key_list, presorted=False)
aggregation = self.__get_aggregation_ordereddict()
aggregate_key = self.__get_aggregate_key(key_list)
# 聚合表,按照聚合方式聚合
table_file_aggregate = petl.aggregate(table_file_merge_sort,
key=aggregate_key, aggregation=aggregation,
presorted=True)
table_file_merge_sort = petl.mergesort(table_file_aggregate,
key=key_list, presorted=True)
return table_file_merge_sort
else:
table_file_merge_sort = petl.mergesort(*table_file_to_merge_list,
key=field_list, presorted=False)
return table_file_merge_sort
def __generate_dafault_save_verifyfile_list(self):
"""根据库名,表名生成校验文件的名称"""
name_prefix = ".%s.%s" % (self.database_name, self.table_name)
name_for_database = "%s.%s" % (name_prefix, 'DB')
name_for_file = "%s.%s" % (name_prefix, 'FILE')
return [VerifyFile(name_for_database), VerifyFile(name_for_file)]
def verify(self, save_file_list=None):
"""
崭新的校验函数
"""
LOG.info(L("check file:", file=self.expected_file_list))
self.__adjust_schema_for_verify()
# 获取db中的数据
data_from_database = self.__get_data_from_database()
# 获取file中的文件
data_from_file = self.__get_data_from_file()
if save_file_list is None:
save_file_list = self.__generate_dafault_save_verifyfile_list()
# 分别写入数据
self.__write_data_to_file(data_from_database, data_from_file, save_file_list)
# 返回check结果, true / false
return self.__check_data(data_from_database, data_from_file)
def __adjust_schema_for_verify(self):
adjust_schema = []
for field in self.schema:
adjust_field = list(field)
if adjust_field[3] == 'false':
if adjust_field[5] is not None:
adjust_field[5] = adjust_field[5].split(',')[0]
else:
adjust_field[5] = ''
adjust_schema.append(tuple(adjust_field))
self.schema = tuple(adjust_schema)
return self.schema
def __adjust_schema_for_self_defined_sql(self):
# TODO
# 这个函数可能有问题,以后修改
adjust_schema = []
for column in self.schema:
adjust_column = []
adjust_column.append(column[0])
adjust_column.append(column[1])
adjust_column.append('No')
if len(column) > 2 and column[2] is not None:
adjust_column.append('false')
else:
adjust_column.append('true')
adjust_column.append('N/A')
if len(column) > 2 and column[2] is not None:
adjust_column.append(column[2])
else:
adjust_column.append('')
adjust_schema.append(tuple(adjust_column))
self.schema = adjust_schema
return self.schema
def verify_by_sql(self, save_file_list=None):
"""
校验自定义的SQL语句的查询结果
expected_file_list: VerifyFile对象的list
sql: SQL语句字符串
schema: 查询结果的schema, 由四元组(name, type, agg_type, default_value)组成的list
四元组中后两项可省略, 需要注意的是key列指定默认值是agg_type设置为None
Example -> [("k1", "INT"), ("k2", "CHAR", None, ""), ("v", "DATE", "REPLACE")]
save_file_list: VerifyFile对象的list
"""
self.__adjust_schema_for_self_defined_sql()
data_from_database = self.__get_data_from_database()
data_from_file = self.__get_data_from_file()
if save_file_list is not None:
self.__write_data_to_file(data_from_database, data_from_file, save_file_list)
return self.__check_data(data_from_database, data_from_file)
def verify(file, sql_ret, schema, table_name, database_name, encoding, save_file_list):
"""
verify, schema为palo desc结果
适用于
1. 多个文件的时候,会对文件进行拼接,排序读取
2. 适用于原始文件,palo对原始文件进行过滤、聚合等处理时,无需额外保存校验文件,直接使用原始文件进行处理生成校验文件
"""
verifier = Verify(file, sql_ret, schema, table_name, database_name, encoding)
return verifier.verify(save_file_list)
def verify_by_sql(file, sql_ret, schema, table_name, database_name, encoding, save_file_list):
"""
verify by sql
指定四元组为schema
"""
verifier = Verify(file, sql_ret, schema, table_name, database_name, encoding)
return verifier.verify_by_sql(save_file_list)