pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[BUG] `SACLoss` module does not allow stochastic modules (i.e. `Dropout`, etc.) due to `vmap` #2313

Closed N00bcak closed 1 month ago

N00bcak commented 1 month ago

Describe the bug

SACLoss has flawed checks for determining the nature of vmap_randomness. Therefore, stochastic modules cannot be used in constituent networks.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector 
from torchrl.objectives import SACLoss, ValueEstimators 
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter

EPS = 1e-7
class SMACCNet(MultiAgentNetBase): 
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = 'cpu',
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class
        self.device = device

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )

    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def init_net_params(self, net):
        def init_layer_params(layer):
            if isinstance(layer, nn.Linear):
                weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
                torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
                if 'bias' in layer.state_dict():
                    torch.nn.init.zeros_(layer.bias)
        net.apply(init_layer_params)
        return net
    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents
        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Dropout(0.5), # <- The dropout is here!
            nn.Linear(300, self.n_agent_outputs)
        ).to(self.device)

        model = self.init_net_params(model)

        return model

class CustomTanhTransform(D.transforms.TanhTransform):

    def _inverse(self, y):
        # Yoinked from SB3!!!
        """
        Inverse of Tanh

        Taken from Pyro: https://github.com/pyro-ppl/pyro
        0.5 * torch.log((1 + x ) / (1 - x))
        """

        y = y.clamp(-1. + EPS, 1. - EPS)
        return 0.5 * (y.log1p() - (-y).log1p())

    def log_abs_det_jacobian(self, x, y):
        # Yoinked from PyTorch TanhTransform!
        '''
        tl;dr log(1-tanh^2(x)) = log(sech^2(x)) 
                               = 2log(2/(e^x + e^(-x))) 
                               = 2(log2 - log(e^x/(1 + e^(-2x)))
                               = 2(log2 - x - log(1 + e^(-2x)))
                               = 2(log2 - x - softplus(-2x)) 
        '''

        return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))

class TanhNormalStable(D.TransformedDistribution):
    '''Numerically stable variant of TanhNormal. Employs clipping trick.'''
    def __init__(self, loc, scale, event_dims = 1):
        self._event_dims = event_dims
        self._t = [
            CustomTanhTransform()
        ]
        self.update(loc, scale)

    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - D.utils._sum_rightmost(
                transform.log_abs_det_jacobian(x, y),
                event_dim - transform.domain.event_dim,
            )
            y = x

        log_prob = log_prob + D.utils._sum_rightmost(
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
        )

        log_prob = torch.clamp(log_prob, min = math.log10(EPS))
        return log_prob

    def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self.loc = loc
        self.scale = scale
        if (
            hasattr(self, "base_dist")
            and (self.base_dist.base_dist.loc.shape == self.loc.shape)
            and (self.base_dist.base_dist.scale.shape == self.scale.shape)
        ):
            self.base_dist.base_dist.loc = self.loc
            self.base_dist.base_dist.scale = self.scale
        else:
            base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
            super().__init__(base, self._t)

    @property
    def mode(self):
        m = self.base_dist.base_dist.mean
        for t in self.transforms:
            m = t(m)
        return m

# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 256
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda:0"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 256
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1
    WARMUP_STEPS = 0
    TRAIN_TIMESTEPS = int(1e7)
    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn(mode, parallel = True, rew_scale = True):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()

    check_env_specs(train_env)

    obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["walker", "action"].shape[-1]

    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = NUM_AGENTS,
                          centralised = False,
                          share_params = True,
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        ),
                        NormalParamExtractor(),
                    )

    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = NUM_AGENTS,
                          centralised = True, 
                          share_params = True, 
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        )

    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("walker", "observation")],
                                            out_keys = [("walker", "loc"), ("walker", "scale")]
                                        )

    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("walker", "observation"), ("walker", "action")],
                                        out_keys = [("walker", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("walker", "obs_act")],
                                            out_keys = [("walker", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["walker", "action"],
        in_keys = [("walker", "loc"), ("walker", "scale")],
        out_keys = [("walker", "action")],
        distribution_class = TanhNormalStable,
        return_log_prob = True,
    )

    with torch.no_grad():
        fake_td = train_env.fake_tensordict()
        policy_actor(fake_td)

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        )

    with torch.no_grad():
        reset_obs = train_env.reset()
        reset_obs_clean = deepcopy(reset_obs)
        action = policy_actor(reset_obs)
        state_action_value = critic_actor(action)
        reset_obs = train_env.reset()
        reset_obs["walker", "action"] = torch.zeros((*reset_obs["walker", "observation"].shape[:-1], action_dim))
        train_env.rand_action(reset_obs)
        action = train_env.step(reset_obs)

    collector = SyncDataCollector(
                    ParallelEnv(NUM_EXPLORE_WORKERS,
                        [
                            env_fn(None, parallel = False) 
                            for _ in range(NUM_EXPLORE_WORKERS)
                        ],
                        device = None,
                        mp_start_method = "spawn"
                    ), 
                    policy = policy_actor,
                    frames_per_batch = BATCH_SIZE,
                    max_frames_per_traj = -1,
                    total_frames = TRAIN_TIMESTEPS,
                    device = 'cpu',
                    reset_at_each_iter = False
                )
    # Dummy loss module

    replay_buffer = TensorDictPrioritizedReplayBuffer(
        alpha = 0.7,
        beta = 0.9,
        storage = LazyMemmapStorage(
            1e5, 
            device = 'cpu',
            scratch_dir = "temp/"
        ), 
        priority_key = "td_error",
        batch_size = BATCH_SIZE, 
    )

    sac_loss = SACLoss(actor_network = policy_actor, 
                        qvalue_network = critic_actor, 
                        num_qvalue_nets = 2,
                        loss_function = "l2",
                        delay_actor = False,
                        delay_qvalue = True,
                        alpha_init = 0.1,
                        )

    sac_loss.set_keys(
        action = ("walker", "action"),
        state_action_value = ("walker", "state_action_value"),
        reward = ("walker", "reward"),
        done = ("walker", "done"),
        terminated = ("walker", "terminated"),
    )
    sac_loss.make_value_estimator(
        value_type = ValueEstimators.TD0, 
        gamma = 0.99, 
    )

    # Compiling replay_buffer.sample works :D
    @torch.compile(mode = "reduce-overhead")
    def rb_sample():
        td_sample = replay_buffer.sample()
        if td_sample.device != torch.device(DEVICE):
            td_sample = td_sample.to(
                                    DEVICE, 
                                    non_blocking = False
                                )
        else:
            td_sample = td_sample.clone()

        return td_sample

    def test_compile():
        td_sample = rb_sample()
        return sac_loss(td_sample)

    samples = 0
    for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
        tensordict = tensordict.reshape(-1)
        samples += tensordict.numel()
        replay_buffer.extend(tensordict.to('cpu', non_blocking = True))
        pbar.write("Hey Hey!!! :D")
        a = test_compile()
        print(a)

    collector.shutdown()
    train_env.close()
Traceback (most recent call last):
  File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 381, in <module>
    a = test_compile()
  File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 373, in test_compile
    return sac_loss(td_sample)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1582, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 289, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 559, in forward
    loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 723, in _qvalue_v2_loss
    target_value = self._compute_target_v2(tensordict)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 700, in _compute_target_v2
    next_tensordict_expand = self._vmap_qnetworkN0(
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/utils.py", line 491, in decorated_module
    return module(*module_args)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 289, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/utils.py", line 261, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/sequence.py", line 428, in forward
    tensordict = self._run_module(module, tensordict, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/sequence.py", line 409, in _run_module
    tensordict = module(tensordict, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 289, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/utils.py", line 261, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 1224, in forward
    raise err from RuntimeError(
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 1198, in forward
    raise err
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 1184, in forward
    tensors = self._call_module(tensors, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 1141, in _call_module
    out = self.module(*tensors, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/modules/models/multiagent.py", line 115, in forward
    output = self._empty_net(inputs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/dropout.py", line 59, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/functional.py", line 1295, in dropout
    return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

Expected behavior

SACLoss module performs forward passes successfully.

System info

Describe the characteristic of your environment:

import torchrl, numpy, sys
0.4.0 1.25.0 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0] linux

Reason and Possible fixes

There are essentially two reasons for this error:

  1. Check appears to be flawed since only top-level modules' types are checked against the RANDOM_MODULE_LIST
    • Recursive checks could resolve this, maybe?
  2. Flawed check cannot be rectified manually via LossModule.set_vmap_randomness as self.vmap_randomness is accessed immediately during initialization time

Checklist

vmoens commented 1 month ago

modules is recursive, unlike children. I think you're right with 2. though, if we cache the vmap call as we do we can't get this to work. I'll push a fix shortly

N00bcak commented 1 month ago

modules is recursive, unlike children. I think you're right with 2. though, if we cache the vmap call as we do we can't get this to work. I'll push a fix shortly

Hmm, that's surprising. Were it truly recursive, I'd expect _vmap_randomness == "different" at the end of it, because the double-break should prevent the else clause from triggering.

I've got something running atm, so I can't provide proof of this just yet, but 1. was what I observed when stepping.

vmoens commented 1 month ago

I fixed a couple more things, but I can't try your example because i'm (as always) having problems with petting zoo dependencies maybe you can check that it works, or perhaps give an example that does not involve an extra lib?

N00bcak commented 1 month ago

I fixed a couple more things, but I can't try your example because i'm (as always) having problems with petting zoo dependencies maybe you can check that it works, or perhaps give an example that does not involve an extra lib?

FWIW, I was able to replicate this issue on "navigation" in VMAS (I figured you could run it since its featured with MAPPO as a tutorial :P):

For a little extra information, I patched LossModule.vmap_randomness:


@property
def vmap_randomness(self):
    modules = []
    if self._vmap_randomness is None:
        do_break = False
        for val in self.__dict__.values():
            if isinstance(val, torch.nn.Module):
                for module in val.modules():
                    modules.append(str(type(module)))
                    if isinstance(module, RANDOM_MODULE_LIST):
                        self._vmap_randomness = "different"
                        do_break = True
                        break
            if do_break:
                # double break
                break
        else:
            self._vmap_randomness = "error"
    print(','.join(modules))
    return self._vmap_randomness

This is the script proper:

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, VmasEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector 
from torchrl.objectives import SACLoss, ValueEstimators 
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter

EPS = 1e-7
class SMACCNet(MultiAgentNetBase): 
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = 'cpu',
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class
        self.device = device

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )

    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def init_net_params(self, net):
        def init_layer_params(layer):
            if isinstance(layer, nn.Linear):
                weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
                torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
                if 'bias' in layer.state_dict():
                    torch.nn.init.zeros_(layer.bias)
        net.apply(init_layer_params)
        return net
    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents
        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Dropout(0.5), # <- The dropout is here!
            nn.Linear(300, self.n_agent_outputs)
        ).to(self.device)

        model = self.init_net_params(model)

        return model

class CustomTanhTransform(D.transforms.TanhTransform):

    def _inverse(self, y):
        # Yoinked from SB3!!!
        """
        Inverse of Tanh

        Taken from Pyro: https://github.com/pyro-ppl/pyro
        0.5 * torch.log((1 + x ) / (1 - x))
        """

        y = y.clamp(-1. + EPS, 1. - EPS)
        return 0.5 * (y.log1p() - (-y).log1p())

    def log_abs_det_jacobian(self, x, y):
        # Yoinked from PyTorch TanhTransform!
        '''
        tl;dr log(1-tanh^2(x)) = log(sech^2(x)) 
                               = 2log(2/(e^x + e^(-x))) 
                               = 2(log2 - log(e^x/(1 + e^(-2x)))
                               = 2(log2 - x - log(1 + e^(-2x)))
                               = 2(log2 - x - softplus(-2x)) 
        '''

        return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))

class TanhNormalStable(D.TransformedDistribution):
    '''Numerically stable variant of TanhNormal. Employs clipping trick.'''
    def __init__(self, loc, scale, event_dims = 1):
        self._event_dims = event_dims
        self._t = [
            CustomTanhTransform()
        ]
        self.update(loc, scale)

    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - D.utils._sum_rightmost(
                transform.log_abs_det_jacobian(x, y),
                event_dim - transform.domain.event_dim,
            )
            y = x

        log_prob = log_prob + D.utils._sum_rightmost(
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
        )

        log_prob = torch.clamp(log_prob, min = math.log10(EPS))
        return log_prob

    def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self.loc = loc
        self.scale = scale
        if (
            hasattr(self, "base_dist")
            and (self.base_dist.base_dist.loc.shape == self.loc.shape)
            and (self.base_dist.base_dist.scale.shape == self.scale.shape)
        ):
            self.base_dist.base_dist.loc = self.loc
            self.base_dist.base_dist.scale = self.scale
        else:
            base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
            super().__init__(base, self._t)

    @property
    def mode(self):
        m = self.base_dist.base_dist.mean
        for t in self.transforms:
            m = t(m)
        return m

# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 256
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda:0"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 256
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1
    WARMUP_STEPS = 0
    TRAIN_TIMESTEPS = int(1e7)
    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn():

        def base_env_fn():
            return VmasEnv(
                  scenario="navigation",
                  num_envs=NUM_EXPLORE_WORKERS,
                  continuous_actions=True,
                  max_steps=200,
                  device="cpu",
                  seed=None,
                  # Scenario kwargs
                  n_agents=NUM_AGENTS,
              )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("agents", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

    train_env = env_fn()()

    if train_env.is_closed:
        train_env.start()

    check_env_specs(train_env)
    # print(train_env.full_observation_spec)
    # print(train_env.full_action_spec)
    print(train_env.done_spec)
    # breakpoint()

    obs_dim = train_env.full_observation_spec["agents", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["agents", "action"].shape[-1]

    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = NUM_AGENTS,
                          centralised = False,
                          share_params = True,
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        ),
                        NormalParamExtractor(),
                    ).to(DEVICE)

    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = NUM_AGENTS,
                          centralised = True, 
                          share_params = True, 
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        ).to(DEVICE)

    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("agents", "observation")],
                                            out_keys = [("agents", "loc"), ("agents", "scale")]
                                        )

    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("agents", "observation"), ("agents", "action")],
                                        out_keys = [("agents", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("agents", "obs_act")],
                                            out_keys = [("agents", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["agents", "action"],
        in_keys = [("agents", "loc"), ("agents", "scale")],
        out_keys = [("agents", "action")],
        distribution_class = TanhNormalStable,
        return_log_prob = True,
    )

    # with torch.no_grad():
    #     fake_td = train_env.fake_tensordict()
    #     policy_actor(fake_td)

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        )

    # with torch.no_grad():
    #     reset_obs = train_env.reset()
    #     reset_obs_clean = deepcopy(reset_obs)
    #     action = policy_actor(reset_obs)
    #     state_action_value = critic_actor(action)
    #     reset_obs = train_env.reset()
    #     reset_obs["agents", "action"] = torch.zeros((*reset_obs["agents", "observation"].shape[:-1], action_dim))
    #     train_env.rand_action(reset_obs)
    #     action = train_env.step(reset_obs)

    collector = SyncDataCollector(
                    ParallelEnv(NUM_EXPLORE_WORKERS,
                        [
                            env_fn() 
                            for _ in range(NUM_EXPLORE_WORKERS)
                        ],
                        device = None,
                        mp_start_method = "spawn"
                    ), 
                    policy = policy_actor,
                    frames_per_batch = BATCH_SIZE,
                    max_frames_per_traj = -1,
                    total_frames = TRAIN_TIMESTEPS,
                    device = 'cpu',
                    policy_device = 'cpu',
                    reset_at_each_iter = False
                )
    # Dummy loss module

    replay_buffer = TensorDictPrioritizedReplayBuffer(
        alpha = 0.7,
        beta = 0.9,
        storage = LazyMemmapStorage(
            1e5, 
            device = 'cpu',
            scratch_dir = "googoogaagaa/"
        ), 
        priority_key = "td_error",
        batch_size = BATCH_SIZE, 
    )

    sac_loss = SACLoss(actor_network = policy_actor, 
                        qvalue_network = critic_actor, 
                        num_qvalue_nets = 2,
                        loss_function = "l2",
                        delay_actor = False,
                        delay_qvalue = True,
                        alpha_init = 0.1,
                        )

    sac_loss.set_keys(
        action = ("agents", "action"),
        state_action_value = ("agents", "state_action_value"),
        reward = ("agents", "reward"),
        done = ("agents", "done"),
        terminated = ("agents", "terminated"),
    )
    sac_loss.make_value_estimator(
        value_type = ValueEstimators.TD0, 
        gamma = 0.99, 
    )

    # Compiling replay_buffer.sample works :D
    @torch.compile(mode = "reduce-overhead")
    def rb_sample():
        td_sample = replay_buffer.sample()
        if td_sample.device != torch.device(DEVICE):
            td_sample = td_sample.to(
                                    DEVICE, 
                                    non_blocking = False
                                )
        else:
            td_sample = td_sample.clone()

        return td_sample

    def test_compile():
        td_sample = rb_sample()
        return sac_loss(td_sample)

    samples = 0
    for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
        tensordict.set(
            ("next", "agents", "done"),
            tensordict.get(("next", "done"))
            .unsqueeze(-1)
            .expand(tensordict.get_item_shape(("next", "agents", "reward"))),
        )
        tensordict.set(
            ("next", "agents", "terminated"),
            tensordict.get(("next", "terminated"))
            .unsqueeze(-1)
            .expand(tensordict.get_item_shape(("next", "agents", "reward"))),
        )
        tensordict = tensordict.reshape(-1)
        samples += tensordict.numel()
        replay_buffer.extend(tensordict.to('cpu', non_blocking = True))
        pbar.write("Hey Hey!!! :D")
        a = test_compile()
        print(a)

    collector.shutdown()
    train_env.close()

Running the script now yields

<class 'torchrl.modules.tensordict_module.actors.ProbabilisticActor'>,<class 'torch.nn.modules.container.ModuleList'>,<class 'tensordict.nn.common.TensorDictModule'>,<class 'torch.nn.modules.container.Sequential'>,<class '__main__.SMACCNet'>,<class 'tensordict.nn.params.TensorDictParams'>,<class 'tensordict.nn.distributions.continuous.NormalParamExtractor'>,<class 'torchrl.modules.tensordict_module.probabilistic.SafeProbabilisticModule'>,<class 'tensordict.nn.sequence.TensorDictSequential'>,<class 'torch.nn.modules.container.ModuleList'>,<class 'tensordict.nn.common.TensorDictModule'>,<class 'tensordict.nn.common.TensorDictModule'>,<class '__main__.SMACCNet'>,<class 'tensordict.nn.params.TensorDictParams'>

<...omitted for brevity...>
...
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/dropout.py", line 59, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/functional.py", line 1295, in dropout
    return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap

Seems that the list of modules being checked does not go deep enough because DropoutNd is nowhere to be seen :P

vmoens commented 1 month ago

Got it, here your problem is that the dropout is hidden by the MARL model which does not register the inner module in a usual way. Should be somewhat easy to fix