Closed edoust closed 3 months ago
Yeah, it definitely makes sense to decouple the implementation from any particular environment library. Made a few small changes to your commit:
dummy_observation
to match the replay bufferThoughts?
Looks good to me...
Would these changes break envs with dictionary observation spaces? It seems like we're assuming the observation is a single vector.
I guess the code currently doesn't support dictionary observation spaces anyways, so it's not a huge deal.
This shouldn't impose any constraints on the observation, but the ArrayLike type hint probably needs to change. Since the dummy observation is only processed by encoder_module.init
, you should be able to modify the encoder layers to handle whatever observation structures you want.
I'll double check this tomorrow.
Apologies for the late response; I've been trying to finalize the other active PR before modifying the code for this one.
What do you all think about doing away with the encoder_module
, encoder_optim
and observation space (or dummy observation) args and just having the user define the encoder TrainState externally? I like this idea from a flexibility perspective. Since the encoder is the main design degree of freedom across state representations/modalities, I think it makes a lot of sense to give the user full control over its definition.
After toying with this more over the weekend, I'm leaning heavily towards separating the encoder and world model definitions entirely. It would add quite a bit of flexibility for custom environment types and exotic encoder models.
The commits above show the proposed changes integrated into the most recent develop branch push. A custom encoder can be defined as in lines 113-126 of train.py
. Curious to hear your thoughts @edoust @edwhu
Decoupling the encoder makes sense to me in general. This opens the door to using more flexible observation spaces like dictionaries.
However, I think the replay buffer (and maybe other parts of the codebase) still expect the observation to be a vector, which could be a problem.
The code base actually supports dict observations already, and I've used it for some of my dict-based environments. If you look in the sequential buffer and the world model, all operations on the observations get tree mapped for this reason.
From there, I leave it up to the encoder definition to handle dict spaces as they want (e.g., concatenation, separate network heads, and so on).
Great, didn't see the tree map. That's nice!
Nice, those are some good changes.
I made minor changes to remove a few lines that won't work with non-gymnasium environments, now there is mainly the single_action_space
and the keys of the RecordEpisodeStatistics
wrapper remaining, which are not compatible
I think that's fine, since training on other environments requires more changes anyways
It appears that something in this pull request causes a significant performance degradation - the agent is stuck at ~3000 reward in the halfcheetah gym env. I get good performance in the latest commit to develop. Hoping this is not the case, but it's possible something broke in afcf3cb158adfceee920b068c2ff23db610a36f7. It will be a day or two before I can look into it further.
EDIT: I think there was actually a problem on my end. Looks good in my current run.
I tested your tdmpc2 implementation on some JAX environments (
brax
,gymnax
usingvmap
on thestep
andreset
functions)One issue with training non-gymnasium environments is the requirement to pass in gymnasium spaces into the world model
This PR removes that requirement, which facilitates creating and running independent training scripts without gymnasium