CloriNN manuscript figures - Simulations

Author

Saikat Banerjee

Published

February 12, 2026

Abstract
High quality plots used for simulation figures, using manuscript color palette.
Code
import numpy as np
import pandas as pd
import pickle
import sys
import os
import dsc
from dsc.query_engine import Query_Processor as dscQP
from dsc import dsc_io

import matplotlib
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils

from matplotlib import colormaps as mpl_cmaps
import matplotlib.colors as mpl_colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
Code
# import matplotlib.font_manager as mpl_fm
# font_path = '/gpfs/commons/home/sbanerjee/nygc/Futura'
# mpl_fm.fontManager.addfont(font_path + '/FuturaStd-Book.otf') # Loads "Futura Std"

# mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 300)
# futura_book = FontProperties(fname='/gpfs/commons/home/sbanerjee/nygc/Futura/FuturaStd-Book.otf')

manuscript_colors = {
    'brown': '#7F180D',
    'darkred': '#C10020',
    'darkyellow': '#FF6800',
    'blue': '#00538A',
    'green': '#0A8A42',
    'lightgreen': '#74B74A',    
    'yellowgreen': '#93AA00',
    'lightblue': '#A6BDD7',
    'purple': '#803E75',
    'olive': '#232C16',
    'khaki': '#CEA262',
    'darkgray': '#1A1A1A',
    'orange': '#F37239',
}

# Style sheet for manuscript
mpl_stylesheet.banskt_presentation(dpi = 300, fontsize = 22, 
    splinecolor = manuscript_colors['darkgray'], black = manuscript_colors['darkgray'])
# plt.rcParams['font.family'] = 'Futura Std'
Code
dsc_output = "/gpfs/commons/groups/knowles_lab/sbanerjee/low_rank_matrix_approximation_numerical_experiments/lrma"
dsc_fname  = os.path.basename(os.path.normpath(dsc_output))
db = os.path.join(dsc_output, dsc_fname + ".db")
dscoutpkl = os.path.join("/gpfs/commons/home/sbanerjee/work/npd/lrma-dsc/dsc/results", dsc_fname + "_dscout.pkl")
dscout    = pd.read_pickle(dscoutpkl)
dscout
DSC simulate simulate.n simulate.p simulate.k simulate.h2 simulate.h2_shared_frac simulate.aq simulate.nsample_minmax lowrankfit ... score.MI_raw score.adj_MI_raw score.MI_bal score.adj_MI_bal score.MI_cal score.adj_MI_cal score.MI_bal_cal score.adj_MI_bal_cal score.MI_oracle_aligned score.adj_MI_oracle_aligned
0 1 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca ... 0.009893 0.045381 0.009893 0.045381 0.009893 0.045381 0.009893 0.045381 0.009893 0.058497
1 2 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca ... 0.559206 0.791608 0.559206 0.791608 0.559206 0.791608 0.559206 0.791608 0.559206 0.801725
2 3 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca ... 0.012680 0.637053 0.012680 0.637053 0.012680 0.637053 0.012680 0.637053 0.009966 0.637053
3 4 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca ... 0.602172 0.662108 0.602172 0.662108 0.602172 0.662108 0.602172 0.662108 0.593425 0.834183
4 5 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) rpca ... 0.503745 0.534922 0.503745 0.534922 0.503745 0.534922 0.503745 0.534922 0.503745 0.543721
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1705 8 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None ... 0.006982 0.017398 1.006754 0.949173 0.006982 0.017398 1.006754 0.949173 0.981545 0.933310
1706 9 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.4 (10000,40000) None ... 0.008989 0.012553 0.010364 0.223061 0.008989 0.012553 0.010364 0.223061 0.010364 0.001357
1707 9 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None ... 0.010147 0.002283 1.012892 0.927433 0.010147 0.002283 1.012892 0.927433 1.012892 0.927433
1708 10 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.4 (10000,40000) None ... 0.013312 0.034239 0.008805 0.339672 0.013312 0.034239 0.008805 0.339672 0.025912 0.022569
1709 10 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None ... 0.022239 0.008291 1.043526 0.914972 0.022239 0.008291 1.043526 0.914972 1.043526 0.921261

1710 rows × 50 columns

Code
dscout.columns
Index(['DSC', 'simulate', 'simulate.n', 'simulate.p', 'simulate.k',
       'simulate.h2', 'simulate.h2_shared_frac', 'simulate.aq',
       'simulate.nsample_minmax', 'lowrankfit', 'mfmethods',
       'score.global_scale', 'score.balancing_impact', 'score.L_rmse_raw',
       'score.F_rmse_raw', 'score.Z_rmse_raw', 'score.L_rmse_bal',
       'score.F_rmse_bal', 'score.Z_rmse_bal', 'score.L_rmse_cal',
       'score.F_rmse_cal', 'score.Z_rmse_cal', 'score.L_rmse', 'score.F_rmse',
       'score.Z_rmse', 'score.L_rel_rmse_raw', 'score.F_rel_rmse_raw',
       'score.Z_rel_rmse_raw', 'score.L_rel_rmse_bal', 'score.F_rel_rmse_bal',
       'score.Z_rel_rmse_bal', 'score.L_rel_rmse_cal', 'score.F_rel_rmse_cal',
       'score.Z_rel_rmse_cal', 'score.L_rel_rmse', 'score.F_rel_rmse',
       'score.Z_rel_rmse', 'score.L_psnr', 'score.F_psnr', 'score.Z_psnr',
       'score.MI_raw', 'score.adj_MI_raw', 'score.MI_bal', 'score.adj_MI_bal',
       'score.MI_cal', 'score.adj_MI_cal', 'score.MI_bal_cal',
       'score.adj_MI_bal_cal', 'score.MI_oracle_aligned',
       'score.adj_MI_oracle_aligned'],
      dtype='object')
