Closed josephdviviano closed 7 months ago
Some debugging:
First, I changed the batch_size to 3 in the script. Then, with a breakpoint at the assertion error, I see that the forward masks of all the steps within the batch of 3 trajectories are the same.
So here is the interesting part. If I add a breakpoint at lines 195 and 199 of samplers.py
, I get the following:
The first mask is ok. Initially, we have 3 copies of s0, so the masks should all be True. Once we call _step
, we haven't explicitly modified the list trajectories_states_b
. Yet, the forward_masks of its only element changed, just because we called _step
.
So the problem happens here in env.py
And looking at the update_masks
function implemented in Hypergrid:
it seems to me that the masks are changed in place and that the problem is due to https://github.com/GFNOrg/torchgfn/pull/149. I don't remember if tests were passing in that PR (I haven't reviewed that PR).
Let me know what you think
Good catch! These tests pass fine - I think the inplace update of masks is desirable behaviour except in this case where we want to accumulate a trajectory of states.
To reduce user error, the base States
class could have a clone
method, which return deepcopy(self)
.
When I was messing around, I tried copying
the states. I should have instead used deepcopy
, which prevents the forward masks from being updated inplace.
OK @saleml figured it out - check line 413 here https://github.com/GFNOrg/torchgfn/pull/163/commits/77e7e1b524a0a640f51b30a82a73a5ea8fee9e90
Before setting the False
elements, since we are doing inplace operations, we must first set all values to True
to prevent side effects over multiple steps (self.forward_masks[:] = True
).
This, plus using deepcopy
where appropriate, fixes the issue, and we no longer recompute masks.
This isn't working @saleml -- please see the issue in
samplers.py
.You can reproduce the error with
tutorials/examples/train_hypergrid_simple.py