pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.19k stars 289 forks source link

[BUG] `LossModule`/`TensorDictSequential`/`ProbabilisticActor` cannot be compiled with `torch.compile` #2312

Open N00bcak opened 1 month ago

N00bcak commented 1 month ago

Describe the bug

Attempting to invoke torch.compile on any of the abovementioned classes results in similar errors (see below)

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector 
from torchrl.objectives import SACLoss, ValueEstimators 
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter

EPS = 1e-7
class SMACCNet(MultiAgentNetBase): 
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = 'cpu',
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class
        self.device = device

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )

    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def init_net_params(self, net):
        def init_layer_params(layer):
            if isinstance(layer, nn.Linear):
                weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
                torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
                if 'bias' in layer.state_dict():
                    torch.nn.init.zeros_(layer.bias)
        net.apply(init_layer_params)
        return net
    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents
        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Linear(300, self.n_agent_outputs)
        ).to(self.device)

        model = self.init_net_params(model)

        return model

class CustomTanhTransform(D.transforms.TanhTransform):

    def _inverse(self, y):
        # Yoinked from SB3!!!
        """
        Inverse of Tanh

        Taken from Pyro: https://github.com/pyro-ppl/pyro
        0.5 * torch.log((1 + x ) / (1 - x))
        """

        y = y.clamp(-1. + EPS, 1. - EPS)
        return 0.5 * (y.log1p() - (-y).log1p())

    def log_abs_det_jacobian(self, x, y):
        # Yoinked from PyTorch TanhTransform!
        '''
        tl;dr log(1-tanh^2(x)) = log(sech^2(x)) 
                               = 2log(2/(e^x + e^(-x))) 
                               = 2(log2 - log(e^x/(1 + e^(-2x)))
                               = 2(log2 - x - log(1 + e^(-2x)))
                               = 2(log2 - x - softplus(-2x)) 
        '''

        return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))

class TanhNormalStable(D.TransformedDistribution):
    '''Numerically stable variant of TanhNormal. Employs clipping trick.'''
    def __init__(self, loc, scale, event_dims = 1):
        self._event_dims = event_dims
        self._t = [
            CustomTanhTransform()
        ]
        self.update(loc, scale)

    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - D.utils._sum_rightmost(
                transform.log_abs_det_jacobian(x, y),
                event_dim - transform.domain.event_dim,
            )
            y = x

        log_prob = log_prob + D.utils._sum_rightmost(
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
        )

        log_prob = torch.clamp(log_prob, min = math.log10(EPS))
        return log_prob

    def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self.loc = loc
        self.scale = scale
        if (
            hasattr(self, "base_dist")
            and (self.base_dist.base_dist.loc.shape == self.loc.shape)
            and (self.base_dist.base_dist.scale.shape == self.scale.shape)
        ):
            self.base_dist.base_dist.loc = self.loc
            self.base_dist.base_dist.scale = self.scale
        else:
            base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
            super().__init__(base, self._t)

    @property
    def mode(self):
        m = self.base_dist.base_dist.mean
        for t in self.transforms:
            m = t(m)
        return m

# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 256
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda:0"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 256
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1
    WARMUP_STEPS = 0
    TRAIN_TIMESTEPS = int(1e7)
    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn(mode, parallel = True, rew_scale = True):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()

    check_env_specs(train_env)

    obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["walker", "action"].shape[-1]

    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = NUM_AGENTS,
                          centralised = False,
                          share_params = True,
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        ),
                        NormalParamExtractor(),
                    )

    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = NUM_AGENTS,
                          centralised = True, 
                          share_params = True, 
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        )

    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("walker", "observation")],
                                            out_keys = [("walker", "loc"), ("walker", "scale")]
                                        )

    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("walker", "observation"), ("walker", "action")],
                                        out_keys = [("walker", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("walker", "obs_act")],
                                            out_keys = [("walker", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["walker", "action"],
        in_keys = [("walker", "loc"), ("walker", "scale")],
        out_keys = [("walker", "action")],
        distribution_class = TanhNormalStable,
        return_log_prob = True,
    )

    with torch.no_grad():
        fake_td = train_env.fake_tensordict()
        policy_actor(fake_td)

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        )

    # Can't compile these either...
    policy_actor = torch.compile(policy_actor)
    critic_actor = torch.compile(critic_actor)

    with torch.no_grad():
        reset_obs = train_env.reset()
        reset_obs_clean = deepcopy(reset_obs)
        action = policy_actor(reset_obs)
        state_action_value = critic_actor(action)
        reset_obs = train_env.reset()
        reset_obs["walker", "action"] = torch.zeros((*reset_obs["walker", "observation"].shape[:-1], action_dim))
        train_env.rand_action(reset_obs)
        action = train_env.step(reset_obs)

    collector = SyncDataCollector(
                    ParallelEnv(NUM_EXPLORE_WORKERS,
                        [
                            env_fn(None, parallel = False) 
                            for _ in range(NUM_EXPLORE_WORKERS)
                        ],
                        device = None,
                        mp_start_method = "spawn"
                    ), 
                    policy = policy_actor,
                    frames_per_batch = BATCH_SIZE,
                    max_frames_per_traj = -1,
                    total_frames = TRAIN_TIMESTEPS,
                    device = 'cpu',
                    reset_at_each_iter = False
                )
    # Dummy loss module

    replay_buffer = TensorDictPrioritizedReplayBuffer(
        alpha = 0.7,
        beta = 0.9,
        storage = LazyMemmapStorage(
            1e5, 
            device = 'cpu',
            scratch_dir = "temp/"
        ), 
        priority_key = "td_error",
        batch_size = BATCH_SIZE, 
    )

    sac_loss = SACLoss(actor_network = policy_actor, 
                        qvalue_network = critic_actor, 
                        num_qvalue_nets = 2,
                        loss_function = "l2",
                        delay_actor = False,
                        delay_qvalue = True,
                        alpha_init = 0.1,
                        )

    sac_loss.set_keys(
        action = ("walker", "action"),
        state_action_value = ("walker", "state_action_value"),
        reward = ("walker", "reward"),
        done = ("walker", "done"),
        terminated = ("walker", "terminated"),
    )
    sac_loss.make_value_estimator(
        value_type = ValueEstimators.TD0, 
        gamma = 0.99, 
    )

    # Compiling replay_buffer.sample works :D
    @torch.compile(mode = "reduce-overhead")
    def rb_sample():
        td_sample = replay_buffer.sample()
        if td_sample.device != torch.device(DEVICE):
            td_sample = td_sample.to(
                                    DEVICE, 
                                    non_blocking = False
                                )
        else:
            td_sample = td_sample.clone()

        return td_sample

    # This does not :P
    @torch.compile(disable = True)
    def test_compile():
        td_sample = rb_sample()
        return sac_loss(td_sample)

    samples = 0
    for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
        tensordict = tensordict.reshape(-1)
        samples += tensordict.numel()
        replay_buffer.extend(tensordict.to('cpu', non_blocking = True))
        pbar.write("Hey Hey!!! :D")
        a = test_compile()
        print(a)

    collector.shutdown()
    train_env.close()
  File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 382, in <module>
    a = test_compile()
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 373, in test_compile
    td_sample = rb_sample()
  File "/home/n00bcak/Desktop/<path_to_script>/torchrl_no_compile.py", line 374, in torch_dynamo_resume_in_test_compile_at_373
    return sac_loss(td_sample)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/nn/modules/module.py", line 1582, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 125, in decorate_context
    with ctx_factory():
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_contextlib.py", line 126, in torch_dynamo_resume_in_decorate_context_at_125
    return func(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/common.py", line 289, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 559, in forward
    loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 723, in _qvalue_v2_loss
    target_value = self._compute_target_v2(tensordict)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 687, in _compute_target_v2
    tensordict = tensordict.clone(False)
  File "/home/n00bcak/Desktop/<path_to_venv>/torchrl/objectives/sac.py", line 692, in torch_dynamo_resume_in__compute_target_v2_at_687
    ), self.actor_network_params.to_module(self.actor_network):
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/utils.py", line 1189, in new_func
    out = func(_self, *args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/base.py", line 949, in to_module
    return self._to_module(
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/params.py", line 174, in new_func
    out = getattr(self._param_td, name)(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 444, in _to_module
    if value.is_empty():
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 348, in is_empty
    if not item.is_empty():
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 348, in is_empty
    if not item.is_empty():
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 348, in is_empty
    if not item.is_empty():
  [Previous line repeated 1 more time]
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/base.py", line 3350, in is_empty
    for _ in self.keys(True, True):
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/convert_frame.py", line 295, in _convert_frame_assert
    cache_size = compute_cache_size(frame, cache_entry)
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/cache_size.py", line 142, in compute_cache_size
    if _has_same_id_matched_objs(frame, cache_entry):
  File "/home/n00bcak/Desktop/<path_to_venv>/torch/_dynamo/cache_size.py", line 123, in _has_same_id_matched_objs
    if weakref_from_frame != weakref_from_cache_entry:
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/nn/params.py", line 174, in new_func
    out = getattr(self._param_td, name)(*args, **kwargs)
  File "/home/n00bcak/Desktop/<path_to_venv>/tensordict/_td.py", line 505, in __ne__
    raise KeyError(
KeyError: "keys in TensorDict(<omitted for brevity>) mismatch, got {'2', '4', '1', '0', '3'} and {'module'}"
[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

Expected behavior

Since non-compile version executes successfully, compile is expected to succeed.

System info

Describe the characteristic of your environment:

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.25.0 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0] linux

Reason and Possible fixes

Perhaps it is due to the decorators you mentioned in discord?

Checklist