Code
dscout["mfmethods"].unique()
array(['truncated_svd', 'flashier', 'flashier_sparse', 'guide', 'gleanr',
       'factorgo'], dtype=object)
Code
simdata_filename = os.path.join(dsc_output, "blockdiag/blockdiag_2.pkl")
with open(simdata_filename, "rb") as fh:
    simdata = pickle.load(fh)

simdata['Ltrue'].shape
(200, 10)
Code
methods = {
    "rpca" : ["rpca", "truncated_svd"],
    "nnm"  : ["nnm", "truncated_svd"],
    "nnm_sparse" : ["nnm_sparse", "truncated_svd"],
    "truncated_svd" : [None, "truncated_svd"],
    "factorgo" : [None, "factorgo"],
    "flashier" : [None, "flashier"],
    "flashier_sparse" : [None, "flashier_sparse"],
    "guide" : [None, "guide"],
    "gleanr" : [None, "gleanr"],
}

method_labels = {
    "rpca" : "Clorinn (RPCA)",
    "nnm" : "Clorinn (NNM)",
    "nnm_sparse" : "Clorinn (NNM-Sparse)",
    "truncated_svd": "Truncated SVD",
    "factorgo": "FactorGO",
    "flashier" : "Flashier (Default)",
    "flashier_sparse" : "Flashier (Sparse)",
    "guide" : "GUIDE",
    "gleanr" : "GLEANR",
}

method_colors = {
    "rpca" : manuscript_colors['brown'],
    "nnm" : manuscript_colors['darkred'],
    "nnm_sparse" : manuscript_colors['darkyellow'],
    "truncated_svd" : manuscript_colors['blue'],
    "factorgo" : manuscript_colors['lightblue'],
    "flashier" : manuscript_colors['green'],
    "flashier_sparse" : manuscript_colors['yellowgreen'],
    "guide" : manuscript_colors['purple'],
    "gleanr" : manuscript_colors['khaki'],
}

metric_labels = {
  "global_scale_log10": "Global calibration scalar (log10)",
  "balancing_impact": "Balancing impact (log10 scale)",
  "L_rmse_raw":     r"RMSE ($\hat{L}$) - align",
  "F_rmse_raw":     r"RMSE ($\hat{F}$) - align",
  "Z_rmse_raw":     r"RMSE ($\hat{L}\hat{F}^{T}$) - align",
  "L_rmse_bal":     r"RMSE ($\hat{L}$) - balance + align",
  "F_rmse_bal":     r"RMSE ($\hat{F}$) - balance + align",
  "Z_rmse_bal":     r"RMSE ($\hat{L}\hat{F}^{T}$) - balance + align",
  "L_rmse_cal":     r"RMSE ($\hat{L}$) - calibrate + align",
  "F_rmse_cal":     r"RMSE ($\hat{F}$) - calibrate + align",
  "Z_rmse_cal":     r"RMSE ($\hat{L}\hat{F}^{T}$) - calibrate + align",
  "L_rmse":         r"RMSE ($\hat{L}$) - balance + calibrate + align",
  "F_rmse":         r"RMSE ($\hat{F}$) - balance + calibrate + align",
  "Z_rmse":         r"RMSE ($\hat{L}\hat{F}^{T}$) - balance + calibrate + align",
  "L_rel_rmse_raw": r"RelErr ($\hat{L}$) - align",
  "F_rel_rmse_raw": r"RelErr ($\hat{F}$) - align",
  "Z_rel_rmse_raw": r"RelErr ($\hat{L}\hat{F}^{T}$) - align",
  "L_rel_rmse_bal": r"RelErr ($\hat{L}$) - balance + align",
  "F_rel_rmse_bal": r"RelErr ($\hat{F}$) - balance + align",
  "Z_rel_rmse_bal": r"RelErr ($\hat{L}\hat{F}^{T}$) - balance + align",
  "L_rel_rmse_cal": r"RelErr ($\hat{L}$) - calibrate + align",
  "F_rel_rmse_cal": r"RelErr ($\hat{F}$) - calibrate + align",
  "Z_rel_rmse_cal": r"RelErr ($\hat{L}\hat{F}^{T}$) - calibrate + align",
  "L_rel_rmse":     r"RelErr ($\hat{L}$) - balance + calibrate + align",
  "F_rel_rmse":     r"RelErr ($\hat{F}$) - balance + calibrate + align",
  "Z_rel_rmse":     r"RelErr ($\hat{L}\hat{F}^{T}$) - balance + calibrate + align",
  "MI_raw":                "MI - native",
  "adj_MI_raw":            "Adj. MI - native",
  "MI_bal":                "MI - balance",
  "adj_MI_bal":            "Adj. MI - balance",
  "MI_cal":                "MI - calibrate",
  "adj_MI_cal":            "Adj. MI - calibrate",
  "MI_bal_cal":                "MI - balance + calibrate",
  "adj_MI_bal_cal":            "Adj. MI - balance + calibrate",
  "MI_oracle_aligned":     "MI - balance + calibrate + align",
  "adj_MI_oracle_aligned": "Adj. MI - balance + calibrate + align",
}
Code
def stratify_dfcol(df, colname, value):
    #return pd_utils.select_dfrows(df, [f"$({colname}) == {value}"])
    if value is None:
        return df.loc[df[colname].isnull()]
    else:
        return df.loc[df[colname] == value]

