ColormaNN manuscript figures - Simulations

Author

Saikat Banerjee

Published

December 12, 2024

Abstract
High quality plots used for simulation figures, using NYGC color palette.
Code
import numpy as np
import pandas as pd
import pickle
import sys
import os
import dsc
from dsc.query_engine import Query_Processor as dscQP
from dsc import dsc_io

import matplotlib
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
Code
# import matplotlib.font_manager as mpl_fm
# font_path = '/gpfs/commons/home/sbanerjee/nygc/Futura'
# mpl_fm.fontManager.addfont(font_path + '/FuturaStd-Book.otf') # Loads "Futura Std"

# mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 300)
# futura_book = FontProperties(fname='/gpfs/commons/home/sbanerjee/nygc/Futura/FuturaStd-Book.otf')

# NYGC Color Palette
nygc_colors = {
    'brown': '#7F0814',
    'darkred': '#d42e12',
    'orange': '#F37239',
    'darkyellow': '#F79320',
    'yellow': '#FFE438',
    'darkblue': '#003059',
    'blue': '#266DB6',
    'lightblue': '#A3D5ED',
    'darkgreen': '#006838',
    'green': '#0A8A42',
    'lightgreen': '#74B74A',
    'yellowgreen': '#BAD75F',
    'darkgray': '#1A1A1A',
    'gray': '#666666',
    'lightgray': '#CCCCCC',
    'khaki': '#ADA194',
    'darkkhaki': '#5E514D',
}

manuscript_colors = {
    'brown': '#7F180D',
    'darkred': '#C10020',
    'darkyellow': '#FF6800',
    # 'blue': '#2D69C4',
    'blue': '#00538A',
    'green': '#0A8A42',
    'lightgreen': '#74B74A',    
    'yellowgreen': '#93AA00',
    'lightblue': '#A6BDD7',
    'purple': '#803E75',
    'olive': '#232C16',
    'khaki': '#CEA262',
}

# Style sheet for manuscript
mpl_stylesheet.banskt_presentation(dpi = 300, fontsize = 22, 
    splinecolor = nygc_colors['darkgray'], black = nygc_colors['darkgray'])
# plt.rcParams['font.family'] = 'Futura Std'
Code
dsc_output = "/gpfs/commons/groups/knowles_lab/sbanerjee/low_rank_matrix_approximation_numerical_experiments/lrma"
dsc_fname  = os.path.basename(os.path.normpath(dsc_output))
db = os.path.join(dsc_output, dsc_fname + ".db")
dscoutpkl = os.path.join("/gpfs/commons/home/sbanerjee/work/npd/lrma-dsc/dsc/results", dsc_fname + "_dscout.pkl")
dscout    = pd.read_pickle(dscoutpkl)
dscout
DSC simulate simulate.n simulate.p simulate.k simulate.h2 simulate.h2_shared_frac simulate.aq simulate.nsample_minmax lowrankfit mfmethods score.L_rmse score.F_rmse score.Z_rmse score.L_psnr score.F_psnr score.Z_psnr score.MI score.adj_MI
0 1 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca truncated_svd 0.006350 0.002961 0.493251 27.269818 26.260278 26.340408 0.009893 0.045381
1 2 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca truncated_svd 0.005409 0.002855 0.429779 28.954450 26.236996 26.026502 0.559206 0.791608
2 3 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca truncated_svd 0.006291 0.002851 0.475534 27.536409 25.673455 25.022555 0.012680 0.637053
3 4 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca truncated_svd 0.005608 0.002829 0.432801 28.154152 25.813502 25.538536 0.602172 0.662108
4 5 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca truncated_svd 0.005671 0.002974 0.457956 27.982251 25.952034 25.116839 0.503745 0.534922
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1705 8 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None factorgo 0.012889 0.004925 0.414627 21.462818 21.905958 26.098628 0.006982 0.017398
1706 9 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.4 (10000,40000) None factorgo 0.004534 0.002535 0.340617 30.259896 26.561325 28.425665 0.008989 0.012553
1707 9 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None factorgo 0.010747 0.004491 0.458848 23.653181 21.569757 25.348989 0.010147 0.002283
1708 10 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.4 (10000,40000) None factorgo 0.005998 0.002837 0.340220 28.284464 25.837819 27.857088 0.013312 0.034239
1709 10 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None factorgo 0.010672 0.004446 0.459545 24.121927 22.173044 25.245275 0.022239 0.008291

