zizheng-guo / RhythmMamba

RhythmMamba
33 stars 4 forks source link

RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR #4

Closed 408550969 closed 4 months ago

408550969 commented 4 months ago

When I built the environment in conda as required, the training model reported an error: ====Training Epoch: 0==== Train epoch 0: 0%| | 0/25 [00:07<?, ?it/s] Traceback (most recent call last): File "main.py", line 296, in train_and_test(config, data_loader_dict) File "main.py", line 67, in train_and_test model_trainer.train(data_loader_dict) File "/home/chenlili/RhythmMamba-main/neural_methods/trainer/RhythmMambaTrainer.py", line 71, in train pred_ppg = self.model(data) File "/home/chenlili/anaconda3/envs/rppg-toolbox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, kwargs) File "/home/chenlili/anaconda3/envs/rppg-toolbox/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward return self.module(*inputs[0], *kwargs[0]) File "/home/chenlili/anaconda3/envs/rppg-toolbox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(input, kwargs) File "/home/chenlili/RhythmMamba-main/neural_methods/model/RhythmMamba.py", line 332, in forward x = blk(x) File "/home/chenlili/anaconda3/envs/rppg-toolbox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, *kwargs) File "/home/chenlili/RhythmMamba-main/neural_methods/model/RhythmMamba.py", line 215, in forward x = x + self.drop_path(self.mlp(self.norm2(x))) File "/home/chenlili/anaconda3/envs/rppg-toolbox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(input, **kwargs) File "/home/chenlili/RhythmMamba-main/neural_methods/model/RhythmMamba.py", line 120, in forward x_fre = torch.fft.fft(x, dim=1, norm='ortho') # FFT on N dimension RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR

zizheng-guo commented 4 months ago

The torch.fft library used in the code seems to temporarily not support RTX 4090. You can use other GPUs or other spectral transformation methods.

408550969 commented 4 months ago

Thanks!

408550969 commented 4 months ago

把CUDA切到11.8能解决问题,具体可以参考下面这个链接: https://blog.csdn.net/weixin_44007713/article/details/136475398

注意cuda toolkit安装完毕后需要添加到环境中, export CPATH=/usr/local/cuda-11.8/targets/x86_64-linux/include:$CPATH export LD_LIBRARY_PATH=/usr/local/cuda-11.8/targets/x86_64-linux/lib:$LD_LIBRARY_PATH export PATH=/usr/local/cuda-11.8/bin:$PATH