nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

About Training #4

Closed VitaLemonTea1 closed 5 months ago

VitaLemonTea1 commented 7 months ago

Hi,thanks for your releasing code. I see that you and your member used 3090 to train the model, and gived the training time by single GPU. Therefore, I would like to ask whether your model can be trained by multiple GPUs and how long the training time is. Thanks!

nicklashansen commented 7 months ago

Hi @VitaLemonTea1, the code currently only supports single-GPU training. We plan to release multi-GPU training code at some point, but I don't have a date for that yet. You'd need to more or less follow the steps outlined here https://pytorch.org/tutorials/intermediate/ddp_tutorial.html to add support for it. I'm happy to work with you in getting it up and running.

Best, Nick

nicklashansen commented 6 months ago

@VitaLemonTea1 Update: I now have a working distributed implementation, which will be integrated into our public repo as a separate branch (https://github.com/nicklashansen/tdmpc2/tree/distributed) over the next week or so.

nicklashansen commented 5 months ago

@VitaLemonTea1 Experimental support for distributed training is now available on branch distributed; use argument world_size=N to train on N GPUs. Feedback is appreciated!