eloialonso / iris

Transformers are Sample-Efficient World Models. ICLR 2023, notable top 5%.
https://openreview.net/forum?id=vhFu1Acb0xb
GNU General Public License v3.0
804 stars 80 forks source link

Training in multi-GPU system #9

Closed lyp741 closed 1 year ago

lyp741 commented 1 year ago

Thanks for your work! Can you tell me how to train the algorithm with multi-GPU systems? I have roughly read the codes, but I didn't find any code about distributed training. And I also want to know how long it will take in a single 3090 desktop PC. Thank you for your help!

vmicheli commented 1 year ago

Thanks! It should be straightforward to adapt the code for data parallelism (https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html).

Training IRIS with the default configuration takes roughly 4 days on a 3090 GPU. If you encounter memory errors, you can reduce the memory requirements by decreasing the batch size batch_num_samples and increasing the number of gradient accumulation steps grad_acc_steps in config/trainer.yaml.

Hope that helps!