Closed josh0tt closed 1 year ago
I'm not able to confirm this, but this is very likely to be a pytorch lightning version control issue. Try v1.6. I had to upgrade several times over the course of the project. I thought this was in the requirements.txt but that's apparently missing from the public version. There is one more code update coming with the final version of this paper over the next few weeks, but I don't expect to future-proof this repo to pytorch lightning because there are just too many breaking changes. Probably applies to #64 as well.... these dependencies change fast and you need to use a late 2021 or early 2022 version of everything.
Edit: if you have to use a newer version of pytorch lightning, this could be a ddp vs. dp problem. ddp launches multiple processes and pickles everything between them. Because of the efficient attention layers I only ever trained models on one multi-gpu node, so the training scripts are not tested with distributed training. If you are training on one node make sure you're using the slower but simpler data parallel mode.