mcpaulgeorge / WalMaFa

[ACCV 2024] Source code of WalMaFa
MIT License
14 stars 3 forks source link

gpu训练问题 #8

Open suxiaolin-collab opened 4 days ago

suxiaolin-collab commented 4 days ago

你好作者,我在跑你们训练的时候遇到了这个问题,请问有解决的方式吗? /home/amax/anaconda3/bin/conda run -n WalMaFa --no-capture-output python /data1/WalMaFa/train.py load training yaml file: ./configs/LOL/train/training_LOL.yaml ==> Build the model

Let's use 3 GPUs!

/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of lr_scheduler.step() before optimizer.step(). In PyTorch 1.1.0 and later, you should call them in the opposite order: optimizer.step() before lr_scheduler.step(). Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate warnings.warn("Detected call of lr_scheduler.step() before optimizer.step(). " ==> Loading datasets

==> Training details:

Restoration mode:   Walmafa_LOL_v1
Train patches size: 128x128
Val patches size:   128x128
Model parameters:   11.09M
Start/End epochs:   1~5000
Best_PSNR_Epoch:    0
Best_PSNR:          0.0000
Best_SSIM_Epoch:    0
Best_SSIM:          0.0000
Batch sizes:        12
Learning rate:      0.0008
GPU:                GPU[0, 1, 2]

==> Training start: 0%| | 0/41 [00:04<?, ?it/s] Traceback (most recent call last): File "/data1/WalMaFa/train.py", line 191, in restored = modelrestored(input) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply output.reraise() File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise raise exception RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker output = module(*input, *kwargs) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/data1/WalMaFa/model/Walmafa.py", line 444, in forward out_enc_level1_0 = self.decoder_level1_0(inp_enc_level1_0) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/data1/WalMaFa/model/Walmafa.py", line 286, in forward input_high = self.mb(input_high) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/data1/WalMaFa/model/Walmafa.py", line 260, in forward y = self.model1(x).permute(0, 2, 1) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/mamba_ssm/modules/mamba_simple.py", line 189, in forward y = selective_scan_fn( File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 88, in selective_scan_fn return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(args, *kwargs) # type: ignore[misc] File "/home/amax/anaconda3/envs/WalMaFa/lib/python3.8/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 42, in forward out, x, rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) RuntimeError: CUDA error: no kernel image is available for execution on the device CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

mcpaulgeorge commented 4 days ago

你好,请你指定一个gpu试一下,不要多卡跑。

suxiaolin-collab commented 4 days ago

Training configuration

GPU: [0]

VERBOSE: False

MODEL: MODE: 'Walmafa_LOL_v1'

Optimization arguments.

OPTIM: BATCH: 12 EPOCHS: 5000

EPOCH_DECAY: [10]

LR_INITIAL: 0.0008 LR_MIN: 1e-6

BETA1: 0.9

TRAINING: VAL_AFTER_EVERY: 1 RESUME: False TRAIN_PS: 128 VAL_PS: 128 TRAIN_DIR: '/data2/TrainingData/LOLdataset/our485/' # path to training data VAL_DIR: '/data2/TrainingData/LOLdataset/eval15/' # path to validation data SAVE_DIR: './checkpoints_walmafa' # path to save models and images 指定了gpu之后还是相同的问题,我想问一下咱们的数据集是LOL吗

mcpaulgeorge commented 3 days ago

是的,您之前报的错一般是cuda与torch的版本有问题,可能与显卡的驱动有关系,我们的显卡设备是Nvidia A10(24G), cuda版本是12.4,本项目装的torch是cuda11.7版本的。 image

suxiaolin-collab commented 3 days ago

1732348793617 我所使用的设备是p100,cuda版本是12.6请问这个设备会对项目的运行造成影响吗

suxiaolin-collab commented 3 days ago

你好非常感谢你的回答,在更换了torch的版本之后,又更换了mamba_ssm的版本为1.2.0.post1之后,问题解决了,感谢您的帮助。

suxiaolin-collab commented 3 days ago

请问在训练中第四轮训练过后loss呈现nan,这是正常现象吗?