GFNOrg / torchgfn

GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
209 stars 26 forks source link

Conserve LogZ value while using epsilon #148

Closed OMalenfantThuot closed 5 months ago

OMalenfantThuot commented 9 months ago

Hi! Thanks for your work on this library, it's much appreciated.

I'm working on a physics project with GFlowNets and was recommended this repo by Bruno Rousseau from MILA. I've been able to port my project to it, but I got some unexpected results when using epsilon for greedy training.

Using the implementation of to_probability_distribution from DiscretePolicyEstimator with a non-zero epsilon, I converge to a different value of LogZ at the end of training. This seems to be due to the epsilon directly modifying the PF values, which in turns modifies the probability of choosing some final states over others.

Here's an example. Physically, temperature and partition function are related and should have a 1-to-1 relation in the same system. But here, the higher epsilon I'm using, the lower is the found partition function is (sorry for the spaghetti in the figure, it's colored by temperature and there's three trainings per parameter combination).

image

If this is not intended behavior, I've written another way of using epsilon that conserves LogZ, but it involves changing the Sampler class instead of the to_probability_distribution.

Here is my implementation:

class LogZConservingSampler(Sampler):
    """
    This Sampler replaces the default Sampler in order to conserve
    the physical LogZ of the system while using a finite epsilon value
    to modify the probabilities of choosing random actions at each step.
    """
    def sample_actions(
        self, env: Env, states: States
    ) -> Tuple[Actions, BatchedFloatTensor]:

        # Same as original Sampler
        module_output = self.estimator(states)
        dist = self.estimator.to_probability_distribution(
            states, module_output, **self.probability_distribution_kwargs
        )
        device = module_output.device

        with torch.no_grad():
            actions = dist.sample()

            # New block
            if self.estimator.epsilon > 0:
                # Choose which trajectories take a random action
                random_actions = (
                    torch.rand(states.tensor.shape[0], device=device)
                    < self.estimator.epsilon
                )
                if random_actions.any():
                    # Filter for only the available actions
                    masks = states[random_actions].forward_masks
                    available_random_actions = torch.tile(
                        torch.arange(
                            self.estimator.expected_output_dim(), device=device
                        ),
                        (masks.shape[0], 1),
                    )[masks].reshape(masks.shape[0], -1)
                    # Choose random index of the available actions
                    random_choices = torch.randint(
                        low=0,
                        high=available_random_actions.shape[-1],
                        size=(torch.sum(random_actions).item(),),
                        device=device,
                    )
                    # Replace the initial choice of actions where relevant
                    actions[random_actions] = available_random_actions[
                        torch.arange(random_choices.shape[0]), random_choices
                    ].unsqueeze(1)

        # Back to original behavior
        log_probs = dist.log_prob(actions)

        if torch.any(torch.isinf(log_probs)):
            raise RuntimeError("Log probabilities are inf. This should not happen.")

        return env.Actions(actions), log_probs

The original project is here, this is not yet in master, it's in the torchgfn_porting branch. Thank you! Let me know if this is of interest to you.

josephdviviano commented 9 months ago

Thank you very much for this issue - we will be looking into this shortly.

I haven't had time yet to investigate your specific issue, but I suspect you are onto something. We also made many changes to this logic in https://github.com/GFNOrg/torchgfn/pull/147, which should be merged shortly. I think it would be productive to evaluate your issue in the context of that update.

josephdviviano commented 6 months ago

Just to update on this -- #147 was unfortunately only merged very recently due to some members of the team taking a very long vacation, but we are going to be working now on clearing out these outstanding issues. Could you update me on whether this problem still exists for you?

I will be looking at this shortly (for real this time)

josephdviviano commented 6 months ago

Yes @OMalenfantThuot - I now see what the issue is. Note that we re-worked the API to accomplish this behaviour (as of v1.2) - so your solution seems reasonable in light of the code you were provided - but we tried hard to improve things since then!

_(As a refresher for those who might find this in the future - the intended behaviour of off policy exploration is shown in this notebook (see function train_with_exploration()).)_

What we want is to sample actions from the exploration_dist but evaluate their logprobs under the policy_dist. If you do this, your logZ should be correct at the end of training.

In the Sampler (which has changed since you filed this issue some months ago), you can pass policy_kwargs to do exactly this - the kwargs define the off policy sampling behaviour (see https://github.com/GFNOrg/torchgfn/blob/3276492f6d5d31f2be9a6d21c4a2cf21bab0026d/src/gfn/samplers.py#L46).

In a DiscretePolicyEstimator, the epsilon kwarg passed via policy_kwargs will handle off policy sampling as shown here: https://github.com/GFNOrg/torchgfn/blob/3276492f6d5d31f2be9a6d21c4a2cf21bab0026d/src/gfn/modules.py#L188

When you then calculate the loss using the gflownet, you will hit this flag (for example, during the calculation of p_f and p_b for trajectory balance):

https://github.com/GFNOrg/torchgfn/blob/3276492f6d5d31f2be9a6d21c4a2cf21bab0026d/src/gfn/gflownet/base.py#L152.

This will calculate the policy distribution, without any tampering, to accumulate p_f and p_b.

In other words, to accomplish this behaviour you must do the following. Note the gflownet must be initialized using the off_policy=True flag.

sampler = gfn.sampler.Sampler(pf_module)
gflownet = gfn.gflownet.TBGFlowNet(pf=pf_module, pb=pb_module, off_policy=True)
trajectories = sampler.sample_trajectories(
    env,
    n_trajectories=batch_size,
    off_policy=True,
    epsilon=0.05,
)

loss = gflownet.loss(env, trajectories)

Your solution works as well but it is harder to make general, particularly in the case that we don't know ahead of time what kind of exploration we need to support (this is particularly true for continuous environments).

I'd love to work with you to

Let me know how I can be of assistance!

OMalenfantThuot commented 6 months ago

Hi Joseph! I'm busy with other things for a few weeks, but looking at it quickly, the separate exploration and policy distributions should make it possible to correct the problem without my version of the Sampler. @sblackburn-mila @rousseab This could be of interest to you. Thanks for the work!

josephdviviano commented 6 months ago

It was a pleasure! Sorry it took us so long to get back to you. Please let me know if you have any further issues -- I'll leave the issue open for now.

josephdviviano commented 5 months ago

I'm going to close this assuming it is resolved on your end -- please feel free to reopen if you have any trouble.

Note we made some minor changes to handling of off_policy sampling (at the API level only).