1710 rows × 19 columns

Code
dscout["mfmethods"].unique()
array(['truncated_svd', 'flashier', 'flashier_sparse', 'guide', 'gleanr',
       'factorgo'], dtype=object)
Code
def stratify_dfcol(df, colname, value):
    #return pd_utils.select_dfrows(df, [f"$({colname}) == {value}"])
    if value is None:
        return df.loc[df[colname].isnull()]
    else:
        return df.loc[df[colname] == value]

def stratify_dfcols(df, condition_list):
    for (colname, value) in condition_list:
        df = stratify_dfcol(df, colname, value)
    return df

def stratify_dfcols_in_list(df, colname, values):
    return df.loc[df[colname].isin(values)]

methods = {
    "rpca" : ["rpca", "truncated_svd"],
    "nnm"  : ["nnm", "truncated_svd"],
    "nnm_sparse" : ["nnm_sparse", "truncated_svd"],
    "truncated_svd" : [None, "truncated_svd"],
    "factorgo" : [None, "factorgo"],
    "flashier" : [None, "flashier"],
    "flashier_sparse" : [None, "flashier_sparse"],
    "guide" : [None, "guide"],
    "gleanr" : [None, "gleanr"],
}

method_labels = {
    "rpca" : "Clorinn (RPCA)",
    "nnm" : "Clorinn (NNM)",
    "nnm_sparse" : "Clorinn (NNM-Sparse)",
    "truncated_svd": "Truncated SVD",
    "factorgo": "FactorGO",
    "flashier" : "Flashier (Default)",
    "flashier_sparse" : "Flashier (Sparse)",
    "guide" : "GUIDE",
    "gleanr" : "GLEANR",
}

# method_colors = {
#     "rpca" : nygc_colors['brown'],
#     "nnm" : nygc_colors['darkred'],
#     "nnm_sparse" : nygc_colors['darkyellow'],
#     "truncated_svd" : nygc_colors['lightblue'],
#     "factorgo" : nygc_colors['khaki'],
#     "flashier" : nygc_colors['green'],
#     "flashier_sparse" : nygc_colors['lightgreen'],
#     "guide" : nygc_colors['blue'],
#     "gleanr" : nygc_colors['gray'],
# }

# color keys are: brown, darkred, darkyellow
# blue, green, lightgreen, yellowgreen,
# lightblue, purple, olive

method_colors = {
    "rpca" : manuscript_colors['brown'],
    "nnm" : manuscript_colors['darkred'],
    "nnm_sparse" : manuscript_colors['darkyellow'],
    "truncated_svd" : manuscript_colors['blue'],
    "factorgo" : manuscript_colors['lightblue'],
    "flashier" : manuscript_colors['green'],
    "flashier_sparse" : manuscript_colors['yellowgreen'],
    "guide" : manuscript_colors['purple'],
    "gleanr" : manuscript_colors['khaki'],
}

# Base parameters
simparams = {'p': 2000, 'k': 10, 'h2': 0.2, 'h2_shared_frac': 0.6, 'aq': 0.6}
score_names = {
    'adj_MI': "Adjusted MI Score",
    # 'MI': "Mutual Information Score",
    'L_rmse': r"$\| L - \hat{L}\|_F$",
    'F_rmse': r"$\| F - \hat{F}\|_F$",
    'Z_rmse': r"$\| LF^{T} - \hat{L}\hat{F}^{T}\|_F$",
}

def get_simulation_with_variable(df, var_name, var_values):
    condition = [(f'simulate.{k}', v) for k, v in simparams.items() if k != var_name]
    df1 = stratify_dfcols(df, condition)
    df2 = stratify_dfcols_in_list(df1, f'simulate.{var_name}', var_values)
    return df2

def get_scores_from_dataframe(df, score_name, variable_name, variable_values, 
        methods = methods):
    simdf = get_simulation_with_variable(df, variable_name, variable_values)
    scores = {key: list() for key in methods.keys()}
    for method, mlist in methods.items():
        mrows = stratify_dfcols(simdf, [('lowrankfit', mlist[0]), ('mfmethods', mlist[1])])
        for value in variable_values:
            vrows = stratify_dfcol(mrows, f'simulate.{variable_name}', value)
            scores[method].append(vrows[f'score.{score_name}'].to_numpy())
    return scores

