Semantic clustering of PanUKB phenotypes

Author

Saikat Banerjee

Published

April 2, 2024

Abstract
We use pre-trained LLM models for clustering of PanUKB phenotypes
Code
import os
import numpy as np
import pandas as pd
import pickle
import re

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils

mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 120)
Code
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util
Code
data_dir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/data"
trait_df  = pd.read_pickle(os.path.join(data_dir, f"modselect/traits_all_with_desc.pkl"))
trait_df
zindex trait_type phenocode pheno_sex coding modifier description description_more coding_description category BIN_QT n_cases_EUR n_controls_EUR N Neff filename aws_link estimates.final.h2_observed long_description short_description
0 1 icd10 A04 both_sexes NaN NaN A04 Other bacterial intestinal infections truncated: true NaN Chapter I Certain infectious and parasitic dis... BIN 3088 417443.0 420531 6130.649032 icd10-A04-both_sexes.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.0033 A04 Other bacterial intestinal infections A04 Bacterial intestinal infections
1 2 icd10 A08 both_sexes NaN NaN A08 Viral and other specified intestinal infec... truncated: true NaN Chapter I Certain infectious and parasitic dis... BIN 1107 419424.0 420531 2208.171897 icd10-A08-both_sexes.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.0001 A08 Viral and other specified intestinal infec... A08 Viral, other intestinal infections
2 3 icd10 A09 both_sexes NaN NaN A09 Diarrhoea and gastro-enteritis of presumed... truncated: true NaN Chapter I Certain infectious and parasitic dis... BIN 9029 411502.0 420531 17670.286180 icd10-A09-both_sexes.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.0035 A09 Diarrhoea and gastro-enteritis of presumed... A09 Diarrhoea, infectious gastro-enteritis
3 4 icd10 A41 both_sexes NaN NaN A41 Other septicaemia truncated: true NaN Chapter I Certain infectious and parasitic dis... BIN 5512 415019.0 420531 10879.505810 icd10-A41-both_sexes.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.0011 A41 Other septicaemia A41 Other septicaemia
4 5 icd10 B34 both_sexes NaN NaN B34 Viral infection of unspecified site truncated: true NaN Chapter I Certain infectious and parasitic dis... BIN 2129 418402.0 420531 4236.443249 icd10-B34-both_sexes.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.0003 B34 Viral infection of unspecified site B34 Viral infection of unspecified site
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2478 2479 continuous Smoking both_sexes NaN Ever_Never Smoking status, ever vs never Ever (previous + current smoker) vs never base... NaN NaN QT 418817 NaN 418817 418817.000000 continuous-Smoking-both_sexes-Ever_Never.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.1100 Smoking status, ever vs never Smoking status, ever vs never
2479 2480 continuous eGFR both_sexes NaN irnt Estimated glomerular filtration rate, serum cr... eGFR based on serum creatinine (30700) using t... NaN NaN QT 401867 NaN 401867 401867.000000 continuous-eGFR-both_sexes-irnt.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.2070 Estimated glomerular filtration rate, serum cr... Estimated GFR, serum creatinine
2480 2481 continuous eGFRcreacys both_sexes NaN irnt Estimated glomerular filtration rate, cystain C eGFR based on cystain C (30720) using the CKD-... NaN NaN QT 401570 NaN 401570 401570.000000 continuous-eGFRcreacys-both_sexes-irnt.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.2380 Estimated glomerular filtration rate, cystain C Estimated GFR, cystain C
2481 2482 continuous eGFRcys both_sexes NaN irnt Estimated glomerular filtration rate, serum cr... eGFR based on serum creatinine (30700) and cys... NaN NaN QT 402031 NaN 402031 402031.000000 continuous-eGFRcys-both_sexes-irnt.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.2240 Estimated glomerular filtration rate, serum cr... Estimated GFR, serum creatinine + cystain C
2482 2483 continuous whr both_sexes NaN irnt pheno 48 / pheno 49 NaN NaN NaN QT 420531 NaN 420531 420531.000000 continuous-whr-both_sexes-irnt.tsv.bgz https://pan-ukb-us-east-1.s3.amazonaws.com/sum... 0.1740 pheno 48 / pheno 49 Pheno 48 / pheno 49

