Stability heatmaps for PGC split-replication CV

Author

Saikat Banerjee

Published

April 30, 2026

Abstract
Visualize split-replication stability metrics from the Clorinn cross-validation pipeline. The notebook reads per-radius JSON summaries, converts by-k metrics into k-by-nuclear-norm grids, and plots heatmaps for projection distance, cumulative singular-value energy, and largest principal angle.
Code
import json
import pickle
from pathlib import Path
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 300, fontsize=14)

from matplotlib import colormaps as mpl_cmaps
import matplotlib.colors as mpl_colors
from mpl_toolkits.axes_grid1 import make_axes_locatable

Paths and analysis settings

data_root should point to the Snakemake output directory for the split-replication CV run. The notebook expects:

  • per-radius stability JSON files under stability/
  • the aggregate stability metric CSV under summary/

prefix identifies the model/solver combination. For example, pgd_fw_nnm matches files such as pgd_fw_nnm_r8192.json.

Code
data_root = "/gpfs/commons/groups/knowles_lab/data/PsychGen/analysis/clorinn/cv_split_replication"
# Model/solver prefix used by the Snakemake output filenames.
PREFIXES = ["pgd_afw_nnm_corr", "pgd_fw_nnm"]

stability_out_dir = f"{data_root}/stability"
summary_out = f"{data_root}/summary/{PREFIXES[0]}_stability_metrics.csv"

fig_dir = Path("figures/split_replication_cv")
fig_dir.mkdir(parents=True, exist_ok=True)

Read aggregate summary table

This CSV is not required for the heatmaps below, but displaying it is a useful sanity check that the expected pipeline outputs were generated.

Code
summary_df = pd.read_csv(summary_out)
summary_df.head()
nucnorm mean_dist_k1 mean_dist_k2 mean_dist_k3 mean_dist_k4 mean_dist_k5 mean_dist_k6 mean_dist_k7 mean_dist_k8 mean_dist_k9 ... se_energy_k22 se_energy_k23 se_energy_k24 se_energy_k25 se_energy_k26 se_energy_k27 se_energy_k28 se_energy_k29 se_energy_k30 grid
0 1.0 0.021770 0.203261 0.529161 0.600799 0.648400 0.660207 0.659052 0.660754 0.667229 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 coarse
1 2.0 0.022354 0.274409 0.577392 0.564718 0.588592 0.624654 0.654132 0.668207 0.659147 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 coarse
2 4.0 0.023633 0.266199 0.536371 0.635676 0.662308 0.665486 0.668800 0.676398 0.675788 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 coarse
3 8.0 0.026768 0.326992 0.566602 0.625170 0.634612 0.643133 0.661838 0.670966 0.670403 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 coarse
4 16.0 0.034373 0.054074 0.169477 0.484521 0.601823 0.635789 0.658699 0.674419 0.663070 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 coarse

5 rows × 152 columns

Helper functions

The stability JSON files are stored one file per radius. Each file contains a by_k list with metrics evaluated at several ranks. The helper functions below convert this nested JSON structure into rectangular matrices suitable for heatmaps.

Code
def load_stability_jsons(stability_dir, prefix):
    stability_dir = Path(stability_dir)
    paths = sorted(stability_dir.glob(f"{prefix}_r*.json"))

    records = []
    for path in paths:
        with open(path) as fh:
            rec = json.load(fh)
        rec["_path"] = str(path)
        records.append(rec)

    return records


def build_by_k_metrics_grid(records, metric, *, duplicate="error"):
    rows = []

    for rec in records:
        nucnorm = int(round(float(rec["nucnorm"])))

        for entry in rec["by_k"]:
            value = entry[metric]
            rows.append({
                "nucnorm": nucnorm, 
                "k": int(entry["k"]), 
                metric: float(value),
            })

    long_df = pd.DataFrame(rows)

    dup_mask = long_df.duplicated(subset=["nucnorm", "k"], keep=False)
    # Look for duplicates. This can happen if coarse and fine grids
    # contain the same nucnorm.
    if dup_mask.any():
        keep = "first" if duplicate == "first" else "last"
        long_df = (
            long_df
            .sort_values(["nucnorm", "k"])
            .drop_duplicates(subset=["nucnorm", "k"], keep=keep)
        )
    grid_df = (
        long_df
        .pivot(index="k", columns="nucnorm", values=metric)
        .sort_index(axis=0)
        .sort_index(axis=1)
    )

    grid_df.index.name = "k"
    grid_df.columns.name = "nucnorm"

    return grid_df

Load stability metrics as heatmap grids

The three metrics below summarize complementary aspects of split-replication stability:

  • mean_dist: mean projection distance between the top-k trait subspaces from two SNP splits. Lower is more stable.
  • mean_energy: mean cumulative singular-value energy captured by the top-k directions. Higher means the selected k captures more fitted signal.
  • mean_gap_angle: mean largest principal angle between split-specific top-k subspaces. Lower angles indicate more similar subspaces.