def random_jitter(xvals, yvals, d = 0.1):
    xjitter = [x + np.random.randn(len(y)) * d for x, y in zip(xvals, yvals)]
    return xjitter

def boxplot_scores(variable, variable_values, 
        methods = methods, score_names = score_names,
        dscout = dscout, method_colors = method_colors,
        custom_font = 'Futura Std', xlabel = "No. of variants"):
    
    nmethods = len(methods)
    nvariables = len(variable_values)
    nscores = len(score_names)
    
    figh = 8
    figw = (nscores * figh) + (nscores - 1)
    fig = plt.figure(figsize = (figw, figh))
    axs = [fig.add_subplot(1, nscores, x+1) for x in range(nscores)]
    boxs = {x: None for x in methods.keys()}
    
    for i, (score_name, score_label) in enumerate(score_names.items()):
        scores = get_scores_from_dataframe(dscout, score_name, variable, variable_values)
        for j, mkey in enumerate(methods.keys()):
            boxcolor = method_colors[mkey]
            boxface = f'#{boxcolor[1:]}80'
            medianprops = dict(linewidth=0, color = boxcolor)
            whiskerprops = dict(linewidth=2, color = boxcolor)
            boxprops = dict(linewidth=2, color = boxcolor, facecolor = boxface)
            flierprops = dict(marker='o', markerfacecolor=boxface, markersize=3, markeredgecolor = boxcolor)

            xpos = [x * (nmethods + 1) + j for x in range(nvariables)]
            boxs[mkey] = axs[i].boxplot(scores[mkey], positions = xpos,
                showcaps = False, showfliers = False,
                widths = 0.7, patch_artist = True, notch = False,
                flierprops = flierprops, boxprops = boxprops,
                medianprops = medianprops, whiskerprops = whiskerprops)
            
            axs[i].scatter(random_jitter(xpos, scores[mkey]), scores[mkey], 
                           edgecolor = boxcolor, facecolor = boxface, linewidths = 1, 
                           s = 10)

        xcenter = [x * (nmethods + 1) + (nmethods - 1) / 2 for x in range(nvariables)]
        axs[i].set_xticks(xcenter)
        axs[i].set_xticklabels(variable_values)
        axs[i].set_xlabel(xlabel)
        axs[i].set_ylabel(score_label)
        xlim_low = 0 - (nvariables - 1) / 2
        #xlim_high = (nvariables - 1) * (nmethods + 1) + (nmethods - 1) + (nvariables - 1) / 2
        xlim_high = (nmethods + 1.5) * nvariables - 2.5
        axs[i].set_xlim( xlim_low, xlim_high )

    plt.tight_layout()
    return axs, boxs

variable = 'p'
variable_values = [500, 1000, 2000, 5000, 10000]

axs, boxs = boxplot_scores(variable, variable_values)

# plt.savefig('../plots/colormann/numerical_experiments.png', bbox_inches='tight')
plt.show()

