pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

[Bug] With `MultiSyncDataCollector`, `tensors` cannot be instantiated on CUDA in child processes. #2235

Open N00bcak opened 3 weeks ago

N00bcak commented 3 weeks ago

Describe the bug

Despite applying the appropriate guards (mp.set_start_method('spawn'), if __name__ == "__main__"), using MultiSyncDataCollector with the cuda device causes program to freeze.

To Reproduce

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import MultiSyncDataCollector
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform

EPS = 1e-7
class SMACCNet(MultiAgentNetBase): 
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = 'cpu',
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class
        self.device = device

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )

    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def init_net_params(self, net):
        def init_layer_params(layer):
            if isinstance(layer, nn.Linear):
                weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
                torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
                if 'bias' in layer.state_dict():
                    torch.nn.init.zeros_(layer.bias)
        net.apply(init_layer_params)
        return net
    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents
        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Linear(300, self.n_agent_outputs)
        ).to(self.device) # We are not able to use MultiSyncDataCollector with the 'meta' device JUST YET!!!

        model = self.init_net_params(model)

        return model

class CustomTanhTransform(D.transforms.TanhTransform):

    def _inverse(self, y):
        # Yoinked from SB3!!!
        """
        Inverse of Tanh

        Taken from Pyro: https://github.com/pyro-ppl/pyro
        0.5 * torch.log((1 + x ) / (1 - x))
        """

        y = y.clamp(-1. + EPS, 1. - EPS)
        return 0.5 * (y.log1p() - (-y).log1p())

    def log_abs_det_jacobian(self, x, y):
        # Yoinked from PyTorch TanhTransform!
        '''
        tl;dr log(1-tanh^2(x)) = log(sech^2(x)) 
                               = 2log(2/(e^x + e^(-x))) 
                               = 2(log2 - log(e^x/(1 + e^(-2x)))
                               = 2(log2 - x - log(1 + e^(-2x)))
                               = 2(log2 - x - softplus(-2x)) 
        '''

        return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))

class TanhNormalStable(D.TransformedDistribution):
    '''Numerically stable variant of TanhNormal. Employs clipping trick.'''
    def __init__(self, loc, scale, event_dims = 1):
        self._event_dims = event_dims
        self._t = [
            CustomTanhTransform()
        ]
        self.update(loc, scale)

    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - D.utils._sum_rightmost(
                transform.log_abs_det_jacobian(x, y),
                event_dim - transform.domain.event_dim,
            )
            y = x

        log_prob = log_prob + D.utils._sum_rightmost(
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
        )

        log_prob = torch.clamp(log_prob, min = math.log10(EPS))
        return log_prob

    def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self.loc = loc
        self.scale = scale
        if (
            hasattr(self, "base_dist")
            and (self.base_dist.base_dist.loc.shape == self.loc.shape)
            and (self.base_dist.base_dist.scale.shape == self.scale.shape)
        ):
            self.base_dist.base_dist.loc = self.loc
            self.base_dist.base_dist.scale = self.scale
        else:
            base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
            super().__init__(base, self._t)

    @property
    def mode(self):
        m = self.base_dist.base_dist.mean
        for t in self.transforms:
            m = t(m)
        return m

# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 30000
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda:0"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 512
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1500
    WARMUP_STEPS = int(3e5)
    TRAIN_TIMESTEPS = int(1e7)
    EVAL_INTERVAL = 10
    EVAL_EPISODES = 20

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn(mode, parallel = True, rew_scale = True, killswitch = False):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            if killswitch:
                breakpoint()
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()

    def create_eval_env(tag = "rendered"):

        eval_env = env_fn("rgb_array", parallel = False, rew_scale = False)()
        video_recorder = VideoRecorder(
                                        CSVLogger("multiwalker-toy-test", video_format = "mp4"), 
                                        tag = tag, 
                                        in_keys = ["pixels_record"]
                                    )

        # Call the parent's render function
        eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
        eval_env.append_transform(video_recorder)

        if eval_env.is_closed:
            eval_env.start()
        return eval_env

    check_env_specs(train_env)

    obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["walker", "action"].shape[-1]

    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = NUM_AGENTS,
                          centralised = False,
                          share_params = True,
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        ),
                        NormalParamExtractor(),
                    )

    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = NUM_AGENTS,
                          centralised = True, 
                          share_params = True, 
                          device = DEVICE,
                          activation_class = nn.LeakyReLU, 
                        )

    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("walker", "observation")],
                                            out_keys = [("walker", "loc"), ("walker", "scale")]
                                        )

    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("walker", "observation"), ("walker", "action")],
                                        out_keys = [("walker", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("walker", "obs_act")],
                                            out_keys = [("walker", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["walker", "action"],
        in_keys = [("walker", "loc"), ("walker", "scale")],
        out_keys = [("walker", "action")],
        distribution_class = TanhNormalStable,
        return_log_prob = True,
    )

    with torch.no_grad():
        fake_td = train_env.fake_tensordict()
        policy_actor(fake_td)

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        ) 

    with torch.no_grad():
        reset_obs = train_env.reset()
        reset_obs_clean = deepcopy(reset_obs)
        action = policy_actor(reset_obs)
        state_action_value = critic_actor(action)
        reset_obs = train_env.reset()
        reset_obs["walker", "action"] = torch.zeros((*reset_obs["walker", "observation"].shape[:-1], action_dim))
        train_env.rand_action(reset_obs)
        action = train_env.step(reset_obs)
        print("As you can see, spawning a single environment on the main process is absolutely unproblematic.")

    collector = MultiSyncDataCollector(
                    [env_fn(None, parallel = False, killswitch = True) for _ in range(NUM_EXPLORE_WORKERS)], 
                    policy = policy_actor, # the explora
                    frames_per_batch = BATCH_SIZE,
                    max_frames_per_traj = 0,
                    total_frames = TRAIN_TIMESTEPS,
                    device = DEVICE,
                    reset_at_each_iter = False
                )

    for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
        pbar.write("Hey Hey!!! :D")

    collector.shutdown()
    train_env.close()

Execution output:

<program executes as usual>
As you can see, spawning a single environment on the main process is absolutely unproblematic.
<Program freezes indefinitely>

Terminating the program gives this traceback:

Traceback (most recent call last):
  File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 326, in <module>
    collector = MultiSyncDataCollector(
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1516, in __init__
    self._run_processes()
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1690, in _run_processes
    msg = pipe_parent.recv()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 395, in _recv
    chunk = read(handle, remaining)
            ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
^CException ignored in atexit callback: <function _exit_function at 0x7f4151e12e80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/multiprocessing/util.py", line 360, in _exit_function
    p.join()
  File "/usr/local/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt:

Expected behavior

After printing "As you can see, spawning a single environment on the main process is absolutely unproblematic.", program progresses into the collector iterable and prints "Hey Hey!!! :D" repeatedly.

System info

Describe the characteristic of your environment:

>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.11.9 (main, Jun  5 2024, 10:27:27) [GCC 12.2.0] linux

Additional context

Problem was encountered as part of an effort to spawn multiple environments on the GPU. Any pointers in this direction greatly appreciated.

Proof of issue with tensors

By adding a killswitch into env_fn in various positions, we can make the following observations:

Code (No tensor defined yet)


    def env_fn(mode, parallel = True, rew_scale = True, killswitch = False):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            if killswitch:
                breakpoint() # Killswitch before env initialization
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

Result: Program crashes as expected when hitting a breakpoint with child process.

Process _ProcessNoWarn-1:
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/_utils.py", line 668, in run
    return mp.Process.run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 2653, in _main_async_collector
    inner_collector = SyncDataCollector(
                      ^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 450, in __init__
    self.closed = True
                  ^^^^
  File "/usr/local/lib/python3.11/bdb.py", line 90, in trace_dispatch
    return self.dispatch_line(frame)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/bdb.py", line 115, in dispatch_line
    if self.quitting: raise BdbQuit
                      ^^^^^^^^^^^^^
bdb.BdbQuit
Traceback (most recent call last):
  File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 326, in <module>
    collector = MultiSyncDataCollector(
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1518, in __init__
    self._run_processes()
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1692, in _run_processes
    msg = pipe_parent.recv()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 399, in _recv
    raise EOFError
EOFError
[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

Code: Insert CUDA tensor declaration in killswitch clause


            if killswitch:
                torchy_mctorchface = torch.tensor([1,2,3,4,5], device = 'cuda:0')
                breakpoint()

Result: Program hangs indefinitely.

PS

Since error relates to tensors, would it be a good idea to rope in PyTorch devs?

Checklist

vmoens commented 3 weeks ago

I did not try the additional context block, but running the code above on my machine without these lines works perfectly fine (the Hey Hey is displayed as expected)

            if killswitch:
                breakpoint()

If I don't remove that block the program fails on my Python 3.10 env (even if the breakpoint is never reached).

Some further things we can look at to debug:

What env variable are you setting, if any? What cuda version / pytorch version do you have? Does the cuda of your PT match the cuda on the machine?

N00bcak commented 2 weeks ago

tl;dr seems to either be a WSL2-Debian OR a Python 3.11 quirk. Very interesting.

Part 1

My bad, I should have specified that I was on WSL2-Debian.

Here's some information regarding that:

Debian Version

> python3 -c "import sys, torch, torchrl, tensordict; print(sys.version, torch.__version__, torchrl.__version__, ten
sordict.__version__)"
3.11.9 (main, Jun  5 2024, 10:27:27) [GCC 12.2.0] 2.3.0+cu121 0.4.0 0.4.0
> lsb_release -a
No LSB modules are available.
Distributor ID: Debian
Description:    Debian GNU/Linux 12 (bookworm)
Release:        12
Codename:       bookworm
PS C:\Windows\system32> (get-item C:\windows\system32\wsl.exe).VersionInfo.FileVersion                                  
10.0.19041.3636 (WinBuild.160101.0800) 

Part 2

Strange. I am now using Python 3.10 on a different (single-boot Ubuntu) machine, but I cannot reproduce the bug either.

This is my Python environment:

Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torchrl
>>> import tensordict
>>> torch.__version__, torchrl.__version__, tensordict.__version__
('2.3.0+cu121', '0.4.0', '0.4.0')

What cuda version / pytorch version do you have? Does the cuda of your PT match the cuda on the machine?

Both of my machines use the CUDA that comes with PyTorch.

What env variable are you setting, if any?

The offending files do not have any special environment variables set.