GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
223 stars 27 forks source link

Dont recompute masks #163

Closed josephdviviano closed 7 months ago

josephdviviano commented 7 months ago

This isn't working @saleml -- please see the issue in samplers.py.

You can reproduce the error with tutorials/examples/train_hypergrid_simple.py

saleml commented 7 months ago

Some debugging:

First, I changed the batch_size to 3 in the script. Then, with a breakpoint at the assertion error, I see that the forward masks of all the steps within the batch of 3 trajectories are the same.

So here is the interesting part. If I add a breakpoint at lines 195 and 199 of samplers.py, I get the following:

Screenshot 2024-02-24 at 12 58 53 PM

The first mask is ok. Initially, we have 3 copies of s0, so the masks should all be True. Once we call _step, we haven't explicitly modified the list trajectories_states_b. Yet, the forward_masks of its only element changed, just because we called _step.

So the problem happens here in env.py

Screenshot 2024-02-24 at 1 04 10 PM

And looking at the update_masks function implemented in Hypergrid:

Screenshot 2024-02-24 at 1 04 53 PM

it seems to me that the masks are changed in place and that the problem is due to https://github.com/GFNOrg/torchgfn/pull/149. I don't remember if tests were passing in that PR (I haven't reviewed that PR).

Let me know what you think

josephdviviano commented 7 months ago

Good catch! These tests pass fine - I think the inplace update of masks is desirable behaviour except in this case where we want to accumulate a trajectory of states.

To reduce user error, the base States class could have a clone method, which return deepcopy(self).

When I was messing around, I tried copying the states. I should have instead used deepcopy, which prevents the forward masks from being updated inplace.

josephdviviano commented 7 months ago

OK @saleml figured it out - check line 413 here https://github.com/GFNOrg/torchgfn/pull/163/commits/77e7e1b524a0a640f51b30a82a73a5ea8fee9e90

Before setting the False elements, since we are doing inplace operations, we must first set all values to True to prevent side effects over multiple steps (self.forward_masks[:] = True).

This, plus using deepcopy where appropriate, fixes the issue, and we no longer recompute masks.