Closed albertbou92 closed 10 months ago
inplace ops are the devil :) Let's work on the patch in the "faster rnn" PR!
@albertbou92 can you write a smaller (more standalone and mininal) reprod example that I could put in the tests?
What about like this?
import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import GRUModule, ProbabilisticActor
from torchrl.modules.distributions import OneHotCategorical
from torchrl.objectives import DiscreteSACLoss
def create_model(input_size, output_size, num_layers=3, out_key="logits"):
gru_module = GRUModule(
input_size=input_size,
hidden_size=output_size,
num_layers=num_layers,
in_key="observation",
out_key=out_key,
python_based=True,
)
return (
gru_module,
gru_module.set_recurrent_mode(True),
gru_module.make_tensordict_primer(),
)
def test_python_gru(device):
env_name = "CartPole-v1"
test_env = GymEnv(env_name)
observation_size = test_env.observation_spec["observation"].shape[-1]
num_actions = int(test_env.action_spec.space.n)
inference_actor, training_actor, rhs_transform = create_model(
input_size=observation_size, output_size=num_actions
)
inference_actor = ProbabilisticActor(
module=inference_actor,
in_keys=["logits"],
out_keys=["action"],
distribution_class=OneHotCategorical,
return_log_prob=True,
)
training_actor = ProbabilisticActor(
module=training_actor,
in_keys=["logits"],
out_keys=["action"],
distribution_class=OneHotCategorical,
return_log_prob=True,
)
inference_actor = inference_actor.to(device)
training_actor = training_actor.to(device)
_, training_critic, _ = create_model(
input_size=observation_size, output_size=num_actions, out_key="action_value"
)
training_critic = training_critic.to(device)
def create_env_fn():
env = GymEnv(env_name, device=device)
env = TransformedEnv(env)
env.append_transform(rhs_transform)
env.append_transform(InitTracker())
return env
collector = SyncDataCollector(
create_env_fn=create_env_fn,
policy=inference_actor,
frames_per_batch=10,
total_frames=100,
)
loss_module = DiscreteSACLoss(
actor_network=training_actor,
qvalue_network=training_critic,
num_actions=num_actions,
num_qvalue_nets=2,
loss_function="smooth_l1",
)
for data in collector:
loss = loss_module(data.cuda())
loss_sum = loss["loss_actor"] + loss["loss_qvalue"] + loss["loss_alpha"]
loss_sum.backward()
collector.shutdown()
print("Success!")
I don't think we need to create any env, actor, distribution, or even spec to test that bug. This isn't fit to be a unit test unfortunately. But no worry I will try to find a minimal example on my own :)
The issue isn't rnn-related, it is related to the fact that you share params between actor and value and we don't clone when we pass the value params. Hence when we optimize the params, the graph breaks. I even wrote a comment about this: https://github.com/pytorch/rl/blob/0906206c2de07ec7bde99439708928b945503fb2/torchrl/objectives/sac.py#L1208-L1210
We should call clone
whenever the qvalue and actor nets share params. I will write a fix
Describe the bug
Training with the Python-based GRU raises the following error, which indicates the current implementation has some in-place operations that prevent correct backward computation:
@vmoens you were right, in place operations cause problems here. I believe the issue is being addressed in https://github.com/pytorch/rl/pull/1732
To Reproduce
Checklist