GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

`Transitions` container should then have some `estimator_outputs` attribute to avoid duplicate computation. #156

Open josephdviviano opened 6 months ago

josephdviviano commented 6 months ago

In detailed_balance.py, we have:

        if not self.off_policy:
            valid_log_pf_actions = transitions.log_probs
        else:
            # Evaluate the log PF of the actions sampled off policy.
            # I suppose the Transitions container should then have some
            # estimator_outputs attribute as well, to avoid duplication here ?
            module_output = self.pf(states)  # TODO: Inefficient duplication.
            valid_log_pf_actions = self.pf.to_probability_distribution(
                states, module_output
            ).log_prob(
                actions.tensor
            )  # Actions sampled off policy.

We could aboid this second forward pass of .pf() by storing the estimator outputs in the transitions class.

Ideally, both Trajectories and Transitions would be able to access the same estimator outputs in memory if there were ever a need to keep track of both.