GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

Redundant calling of `update_masks` in dafault sampler #160

Closed listar2000 closed 6 months ago

listar2000 commented 6 months ago

Hi, I'm currently developing my own environment and following the training script given in the example (i.e. I do on_policy sampling using the forward policy), and I'm a little bit confused about the following redundancy:

  1. during the sequential sampling process, at each step, the current state calls the update_masks method in its initialization method, which sets up the forward_masks and backward_masks.
  2. once the sampling is done, the trajectories are batched into a new States object in the default samplers.py:
    trajectories_states = env.States(tensor=trajectories_states)

    (p.s. I think in the nightly version this is changed to) https://github.com/GFNOrg/torchgfn/blob/3276492f6d5d31f2be9a6d21c4a2cf21bab0026d/src/gfn/samplers.py#L225

which essentially does the same thing. This new state is initialized and then call update_masks on all the states (in the trajectories of states) again, which I believe has been already calculated once in step 1. So why bother repeating this process and not reusing the already computed masks?

Lots of thanks for any explanation for this :).

josephdviviano commented 6 months ago

Thanks for this great point. We need a States.stack() method or stack_states() function which will accept a list of states to avoid this recomputation. Addressed in #161 .

Pseudocode

from gfn.containers.utils import stack_states

        while not all(dones):
            actions = env.actions_from_batch_shape((n_trajectories,))  # Dummy actions.
            valid_actions, actions_log_probs, estimator_outputs = self.sample_actions(
                env,
                states[~dones],
                save_estimator_outputs=True if save_estimator_outputs else False,
                calculate_logprobs=False if skip_logprob_calculaion else True,
                **policy_kwargs,
            )
            ...
            actions[~dones] = valid_actions
            ...
            if self.estimator.is_backward:
                new_states = env._backward_step(states, actions)
            else:
                new_states = env._step(states, actions)
            ...
            new_dones = (new_states.is_initial_state if self.estimator.is_backward else sink_states_mask ) & ~dones
            trajectories_dones[new_dones & ~dones] = step
            ...    
            states = new_states
            dones = dones | new_dones

            trajectories_states += [states]

        trajectories_states = stack_states(trajectories_states, dim=0)

And this stack_states method would extend all relevant attributes of the submitted states along the trajectory dim, and would return a Trajectories object.

josephdviviano commented 6 months ago

We're working on this here https://github.com/GFNOrg/torchgfn/pull/163

Edit - this issue is resolved!