ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
43 stars 6 forks source link

Removed world model's dependency on gymnasium spaces #6

Closed edoust closed 3 months ago

edoust commented 4 months ago

I tested your tdmpc2 implementation on some JAX environments (brax, gymnax using vmap on the step and reset functions)

One issue with training non-gymnasium environments is the requirement to pass in gymnasium spaces into the world model

This PR removes that requirement, which facilitates creating and running independent training scripts without gymnasium

ShaneFlandermeyer commented 4 months ago

Yeah, it definitely makes sense to decouple the implementation from any particular environment library. Made a few small changes to your commit:

Thoughts?

edoust commented 4 months ago

Looks good to me...

edwhu commented 4 months ago

Would these changes break envs with dictionary observation spaces? It seems like we're assuming the observation is a single vector.

edwhu commented 4 months ago

I guess the code currently doesn't support dictionary observation spaces anyways, so it's not a huge deal.

ShaneFlandermeyer commented 4 months ago

This shouldn't impose any constraints on the observation, but the ArrayLike type hint probably needs to change. Since the dummy observation is only processed by encoder_module.init, you should be able to modify the encoder layers to handle whatever observation structures you want.

I'll double check this tomorrow.

ShaneFlandermeyer commented 3 months ago

Apologies for the late response; I've been trying to finalize the other active PR before modifying the code for this one.

What do you all think about doing away with the encoder_module, encoder_optim and observation space (or dummy observation) args and just having the user define the encoder TrainState externally? I like this idea from a flexibility perspective. Since the encoder is the main design degree of freedom across state representations/modalities, I think it makes a lot of sense to give the user full control over its definition.

ShaneFlandermeyer commented 3 months ago

After toying with this more over the weekend, I'm leaning heavily towards separating the encoder and world model definitions entirely. It would add quite a bit of flexibility for custom environment types and exotic encoder models.

The commits above show the proposed changes integrated into the most recent develop branch push. A custom encoder can be defined as in lines 113-126 of train.py. Curious to hear your thoughts @edoust @edwhu

edwhu commented 3 months ago

Decoupling the encoder makes sense to me in general. This opens the door to using more flexible observation spaces like dictionaries.

However, I think the replay buffer (and maybe other parts of the codebase) still expect the observation to be a vector, which could be a problem.

ShaneFlandermeyer commented 3 months ago

The code base actually supports dict observations already, and I've used it for some of my dict-based environments. If you look in the sequential buffer and the world model, all operations on the observations get tree mapped for this reason.

From there, I leave it up to the encoder definition to handle dict spaces as they want (e.g., concatenation, separate network heads, and so on).

edwhu commented 3 months ago

Great, didn't see the tree map. That's nice!

edoust commented 3 months ago

Nice, those are some good changes.

I made minor changes to remove a few lines that won't work with non-gymnasium environments, now there is mainly the single_action_space and the keys of the RecordEpisodeStatistics wrapper remaining, which are not compatible

I think that's fine, since training on other environments requires more changes anyways

ShaneFlandermeyer commented 3 months ago

It appears that something in this pull request causes a significant performance degradation - the agent is stuck at ~3000 reward in the halfcheetah gym env. I get good performance in the latest commit to develop. Hoping this is not the case, but it's possible something broke in afcf3cb158adfceee920b068c2ff23db610a36f7. It will be a day or two before I can look into it further.

EDIT: I think there was actually a problem on my end. Looks good in my current run.