open-mmlab / mmflow

OpenMMLab optical flow toolbox and benchmark
https://mmflow.readthedocs.io/en/latest/
Apache License 2.0
958 stars 115 forks source link

Runtime error gradient computation when train.py #248

Closed Salvatore-tech closed 2 years ago

Salvatore-tech commented 2 years ago

Describe the bug Good evening, i wanted to lauch train.py using a default config file for RAFT on a standard dataset (KITTI_2015). I followed the instruction to install MMFlow from source successfully.

Reproduction

python tools/train.py configs/raft/raft_8x2_50k_kitti2015_and_Aug_288x960.py \
--load-from /home/s.starace/FlowNets/mmflow/checkpoints/raft/raft_8x2_100k_mixed_368x768.pth
  1. Did you make any modifications on the code or config? Did you understand what you have modified? I just changed the name of symlink that i created under /data (uppercase)

  2. What dataset did you use? KITTI_2015

Environment I launched the command on my PC and also on a little cluster and the output error is the same.

Error traceback See log attached: slurm-53090.out.txt

Bug fix Not sure about it, could either be a configuration issue in Encoder/Decoder or a regression. I'll try the train.py using other models as well and update the report if i understand better the problem.

Salvatore-tech commented 2 years ago

Training of different models than RAFT (PwcNet, LiteFlow, ec...) with the same script to load the dataset have NOT issue reported above

wz940216 commented 2 years ago

训练GMA的时候也会报同样错误。

Salvatore-tech commented 2 years ago

Sorry @wz940216 but i did not get your answer, do you need additional detail about the issue?

wz940216 commented 2 years ago

@Salvatore-tech I'm not a developer, but I had the same problem as you when training GMA and RAFT while using mmflow.

Salvatore-tech commented 2 years ago

@wz940216 that's interesting, I hope that the owners of this repository could give us a clue (I'd like to use RAFT in my use case because it should give better performance). @MeowZheng @Zachary-66

Zachary-66 commented 2 years ago

I meet the same bug using PyTorch 1.12.1. Below is my log: raft_kitti_bug.txt

However, when I use PyTorch 1.8.0, this bug no longer appears. Below is the normal log: raft_kitti_normal.txt

I notice you are using PyTorch 1.12.1, which is the latest version and there might be some unexpected bugs. To save your time, I recommend that you follow the official command below to install PyTorch1.8.0 instead:

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch

We will modify MMFlow as soon as possible to avoid similar bugs. Thanks a lot!

Salvatore-tech commented 2 years ago

Thanks @Zachary-66 your fix did the job, i did not notice that operator otherwise i would have pulled the request. I'm closing the issue