blob: 645f72f437c1ff47caf4e6f3040a9f2cfd06af84 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
from src.tools.io_tools import read_json
# lines' mark size
set_marker_size = 15
# points' mark size
set_marker_point = 14
# points' mark size
set_font_size = 25
set_lgend_size = 15
set_tick_size = 20
frontinsidebox = 23
# update tick size
matplotlib.rc('xtick', labelsize=set_tick_size)
matplotlib.rc('ytick', labelsize=set_tick_size)
plt.rcParams['axes.labelsize'] = set_tick_size
mark_list = ["o", "*", "<", "^", "s", "d", "D", ">", "h"]
mark_size_list = [set_marker_size, set_marker_size + 1, set_marker_size + 1, set_marker_size,
set_marker_size, set_marker_size, set_marker_size, set_marker_size + 1, set_marker_size + 2]
line_shape_list = ['-.', '--', '-', ':']
shade_degree = 0.2
base_dir = "../exp_data/"
def export_legend(ori_fig, filename="any_time_legend", colnum=9, unique_labels=None):
if unique_labels is None:
unique_labels = []
fig2 = plt.figure(figsize=(5, 0.3))
lines_labels = [ax.get_legend_handles_labels() for ax in ori_fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
# grab unique labels
if len(unique_labels) == 0:
unique_labels = set(labels)
# assign labels and legends in dict
legend_dict = dict(zip(labels, lines))
# query dict based on unique labels
unique_lines = [legend_dict[x] for x in unique_labels]
fig2.legend(unique_lines, unique_labels, loc='center',
ncol=colnum,
fancybox=True,
shadow=True, scatterpoints=1, fontsize=set_lgend_size)
fig2.tight_layout()
fig2.savefig(f"{filename}.pdf", bbox_inches='tight')
def draw_edcf():
# extract train_auc and valid_auc into separate lists
for dataset, architectures in data_dict.items():
fig, ax = plt.subplots(figsize=(6.4, 3.5))
print(dataset)
train_auc = []
valid_auc = []
for architecture, epochs in architectures.items():
for epoch, metrics in epochs.items():
if str(epoch_sampled[dataset]) == epoch:
train_auc.append(metrics["train_auc"])
valid_auc.append(metrics["valid_auc"])
break
# calculate and plot ECDF for train_auc
sorted_train_auc = np.sort(train_auc)
y_train = np.arange(1, len(sorted_train_auc) + 1) / len(sorted_train_auc)
plt.plot(sorted_train_auc, y_train, label='Training AUC', linewidth=3, linestyle='--')
# calculate and plot ECDF for valid_auc
sorted_valid_auc = np.sort(valid_auc)
y_valid = np.arange(1, len(sorted_valid_auc) + 1) / len(sorted_valid_auc)
plt.plot(sorted_valid_auc, y_valid, label='Validation AUC', linewidth=3, linestyle='-')
y_m = np.quantile(sorted_valid_auc, .5, axis=0)
print("medium", y_m, "best", max(sorted_valid_auc))
# plt.xlim(left=0.45)
plt.grid()
plt.xlabel('Accuracy')
plt.ylabel('ECDF')
# plt.legend(loc='upper left', fontsize=set_lgend_size)
plt.tight_layout()
export_legend(ori_fig=fig, colnum=5)
fig.savefig(f"space_{dataset}.pdf", bbox_inches='tight')
# dataset_used = "frappe"
dataset_used = "uci_diabetes"
# dataset_used = "criteo"
epoch_sampled = {"frappe": 19, "uci_diabetes": 35, "criteo": 9}
if dataset_used == "frappe":
mlp_train_frappe = os.path.join(
base_dir,
"tab_data/frappe/all_train_baseline_frappe.json")
data_dict = read_json(mlp_train_frappe)
elif dataset_used == "uci_diabetes":
mlp_train_uci_diabetes = os.path.join(
base_dir,
"tab_data/uci_diabetes/all_train_baseline_uci_160k_40epoch.json")
data_dict = read_json(mlp_train_uci_diabetes)
elif dataset_used == "criteo":
mlp_train_criteo = os.path.join(
base_dir,
"tab_data/criteo/all_train_baseline_criteo.json")
data_dict = read_json(mlp_train_criteo)
else:
print("err")
draw_edcf()