Code
import os
import numpy as np
import pandas as pd
import pickle
import re
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 120)Saikat Banerjee
April 2, 2024
| zindex | trait_type | phenocode | pheno_sex | coding | modifier | description | description_more | coding_description | category | BIN_QT | n_cases_EUR | n_controls_EUR | N | Neff | filename | aws_link | estimates.final.h2_observed | long_description | short_description | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | icd10 | A04 | both_sexes | NaN | NaN | A04 Other bacterial intestinal infections | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 3088 | 417443.0 | 420531 | 6130.649032 | icd10-A04-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0033 | A04 Other bacterial intestinal infections | A04 Bacterial intestinal infections | 
| 1 | 2 | icd10 | A08 | both_sexes | NaN | NaN | A08 Viral and other specified intestinal infec... | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 1107 | 419424.0 | 420531 | 2208.171897 | icd10-A08-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0001 | A08 Viral and other specified intestinal infec... | A08 Viral, other intestinal infections | 
| 2 | 3 | icd10 | A09 | both_sexes | NaN | NaN | A09 Diarrhoea and gastro-enteritis of presumed... | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 9029 | 411502.0 | 420531 | 17670.286180 | icd10-A09-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0035 | A09 Diarrhoea and gastro-enteritis of presumed... | A09 Diarrhoea, infectious gastro-enteritis | 
| 3 | 4 | icd10 | A41 | both_sexes | NaN | NaN | A41 Other septicaemia | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 5512 | 415019.0 | 420531 | 10879.505810 | icd10-A41-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0011 | A41 Other septicaemia | A41 Other septicaemia | 
| 4 | 5 | icd10 | B34 | both_sexes | NaN | NaN | B34 Viral infection of unspecified site | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 2129 | 418402.0 | 420531 | 4236.443249 | icd10-B34-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0003 | B34 Viral infection of unspecified site | B34 Viral infection of unspecified site | 
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | 
| 2478 | 2479 | continuous | Smoking | both_sexes | NaN | Ever_Never | Smoking status, ever vs never | Ever (previous + current smoker) vs never base... | NaN | NaN | QT | 418817 | NaN | 418817 | 418817.000000 | continuous-Smoking-both_sexes-Ever_Never.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.1100 | Smoking status, ever vs never | Smoking status, ever vs never | 
| 2479 | 2480 | continuous | eGFR | both_sexes | NaN | irnt | Estimated glomerular filtration rate, serum cr... | eGFR based on serum creatinine (30700) using t... | NaN | NaN | QT | 401867 | NaN | 401867 | 401867.000000 | continuous-eGFR-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.2070 | Estimated glomerular filtration rate, serum cr... | Estimated GFR, serum creatinine | 
| 2480 | 2481 | continuous | eGFRcreacys | both_sexes | NaN | irnt | Estimated glomerular filtration rate, cystain C | eGFR based on cystain C (30720) using the CKD-... | NaN | NaN | QT | 401570 | NaN | 401570 | 401570.000000 | continuous-eGFRcreacys-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.2380 | Estimated glomerular filtration rate, cystain C | Estimated GFR, cystain C | 
| 2481 | 2482 | continuous | eGFRcys | both_sexes | NaN | irnt | Estimated glomerular filtration rate, serum cr... | eGFR based on serum creatinine (30700) and cys... | NaN | NaN | QT | 402031 | NaN | 402031 | 402031.000000 | continuous-eGFRcys-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.2240 | Estimated glomerular filtration rate, serum cr... | Estimated GFR, serum creatinine + cystain C | 
| 2482 | 2483 | continuous | whr | both_sexes | NaN | irnt | pheno 48 / pheno 49 | NaN | NaN | NaN | QT | 420531 | NaN | 420531 | 420531.000000 | continuous-whr-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.1740 | pheno 48 / pheno 49 | Pheno 48 / pheno 49 | 
2483 rows × 20 columns
First we convert all descriptions to sentences, which will be used for embedding.
I decided to use Sentence Transformer for clustering the long description of the phenotypes. This is an ideal usecase for LLM models.
Several pre-trained models are available on Hugginface. Any pre-trained models can be loaded. Some of them are trained for Sentence Transformer, as you can search here. I found a few which are trained on medical data. 1. https://huggingface.co/menadsa/S-Bio_ClinicalBERT 2. https://huggingface.co/ls-da3m0ns/bge_large_medical
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(sentences, model_name, use_pooling = False):
    
    if use_pooling:
        # Load model from HuggingFace Hub
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name)
        
        # Tokenize sentences
        encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)
        # Perform pooling. In this case, mean pooling.
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    else:
        model = SentenceTransformer(model_name)
        sentence_embeddings = model.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_tensor=True)
        
    return sentence_embeddingsAsking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
We can either use K-Means clustering or community detection
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans
def compute_clusters_kmeans(embeddings, n_clusters = 30):
    model = KMeans(n_clusters = n_clusters, random_state=0, max_iter = 10000, n_init='auto')
    model.fit(embeddings)
    cluster_assignment = model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters
def compute_clusters_hierarchical(embeddings, n_clusters = 30):
    clustering_model = AgglomerativeClustering(n_clusters = n_clusters)
    clustering_model.fit(embeddings)
    cluster_assignment = clustering_model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters
def compute_cluster_community(embeddings, n_size = 10, n_clusters = 30, thres_step = 0.05):
    threshold = 1.0
    clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    while len(clusters) < n_clusters:
        threshold -= thres_step
        clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    return clusters, thresholdclusters_community = dict()
clusters_kmeans = dict()
clusters_agglom = dict()
for model_name in model_names.keys():
    clusters_community[model_name], threshold = compute_cluster_community(embeddings[model_name], thres_step = 0.01, n_clusters = 30)
    clusters_kmeans[model_name] = compute_clusters_kmeans(embeddings[model_name], n_clusters = 30)
    clusters_agglom[model_name] = compute_clusters_hierarchical(embeddings[model_name], n_clusters = 30)
    
    kmeans_sizes = ", ".join([f"{len(x)}" for x in clusters_kmeans[model_name]])
    agglom_sizes = ", ".join([f"{len(x)}" for x in clusters_agglom[model_name]])
    print (f"{model_name}\n\tCommunity: {len(clusters_community[model_name])} clusters | auto threshold = {threshold:.2f}\n" +
           f"\tKMeans: [{kmeans_sizes}]\n" +
           f"\tCosine: [{agglom_sizes}]\n")ls-da3m0ns/bge_large_medical
    Community: 35 clusters | auto threshold = 0.68
    KMeans: [81, 57, 54, 138, 62, 76, 28, 142, 110, 122, 94, 40, 123, 23, 73, 98, 79, 139, 52, 125, 56, 126, 67, 93, 28, 54, 144, 62, 44, 93]
    Cosine: [154, 188, 99, 49, 86, 150, 123, 111, 45, 186, 56, 54, 187, 189, 65, 30, 50, 35, 52, 43, 98, 55, 23, 24, 12, 62, 100, 35, 37, 85]
medicalai/ClinicalBERT
    Community: 34 clusters | auto threshold = 0.79
    KMeans: [63, 120, 110, 39, 110, 77, 63, 40, 160, 22, 84, 93, 146, 19, 62, 107, 85, 74, 151, 86, 140, 47, 84, 46, 37, 82, 135, 18, 123, 60]
    Cosine: [72, 68, 87, 134, 38, 268, 23, 144, 58, 76, 104, 188, 155, 71, 104, 54, 72, 45, 56, 43, 21, 24, 94, 65, 85, 33, 189, 59, 32, 21]
emilyalsentzer/Bio_ClinicalBERT
    Community: 43 clusters | auto threshold = 0.92
    KMeans: [42, 139, 124, 91, 49, 53, 65, 59, 103, 101, 136, 68, 36, 73, 127, 92, 44, 142, 60, 29, 73, 60, 107, 32, 90, 136, 88, 95, 89, 80]
    Cosine: [231, 130, 50, 106, 73, 34, 163, 47, 178, 115, 74, 62, 69, 122, 98, 109, 31, 55, 88, 115, 29, 76, 58, 114, 27, 43, 40, 39, 31, 76]
for m, citems in clusters_agglom.items():
    
    print (f"Model name: {m}")
    
    nc = len(citems)
    
    # The top 3 and bottom 3 elements for the top 3 and bottom 3 clusters
    for i in list(range(3)) + list(range(nc))[-3:]:
        cluster = citems[i]
        print("\nCluster {}, #{} Elements ".format(i + 1, len(cluster)))
        for sentence_id in cluster[0:3]:
            print("\t", long_desc[sentence_id])
        print("\t", "...")
        for sentence_id in cluster[-3:]:
            print("\t", long_desc[sentence_id])
    print ("---------------------")
    print ()Model name: ls-da3m0ns/bge_large_medical
Cluster 1, #154 Elements 
     G47 Sleep disorders
     R47 Speech disturbances, not elsewhere classified
     R49 Voice disturbances
     ...
     Duration walking for pleasure
     Frequency of strenuous sports in last 4 weeks
     FEV1/FVC ratio
Cluster 2, #188 Elements 
     D68 Other coagulation defects
     D69 Purpura and other haemorrhagic conditions
     E87 Other disorders of fluid, electrolyte and acid-base balance
     ...
     ECG, phase time
     ECG, number of stages in a phase
     Pulse rate (during blood-pressure measurement)
Cluster 3, #99 Elements 
     N41 Inflammatory diseases of prostate
     N42 Other disorders of prostate
     N43 Hydrocele and spermatocele
     ...
     Number of stillbirths
     Number of spontaneous miscarriages
     Number of pregnancy terminations
Cluster 28, #35 Elements 
     E03 Other hypothyroidism
     E04 Other non-toxic goitre
     E05 Thyrotoxicosis [hyperthyroidism]
     ...
     levothyroxine sodium medication
     Oestradiol
     Testosterone