def stratify_dfcols(df, condition_list):
    for (colname, value) in condition_list:
        df = stratify_dfcol(df, colname, value)
    return df

def stratify_dfcols_in_list(df, colname, values):
    return df.loc[df[colname].isin(values)]

def get_simulation_with_variable(df, var_name, var_values):
    condition = [(f'simulate.{k}', v) for k, v in simparams.items() if k != var_name]
    df1 = stratify_dfcols(df, condition)
    df2 = stratify_dfcols_in_list(df1, f'simulate.{var_name}', var_values)
    return df2

def get_scores_from_dataframe(df, score_name, variable_name, variable_values, methods = methods):
    simdf = get_simulation_with_variable(df, variable_name, variable_values)
    scores = {key: list() for key in methods.keys()}
    for method, mlist in methods.items():
        mrows = stratify_dfcols(simdf, [('lowrankfit', mlist[0]), ('mfmethods', mlist[1])])
        for value in variable_values:
            vrows = stratify_dfcol(mrows, f'simulate.{variable_name}', value)
            scores[method].append(vrows[f'score.{score_name}'].to_numpy())
    return scores

def get_scores_all_simulations(df, score_name, methods = methods):
    # simdf = get_simulation_with_variable(df, variable_name, variable_values)
    scores = {key: list() for key in methods.keys()}
    for method, mlist in methods.items():
        mrows = stratify_dfcols(df, [('lowrankfit', mlist[0]), ('mfmethods', mlist[1])])
        scores[method].append(mrows[f'score.{score_name}'].to_numpy())
    return scores

def random_jitter(xvals, yvals, d = 0.1):
    """
    yvals: ndarray (ndim=2) or list of ndarrays
    xvals: ndarray (ndim=1) or list or int
        One x-position per y-array, recast to the length of yvals if int
    """
    # yvals = np.asarray(yvals)
    if isinstance(xvals, (int, np.integer)):
        xvals = np.full(yvals.shape, xvals, dtype=float)
    else:
        xvals = np.asarray(xvals, dtype=float)
        
    if isinstance(yvals, np.ndarray):
        if yvals.ndim == 1:
            xjitter = xvals + np.random.randn(*yvals.shape) * d
    else:
        xjitter = [
            x + np.random.randn(len(y)) * d
            for x, y in zip(xvals, yvals)
        ]
    # xjitter = xvals + np.random.randn(*yvals.shape) * d
    return xjitter

def remove_iqr_outliers(arr, k=1.5, label="score", debug = False):
    """
    Remove outliers using Tukey's IQR rule:
        keep values within [Q1 - k*IQR, Q3 + k*IQR]
    """
    arr = np.asarray(arr, dtype=float)

    finite = np.isfinite(arr)
    arr_finite = arr[finite]

    q1 = np.percentile(arr_finite, 25)
    q3 = np.percentile(arr_finite, 75)
    iqr = q3 - q1

    lower = q1 - k * iqr
    upper = q3 + k * iqr

    keep = finite & (arr >= lower) & (arr <= upper)
    bad = ~keep

    if bad.any() and debug:
        print(f"\n[{label}] removed {bad.sum()} / {arr.size} values")
        print(f"  Q1={q1:.6g}, Q3={q3:.6g}, IQR={iqr:.6g}")
        print(f"  allowed range: [{lower:.6g}, {upper:.6g}]")
        print("  bad indices:", np.argwhere(bad).ravel()[:20])
        print("  bad values:", arr[bad][:20])

    return arr[keep]

def should_use_broken_axis(arrays_by_method, gap_factor=100.0):
    use_broken_axis = False
    pathological = list()
    vals = {k: np.nanmedian(v) for k, v in arrays_by_method.items()}
    x = np.asarray(list(vals.values()))
    center = np.nanmedian(x)
    mad = np.median(np.abs(x - center))
    scale = 1.4826 * mad
    robust_x = np.abs(x - center) / scale
    mask = robust_x > gap_factor
    if np.any(mask):
        use_broken_axis = True
        pathological = [m for i, m in enumerate(list(vals.keys())) if mask[i]]
    return use_broken_axis, pathological
Code
def style_metric_axis(ax):
    ax.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)
    

def draw_box_scatter(ax, arraylist_by_method, boxs, nvariables, gap, metric_name, show_scatter=True):
    nmethods = len(methods)
    for j, mkey in enumerate(methods.keys()):
        boxcolor = method_colors[mkey]
        boxface = f'#{boxcolor[1:]}80'
        medianprops = dict(linewidth=0, color = boxcolor)
        whiskerprops = dict(linewidth=2, color = boxcolor)
        boxprops = dict(linewidth=2, color = boxcolor, facecolor = boxface)
        flierprops = dict(marker='o', markerfacecolor=boxface, markersize=3, markeredgecolor = boxcolor)

        scores = arraylist_by_method[mkey]
        xpos = [x * (nmethods + gap) + j for x in range(nvariables)]
        boxs[mkey] = ax.boxplot(
            scores,
            positions = xpos,
            showcaps = False, showfliers = False,
            widths = 0.7, patch_artist = True, notch = False,
            flierprops = flierprops, boxprops = boxprops,
            medianprops = medianprops, whiskerprops = whiskerprops)
        if show_scatter:
            ax.scatter(
                np.concatenate(random_jitter(xpos, scores)), 
                np.concatenate((scores)), 
                edgecolor = boxcolor, facecolor = boxface, linewidths = 1, s = 5)

