BernhoferM / TMbed

Transmembrane proteins predicted through Language Model embeddings
Apache License 2.0
30 stars 3 forks source link

Clarification on Cross-Entropy Loss for TMbed Training #12

Open grmos opened 2 days ago

grmos commented 2 days ago

Thank you for sharing TMbed! I noticed there isn’t a training script in the repository, and I had a question regarding the cross-entropy loss function used. From the paper, it seems like TMbed might be predicting both the overall protein type (e.g., alpha helical, beta barrel) and classifying each residue within the sequence. Could you clarify if the cross-entropy loss function is: A sum of two losses—one for the overall protein type classification and another for per-residue classification? A single cross-entropy loss applied solely to per-residue classification? Or a different approach altogether? Also, if it’s possible, could you consider sharing any additional code related to training, especially regarding how the loss function is implemented?

Thank you very much for your help!

BernhoferM commented 2 days ago

Hi.

TMbed directly only predicts the per-residue state of protein sequences. The global classification into transmembrane or soluble protein is simply implied by the presence of at least one predicted transmembrane segment. Thus, there is only a single loss function used during training.

We've used the CrossEntropyLoss implemented by PyTorch applied to the five-state prediction (i, o, S, H, B) of each residue. Note that only the model (plus Gaussian filter) were trained, not the Viterbi decoder, i.e. for training simply use the Predictor class from the model.py file.

Regards Michael

grmos commented 1 day ago

Thank you, Michael, for the quick response!

I have a follow-up question regarding the loss function setup. Given that the "B" class (beta barrel) proteins are underrepresented in the dataset, did you apply any class weights in the cross-entropy loss to balance the data during training? If so, could you share a bit about how you calculated or set these weights?

Appreciate your insights, and thanks again for sharing TMbed with the community!

BernhoferM commented 1 day ago

No, there are no class weights applied to the loss; all classes are equal. In fact, in early tests class weights did degrade the overall performance slightly. However, the data splits are stratified to roughly contain the same number of proteins from the three different types (TMB, TMH, Non-TM) with and without signal peptides.