pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.21k stars 290 forks source link

[BUG] Using DiscreteSAC in Multi-Agent environments #1459

Closed hyerra closed 1 year ago

hyerra commented 1 year ago

Describe the bug

I am currently trying to use the DiscreteSACLoss in a multi-agent environment. I am currently following this tutorial. When I try to run my example though, I get some dimension issues. Maybe we are handling the dimensions (unsqueezing/squeezing) incorrectly in the loss so that it doesn't generalize to multi-agent environments?

It is important to note that in my setup, I am only using a single-agent. However, the game environment I'm using (Unity) supports multiple agents so the tensordicts/specs are all setup as if it were a multi-agent setup. For the purposes of this bug though, we don't really need to consider Unity. I just provide an example that is independent of Unity below.

To Reproduce

You can run this simple script using the latest torchrl and tensordict installed from main:

import math

import torch
from tensordict.nn import InteractionType
from torch import nn
from torchrl.modules.distributions import OneHotCategorical
from torchrl.data.tensor_specs import (
    CompositeSpec,
    DiscreteTensorSpec,
)
from torchrl.modules import ProbabilisticActor, SafeModule, ValueOperator
from torchrl.objectives import DiscreteSACLoss
from torchrl.modules import MultiAgentMLP
from tensordict import TensorDict

device = "cpu"
actor_net = MultiAgentMLP(
    n_agent_inputs=128,
    num_cells=[256, 256],
    n_agent_outputs=5,
    centralised=False,
    n_agents=1,
    activation_class=nn.ReLU,
    share_params=False,
)

actor_module = SafeModule(
    actor_net,
    in_keys=[("agents", "encoder_vec")],
    out_keys=[
        ("agents", "logits"),
    ],
)
unbatched_action_spec = CompositeSpec(
    {
        "agents": CompositeSpec(
            {"action": DiscreteTensorSpec(n=5, shape=torch.Size([1, 1]), dtype=torch.int64)}
        )
    }
)
actor = ProbabilisticActor(
    spec=unbatched_action_spec,
    in_keys=[("agents", "logits")],
    out_keys=[("agents", "action")],
    module=actor_module,
    distribution_class=OneHotCategorical,
    default_interaction_type=InteractionType.RANDOM,
    return_log_prob=False,
)

qvalue_net = MultiAgentMLP(
    n_agent_inputs=128,
    num_cells=[256, 256],
    n_agent_outputs=5,
    centralised=False,
    n_agents=1,
    activation_class=nn.ReLU,
    share_params=False,
)

qvalue = ValueOperator(
    in_keys=[("agents", "encoder_vec")],
    out_keys=[("agents", "state_value")],
    module=qvalue_net,
)

model = torch.nn.ModuleList([actor, qvalue]).to(device)

loss_module = DiscreteSACLoss(
    actor_network=model[0],
    qvalue_network=model[1],
    num_actions=5,
    num_qvalue_nets=2,
    target_entropy_weight=0.2,
    loss_function="smooth_l1",
)
loss_module.set_keys(action=("agents", "action"), reward=("agents", "reward"), done=("agents", "done"), value=("agents", "state_value"), priority=("agents", "td_error"))
loss_module.make_value_estimator(gamma=0.99)

