from matplotlib import colormaps as mpl_cmaps
import matplotlib.colors as mpl_colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
variables = ['p', 'h2']
variable_values = {
'p' : [500, 1000, 2000, 5000, 10000],
'k' : [2, 5, 10, 15, 20],
'h2' : [0.05, 0.2, 0.1, 0.3, 0.4],
'h2_shared_frac' : [0.2, 0.4, 0.6, 0.8, 1.0],
}
xlabel = {
'p': "No. of variants",
'k': "No. of factors",
'h2': "Heritability",
'h2_shared_frac': "Fraction of shared heritability",
}
score_names = {
'adj_MI': "Adjusted MI Score",
# 'MI': "Mutual Information Score",
'L_rmse': r"$\| L - \hat{L}\|_F$",
'F_rmse': r"$\| F - \hat{F}\|_F$",
'Z_rmse': r"$\| LF^{T} - \hat{L}\hat{F}^{T}\|_F$",
}
nmethods = len(methods)
nscores = len(score_names)
fig = plt.figure(figsize = (18, 24))
boxs = {x: None for x in methods.keys()}
gs = fig.add_gridspec(6, 6, height_ratios=(0.7, 0.5, 1, 1, 1, 1), wspace=0, hspace=0)
ax0_dummy = fig.add_subplot(gs[0, :])
ax_dummy = fig.add_subplot(gs[1,:])
ax0 = gs[0, :].subgridspec(1, 3, width_ratios=[1, 1, 0.5], wspace=0.05)
ax0l = fig.add_subplot(ax0[0, 0:2])
ax0r = fig.add_subplot(ax0[0, 2])
ax_p = [fig.add_subplot(gs[2, 0:3]),
fig.add_subplot(gs[3, 0:3]),
fig.add_subplot(gs[4, 0:3]),
fig.add_subplot(gs[5, 0:3])]
ax_h2 = [fig.add_subplot(gs[2, 3:], sharey = ax_p[0]),
fig.add_subplot(gs[3, 3:], sharey = ax_p[1]),
fig.add_subplot(gs[4, 3:], sharey = ax_p[2]),
fig.add_subplot(gs[5, 3:], sharey = ax_p[3])]
axs = [ax_p, ax_h2]
def make_plot_pca(ax, comp, labels, unique_labels, colorlist, bgcolor = "#F0F0F0"):
pc1 = comp[:, 0]
pc2 = comp[:, 1]
for i, label in enumerate(unique_labels):
idx = np.array([k for k, x in enumerate(labels) if x == label])
ax.scatter(pc1[idx], pc2[idx], s = 100, alpha = 0.7, label = label, color = colorlist[i])
ax.tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
ax.patch.set_facecolor(bgcolor)
ax.patch.set_alpha(0.0)
for side, border in ax.spines.items():
border.set_visible(False)
return
def make_plot_covariance_heatmap(ax, X, vmax = 1):
cmap1 = mpl_cmaps.get_cmap("YlOrRd").copy()
cmap1.set_bad("w")
norm1 = mpl_colors.TwoSlopeNorm(vmin = 0., vcenter = vmax / 2., vmax = vmax)
im1 = ax.imshow(X.T, cmap = cmap1, norm = norm1, interpolation='nearest', origin = 'lower')
ax.tick_params(bottom = True, top = False, left = True, right = False,
labelbottom = True, labeltop = False, labelleft = True, labelright = False)
ax.set_xticks([0,100,200])
ax.set_yticks([0,100,200])
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="10%", pad=0.2)
cbar = plt.colorbar(im1, cax=cax, fraction = 0.1)
cax.set_ylabel("Correlation $r^2$")
# shift the box to the left
# box = ax.get_position()
# box.x0 = box.x0 - 0.03
# box.x1 = box.x1 - 0.03
# box.y0 = box.y0 - 0.01
# box.y1 = box.y1 - 0.01
# ax.set_position(box)
colorlist = [nygc_colors[x] for x in ['orange', 'blue', 'yellowgreen']]
nsample, p = simdata['Z'].shape
k = simdata['Ltrue'].shape[1]
Ltrue_cov = np.cov(simdata['Ltrue']) * np.sqrt(np.prod(simdata['Ltrue'].shape))
Ltrue_cov = np.cov(simdata['Ltrue']) * np.sqrt(simdata['Ltrue'].shape[0])
make_plot_covariance_heatmap(ax0r, Ltrue_cov, vmax = 0.2)
for ax in [ax_dummy, ax0l, ax0_dummy]:
ax.tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
for side, border in list(ax0_dummy.spines.items()) + \
list(ax0l.spines.items()) + \
list(ax_dummy.spines.items()):
border.set_visible(False)
# for side, border in list(ax0_dummy.spines.items()) + \
# list(ax0l.spines.items()) + \
# list(ax0r.spines.items()):
# border.set_visible(True)
panel_label = ["(b)", "(c)"]
gap = 2
for k, var in enumerate(variables):
for i, (score_name, score_label) in enumerate(score_names.items()):
nvariables = len(variable_values[var])
scores = get_scores_from_dataframe(dscout, score_name, var, variable_values[var])
for j, mkey in enumerate(methods.keys()):
boxcolor = method_colors[mkey]
boxface = f'#{boxcolor[1:]}80'
medianprops = dict(linewidth=0, color = boxcolor)
whiskerprops = dict(linewidth=2, color = boxcolor)
boxprops = dict(linewidth=2, color = boxcolor, facecolor = boxface)
flierprops = dict(marker='o', markerfacecolor=boxface, markersize=3, markeredgecolor = boxcolor)
xpos = [x * (nmethods + gap) + j for x in range(nvariables)]
boxs[mkey] = axs[k][i].boxplot(scores[mkey], positions = xpos,
showcaps = False, showfliers = False,
widths = 0.7, patch_artist = True, notch = False,
flierprops = flierprops, boxprops = boxprops,
medianprops = medianprops, whiskerprops = whiskerprops)
# axs[k][i].scatter(random_jitter(xpos, scores[mkey]), scores[mkey],
# edgecolor = boxcolor, facecolor = boxface, linewidths = 1,
# s = 10)
axs[k][i].tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
if i == 3:
axs[k][i].tick_params(bottom = True, labelbottom = True)
xcenter = [x * (nmethods + gap) + (nmethods - 1) / 2 for x in range(nvariables)]
axs[k][i].set_xticks(xcenter)
axs[k][i].set_xticklabels(variable_values[var])
axs[k][i].set_xlabel(xlabel[var])
if k == 0:
axs[k][i].tick_params(left = True, labelleft = True)
axs[k][i].set_ylabel(score_label)
xlim_low = 0 - (nvariables - 1) / 2
xlim_high = (nvariables - 1) * (nmethods + gap) + (nmethods - 1) + (nvariables - 1) / 2
# xlim_high = (nmethods + 1.5) * nvariables - 2.5
axs[k][i].set_xlim( xlim_low, xlim_high )
# axs[k][i].text(-0.25, 1.0, panel_label[i], transform=axs[i].transAxes, fontweight='bold')
# for side, border in list(axs[k][i].spines.items()):
# border.set_visible(False)
axs[k][i].patch.set_alpha(0.0)
if i == 0:
axs[k][i].text(0, 1.1, panel_label[k], transform=axs[k][i].transAxes, fontweight='bold')
# ---- Group shading + separators ----
for j in range(nvariables):
left = j * (nmethods + gap) - 0.5
right = left + nmethods
# alternating faint background blocks
if j % 2 != 0:
axs[k][i].axvspan(left - gap/2, right + gap/2, color="k", alpha=0.02, zorder=0)
# dashed separator between blocks
if j < (nvariables - 1):
axs[k][i].axvline(right + gap/2, ls="--", lw=1, alpha=0.4, color="k", zorder=0)
handles = [boxs[mkey]["boxes"][0] for mkey in methods.keys()]
labels = [method_labels[mkey] for mkey in methods.keys()]
ax0l.legend(handles = handles, labels = labels,
loc = 'upper left', frameon = False, handlelength = 2, ncol = 2)
for ax in [ax0_dummy, ax0l, ax0r, ax_dummy]:
ax.patch.set_alpha(0.0)
# ax0.text(-0.25, 1.0, "(a)", transform=ax0.transAxes, fontweight='bold')
ax0r.text(-0.8, 0.95, "(a)", transform=ax0r.transAxes, fontweight='bold')
# plt.tight_layout(h_pad=1.5, w_pad=1.5)
# gs.tight_layout(fig)
plt.savefig('../plots/colormann-manuscript/sim_expt_variants_02.pdf', bbox_inches='tight')
plt.show()