def plot_heatmap(ax, X, rank_list, k_list,
vmin = 0, vcenter = 0.9, vmax = 1.0):
cmap1 = mpl_cmaps.get_cmap("YlOrRd").copy()
cmap1.set_bad("w")
norm1 = mpl_colors.TwoSlopeNorm(vmin = vmin, vcenter = vcenter, vmax = vmax)
im1 = ax.imshow(X, cmap = cmap1, norm = norm1, origin = 'lower')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.2)
cbar = plt.colorbar(im1, cax=cax, fraction = 0.1)
ax.set_xlabel("Nuclear norm constraint radius r")
ax.set_ylabel("Rank k")
ax.set_yticks(np.arange(len(k_list)))
ax.set_yticklabels([str(int(k)) for k in k_list])
ax.set_xticks(np.arange(len(rank_list)))
ax.set_xticklabels([str(int(r)) for r in rank_list], rotation=90)
return
def make_projection_distance_figure(prefix,
k_choose=24, vmin=0.09, vcenter=0.12, vmax=0.2):
records = load_stability_jsons(stability_out_dir, prefix)
dist_df = build_by_k_metrics_grid(records, "mean_dist", duplicate="first")
rank_list = dist_df.columns.to_numpy()
k_list = dist_df.index.to_numpy()
fig = plt.figure(figsize=(15, 9), constrained_layout=True)
gs = fig.add_gridspec(nrows=1, ncols=2,
width_ratios=[1, 0.8], # heatmap, line plot
wspace=0.2,
)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
plot_heatmap(ax1, dist_df.to_numpy(), rank_list, k_list,
vmin=vmin, vcenter=vcenter, vmax=vmax)
ax1.set_title("Mean projection distance across splits", pad = 20)
# Plot a single row of the heatmap as a line plot.
# Makes it easier to see the plateau onset for one chosen rank.
dist_vals = dist_df.loc[k_choose].to_numpy()
ax2.plot(np.arange(len(dist_vals)), dist_vals, 'o-')
ax2.set_xlabel("Nuclear norm constraint radius r")
ax2.set_ylabel(f"Mean projection distance across splits at k={k_choose:d}")
ax2.set_xticks(np.arange(len(dist_vals)))
ax2.set_xticklabels(rank_list, rotation=90)
ax2.set_title(f"Fixed-rank slice at k={k_choose}", pad = 20)
fig.savefig(
fig_dir / f"{prefix}_projection_distance.png",
bbox_inches="tight",
)
plt.close(fig)
# plt.show()
for prefix in PREFIXES:
k_choose = 24
vmin, vcenter, vmax = 0.09, 0.12, 0.20
if prefix == "pgd_afw_nnm_corr":
k_choose = 20
vmin, vcenter, vmax = 0.1, 0.2, 0.4
make_projection_distance_figure(prefix,
k_choose=k_choose, vmin=vmin, vcenter=vcenter, vmax=vmax)