danijar / dreamerv2

Mastering Atari with Discrete World Models
https://danijar.com/dreamerv2
MIT License
886 stars 195 forks source link

Question about Plan2explore #23

Closed TachikakaMin closed 2 years ago

TachikakaMin commented 2 years ago

For Plan2Explore, in expl.py the Class Plan2Explore will have a world model.

class Plan2Explore(common.Module):

  def __init__(self, config, act_space, wm, tfstep, reward):
    self.config = config
    self.reward = reward
    self.wm = wm

And this model will be WorldModel which is the same as dreamerv2.

class Agent(common.Module):

  def __init__(self, config, obs_space, act_space, step):
    self.config = config
    self.obs_space = obs_space
    self.act_space = act_space['action']
    self.step = step
    self.tfstep = tf.Variable(int(self.step), tf.int64)
    self.wm = WorldModel(config, obs_space, self.tfstep)
    self._task_behavior = ActorCritic(config, self.act_space, self.tfstep)
    if config.expl_behavior == 'greedy':
      self._expl_behavior = self._task_behavior
    else:
      self._expl_behavior = getattr(expl, config.expl_behavior)(
          self.config, self.act_space, self.wm, self.tfstep,
          lambda seq: self.wm.heads['reward'](seq['feat']).mode())

For worldmodel training, the code will encode all information include reward into encoder

def loss(self, data, state=None):
    data = self.preprocess(data)
    embed = self.encoder(data)
def preprocess(self, obs):
    dtype = prec.global_policy().compute_dtype
    obs = obs.copy()
    for key, value in obs.items():
      if key.startswith('log_'):
        continue
      if value.dtype == tf.int32:
        value = value.astype(dtype)
      if value.dtype == tf.uint8:
        value = value.astype(dtype) / 255.0 - 0.5
      obs[key] = value
    obs['reward'] = {
        'identity': tf.identity,
        'sign': tf.sign,
        'tanh': tf.tanh,
    }[self.config.clip_rewards](obs['reward'])
    obs['discount'] = 1.0 - obs['is_terminal'].astype(dtype)
    obs['discount'] *= self.config.discount
    return obs
class Encoder(common.Module):
  def _cnn(self, data):
    x = tf.concat(list(data.values()), -1)

But Plan2explore says there should not be env reward.

danijar commented 2 years ago

The line of the encoder you're showing is only applied to image observations, not the reward. You can choose which observation keys should be used via these config options:

encoder.mlp_keys: '.*'
encoder.cnn_keys: '.*'
decoder.mlp_keys: '.*'
decoder.cnn_keys: '.*'

But only images (rank 3 tensor) and vectors (rank 1 tensor) are supported. The reward is scalar (rank 0 tensor).