For plotting, duplicate radii are resolved with duplicate="first" because the coarse and fine radius grids can overlap.

Code
records      = load_stability_jsons(stability_out_dir, PREFIXES[0])
dist_df      = build_by_k_metrics_grid(records, "mean_dist", duplicate="first")
energy_df    = build_by_k_metrics_grid(records, "mean_energy", duplicate="first")
gap_angle_df = build_by_k_metrics_grid(records, "mean_gap_angle", duplicate="first")

Main cross-validation plot

Use this figure to decide where the projection distance stops improving materially as the nuclear-norm radius increases. The left panel gives the full k by r grid; the right panel shows one rank slice through the same grid.

Code
def get_r_scaled(x, scale="log2"):
    x = np.asarray(x, dtype=float)
    powers = np.arange(np.ceil(np.log2(x.min())), np.floor(np.log2(x.max())) + 1)
    xlabels = 2 ** powers
    
    if scale == "log10":
        xscale = np.log10(x)
        xticks = np.log10(xlabels)
    elif scale == "log2":
        xscale = np.log2(x)
        xticks = np.log2(xlabels)
    else:
        xscale = x
        xticks = xlabels

    xlabel_str = [f"{r:g}" for r in xlabels]
    return xscale, xticks, xlabel_str


def centers_to_edges(x):
    x = np.asarray(x, dtype=float)
    edges = np.empty(len(x) + 1)
    edges[1:-1] = 0.5 * (x[:-1] + x[1:])
    edges[0] = x[0] - 0.5 * (x[1] - x[0])
    edges[-1] = x[-1] + 0.5 * (x[-1] - x[-2])
    return edges


def plot_heatmap(ax, X, rank_list, k_list,
        vmin=0, vcenter=0.9, vmax=1.0,
        scale="log2"):

    cmap1 = mpl_cmaps.get_cmap("YlOrRd").copy()
    cmap1.set_bad("w")
    norm1 = mpl_colors.TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

    r, rticks, rlabels = get_r_scaled(rank_list, scale=scale)

    x_edges = centers_to_edges(r)
    y_edges = np.arange(len(k_list) + 1) - 0.5

    im1 = ax.pcolormesh(x_edges, y_edges, X, cmap=cmap1, norm=norm1, shading="auto")

    # match imshow(origin="upper")
    ax.set_ylim(len(k_list) - 0.5, -0.5)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.2)
    cbar = plt.colorbar(im1, cax=cax, fraction=0.1)

    ax.set_xlabel("Nuclear norm constraint radius r")
    ax.set_ylabel("Rank k")

    ax.set_yticks(np.arange(len(k_list)))
    ax.set_yticklabels([str(int(k)) for k in k_list])

    ax.set_xticks(rticks)
    ax.set_xticklabels([str(int(r)) for r in rlabels], rotation=90)

    return im1

