GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
212 stars 26 forks source link

Function to revert backward trajectories #109

Open saleml opened 1 year ago

saleml commented 1 year ago

In previous versions of the code, when actions were integers, we had this function that reverts backward trajectories. It's not used as part of the codebase, but I remember using it for another project (probably GFN vs HVI). I just removed it (in an upcoming PR), and it would be nice to fix it and have it back

    @staticmethod
    def revert_backward_trajectories(trajectories: Trajectories) -> Trajectories:
        """Reverses a trajectory, but not compatible with continuous GFN. Remove."""
        # TODO: this isn't used anywhere - it doesn't work as it assumes that the
        # actions are ints. Do we need it?
        assert trajectories.is_backward
        new_actions = torch.full_like(trajectories.actions, -1)
        new_actions = torch.cat(
            [new_actions, torch.full((1, len(trajectories)), -1)], dim=0
        )

        # env.sf should never be None unless something went wrong during class
        # instantiation.
        if trajectories.env.sf is None:
            raise AttributeError(
                "Something went wrong during the instantiation of environment {}".format(
                    trajectories.env
                )
            )

        new_states = trajectories.env.sf.repeat(
            trajectories.when_is_done.max() + 1, len(trajectories), 1
        )
        new_when_is_done = trajectories.when_is_done + 1

        for i in range(len(trajectories)):
            new_actions[trajectories.when_is_done[i], i] = (
                trajectories.env.n_actions - 1
            )

            new_actions[: trajectories.when_is_done[i], i] = trajectories.actions[
                : trajectories.when_is_done[i], i
            ].flip(0)

            new_states[
                : trajectories.when_is_done[i] + 1, i
            ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
                0
            )

        new_states = trajectories.env.States(new_states)

        return Trajectories(
            env=trajectories.env,
            states=new_states,
            actions=new_actions,
            log_probs=trajectories.log_probs,
            when_is_done=new_when_is_done,
            is_backward=False,
        )