Manifold embedding of NNM loadings

Author

Saikat Banerjee

Published

April 3, 2024

Abstract
We compare the clusters obtained from the latent factors with LLM classification of the corresponding phenotype descriptions.

About

We identify clusters from the loadings of the low-dimensional representation of the PanUKB Z-score matrix. We then compare those clusters with the groups predicted by the pretrained LLM models.

Setup

We have to first load a bunch of useful tools including scikit-learn and Bokeh.

Code
import os
import numpy as np
import pandas as pd
import pickle
import re

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils

mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 120)

import umap
from bokeh.plotting import figure as bokeh_figure
from bokeh.plotting import show as bokeh_show
from bokeh.layouts import column as bokeh_column
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
from bokeh.models import HoverTool
from bokeh.models import CategoricalColorMapper

from sklearn.neighbors import kneighbors_graph
from sklearn.manifold import SpectralEmbedding, TSNE, LocallyLinearEmbedding, Isomap, MDS
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans

from sentence_transformers import util as st_util

output_notebook()
Loading BokehJS ...

Load data and results

Here, we explore the low rank model from nuclear norm minimization with the sparse matrix.

Code
data_dir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/data"
result_dir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/nnsparsh"

zscore_df = pd.read_pickle(os.path.join(data_dir, f"modselect/zscore_all.pkl"))
trait_df  = pd.read_pickle(os.path.join(data_dir, f"modselect/traits_all_with_desc.pkl"))
zscore_df_noRx = pd.read_pickle(os.path.join(data_dir, f"modselect/zscore_noRx.pkl"))
trait_df_noRx  = trait_df.query('trait_type != "prescriptions"')

method = 'rpca'

res_filename = os.path.join(result_dir, "full", f"{method}_model.pkl")
with (open(res_filename, "rb")) as fh:
    lowrank_model = pickle.load(fh)
    
res_filename_noRx = os.path.join(result_dir, "noRx", f"{method}_model.pkl")
with (open(res_filename_noRx, "rb")) as fh:
    lowrank_model_noRx = pickle.load(fh)

X = np.array(zscore_df.drop(labels = ['rsid'], axis = 1).values.T)
X_cent = X - np.mean(X, axis = 0, keepdims = True)

X_noRx = np.array(zscore_df_noRx.values.T)
X_noRx_cent = X_noRx - np.mean(X_noRx, axis = 0, keepdims = True)

lowX = lowrank_model['L_']
lowX_cent = lowX - np.mean(lowX, axis = 0, keepdims = True)
lowX_std = lowX_cent / np.sqrt(np.prod(lowX_cent.shape))

lowX_noRx = lowrank_model_noRx['L_']
lowX_noRx_cent = lowX_noRx - np.mean(lowX_noRx, axis = 0, keepdims = True)
lowX_noRx_std = lowX_noRx_cent / np.sqrt(np.prod(lowX_noRx_cent.shape))

print ("Nuclear Norms")
print (f"Low rank model: {np.linalg.norm(lowX, ord = 'nuc'):.3f}")
print (f"Low rank model without Rx: {np.linalg.norm(lowX_noRx, ord = 'nuc'):.3f}")
print (f"Input data: {np.linalg.norm(X, ord = 'nuc'):.3f}")
print (f"Input data (mean centered): {np.linalg.norm(X_cent, ord = 'nuc'):.3f}")
Nuclear Norms
Low rank model: 180216.129
Low rank model without Rx: 147300.307
Input data: 496751.155
Input data (mean centered): 495872.387

Compute loadings

We compute the loadings from the SVD of the low rank approximation of the input data.

To-Do: Save the loadings and factors so that we don’t have to calculate it every time.

Code
def compute_loadings_factors(X, k = None):
    #X_cent = mpy_simulate.do_standardize(X, scale = False)
    #X_cent /= np.sqrt(np.prod(X_cent.shape))
    U, S, Vt = np.linalg.svd(X, full_matrices = False)
    S2 = np.square(S)
    explained_variance = S2 / np.sum(S2)
    if k is None:
        k = np.where(explained_variance < 1e-4)[0][0] - 1
    U_low = U[:, :k]
    S_low = S[:k]
    factors = Vt[:k, :].T
    loadings = U_low @ np.diag(S_low)
    return U_low, S_low, loadings, factors

