mheinzinger / ProstT5

Bilingual Language Model for Protein Sequence and Structure
MIT License
147 stars 13 forks source link

computing residue logits from 3Di input #22

Closed adrienchaton closed 1 week ago

adrienchaton commented 2 weeks ago

Hello @mheinzinger and thanks for sharing this interesting model.

I dont have practical experience with encoder-decoder LMs such as T5 so I am still trying to figure out the right way to get logits for the inverse folding task, instead of sampling sequences as in the example shown in the readme.

I have tried to run

outputs = model.decoder(ids_backtranslation.input_ids, attention_mask=ids_backtranslation.attention_mask)
logits = model.lm_head(outputs.last_hidden_state)

the shape of logits seems right, ie [batch, L+2, vocab_size] but then when I sample from these logits for sanity check I get about random tokens and no recovery of the target sequence from which I extracted the 3Di tokens ...

If calling model itself, I get the error

ValueError: You have to specify either input_ids or inputs_embeds

but in the case of inverse folding, we should not provide anything to the encoder and only get the structure information through the decoder, at least as far as I understand ...

Can you help please?

adrienchaton commented 2 weeks ago

in case others face same issue, actually one should also use generate method for inference to compute logits

it goes as

ids_backtranslation = tokenizer.batch_encode_plus(["<fold2AA>"+" "+seq_3di], add_special_tokens=True,  padding="longest", return_tensors='pt').to(model.device)
outputs = model.generate(ids_backtranslation.input_ids, attention_mask=ids_backtranslation.attention_mask, max_length=len(seq_3di)+1, min_length=len(seq_3di)+1, output_scores=True, return_dict_in_generate=True, repetition_penalty=repetition_penalty)
logits = torch.cat(outputs.scores).cpu()

one thing which seemed off in your example is that if I didnt add +1 to the expected length (here only one example) then the output would be shorter by one residue compared to the expected length from the 3Di encoding ...

Any corrections on what I came up with would be greatly appreciated. For sanity check, the recovery from the sequence I computed the 3Di from is pretty fine, i.e. >40% so it seems not buggy to me ..

mheinzinger commented 1 week ago

Thanks for sharing the details in how you got scores. -From what I remember, I used a similar logic at one point, so I would not immediately see what to change.

Only thing: on the +1 offset: maybe double check but the decoder should not need those special pre-fixes which indicate the direction of translation ("" etc). Those prefixes are only added to the encoder input to tell the model already how to interpret the input to the encoder and how to optimally embed it for the translation direction you are interested in. That being said: I think there is a special token added to the decoder to kick-off the translation (<s> if I am not mistaken) but this should get stripped off automatically when you do sth like decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )

adrienchaton commented 1 week ago

for inference it looks like it is easiest to stick to using model.generate() along with the special tokens which indicate what processing is expected (e.g. encoder or decoder)

additionally thanks for sharing this batch_decode method which takes care of dropping special tokens from sequence outputs

I am considering to try to finetune ProstT5 for other protein types, I might come back with a few questions if you dont mind!

mheinzinger commented 1 week ago

Sounds interesting, let me know in case you hit any problems on the way. Regarding finetuning, I would recommend to consider some parameter-efficient version which we made good experience with previously. .

adrienchaton commented 1 week ago

@mheinzinger Thanks for the advice, here specifically I am thinking about finetuning the encoder-decoder models together on sequence-3Di pairs, not on e.g. supervised fitness prediction with the encoder alone (as with e.g. ESM2). It could be interesting to tune the models to antibodies for example, or retrain from scratch but still I guess starting from the general pretrained model would be an advantage.