Closed adrienchaton closed 1 week ago
in case others face same issue, actually one should also use generate method for inference to compute logits
it goes as
ids_backtranslation = tokenizer.batch_encode_plus(["<fold2AA>"+" "+seq_3di], add_special_tokens=True, padding="longest", return_tensors='pt').to(model.device)
outputs = model.generate(ids_backtranslation.input_ids, attention_mask=ids_backtranslation.attention_mask, max_length=len(seq_3di)+1, min_length=len(seq_3di)+1, output_scores=True, return_dict_in_generate=True, repetition_penalty=repetition_penalty)
logits = torch.cat(outputs.scores).cpu()
one thing which seemed off in your example is that if I didnt add +1 to the expected length (here only one example) then the output would be shorter by one residue compared to the expected length from the 3Di encoding ...
Any corrections on what I came up with would be greatly appreciated. For sanity check, the recovery from the sequence I computed the 3Di from is pretty fine, i.e. >40% so it seems not buggy to me ..
Thanks for sharing the details in how you got scores. -From what I remember, I used a similar logic at one point, so I would not immediately see what to change.
Only thing: on the +1 offset: maybe double check but the decoder should not need those special pre-fixes which indicate the direction of translation ("<s>
if I am not mistaken) but this should get stripped off automatically when you do sth like decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
for inference it looks like it is easiest to stick to using model.generate()
along with the special tokens which indicate what processing is expected (e.g. encoder or decoder)
additionally thanks for sharing this batch_decode
method which takes care of dropping special tokens from sequence outputs
I am considering to try to finetune ProstT5 for other protein types, I might come back with a few questions if you dont mind!
Sounds interesting, let me know in case you hit any problems on the way. Regarding finetuning, I would recommend to consider some parameter-efficient version which we made good experience with previously. .
@mheinzinger Thanks for the advice, here specifically I am thinking about finetuning the encoder-decoder models together on sequence-3Di pairs, not on e.g. supervised fitness prediction with the encoder alone (as with e.g. ESM2). It could be interesting to tune the models to antibodies for example, or retrain from scratch but still I guess starting from the general pretrained model would be an advantage.
Hello @mheinzinger and thanks for sharing this interesting model.
I dont have practical experience with encoder-decoder LMs such as T5 so I am still trying to figure out the right way to get logits for the inverse folding task, instead of sampling sequences as in the example shown in the readme.
I have tried to run
the shape of
logits
seems right, ie [batch, L+2, vocab_size] but then when I sample from these logits for sanity check I get about random tokens and no recovery of the target sequence from which I extracted the 3Di tokens ...If calling model itself, I get the error
but in the case of inverse folding, we should not provide anything to the encoder and only get the structure information through the decoder, at least as far as I understand ...
Can you help please?