We compare the clusters obtained from the latent factors with LLM classification of the corresponding phenotype descriptions.
About
We identify clusters from the loadings of the low-dimensional representation of the PanUKB Z-score matrix. We then compare those clusters with the groups predicted by the pretrained LLM models.
Setup
We have to first load a bunch of useful tools including scikit-learn and Bokeh.
Code
import osimport numpy as npimport pandas as pdimport pickleimport reimport matplotlib.pyplot as pltfrom pymir import mpl_stylesheetfrom pymir import mpl_utilsmpl_stylesheet.banskt_presentation(splinecolor ='black', dpi =120)import umapfrom bokeh.plotting import figure as bokeh_figurefrom bokeh.plotting import show as bokeh_showfrom bokeh.layouts import column as bokeh_columnfrom bokeh.io import output_notebookfrom bokeh.models import ColumnDataSourcefrom bokeh.models import HoverToolfrom bokeh.models import CategoricalColorMapperfrom sklearn.neighbors import kneighbors_graphfrom sklearn.manifold import SpectralEmbedding, TSNE, LocallyLinearEmbedding, Isomap, MDSfrom sklearn.cluster import AgglomerativeClusteringfrom sklearn.cluster import KMeansfrom sentence_transformers import util as st_utiloutput_notebook()
Loading BokehJS ...
Load data and results
Here, we explore the low rank model from nuclear norm minimization with the sparse matrix.
hex_colors_40 = ["#084609","#ff4ff4","#01d94a","#b700ce","#91c900","#5f42ed","#5fa200","#8d6dff","#c9f06b","#0132a7","#ffbb1f","#0080ed","#f56600","#3afaf5","#c10001","#01e698","#a20096","#00e2c1","#ff5ac8","#008143","#cd0057","#4aeeff","#8c001a","#b5f2a2","#5d177d","#a99900","#e299ff","#5b6b00","#96aeff","#a46f00","#007acb","#ff9757","#00a8e0","#ff708e","#baefc7","#622b25","#c8c797","#885162","#ffb7a5","#ffa3c3"]llm_methods = ["ls-da3m0ns/bge_large_medical","medicalai/ClinicalBERT","emilyalsentzer/Bio_ClinicalBERT",]llm_ctypes = ["community", "kmeans", "agglomerative"]llm_outdir ="/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/llm"llm_clusters = {method : { x : Nonefor x in llm_ctypes } for method in llm_methods}for method in llm_methods:for ctype in llm_ctypes: m_filename = os.path.join(llm_outdir, f"{method}/{ctype}_clusters.pkl")withopen(m_filename, "rb") as fh: llm_clusters[method][ctype] = pickle.load(fh)def get_llm_cluster_labels(selectidx, method, ctype): clusteridx = np.full([selectidx.shape[0],], -1)for i, ccomps inenumerate(llm_clusters[method][ctype]):for idx in ccomps: clusteridx[idx] = ireturn clusteridxllm_cluster_labels = { method : { ctype: get_llm_cluster_labels(np.array(trait_df.index), method, ctype) for ctype in llm_ctypes } for method in llm_methods}llm_cluster_labels_noRx = { method : { ctype: [llm_cluster_labels[method][ctype][i] for i in trait_df_noRx.index] for ctype in llm_ctypes } for method in llm_methods}datadict = {'nnm_noRx': {'loadings': nnm_loadings_noRx,'cluster_labels': llm_cluster_labels_noRx,'traits': trait_df_noRx,'label': 'RPCA (no Rx)', },'nnm': {'loadings': nnm_loadings,'cluster_labels': llm_cluster_labels,'traits': trait_df,'label': 'RPCA', },'tsvd_noRx': {'loadings': tsvd_loadings_noRx,'cluster_labels': llm_cluster_labels_noRx,'traits': trait_df_noRx,'label': 'tSVD (no Rx)', },'tsvd': {'loadings': tsvd_loadings,'cluster_labels': llm_cluster_labels,'traits': trait_df,'label': 'tSVD', },}
Interactive Plots
We show interactive t-SNE plots in Figure 1. Each point is a PanUKB phenotype, colored according to the clusters identified by LLM models from the trait descriptions. The opaciy of each point is proportional to the estimated heritability of the trait.
Code
def get_bokeh_plot( embedding, cluster_labels, trait_df, plot_title, color_palette = hex_colors_40, alpha_factor =10): h2_list = [max(1e-6, x) for x in trait_df['estimates.final.h2_observed'].fillna(1e-6).tolist()] fill_alpha_list = [min(0.6, alpha_factor * x) for x in h2_list] line_alpha_list = [min(0.8, 1.3* alpha_factor * x) for x in h2_list] color_palette_mod = color_palette.copy()if-1in np.unique(cluster_labels): color_palette_mod.insert(0, "#e3e3e3") plot_dict =dict( x = embedding[:, 0], y = embedding[:, 1], trait_id = [f"{x}"for x in cluster_labels], fill_alpha = fill_alpha_list, line_alpha = line_alpha_list, fulldesc = [f"{i}"+f" | {trait_df.loc[i, 'short_description']}"+f" | {trait_df.loc[i, 'estimates.final.h2_observed']:.3f}"+f" | {trait_df.loc[i, 'Neff']:.2f}"for i in trait_df.index], ) color_mapping = CategoricalColorMapper(factors = [f"{x}"for x in np.unique(cluster_labels)], palette = color_palette_mod) plot_tooltips = [ ("Desc", "@fulldesc"), ] ax = bokeh_figure( width =800, height =800, tooltips = plot_tooltips, title = plot_title, ) ax.circle('x', 'y', size =8, source = ColumnDataSource(plot_dict), color =dict(field='trait_id', transform = color_mapping), line_alpha =dict(field='line_alpha'), fill_alpha =dict(field='fill_alpha'), ) ax.title.text_font_size ='16pt' ax.title.text_font_style ='normal' ax.title.text_font ='tahoma' ax.axis.major_label_text_font_size ='20pt' ax.axis.axis_line_width =2 ax.axis.major_tick_line_width =2 ax.grid.visible =Falsereturn ax
Z_clusters = { key:{} for key in datadict.keys() }clustering_methods = ['kmeans', 'agglomerative', 'community']for method in clustering_methods:for key, data in datadict.items():if method =='kmeans': Z_clusters[key][method] = compute_clusters_kmeans(data['loadings'], n_clusters =50)if method =='agglomerative': Z_clusters[key][method] = compute_clusters_hierarchical(data['loadings'], n_clusters =50)if method =='community': Z_clusters[key][method], threshold = compute_cluster_community(data['loadings'], n_clusters =50)# print ("Used threshold:", threshold)
Code
def get_bokeh_plot_Zclusters( Z_clusters, llm_clusters, traits, color_map = hex_colors_40, alpha_factor =30, plot_title ="", nmax =100 ):# Put cluster indices on the x-axis, # community clusters do not include all indices xvals = np.full([traits.shape[0],], -1)for i, clusters inenumerate(Z_clusters):for cidx in clusters: xvals[cidx] = i# Count the clusters on the y-axis yvals = np.zeros(traits.shape[0])for x in np.unique(xvals): binidx = np.where(xvals == x)[0] yvals[binidx] = np.arange(binidx.shape[0]) h2_list = traits['estimates.final.h2_observed'].fillna(1e-6).tolist() alphalist = [alpha_factor *max(1e-6, y) for y in h2_list] colorlist = [color_map[i +1] for i in llm_clusters]# Mask the large clusters members, member_counts = np.unique(xvals, return_counts=True) mask = np.ones_like(xvals, dtype =bool)for m, c inzip(members, member_counts):if c > nmax: mask[np.where(xvals == m)[0]] =False# Put everything in the dictionary plot_dict =dict( x = xvals[mask], y = yvals[mask], trait_code = [f"{x}"for x, y inzip(llm_clusters, mask) if y], h2_fill_alpha = [min(0.6, x) for x, y inzip(alphalist, mask) if y], h2_line_alpha = [min(0.8, 1.3* x) for x, y inzip(alphalist, mask) if y], fulldesc = [f"{i}"+f"| {traits.loc[i, 'short_description']}"+f"| {traits.loc[i, 'estimates.final.h2_observed']:.3f}"+f"| {traits.loc[i, 'Neff']:.2f}"for i, y inzip(traits.index, mask) if y], color = [x for x,y inzip(colorlist, mask) if y], ) plot_tooltips = [ ("Desc", "@fulldesc"), ] ax = bokeh_figure( width =800, height =500, tooltips = plot_tooltips, title = plot_title, x_axis_label="Clusters identified from loadings", y_axis_label="Count" ) ax.square('x', 'y', size =7, source = ColumnDataSource(plot_dict), color =dict(field='color'), line_alpha =dict(field='h2_line_alpha'), fill_alpha =dict(field='h2_fill_alpha'), ) ax.title.text_font_size ='16pt' ax.title.text_font_style ='normal' ax.title.text_font ='tahoma' ax.grid.visible =False ax.xaxis.major_label_text_font_size ='0pt'# turn off x-axis tick labels#ax.yaxis.major_label_text_font_size = '0pt' # turn off y-axis tick labels# ax.axis.axis_line_width = 0# ax.axis.major_tick_line_width = 0 ax.xaxis.axis_line_color =None ax.xaxis.major_tick_line_color =None ax.xaxis.minor_tick_line_color =None ax.axis.axis_label_text_font_size ='16pt' ax.axis.axis_label_text_font_style ='normal' ax.axis.axis_label_text_font ='tahoma'return ax