pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.37k stars 315 forks source link

[BUG] No real DDQN when using `delay_value` #1722

Closed Arlaz closed 11 months ago

Arlaz commented 12 months ago

Describe the bug

For an academic project, I wanted to compare few versions of DQN :

By looking into torchrl documentation, I found the delay_value argument that it is said to create a target network to create a double DQN. This can mislead a user between trying to implement a simple DQN with a target net or a real DDQN.

I may have not really understand all the intricacies of TorchRL, but after digging a bit into TorchRL code, I think that using the delay_value does not really create a Double DQN as described in the reference article

Reason and Possible fixes

The issue is maybe in the _next_value function in advantages.py. The current implementation use the same network (target or value network) to predict the next state value.

Instead of:

https://github.com/pytorch/rl/blob/2e7f574529fd4e6bd2f661b0d59bd22623e4fb49/torchrl/objectives/value/advantages.py#L427-L435

I would think of something like :

def _next_value(self, tensordict, target_params=None, action_select_params=None, kwargs={}):

        step_td = step_mdp(tensordict, keep_other=False)
        step_td_copy = step_td.clone(False)

        if self.value_network is not None:
            with hold_out_net(
                    self.value_network
            ) if action_select_params is None else action_select_params.to_module(self.value_network):
                self.value_network(step_td)

            with hold_out_net(
                    self.value_network
            ) if target_params is None else target_params.to_module(self.value_network):
                self.value_network(step_td_copy)

        action = step_td.get(self.tensor_keys.action).to(torch.float)
        pred_val = step_td_copy.get(self.tensor_keys.action_value)

        next_value = (pred_val * action).sum(-1).unsqueeze(-1)

        return next_value

and in dqn.py :

target_value = self.value_estimator.value_estimate(
            td_copy,
            target_params=self.target_value_network_params,
            action_select_params=self.target_value_network_params if not self.double_q else None,
        ).squeeze(-1)

I did these changes (and few other concerning the keys) and I got results closer to what I can achieve with a DDQN in other libraries.

Am I wrong somewhere ? Please tell me if I can further help or even make a PR if needed. Thank you for this incredible work!

Checklist

vmoens commented 12 months ago

Hi @Arlaz thanks for the interest! Indeed the doc is misleading, very grateful that you reported this. We're refactoring DQN so we should at least fix the doc.

Your fixes seem sensible to me, we'll integrate them (unless you want to make the PR).

cc @albertbou92 for context

vmoens commented 11 months ago

@Arlaz Would you be happy to review #1737? I trained a couple of models and it seems ok on my side