Code
stratify_dfcols(dscout, [('lowrankfit', None), ('mfmethods', 'gleanr'), ('simulate.p', 5000)])
DSC simulate simulate.n simulate.p simulate.k simulate.h2 simulate.h2_shared_frac simulate.aq simulate.nsample_minmax lowrankfit mfmethods score.L_rmse score.F_rmse score.Z_rmse score.L_psnr score.F_psnr score.Z_psnr score.MI score.adj_MI
752 1 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.013260 0.003533 0.713833 20.352956 21.582428 21.988351 0.030719 0.712974
756 2 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.016702 0.003912 NaN 18.611814 19.367315 NaN 0.380503 0.404146
760 3 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.013229 0.003659 NaN 21.282401 20.541539 NaN 0.927977 0.831983
764 4 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.014972 0.003710 NaN 20.748471 20.136834 NaN 0.017752 0.391726
768 5 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.015011 0.003938 NaN 20.604412 20.750895 NaN 0.729495 0.648164
772 6 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.014764 0.003709 NaN 20.613805 20.666924 NaN 0.015040 0.361918
776 7 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.013321 0.003471 NaN 21.218398 21.298047 NaN 0.022656 0.560445
780 8 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.012570 0.003270 0.694290 21.095071 20.863476 21.789366 0.011785 -0.000274
784 9 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.015797 0.003847 NaN 19.397441 19.917950 NaN 0.787489 0.675628
788 10 blockdiag_p 200.0 5000.0 10.0 0.2 0.6 0.6 (10000,40000) None gleanr 0.013716 0.003663 0.753525 22.007271 20.022504 21.116617 0.839608 0.746152
Code
stratify_dfcols(dscout, [('lowrankfit', 'nnm_sparse'), ('mfmethods', 'truncated_svd'), ('simulate.h2_shared_frac', 0.2)])
DSC simulate simulate.n simulate.p simulate.k simulate.h2 simulate.h2_shared_frac simulate.aq simulate.nsample_minmax lowrankfit mfmethods score.L_rmse score.F_rmse score.Z_rmse score.L_psnr score.F_psnr score.Z_psnr score.MI score.adj_MI
470 1 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.009277 0.004790 0.726738 24.401062 22.246502 21.664802 0.593343 0.633928
474 2 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.009308 0.004785 0.755534 24.682219 21.377632 21.653191 0.017471 0.027549
478 3 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.008917 0.004764 0.762006 23.946771 21.958810 21.188775 0.011672 0.047220
482 4 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.009180 0.004735 0.733231 24.514610 21.161575 20.430464 0.517332 0.583975
486 5 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.008625 0.004616 0.716311 25.626320 21.871928 21.820822 0.017740 0.596206
490 6 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.009032 0.004634 0.758134 24.903905 21.509997 20.886909 0.013898 0.010379
494 7 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.010758 0.004864 0.783047 22.104810 21.177329 20.756933 0.031476 0.036359
498 8 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.008429 0.004589 0.682947 24.744418 21.120524 22.185621 0.939227 0.847907
502 9 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.009510 0.004722 0.757983 23.847346 21.314586 22.272501 0.023509 0.036652
506 10 blockdiag_h2shared 200.0 2000.0 10.0 0.2 0.2 0.6 (10000,40000) nnm_sparse truncated_svd 0.008877 0.004559 0.728476 26.051317 21.413803 21.618791 0.008403 0.673081
Code
score_names
{'adj_MI': 'Adjusted MI Score',
 'L_rmse': '$\\| L - \\hat{L}\\|_F$',
 'F_rmse': '$\\| F - \\hat{F}\\|_F$',
 'Z_rmse': '$\\| LF^{T} - \\hat{L}\\hat{F}^{T}\\|_F$'}
Code
simdata_filename = os.path.join(dsc_output, "blockdiag/blockdiag_5.pkl")
with open(simdata_filename, "rb") as fh:
    simdata = pickle.load(fh)

sys.path.append("../utils/")
import plot_functions as mpy_plotfn
Code
simdata['Ltrue'].shape
(200, 10)
Code
from matplotlib import colormaps as mpl_cmaps
import matplotlib.colors as mpl_colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec

variables = ['p', 'k', 'h2']
variable_values = {
    'p' : [500, 1000, 2000, 5000, 10000],
    'k' : [2, 5, 10, 15, 20],
    'h2' : [0.05, 0.2, 0.1, 0.3, 0.4],
}

xlabel = {
    'p': "No. of variants",
    'k': "No. of factors",
    'h2': "Heritability",
}

nmethods = len(methods)
nscores = len(score_names)

fig = plt.figure(figsize = (18, 24))

boxs = {x: None for x in methods.keys()}

gs = fig.add_gridspec(6, 6, height_ratios=(0.7, 0.5, 1, 1, 1, 1), wspace=0, hspace=0)

ax0 = fig.add_subplot(gs[0,0:2])
ax_dummy = fig.add_subplot(gs[1,:])

ax_p  = [fig.add_subplot(gs[2, 0:2]), 
         fig.add_subplot(gs[3, 0:2]), 
         fig.add_subplot(gs[4, 0:2]),
         fig.add_subplot(gs[5, 0:2])]
ax_k  = [fig.add_subplot(gs[2, 2:4], sharey = ax_p[0]), 
         fig.add_subplot(gs[3, 2:4], sharey = ax_p[1]), 
         fig.add_subplot(gs[4, 2:4], sharey = ax_p[2]), 
         fig.add_subplot(gs[5, 2:4], sharey = ax_p[3])]
ax_h2 = [fig.add_subplot(gs[2, 4:], sharey = ax_k[0]), 
         fig.add_subplot(gs[3, 4:], sharey = ax_k[1]), 
         fig.add_subplot(gs[4, 4:], sharey = ax_k[2]), 
         fig.add_subplot(gs[5, 4:], sharey = ax_k[3])]