def make_plot_pca(ax, comp, labels, unique_labels, colorlist, bgcolor = "#F0F0F0"):
    pc1 = comp[:, 0]
    pc2 = comp[:, 1]
    for i, label in enumerate(unique_labels):
        idx = np.array([k for k, x in enumerate(labels) if x == label])
        ax.scatter(pc1[idx], pc2[idx], s = 100, alpha = 0.7, label = label, color = colorlist[i])
    ax.tick_params(bottom = False, top = False, left = False, right = False,
        labelbottom = False, labeltop = False, labelleft = False, labelright = False)
    ax.patch.set_facecolor(bgcolor)
    ax.patch.set_alpha(0.0)
    for side, border in ax.spines.items():
        border.set_visible(False)
    return


def make_plot_covariance_heatmap(ax, X, vmax = 1):
    cmap1 = mpl_cmaps.get_cmap("YlOrRd").copy()
    cmap1.set_bad("w")
    norm1 = mpl_colors.TwoSlopeNorm(vmin = 0., vcenter = vmax / 2., vmax = vmax)
    im1 = ax.imshow(X.T, cmap = cmap1, norm = norm1, interpolation='nearest', origin = 'lower')

    ax.tick_params(bottom = True, top = False, left = True, right = False,
                    labelbottom = True, labeltop = False, labelleft = True, labelright = False)
    ax.set_xticks([0,100,200])
    ax.set_yticks([0,100,200])

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="10%", pad=0.2)
    cbar = plt.colorbar(im1, cax=cax, fraction = 0.1)
    cax.set_ylabel("Correlation $r^2$")
    # shift the box to the left
    # box = ax.get_position()
    # box.x0 = box.x0 - 0.03
    # box.x1 = box.x1 - 0.03
    # box.y0 = box.y0 - 0.01
    # box.y1 = box.y1 - 0.01
    # ax.set_position(box)
    
def remove_top_k_values(array_list, k):
    cutoff = np.sort(np.concatenate(array_list))[-k]
    return [a[a < cutoff] for a in array_list]

def sanitize_scores_manually(scores_by_method, metric_name, remove = ("gleanr", 3)):
    new_scores = {}
    for method, scores in scores_by_method.items():
        if method == remove[0] and metric_name in ("L_rmse_raw", "L_rel_rmse_raw", "L_rmse_cal", "L_rel_rmse_cal"):
            new_scores[method] = remove_top_k_values(scores, remove[1])
        else:
            new_scores[method] = scores
    return new_scores

# Base parameters
simparams = {'p': 2000, 'k': 10, 'h2': 0.2, 'h2_shared_frac': 0.6, 'aq': 0.6}
variables = ['p', 'h2']
variable_values = {
    'p' : [500, 1000, 2000, 5000, 10000],
    'k' : [2, 5, 10, 15, 20],
    'h2' : [0.05, 0.1, 0.2, 0.3, 0.4],
    'h2_shared_frac' : [0.2, 0.4, 0.6, 0.8, 1.0],
}

xlabel = {
    'p': "No. of variants",
    'k': "No. of factors",
    'h2': "Heritability",
    'h2_shared_frac': "Fraction of shared heritability",
}

metric_labels_plot = {
  "adj_MI_raw":      "Adjusted MI",
  "L_rel_rmse_cal": r"rRMSE ($\hat{L}$)",
  "F_rel_rmse_cal": r"rRMSE ($\hat{F}$)",
  "Z_rel_rmse_cal": r"rRMSE ($\hat{L}\hat{F}^{T}$)",
}

nmethods = len(methods)
nscores = len(metric_labels_plot)
fig = plt.figure(figsize = (24, 24))

boxs = {x: None for x in methods.keys()}
gs = fig.add_gridspec(6, 6, height_ratios=(1, 0.2, 1, 1, 1, 1), wspace=0, hspace=0)

ax0_dummy = fig.add_subplot(gs[0, :])
ax_dummy = fig.add_subplot(gs[1,:])

ax0 = gs[0, :].subgridspec(1, 2, width_ratios=[1.2, 0.8], wspace=0.05)
ax0l = fig.add_subplot(ax0[0, 0])
ax0r = fig.add_subplot(ax0[0, 1])

ax_p  = [fig.add_subplot(gs[2, 0:3]), 
         fig.add_subplot(gs[3, 0:3]), 
         fig.add_subplot(gs[4, 0:3]),
         fig.add_subplot(gs[5, 0:3])]

ax_h2 = [fig.add_subplot(gs[2, 3:], sharey = ax_p[0]), 
         fig.add_subplot(gs[3, 3:], sharey = ax_p[1]), 
         fig.add_subplot(gs[4, 3:], sharey = ax_p[2]), 
         fig.add_subplot(gs[5, 3:], sharey = ax_p[3])]

axs  = [ax_p, ax_h2]

colorlist = [manuscript_colors[x] for x in ['orange', 'blue', 'yellowgreen']]