def make_projection_distance_figure(prefix,
        vmin=0.09, vcenter=0.12, vmax=0.2):
    
    smooth_window = 5
    
    records = load_stability_jsons(stability_out_dir, prefix)
    dist_df = build_by_k_metrics_grid(records, "mean_dist", duplicate="first")
    smoothed_dist_df = dist_df.rolling(window=smooth_window, center=True, min_periods=1).mean()
    se_dist_df = build_by_k_metrics_grid(records, "se_dist", duplicate="first")
    rank_list = dist_df.columns.to_numpy()
    k_list = dist_df.index.to_numpy()
 
    fig = plt.figure(figsize=(20, 15), constrained_layout=True)
    gs = fig.add_gridspec(nrows=2, ncols=2, 
                          width_ratios=[1, 1],  # heatmap, line plot
                          wspace=0.2, hspace=0.05,
                         )
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])

    im1 = plot_heatmap(ax1, dist_df.to_numpy(), rank_list, k_list, 
                 vmin=vmin, vcenter=vcenter, vmax=vmax)
    ax1.set_title("Mean projection distance across splits", pad = 20)
    ax1.text(-0.1, 1.1, "(a)", transform=ax1.transAxes, fontweight='bold')

    # Plot rows of the heatmap as a line plot. 
    # Makes it easier to see the plateau onset for chosen ranks.
    k_choose = [8, 10, 15, 20, 25]
    r, rticks, rlabels = get_r_scaled(rank_list, scale='log2')
    rlabels = [f"{int(r):d}" for r in rlabels]
    for k in k_choose:
        y = dist_df.loc[k].to_numpy()
        se = se_dist_df.loc[k].to_numpy()
        ax2.plot(r, y, 'o-', label=f"{k:d}")
        ax2.fill_between(r, y - se, y + se, alpha=0.2)
    ax2.set_xlabel("Nuclear norm constraint radius r")
    ax2.set_ylabel(f"Mean projection distance")
    ax2.set_xticks(rticks)
    ax2.set_xticklabels(rlabels, rotation=90)
    ax2.set_title(f"Slices over r at selected k", pad = 20)
    ax2.legend(loc = 'upper right', frameon = False, handlelength = 2, ncol = 2)
    ax2.text(-0.1, 1.1, "(b)", transform=ax2.transAxes, fontweight='bold')
    
    # Plot columns of the heatmap as a line plot. 
    # Makes it easier to see the variation over ranks.
    r_choose = [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
    for r in r_choose:
        y = smoothed_dist_df[r].to_numpy()
        x = np.arange(len(y))
        se = se_dist_df[r].to_numpy()
        ax3.plot(x, y, 'o-', label=f"{r:d}")
        ax3.fill_between(x, y - se, y + se, alpha=0.2)
    ax3.set_xlabel("Rank k")
    ax3.set_ylabel(f"Mean projection distance")
    ax3.set_xticks(x)
    ax3.set_xticklabels(k_list, rotation=90)
    ax3.set_title(f"Slices over k at selected r, smoothed over 5 k-neighbors", pad = 20)
    ax3.legend(loc = 'upper left', bbox_to_anchor=(0.2, 0.95), frameon = False, handlelength = 2, ncol = 2)
    ax3.text(-0.1, 1.1, "(c)", transform=ax3.transAxes, fontweight='bold')
    
    # Plot rows of the heatmap as a line plot. 
    # Makes it easier to see the plateau onset for chosen ranks.
    k_choose = [8, 10, 15, 20, 25]
    r, rticks, rlabels = get_r_scaled(rank_list, scale='log2')
    rlabels = [f"{int(r):d}" for r in rlabels]
    for k in k_choose:
        y = smoothed_dist_df.loc[k].to_numpy()
        ax4.plot(r, y, 'o-', label=f"{k:d}")
    ax4.set_xlabel("Nuclear norm constraint radius r")
    ax4.set_ylabel(f"Mean projection distance")
    ax4.set_xticks(rticks)
    ax4.set_xticklabels(rlabels, rotation=90)
    ax4.set_title(f"Slices over r at selected k, smoothed over {smooth_window:d} k-neighbors", pad = 20)
    ax4.legend(loc = 'upper right', frameon = False, handlelength = 2, ncol = 2)
    ax4.text(-0.1, 1.1, "(d)", transform=ax4.transAxes, fontweight='bold')
    
    fig.savefig(
        fig_dir / f"{prefix}_projection_distance.png",
        bbox_inches="tight",
    )
    plt.close(fig)
    # plt.show()
    

for prefix in PREFIXES:
    vmin, vcenter, vmax = 0.09, 0.18, 0.30
    if prefix == "pgd_afw_nnm_corr":
        vmin, vcenter, vmax = 0.05, 0.3, 0.6
    make_projection_distance_figure(prefix, vmin=vmin, vcenter=vcenter, vmax=vmax)

Diagnostics

These plots are secondary checks. They help distinguish a real stability plateau from artifacts caused by changing effective rank or unstable high-dimensional directions.

Code
def make_diagnostics_energy_angle_figure(prefix):
    records      = load_stability_jsons(stability_out_dir, prefix)
    energy_df    = build_by_k_metrics_grid(records, "mean_energy", duplicate="first")
    gap_angle_df = build_by_k_metrics_grid(records, "mean_gap_angle", duplicate="first")
    rank_list = energy_df.columns.to_numpy()
    k_list = energy_df.index.to_numpy()

    fig = plt.figure(figsize=(20, 7.5), constrained_layout=True)
    gs = fig.add_gridspec(nrows=1, ncols=2, 
                          width_ratios=[1, 1],
                          wspace=0.2,
                         )
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])

    im1 = plot_heatmap(ax1, energy_df.to_numpy(), rank_list, k_list, vmin = 0.8, vcenter = 0.9, vmax = 1.0)
    ax1.set_title("Mean cumulative singular-value energy", pad = 20)

    rank_list = gap_angle_df.columns.to_numpy()
    k_list = gap_angle_df.index.to_numpy()
    im2 = plot_heatmap(ax2, gap_angle_df.to_numpy(), rank_list, k_list, vmin = 0, vcenter = 45, vmax = 90.0)
    ax2.set_title("Mean largest principal angle across splits", pad = 20)
    
    fig.savefig(
        fig_dir / f"{prefix}_energy_gap_angle.png",
        bbox_inches="tight",
    )
    plt.close(fig)
    
for prefix in PREFIXES:
    make_diagnostics_energy_angle_figure(prefix)