axs  = [ax_p, ax_k, ax_h2]

gs_inner = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[0,0:6], 
                                            width_ratios = [0.7, 1, 1])
# ax0l = fig.add_subplot(gs_inner[0, 0])
ax0c = fig.add_subplot(gs_inner[0, 0])
ax0r = fig.add_subplot(gs_inner[0, 1:])


def make_plot_pca(ax, comp, labels, unique_labels, colorlist, bgcolor = "#F0F0F0"):
    pc1 = comp[:, 0]
    pc2 = comp[:, 1]
    for i, label in enumerate(unique_labels):
        idx = np.array([k for k, x in enumerate(labels) if x == label])
        ax.scatter(pc1[idx], pc2[idx], s = 100, alpha = 0.7, label = label, color = colorlist[i])
    ax.tick_params(bottom = False, top = False, left = False, right = False,
        labelbottom = False, labeltop = False, labelleft = False, labelright = False)
    ax.patch.set_facecolor(bgcolor)
    ax.patch.set_alpha(0.0)
    for side, border in ax.spines.items():
        border.set_visible(False)
    return


def make_plot_covariance_heatmap(ax, X, vmax = 1):
    cmap1 = mpl_cmaps.get_cmap("YlOrRd").copy()
    cmap1.set_bad("w")
    norm1 = mpl_colors.TwoSlopeNorm(vmin = 0., vcenter = vmax / 2., vmax = vmax)
    im1 = ax.imshow(X.T, cmap = cmap1, norm = norm1, interpolation='nearest', origin = 'upper')

    ax.tick_params(bottom = False, top = True, left = True, right = False,
                    labelbottom = False, labeltop = True, labelleft = True, labelright = False)
    ax.set_xticks([0,100,200])
    ax.set_yticks([0,100,200])

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="10%", pad=0.2)
    cbar = plt.colorbar(im1, cax=cax, fraction = 0.1)
    cax.set_ylabel("Correlation $r^2$")
    # shift the box to the left
    box = ax.get_position()
    box.x0 = box.x0 - 0.03
    box.x1 = box.x1 - 0.03
    box.y0 = box.y0 - 0.01
    box.y1 = box.y1 - 0.01
    ax.set_position(box)

colorlist = [nygc_colors[x] for x in ['orange', 'blue', 'yellowgreen']]
# make_plot_pca(ax0l, simdata['Ltrue'], simdata['Ctrue'], list(set(simdata['Ctrue'])), colorlist)
# ax0l.set_xlabel("Factor 1")
# ax0l.set_ylabel("Factor 2")
# ax0l.set_aspect(1.0)
ax0.text(-0.25, 1.0, "(a)", transform=ax0.transAxes, fontweight='bold')

nsample, p = simdata['Z'].shape
k = simdata['Ltrue'].shape[1]
make_plot_covariance_heatmap(ax0c, np.cov(simdata['Ltrue'] * np.sqrt(k)), vmax = 0.2)

for ax in [ax0, ax_dummy, ax0r]:
    ax.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)
# for side, border in list(ax0.spines.items()) + \
#                     list(ax0l.spines.items()) + \
#                     list(ax0r.spines.items()) + \
#                     list(ax_dummy.spines.items()):
for side, border in list(ax0.spines.items()) + \
                    list(ax0r.spines.items()) + \
                    list(ax_dummy.spines.items()):
    border.set_visible(False)
for side, border in list(ax0.spines.items()):
    border.set_visible(False)
    
panel_label = ["(b)", "(c)", "(d)"]

