zhengchen1999 / DAT

PyTorch code for our ICCV 2023 paper "Dual Aggregation Transformer for Image Super-Resolution"
Apache License 2.0
350 stars 27 forks source link

Got KeyError: 'RANK' when train on single GPU #31

Open imaklex5 opened 4 months ago

imaklex5 commented 4 months ago

Hi!

I do as the the reply said to train the DAT-light model on single 3090 GPU,but got KeyError: 'RANK' as follows. How can I fix this?

root@autodl-container-2bae408949-4f297d06:~/DAT# python basicsr/train.py -opt options/Train/train_DAT_light_x2.yml --launcher pytorch
Traceback (most recent call last):
  File "basicsr/train.py", line 215, in <module>
    train_pipeline(root_path)
  File "basicsr/train.py", line 93, in train_pipeline
    opt, args = parse_options(root_path, is_train=True)
  File "/root/DAT/basicsr/utils/options.py", line 106, in parse_options
    init_dist(args.launcher)
  File "/root/DAT/basicsr/utils/dist_util.py", line 14, in init_dist
    _init_dist_pytorch(backend, **kwargs)
  File "/root/DAT/basicsr/utils/dist_util.py", line 22, in _init_dist_pytorch
    rank = int(os.environ['RANK'])
  File "/root/miniconda3/lib/python3.8/os.py", line 675, in __getitem__
    raise KeyError(key) from None
KeyError: 'RANK'
mychdream commented 4 months ago

Hello, I also encountered this problem, I would like to ask if you have solved this problem

zhengchen1999 commented 4 months ago

Set num_gpu in yml as 1.

And use the following script: python basicsr/train.py -opt options/Train/train_DAT_light_x2.yml

mychdream commented 4 months ago

Thanks for the author's immediate reply, but now I have a new question, Snipaste_2024-03-07_19-45-21 I am using a single 2080Ti run and have set both batch_size and num_worker to 1.

zhengchen1999 commented 4 months ago

It may be that the cuda and pytorch versions do not match. Re-build a new python environment use the following scripts:

conda create -n DAT python=3.8
conda activate DAT
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio===0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt (delete torch==1.8.0 and torchvision in requirements.txt)
python setup.py develop
imaklex5 commented 4 months ago

Hello, I also encountered this problem, I would like to ask if you have solved this problem

I just add this lines in basicsr.utilsdist_util._init_dist_pytorch,and use command the author mentioned before then it works.

os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '1234'

image

imaklex5 commented 4 months ago

It may be that the cuda and pytorch versions do not match. Re-build a new python environment use the following scripts:

conda create -n DAT python=3.8
conda activate DAT
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio===0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt (delete torch==1.8.0 and torchvision in requirements.txt)
python setup.py develop

thx very much^_^