pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.36k stars 312 forks source link

[BUG] GAE parameters (gamma, lmbda) seemed to get changed by ClipPPOLoss, advantage module does not calculate loss_critic #2462

Open therealjoker4u opened 1 month ago

therealjoker4u commented 1 month ago

Environment

OS: Windows 11 Python : CPython 3.10.14 Torchrl Version : 0.5.0 PyTorch Version : 2.4.1+cu124 Gym Environment: A custom subclass of EnvBase (from torchrl.envs)

The project I'm working on is relatively complex, so I only mention parts of code that I know are related to the bug that I mention below. Here's the definition of actor, value (critic), advantage, and loss module.

import torch
from torchrl.modules import ProbabilisticActor, ValueOperator, OneHotCategorical
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

self.action_spec = DiscreteTensorSpec(3, dtype=torch.int8)

# Actor
_actor = TensorDictModule(self.agent, in_keys=self.agent.in_keys,
                          out_keys=self.agent.out_keys).to(self.agent.device)

self.actor_module = ProbabilisticActor(
    _actor,
    in_keys=self.agent.out_keys,
    spec=self.action_spec,
    distribution_class=OneHotCategorical,
    return_log_prob=True,
)

# Critic
self.value_net = MyValueNetwork(device=agent.device)
self.value_module = ValueOperator(
    self.value_net, in_keys=self.value_net.in_keys, 
)

# Advantage
self.advantage_module = GAE(value_network=self.value_module,
      gamma=self.advantage_gamma,
      lmbda=self.advantage_lmbda,
      differentiable=True,
      average_gae=True,
      device=self.value_module.device,
)

# Loss
entropy_eps = 0.001
self.loss_module = ClipPPOLoss(
    actor_network=self.actor_module,
    critic_network=self.value_module,
    clip_epsilon=(0.2, ),
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

Training loop

My training loop catches the batched data from a MultiSyncDataCollector, and adds it to a replay buffer with a LazyTensorStorage storage, and after that it samples and passes the sample to the _optimize_policy function:

def _optimize_sample(self, sample: TensorDict):
  self.actor_optimizer.zero_grad()
  self.critic_optimizer.zero_grad()

  if not self.value_net.weights_initialized:
      self.value_net(sample["observation"])
      self.value_net.weights_initialized = True

  self.actor_module(sample)
  sample["sample_log_prob"] = sample["sample_log_prob"].detach()

  self.advantage_module(sample)

  loss_vals = self.loss_module(sample)
  total_loss = loss_vals["loss_entropy"] + \
      loss_vals["loss_objective"] + loss_vals["loss_critic"]

  total_loss.backward()

  torch.nn.utils.clip_grad_norm_(
      self.actor_module.parameters(), max_norm=0.1)
  torch.nn.utils.clip_grad_norm_(
      self.value_module.parameters(), max_norm=1.0)

  self.actor_optimizer.step()
  self.critic_optimizer.step()

  return loss_vals

In the code above I got the error below when it called self.actor_module(sample):

loss_vals = self.loss_module(sample)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\common.py", line 39, in new_forward
    return func(self, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\ppo.py", line 817, in forward
    log_weight, dist, kl_approx = self._log_weight(tensordict)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\ppo.py", line 473, in _log_weight
    raise RuntimeError("tensordict prev_log_prob requires grad.")
RuntimeError: tensordict prev_log_prob requires grad.

So I added sample["sample_log_prob"] = sample["sample_log_prob"].detach() to detach sample_log_prob from the computation graph. and the issue was solved.

At this stage the model seems to converge, as objective and critic loss is minimizing: Figure 1 - Objective/Policy loss (Exponentially moving average interval 100): Figure_1

Figure 2 - Critic loss: Figure_2_critic

The main issue

At this point apparently, everything is ok, but the main issue occurs when I connect the actor (policy) module to the collector, to collect data based on the current policy (not a random choice of actions):

train_kwargs["policy_device"] = self.agent.device
train_kwargs["policy"] = self.actor_module
my_collector = MultiSyncDataCollector(**train_kwargs)

And when I run it, I get the error below (thrown inside self.advantage_module(sample)):

self.advantage_module(sample)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 68, in new_func
    return fun(self, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 57, in new_fun
    return fun(self, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 1357, in forward
    adv, value_target = vec_generalized_advantage_estimate(
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 89, in transposed_fun
    out = fun(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 315, in vec_generalized_advantage_estimate
    return _fast_vec_gae(
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 250, in _fast_vec_gae
    advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\utils.py", line 78, in _custom_conv1d
    val_pad = torch.nn.functional.pad(tensor, [0, filter.shape[-2] - 1])
IndexError: tuple index out of range

I found out that in torchrl\objectives\value\functional.py, and inside the function vec_generalized_advantage_estimate line 307, value variable is vector of zeros (1d) with length of the sample batch size, but without connecting the actor_module it's the truth matrix of multiplied gammas and lambdas (with one column), and I found out that in the buffer of the advantage module , when the collector uses the actor module, it resets gamma and lmbda of the buffer to 0.0 (Inside the training loop print("Gamma : ", self.advantage_module.get_buffer("gamma")) outputs tensor(0.) ).

So I added these tow lines after the loss module definition:

self.advantage_module.register_buffer("gamma", torch.tensor(self.advantage_gamma))
self.advantage_module.register_buffer("lmbda", torch.tensor(self.advantage_lmbda))

By adding these tow lines of code, the previous error vanished, but a new issue appeared:

loss_vals["loss_objective"] + loss_vals["loss_critic"]
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\base.py", line 335, in __getitem__
    result = self._get_tuple(idx_unravel, NO_DEFAULT)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\_td.py", line 2399, in _get_tuple
    first = self._get_str(key[0], default)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\_td.py", line 2395, in _get_str
    return self._default_get(first_key, default)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\base.py", line 4503, in _default_get
    raise KeyError(
KeyError: 'key "loss_critic" not found in TensorDict with keys [\'ESS\', \'clip_fraction\', \'entropy\', \'kl_approx\', \'loss_entropy\', \'loss_objective\']'

That clearly implies that the key "loss_critic" does not exist in the sample tensordict object (but before I connect the actor module to the collector it calculates it properly).

therealjoker4u commented 1 month ago

I found the root of issue, that's because when I connect the policy_forward to collector (train_kwargs["policy"] = self.policy_forward), for every iteration (step), the collector spawns another process and runs the policy on that and since my replay buffer and the policy_forward are on two separated processes then I get the error which the advantage gamma param is zero