pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.31k stars 306 forks source link

[BUG] Python-based RNNs in place operations cause RuntimeError #1742

Closed albertbou92 closed 10 months ago

albertbou92 commented 11 months ago

Describe the bug

Training with the Python-based GRU raises the following error, which indicates the current implementation has some in-place operations that prevent correct backward computation:

File "/home/abou/test_bug.py", line 157, in main() File "/home/abou/test_bug.py", line 148, in main loss = loss_module(batch.cuda()) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl result = forward_call(*args, *kwargs) File "/home/abou/tensordict/tensordict/_contextlib.py", line 126, in decorate_context return func(args, kwargs) File "/home/abou/tensordict/tensordict/nn/common.py", line 281, in wrapper return func(_self, tensordict, *args, kwargs) File "/home/abou/rl/torchrl/objectives/sac.py", line 1096, in forward loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) File "/home/abou/rl/torchrl/objectives/sac.py", line 1203, in _actor_loss dist = self.actor_network.get_dist(tensordict) File "/home/abou/tensordict/tensordict/nn/probabilistic.py", line 524, in get_dist tensordict_out = self.get_dist_params(tensordict, tensordict_out, kwargs) File "/home/abou/tensordict/tensordict/nn/probabilistic.py", line 515, in get_dist_params return tds(tensordict, tensordict_out, kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun return getattr(type(self), fun_name)(self, *args, kwargs) File "/home/abou/tensordict/tensordict/nn/common.py", line 281, in wrapper return func(_self, tensordict, *args, *kwargs) File "/home/abou/tensordict/tensordict/_contextlib.py", line 126, in decorate_context return func(args, kwargs) File "/home/abou/tensordict/tensordict/nn/utils.py", line 253, in wrapper return func(_self, tensordict, *args, kwargs) File "/home/abou/tensordict/tensordict/nn/sequence.py", line 426, in forward tensordict = self._run_module(module, tensordict, kwargs) File "/home/abou/tensordict/tensordict/nn/sequence.py", line 407, in _run_module tensordict = module(tensordict, kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun return getattr(type(self), fun_name)(self, *args, kwargs) File "/home/abou/tensordict/tensordict/nn/common.py", line 281, in wrapper return func(_self, tensordict, *args, *kwargs) File "/home/abou/tensordict/tensordict/_contextlib.py", line 126, in decorate_context return func(args, kwargs) File "/home/abou/tensordict/tensordict/nn/utils.py", line 253, in wrapper return func(_self, tensordict, *args, kwargs) File "/home/abou/tensordict/tensordict/nn/sequence.py", line 426, in forward tensordict = self._run_module(module, tensordict, kwargs) File "/home/abou/tensordict/tensordict/nn/sequence.py", line 407, in _run_module tensordict = module(tensordict, kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun return getattr(type(self), fun_name)(self, *args, kwargs) File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 1346, in forward val, hidden = self._gru(value, batch, steps, device, dtype, hidden) File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 1386, in _gru y, hidden = self.gru(input, hidden) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun return getattr(type(self), fun_name)(self, *args, **kwargs) File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 990, in forward result = self._gru(input, hx) File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 945, in _gru h_t[layer] = self._gru_cell( File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 904, in _gru_cell gate_h = F.linear(hx, weight_hh, bias_hh) (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 28it [00:13, 2.10it/s] Traceback (most recent call last): File "/home/abou/test_bug.py", line 157, in main() File "/home/abou/test_bug.py", line 150, in main loss_sum.backward() File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 256]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

@vmoens you were right, in place operations cause problems here. I believe the issue is being addressed in https://github.com/pytorch/rl/pull/1732

To Reproduce

import tqdm
import torch
import random
import numpy as np
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs.libs.gym import GymEnv
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.modules.distributions import OneHotCategorical
from torchrl.modules import ProbabilisticActor, GRUModule, MLP
from torchrl.collectors import SyncDataCollector
from torchrl.objectives import DiscreteSACLoss
from torchrl.envs import (
    ParallelEnv,
    TransformedEnv,
    InitTracker,
    StepCounter,
    RewardSum,
)

def create_model(input_size, output_size, hidden_size=256, num_layers=3, out_key="logits"):

    embedding_module = TensorDictModule(
        in_keys=["observation"],
        out_keys=["embed"],
        module=torch.nn.Linear(input_size, input_size), # this raises RuntimeError
    )
    lstm_module = GRUModule(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        in_key="embed",
        out_key="features",
        python_based=True,
    )
    mlp = TensorDictModule(
        MLP(
            in_features=hidden_size,
            out_features=output_size,
            num_cells=[],
        ),
        in_keys=["features"],
        out_keys=[out_key],
    )

    inference_model = TensorDictSequential(embedding_module, lstm_module, mlp)
    training_model = TensorDictSequential(embedding_module, lstm_module.set_recurrent_mode(), mlp)

    return inference_model, training_model

def create_rhs_transform(input_size, hidden_size=256, num_layers=3):
    lstm_module = GRUModule(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        in_key="observation",
        out_key="features",
    )
    return lstm_module.make_tensordict_primer()

def main():

    # Set seeds
    seed = 2024
    random.seed(int(seed))
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))

    torch.autograd.set_detect_anomaly(True)

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    test_env = GymEnv("CartPole-v1", device=device, categorical_action_encoding=True)
    action_spec = test_env.action_spec.space
    observation_spec = test_env.observation_spec["observation"]

    def create_env_fn():
        env = GymEnv("CartPole-v1", device=device)
        env = TransformedEnv(env)
        env.append_transform(create_rhs_transform(input_size=observation_spec.shape[-1]))
        env.append_transform(InitTracker())
        return env

    # Models
    ##################

    inference_actor, training_actor = create_model(input_size=observation_spec.shape[-1], output_size=action_spec.n)
    inference_actor = ProbabilisticActor(
        module=inference_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    training_actor = ProbabilisticActor(
        module=training_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    inference_actor = inference_actor.to(device)
    training_actor = training_actor.to(device)
    _, training_critic = create_model(input_size=observation_spec.shape[-1], output_size=action_spec.n, out_key="action_value")
    training_critic = training_critic.to(device)

    # Collector
    ##################

    collector = SyncDataCollector(
        create_env_fn=create_env_fn,
        policy=inference_actor,
        frames_per_batch=100,
        total_frames=5000,
        device=device,
        storing_device=device,
        max_frames_per_traj=-1,
        split_trajs=False,
    )

    # Buffer
    ##################

    buffer = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(100),
        batch_size=1,
    )

    # Loss
    ##################

    loss_module = DiscreteSACLoss(
        actor_network=training_actor,
        qvalue_network=training_critic,
        num_actions=action_spec.n,
        num_qvalue_nets=2,
        loss_function="smooth_l1",
    )
    loss_module.make_value_estimator(gamma=0.99)

    # Collection loop
    ##################

    for data in tqdm.tqdm(collector):
        buffer.extend(data.cpu())
        batch = buffer.sample()
        loss = loss_module(batch.cuda())
        loss_sum = loss["loss_actor"] + loss["loss_qvalue"] + loss["loss_alpha"]
        loss_sum.backward()

    collector.shutdown()
    print("Success!")

if __name__ == "__main__":
    main()

Checklist

vmoens commented 11 months ago

inplace ops are the devil :) Let's work on the patch in the "faster rnn" PR!

vmoens commented 10 months ago

@albertbou92 can you write a smaller (more standalone and mininal) reprod example that I could put in the tests?

albertbou92 commented 10 months ago

What about like this?


import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import GRUModule, ProbabilisticActor
from torchrl.modules.distributions import OneHotCategorical
from torchrl.objectives import DiscreteSACLoss

def create_model(input_size, output_size, num_layers=3, out_key="logits"):
    gru_module = GRUModule(
        input_size=input_size,
        hidden_size=output_size,
        num_layers=num_layers,
        in_key="observation",
        out_key=out_key,
        python_based=True,
    )
    return (
        gru_module,
        gru_module.set_recurrent_mode(True),
        gru_module.make_tensordict_primer(),
    )

def test_python_gru(device):

    env_name = "CartPole-v1"
    test_env = GymEnv(env_name)
    observation_size = test_env.observation_spec["observation"].shape[-1]
    num_actions = int(test_env.action_spec.space.n)

    inference_actor, training_actor, rhs_transform = create_model(
        input_size=observation_size, output_size=num_actions
    )
    inference_actor = ProbabilisticActor(
        module=inference_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    training_actor = ProbabilisticActor(
        module=training_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    inference_actor = inference_actor.to(device)
    training_actor = training_actor.to(device)
    _, training_critic, _ = create_model(
        input_size=observation_size, output_size=num_actions, out_key="action_value"
    )
    training_critic = training_critic.to(device)

    def create_env_fn():
        env = GymEnv(env_name, device=device)
        env = TransformedEnv(env)
        env.append_transform(rhs_transform)
        env.append_transform(InitTracker())
        return env

    collector = SyncDataCollector(
        create_env_fn=create_env_fn,
        policy=inference_actor,
        frames_per_batch=10,
        total_frames=100,
    )

    loss_module = DiscreteSACLoss(
        actor_network=training_actor,
        qvalue_network=training_critic,
        num_actions=num_actions,
        num_qvalue_nets=2,
        loss_function="smooth_l1",
    )

    for data in collector:
        loss = loss_module(data.cuda())
        loss_sum = loss["loss_actor"] + loss["loss_qvalue"] + loss["loss_alpha"]
        loss_sum.backward()

    collector.shutdown()
    print("Success!")
vmoens commented 10 months ago

I don't think we need to create any env, actor, distribution, or even spec to test that bug. This isn't fit to be a unit test unfortunately. But no worry I will try to find a minimal example on my own :)

vmoens commented 10 months ago

The issue isn't rnn-related, it is related to the fact that you share params between actor and value and we don't clone when we pass the value params. Hence when we optimize the params, the graph breaks. I even wrote a comment about this: https://github.com/pytorch/rl/blob/0906206c2de07ec7bde99439708928b945503fb2/torchrl/objectives/sac.py#L1208-L1210

We should call clone whenever the qvalue and actor nets share params. I will write a fix