td = TensorDict(
    source={
        "agents": TensorDict(
            source={
                "action": torch.zeros((256, 1, 5), dtype=torch.float32),
                "encoder_vec": torch.zeros((256, 1, 128), dtype=torch.float32),
                "logits": torch.zeros((256, 1, 5), dtype=torch.float32),
            },
            batch_size=torch.Size([256, 1])
        ),
        "next": TensorDict(
            source={
                "agents": TensorDict(
                    source={
                        "reward": torch.zeros((256, 1, 1), dtype=torch.float32),
                        "done": torch.zeros((256, 1, 1), dtype=torch.bool),
                    },
                    batch_size=torch.Size([256, 1])
                )
            },
            batch_size=torch.Size([256])
        )
    },
    batch_size=torch.Size([256])
)
loss_module(td)
Traceback (most recent call last):
/Users/hyerra/Library/Caches/pypoetry/virtualenvs/.../lib/python3.10/site-packages/torchrl/objectives/common.py:33: UserWarning: No target network updater has been associated with this loss module, but target parameters have been found. While this is supported, it is expected that the target network updates will be manually performed. You can deactivate this warning by turning the RL_WARNINGS env variable to False.
  warnings.warn(
/Users/hyerra/Library/Caches/pypoetry/virtualenvs/.../lib/python3.10/site-packages/torchrl/objectives/common.py:343: UserWarning: No target network updater has been associated with this loss module, but target parameters have been found. While this is supported, it is expected that the target network updates will be manually performed. You can deactivate this warning by turning the RL_WARNINGS env variable to False.
  warnings.warn(
Traceback (most recent call last):
  File "/Users/hyerra/Desktop/crew-algorithms/.../harish_algorithm/test.py", line 107, in <module>
    loss_module(td)
  File "/Users/hyerra/Library/Caches/pypoetry/virtualenvs/.../lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/Users/hyerra/Library/Caches/pypoetry/virtualenvs/.../lib/python3.10/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/Users/hyerra/Library/Caches/pypoetry/virtualenvs/.../lib/python3.10/site-packages/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/Users/hyerra/Library/Caches/pypoetry/virtualenvs/.../lib/python3.10/site-packages/torchrl/objectives/sac.py", line 1175, in forward
    td_error = (pred_val - target_value.expand_as(pred_val)).pow(2)
RuntimeError: The expanded size of the tensor (2) must match the existing size (256) at non-singleton dimension 0.  Target sizes: [2, 256].  Tensor sizes: [256, 1]

Expected behavior

Ideally we should be handling dimensions so that it can work in these multi-agent environments as well.

System info

Describe the characteristic of your environment: Using the latest torchrl and tensordict installed directly from GitHub source.

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

Output:

None 1.25.0 3.10.12 (main, Jul 28 2023, 18:34:01) [Clang 14.0.3 (clang-1403.0.22.14.1)] darwin

Checklist

cc: @matteobettini

matteobettini commented 1 year ago

Thanks for reporting. I am looking into it.

Do the other losses work (e.g. PPO)?

matteobettini commented 1 year ago

This is not dependent on multi agent, it seems like SAC does not work with batch sizes that have more than one dim.

I was able to reproduce with the single-agent code


import torch
from tensordict.nn import InteractionType
from torch import nn
from torchrl.modules.distributions import OneHotCategorical
from torchrl.data.tensor_specs import (
    OneHotDiscreteTensorSpec,
)
from torchrl.modules import ProbabilisticActor, SafeModule, ValueOperator, MLP
from torchrl.objectives import DiscreteSACLoss
from tensordict import TensorDict

device = "cpu"
actor_net = MLP(
    in_features=128,
    num_cells=[256, 256],
    out_features=5,
    activation_class=nn.ReLU,
)

actor_module = SafeModule(
    actor_net,
    in_keys=["encoder_vec"],
    out_keys=[
        "logits",
    ],
)
unbatched_action_spec = OneHotDiscreteTensorSpec(
    n=5, shape=torch.Size([5]), dtype=torch.int64
)

actor = ProbabilisticActor(
    spec=unbatched_action_spec,
    in_keys=["logits"],
    out_keys=["action"],
    module=actor_module,
    distribution_class=OneHotCategorical,
    default_interaction_type=InteractionType.RANDOM,
    return_log_prob=False,
)

qvalue_net = MLP(
    in_features=128,
    num_cells=[256, 256],
    out_features=5,
    activation_class=nn.ReLU,
)

qvalue = ValueOperator(
    in_keys=["encoder_vec"],
    out_keys=["state_value"],
    module=qvalue_net,
)

model = torch.nn.ModuleList([actor, qvalue]).to(device)

loss_module = DiscreteSACLoss(
    actor_network=model[0],
    qvalue_network=model[1],
    num_actions=5,
    num_qvalue_nets=2,
    target_entropy_weight=0.2,
    loss_function="smooth_l1",
)

loss_module.make_value_estimator(gamma=0.99)
single_agent_td = TensorDict(
    source={
        "action": torch.zeros((256,1, 5), dtype=torch.float32),
        "encoder_vec": torch.zeros((256, 1,128), dtype=torch.float32),
        "logits": torch.zeros((256,1, 5), dtype=torch.float32),
        "next": TensorDict(
            source={
                "reward": torch.zeros((256,1, 1), dtype=torch.float32),
                "done": torch.zeros((256, 1,1), dtype=torch.bool),
            },
            batch_size=torch.Size([256,1]),
        ),
    },
    batch_size=torch.Size([256,1]),
)

loss_module(single_agent_td)

Now it is just a matter of understanding the code of the loss module and see where is the point that does not generalize

matteobettini commented 1 year ago

@vmoens I might go for a rewrite of discrete sac to align with normal sac and increase readability and modularity

hyerra commented 1 year ago

Ah yep, I should have clarified here. The issue isn’t with your multi-agent components like MultiAgentMLP or anything like that. I just meant to refer to multi-agent as a motivating example because that’s a common case where batch dimensions will be greater than 1.

matteobettini commented 1 year ago

I rewrote discrete SAC to in #1461 try it out, it should be more flexible

here is a script to see how you can adapt yours

import torch
from tensordict.nn import InteractionType
from torch import nn
from torchrl.modules.distributions import OneHotCategorical
from torchrl.data.tensor_specs import (
    OneHotDiscreteTensorSpec,
)
from torchrl.modules import ProbabilisticActor, SafeModule, ValueOperator, MLP
from torchrl.objectives import DiscreteSACLoss
from tensordict import TensorDict

device = "cpu"
actor_net = MLP(
    in_features=128,
    num_cells=[256, 256],
    out_features=5,
    activation_class=nn.ReLU,
)

actor_module = SafeModule(
    actor_net,
    in_keys=["encoder_vec"],
    out_keys=[
        "logits",
    ],
)
unbatched_action_spec = OneHotDiscreteTensorSpec(
    n=5, shape=torch.Size([5]), dtype=torch.int64
)

actor = ProbabilisticActor(
    spec=unbatched_action_spec,
    in_keys=["logits"],
    out_keys=["action"],
    module=actor_module,
    distribution_class=OneHotCategorical,
    default_interaction_type=InteractionType.RANDOM,
    return_log_prob=False,
)

qvalue_net = MLP(
    in_features=128,
    num_cells=[256, 256],
    out_features=5,
    activation_class=nn.ReLU,
)

qvalue = ValueOperator(
    in_keys=["encoder_vec"],
    out_keys=["action_value"],
    module=qvalue_net,
)

model = torch.nn.ModuleList([actor, qvalue]).to(device)

loss_module = DiscreteSACLoss(
    actor_network=model[0],
    qvalue_network=model[1],
    num_actions=5,
    num_qvalue_nets=2,
    target_entropy_weight=0.2,
    loss_function="smooth_l1",
)

loss_module.make_value_estimator(gamma=0.99)
single_agent_td = TensorDict(
    source={
        "action": torch.zeros((256, 1, 5), dtype=torch.float32),
        "encoder_vec": torch.zeros((256, 1, 128), dtype=torch.float32),
        "logits": torch.zeros((256, 1, 5), dtype=torch.float32),
        "next": TensorDict(
            source={
                "reward": torch.zeros((256, 1, 1), dtype=torch.float32),
                "done": torch.zeros((256, 1, 1), dtype=torch.bool),
                "encoder_vec": torch.zeros((256, 1, 128), dtype=torch.float32),
            },
            batch_size=torch.Size([256, 1]),
        ),
    },
    batch_size=torch.Size([256, 1]),
)

loss_module(single_agent_td)

notably, before you were missing "encoder_vec" in next and the previous implementation was silently failing