In previous versions of the code, when actions were integers, we had this function that reverts backward trajectories. It's not used as part of the codebase, but I remember using it for another project (probably GFN vs HVI). I just removed it (in an upcoming PR), and it would be nice to fix it and have it back
@staticmethod
def revert_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a trajectory, but not compatible with continuous GFN. Remove."""
# TODO: this isn't used anywhere - it doesn't work as it assumes that the
# actions are ints. Do we need it?
assert trajectories.is_backward
new_actions = torch.full_like(trajectories.actions, -1)
new_actions = torch.cat(
[new_actions, torch.full((1, len(trajectories)), -1)], dim=0
)
# env.sf should never be None unless something went wrong during class
# instantiation.
if trajectories.env.sf is None:
raise AttributeError(
"Something went wrong during the instantiation of environment {}".format(
trajectories.env
)
)
new_states = trajectories.env.sf.repeat(
trajectories.when_is_done.max() + 1, len(trajectories), 1
)
new_when_is_done = trajectories.when_is_done + 1
for i in range(len(trajectories)):
new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
new_actions[: trajectories.when_is_done[i], i] = trajectories.actions[
: trajectories.when_is_done[i], i
].flip(0)
new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
0
)
new_states = trajectories.env.States(new_states)
return Trajectories(
env=trajectories.env,
states=new_states,
actions=new_actions,
log_probs=trajectories.log_probs,
when_is_done=new_when_is_done,
is_backward=False,
)
In previous versions of the code, when actions were integers, we had this function that reverts backward trajectories. It's not used as part of the codebase, but I remember using it for another project (probably GFN vs HVI). I just removed it (in an upcoming PR), and it would be nice to fix it and have it back