2483 rows × 20 columns

Phenotype Description to Sentences

First we convert all descriptions to sentences, which will be used for embedding.

Code
long_desc  = trait_df['long_description'].tolist()

Get embedding from sentences

I decided to use Sentence Transformer for clustering the long description of the phenotypes. This is an ideal usecase for LLM models.

Several pre-trained models are available on Hugginface. Any pre-trained models can be loaded. Some of them are trained for Sentence Transformer, as you can search here. I found a few which are trained on medical data. 1. https://huggingface.co/menadsa/S-Bio_ClinicalBERT 2. https://huggingface.co/ls-da3m0ns/bge_large_medical

Code
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embeddings(sentences, model_name, use_pooling = False):
    
    if use_pooling:
        # Load model from HuggingFace Hub
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name)
        
        # Tokenize sentences
        encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

        # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)

        # Perform pooling. In this case, mean pooling.
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    else:
        model = SentenceTransformer(model_name)
        sentence_embeddings = model.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_tensor=True)
        
    return sentence_embeddings
Code
model_names = {
    "ls-da3m0ns/bge_large_medical" : "SentenceTransformer",
    "medicalai/ClinicalBERT" : "Transformer",
    "emilyalsentzer/Bio_ClinicalBERT" : "Transformer"
}
Code
embeddings = dict()
for model_name, model_type in model_names.items():
    use_pooling = False if model_type == "SentenceTransformer" else True
    embeddings[model_name] = get_embeddings(long_desc, model_name, use_pooling = use_pooling)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.

Compute clusters from embeddings

We can either use K-Means clustering or community detection

Code
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans

def compute_clusters_kmeans(embeddings, n_clusters = 30):
    model = KMeans(n_clusters = n_clusters, random_state=0, max_iter = 10000, n_init='auto')
    model.fit(embeddings)
    cluster_assignment = model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters

def compute_clusters_hierarchical(embeddings, n_clusters = 30):
    clustering_model = AgglomerativeClustering(n_clusters = n_clusters)
    clustering_model.fit(embeddings)
    cluster_assignment = clustering_model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters

def compute_cluster_community(embeddings, n_size = 10, n_clusters = 30, thres_step = 0.05):
    threshold = 1.0
    clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    while len(clusters) < n_clusters:
        threshold -= thres_step
        clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    return clusters, threshold
Code
clusters_community = dict()
clusters_kmeans = dict()
clusters_agglom = dict()
for model_name in model_names.keys():
    clusters_community[model_name], threshold = compute_cluster_community(embeddings[model_name], thres_step = 0.01, n_clusters = 30)
    clusters_kmeans[model_name] = compute_clusters_kmeans(embeddings[model_name], n_clusters = 30)
    clusters_agglom[model_name] = compute_clusters_hierarchical(embeddings[model_name], n_clusters = 30)
    
    kmeans_sizes = ", ".join([f"{len(x)}" for x in clusters_kmeans[model_name]])
    agglom_sizes = ", ".join([f"{len(x)}" for x in clusters_agglom[model_name]])
    print (f"{model_name}\n\tCommunity: {len(clusters_community[model_name])} clusters | auto threshold = {threshold:.2f}\n" +
           f"\tKMeans: [{kmeans_sizes}]\n" +
           f"\tCosine: [{agglom_sizes}]\n")
ls-da3m0ns/bge_large_medical
    Community: 35 clusters | auto threshold = 0.68
    KMeans: [81, 57, 54, 138, 62, 76, 28, 142, 110, 122, 94, 40, 123, 23, 73, 98, 79, 139, 52, 125, 56, 126, 67, 93, 28, 54, 144, 62, 44, 93]
    Cosine: [154, 188, 99, 49, 86, 150, 123, 111, 45, 186, 56, 54, 187, 189, 65, 30, 50, 35, 52, 43, 98, 55, 23, 24, 12, 62, 100, 35, 37, 85]

