Open qm-intel opened 5 months ago
@HanxunH IS the code for figure 1.b is available to visualize the data? Thanks
def plot_detph_grid_search_by_targets_full(exp_raw_dict, graph_file_name=None):
fig = plt.figure(figsize=(16, 10), dpi=300, facecolor='w', edgecolor='k')
plt.rcParams.update({
"text.usetex": False,
"font.family": "sans-serif",
"font.sans-serif": ["Helvetica"]})
pgd_scores_n = []
model_parameters_n = []
pgd_scores_h = []
model_parameters_h = []
h_keys = []
pgd_scores_b = []
model_parameters_b = []
b_keys = []
pgd_scores_low = []
model_parameters_low = []
low_keys = []
for target in exp_raw_dict.keys():
if 'best' in target:
continue
score = float(exp_raw_dict[target]['PGD_20'])
size = float(exp_raw_dict[target]['model_parameters'])
if target == 'd5_5_5':
pgd_scores_b.append(score)
model_parameters_b.append(size)
key = '$^{d}$'+target[1]+'-'+target[3]+'-'+target[5]
b_keys.append(key)
elif score >= 52.5:
pgd_scores_h.append(score)
model_parameters_h.append(size)
key = '$^{d}$'+target[1]+'-'+target[3]+'-'+target[5]
h_keys.append(key)
elif score <= 50:
pgd_scores_low.append(score)
model_parameters_low.append(size)
key = '$^{d}$'+target[1]+'-'+target[3]+'-'+target[5]
low_keys.append(key)
else:
pgd_scores_n.append(score)
model_parameters_n.append(size)
plt.scatter(model_parameters_h, pgd_scores_h, s=100, alpha=0.8, marker='v', color='tab:blue', label='$>=$ 52.5')
plt.scatter(model_parameters_low, pgd_scores_low, s=100, alpha=0.8, marker='x', color='tab:orange', label='$<=$ 50.0')
plt.scatter(model_parameters_b, pgd_scores_b, s=300, alpha=0.8, marker='*', color='tab:red', label='WRN-34-10')
plt.scatter(model_parameters_n, pgd_scores_n, s=100, alpha=0.8, color='tab:green', label='Others')
for i, h_key in enumerate(h_keys):
plt.annotate(h_key, (model_parameters_h[i]-0.55, pgd_scores_h[i]-0.25), fontsize=18)
for i, low_key in enumerate(low_keys):
plt.annotate(low_key, (model_parameters_low[i]-0.55, pgd_scores_low[i]-0.25), fontsize=18)
for i, b_key in enumerate(b_keys):
plt.annotate(b_key, (model_parameters_b[i]-0.55, pgd_scores_b[i]-0.25), fontsize=18)
plt.ylim(47, 55)
plt.ylabel('PGD$^{20}$ Robustness', fontsize=30)
plt.xlabel('Parameters (M)', fontsize=30)
plt.legend()
plt.rcParams.update({'font.size': 30})
if graph_file_name is None:
plt.show()
else:
plt.savefig(graph_file_name + '.pdf', bbox_inches='tight')
@HanxunH IS the code for figure 1.b is available to visualize the data? Thanks