for k, var in enumerate(variables):
    for i, (score_name, score_label) in enumerate(score_names.items()):
        nvariables = len(variable_values[var])
        scores = get_scores_from_dataframe(dscout, score_name, var, variable_values[var])
        for j, mkey in enumerate(methods.keys()):
            boxcolor = method_colors[mkey]
            boxface = f'#{boxcolor[1:]}80'
            medianprops = dict(linewidth=0, color = boxcolor)
            whiskerprops = dict(linewidth=2, color = boxcolor)
            boxprops = dict(linewidth=2, color = boxcolor, facecolor = boxface)
            flierprops = dict(marker='o', markerfacecolor=boxface, markersize=3, markeredgecolor = boxcolor)

            xpos = [x * (nmethods + 1) + j for x in range(nvariables)]
            boxs[mkey] = axs[k][i].boxplot(scores[mkey], positions = xpos,
                showcaps = False, showfliers = False,
                widths = 0.7, patch_artist = True, notch = False,
                flierprops = flierprops, boxprops = boxprops,
                medianprops = medianprops, whiskerprops = whiskerprops)

            # axs[k][i].scatter(random_jitter(xpos, scores[mkey]), scores[mkey], 
            #                edgecolor = boxcolor, facecolor = boxface, linewidths = 1, 
            #                s = 10)

        axs[k][i].tick_params(bottom = False, top = False, left = False, right = False,
               labelbottom = False, labeltop = False, labelleft = False, labelright = False)
        if i == 3:
            axs[k][i].tick_params(bottom = True, labelbottom = True)
            xcenter = [x * (nmethods + 1) + (nmethods - 1) / 2 for x in range(nvariables)]
            axs[k][i].set_xticks(xcenter)
            axs[k][i].set_xticklabels(variable_values[var])
            axs[k][i].set_xlabel(xlabel[var])
        if k == 0:
            axs[k][i].tick_params(left = True, labelleft = True)
            axs[k][i].set_ylabel(score_label)
        xlim_low = 0 - (nvariables - 1) / 2
        #xlim_high = (nvariables - 1) * (nmethods + 1) + (nmethods - 1) + (nvariables - 1) / 2
        xlim_high = (nmethods + 1.5) * nvariables - 2.5
        axs[k][i].set_xlim( xlim_low, xlim_high )
        # axs[k][i].text(-0.25, 1.0, panel_label[i], transform=axs[i].transAxes, fontweight='bold')
        
        # for side, border in list(axs[k][i].spines.items()):
        #     border.set_visible(False)
        axs[k][i].patch.set_alpha(0.0)
        if i == 0:
            axs[k][i].text(0.05, 1.1, panel_label[k], transform=axs[k][i].transAxes, fontweight='bold')
    
handles = [boxs[mkey]["boxes"][0] for mkey in methods.keys()]
labels = [method_labels[mkey] for mkey in methods.keys()]
ax0r.legend(handles = handles, labels = labels, 
           loc = 'upper right', frameon = False, handlelength = 2, ncol = 2)

# for ax in [ax0, ax0l, ax0r, ax_dummy]:
#     ax.patch.set_alpha(0.0)
    
# plt.tight_layout(h_pad=1.5, w_pad=1.5)
# gs.tight_layout(fig)
# plt.savefig('../plots/colormann-manuscript/sim_expt_variants_02.pdf', bbox_inches='tight')
plt.show()

Code
from matplotlib import colormaps as mpl_cmaps
import matplotlib.colors as mpl_colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec

variables = ['p', 'h2']
variable_values = {
    'p' : [500, 1000, 2000, 5000, 10000],
    'k' : [2, 5, 10, 15, 20],
    'h2' : [0.05, 0.2, 0.1, 0.3, 0.4],
    'h2_shared_frac' : [0.2, 0.4, 0.6, 0.8, 1.0],
}

xlabel = {
    'p': "No. of variants",
    'k': "No. of factors",
    'h2': "Heritability",
    'h2_shared_frac': "Fraction of shared heritability",
}

score_names = {
    'adj_MI': "Adjusted MI Score",
    # 'MI': "Mutual Information Score",
    'L_rmse': r"$\| L - \hat{L}\|_F$",
    'F_rmse': r"$\| F - \hat{F}\|_F$",
    'Z_rmse': r"$\| LF^{T} - \hat{L}\hat{F}^{T}\|_F$",
}

nmethods = len(methods)
nscores = len(score_names)

fig = plt.figure(figsize = (18, 24))

boxs = {x: None for x in methods.keys()}

gs = fig.add_gridspec(6, 6, height_ratios=(0.7, 0.5, 1, 1, 1, 1), wspace=0, hspace=0)

ax0_dummy = fig.add_subplot(gs[0, :])
ax_dummy = fig.add_subplot(gs[1,:])

ax0 = gs[0, :].subgridspec(1, 3, width_ratios=[1, 1, 0.5], wspace=0.05)
ax0l = fig.add_subplot(ax0[0, 0:2])
ax0r = fig.add_subplot(ax0[0, 2])