nsample, p = simdata['Z'].shape
k = simdata['Ltrue'].shape[1]
Ltrue_cov = np.cov(simdata['Ltrue']) * np.sqrt(np.prod(simdata['Ltrue'].shape))
Ltrue_cov = np.cov(simdata['Ltrue']) * np.sqrt(simdata['Ltrue'].shape[0])
Ltrue_r2 = np.corrcoef(simdata['Ltrue']) ** 2
make_plot_covariance_heatmap(ax0r, Ltrue_r2, vmax = 0.5)

for ax in [ax_dummy, ax0l, ax0_dummy]:
    style_metric_axis(ax)

for side, border in list(ax0_dummy.spines.items()) + \
                    list(ax0l.spines.items()) + \
                    list(ax_dummy.spines.items()):
    border.set_visible(False)
    
    
panel_label = ["(b)", "(c)"]
gap = 2

for k, var in enumerate(variables):
    for i, (metric_name, metric_label) in enumerate(metric_labels_plot.items()):
        nvariables = len(variable_values[var])
        scores = get_scores_from_dataframe(dscout, metric_name, var, variable_values[var])
        
        if var == "p":
            clean_scores = sanitize_scores_manually(scores, metric_name, remove = ("gleanr", 3))
        elif var == "h2":
            clean_scores = sanitize_scores_manually(scores, metric_name, remove = ("gleanr", 3))
        
        draw_box_scatter(axs[k][i], clean_scores, boxs, nvariables, gap, metric_name, show_scatter=False)
        style_metric_axis(axs[k][i])
        
        if i == 3:
            axs[k][i].tick_params(bottom = True, labelbottom = True)
            xcenter = [x * (nmethods + gap) + (nmethods - 1) / 2 for x in range(nvariables)]
            axs[k][i].set_xticks(xcenter)
            axs[k][i].set_xticklabels(variable_values[var])
            axs[k][i].set_xlabel(xlabel[var])
        if k == 0:
            axs[k][i].tick_params(left = True, labelleft = True)
            axs[k][i].set_ylabel(metric_label)

        xlim_low = 0 - (nvariables - 1) / 2
        xlim_high = (nvariables - 1) * (nmethods + gap) + (nmethods - 1) + (nvariables - 1) / 2

        axs[k][i].set_xlim( xlim_low, xlim_high )

        axs[k][i].patch.set_alpha(0.0)
        if i == 0:
            axs[k][i].text(0, 1.1, panel_label[k], transform=axs[k][i].transAxes, fontweight='bold')

        # ---- Group shading + separators ----
        for j in range(nvariables):
            left  = j * (nmethods + gap) - 0.5
            right = left + nmethods
            

            # alternating faint background blocks
            if j % 2 != 0:
                axs[k][i].axvspan(left - gap/2, right + gap/2, color="k", alpha=0.02, zorder=0)

            # dashed separator between blocks
            if j < (nvariables - 1):
                axs[k][i].axvline(right + gap/2, ls="--", lw=1, alpha=0.4, color="k", zorder=0)

handles = [boxs[mkey]["boxes"][0] for mkey in methods.keys()]
labels = [method_labels[mkey] for mkey in methods.keys()]
ax0l.legend(handles = handles, labels = labels, 
           loc = 'upper left', frameon = False, handlelength = 2, ncol = 2)

for ax in [ax0_dummy, ax0l, ax0r, ax_dummy]:
    ax.patch.set_alpha(0.0)
    
ax0r.text(-0.8, 0.95, "(a)", transform=ax0r.transAxes, fontweight='bold')

# plt.tight_layout(h_pad=1.5, w_pad=1.5)
# gs.tight_layout(fig)
# plt.savefig('../plots/colormann-manuscript/sim_vary_p_h2_rel_rmse.pdf', bbox_inches='tight')
# plt.savefig('../plots/colormann-manuscript/sim_vary_p_h2_rel_rmse.png', bbox_inches='tight')
plt.show()

Code
df = stratify_dfcols(dscout, [('lowrankfit', None), ('mfmethods', 'flashier'), ('simulate.k', 10)])
df
DSC simulate simulate.n simulate.p simulate.k simulate.h2 simulate.h2_shared_frac simulate.aq simulate.nsample_minmax lowrankfit ... score.MI_raw score.adj_MI_raw score.MI_bal score.adj_MI_bal score.MI_cal score.adj_MI_cal score.MI_bal_cal score.adj_MI_bal_cal score.MI_oracle_aligned score.adj_MI_oracle_aligned
570 1 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) None ... 0.009893 0.006935 0.009893 0.057017 0.009893 0.006935 0.009893 0.057017 0.009893 0.068062
571 2 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) None ... 0.011190 0.010333 0.016863 0.600236 0.011190 0.010333 0.016863 0.600236 0.016863 0.600236
572 3 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) None ... 0.007394 0.022740 0.031577 0.022466 0.007394 0.022740 0.031577 0.022466 0.031577 0.022466
573 4 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) None ... 0.009358 0.031368 0.014479 0.665729 0.009358 0.031368 0.014479 0.665729 0.014479 0.665729
574 5 blockdiag 200.0 2000.0 10.0 0.2 0.6 0.6 (10000,40000) None ... 0.011757 0.032764 0.526487 0.569449 0.011757 0.032764 0.526487 0.569449 0.526487 0.569449
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1605 8 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None ... 0.586263 0.942010 1.034805 0.966005 0.586263 0.942010 1.034805 0.966005 1.034805 0.966005
1606 9 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.4 (10000,40000) None ... 0.021820 0.054018 0.022875 0.021522 0.021820 0.054018 0.022875 0.021522 0.022875 0.021522
1607 9 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None ... 0.024300 0.009859 1.041024 0.951072 0.024300 0.009859 1.041024 0.951072 1.065732 0.975363
1608 10 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.4 (10000,40000) None ... 0.012678 -0.001681 0.008805 0.073802 0.012678 -0.001681 0.008805 0.073802 0.008805 0.073802
1609 10 blockdiag_aq 200.0 2000.0 10.0 0.2 0.6 0.8 (10000,40000) None ... 0.021674 0.030892 1.024265 0.891986 0.021674 0.030892 1.024265 0.891986 1.043526 0.908815

