Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
MIT License
465 stars 173 forks source link

[Question] How to do pre-training on the RecurrentPPO MlpLstmPolicy #248

Open iwishiwasaneagle opened 4 months ago

iwishiwasaneagle commented 4 months ago

❓ Question

I have been using the library for imitation learning for sb3 PPO with great effect. However, my end goal is to do the same for RecurrentPPO. From testing, I found that the imitation library does not support the MlpLstmPolicy so I went down the supervised learning approach.

During development I ran into the an issue with what I believe to be underlying MlpLstmPolicy code. Now, I fully understand that I am using this class in a potentially unforeseen manner but thought it was prudent to ask for help here regardless.

The runtime error (below) originates when running loss.backward(retain_graph=True) with loss = th.nn.MSELoss(). Any advice would be much appreciated.

/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/ UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/ UserWarning: Error detected in MkldnnRnnLayerBackward0. Traceback of forward call that caused the error:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/", line 17, in <module>
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/traitlets/config/", line 1077, in launch_instance
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 737, in start
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/tornado/platform/", line 195, in start
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/asyncio/", line 607, in run_forever
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/asyncio/", line 1922, in _run_once
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/asyncio/", line 80, in _run, *self._args)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 524, in dispatch_queue
    await self.process_one()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 513, in process_one
    await dispatch(*args)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 418, in dispatch_shell
    await result
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 758, in execute_request
    reply_content = await reply_content
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 426, in do_execute
    res = shell.run_cell(
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/", line 3048, in run_cell
    result = self._run_cell(
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/", line 3103, in _run_cell
    result = runner(coro)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/", line 129, in _pseudo_sync_runner
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/", line 3308, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/", line 3490, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_977293/", line 13, in <module>
    train(student_policy, device, train_loader, optimizer, student._last_lstm_states)
  File "/tmp/ipykernel_977293/", line 13, in train
    action, _, _, lstm_state = model(data_obsi.reshape((1, 1,-1)), lstm_state, data_ep_starti)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/sb3_contrib/common/recurrent/", line 237, in forward
    latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/sb3_contrib/common/recurrent/", line 199, in _process_sequence
    hidden, lstm_states = lstm(
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/", line 911, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

RuntimeError                              Traceback (most recent call last)
Cell In[15], line 13
     10 optimizer = th.optim.Adadelta(student_policy.parameters(), lr=1.0)
     12 for epoch in range(1, N_IMITATE):
---> 13     train(student_policy, device, train_loader, optimizer, student._last_lstm_states)
     14     test(student_policy, device, test_loader, student._last_lstm_states)
     16 # Implant the trained policy network back into the RL student agent

Cell In[12], line 17, in train(model, device, train_loader, optimizer, start_lstm_state)
     14 action_prediction = action.double()
     16 loss = criterion(action_prediction.view(-1), targeti)
---> 17 loss.backward(retain_graph=True)
     18 optimizer.step()

File /opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File /opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File /opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1024, 3]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
