ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
42 stars 5 forks source link

Issues&Questions about vision observation #10

Open ziyc opened 1 week ago

ziyc commented 1 week ago

Hi @ShaneFlandermeyer, thanks you so much for brining out this nice work! Its really fast!

I encountered an issue when I tried to use rgb for observation by python train.py env.backend=dmc env.env_id=dog-run env.dmc.obs_type=rgb

got this error (note that the code works well when env.dmc.obs_type=state in the same command): image

Does this codebase currently support vision observation? If it supports, could you please provides some insights about its convergence speed? I noticed that for the dog-run task with state as observation, the training time is about 2 hours on my 4090. If using RGB or depthmap as observation, how long might it possibly take?

Thank you for your time!

ShaneFlandermeyer commented 1 week ago

Hi,

While there's no built-in option for image inputs, it should be easy to change the encoder to a CNN by replacing lines 78-85 in train.py with your vision model of choice. I have done this for graph networks in my personal research with no trouble, but let me know if you run into problems!

EDIT: I just saw that your error occurs prior to any model creation. I can take a look at what's going on in the DMC env creation this weekend.

ziyc commented 1 week ago

Hi @ShaneFlandermeyer, thank you so much for the rapid reply!

Yeah, the error occurs before model creation, I've fixed this issue by modifying the arguments of the reset() fn in PixelWrapper image

But here came the next issue, image

Does this jax version support rgb observation? From the code is seems to get this part implemented: https://github.com/ShaneFlandermeyer/tdmpc2-jax/blob/d9d95b51b04b39c45d6f16c0b907ace90871bcbf/tdmpc2_jax/envs/dmcontrol.py#L206-L208

but I can't successfully run any command with env.dmc.obs_type=rgb

I'm currently benchmarking tdmpc2 with rgb observation, I'm wondering how long it takes for a single task to converge, your insights would be very helpful to my research, thank you so much!

ShaneFlandermeyer commented 1 week ago

That appears to be an issue with dmc rather than this repo. Have you tried the instructions in the "Rendering" section of their readme?

https://github.com/google-deepmind/dm_control

ziyc commented 1 week ago

Thanks! I'll check it.