150 rows × 50 columns

Code
# def get_scores_from_dataframe(df, score_name, variable_name, variable_values, methods = methods):
#     simdf = get_simulation_with_variable(df, variable_name, variable_values)
#     scores = {key: list() for key in methods.keys()}
#     for method, mlist in methods.items():
#         mrows = stratify_dfcols(simdf, [('lowrankfit', mlist[0]), ('mfmethods', mlist[1])])
#         for value in variable_values:
#             vrows = stratify_dfcol(mrows, f'simulate.{variable_name}', value)
#             scores[method].append(vrows[f'score.{score_name}'].to_numpy())
#     return scores

variables = ['h2_shared_frac', 'k']
variable_values = {
    'k' : [2, 5, 10, 15, 20],
    # 'k': [2, 5],
    'h2_shared_frac' : [0.2, 0.4, 0.6, 0.8, 1.0],
}

xlabel = {
    'k': "No. of factors",
    'h2_shared_frac': "Fraction of shared heritability",
}

metric_labels_plot = {
  "adj_MI_raw":      "Adjusted MI",
  "L_rmse_cal": r"RMSE ($\hat{L}$)",
  "F_rmse_cal": r"RMSE ($\hat{F}$)",
  "Z_rmse_cal": r"RMSE ($\hat{L}\hat{F}^{T}$)",
}

nmethods = len(methods)
nscores = len(metric_labels_plot)
fig = plt.figure(figsize = (24, 24))

boxs = {x: None for x in methods.keys()}
gs = fig.add_gridspec(6, 6, height_ratios=(1, 0.2, 1, 1, 1, 1), wspace=0, hspace=0)

ax0_dummy = fig.add_subplot(gs[0, :])
ax_dummy = fig.add_subplot(gs[1,:])

ax0 = gs[0, :].subgridspec(1, 2, width_ratios=[1.2, 0.8], wspace=0.05)
ax0l = fig.add_subplot(ax0[0, 0])
ax0r = fig.add_subplot(ax0[0, 1])

ax_p  = [fig.add_subplot(gs[2, 0:3]), 
         fig.add_subplot(gs[3, 0:3]), 
         fig.add_subplot(gs[4, 0:3]),
         fig.add_subplot(gs[5, 0:3])]

ax_h2 = [fig.add_subplot(gs[2, 3:], sharey = ax_p[0]), 
         fig.add_subplot(gs[3, 3:], sharey = ax_p[1]), 
         fig.add_subplot(gs[4, 3:], sharey = ax_p[2]), 
         fig.add_subplot(gs[5, 3:], sharey = ax_p[3])]

axs  = [ax_p, ax_h2]

colorlist = [manuscript_colors[x] for x in ['orange', 'blue', 'yellowgreen']]

make_plot_pca(ax0r, simdata['Ltrue'], simdata['Ctrue'], list(set(simdata['Ctrue'])), colorlist)
ax0r.set_xlabel("Factor 1")
ax0r.set_ylabel("Factor 2")
ax0r.set_aspect(1.0)

for ax in [ax_dummy, ax0l, ax0_dummy]:
    style_metric_axis(ax)

for side, border in list(ax0_dummy.spines.items()) + \
                    list(ax0l.spines.items()) + \
                    list(ax_dummy.spines.items()):
    border.set_visible(False)
    
    
panel_label = ["(b)", "(c)"]
gap = 2

for k, var in enumerate(variables):
    for i, (metric_name, metric_label) in enumerate(metric_labels_plot.items()):
        nvariables = len(variable_values[var])
        scores = get_scores_from_dataframe(dscout, metric_name, var, variable_values[var])
        
        if var == "h2_shared_frac":
            clean_scores = sanitize_scores_manually(scores, metric_name, remove = ("gleanr", 6))
        elif var == "k":
            clean_scores = scores
        
        draw_box_scatter(axs[k][i], clean_scores, boxs, nvariables, gap, metric_name, show_scatter=False)
        style_metric_axis(axs[k][i])
        
        if i == 3:
            axs[k][i].tick_params(bottom = True, labelbottom = True)
            xcenter = [x * (nmethods + gap) + (nmethods - 1) / 2 for x in range(nvariables)]
            axs[k][i].set_xticks(xcenter)
            axs[k][i].set_xticklabels(variable_values[var])
            axs[k][i].set_xlabel(xlabel[var])
        if k == 0:
            axs[k][i].tick_params(left = True, labelleft = True)
            axs[k][i].set_ylabel(metric_label)

        xlim_low = 0 - (nvariables - 1) / 2
        xlim_high = (nvariables - 1) * (nmethods + gap) + (nmethods - 1) + (nvariables - 1) / 2

        axs[k][i].set_xlim( xlim_low, xlim_high )

        axs[k][i].patch.set_alpha(0.0)
        if i == 0:
            axs[k][i].text(0, 1.1, panel_label[k], transform=axs[k][i].transAxes, fontweight='bold')

        # ---- Group shading + separators ----
        for j in range(nvariables):
            left  = j * (nmethods + gap) - 0.5
            right = left + nmethods
            

            # alternating faint background blocks
            if j % 2 != 0:
                axs[k][i].axvspan(left - gap/2, right + gap/2, color="k", alpha=0.02, zorder=0)

            # dashed separator between blocks
            if j < (nvariables - 1):
                axs[k][i].axvline(right + gap/2, ls="--", lw=1, alpha=0.4, color="k", zorder=0)

