songlab-cal / tape

Tasks Assessing Protein Embeddings (TAPE), a set of five biologically relevant semi-supervised learning tasks spread across different domains of protein biology.
https://www.biorxiv.org/content/10.1101/676825v1
BSD 3-Clause "New" or "Revised" License
656 stars 129 forks source link

Training UniRep language model error #33

Closed faruuk closed 3 years ago

faruuk commented 4 years ago

Hi,

When I try to train a language model with UniRep I get the following error: "RuntimeError: shape '[-1,26]' is invalid for input of size 2800"

Looking at the UniRepForLM Class, it says in a comment: "# TODO: Fix this for UniRep - UniRep changes the size of the targets"

Does that mean that currently, it is not possible to train a UniRep model for a language modelling task?

rmrao commented 4 years ago

I have not extensively tested and debugged UniRep training with the current repository. It is meant more to provide the default embeddings. If you're interested in training a UniRep language model, I'd be happy to accept a pull request that does this.

spark157 commented 4 years ago

Hi, I've been trying to do some further training of the UniRep model (essentially 'evoTuning') and have been looking into this error/issue. I wanted to share what I've uncovered to see if it makes sense as I'm new to the codebase. If it all sounds reasonable I'll try to get a pull request organized (I'm a newb for github). Basically I see 3 things that need small tweaks to solve the problem:

  1. On line 179 of modeling_unirep.py taking .view of prediction_scores with self.config.vocab_size causes a problem since output of the Linear self.feedforward is of dimension config.vocab_size - 1. (I'm not sure why this is the case but if you change the output dim of self.feedforward to config.vocab_size you can't load up the pretrained model since it must have been trained using config.vocab_size - 1.) The solution would then be to change the .view to use self.config.vocab_size - 1 like so:

        lm_loss = loss_fct(
            prediction_scores.view(-1, self.config.vocab_size-1), targets.view(-1))                        
  2. However the targets and prediction_scores need to have been made contiguous otherwise an error is thrown (since they were sliced).

  3. When encoding the sequences the special token stop token has been added but this will now throw an error for the CrossEntropyLoss (since there is a mismatch now for the number of possible classes). Therefore the targets needs to be adjusted to remove the stop token id (which is 25 for UniRep). What I did was replace 25 with -1 (for ignore) which I'm guessing will not cause any further problems.

I'm currently training with this solution but have yet to see whether the final results on a downstream task are valid (as a test).

Also, UniRep is a beast on memory and to get it to load and train is finicky. In the original paper they limited prediction to 280 amino acids (for the evoTuning task). To get my version running on a 16GB NVIDIA V100 I had to limit the sequences to 200 amino acids. I did this in getitem by just slicing the sequence. Others who try to train are likely to run into similar resource constraints they have to resolve.

rmrao commented 4 years ago

Yes - I implemented UniRep in a plausibly questionable way in pytorch, which correctly mimics the behavior of the forward pass, but simply using a for-loop to iterate through the sequence is not a recommended method of implementing an RNN.

I mainly did this since I have no intention of training UniRep myself, and don't have the time to figure out a better way of implementing this in pytorch. If you're planning on extensive UniRep training, I'd recommend either using the tensorflow implementation, or figuring out a better way of implementing it in pytorch.

If you're able to improve the pytorch implementation I'd be happy to take a pull request!

spetti commented 3 years ago
3. When encoding the sequences the special token stop token has been added but this will now throw an error for the CrossEntropyLoss (since there is a mismatch now for the number of possible classes). Therefore the targets needs to be adjusted to remove the stop token id (which is 25 for UniRep). What I did was replace 25 with -1 (for ignore) which I'm guessing will not cause any further problems.

For the record, I think the proper adjustment is to remove the padding token "0", rather than the stop token "25". The next-residues probabilities after inputting the start token "24" is a vector of the form [ .95, something, something... ]. If you ignore the padding token "0", the .95 refers to the aa labeled "1", which is M. The guess that first aa is methionine with probability 95 makes sense biologically (it is the aa encoded by the start codon). Another way to check is noting that the probability vectors always have very low values in 0-index vector positions 21 and 11, which refer to the aa labeled "22" (O) and "12" (U), which are very uncommon.

rmrao commented 3 years ago

I am closing this issue - for any future users who would like to train Unirep I would point you to the following repository which has a much faster implementation than you will ever get in pytorch: https://github.com/ElArkk/jax-unirep.