hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.13k stars 723 forks source link

how to customize policy network for SAC #349

Closed yutao-li closed 5 years ago

yutao-li commented 5 years ago

If you have any questions, feel free to create an issue with the tag [question].
If you wish to suggest an enhancement or feature request, add the tag [feature request].
If you are submitting a bug report, please fill in the following details.

Describe the bug

A clear and concise description of what the bug is.

I follow the instructions to customize a policy network for SAC, but it does not work. Can you show a brief example on how to do that?

Code example Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

from stable_baselines import SAC
from stable_baselines.sac.policies import LnMlpPolicy

agent = SAC(LnMlpPolicy, "Pendulum-v0", policy_kwargs=dict(net_arch=[128, 128, dict(pi=[64], vf=[64])]))
Traceback (most recent call last):
  File "/datadrive/yutao/.pycharm_helpers/pydev/pydev_run_in_console.py", line 53, in run_file
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/datadrive/yutao/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/datadrive/yutao/scratch/scratch_8.py", line 4, in <module>
    agent = SAC(LnMlpPolicy, "Pendulum-v0", policy_kwargs=dict(net_arch=[128, 128, dict(pi=[64], vf=[64])]))
  File "/datadrive/yutao/anaconda3/lib/python3.7/site-packages/stable_baselines/sac/sac.py", line 123, in __init__
    self.setup_model()
  File "/datadrive/yutao/anaconda3/lib/python3.7/site-packages/stable_baselines/sac/sac.py", line 145, in setup_model
    **self.policy_kwargs)
  File "/datadrive/yutao/anaconda3/lib/python3.7/site-packages/stable_baselines/sac/policies.py", line 365, in __init__
    feature_extraction="mlp", layer_norm=True, **_kwargs)
  File "/datadrive/yutao/anaconda3/lib/python3.7/site-packages/stable_baselines/sac/policies.py", line 189, in __init__
    self._kwargs_check(feature_extraction, kwargs)
  File "/datadrive/yutao/anaconda3/lib/python3.7/site-packages/stable_baselines/common/policies.py", line 177, in _kwargs_check
    raise ValueError("Unknown keywords for policy: {}".format(kwargs))
ValueError: Unknown keywords for policy: {'net_arch': [128, 128, {'pi': [64], 'vf': [64]}]}

System Info Describe the characteristic of your environment:

Additional context Add any other context about the problem here.

araffin commented 5 years ago

Hello, Please read carefully the documentation of SAC:

"The SAC model does not support stable_baselines.common.policies because it uses double q-values and value estimation, as a result it must use its own policy models (see SAC Policies)."

The net_arch keyword is for stable_baselines.common.policies only, you have to use layers in that case.

yutao-li commented 5 years ago

ok, thanks