jasonkyuyim / multiflow

https://arxiv.org/abs/2402.04997
MIT License
119 stars 6 forks source link

Question regarding "aatype_pred_num_tokens" #9

Open smiles724 opened 1 month ago

smiles724 commented 1 month ago

Hi, Jason,

Asking question again and thanks for your patience in advance.

In your prediction model, you always set the aatype_pred_num_tokens to 21, including the traditional 20 amino acid types and the additional mask token type. Meanwhile, in your loss computation part, you manually turn the dimension for the mask token to 1e-9, a number close to 0.

Therefore, I just wonder why we cannot directly ask the model to predict 20 tokens without the mask one. Is there any advantage of your current implementation?

The only reason I can guess is that as you use an extra padding token (the same as masking token), so during the loss calculation, it will trigger error if you only predict 20 types.

jasonkyuyim commented 1 month ago

Both approaches are fine. I think we were going for shape compatibility by keeping the aatype dimension to always be 21. This simplified some torch operations.