hpi-dhc / xmen

✖️MEN - A Modular Toolkit for Cross-Lingual Medical Entity Normalization
Apache License 2.0
23 stars 7 forks source link

How to use the trained Crossencoder for predicting the ICD code for entities in a text ? #23

Closed kunalr97 closed 11 months ago

kunalr97 commented 11 months ago

I have successfully trained a cross encoder on the BRONCO dataset. I want to use this for predict the ICD code for an entity in the input text I feed it. I am loading the trained cross encoder this way:

model_icd = CrossEncoderReranker.load(f'../outputs/{label2dict[label]}_index/cross_encoder_training/',device=1)

I want to use this model in a similar way I can use a linker to predict the ICD codes. I do that in this way using my spacy model:

nlp_ner = spacy.load('output/model-best/')
nlp_ner.add_pipe('sentencizer')
doc = list(nlp_ner.pipe(["CT untere Extremitaet: leichte progrediente Metastase im Musculus gluteus"]))
sample = from_spacy(doc)
label = "DIAG"
label2dict = {
    "TREAT": "ops",
    "MED": "atc",
    "DIAG": "icd10gm"
}

def filter_entities(bigbio_entities, valid_entities):
    filtered_entities = []
    for ent in bigbio_entities:
        if ent['type'] in valid_entities:
            filtered_entities.append(ent)
    return filtered_entities

sample_diag = sample.map(lambda row: {'entities': filter_entities(row['entities'], [label])})

icd_sapbert_ngram = default_ensemble(index_base_path='/home/IAIS/krunwal/.cache/xmen/icd10gm/index/')
candidates_sap_ngram = icd_sapbert_ngram.predict_batch(sample_diag)
candidates_sap_ngram['entities']

Here I can see the predicted ICD codes in the 'normalized' key of the dictionary.

[[{'id': '0', 'offsets': [[18, 39]], 'text': ['akutes nierenversagen'], 'type': 'DIAG', 'normalized': [{'db_id': 'N17.99', 'db_name': 'UMLS', 'score': 0.9999999403953552, 'predicted_by': ['ngram', 'sapbert']}, {'db_id': 'N17', 'db_name': 'UMLS', 'score': 0.9999999403953552, 'predicted_by': ['ngram', 'sapbert']}, {'db_id': 'N17.8', 'db_name': 'UMLS', 'score': 0.8631265163421631, 'predicted_by': ['ngram', 'sapbert']}, {'db_id': 'N19', 'db_name': 'UMLS', 'score': 0.8424741625785828, 'predicted_by': ['ngram', 'sapbert']}, {'db_id': 'J96.09', 'db_name': 'UMLS', 'score': 0.8196000456809998, 'predicted_by': ['ngram', 'sapbert']}, {'db_id': 'E27.2', ... {'db_id': 'F52.2', 'db_name': 'UMLS', 'score': 0.3720674514770508, 'predicted_by': ['ngram']}], 'long_form': None}]]

I want to achieve a similar thing using the trained crossencoder on the BRONCO dataset. So far I can only achieve reranking, I want to simply predict the ICD code for an input text.

Thanks in advance!

phlobo commented 11 months ago

Thank you for your question!

It looks like you have successfully generated candidates for your dataset.

From thereon, you should be able to follow the Getting Started Notebook, and the section Using a Pre-trained Model for Reranking.

Again, you want to make sure to limit the number of candidates that you use for re-ranking to control your memory usage.

To this end, you could either: 1) pass top_k to icd_sapbert_ngram.predict_batch(sample_diag) 2) use filter_and_apply_threshold just before preparing the dataset for the cross-encoder, or 3) pass k explicitly to the cross-encoder.

Your code (using the third option) might end up looking something like this:

YOUR_K = 16
kb = load_kb(...)
# Prepare dataset to create samples of mention/context - candidate pairs
ce_candidates = CrossEncoderReranker.prepare_data(candidates_sap_ngram, None, kb, k=YOUR_K)

# run re-ranking
result = model_icd.rerank_batch(candidates_sap_ngram, ce_candidates, k=YOUR_K)
kunalr97 commented 11 months ago

Thank you so much for your concise and quick reply! This helps a lot