handles = [boxs[mkey]["boxes"][0] for mkey in methods.keys()]
labels = [method_labels[mkey] for mkey in methods.keys()]
ax0l.legend(handles = handles, labels = labels, 
           loc = 'upper left', frameon = False, handlelength = 2, ncol = 2)

for ax in [ax0_dummy, ax0l, ax0r, ax_dummy]:
    ax.patch.set_alpha(0.0)
    
ax0r.text(-0.8, 0.95, "(a)", transform=ax0r.transAxes, fontweight='bold')

# plt.tight_layout(h_pad=1.5, w_pad=1.5)
# gs.tight_layout(fig)
# plt.savefig('../plots/colormann-manuscript/sim_vary_k_h2shared_rmse.pdf', bbox_inches='tight')
# plt.savefig('../plots/colormann-manuscript/sim_vary_k_h2shared_rmse.png', bbox_inches='tight')
plt.show()

Code
df = stratify_dfcols(dscout, 
    [('lowrankfit', None), 
     ('mfmethods', 'flashier'), 
     # ('simulate.h2_shared_frac', 0.2),
    ])
df = stratify_dfcols_in_list(df, "simulate.k", (2.0, 5.0,))
cols = ~df.columns.str.startswith("score.") | df.columns.str.startswith("score.F_")
df.loc[:, cols]
DSC simulate simulate.n simulate.p simulate.k simulate.h2 simulate.h2_shared_frac simulate.aq simulate.nsample_minmax lowrankfit mfmethods score.F_rmse_raw score.F_rmse_bal score.F_rmse_cal score.F_rmse score.F_rel_rmse_raw score.F_rel_rmse_bal score.F_rel_rmse_cal score.F_rel_rmse score.F_psnr
870 1 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003475 0.003450 0.003475 0.003450 0.155404 0.154297 0.155404 0.154297 33.840156
871 1 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006007 0.005555 0.006007 0.005555 0.268657 0.248407 0.268657 0.248407 29.408689
874 2 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003385 0.003124 0.003385 0.003124 0.151402 0.139727 0.151402 0.139727 33.369204
875 2 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.008079 0.005756 0.008079 0.005756 0.361326 0.257403 0.361326 0.257403 29.985767
878 3 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.004920 0.003026 0.004920 0.003026 0.220032 0.135306 0.220032 0.135306 35.014616
879 3 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.005384 0.004662 0.005384 0.004662 0.240790 0.208508 0.240790 0.208508 31.545110
882 4 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.004064 0.003417 0.004064 0.003417 0.181745 0.152793 0.181745 0.152793 33.519885
883 4 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006711 0.005185 0.006711 0.005185 0.300144 0.231865 0.300144 0.231865 30.813012
886 5 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003734 0.003333 0.003734 0.003333 0.166995 0.149065 0.166995 0.149065 34.449152
887 5 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006036 0.005245 0.006036 0.005245 0.269956 0.234579 0.269956 0.234579 30.107015
890 6 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.005140 0.003325 0.005140 0.003325 0.229853 0.148707 0.229853 0.148707 32.672731
891 6 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006505 0.005188 0.006505 0.005188 0.290903 0.232023 0.290903 0.232023 29.894746
894 7 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003816 0.003711 0.003816 0.003711 0.170661 0.165974 0.170661 0.165974 32.745931
895 7 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006779 0.005038 0.006779 0.005038 0.303156 0.225295 0.303156 0.225295 31.087895
898 8 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003252 0.002727 0.003252 0.002727 0.145426 0.121955 0.145426 0.121955 35.811457
899 8 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006549 0.005308 0.006549 0.005308 0.292899 0.237393 0.292899 0.237393 30.645232
902 9 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003278 0.003238 0.003278 0.003238 0.146595 0.144803 0.146595 0.144803 34.327362
903 9 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.006782 0.005183 0.006782 0.005183 0.303297 0.231792 0.303297 0.231792 30.232479
906 10 blockdiag_k 200.0 2000.0 2.0 0.2 0.6 0.6 (10000,40000) None flashier 0.003486 0.003270 0.003486 0.003270 0.155897 0.146234 0.155897 0.146234 33.462423
907 10 blockdiag_k 200.0 2000.0 5.0 0.2 0.6 0.6 (10000,40000) None flashier 0.005929 0.005221 0.005929 0.005221 0.265146 0.233512 0.265146 0.233512 29.872346
Code
def get_scores_all_simulations(df, score_name, methods = methods):
    # simdf = get_simulation_with_variable(df, variable_name, variable_values)
    scores = {key: list() for key in methods.keys()}
    for method, mlist in methods.items():
        mrows = stratify_dfcols(df, [('lowrankfit', mlist[0]), ('mfmethods', mlist[1])])
        scores[method].append(mrows[f'score.{score_name}'].to_numpy())
    return scores

