df = model_utils.reformat_classification_pipeline_preds(tcrbert_trb_cls([
"C A S S P V T G G I Y G Y T F", # Binds to NLVPMVATV CMV antigen
"C A T S G R A G V E Q F F", # Binds to GILGFVFTL flu antigen
])) Is there any decoder that can decode the dataframe's output? Maybe like
'model = BertModel.from_pretrained("wukevin/tcr-bert").to(device)
tokenizer = BertTokenizer.from_pretrained("wukevin/tcr-bert")
outputs = model.decoder(embedding_tensor)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
`import model_utils # found under the "tcr" folder
tcrbert_trb_cls = model_utils.load_classification_pipeline("wukevin/tcr-bert", device=0)
df = model_utils.reformat_classification_pipeline_preds(tcrbert_trb_cls([ "C A S S P V T G G I Y G Y T F", # Binds to NLVPMVATV CMV antigen "C A T S G R A G V E Q F F", # Binds to GILGFVFTL flu antigen ]))
Is there any decoder that can decode the dataframe's output? Maybe like
'model = BertModel.from_pretrained("wukevin/tcr-bert").to(device) tokenizer = BertTokenizer.from_pretrained("wukevin/tcr-bert") outputs = model.decoder(embedding_tensor) logits = outputs.logits predicted_ids = torch.argmax(logits, dim=-1)
decoded_seq = tokenizer.decode(predicted_ids[0], skip_special_tokens=True) decoded_sequences.append(decoded_seq)`