GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

`maskless_step` and `maskless_backward_step` should throw an error if you don't return a tensor #145

Closed josephdviviano closed 6 months ago

josephdviviano commented 10 months ago

Right now when writing environments, it's quite natural to manipulate the states object in place, and return the modified state. However the Env class expects a Tensor to be returned, not a States class instance.

E.g.,:

    def maskless_step(self, states: States, actions: Actions) -> TT["batch_shape", 1, torch.float]:
        states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze(-1)  # x position.
        states.tensor[..., 1] = states.tensor[..., 1] + 1  # Step counter.
        return states

throws a bug downstream,

[/usr/local/lib/python3.10/dist-packages/gfn/env.py](https://localhost:8080/#) in step(self, states, actions)
    158         #     new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)
    159 
--> 160         new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor
    161 
    162         return new_states

TypeError: can't assign a LineStates to a torch.FloatTensor

The fix is simply to return states.tensor.

josephdviviano commented 6 months ago

addressed in https://github.com/GFNOrg/torchgfn/pull/165