U0, S0, loadings0, factors0 = compute_loadings_factors(lowX_std)
U1, S1, loadings1, factors1 = compute_loadings_factors(lowX_noRx_std)
U2, S2, loadings2, factors2 = compute_loadings_factors(X_cent)
U3, S3, loadings3, factors3 = compute_loadings_factors(X_noRx_cent)

nnm_loadings  = loadings0.copy()
nnm_loadings_noRx = loadings1.copy()
tsvd_loadings = U2[:, :loadings0.shape[1]] @ np.diag(S2[:loadings0.shape[1]])
tsvd_loadings_noRx = U3[:, :loadings1.shape[1]] @ np.diag(S3[:loadings1.shape[1]])
Code
hex_colors_40 = [
    "#084609",
    "#ff4ff4",
    "#01d94a",
    "#b700ce",
    "#91c900",
    "#5f42ed",
    "#5fa200",
    "#8d6dff",
    "#c9f06b",
    "#0132a7",
    "#ffbb1f",
    "#0080ed",
    "#f56600",
    "#3afaf5",
    "#c10001",
    "#01e698",
    "#a20096",
    "#00e2c1",
    "#ff5ac8",
    "#008143",
    "#cd0057",
    "#4aeeff",
    "#8c001a",
    "#b5f2a2",
    "#5d177d",
    "#a99900",
    "#e299ff",
    "#5b6b00",
    "#96aeff",
    "#a46f00",
    "#007acb",
    "#ff9757",
    "#00a8e0",
    "#ff708e",
    "#baefc7",
    "#622b25",
    "#c8c797",
    "#885162",
    "#ffb7a5",
    "#ffa3c3"]

llm_methods = [
    "ls-da3m0ns/bge_large_medical",
    "medicalai/ClinicalBERT",
    "emilyalsentzer/Bio_ClinicalBERT",
]

llm_ctypes = [
    "community", 
    "kmeans", 
    "agglomerative"]

llm_outdir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/llm"

llm_clusters = {method : { x : None for x in llm_ctypes } for method in llm_methods}
for method in llm_methods:
    for ctype in llm_ctypes:
        m_filename = os.path.join(llm_outdir, f"{method}/{ctype}_clusters.pkl")
        with open(m_filename, "rb") as fh:
            llm_clusters[method][ctype] = pickle.load(fh)
            
def get_llm_cluster_labels(selectidx, method, ctype):
    clusteridx = np.full([selectidx.shape[0],], -1)
    for i, ccomps in enumerate(llm_clusters[method][ctype]):
        for idx in ccomps:
            clusteridx[idx] = i
    return clusteridx

llm_cluster_labels = {
    method : {
        ctype: get_llm_cluster_labels(np.array(trait_df.index), method, ctype) for ctype in llm_ctypes
    } for method in llm_methods
}

llm_cluster_labels_noRx = {
    method : {
        ctype: [llm_cluster_labels[method][ctype][i] for i in trait_df_noRx.index] for ctype in llm_ctypes
    } for method in llm_methods
}

datadict = {
    'nnm_noRx': {
        'loadings': nnm_loadings_noRx,
        'cluster_labels': llm_cluster_labels_noRx,
        'traits': trait_df_noRx,
        'label': 'RPCA (no Rx)',
    },
    'nnm': {
        'loadings': nnm_loadings,
        'cluster_labels': llm_cluster_labels,
        'traits': trait_df,
        'label': 'RPCA',
    },
    'tsvd_noRx': {
        'loadings': tsvd_loadings_noRx,
        'cluster_labels': llm_cluster_labels_noRx,
        'traits': trait_df_noRx,
        'label': 'tSVD (no Rx)',
    },
    'tsvd': {
        'loadings': tsvd_loadings,
        'cluster_labels': llm_cluster_labels,
        'traits': trait_df,
        'label': 'tSVD',
    },
}   

Interactive Plots

We show interactive t-SNE plots in Figure 1. Each point is a PanUKB phenotype, colored according to the clusters identified by LLM models from the trait descriptions. The opaciy of each point is proportional to the estimated heritability of the trait.

