pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

[BUG] `check_env_specs` + `PixelRenderTransform` does not tolerate "cuda" device #2236

Closed N00bcak closed 3 weeks ago

N00bcak commented 3 weeks ago

Describe the bug

Running check_env_specs on a TransformedEnv which contains the PixelRenderTransform fails.

To Reproduce

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform

EPS = 1e-7

# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 30000
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 512
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1500
    WARMUP_STEPS = int(3e5)
    TRAIN_TIMESTEPS = int(1e7)
    EVAL_INTERVAL = 10
    EVAL_EPISODES = 20

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn(mode, parallel = True, rew_scale = True):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()

    def create_eval_env(tag = "rendered"):

        eval_env = env_fn("rgb_array", parallel = False, rew_scale = False)()
        video_recorder = VideoRecorder(
                                        CSVLogger("multiwalker-toy-test", video_format = "mp4"), 
                                        tag = tag, 
                                        in_keys = ["pixels_record"]
                                    )

        # Call the parent's render function
        eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
        eval_env.append_transform(video_recorder)

        if eval_env.is_closed:
            eval_env.start()
        return eval_env

    check_env_specs(train_env)
    eval_env = create_eval_env()
    check_env_specs(eval_env)

    train_env.close()
File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 115, in <module>
    check_env_specs(eval_env)
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/utils.py", line 728, in check_env_specs
    fake_tensordict = env.fake_tensordict()
                      ^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 2922, in fake_tensordict
    observation_spec = self.observation_spec
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 1303, in observation_spec
    observation_spec = self.output_spec["full_observation_spec"]
                       ^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 748, in output_spec
    output_spec = self.transform.transform_output_spec(output_spec)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 1104, in transform_output_spec
    output_spec = t.transform_output_spec(output_spec)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 376, in transform_output_spec
    output_spec["full_observation_spec"] = self.transform_observation_spec(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/record/recorder.py", line 501, in transform_observation_spec
    observation_spec[self.out_keys[0]] = spec
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 3783, in __setitem__
    raise RuntimeError(
RuntimeError: Setting a new attribute (pixels_record) on another device (cuda:0 against cuda). All devices of CompositeSpec must match.

Expected behavior

check_env_specs succeeds and program terminates.

System info

Describe the characteristic of your environment:

>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.11.9 (main, Jun  5 2024, 10:27:27) [GCC 12.2.0] linux

Reason and Possible fixes

A strict check appears to be conducted on the device strings, which results in the error.

For consistency with PyTorch in general, can consider substituting "cuda" with f"cuda:{torch.cuda.current_device()}"

Depending on availability of current_device() on other devices, can consider implementing checks for those too.

Checklist