Eclectic-Sheep / sheeprl

Distributed Reinforcement Learning accelerated by Lightning Fabric
https://eclecticsheep.ai
Apache License 2.0
300 stars 29 forks source link

Dimension of batches in Dreamer-V1 implementation #92

Closed elisaparga19 closed 11 months ago

elisaparga19 commented 11 months ago

Hey!

I'm trying to understand the implementation of Dreamer-V1. When you train the agent, data['actions'] has shape (sequence_length, batch_size, actions_dim)? Also when you do: embedded_obs = world_model.encoder(batch_obs) in your training loop, the batch_obs have shape (sequence_length, batch_size, obs_dim)? or do you flatten those samples to process them with a neural network?

Thanks in advance!

belerico commented 11 months ago

Hi @elisaparga19!

  1. Yes, data["actions"] has exactly that shape: seq_len x batch_size x action_dims
  2. Both the encoder and the decoder in all the dreamer implementations, when they encode/decode, flatten all the batch dimensions, run the forward and reshape back the flattened dimensions. So for example: if you have a tensor of shape [*, C, H, W] the encoder will first flatten all the * dimensions (for the Dreamer-V1 case those are exactly seq_len x batch_size), run the forward on the resulting 4D tensor (as the nn.Conv2d expects) and reshape the output back to [*, ...]
elisaparga19 commented 11 months ago

Thanks! So, the actor does the same as the encoder to process the actions? It flattens the tensor of shape seq_len x batch_size x actions_dims into seq_len*batch_size, actions_dim and then reshapes the output to the original dimension?

belerico commented 11 months ago

Nope, the actor takes as input the state from the prior or posterior model (it depends in which phase of the experiment you're calling the actor: during the behaviour learning it takes the state coming from the prior model (the one that imagines), while during the environment interaction it takes the state coming from the posterior model (the one that learns the dynamic of the environment conditioned on the observation received by it and its history)). The one that does it all is the Player agent, which incorporates the recurrent model, the posterior model, the encoder and the actor: you can find it in the agent.py of Dreamer-v2