Code
def get_bokeh_plot(
        embedding, cluster_labels, trait_df, plot_title, 
        color_palette = hex_colors_40, alpha_factor = 10):
    
    h2_list = [max(1e-6, x) for x in trait_df['estimates.final.h2_observed'].fillna(1e-6).tolist()]
    fill_alpha_list = [min(0.6, alpha_factor * x) for x in h2_list]
    line_alpha_list = [min(0.8, 1.3 * alpha_factor * x) for x in h2_list]
    
    color_palette_mod = color_palette.copy()
    if -1 in np.unique(cluster_labels):
        color_palette_mod.insert(0,  "#e3e3e3")
    
    plot_dict = dict(
        x = embedding[:, 0],
        y = embedding[:, 1],
        trait_id = [f"{x}" for x in cluster_labels],
        fill_alpha = fill_alpha_list,
        line_alpha = line_alpha_list,
        fulldesc = [f"{i}"
                    + f" | {trait_df.loc[i, 'short_description']}" 
                    + f" | {trait_df.loc[i, 'estimates.final.h2_observed']:.3f}"
                    + f" | {trait_df.loc[i, 'Neff']:.2f}" 
                    for i in trait_df.index],
    )

    color_mapping = CategoricalColorMapper(factors = [f"{x}" for x in np.unique(cluster_labels)], palette = color_palette_mod)

    plot_tooltips = [
        ("Desc", "@fulldesc"),
    ]

    ax = bokeh_figure(
        width = 800, height = 800, 
        tooltips = plot_tooltips,
        title = plot_title,
    )

    ax.circle('x', 'y', size = 8, 
        source = ColumnDataSource(plot_dict), 
        color = dict(field='trait_id', transform = color_mapping),
        line_alpha = dict(field='line_alpha'),
        fill_alpha = dict(field='fill_alpha'),
    )
    ax.title.text_font_size = '16pt'
    ax.title.text_font_style = 'normal'
    ax.title.text_font = 'tahoma'
    ax.axis.major_label_text_font_size = '20pt'
    ax.axis.axis_line_width = 2
    ax.axis.major_tick_line_width = 2
    ax.grid.visible = False
    return ax
Code
axlist = list()
if not ('data_transformed' in vars() or 'data_transformed' in globals()):
    data_transformed = dict()
llm_method = "ls-da3m0ns/bge_large_medical"
llm_ctype = "agglomerative"
alpha_factor = 30

for key, data in datadict.items():

    dist_matrix = kneighbors_graph(data['loadings'], n_neighbors = 2000, mode='distance', metric = 'cosine')
    embedding   = TSNE(n_components = 2, init = "random", perplexity = 20, early_exaggeration = 12, learning_rate = 'auto', random_state = 42, metric = 'precomputed')
    if key not in data_transformed.keys():
        data_transformed[key] = embedding.fit_transform(dist_matrix)

    ax = get_bokeh_plot(
        data_transformed[key], data['cluster_labels'][llm_method][llm_ctype], data['traits'], 
        f"{data['label']}, {llm_method}, {llm_ctype} clustering",
        alpha_factor = alpha_factor
    )
    axlist.append(ax)


# show the results
p = bokeh_column(*axlist)
bokeh_show(p)
(a) tSNE embeddings of the NNM loadings.
(b)
Figure 1
Code
def compute_clusters_kmeans(embeddings, n_clusters = 30):
    model = KMeans(n_clusters = n_clusters, random_state=0, max_iter = 10000, n_init='auto')
    model.fit(embeddings)
    cluster_assignment = model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters

def compute_clusters_hierarchical(embeddings, n_clusters = 30):
    clustering_model = AgglomerativeClustering(n_clusters = n_clusters)
    clustering_model.fit(embeddings)
    cluster_assignment = clustering_model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters

def compute_cluster_community(embeddings, n_size = 10, n_clusters = 30, thres_step = 0.05):
    threshold = 1.0
    clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    while len(clusters) < n_clusters:
        threshold -= thres_step
        clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    return clusters, threshold
Code
Z_clusters = { key:{} for key in datadict.keys() }
clustering_methods = ['kmeans', 'agglomerative', 'community']

for method in clustering_methods:
    for key, data in datadict.items():
        if method == 'kmeans': 
            Z_clusters[key][method] = compute_clusters_kmeans(data['loadings'], n_clusters = 50)
        if method == 'agglomerative': 
            Z_clusters[key][method] = compute_clusters_hierarchical(data['loadings'], n_clusters = 50)
        if method == 'community': 
            Z_clusters[key][method], threshold = compute_cluster_community(data['loadings'], n_clusters = 50)
            # print ("Used threshold:", threshold)