def draw_box_scatter(ax, arrays_by_method, boxs):
    for j, mkey in enumerate(methods.keys()):
        boxcolor = method_colors[mkey]
        boxface = f'#{boxcolor[1:]}80'
        medianprops = dict(linewidth=0, color = boxcolor)
        whiskerprops = dict(linewidth=2, color = boxcolor)
        boxprops = dict(linewidth=2, color = boxcolor, facecolor = boxface)
        flierprops = dict(marker='o', markerfacecolor=boxface, markersize=3, markeredgecolor = boxcolor)

        # m_score_arr = clean_score_array(arrays_by_method[mkey])
        m_score_arr = arrays_by_method[mkey][0]
        m_score_arr = m_score_arr[~np.isnan(m_score_arr)]
        xpos = [j]
        boxs[mkey] = ax.boxplot(m_score_arr, positions = xpos,
            showcaps = False, showfliers = False,
            widths = 0.7, patch_artist = True, notch = False,
            flierprops = flierprops, boxprops = boxprops,
            medianprops = medianprops, whiskerprops = whiskerprops)
        ax.scatter(random_jitter(j, m_score_arr), m_score_arr, 
               edgecolor = boxcolor, facecolor = boxface, linewidths = 1, 
               s = 2)
    return

dscout_plot = dscout.copy()
dscout_plot["score.global_scale_log10"] = np.log10(
    dscout_plot["score.global_scale"].where(dscout_plot["score.global_scale"] != 0)
)

fig = plt.figure(figsize = (8,8))
ax = fig.add_subplot(111)

metric_name = "global_scale_log10"
scores = get_scores_all_simulations(dscout_plot, metric_name)
scores_by_method = {mkey: scores[mkey] for mkey in methods.keys()}
score_label = metric_labels[metric_name]
draw_box_scatter(ax, scores_by_method, boxs)

ax.set_xticks(np.arange(len(methods)))
ax.set_xticklabels([method_labels[m] for m in methods.keys()], rotation=90)
ax.set_ylabel(score_label)
plt.savefig('../plots/colormann-manuscript/sim_global_amplitude_calibration.pdf', bbox_inches='tight')

plt.show()

Code
dscout_plot[dscout_plot["score.global_scale"] == 0].loc[777]
DSC                                        7
simulate                         blockdiag_p
simulate.n                             200.0
simulate.p                           10000.0
simulate.k                              10.0
simulate.h2                              0.2
simulate.h2_shared_frac                  0.6
simulate.aq                              0.6
simulate.nsample_minmax        (10000,40000)
lowrankfit                              None
mfmethods                             gleanr
score.global_scale                       0.0
score.balancing_impact                   NaN
score.L_rmse_raw                    0.091683
score.F_rmse_raw                        0.01
score.Z_rmse_raw                    0.002899
score.L_rmse_bal                    0.091683
score.F_rmse_bal                        0.01
score.Z_rmse_bal                    0.002899
score.L_rmse_cal                    0.091683
score.F_rmse_cal                        0.01
score.Z_rmse_cal                    0.002899
score.L_rmse                        0.091683
score.F_rmse                            0.01
score.Z_rmse                        0.002899
score.L_rel_rmse_raw                     1.0
score.F_rel_rmse_raw                     1.0
score.Z_rel_rmse_raw                     1.0
score.L_rel_rmse_bal                     1.0
score.F_rel_rmse_bal                     1.0
score.Z_rel_rmse_bal                     1.0
score.L_rel_rmse_cal                     1.0
score.F_rel_rmse_cal                     1.0
score.Z_rel_rmse_cal                     1.0
score.L_rel_rmse                         1.0
score.F_rel_rmse                         1.0
score.Z_rel_rmse                         1.0
score.L_psnr                       15.727685
score.F_psnr                       18.479476
score.Z_psnr                       22.945266
score.MI_raw                        0.010307
score.adj_MI_raw                    -0.00207
score.MI_bal                        0.010307
score.adj_MI_bal                    -0.00207
score.MI_cal                        0.010307
score.adj_MI_cal                    -0.00207
score.MI_bal_cal                    0.010307
score.adj_MI_bal_cal                -0.00207
score.MI_oracle_aligned             0.010307
score.adj_MI_oracle_aligned         -0.00207
score.global_scale_log10                 NaN
Name: 777, dtype: object
Code
nsim = 10
nrow = 4
ncol = 3

fig = plt.figure(figsize = (12,16))
axs = [fig.add_subplot(nrow, ncol, i+1) for i in range(nsim)]

for i in range(nsim):
    _filename = os.path.join(dsc_output, f"blockdiag/blockdiag_{i+1}.pkl")
    with open(_filename, "rb") as fh:
        _simdata = pickle.load(fh)

    make_plot_pca(axs[i], _simdata['Ltrue'], _simdata['Ctrue'], 
                  list(set(simdata['Ctrue'])), colorlist)
    axs[i].set_title(f"blockdiag_{i+1}")
    
plt.tight_layout()
plt.show()

Code
def get_scores_all_simulations(df, score_name, methods = methods):
    # simdf = get_simulation_with_variable(df, variable_name, variable_values)
    scores = {key: list() for key in methods.keys()}
    for method, mlist in methods.items():
        mrows = stratify_dfcols(df, [('lowrankfit', mlist[0]), ('mfmethods', mlist[1])])
        scores[method].append(mrows[f'score.{score_name}'].to_numpy())
    return scores