medicalai/ClinicalBERT
    Community: 34 clusters | auto threshold = 0.79
    KMeans: [63, 120, 110, 39, 110, 77, 63, 40, 160, 22, 84, 93, 146, 19, 62, 107, 85, 74, 151, 86, 140, 47, 84, 46, 37, 82, 135, 18, 123, 60]
    Cosine: [72, 68, 87, 134, 38, 268, 23, 144, 58, 76, 104, 188, 155, 71, 104, 54, 72, 45, 56, 43, 21, 24, 94, 65, 85, 33, 189, 59, 32, 21]

emilyalsentzer/Bio_ClinicalBERT
    Community: 43 clusters | auto threshold = 0.92
    KMeans: [42, 139, 124, 91, 49, 53, 65, 59, 103, 101, 136, 68, 36, 73, 127, 92, 44, 142, 60, 29, 73, 60, 107, 32, 90, 136, 88, 95, 89, 80]
    Cosine: [231, 130, 50, 106, 73, 34, 163, 47, 178, 115, 74, 62, 69, 122, 98, 109, 31, 55, 88, 115, 29, 76, 58, 114, 27, 43, 40, 39, 31, 76]
Code
for m, citems in clusters_agglom.items():
    
    print (f"Model name: {m}")
    
    nc = len(citems)
    
    # The top 3 and bottom 3 elements for the top 3 and bottom 3 clusters
    for i in list(range(3)) + list(range(nc))[-3:]:
        cluster = citems[i]
        print("\nCluster {}, #{} Elements ".format(i + 1, len(cluster)))
        for sentence_id in cluster[0:3]:
            print("\t", long_desc[sentence_id])
        print("\t", "...")
        for sentence_id in cluster[-3:]:
            print("\t", long_desc[sentence_id])
    print ("---------------------")
    print ()
Model name: ls-da3m0ns/bge_large_medical

Cluster 1, #154 Elements 
     G47 Sleep disorders
     R47 Speech disturbances, not elsewhere classified
     R49 Voice disturbances
     ...
     Duration walking for pleasure
     Frequency of strenuous sports in last 4 weeks
     FEV1/FVC ratio

Cluster 2, #188 Elements 
     D68 Other coagulation defects
     D69 Purpura and other haemorrhagic conditions
     E87 Other disorders of fluid, electrolyte and acid-base balance
     ...
     ECG, phase time
     ECG, number of stages in a phase
     Pulse rate (during blood-pressure measurement)

Cluster 3, #99 Elements 
     N41 Inflammatory diseases of prostate
     N42 Other disorders of prostate
     N43 Hydrocele and spermatocele
     ...
     Number of stillbirths
     Number of spontaneous miscarriages
     Number of pregnancy terminations

Cluster 28, #35 Elements 
     E03 Other hypothyroidism
     E04 Other non-toxic goitre
     E05 Thyrotoxicosis [hyperthyroidism]
     ...
     levothyroxine sodium medication
     Oestradiol
     Testosterone

Cluster 29, #37 Elements 
     S00 Superficial injury of head
     S01 Open wound of head
     S09 Other and unspecified injuries of head
     ...
     Contusion
     fracture lower leg / ankle self-reported
     Falls in the last year

Cluster 30, #85 Elements 
     F05 Delirium, not induced by alcohol and other psychoactive substances
     F31 Bipolar affective disorder
     F32 Depressive episode
     ...
     Financial situation satisfaction
     Longest period of depression
     Number of depression episodes
---------------------

Model name: medicalai/ClinicalBERT

Cluster 1, #72 Elements 
     D50 Iron deficiency anaemia
     D51 Vitamin B12 deficiency anaemia
     E53 Deficiency of other B group vitamins
     ...
     Vitamin C
     Vitamin D
     Vitamin E