Cluster 29, #37 Elements 
     S00 Superficial injury of head
     S01 Open wound of head
     S09 Other and unspecified injuries of head
     ...
     Contusion
     fracture lower leg / ankle self-reported
     Falls in the last year
Cluster 30, #85 Elements 
     F05 Delirium, not induced by alcohol and other psychoactive substances
     F31 Bipolar affective disorder
     F32 Depressive episode
     ...
     Financial situation satisfaction
     Longest period of depression
     Number of depression episodes
---------------------
Model name: medicalai/ClinicalBERT
Cluster 1, #72 Elements 
     D50 Iron deficiency anaemia
     D51 Vitamin B12 deficiency anaemia
     E53 Deficiency of other B group vitamins
     ...
     Vitamin C
     Vitamin D
     Vitamin E
Cluster 2, #68 Elements 
     G45 Transient cerebral ischaemic attacks and related syndromes
     I21 Acute myocardial infarction
     I24 Other acute ischaemic heart diseases
     ...
     atrial fibrillation self-reported
     Ventricular rate
     Age deep-vein thrombosis (DVT, blood clot in leg) diagnosed
Cluster 3, #87 Elements 
     C43 Malignant melanoma of skin
     D05 Carcinoma in situ of breast
     K04 Diseases of pulp and periapical tissues
     ...
     Trunk fat mass
     Trunk fat-free mass
     Trunk predicted mass
Cluster 28, #59 Elements 
     5-alpha reductase inhibitor|BPH|benign prostatic hyperplasia
     ACE inhibitor|anti-hypertensive
     DPP-4 inhibitor|diabetes
     ...
     sodium/potassium transporter inhibitor|heart failure
     tricyclic antihistamine|selective histamine H1 inverse agonist|antihistamine
     xanthine oxidase inhibitor|anti-gout agent
Cluster 29, #32 Elements 
     E10 Insulin-dependent diabetes mellitus
     E11 Non-insulin-dependent diabetes mellitus
     E14 Unspecified diabetes mellitus
     ...
     Diabetes (mother illness)
     Most recent bowel cancer screening
     Age high blood pressure diagnosed
Cluster 30, #21 Elements 
     Diastolic blood pressure, automated reading
     Systolic blood pressure, automated reading
     Systolic blood pressure, manual reading
     ...
     Systolic blood pressure, combined automated + manual reading
     Systolic blood pressure, combined automated + manual reading, adjusted by medication
     Systolic blood pressure, manual reading, adjusted by medication
---------------------
Model name: emilyalsentzer/Bio_ClinicalBERT
Cluster 1, #231 Elements 
     H00 Hordeolum and chalazion
     Desogestrel
     albuterol
     ...
     Retinol
     Carotene
     pheno 48 / pheno 49
Cluster 2, #130 Elements 
     A41 Other septicaemia
     B37 Candidiasis
     B95 Streptococcus and staphylococcus as the cause of diseases classified to other chapters
     ...
     Facial ageing
     Relative age of first facial hair
     Tinnitus severity/nuisance
Cluster 3, #50 Elements 
     Bacterial enteritis
     Viral Enteritis
     Bacterial infection NOS
     ...
     Financial situation satisfaction
     Private healthcare
     Noisy workplace
Cluster 28, #39 Elements 
     R00 Abnormalities of heart beat
     R01 Cardiac murmurs and other cardiac sounds
     R06 Abnormalities of breathing
     ...
     Malaise and fatigue
     Leg pain on walking : action taken
     Multi-site chronic pain
Cluster 29, #31 Elements 
     Pulse rate, automated reading
     Ventricular rate
     Basal metabolic rate
     ...
     Systolic blood pressure, combined automated + manual reading
     Systolic blood pressure, combined automated + manual reading, adjusted by medication
     Systolic blood pressure, manual reading, adjusted by medication
Cluster 30, #76 Elements 
     H04 Disorders of lachrymal system
     H52 Disorders of refraction and accommodation
     I08 Multiple valve diseases
     ...
     Duration of fitness test
     Duration of walks
     LDL direct, adjusted by medication
---------------------
def save_cluster_list(filepath, clusters):
    dirname = os.path.dirname(filepath)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    with open(filepath, "wb") as fh:
        pickle.dump(clusters, fh, protocol=pickle.HIGHEST_PROTOCOL)
        
outdir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/llm"
for method, clusters in clusters_community.items():
    m_filename = os.path.join(outdir, f"{method}/community_clusters.pkl")
    save_cluster_list(m_filename, clusters)
    
for method, clusters in clusters_kmeans.items():
    m_filename = os.path.join(outdir, f"{method}/kmeans_clusters.pkl")
    save_cluster_list(m_filename, clusters)
    
for method, clusters in clusters_agglom.items():
    m_filename = os.path.join(outdir, f"{method}/agglomerative_clusters.pkl")
    save_cluster_list(m_filename, clusters)