ax_p  = [fig.add_subplot(gs[2, 0:3]), 
         fig.add_subplot(gs[3, 0:3]), 
         fig.add_subplot(gs[4, 0:3]),
         fig.add_subplot(gs[5, 0:3])]

ax_h2 = [fig.add_subplot(gs[2, 3:], sharey = ax_p[0]), 
         fig.add_subplot(gs[3, 3:], sharey = ax_p[1]), 
         fig.add_subplot(gs[4, 3:], sharey = ax_p[2]), 
         fig.add_subplot(gs[5, 3:], sharey = ax_p[3])]

axs  = [ax_p, ax_h2]

def make_plot_pca(ax, comp, labels, unique_labels, colorlist, bgcolor = "#F0F0F0"):
    pc1 = comp[:, 0]
    pc2 = comp[:, 1]
    for i, label in enumerate(unique_labels):
        idx = np.array([k for k, x in enumerate(labels) if x == label])
        ax.scatter(pc1[idx], pc2[idx], s = 100, alpha = 0.7, label = label, color = colorlist[i])
    ax.tick_params(bottom = False, top = False, left = False, right = False,
        labelbottom = False, labeltop = False, labelleft = False, labelright = False)
    ax.patch.set_facecolor(bgcolor)
    ax.patch.set_alpha(0.0)
    for side, border in ax.spines.items():
        border.set_visible(False)
    return


def make_plot_covariance_heatmap(ax, X, vmax = 1):
    cmap1 = mpl_cmaps.get_cmap("YlOrRd").copy()
    cmap1.set_bad("w")
    norm1 = mpl_colors.TwoSlopeNorm(vmin = 0., vcenter = vmax / 2., vmax = vmax)
    im1 = ax.imshow(X.T, cmap = cmap1, norm = norm1, interpolation='nearest', origin = 'lower')

    ax.tick_params(bottom = True, top = False, left = True, right = False,
                    labelbottom = True, labeltop = False, labelleft = True, labelright = False)
    ax.set_xticks([0,100,200])
    ax.set_yticks([0,100,200])

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="10%", pad=0.2)
    cbar = plt.colorbar(im1, cax=cax, fraction = 0.1)
    cax.set_ylabel("Correlation $r^2$")
    # shift the box to the left
    # box = ax.get_position()
    # box.x0 = box.x0 - 0.03
    # box.x1 = box.x1 - 0.03
    # box.y0 = box.y0 - 0.01
    # box.y1 = box.y1 - 0.01
    # ax.set_position(box)

colorlist = [nygc_colors[x] for x in ['orange', 'blue', 'yellowgreen']]

nsample, p = simdata['Z'].shape
k = simdata['Ltrue'].shape[1]
Ltrue_cov = np.cov(simdata['Ltrue']) * np.sqrt(np.prod(simdata['Ltrue'].shape))
Ltrue_cov = np.cov(simdata['Ltrue']) * np.sqrt(simdata['Ltrue'].shape[0])
make_plot_covariance_heatmap(ax0r, Ltrue_cov, vmax = 0.2)

for ax in [ax_dummy, ax0l, ax0_dummy]:
    ax.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)

for side, border in list(ax0_dummy.spines.items()) + \
                    list(ax0l.spines.items()) + \
                    list(ax_dummy.spines.items()):
    border.set_visible(False)
    
# for side, border in list(ax0_dummy.spines.items()) + \
#                     list(ax0l.spines.items()) + \
#                     list(ax0r.spines.items()):
#     border.set_visible(True)
    
    
panel_label = ["(b)", "(c)"]
gap = 2

