nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
398 stars 94 forks source link

Possible training speedups #18

Open josephrocca opened 9 months ago

josephrocca commented 9 months ago

In the last few days I've been playing around trying to see how fast I can get a 19M model training on a single 4090. My somewhat arbitrary goal is 1 hour, down from about 24 hours (just on humanoid-walk for now). I was pretty surprised at how well simple torch.mean weight-merging worked (y-axis is reward):

image

The yellow line is the "control" run - i.e. no sharing/merging, and the bunched lines are 'workers' (separate Python processes) which share weights via the filesystem every so often. It seems like 4 workers with merging every 2.5k steps works quite well. The above chart shows 8 workers, but it only very marginally improves on 4 workers in my tests.

(Side note: Kind of counter-intuitive how 4 workers seem to give >4x speedup, but increasing beyond that barely improves speed. Might need to run some more tests to confirm - I've only tried 8 workers once, whereas I've tried 4 workers in several experiments.)

In case anyone is interested, the main code for this (in online_trainer.py while loop) is just:

Click to show code ```py # be sure to change the seed for each worker if cfg.do_multitrain_model_merging and self._step >= num_seed_steps and self._step > 0 and self._step % cfg.model_merge_freq == 0 and cfg.model_merge_index != 0: # zeroth model is the 'control' model model_path = f'/mnt/tdmpc2/tdmpc2/temp_worker_models/{cfg.exp_name}-{cfg.model_merge_index}.pt' lock_path = f'{model_path}.lock' with FileLock(lock_path, timeout=200): # prevent others from reading the model file while we are writing to it self.agent.save(model_path) agents = [] for filename in os.listdir('/mnt/tdmpc2/tdmpc2/temp_worker_models/'): # check regex match for filename `{cfg.exp_name}-[0-9]+.pt` regex_match = f'^{cfg.exp_name}-[0-9]+\\.pt$' if re.match(regex_match, filename): file_path = os.path.join('/mnt/tdmpc2/tdmpc2/temp_worker_models/', filename) lock_path = f'{file_path}.lock' # wait to acquire lock before reading the model file with FileLock(lock_path, timeout=200): print("Loading model:", filename) a = TDMPC2(cfg) a.load(file_path) agents.append(a) # merge all the models into the current agent's model model_names = ["_encoder", "_dynamics", "_reward", "_pi", "_Qs", "_target_Qs"] for model_name in model_names: print("merging model_name:", model_name) models = [getattr(a.model, model_name) for a in agents] for model in models: model.eval() with torch.no_grad(): for name, param in getattr(self.agent.model, model_name).named_parameters(): params_to_merge = [model.state_dict()[name] for model in models] mean_param = torch.mean(torch.stack(params_to_merge, dim=0), dim=0) param.data.copy_(mean_param) ```

I'm quite happy with a ~5x speedup for such a small change to the code, and am now on the hunt for another 5x to get it under an hour.

I'm wondering if Nick or anyone else here has any thoughts on some more low-hanging fruit? I thought perhaps that a generalized reward component that effectively aims to teach the agent "inverse kinematics" in the very early stages might be helpful (basically giving it desired global coordinates for random body parts), and I've started experimenting with ways to do that, but no luck so far.

Any loose thoughts/tips/ideas would be useful - I'm a newbie and haven't yet got great intuitions for trimming the experiment search space.

nicklashansen commented 9 months ago

This looks really interesting! I'll have to think about this a bit more, but just to clarify: does each worker only perform model updates or are they entirely separate training processes? I.e., does each worker collect their own data? I am planning to release an implementation soon that trades sample-efficiency with better wall-time using environment parallelization (stepping in multiple copies of the environment in parallel). I imagine that env parallelization and training parallelization like you're seemingly proposing here could be used simultaneously.

josephrocca commented 9 months ago

Yep, entirely separate training processes with different seeds. Their only interaction was via loading one anothers' weights from the filesystem and merging (via a simple torch.mean) with their own.

It's hacky, but just for reference I used hydra --multirun like this:

pip install hydra-joblib-launcher --upgrade
# 1. change `submitit_local` to `joblib` in config.yaml
# 2. add something like `if cfg.do_multitrain_model_merging: seed += int(cfg.model_merge_index)` in train.py
# 3. add code block from my previous comment within while loop in online_trainer.py
rm /mnt/tdmpc2/tdmpc2/temp_worker_models/* # ensure prev worker models are deleted
cd /mnt/tdmpc2/tdmpc2 && python train.py --multirun do_multitrain_model_merging=true model_merge_index=0,1,2,3,4 exp_name=merge1

I did also try a version where they all collect (obs, action, reward)s and share those with each other (just via .pt of List of TensorDicts on filesystem), but it didn't perform quite as well (per step and per wall-clock time). Basically my guess (RE wall-clock at least) is that merging the model weights allows them to share the compute they did in both agent.act and agent.update calls (which each take around half the training step IIRC - env step is negligible), whereas sharing (obs, action, reward) with one another means that they only share the agent.act compute that they did.

stepping in multiple copies of the environment in parallel

Very keen to try that out!

If you have any other ~orthogonal ideas that you think it might be worth me trying and reporting back here, please do let me know. I'd love to get humanoid-difficulty-level training below an hour on a 4090 - I'd basically be able to watch it learning in realtime!

Besides the "universal auto-curricula" (motor skills / inverse-kinematics during early learning, with motor task specified in dedicated obs values) thing, I was also pondering training a model on a large number of different morphologies/tasks such that it (hopefully) learns some internal abstractions that let it quickly pick up new morphologies/tasks. Akin to a "base model" in LLM land in the sense that it's not really good at anything as-is, but can be very quickly/easily fine-tuned to a task. I haven't played with multitask stuff at all yet - I should probably test that out as the first step in that direction.

nicklashansen commented 9 months ago

This all makes a lot of sense @josephrocca! I'd be very interested to see where this leads to :-)

Very keen to try that out!

I now have it up and running internally but still need to run some performance benchmarks. I will let you know as soon as it is ~ready for external testing!

Edit: since you're optimizing for wall-time, it might be worth playing around with the planning as well. You could either try to reduce the # of planning iterations per step, and/or execute several planned actions at a time without replanning. I would expect data-efficiency to take a hit (since actions are more suboptimal) but that might be tolerable in your setting. I would also expect 5M and 19M parameter models to perform fairly similarly on these single-task problems, so using the 5M model should decrease computational cost a bit as well.

nicklashansen commented 9 months ago

@josephrocca I have pushed my implementation to branch vectorized_env. I verified that it solves easy tasks from DMControl but have yet to try harder tasks + figure out what the right balance of envs and model updates are; num_envs=4 steps_per_update=2 seems to be a good starting point.

nicklashansen commented 9 months ago

Benchmark results on task=walker-walk. Orange is default; green is vectorized (num_envs = steps_per_update = 4). Approx. 8 min vs. 24 min. to converge on a 2080-Ti.

Screenshot 2024-02-11 at 4 57 46 PM
nicklashansen commented 9 months ago

task=humanoid-walk, again default vs. num_envs = steps_per_update = 4 on a 2080-Ti. Seems like it's also a ~3x speedup.

Screenshot 2024-02-12 at 3 04 46 PM
josephrocca commented 9 months ago

Oh this looks great! I've just kicked off a few experiments including testing vectorized, playing with iterations, model size, adding in the merging on top of the vectorized env stuff, and some other random stuff. Will report back 🫡

Edit: I haven't forgotten about this - some other work has stolen me away from tdmpc stuff for now. I did run some tests already but I think there was a bug in my training code and haven't yet looked into it.

ShaneFlandermeyer commented 8 months ago

Hi guys. The improvements in this thread are awesome to see!

I am also interested in speeding up this algorithm, so I created a (mostly) 1:1 port in Jax/Flax. I am happy to report a 5-10x speedup for the planning/update steps, which reduces the training time to ~2.5 hours/million time steps! I imagine that we could get it down to an hour or less using multiple workers. If anyone is interested in trying to combine them, I have put my implementation in the repo below!

https://github.com/ShaneFlandermeyer/tdmpc2-jax

josephrocca commented 8 months ago

Wow, Shane! I should be able to get back to tdmpc experiments soon 🤞 can't wait to try out your code and test with the other ideas/improvements in this thread. What a time to be alive.