Cluster 2, #68 Elements 
     G45 Transient cerebral ischaemic attacks and related syndromes
     I21 Acute myocardial infarction
     I24 Other acute ischaemic heart diseases
     ...
     atrial fibrillation self-reported
     Ventricular rate
     Age deep-vein thrombosis (DVT, blood clot in leg) diagnosed

Cluster 3, #87 Elements 
     C43 Malignant melanoma of skin
     D05 Carcinoma in situ of breast
     K04 Diseases of pulp and periapical tissues
     ...
     Trunk fat mass
     Trunk fat-free mass
     Trunk predicted mass

Cluster 28, #59 Elements 
     5-alpha reductase inhibitor|BPH|benign prostatic hyperplasia
     ACE inhibitor|anti-hypertensive
     DPP-4 inhibitor|diabetes
     ...
     sodium/potassium transporter inhibitor|heart failure
     tricyclic antihistamine|selective histamine H1 inverse agonist|antihistamine
     xanthine oxidase inhibitor|anti-gout agent

Cluster 29, #32 Elements 
     E10 Insulin-dependent diabetes mellitus
     E11 Non-insulin-dependent diabetes mellitus
     E14 Unspecified diabetes mellitus
     ...
     Diabetes (mother illness)
     Most recent bowel cancer screening
     Age high blood pressure diagnosed

Cluster 30, #21 Elements 
     Diastolic blood pressure, automated reading
     Systolic blood pressure, automated reading
     Systolic blood pressure, manual reading
     ...
     Systolic blood pressure, combined automated + manual reading
     Systolic blood pressure, combined automated + manual reading, adjusted by medication
     Systolic blood pressure, manual reading, adjusted by medication
---------------------

Model name: emilyalsentzer/Bio_ClinicalBERT

Cluster 1, #231 Elements 
     H00 Hordeolum and chalazion
     Desogestrel
     albuterol
     ...
     Retinol
     Carotene
     pheno 48 / pheno 49

Cluster 2, #130 Elements 
     A41 Other septicaemia
     B37 Candidiasis
     B95 Streptococcus and staphylococcus as the cause of diseases classified to other chapters
     ...
     Facial ageing
     Relative age of first facial hair
     Tinnitus severity/nuisance

Cluster 3, #50 Elements 
     Bacterial enteritis
     Viral Enteritis
     Bacterial infection NOS
     ...
     Financial situation satisfaction
     Private healthcare
     Noisy workplace

Cluster 28, #39 Elements 
     R00 Abnormalities of heart beat
     R01 Cardiac murmurs and other cardiac sounds
     R06 Abnormalities of breathing
     ...
     Malaise and fatigue
     Leg pain on walking : action taken
     Multi-site chronic pain

Cluster 29, #31 Elements 
     Pulse rate, automated reading
     Ventricular rate
     Basal metabolic rate
     ...
     Systolic blood pressure, combined automated + manual reading
     Systolic blood pressure, combined automated + manual reading, adjusted by medication
     Systolic blood pressure, manual reading, adjusted by medication

Cluster 30, #76 Elements 
     H04 Disorders of lachrymal system
     H52 Disorders of refraction and accommodation
     I08 Multiple valve diseases
     ...
     Duration of fitness test
     Duration of walks
     LDL direct, adjusted by medication
---------------------

Save the clusters

Code
def save_cluster_list(filepath, clusters):
    dirname = os.path.dirname(filepath)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    with open(filepath, "wb") as fh:
        pickle.dump(clusters, fh, protocol=pickle.HIGHEST_PROTOCOL)
        
outdir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/llm"

for method, clusters in clusters_community.items():
    m_filename = os.path.join(outdir, f"{method}/community_clusters.pkl")
    save_cluster_list(m_filename, clusters)
    
for method, clusters in clusters_kmeans.items():
    m_filename = os.path.join(outdir, f"{method}/kmeans_clusters.pkl")
    save_cluster_list(m_filename, clusters)
    
for method, clusters in clusters_agglom.items():
    m_filename = os.path.join(outdir, f"{method}/agglomerative_clusters.pkl")
    save_cluster_list(m_filename, clusters)