Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
465 stars 173 forks source link

[Question] How to do pre-training on the RecurrentPPO MlpLstmPolicy #248

Open iwishiwasaneagle opened 4 months ago

iwishiwasaneagle commented 4 months ago

❓ Question

I have been using the https://github.com/HumanCompatibleAI/imitation/ library for imitation learning for sb3 PPO with great effect. However, my end goal is to do the same for RecurrentPPO. From testing, I found that the imitation library does not support the MlpLstmPolicy so I went down the supervised learning approach.

During development I ran into the an issue with what I believe to be underlying MlpLstmPolicy code. Now, I fully understand that I am using this class in a potentially unforeseen manner but thought it was prudent to ask for help here regardless.

The runtime error (below) originates when running loss.backward(retain_graph=True) with loss = th.nn.MSELoss(). Any advice would be much appreciated.

/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Error detected in MkldnnRnnLayerBackward0. Traceback of forward call that caused the error:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/traitlets/config/application.py", line 1077, in launch_instance
    app.start()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 737, in start
    self.io_loop.start()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 524, in dispatch_queue
    await self.process_one()
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 513, in process_one
    await dispatch(*args)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 418, in dispatch_shell
    await result
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 758, in execute_request
    reply_content = await reply_content
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 426, in do_execute
    res = shell.run_cell(
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    result = self._run_cell(
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    result = runner(coro)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_977293/1171147541.py", line 13, in <module>
    train(student_policy, device, train_loader, optimizer, student._last_lstm_states)
  File "/tmp/ipykernel_977293/3549058118.py", line 13, in train
    action, _, _, lstm_state = model(data_obsi.reshape((1, 1,-1)), lstm_state, data_ep_starti)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/sb3_contrib/common/recurrent/policies.py", line 237, in forward
    latent_pi, lstm_states_pi = self._process_sequence(pi_features, lstm_states.pi, episode_starts, self.lstm_actor)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/sb3_contrib/common/recurrent/policies.py", line 199, in _process_sequence
    hidden, lstm_states = lstm(
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 911, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], line 13
     10 optimizer = th.optim.Adadelta(student_policy.parameters(), lr=1.0)
     12 for epoch in range(1, N_IMITATE):
---> 13     train(student_policy, device, train_loader, optimizer, student._last_lstm_states)
     14     test(student_policy, device, test_loader, student._last_lstm_states)
     16 # Implant the trained policy network back into the RL student agent

Cell In[12], line 17, in train(model, device, train_loader, optimizer, start_lstm_state)
     14 action_prediction = action.double()
     16 loss = criterion(action_prediction.view(-1), targeti)
---> 17 loss.backward(retain_graph=True)
     18 optimizer.step()

File /opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File /opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File /opt/miniconda3/envs/jpathgen/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1024, 3]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Checklist