Closed technocrat13 closed 11 months ago
I'd be more than happy to send a pull request!
Please do =) It is indeed missing.
I think you just need to extend the ppo sampler (have a sample_ppo_lstm()
that will call sample_ppo_params()
).
"ppo_lstm": sample_ppo_params,
is actually already a quick and good solution.
In my implementation of this where sample_ppo_lstm_params()
calls sample_ppo_params()
, I am encountering a limit in optuna's trail.suggest_categorical()
I am sampling net_arch
from ["tiny", "medium", "large"]
for the LSTM but in vanilla PPO it is sampling from only ["medium", "large"]
Optuna is unable to suggest categorical variables as it does not support having multiple parameters with same name but different value space, it does not even override it. There are some discussions on implementing it but they are old
Using a new name for the LSTM's net_arch
(eg. net_arch_lstm
) wastes search space and is an inelegant solution
I have three possible solutions to this:
Not including "tiny"
in sample_ppo_lstm_params()
:
Including "tiny"
in sample_ppo_params()
:
Passing a flag
to sample_ppo_params()
:
All these 3 assume that I am extending the function and updating the "policy_kwargs"
in the returned dictionary, if I just write a new function these issues simply do not exist.
Which solution according to you is ideal @araffin? And do you have any suggestions of your own? I can look into their implementation.
Solution 1 or 2 are fine for me.
🐛 Bug
There is no direct way to optimize hyperparameters for ppo_lstm (RecurrentPPO) from the command line. It can only be achieved by amending like so in hyperparams_opt.py
Interesting to note that ALGO in utils.py does implement ppo_lstm but returns a KeyError when trying to pass 'ppo_lstm' with -optimize/--optimize-hyperparameters
Am I missing something? As I understand it the hyperparameters themselves do not change too much going from PPO to RecurrentPPO, only their values, hence the same sampling function can be used, unless the
policy_kwargs
needs to be changed then asample_ppo_lstm_params()
needs to be implemented, I'd be more than happy to send a pull request!To Reproduce
Relevant log output / Error message
System Info
pip installed zoo, sb3, and sb3-contrib ensured relevant files are up to date with the latest rl-baselines-zoo/main
Checklist