Closed EelcoHoogendoorn closed 1 year ago
Implementing an inverse pendulum in jax is also just a handful of LOCs; that might make for an even more transparent minimal example. Just need to reverse engineer the config and the env interface to make that work I suppose.
Maybe somebody wants to write a nanodreamer ;)
For the sake of dissemination of ideas, thatd be awesome indeed. Dont get me wrong; this is high up there in terms of code publishing standards; but its tied up in a lot of generality; and being able to see the most minimal implementation of the idea in action is very powerful imo.
Absent that; do you foresee any gotchas in implementing a self-contained minimal env, like a pendulum?
I don't see any issues with that. If you train with a single env instance, remove much of the logging, and use an existing library for the neural net layers then you can probably reasonably fit the whole agent code into a single file. Whether the env is included or loaded from somewhere doesn't make much of a difference IMO.
By the way, the pendulum environment should work. Are you sure you didn't point to a logdir with an existing incompatible checkpoint? Otherwise, it might be that the code expects actions to be a vector (possibly of length 1) rather than a scalar.
I was testing the Bipedal Walker as an example env without an image as part of the state which I think the Pendulum is as well. I found I had to make the following changes.
In example.py The changes to the config:
config = config.update({
'logdir': f'logdir/{int(time.time())}', # this was just changed to generate a new log dir every time for testing
'run.train_ratio': 64,
'run.log_every': 30,
'batch_size': 16,
'jax.prealloc': False,
'encoder.mlp_keys': '.*',
'decoder.mlp_keys': '.*',
'encoder.cnn_keys': '$^',
'decoder.cnn_keys': '$^',
'jax.platform': 'cpu', # I don't have a gpu locally
})
Changes to the env
env = gym.make("BipedalWalker-v3") # this needs box2d-py installed also
env = from_gym.FromGym(env, obs_key='state_vec') # I found I had to specify a different obs_key than the default of 'image'
I hope this helps.
Thanks for the input; that did get me closer. Seems I have some mujoco install issues to take care of though.
On a related note: running this:
from dm_control.suite.pendulum import SwingUp
env = SwingUp()
from embodied.envs import from_dm
env = from_dm.FromDM(env)
TypeError: Task.observation_spec() missing 1 required positional argument:
'physics'
Thats with dm-control 1.0.10
Is the dm-wrapper not supposed to work with the dm-control subpackage?
Whether the env is included or loaded from somewhere doesn't make much of a difference IMO.
Well; except for the fact that those external dependencies will inevitably lead to lazy people such as myself whining about them not working here on your issue tracker :)
@EelcoHoogendoorn I think the SwingUP
class you're using is an environment without a task. If you use suite.load()
then it returns a full environment that can be used with FromDM
. If that doesn't work, please let me know and I'll look into it further.
Now I have this with the same config as above:
env = suite.load(
"pendulum",
"swingup",
# task_kwargs={"time_limit", float("inf")},
)
It gets me closer:
│ /Users/.../dreamerv3/dreamerv3/nets.py:92 in <lambda> │
│ │
│ 89 │ prev_state, prev_action = jax.tree_util.tree_map( │
│ 90 │ │ lambda x: self._mask(x, 1.0 - is_first), (prev_state, prev_act │
│ 91 │ prev_state = jax.tree_util.tree_map( │
│ ❱ 92 │ │ lambda x, y: x + self._mask(y, is_first), │
│ 93 │ │ prev_state, self.initial(len(is_first))) │
│ 94 │ prior = self.img_step(prev_state, prev_action) │
│ 95 │ x = jnp.concatenate([prior['deter'], embed], -1) │
TypeError: add got incompatible shapes for broadcasting: (1, 1024), (1, 512).
But not quite
Maybe you're loading an incompatible checkpoint this time, where the model size changed? Otherwise, if you can print the shapes of all state entries in that line (of both prev_state
and self.initial(len(is_first))
I might be able to guess what's going on.
Ah yes the shape mismatch seemed to be a model checkpoint problem; seems to be running now!
I was testing the Bipedal Walker as an example env without an image as part of the state which I think the Pendulum is as well. I found I had to make the following changes.
In example.py The changes to the config:
config = config.update({ 'logdir': f'logdir/{int(time.time())}', # this was just changed to generate a new log dir every time for testing 'run.train_ratio': 64, 'run.log_every': 30, 'batch_size': 16, 'jax.prealloc': False, 'encoder.mlp_keys': '.*', 'decoder.mlp_keys': '.*', 'encoder.cnn_keys': '$^', 'decoder.cnn_keys': '$^', 'jax.platform': 'cpu', # I don't have a gpu locally })
Changes to the env
env = gym.make("BipedalWalker-v3") # this needs box2d-py installed also env = from_gym.FromGym(env, obs_key='state_vec') # I found I had to specify a different obs_key than the default of 'image'
I hope this helps.
Thanks for the config setup. It works! Do you know the difference between obs_key='state_vec' and 'vector'.
No difference from what I'm aware. I'm not sure why I had to change the obs_key at the time. I may have been using an older version of the code.
Thanks for sharing code in general; and even more thanks for providing the crafter example, which 'just works' on the first try!
That being said; I was trying to find the most minimal env, to permit me to run and step through the code. To that end I tried a bunch of things; replacing
With
This however runs into an error; and not one that I have solved yet. (see below [^1])
What would be really nice if there was a gymnax wrapper; gymnax has a bunch of simple envs implemented in jax, so in terms of getting something converging fast with minimal dependencies, I think thatd be ideal.
It has a gym-like api but I also did not get it to work as a drop in replacement. I hope to get it to work though, if successful ill put in the work to make a tidy pr out of it.
[^1] heres the error on the gym classic pendulum; I suspect the default config pertaining to image vs mlp inputs is incorrect, but I cant find any documentation pertaining to the config object.