danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.36k stars 230 forks source link

Training on Gym Pendulum #17

Closed EelcoHoogendoorn closed 1 year ago

EelcoHoogendoorn commented 1 year ago

Thanks for sharing code in general; and even more thanks for providing the crafter example, which 'just works' on the first try!

That being said; I was trying to find the most minimal env, to permit me to run and step through the code. To that end I tried a bunch of things; replacing

  import crafter
  env = crafter.Env()  # Replace this with your Gym env.

With

  from gym.envs.classic_control import pendulum
  env = pendulum.PendulumEnv()

This however runs into an error; and not one that I have solved yet. (see below [^1])

What would be really nice if there was a gymnax wrapper; gymnax has a bunch of simple envs implemented in jax, so in terms of getting something converging fast with minimal dependencies, I think thatd be ideal.

  from gymnax.environments.classic_control import Pendulum
  env = Pendulum()

It has a gym-like api but I also did not get it to work as a drop in replacement. I hope to get it to work though, if successful ill put in the work to make a tidy pr out of it.

[^1] heres the error on the gym classic pendulum; I suspect the default config pertaining to image vs mlp inputs is incorrect, but I cant find any documentation pertaining to the config object.

│   214 │     self._mlp = MLP(None, mlp_layers, mlp_units, dist='none', **mlp_ │
│   215                                                                        │
│   216   def __call__(self, data):                                            │
│ ❱ 217 │   some_key, some_shape = list(self.shapes.items())[0]                │
│   218 │   batch_dims = data[some_key].shape[:-len(some_shape)]               │
│   219 │   data = {                                                           │
│   220 │   │   k: v.reshape((-1,) + v.shape[len(batch_dims):])                │
╰──────────────────────────────────────────────────────────────────────────────╯
IndexError: list index out of range
EelcoHoogendoorn commented 1 year ago

Implementing an inverse pendulum in jax is also just a handful of LOCs; that might make for an even more transparent minimal example. Just need to reverse engineer the config and the env interface to make that work I suppose.

danijar commented 1 year ago

Maybe somebody wants to write a nanodreamer ;)

EelcoHoogendoorn commented 1 year ago

For the sake of dissemination of ideas, thatd be awesome indeed. Dont get me wrong; this is high up there in terms of code publishing standards; but its tied up in a lot of generality; and being able to see the most minimal implementation of the idea in action is very powerful imo.

EelcoHoogendoorn commented 1 year ago

Absent that; do you foresee any gotchas in implementing a self-contained minimal env, like a pendulum?

danijar commented 1 year ago

I don't see any issues with that. If you train with a single env instance, remove much of the logging, and use an existing library for the neural net layers then you can probably reasonably fit the whole agent code into a single file. Whether the env is included or loaded from somewhere doesn't make much of a difference IMO.

By the way, the pendulum environment should work. Are you sure you didn't point to a logdir with an existing incompatible checkpoint? Otherwise, it might be that the code expects actions to be a vector (possibly of length 1) rather than a scalar.

ChrisAGBlake commented 1 year ago

I was testing the Bipedal Walker as an example env without an image as part of the state which I think the Pendulum is as well. I found I had to make the following changes.

In example.py The changes to the config:

    config = config.update({
        'logdir': f'logdir/{int(time.time())}', # this was just changed to generate a new log dir every time for testing
        'run.train_ratio': 64,
        'run.log_every': 30,
        'batch_size': 16,
        'jax.prealloc': False,
        'encoder.mlp_keys': '.*',
        'decoder.mlp_keys': '.*',
        'encoder.cnn_keys': '$^',
        'decoder.cnn_keys': '$^',
        'jax.platform': 'cpu', # I don't have a gpu locally
    })

Changes to the env

env = gym.make("BipedalWalker-v3") # this needs box2d-py installed also
env = from_gym.FromGym(env, obs_key='state_vec') # I found I had to specify a different obs_key than the default of 'image'

I hope this helps.

EelcoHoogendoorn commented 1 year ago

Thanks for the input; that did get me closer. Seems I have some mujoco install issues to take care of though.

On a related note: running this:

  from dm_control.suite.pendulum import SwingUp
  env = SwingUp()
  from embodied.envs import from_dm
  env = from_dm.FromDM(env)

TypeError: Task.observation_spec() missing 1 required positional argument: 
'physics'

Thats with dm-control 1.0.10

Is the dm-wrapper not supposed to work with the dm-control subpackage?

Whether the env is included or loaded from somewhere doesn't make much of a difference IMO.

Well; except for the fact that those external dependencies will inevitably lead to lazy people such as myself whining about them not working here on your issue tracker :)

danijar commented 1 year ago

@EelcoHoogendoorn I think the SwingUP class you're using is an environment without a task. If you use suite.load() then it returns a full environment that can be used with FromDM. If that doesn't work, please let me know and I'll look into it further.

EelcoHoogendoorn commented 1 year ago

Now I have this with the same config as above:

  env = suite.load(
      "pendulum",
      "swingup",
      # task_kwargs={"time_limit", float("inf")},
  )

It gets me closer:

│ /Users/.../dreamerv3/dreamerv3/nets.py:92 in <lambda>    │
│                                                                              │
│    89 │   prev_state, prev_action = jax.tree_util.tree_map(                  │
│    90 │   │   lambda x: self._mask(x, 1.0 - is_first), (prev_state, prev_act │
│    91 │   prev_state = jax.tree_util.tree_map(                               │
│ ❱  92 │   │   lambda x, y: x + self._mask(y, is_first),                      │
│    93 │   │   prev_state, self.initial(len(is_first)))                       │
│    94 │   prior = self.img_step(prev_state, prev_action)                     │
│    95 │   x = jnp.concatenate([prior['deter'], embed], -1)                   │

TypeError: add got incompatible shapes for broadcasting: (1, 1024), (1, 512).

But not quite

danijar commented 1 year ago

Maybe you're loading an incompatible checkpoint this time, where the model size changed? Otherwise, if you can print the shapes of all state entries in that line (of both prev_state and self.initial(len(is_first)) I might be able to guess what's going on.

EelcoHoogendoorn commented 1 year ago

Ah yes the shape mismatch seemed to be a model checkpoint problem; seems to be running now!

hanshuo-shuo commented 1 year ago

I was testing the Bipedal Walker as an example env without an image as part of the state which I think the Pendulum is as well. I found I had to make the following changes.

In example.py The changes to the config:

    config = config.update({
        'logdir': f'logdir/{int(time.time())}', # this was just changed to generate a new log dir every time for testing
        'run.train_ratio': 64,
        'run.log_every': 30,
        'batch_size': 16,
        'jax.prealloc': False,
        'encoder.mlp_keys': '.*',
        'decoder.mlp_keys': '.*',
        'encoder.cnn_keys': '$^',
        'decoder.cnn_keys': '$^',
        'jax.platform': 'cpu', # I don't have a gpu locally
    })

Changes to the env

env = gym.make("BipedalWalker-v3") # this needs box2d-py installed also
env = from_gym.FromGym(env, obs_key='state_vec') # I found I had to specify a different obs_key than the default of 'image'

I hope this helps.

Thanks for the config setup. It works! Do you know the difference between obs_key='state_vec' and 'vector'.

ChrisAGBlake commented 1 year ago

No difference from what I'm aware. I'm not sure why I had to change the obs_key at the time. I may have been using an older version of the code.