for k, var in enumerate(variables):
    for i, (score_name, score_label) in enumerate(score_names.items()):
        nvariables = len(variable_values[var])
        scores = get_scores_from_dataframe(dscout, score_name, var, variable_values[var])
        for j, mkey in enumerate(methods.keys()):
            boxcolor = method_colors[mkey]
            boxface = f'#{boxcolor[1:]}80'
            medianprops = dict(linewidth=0, color = boxcolor)
            whiskerprops = dict(linewidth=2, color = boxcolor)
            boxprops = dict(linewidth=2, color = boxcolor, facecolor = boxface)
            flierprops = dict(marker='o', markerfacecolor=boxface, markersize=3, markeredgecolor = boxcolor)

            xpos = [x * (nmethods + gap) + j for x in range(nvariables)]
            boxs[mkey] = axs[k][i].boxplot(scores[mkey], positions = xpos,
                showcaps = False, showfliers = False,
                widths = 0.7, patch_artist = True, notch = False,
                flierprops = flierprops, boxprops = boxprops,
                medianprops = medianprops, whiskerprops = whiskerprops)

            # axs[k][i].scatter(random_jitter(xpos, scores[mkey]), scores[mkey], 
            #                edgecolor = boxcolor, facecolor = boxface, linewidths = 1, 
            #                s = 10)

        axs[k][i].tick_params(bottom = False, top = False, left = False, right = False,
               labelbottom = False, labeltop = False, labelleft = False, labelright = False)
        if i == 3:
            axs[k][i].tick_params(bottom = True, labelbottom = True)
            xcenter = [x * (nmethods + gap) + (nmethods - 1) / 2 for x in range(nvariables)]
            axs[k][i].set_xticks(xcenter)
            axs[k][i].set_xticklabels(variable_values[var])
            axs[k][i].set_xlabel(xlabel[var])
        if k == 0:
            axs[k][i].tick_params(left = True, labelleft = True)
            axs[k][i].set_ylabel(score_label)
        xlim_low = 0 - (nvariables - 1) / 2
        xlim_high = (nvariables - 1) * (nmethods + gap) + (nmethods - 1) + (nvariables - 1) / 2
        # xlim_high = (nmethods + 1.5) * nvariables - 2.5
        axs[k][i].set_xlim( xlim_low, xlim_high )
        # axs[k][i].text(-0.25, 1.0, panel_label[i], transform=axs[i].transAxes, fontweight='bold')
        
        # for side, border in list(axs[k][i].spines.items()):
        #     border.set_visible(False)
        axs[k][i].patch.set_alpha(0.0)
        if i == 0:
            axs[k][i].text(0, 1.1, panel_label[k], transform=axs[k][i].transAxes, fontweight='bold')

        # ---- Group shading + separators ----
        for j in range(nvariables):
            left  = j * (nmethods + gap) - 0.5
            right = left + nmethods
            

            # alternating faint background blocks
            if j % 2 != 0:
                axs[k][i].axvspan(left - gap/2, right + gap/2, color="k", alpha=0.02, zorder=0)

            # dashed separator between blocks
            if j < (nvariables - 1):
                axs[k][i].axvline(right + gap/2, ls="--", lw=1, alpha=0.4, color="k", zorder=0)

handles = [boxs[mkey]["boxes"][0] for mkey in methods.keys()]
labels = [method_labels[mkey] for mkey in methods.keys()]
ax0l.legend(handles = handles, labels = labels, 
           loc = 'upper left', frameon = False, handlelength = 2, ncol = 2)

for ax in [ax0_dummy, ax0l, ax0r, ax_dummy]:
    ax.patch.set_alpha(0.0)
    
# ax0.text(-0.25, 1.0, "(a)", transform=ax0.transAxes, fontweight='bold')
ax0r.text(-0.8, 0.95, "(a)", transform=ax0r.transAxes, fontweight='bold')
    
# plt.tight_layout(h_pad=1.5, w_pad=1.5)
# gs.tight_layout(fig)
plt.savefig('../plots/colormann-manuscript/sim_expt_variants_02.pdf', bbox_inches='tight')
plt.show()

Code
np.cov(simdata['Ltrue']) * np.sqrt(np.prod(simdata['Ltrue'].shape))
array([[ 0.34143554, -0.02347945,  0.21674217, ...,  0.25737923,
         0.22172796,  0.21296168],
       [-0.02347945,  0.20900267,  0.14567344, ..., -0.05553266,
        -0.08736217,  0.07033388],
       [ 0.21674217,  0.14567344,  0.62601644, ...,  0.0948575 ,
         0.02418137,  0.26304167],
       ...,
       [ 0.25737923, -0.05553266,  0.0948575 , ...,  0.39770126,
         0.36152045,  0.21066828],
       [ 0.22172796, -0.08736217,  0.02418137, ...,  0.36152045,
         0.37908465,  0.2054572 ],
       [ 0.21296168,  0.07033388,  0.26304167, ...,  0.21066828,
         0.2054572 ,  0.522065  ]])
Code
simdata['Ltrue'].shape
(200, 10)