blob: 81144d7f21b6d62afcb20d5f7234139c5c510ec4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0.html
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import math
import pickle
import statistics
import yaml
import argparse
import re
import hashlib
import pyspark.sql.functions as fn
import numpy as np
from pyspark import SparkContext
from pyspark.sql import SparkSession, HiveContext
from pyspark.sql.types import IntegerType, StringType, MapType
from datetime import datetime, timedelta
'''
This script performs the following actions:
1. call model API with N number of randomly picked dense uckeys from trainready (The same data that is used to train the model).
2. calculate the accuracy of the model.
run by:
spark-submit --master yarn --num-executors 5 --executor-cores 3 --executor-memory 16G --driver-memory 16G check_model.py
'''
from client_rest_dl2 import predict
def c_error(x, y):
x = x * 1.0
if x != 0:
e = abs(x - y) / x
else:
e = -1
e = round(e, 3)
return e
def error_m(a, p):
result = []
for i in range(len(a)):
x = a[i]
y = p[i]
e = c_error(x, y)
result.append(e)
x = sum(a)
y = sum(p)
e = c_error(x, y)
return (e, result)
def normalize_ts(ts):
ts_n = [math.log(i + 1) for i in ts]
return ts_n
def dl_daily_forecast(serving_url, model_stats, day_list, ucdoc_attribute_map):
x, y = predict(serving_url=serving_url, model_stats=model_stats,
day_list=day_list, ucdoc_attribute_map=ucdoc_attribute_map, forward_offset=0)
ts = x[0]
days = y
return ts, days
def get_model_stats(hive_context, model_stat_table):
'''
return a dict
model_stats = {
"model": {
"name": "s32",
"version": 1,
"duration": 90,
"train_window": 60,
"predict_window": 10
},
"stats": {
"g_g_m": [
0.32095959595959594,
0.4668649491714752
],
"g_g_f": [
0.3654040404040404,
0.4815635452904544
],
"g_g_x": [
0.31363636363636366,
0.46398999646418304
],
'''
command = """
SELECT * FROM {}
""".format(model_stat_table)
df = hive_context.sql(command)
rows = df.collect()
if len(rows) != 1:
raise Exception('Bad model stat table {} '.format(model_stat_table))
model_info = rows[0]['model_info']
model_stats = rows[0]['stats']
result = {
'model': model_info,
'stats': model_stats
}
return result
def predict_daily_uckey(sample, days, serving_url, model_stats, columns):
def _denoise(ts):
non_zero_ts = [_ for _ in ts if _ != 0]
nonzero_p = 0.0
if len(non_zero_ts) > 0:
nonzero_p = 1.0 * sum(ts) / len(non_zero_ts)
return [i if i > (nonzero_p / 10.0) else 0 for i in ts]
def _helper(cols):
day_list = days[:]
ucdoc_attribute_map = {}
for feature in columns:
ucdoc_attribute_map[feature] = cols[feature]
# determine ts_n and days
model_input_ts = []
# -----------------------------------------------------------------------------------------------
'''
The following code is in dlpredictor, here ts has a different format
'ts': [0, 0, 0, 0, 0, 65, 47, 10, 52, 58, 27, 55, 23, 44, 38, 42, 90, 26, 95, 34, 25, 26, 18, 66, 31,
0, 38, 26, 30, 49, 35, 61, 0, 55, 23, 44, 35, 33, 22, 25, 28, 72, 25, 15, 29, 29, 9, 32, 18, 20, 70,
20, 4, 11, 15, 10, 8, 3, 0, 5, 3, 0, 23, 11, 44, 11, 11, 8, 3, 38, 3, 28, 16, 3, 4, 20, 5, 4, 45, 15, 9, 3, 60, 27, 15, 17, 5, 6, 0, 7, 12, 0],
# ts = {u'2019-11-02': [u'1:862', u'3:49', u'2:1154'], u'2019-11-03': [u'1:596', u'3:67', u'2:1024']}
ts = ucdoc_attribute_map['ts'][0]
price_cat = ucdoc_attribute_map['price_cat']
for day in day_list:
imp = 0.0
if day in ts:
count_array = ts[day]
for i in count_array:
parts = i.split(':')
if parts[0] == price_cat:
imp = float(parts[1])
break
model_input_ts.append(imp)
'''
model_input_ts = ucdoc_attribute_map['ts']
price_cat = ucdoc_attribute_map['price_cat']
# --------------------------------------------------------------------------------------------------------
# remove science 06/21/2021
# model_input_ts = replace_with_median(model_input_ts)
model_input_ts = _denoise(model_input_ts)
ts_n = normalize_ts(model_input_ts)
ucdoc_attribute_map['ts_n'] = ts_n
# add page_ix
page_ix = ucdoc_attribute_map['uckey'] + '-' + price_cat
ucdoc_attribute_map['page_ix'] = page_ix
rs_ts, rs_days = dl_daily_forecast(
serving_url=serving_url, model_stats=model_stats, day_list=day_list, ucdoc_attribute_map=ucdoc_attribute_map)
# respose = {'2019-11-02': 220.0, '2019-11-03': 305.0}
response = {}
for i, day in enumerate(rs_days):
response[day] = rs_ts[i]
return response
return _helper(cols=sample)
def run(cfg, hive_context):
model_stats = get_model_stats(hive_context, cfg['model_stat_table'])
# create day_list from yesterday for train_window
duration = model_stats['model']['duration']
predict_window = model_stats['model']['predict_window']
day_list = model_stats['model']['days']
day_list.sort()
local = False
if not local:
df_trainready = hive_context.sql(
'SELECT * FROM {} WHERE uckey="native,b6le0s4qo8,4G,g_f,5,CPC,,1156320000" and price_cat="1" '.format(cfg['trainready_table']))
df_dist = hive_context.sql(
'SELECT * FROM {} WHERE ratio=1'.format(cfg['dist_table']))
df = df_trainready.join(
df_dist, on=['uckey', 'price_cat'], how='inner')
columns = df.columns
samples = df.take(cfg['max_calls'])
else:
sample = {'si_vec_n': [-0.4151402711868286, 2.9644479751586914, -0.3145267963409424, -0.26219648122787476, -0.3064562976360321, -0.28393232822418213, -0.28601524233818054, -0.27245578169822693, -0.23727722465991974, -0.1847621202468872, -0.1882103681564331, -0.18137064576148987, -0.17601335048675537, -0.14012782275676727, -0.17195084691047668, -0.13098371028900146, -0.10281818360090256, -0.11568441241979599, -0.08055911213159561, -0.09745623171329498, -0.032780639827251434, -0.044262196868658066, -0.03243844956159592, -0.017660070210695267, -0.01123445387929678], 'p_n': -0.12658941745758057, 'price_cat_2_n': -0.33149921894073486, 't_4G_n': 1.1970059871673584, 'uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPC,9000014,1156430000', 'g_g_f_n': -0.7063417434692383, 'cluster_uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPC,9000014,1156430000', 'price_cat_1_n': 0.6389415264129639, 't_2G_n': -0.020190885290503502, 'a_6_n': -0.39808785915374756, 'ratio': 1.0, 't_UNKNOWN_n': -0.01819116622209549, 'ts': [0, 0, 0, 0, 0, 65, 47, 10, 52, 58, 27, 55, 23, 44, 38, 42, 90, 26, 95, 34, 25, 26, 18, 66, 31, 0, 38, 26, 30, 49, 35, 61, 0, 55, 23, 44, 35, 33, 22, 25, 28, 72, 25, 15, 29, 29, 9, 32, 18, 20, 70, 20, 4, 11, 15, 10, 8, 3, 0, 5, 3, 0, 23, 11, 44, 11, 11, 8, 3, 38, 3, 28, 16, 3, 4, 20, 5, 4, 45, 15, 9, 3, 60, 27, 15, 17, 5, 6, 0, 7, 12, 0], 'a_4_n': -0.5339045524597168, 'a_5_n': -0.4912262558937073, 'page_ix': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPC,9000014,1156430000-1', 't_3G_n': -0.05932876095175743, 'price_cat': u'1', 'g__n': 6.9879865646362305, 'a_2_n': -0.4635731875896454, 'g_g_x_n': 0.0, 'r_vec_n': [0.0, -0.1255381852388382, 0.0, -0.0707460418343544, -0.09126978367567062, -0.08092246204614639, -0.11013858020305634, -0.09528900682926178, -0.07611638307571411, -0.08474867045879364, -0.08210774511098862, -0.07721684128046036, -0.09556490182876587, -0.0758058950304985, -0.0781775563955307, -0.06826190650463104, 0.0, 0.0, -0.09189923852682114, 0.0, -0.0870438739657402, -0.07334470003843307, -0.15104743838310242, 0.0, -0.1291733831167221, -0.09792251139879227, -0.08885188400745392, 0.0, -0.13396185636520386, -0.08979155123233795, -0.06846188753843307, -0.08730152994394302, -0.0627228394150734, -0.0661485493183136, -0.07778248190879822, 0.0, 0.0, -0.16011947393417358, -0.08082698285579681, -0.08794737607240677, -0.16864892840385437, 0.0, -0.09747622162103653, -0.09903174638748169, -0.06362000107765198, -0.05568321421742439, -0.08872973173856735, 0.0, 0.0, -0.08391384780406952, -0.08121273666620255, -0.07653091102838516, 0.0, -0.061079706996679306, -0.09259713441133499, -
0.07449918985366821, -0.05686589330434799, -0.06093728542327881, -0.05630393698811531, -0.060247842222452164, 0.0, 0.0, 0.0, -0.07424047589302063, -0.06607889384031296, -0.06424426287412643, -0.06589335203170776, -0.07382549345493317, -0.05542287230491638, 0.0, 0.0, -0.08411328494548798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a_1_n': -0.19898787140846252, 'ts_n': [0.0, 0.0, 0.0, 0.0, 0.0, 4.18965482711792, 3.8712010383605957, 2.397895336151123, 3.97029185295105, 4.077537536621094, 3.332204580307007, 4.025351524353027, 3.178053855895996, 3.8066625595092773, 3.6635615825653076, 3.761200189590454, 4.510859489440918, 3.295836925506592, 4.564348220825195, 3.5553481578826904, 3.25809645652771, 3.295836925506592, 2.944438934326172, 4.204692840576172, 3.465735912322998, 0.0, 3.6635615825653076, 3.295836925506592, 3.4339871406555176, 3.9120230674743652, 3.5835189819335938, 4.127134323120117, 0.0, 4.025351524353027, 3.178053855895996, 3.8066625595092773, 3.5835189819335938, 3.526360511779785, 3.1354942321777344, 3.25809645652771, 3.367295742034912, 4.290459632873535, 3.25809645652771, 2.7725887298583984, 3.4011974334716797, 3.4011974334716797, 2.3025851249694824, 3.4965076446533203, 2.944438934326172, 3.044522523880005, 4.2626800537109375, 3.044522523880005, 1.6094379425048828, 2.4849066734313965, 2.7725887298583984, 2.397895336151123, 2.1972246170043945, 1.3862943649291992, 0.0, 1.7917594909667969, 1.3862943649291992, 0.0, 3.178053855895996, 2.4849066734313965, 3.8066625595092773, 2.4849066734313965, 2.4849066734313965, 2.1972246170043945, 1.3862943649291992, 3.6635615825653076, 1.3862943649291992, 3.367295742034912, 2.8332133293151855, 1.3862943649291992, 1.6094379425048828, 3.044522523880005, 1.7917594909667969, 1.6094379425048828, 3.828641414642334, 2.7725887298583984, 2.3025851249694824, 1.3862943649291992, 4.110873699188232, 3.332204580307007, 2.7725887298583984, 2.890371799468994, 1.7917594909667969, 1.945910096168518, 0.0, 2.079441547393799, 2.5649492740631104, 0.0], 'ipl_vec_n': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a__n': 5.47928524017334, 'g_g_m_n': -1.359645962715149, 't_WIFI_n': -1.0556801557540894, 'p': 24.25, 'price_cat_3_n': -0.48571521043777466, 'a_3_n': -0.5047609806060791}
sample1 = {'si_vec_n': [-0.4151402711868286, 2.9644479751586914, -0.3145267963409424, -0.26219648122787476, -0.3064562976360321, -0.28393232822418213, -0.28601524233818054, -0.27245578169822693, -0.23727722465991974, -0.1847621202468872, -0.1882103681564331, -0.18137064576148987, -0.17601335048675537, -0.14012782275676727, -0.17195084691047668, -0.13098371028900146, -0.10281818360090256, -0.11568441241979599, -0.08055911213159561, -0.09745623171329498, -0.032780639827251434, -0.044262196868658066, -0.03243844956159592, -0.017660070210695267, -0.01123445387929678], 'p_n': 0.2695181965827942, 'price_cat_2_n': -0.33149921894073486, 't_4G_n': 1.1970059871673584, 'uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'g_g_f_n': -0.7063417434692383, 'cluster_uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'price_cat_1_n': 0.6389415264129639, 't_2G_n': -0.020190885290503502, 'a_6_n': -0.39808785915374756, 'ratio': 1.0, 't_UNKNOWN_n': -0.01819116622209549, 'ts': [128, 0, 132, 218, 135, 115, 773, 924, 835, 755, 776, 758, 819, 869, 843, 716, 881, 910, 814, 754, 823, 1096, 943, 927, 539, 7051, 397, 427, 492, 500, 421, 309, 6304, 392, 386, 381, 434, 450, 390, 363, 337, 311, 374, 420, 509, 459, 406, 407, 310, 370, 392, 421, 533, 481, 379, 433, 505, 614, 0, 590, 516, 646, 567, 594, 740, 543, 581, 757, 496, 462, 644, 738, 708, 698, 572, 651, 595, 672, 563, 777, 795, 813], 'a_4_n': -0.5339045524597168, 'a_5_n': -0.4912262558937073, 'page_ix': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319-1', 't_3G_n': -0.05932876095175743, 'price_cat': u'1', 'g__n': 6.9879865646362305, 'a_2_n': -0.4635731875896454, 'g_g_x_n': 0.0, 'r_vec_n': [0.0, -0.1255381852388382, 0.0, -0.0707460418343544, -0.09126978367567062, -0.08092246204614639, -0.11013858020305634, -0.09528900682926178, -0.07611638307571411, -0.08474867045879364, -0.08210774511098862, -0.07721684128046036, -0.09556490182876587, -0.0758058950304985, -0.0781775563955307, -0.06826190650463104, 0.0, 0.0, -0.09189923852682114, 0.0, -0.0870438739657402, -0.07334470003843307, -0.15104743838310242, 0.0, -0.1291733831167221, -0.09792251139879227, -0.08885188400745392, 0.0, -0.13396185636520386, -0.08979155123233795, -0.06846188753843307, -0.08730152994394302, -0.0627228394150734, -0.0661485493183136, -0.07778248190879822, 0.0, 0.0, -0.16011947393417358, -0.08082698285579681, -0.08794737607240677, -0.16864892840385437, 0.0, -0.09747622162103653, -0.09903174638748169, -0.06362000107765198, -0.05568321421742439, -0.08872973173856735, 0.0, 0.0, -0.08391384780406952, -0.08121273666620255, -0.07653091102838516, 0.0, -0.061079706996679306, -0.09259713441133499, -0.07449918985366821, -
0.05686589330434799, -0.06093728542327881, -0.05630393698811531, -0.060247842222452164, 0.0, 0.0, 0.0, -0.07424047589302063, -0.06607889384031296, -0.06424426287412643, -0.06589335203170776, -0.07382549345493317, -0.05542287230491638, 0.0, 0.0, -0.08411328494548798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a_1_n': -0.19898787140846252, 'ts_n': [4.859812259674072, 0.0, 4.8903489112854, 5.389071941375732, 4.912654876708984, 4.753590106964111, 6.651571750640869, 6.829793930053711, 6.728628635406494, 6.6280412673950195, 6.655440330505371, 6.632001876831055, 6.709304332733154, 6.768493175506592, 6.738152503967285, 6.575075626373291, 6.782192230224609, 6.814542770385742, 6.703187942504883, 6.626717567443848, 6.714170455932617, 7.0003342628479, 6.850126266479492, 6.83303165435791, 6.291569232940674, 8.861066818237305, 5.986452102661133, 6.0591230392456055, 6.200509071350098, 6.216606140136719, 6.0450053215026855, 5.736572265625, 8.74909782409668, 5.973809719085693, 5.9584245681762695, 5.945420742034912, 6.075345993041992, 6.111467361450195, 5.96870756149292, 5.897153854370117, 5.82304573059082, 5.7430033683776855, 5.926926136016846, 6.042633056640625, 6.234410762786865, 6.131226539611816, 6.008813381195068, 6.011267185211182, 5.739792823791504, 5.916202068328857, 5.973809719085693, 6.0450053215026855, 6.280395984649658, 6.177944183349609, 5.940171241760254, 6.073044300079346, 6.226536750793457, 6.421622276306152, 0.0, 6.3818159103393555, 6.248043060302734, 6.472346305847168, 6.342121601104736, 6.388561248779297, 6.608000755310059, 6.298949241638184, 6.3664703369140625, 6.630683422088623, 6.208590030670166, 6.1377272605896, 6.469250202178955, 6.605298042297363, 6.5638556480407715, 6.5496506690979, 6.35088586807251, 6.480044364929199, 6.390240669250488, 6.511745452880859, 6.335054397583008, 6.656726360321045, 6.679599285125732, 6.701960563659668, 6.748759746551514, 6.539586067199707, 6.4329400062561035, 6.398594856262207, 6.5496506690979, 6.533788681030273, 9.30737590789795, 6.570882797241211, 6.5638556480407715, 6.152732849121094], 'ipl_vec_n': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a__n': 5.47928524017334, 'g_g_m_n': -1.359645962715149, 't_WIFI_n': -1.0556801557540894, 'p': 811.9238891601562, 'price_cat_3_n': -0.48571521043777466, 'a_3_n': -0.5047609806060791}
sample2 = {'si_vec_n': [-0.4151402711868286, 2.9644479751586914, -0.3145267963409424, -0.26219648122787476, -0.3064562976360321, -0.28393232822418213, -0.28601524233818054, -0.27245578169822693, -0.23727722465991974, -0.1847621202468872, -0.1882103681564331, -0.18137064576148987, -0.17601335048675537, -0.14012782275676727, -0.17195084691047668, -0.13098371028900146, -0.10281818360090256, -0.11568441241979599, -0.08055911213159561, -0.09745623171329498, -0.032780639827251434, -0.044262196868658066, -0.03243844956159592, -0.017660070210695267, -0.01123445387929678], 'p_n': 0.2695181965827942, 'price_cat_2_n': -0.33149921894073486, 't_4G_n': 1.1970059871673584, 'uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'g_g_f_n': -0.7063417434692383, 'cluster_uckey': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319', 'price_cat_1_n': 0.6389415264129639, 't_2G_n': -0.020190885290503502, 'a_6_n': -0.39808785915374756, 'ratio': 1.0, 't_UNKNOWN_n': -0.01819116622209549, 'ts': [128, 0, 132, 218, 135, 115, 773, 924, 835, 755, 776, 758, 819, 869, 843, 716, 881, 910, 814, 754, 823, 1096, 943, 927, 539, 7051, 397, 427, 492, 500, 421, 309, 6304, 392, 386, 381, 434, 450, 390, 363, 337, 311, 374, 420, 509, 459, 406, 407, 310, 370, 392, 421, 533, 481, 379, 433, 505, 614, 0, 590, 516, 646, 567, 594, 740, 543, 581, 757, 496, 462, 644, 738, 708, 698, 572, 651, 595, 672, 563, 777, 795, 813], 'a_4_n': -0.5339045524597168, 'a_5_n': -0.4912262558937073, 'page_ix': u'native,66bcd2720e5011e79bc8fa163e05184e,4G,,,CPM,319,319-1', 't_3G_n': -0.05932876095175743, 'price_cat': u'1', 'g__n': 6.9879865646362305, 'a_2_n': -0.4635731875896454, 'g_g_x_n': 0.0, 'r_vec_n': [0.0, -0.1255381852388382, 0.0, -0.0707460418343544, -0.09126978367567062, -0.08092246204614639, -0.11013858020305634, -0.09528900682926178, -0.07611638307571411, -0.08474867045879364, -0.08210774511098862, -0.07721684128046036, -0.09556490182876587, -0.0758058950304985, -0.0781775563955307, -0.06826190650463104, 0.0, 0.0, -0.09189923852682114, 0.0, -0.0870438739657402, -0.07334470003843307, -0.15104743838310242, 0.0, -0.1291733831167221, -0.09792251139879227, -0.08885188400745392, 0.0, -0.13396185636520386, -0.08979155123233795, -0.06846188753843307, -0.08730152994394302, -0.0627228394150734, -0.0661485493183136, -0.07778248190879822, 0.0, 0.0, -0.16011947393417358, -0.08082698285579681, -0.08794737607240677, -0.16864892840385437, 0.0, -0.09747622162103653, -0.09903174638748169, -0.06362000107765198, -0.05568321421742439, -0.08872973173856735, 0.0, 0.0, -0.08391384780406952, -0.08121273666620255, -0.07653091102838516, 0.0, -0.061079706996679306, -0.09259713441133499, -0.07449918985366821, -
0.05686589330434799, -0.06093728542327881, -0.05630393698811531, -0.060247842222452164, 0.0, 0.0, 0.0, -0.07424047589302063, -0.06607889384031296, -0.06424426287412643, -0.06589335203170776, -0.07382549345493317, -0.05542287230491638, 0.0, 0.0, -0.08411328494548798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a_1_n': -0.19898787140846252, 'ts_n': [4.859812259674072, 0.0, 4.8903489112854, 5.389071941375732, 4.912654876708984, 4.753590106964111, 6.651571750640869, 6.829793930053711, 6.728628635406494, 6.6280412673950195, 6.655440330505371, 6.632001876831055, 6.709304332733154, 6.768493175506592, 6.738152503967285, 6.575075626373291, 6.782192230224609, 6.814542770385742, 6.703187942504883, 6.626717567443848, 6.714170455932617, 7.0003342628479, 6.850126266479492, 6.83303165435791, 6.291569232940674, 8.861066818237305, 5.986452102661133, 6.0591230392456055, 6.200509071350098, 6.216606140136719, 6.0450053215026855, 5.736572265625, 8.74909782409668, 5.973809719085693, 5.9584245681762695, 5.945420742034912, 6.075345993041992, 6.111467361450195, 5.96870756149292, 5.897153854370117, 5.82304573059082, 5.7430033683776855, 5.926926136016846, 6.042633056640625, 6.234410762786865, 6.131226539611816, 6.008813381195068, 6.011267185211182, 5.739792823791504, 5.916202068328857, 5.973809719085693, 6.0450053215026855, 6.280395984649658, 6.177944183349609, 5.940171241760254, 6.073044300079346, 6.226536750793457, 6.421622276306152, 0.0, 6.3818159103393555, 6.248043060302734, 6.472346305847168, 6.342121601104736, 6.388561248779297, 6.608000755310059, 6.298949241638184, 6.3664703369140625, 6.630683422088623, 6.208590030670166, 6.1377272605896, 6.469250202178955, 6.605298042297363, 6.5638556480407715, 6.5496506690979, 6.35088586807251, 6.480044364929199, 6.390240669250488, 6.511745452880859, 6.335054397583008, 6.656726360321045, 6.679599285125732, 6.701960563659668, 6.748759746551514, 6.539586067199707, 6.4329400062561035, 6.398594856262207, 6.5496506690979, 6.533788681030273, 9.30737590789795, 6.570882797241211, 6.5638556480407715, 6.152732849121094], 'ipl_vec_n': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'a__n': 5.47928524017334, 'g_g_m_n': -1.359645962715149, 't_WIFI_n': -1.0556801557540894, 'p': 811.9238891601562, 'price_cat_3_n': -0.48571521043777466, 'a_3_n': -0.5047609806060791}
samples = [sample, sample1, sample2]
columns = sample.keys()
day_list = day_list[:-predict_window]
errs = []
for _ in samples:
sample = {}
for feature in columns:
sample[feature] = _[feature]
whole_ts = sample['ts'][:]
expected = whole_ts[-predict_window:]
sample['ts'] = whole_ts[:-predict_window]
input_ts = sample['ts']
# zipped = zip(day_list, input_ts)
# print(zipped)
# print(len(zipped))
print("-------------------------------------")
# print(sample)
print("-------------------------------------")
response = predict_daily_uckey(
sample=sample, days=day_list, serving_url=cfg['serving_url'], model_stats=model_stats, columns=columns)
predicted = [response[_] for _ in sorted(response)]
for i, v in enumerate(expected):
if v == 0:
predicted[i] = 0
print(zip(expected, predicted))
e = error_m(expected, predicted)[0]
print(e*100)
if e < 0:
e = 0
errs.append(e)
print(sum(errs)/(len(errs)*1.0)*100)
if __name__ == '__main__':
cfg = {
'log_level': 'warn',
'trainready_table': 'dlpm_111021_no_residency_no_mapping_trainready_test_12212021',
'dist_table': 'dlpm_111021_no_residency_no_mapping_tmp_distribution_test_12212021',
'serving_url': 'http://10.193.217.126:8503/v1/models/dl_test_1221:predict',
'max_calls': 4,
'model_stat_table': 'dlpm_111021_no_residency_no_mapping_model_stat_test_12212021',
'yesterday': 'WILL BE SET IN PROGRAM'}
sc = SparkContext.getOrCreate()
hive_context = HiveContext(sc)
sc.setLogLevel(cfg['log_level'])
run(cfg=cfg, hive_context=hive_context)