| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| |
| # http://www.apache.org/licenses/LICENSE-2.0.html |
| |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import unittest |
| import math |
| import pickle |
| import statistics |
| import yaml |
| import argparse |
| import re |
| import hashlib |
| |
| import pyspark.sql.functions as fn |
| import numpy as np |
| |
| from pyspark import SparkContext |
| from pyspark.sql import SparkSession, HiveContext |
| from pyspark.sql.types import IntegerType, StringType, MapType |
| from datetime import datetime, timedelta |
| |
| ''' |
| |
| This script performs the following actions: |
| |
| 1. call model API with N number of randomly picked dense uckeys from trainready (The same data that is used to train the model). |
| 2. calculate the accuracy of the model. |
| |
| |
| run by: |
| spark-submit --master yarn --num-executors 5 --executor-cores 3 --executor-memory 16G --driver-memory 16G check_model.py |
| |
| ''' |
| |
| from client_rest_dl2 import predict |
| |
| |
| def c_error(x, y): |
| x = x * 1.0 |
| if x != 0: |
| e = abs(x - y) / x |
| else: |
| e = -1 |
| e = round(e, 3) |
| return e |
| |
| |
| def error_m(a, p): |
| result = [] |
| for i in range(len(a)): |
| x = a[i] |
| y = p[i] |
| e = c_error(x, y) |
| result.append(e) |
| x = sum(a) |
| y = sum(p) |
| e = c_error(x, y) |
| return (e, result) |
| |
| |
| def normalize_ts(ts): |
| ts_n = [math.log(i + 1) for i in ts] |
| return ts_n |
| |
| |
| def dl_daily_forecast(serving_url, model_stats, day_list, ucdoc_attribute_map): |
| x, y = predict(serving_url=serving_url, model_stats=model_stats, |
| day_list=day_list, ucdoc_attribute_map=ucdoc_attribute_map, forward_offset=0) |
| ts = x[0] |
| days = y |
| return ts, days |
| |
| |
| def get_model_stats(hive_context, model_stat_table): |
| ''' |
| return a dict |
| model_stats = { |
| "model": { |
| "name": "s32", |
| "version": 1, |
| "duration": 90, |
| "train_window": 60, |
| "predict_window": 10 |
| }, |
| "stats": { |
| "g_g_m": [ |
| 0.32095959595959594, |
| 0.4668649491714752 |
| ], |
| "g_g_f": [ |
| 0.3654040404040404, |
| 0.4815635452904544 |
| ], |
| "g_g_x": [ |
| 0.31363636363636366, |
| 0.46398999646418304 |
| ], |
| ''' |
| command = """ |
| SELECT * FROM {} |
| """.format(model_stat_table) |
| df = hive_context.sql(command) |
| rows = df.collect() |
| if len(rows) != 1: |
| raise Exception('Bad model stat table {} '.format(model_stat_table)) |
| model_info = rows[0]['model_info'] |
| model_stats = rows[0]['stats'] |
| result = { |
| 'model': model_info, |
| 'stats': model_stats |
| } |
| return result |
| |
| |
| def predict_daily_uckey(sample, days, serving_url, model_stats, columns): |
| |
| def _denoise(ts): |
| non_zero_ts = [_ for _ in ts if _ != 0] |
| nonzero_p = 0.0 |
| if len(non_zero_ts) > 0: |
| nonzero_p = 1.0 * sum(ts) / len(non_zero_ts) |
| |
| return [i if i > (nonzero_p / 10.0) else 0 for i in ts] |
| |
| def _helper(cols): |
| day_list = days[:] |
| ucdoc_attribute_map = {} |
| for feature in columns: |
| ucdoc_attribute_map[feature] = cols[feature] |
| |
| # determine ts_n and days |
| model_input_ts = [] |
| |
| # ----------------------------------------------------------------------------------------------- |
| ''' |
| The following code is in dlpredictor, here ts has a different format |
| |
| 'ts': [0, 0, 0, 0, 0, 65, 47, 10, 52, 58, 27, 55, 23, 44, 38, 42, 90, 26, 95, 34, 25, 26, 18, 66, 31, |
| 0, 38, 26, 30, 49, 35, 61, 0, 55, 23, 44, 35, 33, 22, 25, 28, 72, 25, 15, 29, 29, 9, 32, 18, 20, 70, |
| 20, 4, 11, 15, 10, 8, 3, 0, 5, 3, 0, 23, 11, 44, 11, 11, 8, 3, 38, 3, 28, 16, 3, 4, 20, 5, 4, 45, 15, 9, 3, 60, 27, 15, 17, 5, 6, 0, 7, 12, 0], |
| |
| |
| # ts = {u'2019-11-02': [u'1:862', u'3:49', u'2:1154'], u'2019-11-03': [u'1:596', u'3:67', u'2:1024']} |
| ts = ucdoc_attribute_map['ts'][0] |
| price_cat = ucdoc_attribute_map['price_cat'] |
| |
| for day in day_list: |
| imp = 0.0 |
| if day in ts: |
| count_array = ts[day] |
| for i in count_array: |
| parts = i.split(':') |
| if parts[0] == price_cat: |
| imp = float(parts[1]) |
| break |
| model_input_ts.append(imp) |
| |
| |
| ''' |
| model_input_ts = ucdoc_attribute_map['ts'] |
| price_cat = ucdoc_attribute_map['price_cat'] |
| |
| # -------------------------------------------------------------------------------------------------------- |
| |
| # remove science 06/21/2021 |
| # model_input_ts = replace_with_median(model_input_ts) |
| |
| model_input_ts = _denoise(model_input_ts) |
| |
| ts_n = normalize_ts(model_input_ts) |
| ucdoc_attribute_map['ts_n'] = ts_n |
| |
| # add page_ix |
| page_ix = ucdoc_attribute_map['uckey'] + '-' + price_cat |
| ucdoc_attribute_map['page_ix'] = page_ix |
| |
| rs_ts, rs_days = dl_daily_forecast( |
| serving_url=serving_url, model_stats=model_stats, day_list=day_list, ucdoc_attribute_map=ucdoc_attribute_map) |
| |
| # respose = {'2019-11-02': 220.0, '2019-11-03': 305.0} |
| |
| response = {} |
| for i, day in enumerate(rs_days): |
| response[day] = rs_ts[i] |
| return response |
| |
| return _helper(cols=sample) |
| |
| |
| def run(cfg, hive_context): |
| |
| model_stats = get_model_stats(hive_context, cfg['model_stat_table']) |
| |
| # create day_list from yesterday for train_window |
| duration = model_stats['model']['duration'] |
| predict_window = model_stats['model']['predict_window'] |
| day_list = model_stats['model']['days'] |
| day_list.sort() |
| |
| local = False |
| if not local: |
| df_trainready = hive_context.sql( |
| 'SELECT * FROM {} WHERE uckey="native,b6le0s4qo8,4G,g_f,5,CPC,,1156320000" and price_cat="1" '.format(cfg['trainready_table'])) |
| df_dist = hive_context.sql( |
| 'SELECT * FROM {} WHERE ratio=1'.format(cfg['dist_table'])) |
| df = df_trainready.join( |
| df_dist, on=['uckey', 'price_cat'], how='inner') |
| columns = df.columns |
| samples = df.take(cfg['max_calls']) |
| else: |
| sample = {'si_vec_n': [-0.4151402711868286, 2.9644479751586914, -0.3145267963409424, -0.26219648122787476, -0.3064562976360321, -0.28393232822418213, -0.28601524233818054, -0.27245578169822693, -0.23727722465991974, -0.1847621202468872, -0.1882103681564331, -0.18137064576148987, -0.17601335048675537, -0.14012782275676727, -0.17195084691047668, -0.13098371028900146, -0.10281818360090256, -0.11568441241979599, -0.08055911213159561, -0.09745623171329498, -0.032780639827251434, -0.044262196868658066, -0.03243844956159592, -0.017660070210695267, -0.01123445387929678], 'p_n': -0.12658941745758057, 'price_cat_2_n': -0.33149921894073486, 't_4G_n': 1.1970059871673584, 'uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPC,9000014,1156430000', 'g_g_f_n': -0.7063417434692383, 'cluster_uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPC,9000014,1156430000', 'price_cat_1_n': 0.6389415264129639, 't_2G_n': -0.020190885290503502, 'a_6_n': -0.39808785915374756, 'ratio': 1.0, 't_UNKNOWN_n': -0.01819116622209549, 'ts': [0, 0, 0, 0, 0, 65, 47, 10, 52, 58, 27, 55, 23, 44, 38, 42, 90, 26, 95, 34, 25, 26, 18, 66, 31, 0, 38, 26, 30, 49, 35, 61, 0, 55, 23, 44, 35, 33, 22, 25, 28, 72, 25, 15, 29, 29, 9, 32, 18, 20, 70, 20, 4, 11, 15, 10, 8, 3, 0, 5, 3, 0, 23, 11, 44, 11, 11, 8, 3, 38, 3, 28, 16, 3, 4, 20, 5, 4, 45, 15, 9, 3, 60, 27, 15, 17, 5, 6, 0, 7, 12, 0], 'a_4_n': -0.5339045524597168, 'a_5_n': -0.4912262558937073, 'page_ix': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPC,9000014,1156430000-1', 't_3G_n': -0.05932876095175743, 'price_cat': u'1', 'g__n': 6.9879865646362305, 'a_2_n': -0.4635731875896454, 'g_g_x_n': 0.0, 'r_vec_n': [0.0, -0.1255381852388382, 0.0, -0.0707460418343544, -0.09126978367567062, -0.08092246204614639, -0.11013858020305634, -0.09528900682926178, -0.07611638307571411, -0.08474867045879364, -0.08210774511098862, -0.07721684128046036, -0.09556490182876587, -0.0758058950304985, -0.0781775563955307, -0.06826190650463104, 0.0, 0.0, -0.09189923852682114, 0.0, -0.0870438739657402, -0.07334470003843307, -0.15104743838310242, 0.0, -0.1291733831167221, -0.09792251139879227, -0.08885188400745392, 0.0, -0.13396185636520386, -0.08979155123233795, -0.06846188753843307, -0.08730152994394302, -0.0627228394150734, -0.0661485493183136, -0.07778248190879822, 0.0, 0.0, -0.16011947393417358, -0.08082698285579681, -0.08794737607240677, -0.16864892840385437, 0.0, -0.09747622162103653, -0.09903174638748169, -0.06362000107765198, -0.05568321421742439, -0.08872973173856735, 0.0, 0.0, -0.08391384780406952, -0.08121273666620255, -0.07653091102838516, 0.0, -0.061079706996679306, -0.09259713441133499, - |
| 0.07449918985366821, -0.05686589330434799, -0.06093728542327881, -0.05630393698811531, -0.060247842222452164, 0.0, 0.0, 0.0, -0.07424047589302063, -0.06607889384031296, -0.06424426287412643, -0.06589335203170776, -0.07382549345493317, -0.05542287230491638, 0.0, 0.0, -0.08411328494548798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a_1_n': -0.19898787140846252, 'ts_n': [0.0, 0.0, 0.0, 0.0, 0.0, 4.18965482711792, 3.8712010383605957, 2.397895336151123, 3.97029185295105, 4.077537536621094, 3.332204580307007, 4.025351524353027, 3.178053855895996, 3.8066625595092773, 3.6635615825653076, 3.761200189590454, 4.510859489440918, 3.295836925506592, 4.564348220825195, 3.5553481578826904, 3.25809645652771, 3.295836925506592, 2.944438934326172, 4.204692840576172, 3.465735912322998, 0.0, 3.6635615825653076, 3.295836925506592, 3.4339871406555176, 3.9120230674743652, 3.5835189819335938, 4.127134323120117, 0.0, 4.025351524353027, 3.178053855895996, 3.8066625595092773, 3.5835189819335938, 3.526360511779785, 3.1354942321777344, 3.25809645652771, 3.367295742034912, 4.290459632873535, 3.25809645652771, 2.7725887298583984, 3.4011974334716797, 3.4011974334716797, 2.3025851249694824, 3.4965076446533203, 2.944438934326172, 3.044522523880005, 4.2626800537109375, 3.044522523880005, 1.6094379425048828, 2.4849066734313965, 2.7725887298583984, 2.397895336151123, 2.1972246170043945, 1.3862943649291992, 0.0, 1.7917594909667969, 1.3862943649291992, 0.0, 3.178053855895996, 2.4849066734313965, 3.8066625595092773, 2.4849066734313965, 2.4849066734313965, 2.1972246170043945, 1.3862943649291992, 3.6635615825653076, 1.3862943649291992, 3.367295742034912, 2.8332133293151855, 1.3862943649291992, 1.6094379425048828, 3.044522523880005, 1.7917594909667969, 1.6094379425048828, 3.828641414642334, 2.7725887298583984, 2.3025851249694824, 1.3862943649291992, 4.110873699188232, 3.332204580307007, 2.7725887298583984, 2.890371799468994, 1.7917594909667969, 1.945910096168518, 0.0, 2.079441547393799, 2.5649492740631104, 0.0], 'ipl_vec_n': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a__n': 5.47928524017334, 'g_g_m_n': -1.359645962715149, 't_WIFI_n': -1.0556801557540894, 'p': 24.25, 'price_cat_3_n': -0.48571521043777466, 'a_3_n': -0.5047609806060791} |
| sample1 = {'si_vec_n': [-0.4151402711868286, 2.9644479751586914, -0.3145267963409424, -0.26219648122787476, -0.3064562976360321, -0.28393232822418213, -0.28601524233818054, -0.27245578169822693, -0.23727722465991974, -0.1847621202468872, -0.1882103681564331, -0.18137064576148987, -0.17601335048675537, -0.14012782275676727, -0.17195084691047668, -0.13098371028900146, -0.10281818360090256, -0.11568441241979599, -0.08055911213159561, -0.09745623171329498, -0.032780639827251434, -0.044262196868658066, -0.03243844956159592, -0.017660070210695267, -0.01123445387929678], 'p_n': 0.2695181965827942, 'price_cat_2_n': -0.33149921894073486, 't_4G_n': 1.1970059871673584, 'uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'g_g_f_n': -0.7063417434692383, 'cluster_uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'price_cat_1_n': 0.6389415264129639, 't_2G_n': -0.020190885290503502, 'a_6_n': -0.39808785915374756, 'ratio': 1.0, 't_UNKNOWN_n': -0.01819116622209549, 'ts': [128, 0, 132, 218, 135, 115, 773, 924, 835, 755, 776, 758, 819, 869, 843, 716, 881, 910, 814, 754, 823, 1096, 943, 927, 539, 7051, 397, 427, 492, 500, 421, 309, 6304, 392, 386, 381, 434, 450, 390, 363, 337, 311, 374, 420, 509, 459, 406, 407, 310, 370, 392, 421, 533, 481, 379, 433, 505, 614, 0, 590, 516, 646, 567, 594, 740, 543, 581, 757, 496, 462, 644, 738, 708, 698, 572, 651, 595, 672, 563, 777, 795, 813], 'a_4_n': -0.5339045524597168, 'a_5_n': -0.4912262558937073, 'page_ix': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319-1', 't_3G_n': -0.05932876095175743, 'price_cat': u'1', 'g__n': 6.9879865646362305, 'a_2_n': -0.4635731875896454, 'g_g_x_n': 0.0, 'r_vec_n': [0.0, -0.1255381852388382, 0.0, -0.0707460418343544, -0.09126978367567062, -0.08092246204614639, -0.11013858020305634, -0.09528900682926178, -0.07611638307571411, -0.08474867045879364, -0.08210774511098862, -0.07721684128046036, -0.09556490182876587, -0.0758058950304985, -0.0781775563955307, -0.06826190650463104, 0.0, 0.0, -0.09189923852682114, 0.0, -0.0870438739657402, -0.07334470003843307, -0.15104743838310242, 0.0, -0.1291733831167221, -0.09792251139879227, -0.08885188400745392, 0.0, -0.13396185636520386, -0.08979155123233795, -0.06846188753843307, -0.08730152994394302, -0.0627228394150734, -0.0661485493183136, -0.07778248190879822, 0.0, 0.0, -0.16011947393417358, -0.08082698285579681, -0.08794737607240677, -0.16864892840385437, 0.0, -0.09747622162103653, -0.09903174638748169, -0.06362000107765198, -0.05568321421742439, -0.08872973173856735, 0.0, 0.0, -0.08391384780406952, -0.08121273666620255, -0.07653091102838516, 0.0, -0.061079706996679306, -0.09259713441133499, -0.07449918985366821, - |
| 0.05686589330434799, -0.06093728542327881, -0.05630393698811531, -0.060247842222452164, 0.0, 0.0, 0.0, -0.07424047589302063, -0.06607889384031296, -0.06424426287412643, -0.06589335203170776, -0.07382549345493317, -0.05542287230491638, 0.0, 0.0, -0.08411328494548798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a_1_n': -0.19898787140846252, 'ts_n': [4.859812259674072, 0.0, 4.8903489112854, 5.389071941375732, 4.912654876708984, 4.753590106964111, 6.651571750640869, 6.829793930053711, 6.728628635406494, 6.6280412673950195, 6.655440330505371, 6.632001876831055, 6.709304332733154, 6.768493175506592, 6.738152503967285, 6.575075626373291, 6.782192230224609, 6.814542770385742, 6.703187942504883, 6.626717567443848, 6.714170455932617, 7.0003342628479, 6.850126266479492, 6.83303165435791, 6.291569232940674, 8.861066818237305, 5.986452102661133, 6.0591230392456055, 6.200509071350098, 6.216606140136719, 6.0450053215026855, 5.736572265625, 8.74909782409668, 5.973809719085693, 5.9584245681762695, 5.945420742034912, 6.075345993041992, 6.111467361450195, 5.96870756149292, 5.897153854370117, 5.82304573059082, 5.7430033683776855, 5.926926136016846, 6.042633056640625, 6.234410762786865, 6.131226539611816, 6.008813381195068, 6.011267185211182, 5.739792823791504, 5.916202068328857, 5.973809719085693, 6.0450053215026855, 6.280395984649658, 6.177944183349609, 5.940171241760254, 6.073044300079346, 6.226536750793457, 6.421622276306152, 0.0, 6.3818159103393555, 6.248043060302734, 6.472346305847168, 6.342121601104736, 6.388561248779297, 6.608000755310059, 6.298949241638184, 6.3664703369140625, 6.630683422088623, 6.208590030670166, 6.1377272605896, 6.469250202178955, 6.605298042297363, 6.5638556480407715, 6.5496506690979, 6.35088586807251, 6.480044364929199, 6.390240669250488, 6.511745452880859, 6.335054397583008, 6.656726360321045, 6.679599285125732, 6.701960563659668, 6.748759746551514, 6.539586067199707, 6.4329400062561035, 6.398594856262207, 6.5496506690979, 6.533788681030273, 9.30737590789795, 6.570882797241211, 6.5638556480407715, 6.152732849121094], 'ipl_vec_n': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a__n': 5.47928524017334, 'g_g_m_n': -1.359645962715149, 't_WIFI_n': -1.0556801557540894, 'p': 811.9238891601562, 'price_cat_3_n': -0.48571521043777466, 'a_3_n': -0.5047609806060791} |
| |
| sample2 = {'si_vec_n': [-0.4151402711868286, 2.9644479751586914, -0.3145267963409424, -0.26219648122787476, -0.3064562976360321, -0.28393232822418213, -0.28601524233818054, -0.27245578169822693, -0.23727722465991974, -0.1847621202468872, -0.1882103681564331, -0.18137064576148987, -0.17601335048675537, -0.14012782275676727, -0.17195084691047668, -0.13098371028900146, -0.10281818360090256, -0.11568441241979599, -0.08055911213159561, -0.09745623171329498, -0.032780639827251434, -0.044262196868658066, -0.03243844956159592, -0.017660070210695267, -0.01123445387929678], 'p_n': 0.2695181965827942, 'price_cat_2_n': -0.33149921894073486, 't_4G_n': 1.1970059871673584, 'uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'g_g_f_n': -0.7063417434692383, 'cluster_uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'price_cat_1_n': 0.6389415264129639, 't_2G_n': -0.020190885290503502, 'a_6_n': -0.39808785915374756, 'ratio': 1.0, 't_UNKNOWN_n': -0.01819116622209549, 'ts': [128, 0, 132, 218, 135, 115, 773, 924, 835, 755, 776, 758, 819, 869, 843, 716, 881, 910, 814, 754, 823, 1096, 943, 927, 539, 7051, 397, 427, 492, 500, 421, 309, 6304, 392, 386, 381, 434, 450, 390, 363, 337, 311, 374, 420, 509, 459, 406, 407, 310, 370, 392, 421, 533, 481, 379, 433, 505, 614, 0, 590, 516, 646, 567, 594, 740, 543, 581, 757, 496, 462, 644, 738, 708, 698, 572, 651, 595, 672, 563, 777, 795, 813], 'a_4_n': -0.5339045524597168, 'a_5_n': -0.4912262558937073, 'page_ix': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319-1', 't_3G_n': -0.05932876095175743, 'price_cat': u'1', 'g__n': 6.9879865646362305, 'a_2_n': -0.4635731875896454, 'g_g_x_n': 0.0, 'r_vec_n': [0.0, -0.1255381852388382, 0.0, -0.0707460418343544, -0.09126978367567062, -0.08092246204614639, -0.11013858020305634, -0.09528900682926178, -0.07611638307571411, -0.08474867045879364, -0.08210774511098862, -0.07721684128046036, -0.09556490182876587, -0.0758058950304985, -0.0781775563955307, -0.06826190650463104, 0.0, 0.0, -0.09189923852682114, 0.0, -0.0870438739657402, -0.07334470003843307, -0.15104743838310242, 0.0, -0.1291733831167221, -0.09792251139879227, -0.08885188400745392, 0.0, -0.13396185636520386, -0.08979155123233795, -0.06846188753843307, -0.08730152994394302, -0.0627228394150734, -0.0661485493183136, -0.07778248190879822, 0.0, 0.0, -0.16011947393417358, -0.08082698285579681, -0.08794737607240677, -0.16864892840385437, 0.0, -0.09747622162103653, -0.09903174638748169, -0.06362000107765198, -0.05568321421742439, -0.08872973173856735, 0.0, 0.0, -0.08391384780406952, -0.08121273666620255, -0.07653091102838516, 0.0, -0.061079706996679306, -0.09259713441133499, -0.07449918985366821, - |
| 0.05686589330434799, -0.06093728542327881, -0.05630393698811531, -0.060247842222452164, 0.0, 0.0, 0.0, -0.07424047589302063, -0.06607889384031296, -0.06424426287412643, -0.06589335203170776, -0.07382549345493317, -0.05542287230491638, 0.0, 0.0, -0.08411328494548798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a_1_n': -0.19898787140846252, 'ts_n': [4.859812259674072, 0.0, 4.8903489112854, 5.389071941375732, 4.912654876708984, 4.753590106964111, 6.651571750640869, 6.829793930053711, 6.728628635406494, 6.6280412673950195, 6.655440330505371, 6.632001876831055, 6.709304332733154, 6.768493175506592, 6.738152503967285, 6.575075626373291, 6.782192230224609, 6.814542770385742, 6.703187942504883, 6.626717567443848, 6.714170455932617, 7.0003342628479, 6.850126266479492, 6.83303165435791, 6.291569232940674, 8.861066818237305, 5.986452102661133, 6.0591230392456055, 6.200509071350098, 6.216606140136719, 6.0450053215026855, 5.736572265625, 8.74909782409668, 5.973809719085693, 5.9584245681762695, 5.945420742034912, 6.075345993041992, 6.111467361450195, 5.96870756149292, 5.897153854370117, 5.82304573059082, 5.7430033683776855, 5.926926136016846, 6.042633056640625, 6.234410762786865, 6.131226539611816, 6.008813381195068, 6.011267185211182, 5.739792823791504, 5.916202068328857, 5.973809719085693, 6.0450053215026855, 6.280395984649658, 6.177944183349609, 5.940171241760254, 6.073044300079346, 6.226536750793457, 6.421622276306152, 0.0, 6.3818159103393555, 6.248043060302734, 6.472346305847168, 6.342121601104736, 6.388561248779297, 6.608000755310059, 6.298949241638184, 6.3664703369140625, 6.630683422088623, 6.208590030670166, 6.1377272605896, 6.469250202178955, 6.605298042297363, 6.5638556480407715, 6.5496506690979, 6.35088586807251, 6.480044364929199, 6.390240669250488, 6.511745452880859, 6.335054397583008, 6.656726360321045, 6.679599285125732, 6.701960563659668, 6.748759746551514, 6.539586067199707, 6.4329400062561035, 6.398594856262207, 6.5496506690979, 6.533788681030273, 9.30737590789795, 6.570882797241211, 6.5638556480407715, 6.152732849121094], 'ipl_vec_n': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a__n': 5.47928524017334, 'g_g_m_n': -1.359645962715149, 't_WIFI_n': -1.0556801557540894, 'p': 811.9238891601562, 'price_cat_3_n': -0.48571521043777466, 'a_3_n': -0.5047609806060791} |
| |
| samples = [sample, sample1, sample2] |
| columns = sample.keys() |
| |
| day_list = day_list[:-predict_window] |
| |
| errs = [] |
| for _ in samples: |
| sample = {} |
| for feature in columns: |
| sample[feature] = _[feature] |
| |
| whole_ts = sample['ts'][:] |
| expected = whole_ts[-predict_window:] |
| sample['ts'] = whole_ts[:-predict_window] |
| input_ts = sample['ts'] |
| |
| # zipped = zip(day_list, input_ts) |
| # print(zipped) |
| # print(len(zipped)) |
| |
| print("-------------------------------------") |
| # print(sample) |
| print("-------------------------------------") |
| |
| response = predict_daily_uckey( |
| sample=sample, days=day_list, serving_url=cfg['serving_url'], model_stats=model_stats, columns=columns) |
| predicted = [response[_] for _ in sorted(response)] |
| |
| for i, v in enumerate(expected): |
| if v == 0: |
| predicted[i] = 0 |
| print(zip(expected, predicted)) |
| e = error_m(expected, predicted)[0] |
| print(e*100) |
| if e < 0: |
| e = 0 |
| errs.append(e) |
| |
| print(sum(errs)/(len(errs)*1.0)*100) |
| |
| |
| if __name__ == '__main__': |
| |
| cfg = { |
| 'log_level': 'warn', |
| 'trainready_table': 'dlpm_111021_no_residency_no_mapping_trainready_test_12212021', |
| 'dist_table': 'dlpm_111021_no_residency_no_mapping_tmp_distribution_test_12212021', |
| 'serving_url': 'http://10.193.217.126:8503/v1/models/dl_test_1221:predict', |
| 'max_calls': 4, |
| 'model_stat_table': 'dlpm_111021_no_residency_no_mapping_model_stat_test_12212021', |
| 'yesterday': 'WILL BE SET IN PROGRAM'} |
| |
| sc = SparkContext.getOrCreate() |
| hive_context = HiveContext(sc) |
| sc.setLogLevel(cfg['log_level']) |
| |
| run(cfg=cfg, hive_context=hive_context) |