pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

[BUG] A2C fails with functional=True and shifted=True for ValueEstimator #2265

Open jkrude opened 1 week ago

jkrude commented 1 week ago

Describe the bug

Not quite sure if this is supported behavior, but if I set functional=True for the A2C loss and shifted=True for TD0Estimator I get an internal error.

To Reproduce

import gymnasium as gym
import torchrl.envs
import torch
import torchrl
from torchrl.objectives import ValueEstimators
from torchrl.objectives.value import TD0Estimator
from torchrl.modules import MLP, ValueOperator, ProbabilisticActor, Actor

time_dim = 4

    gym_env = torchrl.envs.GymEnv("MountainCar-v0", device="cpu")
    observation_shape = gym_env.observation_spec["observation"].shape[0]

    actor_net_mock = torch.nn.Linear(
        in_features=observation_shape,
        out_features=gym_env.action_spec.shape[-1],
    )

    value_net_mock = torch.nn.Linear(
        in_features=observation_shape,
        out_features=1,
    )
    probabilistic_actor = ProbabilisticActor(
        module=Actor(
            actor_net_mock,out_keys=["logits"]
        ),
        in_keys=["logits"],
        distribution_class=torch.distributions.OneHotCategorical,
    )
    value_operator = ValueOperator(module=value_net_mock, in_keys=["observation"])

    rollout = gym_env.rollout(max_steps=time_dim, policy=probabilistic_actor)
    loss = torchrl.objectives.a2c.A2CLoss(
        probabilistic_actor,
        value_operator,
        functional=True,
    )
    loss.make_value_estimator(ValueEstimators.TD0, gamma=0.9, shifted=True)

    rollout_loss = loss(rollout)
../venv-nightly/lib/python3.11/site-packages/torch/nn/modules/module.py:1657: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/torch/nn/modules/module.py:1709: in _call_impl
    result = forward_call(*args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/tensordict/_contextlib.py:126: in decorate_context
    return func(*args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/tensordict/nn/common.py:289: in wrapper
    return func(_self, tensordict, *args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/torchrl/objectives/a2c.py:470: in forward
    self.value_estimator(
../venv-nightly/lib/python3.11/site-packages/torch/nn/modules/module.py:1657: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/torch/nn/modules/module.py:1668: in _call_impl
    return forward_call(*args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py:68: in new_func
    return fun(self, *args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py:57: in new_fun
    return fun(self, *args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/tensordict/nn/common.py:289: in wrapper
    return func(_self, tensordict, *args, **kwargs)
../venv-nightly/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py:632: in forward
    value, next_value = _call_value_nets(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

value_net = ValueOperator(
    module=Linear(in_features=2, out_features=1, bias=True),
    device=cpu,
    in_keys=['observation'],
    out_keys=['state_value'])
data = TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=F..., device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)
params = TensorDict(
    fields={
        module: TensorDict(
            fields={
                bias: Tensor(shape=torch.Siz...       device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
next_params = TensorDict(
    fields={
        module: TensorDict(
            fields={
                bias: Tensor(shape=torch.Siz...       device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
single_call = True, value_key = 'state_value', detach_next = True
vmap_randomness = 'error'

    def _call_value_nets(
        value_net: TensorDictModuleBase,
        data: TensorDictBase,
        params: TensorDictBase,
        next_params: TensorDictBase,
        single_call: bool,
        value_key: NestedKey,
        detach_next: bool,
        vmap_randomness: str = "error",
    ):
        in_keys = value_net.in_keys
        if single_call:
            for i, name in enumerate(data.names):
                if name == "time":
                    ndim = i + 1
                    break
            else:
                ndim = None
            if ndim is not None:
                # get data at t and last of t+1
                idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),)
                idx = (slice(None),) * (ndim - 1) + (slice(None, -1),)
                idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),)
                data_in = torch.cat(
                    [
                        data.select(*in_keys, value_key, strict=False),
                        data.get("next").select(*in_keys, value_key, strict=False)[idx0],
                    ],
                    ndim - 1,
                )
            else:
                if RL_WARNINGS:
                    warnings.warn(
                        "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
                        "This warning can be turned off by setting the environment variable RL_WARNINGS to False."
                    )
                ndim = data.ndim
                idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),)
                idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),)
                data_in = torch.cat(
                    [
                        data.select(*in_keys, value_key, strict=False),
                        data.get("next").select(*in_keys, value_key, strict=False),
                    ],
                    ndim - 1,
                )

            # next_params should be None or be identical to params
            if next_params is not None and next_params is not params:
>               raise ValueError(
                    "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
                )
E               ValueError: the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed.

../venv-nightly/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py:122: ValueError

Process finished with exit code 1

Expected behavior

The losses are calculated correctly and the value_network is only called once in the computation of the advantage.

System info

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.6.23 2.0.0 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] linux

Reason and Possible fixes

The problem seems to be in this snippet, where detached parameter are used for params which makes them unequal.


self.value_estimator(
                tensordict,
                params=self._cached_detach_critic_network_params,
                target_params=self.target_critic_network_params,
            )
´´´´
## Checklist

- [x] I have checked that there is no similar issue in the repo (**required**)
- [x] I have read the [documentation](https://github.com/pytorch/rl/tree/main/docs/) (**required**)
- [x] I have provided a minimal working example to reproduce the bug (**required**)
vmoens commented 1 week ago

Thanks for reporting I can't reproduce this error on the main branch with tensordict and pytorch nightly, the script runs perfectly fine. Does this occur sporadically?

jkrude commented 1 week ago

For me, it's happening deterministically every time with the versions I provided. I wanted to test the newer *-nightly versions but found that torchrl-nightly from 2024.6.24 onwards has only wheels for windows? Same for tensordict-nightly from 2024.6.20. Am I missing something?

vmoens commented 1 week ago

No you're right the nightlies are broken. I will fix that In the meantime you can install it all like this:

# Adapt this if you need cuda, e.g. nightly/cu124
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu -U 
pip3 install git+https://github.com/pytorch/tensordict -U
pip3 install git+https://github.com/pytorch/rl -U

LMK if you can reprod after that!

vmoens commented 1 week ago

The nightly release should be available now (not for windows though) https://pypi.org/project/torchrl-nightly/#history

jkrude commented 6 days ago
Thanks for the quick fix regarding the nightly builds 👍 I encounter the same error with 2024.7.3 versions using the above scripts. Here is my full pip list Package Version
cloudpickle 3.0.0
Farama-Notifications 0.0.4
filelock 3.13.1
fsspec 2024.6.1
gymnasium 0.29.1
Jinja2 3.1.4
MarkupSafe 2.1.5
mpmath 1.3.0
networkx 3.3
numpy 2.0.0
orjson 3.10.6
packaging 24.1
pip 23.2.1
setuptools 68.2.0
sympy 1.12.1
tensordict-nightly 2024.7.3
torch 2.5.0.dev20240703+cpu
torchrl-nightly 2024.7.3
typing_extensions 4.12.2
wheel 0.41.2

I am a bit surprised that it works on your side, as the primary code-snippets are still the same on the main-branch. Here in a2c.py the value estimator is called with both params and target_params, where the params are not the same as target_params as they are detached?

self.value_estimator(
                tensordict,
                params=self._cached_detach_critic_network_params,
                target_params=self.target_critic_network_params,
            )

Which ultimately fails in advantages.py still the same on the main branch:

if next_params is not None and next_params is not params:
            raise ValueError(
                "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
            )

Note that I am running completely on CPU without GPU support on the running machine, don't know if that makes any difference 🤷.