uw-ipd / RoseTTAFold2NA

RoseTTAFold2 protein/nucleic acid complex prediction
MIT License
310 stars 69 forks source link

RuntimeError: Class values must be smaller than num_classes. #27

Open adimil opened 1 year ago

adimil commented 1 year ago

Hello,

I am getting the following error message when running on GPU:

Traceback (most recent call last):
  File "/home/.../RoseTTAFold2NA/network/predict.py", line 345, in <module>
    pred.predict(inputs=args.inputs, out_prefix=args.prefix, ffdb=ffdb)
  File "/home/.../RoseTTAFold2NA/network/predict.py", line 225, in predict
    self._run_model(Ls, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial))
  File "/home/.../RoseTTAFold2NA/network/predict.py", line 233, in _run_model
    seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(
  File "/home/.../RoseTTAFold2NA/network/data_loader.py", line 116, in MSAFeaturize
    raw_profile = raw_profile.float().mean(dim=0)
RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Or when running on CPU:

Traceback (most recent call last):
  File "/home/.../RoseTTAFold2NA/network/predict.py", line 345, in <module>
    pred.predict(inputs=args.inputs, out_prefix=args.prefix, ffdb=ffdb)
  File "/home/.../RoseTTAFold2NA/network/predict.py", line 225, in predict
    self._run_model(Ls, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial))
  File "/home/.../RoseTTAFold2NA/network/predict.py", line 233, in _run_model
    seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(
  File "/home/.../RoseTTAFold2NA/network/data_loader.py", line 115, in MSAFeaturize
    raw_profile = torch.nn.functional.one_hot(msa, num_classes=NAATOKENS)
RuntimeError: Class values must be smaller than num_classes.

Do you have any idea what might cause this and how to fix?

Thanks!

mmagnus commented 1 year ago

I got the same problem, I didn't find the solution yet.

bifxcore commented 1 year ago

Same here, though I'm getting @adimil 's 'CPU' error when I'm running on GPU:

RoseTTAFold2NA/network/parsers.py:116: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  msa[msa == "U"] = 30
Running on GPU
Traceback (most recent call last):
  File "RoseTTAFold2NA/network/predict.py", line 374, in <module>
    pred.predict(inputs=args.inputs, out_prefix=args.prefix, ffdb=ffdb)
  File "RoseTTAFold2NA/network/predict.py", line 250, in predict
    self._run_model(Ls, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, same_chain, mask_t_2d, "%s_%02d"%(out_prefix, i_trial))
  File "RoseTTAFold2NA/network/predict.py", line 256, in _run_model
    seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(
  File "RoseTTAFold2NA/network/data_loader.py", line 118, in MSAFeaturize
    raw_profile = torch.nn.functional.one_hot(msa, num_classes=NAATOKENS)
RuntimeError: Class values must be smaller than num_classes.
bifxcore commented 1 year ago

How do I fix the "Class values must be smaller than num_classes." error?

msa size: torch.Size([6142, 1064]) num_classes: 32

It looks like I need to either reduce the msa size or increase NAATOKENS drastically... that does not sound right?

fdimaio commented 1 year ago

Hello, can you try the latest version? This was a bug where 32 was hardcoded. If this does not work, can you post your MSA?

bifxcore commented 1 year ago

@fdimaio I downloaded the package last week, so I just updated network/parsers.py and restarted predict.py.

I'm now getting the error from https://github.com/uw-ipd/RoseTTAFold2NA/issues/15

Running on GPU Traceback (most recent call last): File "RoseTTAFold2NA/network/predict.py", line 374, in pred.predict(inputs=args.inputs, out_prefix=args.prefix, ffdb=ffdb) File "RoseTTAFold2NA/network/predict.py", line 160, in predict msa_i, ins_i = parse_fasta(a3m_i, rna_alphabet=is_rna, dna_alphabet=is_dna) File "RoseTTAFold2NA/network/parsers.py", line 119, in parse_fasta assert (np.all(msa<=31)) AssertionError

bifxcore commented 1 year ago

RNA MSA attached.

tRNA_Glu.afa.gz

bifxcore commented 1 year ago

@fdimaio my bad. Just realised my alignment has lots of sequences with illegal bases: -UCCCGUUCGUCUAGAGGCCUAGGACACCGCCCUUUCACGGCGGUAACAGGGGKUCGACU CCCMUARGSGM----- GUCCCCAUCGUCUAGAGGCCUAGGACACYGCCCUUUCACGGCGRYAACCGGGGUUCGAAU

:(

bifxcore commented 1 year ago

Fixed this issue with a quick and dirty hack:

# remove RNA sequences with illegal characters from alignment
perl -ne ' if (/^>/) {$h = $_} else { if (/^[ NUCGA-]+$/i) {print $h, $_} }' my_crappy_RNA_file.afa  > tRNA_Glu.afa

now fighting https://github.com/uw-ipd/RoseTTAFold2NA/issues/13 ;-)