GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

No more class factories #149

Closed josephdviviano closed 6 months ago

josephdviviano commented 9 months ago

To be merged after #147

As a result of this, multiple methods are offloaded from the States class into the Env, and make_random_states_tensor must be passed from the Env to the States class, which accounts for a large number of these diffs.

As an example, see the below Env definition for the line environment, which is complete:

class Line(Env):
    """Mixture of Gaussians Line environment."""
    def __init__(
        self,
        mus: list,
        sigmas: list,
        init_value: float,
        n_steps_per_trajectory: int = 5,
        device_str: Literal["cpu", "cuda"] = "cpu",
    ):
        assert len(mus) == len(sigmas)
        self.mus = torch.tensor(mus)
        self.sigmas = torch.tensor(sigmas)
        self.n_sd = n_sd
        self.n_steps_per_trajectory = n_steps_per_trajectory
        self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)]

        s0 = torch.tensor([init_value, 0.0], device=torch.device(device_str))
        dummy_action = torch.tensor([float("inf")], device=torch.device(device_str))
        exit_action = torch.tensor([-float("inf")], device=torch.device(device_str))
        super().__init__(
            s0=s0,
            state_shape=(2,),  # [x_pos, step_counter].
            action_shape=(1,),  # [x_pos]
            dummy_action=dummy_action,
            exit_action=exit_action,
        )  # sf is -inf by default.

    def step(
        self, states: States, actions: Actions) -> TT["batch_shape", 2, torch.float]:
        states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze(-1)  # x position.
        states.tensor[..., 1] = states.tensor[..., 1] + 1  # Step counter.
        return states.tensor

    def backward_step(
        self, states: States, actions: Actions) -> TT["batch_shape", 2, torch.float]:
        states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze(-1)  # x position.
        states.tensor[..., 1] = states.tensor[..., 1] - 1  # Step counter.
        return states.tensor

    def is_action_valid(self, states: States, actions: Actions, backward: bool = False) -> bool:
        # Can't take a backward step at the beginning of a trajectory.
        if torch.any(states[~actions.is_exit].is_initial_state) and backward:
            return False

        return True

    def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
        s = final_states.tensor[..., 0]
        log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape)
        for i, m in enumerate(self.mixture):
            log_rewards[i] = m.log_prob(s)

        return torch.logsumexp(log_rewards, 0)

    @property
    def log_partition(self) -> float:
        """Log Partition log of the number of gaussians."""
        return torch.tensor(len(self.mus)).log()
marpaia commented 9 months ago

FYI @josephdviviano I changed the "base" branch to rethinking_sampling instead of master. This allows us to view this PRs changes in isolation. When you merge #147, this PR will automatically update to be based off of master again! Alternatively, you can merge this PR into #147 and then merge #147 into master and it will have the same effect. I would suggest merging #147 first though and then iterating on / merging this PR in isolation 🙌

josephdviviano commented 9 months ago

I implemented the renaming and also realized I needed to update the documentation which is now fixed.