Code
import os
import numpy as np
import pandas as pd
import pickle
import re
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
= 'black', dpi = 120) mpl_stylesheet.banskt_presentation(splinecolor
Saikat Banerjee
April 2, 2024
zindex | trait_type | phenocode | pheno_sex | coding | modifier | description | description_more | coding_description | category | BIN_QT | n_cases_EUR | n_controls_EUR | N | Neff | filename | aws_link | estimates.final.h2_observed | long_description | short_description | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | icd10 | A04 | both_sexes | NaN | NaN | A04 Other bacterial intestinal infections | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 3088 | 417443.0 | 420531 | 6130.649032 | icd10-A04-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0033 | A04 Other bacterial intestinal infections | A04 Bacterial intestinal infections |
1 | 2 | icd10 | A08 | both_sexes | NaN | NaN | A08 Viral and other specified intestinal infec... | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 1107 | 419424.0 | 420531 | 2208.171897 | icd10-A08-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0001 | A08 Viral and other specified intestinal infec... | A08 Viral, other intestinal infections |
2 | 3 | icd10 | A09 | both_sexes | NaN | NaN | A09 Diarrhoea and gastro-enteritis of presumed... | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 9029 | 411502.0 | 420531 | 17670.286180 | icd10-A09-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0035 | A09 Diarrhoea and gastro-enteritis of presumed... | A09 Diarrhoea, infectious gastro-enteritis |
3 | 4 | icd10 | A41 | both_sexes | NaN | NaN | A41 Other septicaemia | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 5512 | 415019.0 | 420531 | 10879.505810 | icd10-A41-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0011 | A41 Other septicaemia | A41 Other septicaemia |
4 | 5 | icd10 | B34 | both_sexes | NaN | NaN | B34 Viral infection of unspecified site | truncated: true | NaN | Chapter I Certain infectious and parasitic dis... | BIN | 2129 | 418402.0 | 420531 | 4236.443249 | icd10-B34-both_sexes.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.0003 | B34 Viral infection of unspecified site | B34 Viral infection of unspecified site |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2478 | 2479 | continuous | Smoking | both_sexes | NaN | Ever_Never | Smoking status, ever vs never | Ever (previous + current smoker) vs never base... | NaN | NaN | QT | 418817 | NaN | 418817 | 418817.000000 | continuous-Smoking-both_sexes-Ever_Never.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.1100 | Smoking status, ever vs never | Smoking status, ever vs never |
2479 | 2480 | continuous | eGFR | both_sexes | NaN | irnt | Estimated glomerular filtration rate, serum cr... | eGFR based on serum creatinine (30700) using t... | NaN | NaN | QT | 401867 | NaN | 401867 | 401867.000000 | continuous-eGFR-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.2070 | Estimated glomerular filtration rate, serum cr... | Estimated GFR, serum creatinine |
2480 | 2481 | continuous | eGFRcreacys | both_sexes | NaN | irnt | Estimated glomerular filtration rate, cystain C | eGFR based on cystain C (30720) using the CKD-... | NaN | NaN | QT | 401570 | NaN | 401570 | 401570.000000 | continuous-eGFRcreacys-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.2380 | Estimated glomerular filtration rate, cystain C | Estimated GFR, cystain C |
2481 | 2482 | continuous | eGFRcys | both_sexes | NaN | irnt | Estimated glomerular filtration rate, serum cr... | eGFR based on serum creatinine (30700) and cys... | NaN | NaN | QT | 402031 | NaN | 402031 | 402031.000000 | continuous-eGFRcys-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.2240 | Estimated glomerular filtration rate, serum cr... | Estimated GFR, serum creatinine + cystain C |
2482 | 2483 | continuous | whr | both_sexes | NaN | irnt | pheno 48 / pheno 49 | NaN | NaN | NaN | QT | 420531 | NaN | 420531 | 420531.000000 | continuous-whr-both_sexes-irnt.tsv.bgz | https://pan-ukb-us-east-1.s3.amazonaws.com/sum... | 0.1740 | pheno 48 / pheno 49 | Pheno 48 / pheno 49 |
2483 rows × 20 columns
First we convert all descriptions to sentences, which will be used for embedding.
I decided to use Sentence Transformer for clustering the long description of the phenotypes. This is an ideal usecase for LLM models.
Several pre-trained models are available on Hugginface. Any pre-trained models can be loaded. Some of them are trained for Sentence Transformer, as you can search here. I found a few which are trained on medical data. 1. https://huggingface.co/menadsa/S-Bio_ClinicalBERT 2. https://huggingface.co/ls-da3m0ns/bge_large_medical
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(sentences, model_name, use_pooling = False):
if use_pooling:
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling. In this case, mean pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
else:
model = SentenceTransformer(model_name)
sentence_embeddings = model.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_tensor=True)
return sentence_embeddings
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
We can either use K-Means clustering or community detection
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans
def compute_clusters_kmeans(embeddings, n_clusters = 30):
model = KMeans(n_clusters = n_clusters, random_state=0, max_iter = 10000, n_init='auto')
model.fit(embeddings)
cluster_assignment = model.labels_
phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
return phenotype_clusters
def compute_clusters_hierarchical(embeddings, n_clusters = 30):
clustering_model = AgglomerativeClustering(n_clusters = n_clusters)
clustering_model.fit(embeddings)
cluster_assignment = clustering_model.labels_
phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
return phenotype_clusters
def compute_cluster_community(embeddings, n_size = 10, n_clusters = 30, thres_step = 0.05):
threshold = 1.0
clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
while len(clusters) < n_clusters:
threshold -= thres_step
clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
return clusters, threshold
clusters_community = dict()
clusters_kmeans = dict()
clusters_agglom = dict()
for model_name in model_names.keys():
clusters_community[model_name], threshold = compute_cluster_community(embeddings[model_name], thres_step = 0.01, n_clusters = 30)
clusters_kmeans[model_name] = compute_clusters_kmeans(embeddings[model_name], n_clusters = 30)
clusters_agglom[model_name] = compute_clusters_hierarchical(embeddings[model_name], n_clusters = 30)
kmeans_sizes = ", ".join([f"{len(x)}" for x in clusters_kmeans[model_name]])
agglom_sizes = ", ".join([f"{len(x)}" for x in clusters_agglom[model_name]])
print (f"{model_name}\n\tCommunity: {len(clusters_community[model_name])} clusters | auto threshold = {threshold:.2f}\n" +
f"\tKMeans: [{kmeans_sizes}]\n" +
f"\tCosine: [{agglom_sizes}]\n")
ls-da3m0ns/bge_large_medical
Community: 35 clusters | auto threshold = 0.68
KMeans: [81, 57, 54, 138, 62, 76, 28, 142, 110, 122, 94, 40, 123, 23, 73, 98, 79, 139, 52, 125, 56, 126, 67, 93, 28, 54, 144, 62, 44, 93]
Cosine: [154, 188, 99, 49, 86, 150, 123, 111, 45, 186, 56, 54, 187, 189, 65, 30, 50, 35, 52, 43, 98, 55, 23, 24, 12, 62, 100, 35, 37, 85]
medicalai/ClinicalBERT
Community: 34 clusters | auto threshold = 0.79
KMeans: [63, 120, 110, 39, 110, 77, 63, 40, 160, 22, 84, 93, 146, 19, 62, 107, 85, 74, 151, 86, 140, 47, 84, 46, 37, 82, 135, 18, 123, 60]
Cosine: [72, 68, 87, 134, 38, 268, 23, 144, 58, 76, 104, 188, 155, 71, 104, 54, 72, 45, 56, 43, 21, 24, 94, 65, 85, 33, 189, 59, 32, 21]
emilyalsentzer/Bio_ClinicalBERT
Community: 43 clusters | auto threshold = 0.92
KMeans: [42, 139, 124, 91, 49, 53, 65, 59, 103, 101, 136, 68, 36, 73, 127, 92, 44, 142, 60, 29, 73, 60, 107, 32, 90, 136, 88, 95, 89, 80]
Cosine: [231, 130, 50, 106, 73, 34, 163, 47, 178, 115, 74, 62, 69, 122, 98, 109, 31, 55, 88, 115, 29, 76, 58, 114, 27, 43, 40, 39, 31, 76]
for m, citems in clusters_agglom.items():
print (f"Model name: {m}")
nc = len(citems)
# The top 3 and bottom 3 elements for the top 3 and bottom 3 clusters
for i in list(range(3)) + list(range(nc))[-3:]:
cluster = citems[i]
print("\nCluster {}, #{} Elements ".format(i + 1, len(cluster)))
for sentence_id in cluster[0:3]:
print("\t", long_desc[sentence_id])
print("\t", "...")
for sentence_id in cluster[-3:]:
print("\t", long_desc[sentence_id])
print ("---------------------")
print ()
Model name: ls-da3m0ns/bge_large_medical
Cluster 1, #154 Elements
G47 Sleep disorders
R47 Speech disturbances, not elsewhere classified
R49 Voice disturbances
...
Duration walking for pleasure
Frequency of strenuous sports in last 4 weeks
FEV1/FVC ratio
Cluster 2, #188 Elements
D68 Other coagulation defects
D69 Purpura and other haemorrhagic conditions
E87 Other disorders of fluid, electrolyte and acid-base balance
...
ECG, phase time
ECG, number of stages in a phase
Pulse rate (during blood-pressure measurement)
Cluster 3, #99 Elements
N41 Inflammatory diseases of prostate
N42 Other disorders of prostate
N43 Hydrocele and spermatocele
...
Number of stillbirths
Number of spontaneous miscarriages
Number of pregnancy terminations
Cluster 28, #35 Elements
E03 Other hypothyroidism
E04 Other non-toxic goitre
E05 Thyrotoxicosis [hyperthyroidism]
...
levothyroxine sodium medication
Oestradiol
Testosterone
Cluster 29, #37 Elements
S00 Superficial injury of head
S01 Open wound of head
S09 Other and unspecified injuries of head
...
Contusion
fracture lower leg / ankle self-reported
Falls in the last year
Cluster 30, #85 Elements
F05 Delirium, not induced by alcohol and other psychoactive substances
F31 Bipolar affective disorder
F32 Depressive episode
...
Financial situation satisfaction
Longest period of depression
Number of depression episodes
---------------------
Model name: medicalai/ClinicalBERT
Cluster 1, #72 Elements
D50 Iron deficiency anaemia
D51 Vitamin B12 deficiency anaemia
E53 Deficiency of other B group vitamins
...
Vitamin C
Vitamin D
Vitamin E
Cluster 2, #68 Elements
G45 Transient cerebral ischaemic attacks and related syndromes
I21 Acute myocardial infarction
I24 Other acute ischaemic heart diseases
...
atrial fibrillation self-reported
Ventricular rate
Age deep-vein thrombosis (DVT, blood clot in leg) diagnosed
Cluster 3, #87 Elements
C43 Malignant melanoma of skin
D05 Carcinoma in situ of breast
K04 Diseases of pulp and periapical tissues
...
Trunk fat mass
Trunk fat-free mass
Trunk predicted mass
Cluster 28, #59 Elements
5-alpha reductase inhibitor|BPH|benign prostatic hyperplasia
ACE inhibitor|anti-hypertensive
DPP-4 inhibitor|diabetes
...
sodium/potassium transporter inhibitor|heart failure
tricyclic antihistamine|selective histamine H1 inverse agonist|antihistamine
xanthine oxidase inhibitor|anti-gout agent
Cluster 29, #32 Elements
E10 Insulin-dependent diabetes mellitus
E11 Non-insulin-dependent diabetes mellitus
E14 Unspecified diabetes mellitus
...
Diabetes (mother illness)
Most recent bowel cancer screening
Age high blood pressure diagnosed
Cluster 30, #21 Elements
Diastolic blood pressure, automated reading
Systolic blood pressure, automated reading
Systolic blood pressure, manual reading
...
Systolic blood pressure, combined automated + manual reading
Systolic blood pressure, combined automated + manual reading, adjusted by medication
Systolic blood pressure, manual reading, adjusted by medication
---------------------
Model name: emilyalsentzer/Bio_ClinicalBERT
Cluster 1, #231 Elements
H00 Hordeolum and chalazion
Desogestrel
albuterol
...
Retinol
Carotene
pheno 48 / pheno 49
Cluster 2, #130 Elements
A41 Other septicaemia
B37 Candidiasis
B95 Streptococcus and staphylococcus as the cause of diseases classified to other chapters
...
Facial ageing
Relative age of first facial hair
Tinnitus severity/nuisance
Cluster 3, #50 Elements
Bacterial enteritis
Viral Enteritis
Bacterial infection NOS
...
Financial situation satisfaction
Private healthcare
Noisy workplace
Cluster 28, #39 Elements
R00 Abnormalities of heart beat
R01 Cardiac murmurs and other cardiac sounds
R06 Abnormalities of breathing
...
Malaise and fatigue
Leg pain on walking : action taken
Multi-site chronic pain
Cluster 29, #31 Elements
Pulse rate, automated reading
Ventricular rate
Basal metabolic rate
...
Systolic blood pressure, combined automated + manual reading
Systolic blood pressure, combined automated + manual reading, adjusted by medication
Systolic blood pressure, manual reading, adjusted by medication
Cluster 30, #76 Elements
H04 Disorders of lachrymal system
H52 Disorders of refraction and accommodation
I08 Multiple valve diseases
...
Duration of fitness test
Duration of walks
LDL direct, adjusted by medication
---------------------
def save_cluster_list(filepath, clusters):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(filepath, "wb") as fh:
pickle.dump(clusters, fh, protocol=pickle.HIGHEST_PROTOCOL)
outdir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/llm"
for method, clusters in clusters_community.items():
m_filename = os.path.join(outdir, f"{method}/community_clusters.pkl")
save_cluster_list(m_filename, clusters)
for method, clusters in clusters_kmeans.items():
m_filename = os.path.join(outdir, f"{method}/kmeans_clusters.pkl")
save_cluster_list(m_filename, clusters)
for method, clusters in clusters_agglom.items():
m_filename = os.path.join(outdir, f"{method}/agglomerative_clusters.pkl")
save_cluster_list(m_filename, clusters)