nyuad-cai / CXR-ML-GZSL

30 stars 8 forks source link

How to obtain BioBert Embedding for all the trainable labels #5

Open alee355 opened 2 years ago

alee355 commented 2 years ago

Hi I know one person already asked about it. But I'm still having a trouble to obtain a BioBErt embedding for all the trainable labels on my own. I saw your reply "we used transformers library https://github.com/huggingface/transformers with 'dmis-lab/biobert-large-cased-v1.1' pre-trained weights and the embeddings were extracted through "last_hidden_state" of the output." If you don't mind, can you also share a script for acquiring the biobert embedding or can you explain how did you acquire it?

chinmay5 commented 2 years ago

` class BERTEmbedModel: def init(self): model_name_or_path = 'dmis-lab/biobert-large-cased-v1.1'

    config = AutoConfig.from_pretrained(model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
        use_fast=False,
    )
    model = AutoModel.from_pretrained(
        model_name_or_path,
        from_tf=bool(".ckpt" in model_name_or_path),
        config=config,
        cache_dir=None,
    )
    self.nlp = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer)

def get_embeds(self, text):
    output = self.nlp(text)
    wts = []
    for idx in range(len(output[0])):
        wts.append(torch.tensor(output[0][idx]))
    wts = torch.stack(wts).mean(axis=0)
    return wts

def get_bert_text_embeddings(node_names):
    vector_repr_dict = dict()
    for disease in node_names:
        vector_repr_dict[disease] = bert_model.get_embeds(disease)
    return vector_repr_dict

`

This is the code I used. You need to make some modifications to the second method but hopefully, this should work out of the box. Please drop me a message in case there is some issue with the code.