Open grmos opened 2 days ago
Hi.
TMbed directly only predicts the per-residue state of protein sequences. The global classification into transmembrane or soluble protein is simply implied by the presence of at least one predicted transmembrane segment. Thus, there is only a single loss function used during training.
We've used the CrossEntropyLoss implemented by PyTorch applied to the five-state prediction (i, o, S, H, B) of each residue. Note that only the model (plus Gaussian filter) were trained, not the Viterbi decoder, i.e. for training simply use the Predictor class from the model.py file.
Regards Michael
Thank you, Michael, for the quick response!
I have a follow-up question regarding the loss function setup. Given that the "B" class (beta barrel) proteins are underrepresented in the dataset, did you apply any class weights in the cross-entropy loss to balance the data during training? If so, could you share a bit about how you calculated or set these weights?
Appreciate your insights, and thanks again for sharing TMbed with the community!
No, there are no class weights applied to the loss; all classes are equal. In fact, in early tests class weights did degrade the overall performance slightly. However, the data splits are stratified to roughly contain the same number of proteins from the three different types (TMB, TMH, Non-TM) with and without signal peptides.
Thank you for sharing TMbed! I noticed there isn’t a training script in the repository, and I had a question regarding the cross-entropy loss function used. From the paper, it seems like TMbed might be predicting both the overall protein type (e.g., alpha helical, beta barrel) and classifying each residue within the sequence. Could you clarify if the cross-entropy loss function is: A sum of two losses—one for the overall protein type classification and another for per-residue classification? A single cross-entropy loss applied solely to per-residue classification? Or a different approach altogether? Also, if it’s possible, could you consider sharing any additional code related to training, especially regarding how the loss function is implemented?
Thank you very much for your help!