valeoai / WaffleIron

Other
39 stars 6 forks source link

Loading state_dict of a model that uses 1 GPU does not work #10

Closed SpeedyGonzales949 closed 3 weeks ago

SpeedyGonzales949 commented 1 month ago

I have trained a model for 1 epoch on the NuScenes Dataset on a single GPU using this flag "--gpu 0".

When this flag is given, model is not using DataParallel, because of this if-statement: https://github.com/valeoai/WaffleIron/blob/40f097ff203dd561ce6f9f642dcdb0b76d3ca05a/launch_train.py#L197-L200

So the keys of the parameters of my trained model, will not contain "module." in their name.

Then I run this command for validation:

python launch_train.py \
--dataset nuscenes \
--path_dataset /path/to/nuscenes/ \
--log_path ./logs/WaffleIron-48-384__nuscenes/ \
--config ./configs/WaffleIron-48-384__nuscenes.yaml \
--fp16 \
--restart \
--eval

Because in this validation command there is no "--gpu 0" argument, the program will end up in this else-statement: https://github.com/valeoai/WaffleIron/blob/40f097ff203dd561ce6f9f642dcdb0b76d3ca05a/launch_train.py#L201-L203

As you can see the model is now wrapped around DataParallel., which means it will expect parameters name to contain "module.".

Because my model does not include DataParallel, the code will break and throw an exeception that model parameters do not match.

I would suggest you to add in launch_train.py , a variation of the lines that you added in eval_nuscenes.py: https://github.com/valeoai/WaffleIron/blob/40f097ff203dd561ce6f9f642dcdb0b76d3ca05a/eval_nuscenes.py#L95-L102

If you think my suggested change is meaningful and helpful, I could also make a pull request for that.

SpeedyGonzales949 commented 1 month ago

An easier solution would be also to add "--gpu 0" to the validation command.

gpuy commented 3 weeks ago

Thank you for reporting this and for the update in the readme file.

Indeed, if the model is trained on one GPU with --gpu 0, then --gpu 0 must be included when restarting from the checkpoint, both when restarting a stopped training or when evaluating on the val set via launch_train.py.