hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

TD3 Policy and Target Policy Naming Conflict #469

Closed nathanhjay closed 5 years ago

nathanhjay commented 5 years ago

Description When I try to instantiate a TD3 model, I get an error in the init function on line 136:

ValueError: Variable input/model/pi_fc0/w already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?

It's possible I misunderstand the variable naming going on, but won't the following lines from td3.py init() always cause a naming conflict because of the reuse=False in the scoping?

with tf.variable_scope("input", reuse=False):
    self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space, **self.policy_kwargs)`                           
    self.target_policy_tf = self.policy(self.sess, self.observation_space, self.action_space, **self.policy_kwargs)

Code Example You can replicate the issue with the following code:

import gym                                                                                 

from stable_baselines.common.policies import MlpPolicy                                           
from stable_baselines.common.policies import FeedForwardPolicy                                   
from stable_baselines import TD3                                                                 

class MyMlpPolicy(FeedForwardPolicy):                                                            
    def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, **_kwargs):                                                                                       
        super(MyMlpPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse, feature_extraction="mlp", **_kwargs)                     

env = gym.make('CartPole-v0')                                                                    
model = TD3(MyMlpPolicy, env)                                                           

System Info

araffin commented 5 years ago

Hello,

as mentioned in the documentation, you should be using td3.policies, and a continuous action environment like Pendulum-v0, not cartpole, because TD3 only support continuous actions.

the following code works:

import gym                                                                                 

from stable_baselines.td3.policies import FeedForwardPolicy                                   
from stable_baselines import TD3                                                                 

class MyMlpPolicy(FeedForwardPolicy):                                                            
    def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, **_kwargs):                                                                                       
        super(MyMlpPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse, feature_extraction="mlp", **_kwargs)                     

env = gym.make('Pendulum-v0')                                                                    
model = TD3(MyMlpPolicy, env)                                                           
nathanhjay commented 5 years ago

Works now, thanks for the help.

Miffyli commented 5 years ago

Agents could check if action/observation spaces are one of the supported type and throw a bit more informative exception. A quick PR for later time :)

araffin commented 5 years ago

there is already a check for that ;) (at least for td3)

Miffyli commented 5 years ago

Ah my bad, I did not notice the issue was with using wrong policies. My bad! Perhaps a check for that, but it is already well-documented with proper highlights ^^