wukevin / tcr-bert

Large language modeling applied to T-cell receptor (TCR) sequences.
Apache License 2.0
47 stars 7 forks source link

how to decode the embeds to original sequences #12

Open samw0806 opened 4 months ago

samw0806 commented 4 months ago

`import model_utils # found under the "tcr" folder

tcrbert_trb_cls = model_utils.load_classification_pipeline("wukevin/tcr-bert", device=0)

df = model_utils.reformat_classification_pipeline_preds(tcrbert_trb_cls([ "C A S S P V T G G I Y G Y T F", # Binds to NLVPMVATV CMV antigen "C A T S G R A G V E Q F F", # Binds to GILGFVFTL flu antigen ]))
Is there any decoder that can decode the dataframe's output? Maybe like
'model = BertModel.from_pretrained("wukevin/tcr-bert").to(device) tokenizer = BertTokenizer.from_pretrained("wukevin/tcr-bert") outputs = model.decoder(embedding_tensor) logits = outputs.logits predicted_ids = torch.argmax(logits, dim=-1)

decoded_seq = tokenizer.decode(predicted_ids[0], skip_special_tokens=True) decoded_sequences.append(decoded_seq)`