Code
def get_bokeh_plot_Zclusters(
        Z_clusters, llm_clusters, traits, 
        color_map = hex_colors_40, alpha_factor = 30,
        plot_title = "", nmax = 100
    ):
    
    # Put cluster indices on the x-axis, 
    # community clusters do not include all indices
    xvals = np.full([traits.shape[0],], -1)
    for i, clusters in enumerate(Z_clusters):
        for cidx in clusters:
            xvals[cidx] = i

    # Count the clusters on the y-axis
    yvals = np.zeros(traits.shape[0])
    for x in np.unique(xvals):
        binidx = np.where(xvals == x)[0]
        yvals[binidx] = np.arange(binidx.shape[0])
        
    h2_list   = traits['estimates.final.h2_observed'].fillna(1e-6).tolist()
    alphalist = [alpha_factor * max(1e-6, y) for y in h2_list]
    colorlist = [color_map[i + 1] for i in llm_clusters]
    
    # Mask the large clusters
    members, member_counts = np.unique(xvals, return_counts=True)
    mask = np.ones_like(xvals, dtype = bool)
    for m, c in zip(members, member_counts):
        if c > nmax:
            mask[np.where(xvals == m)[0]] = False

    # Put everything in the dictionary
    plot_dict = dict(
        x = xvals[mask],
        y = yvals[mask],
        trait_code = [f"{x}" for x, y in zip(llm_clusters, mask) if y],
        h2_fill_alpha = [min(0.6, x) for x, y in zip(alphalist, mask) if y],
        h2_line_alpha = [min(0.8, 1.3 * x) for x, y in zip(alphalist, mask) if y],
        fulldesc = [f"{i}"
                    + f"| {traits.loc[i, 'short_description']}"
                    + f"| {traits.loc[i, 'estimates.final.h2_observed']:.3f}"
                    + f"| {traits.loc[i, 'Neff']:.2f}" 
                    for i, y in zip(traits.index, mask) if y],
        color = [x for x,y in zip(colorlist, mask) if y],
    )
    
    plot_tooltips = [
        ("Desc", "@fulldesc"),
    ]

    ax = bokeh_figure(
        width = 800, height = 500,
        tooltips = plot_tooltips,
        title = plot_title,
        x_axis_label="Clusters identified from loadings",
        y_axis_label="Count"
    )

    ax.square('x', 'y', size = 7, 
        source = ColumnDataSource(plot_dict), 
        color = dict(field='color'),
        line_alpha = dict(field='h2_line_alpha'),
        fill_alpha = dict(field='h2_fill_alpha'),
    )
    ax.title.text_font_size = '16pt'
    ax.title.text_font_style = 'normal'
    ax.title.text_font = 'tahoma'
    ax.grid.visible = False
    ax.xaxis.major_label_text_font_size = '0pt'  # turn off x-axis tick labels
    #ax.yaxis.major_label_text_font_size = '0pt'  # turn off y-axis tick labels
    # ax.axis.axis_line_width = 0
    # ax.axis.major_tick_line_width = 0
    ax.xaxis.axis_line_color = None
    ax.xaxis.major_tick_line_color = None
    ax.xaxis.minor_tick_line_color = None
    ax.axis.axis_label_text_font_size = '16pt'
    ax.axis.axis_label_text_font_style = 'normal'
    ax.axis.axis_label_text_font = 'tahoma'
    return ax
Code
# lowX_method  = 'nnm_noRx'
lowX_ctype   = 'community'
llm_method   = "ls-da3m0ns/bge_large_medical"
llm_ctype    = "agglomerative"
alpha_factor = 20

axlist = list()
for key, data in datadict.items():
    ax = get_bokeh_plot_Zclusters(
        Z_clusters[key][lowX_ctype], # Estimated clusters from Z-scores
        data['cluster_labels'][llm_method][llm_ctype], # Estimated clusters from LLM
        data['traits'],
        color_map = hex_colors_40, alpha_factor = alpha_factor,
        plot_title = data['label'], nmax = 40
    )
    axlist.append(ax)

p = bokeh_column(*axlist)
bokeh_show(p)
(a) Comparison of clusters from Z-scores and LLM.
(b)
Figure 2