While I'm taking a close look in the imagine() function in the world model,
I wonder why the gradient from the input feature to the actor should be stopped.
WorldModel's imagine fuction (agent.py)
def imagine(self, policy, start, is_terminal, horizon):
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
start['feat'] = self.rssm.get_feat(start)
start['action'] = tf.zeroslike(policy(start['feat']).mode())
seq = {k: [v] for k, v in start.items()}
for in range(horizon):
action = policy(tf.stop_gradient(seq['feat'][-1])).sample()
In my opinion, for the full gradient from the initial state to the last step of the sequence, shouldn't the 'feat' flow through the computation graph without the stop gradient? I just wonder why there is a stop gradient. have you tried the code without the stop gradient? What was the result like?
I'm struggling to find out the reason for the stop gradient and ask it here for help.
Thanks!
Hi,
While I'm taking a close look in the imagine() function in the world model, I wonder why the gradient from the input feature to the actor should be stopped.
WorldModel's imagine fuction (agent.py)
def imagine(self, policy, start, is_terminal, horizon): flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) start = {k: flatten(v) for k, v in start.items()} start['feat'] = self.rssm.get_feat(start) start['action'] = tf.zeroslike(policy(start['feat']).mode()) seq = {k: [v] for k, v in start.items()} for in range(horizon): action = policy(tf.stop_gradient(seq['feat'][-1])).sample()
In my opinion, for the full gradient from the initial state to the last step of the sequence, shouldn't the 'feat' flow through the computation graph without the stop gradient? I just wonder why there is a stop gradient. have you tried the code without the stop gradient? What was the result like?
I'm struggling to find out the reason for the stop gradient and ask it here for help. Thanks!