blob: 7d4acb3e2ee9d4f8585655213e6c0921a2f9c2e5 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import palettable
from matplotlib.ticker import MaxNLocator
import numpy
from src.common.constant import Config
import matplotlib
# lines' mark size
set_marker_size = 15
# points' mark size
set_marker_point = 14
# points' mark size
set_font_size = 40
set_lgend_size = 15
set_tick_size = 20
frontinsidebox = 23
# update tick size
matplotlib.rc('xtick', labelsize=set_tick_size)
matplotlib.rc('ytick', labelsize=set_tick_size)
plt.rcParams['axes.labelsize'] = set_tick_size
mark_list = ["o", "*", "<", "^", "s", "d", "D", ">", "h"]
mark_size_list = [set_marker_size, set_marker_size + 1, set_marker_size + 1, set_marker_size,
set_marker_size, set_marker_size, set_marker_size, set_marker_size + 1, set_marker_size + 2]
line_shape_list = ['-.', '--', '-', ':']
# this is for draw figure3 only
def get_plot_compare_with_base_line_cfg(search_space, dataset, if_with_phase1=False):
if search_space == Config.NB201:
run_range_ = range(0, 100, 1)
if if_with_phase1:
draw_graph = draw_anytime_result_with_p1
else:
draw_graph = draw_anytime_result
# min, this is for plot only
if dataset == Config.c10:
# C10 array
budget_array = [0.017, 0.083] + list(range(1, 350, 4))
sub_graph_y1 = [91, 94.5]
sub_graph_y2 = [53.5, 55]
sub_graph_split = 60
elif dataset == Config.c100:
# C10 array
budget_array = [0.017, 0.083] + list(range(1, 350, 4))
sub_graph_y1 = [64, 73.5]
sub_graph_y2 = [15, 16]
sub_graph_split = 20
else:
# ImgNet X array
budget_array = [0.017, 0.083] + list(range(1, 350, 4))
sub_graph_y1 = [33, 48]
sub_graph_y2 = [15.5, 17]
sub_graph_split = 34
else:
# this is NB101 + C10, because only 101 has 20 run. others have 100 run.
run_range_ = range(0, 20, 1)
if if_with_phase1:
draw_graph = draw_anytime_result_one_graph_with_p1
# budget_array = list(range(1, 16, 1))
budget_array = numpy.arange(0.02, 15, 0.02).tolist()
else:
draw_graph = draw_anytime_result_one_graph
budget_array = [0.017, 0.083] + list(range(1, 2000, 8))
if dataset == Config.c10:
# C10 array
# budget_array = list(range(0, 2000, 1))
sub_graph_y1 = [90, 94.5]
sub_graph_y2 = [52, 55]
sub_graph_split = 60
else:
raise Exception
return run_range_, budget_array, sub_graph_y1, sub_graph_y2, sub_graph_split, draw_graph
def draw_anytime_result(result_dir, y_acc_list_arr, x_T_list,
x_acc_train, y_acc_train_l, y_acc_train_m, y_acc_train_h,
annotations, lv,
name_img, dataset,
x1_lim=[], x2_lim=[],
):
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, dpi=100, gridspec_kw={'height_ratios': [4, 1]})
exp = np.array(y_acc_list_arr)
sys_acc_h = np.quantile(exp, .75, axis=0)
sys_acc_m = np.quantile(exp, .5, axis=0)
sys_acc_l = np.quantile(exp, .25, axis=0)
# plot simulate result of system
ax1.fill_between(x_T_list, sys_acc_l, sys_acc_h, alpha=0.1)
ax1.plot(x_T_list, sys_acc_m, mark_list[-1], label="TRAILS")
ax2.fill_between(x_T_list, sys_acc_l, sys_acc_h, alpha=0.1)
# plot simulate result of train-based line
ax1.fill_between(x_acc_train, y_acc_train_l, y_acc_train_h, alpha=0.3)
ax1.plot(x_acc_train, y_acc_train_m, mark_list[-2], label="Training-based MS")
ax2.fill_between(x_acc_train, y_acc_train_l, y_acc_train_h, alpha=0.3)
for i in range(len(annotations)):
ele = annotations[i]
if ele[1] < lv:
# convert to mins
ax2.plot(ele[2] / 60, ele[1], mark_list[i], label=ele[0], fontsize=set_marker_size)
else:
ax1.plot(ele[2] / 60, ele[1], mark_list[i], label=ele[0], fontsize=set_marker_size)
# ax2.scatter(ele[2]/60, ele[1]* 0.01, s=100, color="red")
# ax2.annotate(ele[0], (ele[2]/60, ele[1] * 0.01))
if len(x1_lim) > 0 and len(x2_lim) > 0:
ax1.set_ylim(x1_lim[0], x1_lim[1]) # 子图1设置y轴范围,只显示部分图
ax2.set_ylim(x2_lim[0], x2_lim[1]) # 子图2设置y轴范围,只显示部分图
ax1.spines['bottom'].set_visible(False) # 关闭子图1中底部脊
ax2.spines['top'].set_visible(False) ##关闭子图2中顶部脊
ax2.set_xticks(range(0, 31, 1))
d = .85 # 设置倾斜度
# 绘制断裂处的标记
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=set_marker_size,
linestyle='none', color='r', mec='r', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
plt.tight_layout()
plt.xscale("symlog")
ax1.grid()
ax2.grid()
plt.xlabel("Time Budget given by user (min)", fontsize=set_font_size)
ax1.set_ylabel(f"Test accuracy on {dataset}", fontsize=set_font_size)
ax1.legend(ncol=1, fontsize=set_lgend_size)
ax2.legend(fontsize=set_lgend_size)
# plt.show()
plt.savefig(f"{result_dir}/any_time_{name_img}.pdf", bbox_inches='tight')
def draw_anytime_result_one_graph(y_acc_list_arr, x_T_list,
x_acc_train, y_acc_train_l, y_acc_train_m, y_acc_train_h,
annotations, lv,
name_img, dataset,
x1_lim=[], x2_lim=[],
):
# fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, dpi=100, gridspec_kw={'height_ratios': [5, 1]})
exp = np.array(y_acc_list_arr) * 100
sys_acc_h = np.quantile(exp, .75, axis=0)
sys_acc_m = np.quantile(exp, .5, axis=0)
sys_acc_l = np.quantile(exp, .25, axis=0)
# exp_time = np.array(real_time_used_arr)
# time_mean = np.quantile(exp_time, .5, axis=0)
time_mean = x_T_list
# plot simulate result of system
plt.fill_between(time_mean, sys_acc_l, sys_acc_h, alpha=0.1)
plt.plot(time_mean, sys_acc_m, "o-", label="TRAILS")
# plt.plot(time_mean, sys_acc_m, label="TRAILS")
# plot simulate result of train-based line
plt.fill_between(x_acc_train, y_acc_train_l, y_acc_train_h, alpha=0.3)
plt.plot(x_acc_train, y_acc_train_m, "o-", label="Training-based MS")
# plt.plot(x_acc_train, y_acc_train_m, label="Training-based MS")
if len(x1_lim) > 0:
plt.ylim(x1_lim[0], x1_lim[1]) # 子图1设置y轴范围,只显示部分图
d = .85 # 设置倾斜度
# 绘制断裂处的标记
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=set_marker_size,
linestyle='none', color='r', mec='r', mew=1, clip_on=False)
# plt.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
# plt.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
plt.tight_layout()
# plt.xscale("symlog")
plt.grid()
plt.xlabel("Time Budget given by user (min)", fontsize=set_font_size)
plt.ylabel(f"Test accuracy on {dataset}", fontsize=set_font_size)
plt.legend(ncol=1, fontsize=set_lgend_size)
plt.show()
# plt.savefig(f"amy_time_{name_img}.pdf", bbox_inches='tight')
# those two function will plot phase 1 and phase 2
def draw_anytime_result_with_p1(result_dir, y_acc_list_arr, x_T_list, y_acc_list_arr_p1, x_T_list_p1,
x_acc_train, y_acc_train_l, y_acc_train_m, y_acc_train_h,
annotations, lv,
name_img, dataset, max_value,
x1_lim=[], x2_lim=[],
):
fig, (ax1, ax2) = plt.subplots(
2, 1,
sharex=True,
dpi=100,
gridspec_kw={'height_ratios': [6, 1]})
shade_degree = 0.2
# plot simulate result of train-based line
ax1.plot(x_acc_train, y_acc_train_m, mark_list[-3] + line_shape_list[0], label="Training-Based MS",
markersize=mark_size_list[-3])
ax1.fill_between(x_acc_train, y_acc_train_l, y_acc_train_h, alpha=shade_degree)
ax2.fill_between(x_acc_train, y_acc_train_l, y_acc_train_h, alpha=shade_degree)
# plot simulate result of system
exp = np.array(y_acc_list_arr_p1)
sys_acc_p1_h = np.quantile(exp, .75, axis=0)
sys_acc_p1_m = np.quantile(exp, .5, axis=0)
sys_acc_p1_l = np.quantile(exp, .25, axis=0)
ax1.plot(x_T_list_p1, sys_acc_p1_m, mark_list[-2] + line_shape_list[1], label="Training-Free MS",
markersize=mark_size_list[-2])
ax1.fill_between(x_T_list_p1, sys_acc_p1_l, sys_acc_p1_h, alpha=shade_degree)
ax2.fill_between(x_T_list_p1, sys_acc_p1_l, sys_acc_p1_h, alpha=shade_degree)
# plot simulate result of system
exp = np.array(y_acc_list_arr)
sys_acc_h = np.quantile(exp, .75, axis=0)
sys_acc_m = np.quantile(exp, .5, axis=0)
sys_acc_l = np.quantile(exp, .25, axis=0)
ax1.plot(x_T_list, sys_acc_m, mark_list[-1] + line_shape_list[2], label="2Phase-MS", markersize=mark_size_list[-1])
ax1.fill_between(x_T_list, sys_acc_l, sys_acc_h, alpha=shade_degree)
ax2.fill_between(x_T_list, sys_acc_l, sys_acc_h, alpha=shade_degree)
print(f"speed-up on {dataset} = {x_acc_train[-1] / x_T_list[-2]}, "
f"t_train = {x_acc_train[-1]}, t_f = {x_T_list[-2]}")
for i in range(len(annotations)):
ele = annotations[i]
if ele[1] < lv:
# convert to mins
ax2.plot(ele[2] / 60, ele[1], mark_list[i], label=ele[0], markersize=set_marker_point)
else:
ax1.plot(ele[2] / 60, ele[1], mark_list[i], label=ele[0], markersize=set_marker_point)
# ax2.scatter(ele[2]/60, ele[1]* 0.01, s=100, color="red")
# ax2.annotate(ele[0], (ele[2]/60, ele[1] * 0.01))
if len(x1_lim) > 0 and len(x2_lim) > 0:
ax1.set_ylim(x1_lim[0], x1_lim[1]) # 子图1设置y轴范围,只显示部分图
ax2.set_ylim(x2_lim[0], x2_lim[1]) # 子图2设置y轴范围,只显示部分图
ax1.spines['bottom'].set_visible(False) # 关闭子图1中底部脊
ax2.spines['top'].set_visible(False) ##关闭子图2中顶部脊
ax2.set_xticks(range(0, 31, 1))
d = .85 # 设置倾斜度
# 绘制断裂处的标记
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=set_marker_size,
linestyle='none', color='r', mec='r', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
plt.xscale("log")
ax1.grid()
ax2.grid()
plt.xlabel(r"Response Time Threshold $T_{max}$ (min)", fontsize=set_font_size)
ax1.set_ylabel(f"Test Acc on {'In-16'}", fontsize=set_font_size)
# ax1.legend(ncol=1, fontsize=set_lgend_size)
# ax2.legend(fontsize=set_lgend_size)
ax1.xaxis.label.set_size(set_tick_size)
ax1.yaxis.label.set_size(set_tick_size)
# ax1.set_xticks([])
ax2.xaxis.label.set_size(set_tick_size)
ax2.yaxis.label.set_size(set_tick_size)
ax1.yaxis.set_major_locator(MaxNLocator(nbins=4, integer=True))
ax1.axhline(max_value, color='r', linestyle='-', label='Global Best Accuracy')
tick_values = [0.01, 0.1, 1, 10, 100, 1000]
ax2.set_xticks(tick_values)
ax2.set_xticklabels([f'$10^{{{int(np.log10(val))}}}$' for val in tick_values])
# this is for unique hash
export_legend(
fig,
colnum=3,
unique_labels=['TE-NAS (Training-Free)', 'ENAS (Weight sharing)',
'KNAS (Training-Free)', 'DARTS-V1 (Weight sharing)', 'DARTS-V2 (Weight sharing)',
'Training-Based MS', 'Training-Free MS', '2Phase-MS', 'Global Best Accuracy'])
plt.tight_layout()
fig.savefig(f"{result_dir}/any_time_{name_img}_p1_from_0.1_sec.pdf", bbox_inches='tight')
def export_legend(ori_fig, filename="any_time_legend", colnum=9, unique_labels=[]):
fig2 = plt.figure(figsize=(5, 0.3))
lines_labels = [ax.get_legend_handles_labels() for ax in ori_fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
# grab unique labels
if len(unique_labels) == 0:
unique_labels = set(labels)
# assign labels and legends in dict
legend_dict = dict(zip(labels, lines))
# query dict based on unique labels
unique_lines = [legend_dict[x] for x in unique_labels]
fig2.legend(unique_lines, unique_labels, loc='center',
ncol=colnum,
fancybox=True,
shadow=True, scatterpoints=1, fontsize=set_lgend_size)
fig2.tight_layout()
fig2.savefig(f"{filename}.pdf", bbox_inches='tight')
def draw_anytime_result_one_graph_with_p1(y_acc_list_arr, x_T_list, y_acc_list_arr_p1, x_T_list_p1,
x_acc_train, y_acc_train_l, y_acc_train_m, y_acc_train_h,
annotations, lv,
name_img, dataset,
x1_lim=[], x2_lim=[],
):
# fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, dpi=100, gridspec_kw={'height_ratios': [5, 1]})
# plot simulate result of system
exp = np.array(y_acc_list_arr_p1) * 100
sys_acc_p1_h = np.quantile(exp, .75, axis=0)
sys_acc_p1_m = np.quantile(exp, .5, axis=0)
sys_acc_p1_l = np.quantile(exp, .25, axis=0)
plt.fill_between(x_T_list_p1, sys_acc_p1_l, sys_acc_p1_h, alpha=0.1)
plt.plot(x_T_list_p1, sys_acc_p1_m, "o-", label="TRAILS-P1")
# plt.fill_between(x_T_list_p1, sys_acc_p1_l, sys_acc_p1_h, alpha=0.1)
exp = np.array(y_acc_list_arr) * 100
sys_acc_h = np.quantile(exp, .75, axis=0)
sys_acc_m = np.quantile(exp, .5, axis=0)
sys_acc_l = np.quantile(exp, .25, axis=0)
# exp_time = np.array(real_time_used_arr)
# time_mean = np.quantile(exp_time, .5, axis=0)
time_mean = x_T_list
# plot simulate result of system
plt.fill_between(time_mean, sys_acc_l, sys_acc_h, alpha=0.1)
plt.plot(time_mean, sys_acc_m, "o-", label="TRAILS")
# plt.plot(time_mean, sys_acc_m, label="TRAILS")
# plot simulate result of train-based line
plt.fill_between(x_acc_train, y_acc_train_l, y_acc_train_h, alpha=0.3)
plt.plot(x_acc_train, y_acc_train_m, "o-", label="Training-based MS")
# plt.plot(x_acc_train, y_acc_train_m, label="Training-based MS")
if len(x1_lim) > 0:
plt.ylim(x1_lim[0], x1_lim[1]) # 子图1设置y轴范围,只显示部分图
d = .85 # 设置倾斜度
# 绘制断裂处的标记
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=set_marker_size,
linestyle='none', color='r', mec='r', mew=1, clip_on=False)
# plt.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
# plt.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
plt.tight_layout()
plt.xscale("symlog")
plt.grid()
plt.xlabel("Time Budget given by user (min)", fontsize=set_font_size)
plt.ylabel(f"Test accuracy on {dataset}", fontsize=set_font_size)
plt.legend(ncol=1, fontsize=set_lgend_size)
# plt.show()
plt.savefig(f"amy_time_{name_img}.pdf", bbox_inches='tight')
# for K, U N trade-off
def draw_grid_graph_with_budget(
acc, bt, b1, b2,
img_name: str, y_array: list, x_array: list):
"""
:param acc: Two array list
:param bt: Two array list
:param img_name: img name string
:return:
"""
acc_new = np.array(acc)
acc = acc_new.tolist()
mask = np.array(acc)
mask[mask > 0] = 0
mask[mask < 0] = 1
bt = np.round(np.array(bt), 2).tolist()
mask2 = np.array(bt)
mask2[mask2 > 0] = 0
mask2[mask2 < 0] = 1
mask3 = np.array(b1)
mask3[mask3 > 0] = 0
mask3[mask3 < 0] = 1
mask4 = np.array(b2)
mask4[mask4 > 0] = 0
mask4[mask4 < 0] = 1
fig, ax = plt.subplots(2, 2, figsize=(15, 14))
linewidths = 0.5
sns.set(font_scale=3)
sns.heatmap(
data=acc,
vmax=99,
vmin=93,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'bottom'},
mask=mask,
square=True, linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .5},
ax=ax[0, 0]
)
sns.heatmap(
data=bt,
# vmax=,
vmin=-9,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'top'},
mask=mask2,
square=True, linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .5},
ax=ax[0, 1]
)
sns.heatmap(
data=b1,
vmax=17000,
vmin=15000,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".0f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'top'},
mask=mask4,
square=True, linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .5},
ax=ax[1, 0]
)
sns.heatmap(
data=b2,
# vmax=,
# vmin=-9,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".0f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'top'},
mask=mask4,
square=True, linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .5},
ax=ax[1, 1]
)
plt.tight_layout()
plt.xlabel("U (epoch)", fontsize=set_font_size)
plt.ylabel("K (# models)", fontsize=set_font_size)
for i in [0, 1]:
for j in [0, 1]:
ax[i, j].set_xticklabels(x_array, fontsize=set_font_size)
ax[i, j].set_yticklabels(y_array, fontsize=set_font_size)
ax[i, j].set_xlabel("U (# epoch)", fontsize=set_font_size)
ax[i, j].set_ylabel("K (# models)", fontsize=set_font_size)
ax[0, 0].set_title('Test Accuracy (%)', fontsize=set_font_size)
ax[0, 1].set_title(r'Time Budget $T$ (min)', fontsize=set_font_size)
ax[1, 0].set_title(r'$N$', fontsize=set_font_size)
ax[1, 1].set_title(r"$K \cdot U \cdot \log_{\eta}K$", fontsize=set_font_size)
plt.tight_layout()
fig.subplots_adjust(wspace=0.001, hspace=0.3)
# plt.show()
base_dr = os.getcwd()
path_gra = os.path.join(base_dr, f"{img_name}.pdf")
fig.savefig(path_gra, bbox_inches='tight')
def draw_grid_graph_with_budget_only_Acc_and_T(
acc, bt, b1, b2,
img_name: str, y_array: list, x_array: list):
"""
:param acc: Two array list
:param bt: Two array list
:param img_name: img name string
:return:
"""
acc_new = np.array(acc)
acc = acc_new.tolist()
mask = np.array(acc)
mask[mask > 0] = 0
mask[mask < 0] = 1
bt = np.round(np.array(bt), 2).tolist()
mask2 = np.array(bt)
mask2[mask2 > 0] = 0
mask2[mask2 < 0] = 1
mask3 = np.array(b1)
mask3[mask3 > 0] = 0
mask3[mask3 < 0] = 1
mask4 = np.array(b2)
mask4[mask4 > 0] = 0
mask4[mask4 < 0] = 1
fig, ax = plt.subplots(1, 2, figsize=(15, 14))
linewidths = 0.5
sns.set(font_scale=2)
sns.heatmap(
data=acc,
vmax=99,
vmin=93,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'bottom'},
mask=mask,
square=True,
linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .4},
ax=ax[0]
)
sns.heatmap(
data=bt,
vmax=600,
# vmin=-9,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'top'},
mask=mask2,
square=True,
linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .4},
ax=ax[1]
)
plt.tight_layout()
plt.xlabel("U (epoch)", fontsize=set_font_size)
plt.ylabel("K (# models)", fontsize=set_font_size)
for j in [0, 1]:
ax[j].set_xticklabels(x_array, fontsize=set_font_size)
ax[j].set_yticklabels(y_array, fontsize=set_font_size)
ax[j].set_xlabel("U (# epoch)", fontsize=set_font_size)
ax[j].set_ylabel("K (# models)", fontsize=set_font_size)
ax[0].set_title('Test Accuracy (%)', fontsize=set_font_size)
ax[1].set_title(r'Time Budget $T$ (min)', fontsize=set_font_size)
plt.tight_layout()
fig.subplots_adjust(wspace=0.3, hspace=0.3)
# plt.show()
base_dr = os.getcwd()
path_gra = os.path.join(base_dr, f"{img_name}.pdf")
fig.savefig(path_gra, bbox_inches='tight')
def draw_grid_graph_with_budget_only_Acc(
acc, bt, b1, b2,
img_name: str, y_array: list, x_array: list):
"""
:param acc: Two array list
:param bt: Two array list
:param img_name: img name string
:return:
"""
acc_new = np.array(acc)
acc = acc_new.tolist()
mask = np.array(acc)
mask[mask > 0] = 0
mask[mask < 0] = 1
fig = plt.figure(figsize=(7, 14))
linewidths = 0.5
sns.set(font_scale=2)
sns.heatmap(
data=acc,
vmax=99,
vmin=93,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'bottom'},
mask=mask,
square=True,
linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .4},
ax=fig
)
plt.tight_layout()
plt.xlabel("U (epoch)", fontsize=set_font_size)
plt.ylabel("K (# models)", fontsize=set_font_size)
plt.xticks(x_array, fontsize=set_font_size)
plt.yticks(y_array, fontsize=set_font_size)
plt.title('Test Accuracy (%)', fontsize=set_font_size)
plt.tight_layout()
# fig.subplots_adjust(wspace=0.3, hspace=0.3)
# plt.show()
base_dr = os.getcwd()
path_gra = os.path.join(base_dr, f"{img_name}.pdf")
fig.savefig(path_gra, bbox_inches='tight')
def draw_grid_graph_with_budget_only_T(
acc, bt, b1, b2,
img_name: str, y_array: list, x_array: list):
"""
:param acc: Two array list
:param bt: Two array list
:param img_name: img name string
:return:
"""
acc_new = np.array(acc)
acc = acc_new.tolist()
mask = np.array(acc)
mask[mask > 0] = 0
mask[mask < 0] = 1
bt = np.round(np.array(bt), 2).tolist()
mask2 = np.array(bt)
mask2[mask2 > 0] = 0
mask2[mask2 < 0] = 1
mask3 = np.array(b1)
mask3[mask3 > 0] = 0
mask3[mask3 < 0] = 1
mask4 = np.array(b2)
mask4[mask4 > 0] = 0
mask4[mask4 < 0] = 1
fig, ax = plt.subplots(1, 2, figsize=(15, 14))
linewidths = 0.5
sns.set(font_scale=2)
sns.heatmap(
data=acc,
vmax=99,
vmin=93,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'bottom'},
mask=mask,
square=True,
linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .4},
ax=ax[0]
)
sns.heatmap(
data=bt,
vmax=600,
# vmin=-9,
cmap=palettable.cmocean.diverging.Curl_10.mpl_colors,
annot=True,
fmt=".2f",
annot_kws={'size': frontinsidebox, 'weight': 'normal', 'color': 'w', 'va': 'top'},
mask=mask2,
square=True,
linewidths=linewidths, # 每个方格外框显示,外框宽度设置
cbar_kws={"shrink": .4},
ax=ax[1]
)
plt.tight_layout()
plt.xlabel("U (epoch)", fontsize=set_font_size)
plt.ylabel("K (# models)", fontsize=set_font_size)
for j in [0, 1]:
ax[j].set_xticklabels(x_array, fontsize=set_font_size)
ax[j].set_yticklabels(y_array, fontsize=set_font_size)
ax[j].set_xlabel("U (# epoch)", fontsize=set_font_size)
ax[j].set_ylabel("K (# models)", fontsize=set_font_size)
ax[0].set_title('Test Accuracy (%)', fontsize=set_font_size)
ax[1].set_title(r'Time Budget $T$ (min)', fontsize=set_font_size)
plt.tight_layout()
fig.subplots_adjust(wspace=0.3, hspace=0.3)
# plt.show()
base_dr = os.getcwd()
path_gra = os.path.join(base_dr, f"{img_name}.pdf")
fig.savefig(path_gra, bbox_inches='tight')