Open FrankTianTT opened 10 months ago
Thanks for raising this, let me think about it, I guess we can do that in a vectorized way!
Thanks for raising this, let me think about it, I guess we can do that in a vectorized way!
@vmoens Yes, I also think the efficiency of these loops maybe low. However, I found a new issue again, which ask we to change the rollout
of env
. Maybe the manipulation of reward can be done in this new rollout
function.
The new problem is that the default rollout
function will stop when any one env throws done
:
optim_tensordict = self.env.rollout(
max_steps=self.planning_horizon,
policy=policy,
auto_reset=False,
tensordict=optim_tensordict,
)
but we need a rollout
function that stops when all env
are done
. A new rollout
function can be (it's a method of MPCPlannerBase
, rather that EnvBase
):
def reward_truncated_rollout(self, policy, tensordict):
tensordicts = []
ever_done = torch.zeros(*tensordict.batch_size, 1, dtype=bool).to(self.device)
for i in range(self.planning_horizon):
tensordict = policy(tensordict)
tensordict = self.env.step(tensordict)
tensordict.get(("next", "reward"))[ever_done] = 0
tensordicts.append(tensordict)
ever_done |= tensordict.get(("next", "done"))
if ever_done.all():
break
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
out_td = torch.stack(tensordicts, len(batch_size)).contiguous()
out_td.refine_names(..., "time")
return out_td
I conduct the reward-truncation in this new rollout
in a vectorized way (according to ever_done
). With this new rollout
, the only change in planning
is changing optim_tensordict = self.env.rollout(...)
to optim_tensordict = self.reward_truncated_rollout(...)
, which means no more manipulation on optim_tensordict.get(self.reward_key)
.
The new problem is that the default
rollout
function will stop when any one env throwsdone
:
There is a break_when_any_done
argument that can be used to manipulate this.
Would that solve it?
There is a
break_when_any_done
argument that can be used to manipulate this. Would that solve it?
@vmoens Unfortunately not, when break_when_any_done=False
, all envs will be reseted when any one env throws done
. However, we do not want reset any of them (In fact, if rollout
just reset the env that throws done
, it's acceptable. But it reset all.)
@vmoens Unfortunately not, when break_when_any_done=False, all envs will be reseted when any one env throws done.
In theory, rollout will only reset the envs that are done. You can check in the doc how this is done: we assign a "_reset"
key corresponding to each done that indicates what should be reset. Which env are you referring to, exactly?
However, we do not want reset any of them (In fact, if rollout just reset the env that throws done, it's acceptable. But it reset all.)
It is a problem if the envs are reset (assuming this is done properly)? Do you mean that the env.step()
should just stop being called if the sub-env is done?
We do not cover this as of now, but it could be a feature for a distant release.
It won't be easy to come by though (at least efficiently): "_reset"
brings a lot of overhead that we hope to mitigate by calling "_reset"
only rarely.
In theory, rollout will only reset the envs that are done. You can check in the doc how this is done: we assign a "_reset" key corresponding to each done that indicates what should be reset.
Yep, you are right, my fault.
Which env are you referring to, exactly?
ModelBasedEnvBase
, to be more specific, example code in https://github.com/pytorch/rl/pull/1657, where we write a new _reset
that does not support individual reset:
def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
tensordict = TensorDict(
{},
batch_size=self.batch_size,
device=self.device,
)
tensordict = tensordict.update(self.state_spec.rand())
tensordict = tensordict.update(self.observation_spec.rand())
return tensordict
It seem to be available to make _reset
function sensitive to "_reset"
key.
It is a problem if the envs are reset (assuming this is done properly)? Do you mean that the env.step() should just stop being called if the sub-env is done?
My expectation for the rollout
function is that when multiple environments are running at the same time, even if one environment throws done
, it will continue to run until 1) all environments have thrown done
or 2) max_steps
is reached.
Note that even if we implement the logic of handling "_reset"
in the _reset
function, the original rollout
is still different from what I expected: it will only stop when max_steps
is reached, and will not stop early (because it has reset the done
environment). This will lead to a reduction in the efficiency of CEM (because each rollout
must run all max_steps
steps, but early stopping is very common in some environments, such as Hopper-v4
)
In short, all the above problems are caused by trying to let MPCPlannerBase
handle done
. If the environment will never done
, there is not much problem with the algorithm. To do it, there at least two issues:
reset
individual env rather than all when one env throws done
done
In my opinion, a special rollout
function of MPCPlannerBase
seem to be reasonable, because these special function maybe not necessary for EnvBase
.
BTW, to make a more effective planning, we should torch.no_grad()
the rollout
, because gradient is not needed in CEM.
ModelBaseEnvBase
to match the partial reset, this should have been done a long time ago.rollout
step early we can design a special transform that keeps track of the dones, wdyt?
base_env = MyBatchedEnv()
env = TransformedEnv(base_env, AggregateDone(rule="all")) # rule can be all, in which case `done` is `True` when all envs have been done at least once, or `any` in which case just one done suffies.
env.rollout(1_000) # will stop at 1000 or when all envs have encountered a done, whichever comes first
- To make the
rollout
step early we can design a special transform that keeps track of the dones, wdyt?
You are right, and we can even do the reward-truncated in this special transform, and make a just tiny change to planning
function (add torch.no_grad()
). I'm glad to submit a PR to conduct that.
BTW, there are some tiny different between CEM
in torchrl and that in mbrl-lib: update mean and std of actions by momentum, rather than direct assignment:
change from
for _ in range(self.optim_steps):
...
container.set_(("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True))
container.set_(("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True))
...
to
for _ in range(self.optim_steps):
...
self.update_stats(
best_actions.mean(dim=K_DIM, keepdim=True),
best_actions.std(dim=K_DIM, keepdim=True),
container
)
...
def update_stats(self, means, stds, container):
self.alpha = 0.1 # should in __init__
new_means = self.alpha * container.get(("stats", "_action_means")) + (1 - self.alpha) * means
new_stds = self.alpha * container.get(("stats", "_action_stds")) + (1 - self.alpha) * stds
container.set_(("stats", "_action_means"), new_means)
container.set_(("stats", "_action_stds"), new_stds)
to restore original behaviour, just set self.alpha=0
. Do you think I should add these in the same PR, or create a new one?
A new PR would defo make a lot of sense for this!
Describe the bug
In the current implementation, all subclasses of
MPCPlannerBase
do not considerdone
thrown by env during the planning process, which means that MPC is invalid in a large class of environments. For example, in CEM:Specifically, one type of environment indicates that the agent has entered a dangerous state by throwing
done
(usually the reward is positive in non-dangerous states), including many environments of gym-mujoco, such as InvertedPendulum and Hopper. The MPC algorithm needs to identifydone
thrown by the environment and find the action sequence that maximizes the cumulative reward beforedone
.To Reproduce
Just try CEM on InvertedPendulum.
Reason and Possible fixes
For CEM, a simple fix chould be:
I'm more than happy to submit my changes, but they may require further style uniformity and standardization. At the same time, it is likely that there is a more efficient way.
Checklist