Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
465 stars 173 forks source link

Recurrent PPO Not Training Well on a Very Simple Environment #211

Open sreejank opened 11 months ago

sreejank commented 11 months ago

🐛 Bug

I've adapted the environment from this blog post, exact code of the env shown below. They implemented a recurrent A3C agent in TF1 (it was written a while ago).

It's a very simple "contextual bandit" environment where, for each episode, there are two random colors given as the observation, the order of which is randomly flipped at each timestep. There are two actions, each corresponding to a particular color. The reward is associated with a particular color, and the agent has to employ a strategy where they figure out which color leads to reward and select the action based on that color. Upon environment reset, the task structure stays the same but new pixel colors are chosen.

I figured SB3 Recurrent PPO should be able to solve this environment pretty easily. There is work showing that a recurrent policy network trained with A2C can solve a more complex 3d version of this task.

I have tried training for 1e6-1e8 timesteps and optimizing hyperparameters with Optuna (my ranges can be found below). I'm wondering whether this task uncovers a hidden issue with Recurrent PPO in SB3 or if this is just a deceptively difficult task? I have yet to try to run this environment with other packages as well to see if its specific to recurrent ppo.

Any thoughts/insights?

Hyperparameter ranges:

gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999])
n_steps = trial.suggest_categorical('n_steps', [2,4,6,8,10,12,14])
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1,log=True)
lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant'])
ent_coef = trial.suggest_float("ent_coef",0, 1,log=False)
vf_coef = trial.suggest_float("vf_coef", 0, 1,log=False)
 n_lstm=trial.suggest_categorical('n_lstm',[10,30,90,120])
num_layers=trial.suggest_categorical('num_layers',[0,1,2,3]) 
activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])

clip_range_vf=None
batch_size=n_steps 
gae_lambda=1
n_epochs=1
clip_range=0.2
normalize_advantage=False 
max_grad_norm = 0.5

Code example

class ContextualBanditEnv(gym.Env):
    metadata = {'render_modes': ['human','rgb_array']}
    def __init__(self):
        self.num_actions = 2 
        self.action_space=Discrete(self.num_actions)
        self.observation_space=Box(low=0,high=1,shape=(9,),dtype=np.float64)
        self.prev_action=np.zeros(self.num_actions) 
        self.prev_reward=0 
        self.num_correct=0
        self.num_trials=100

        self.reset()

    def get_state(self):
        self.internal_state = np.random.permutation(self.choices)
        self.state = np.concatenate(np.reshape(self.internal_state,[2,1,1,3]),axis=1)
        return np.concatenate([self.state.reshape((6,)),self.prev_action,[self.prev_reward]]) 

    def reset(self,seed=None,options={}):
        self.timestep = 0
        color = [np.random.uniform(),np.random.uniform(),np.random.uniform()]
        a = [np.reshape(np.array(color),[1,1,3]),np.reshape(1-np.array(color),[1,1,3])]
        self.true = a[0]
        self.choices = a
        self.prev_action=np.zeros(self.num_actions) 
        self.prev_reward=0 
        self.num_correct=0 
        return self.get_state(),{'t':self.timestep}

    def step(self,action):
        self.timestep += 1 
        if (self.internal_state[action] == self.true).all() == True: 
            reward = 1.0
            self.num_correct+=1
        else:
            reward = 0.0
        new_state = self.get_state()
        self.prev_action=np.zeros(self.num_actions) 
        self.prev_action[action]=1 
        self.prev_reward=reward 

        if self.timestep > self.num_trials: 
            done = True
            #print("ACCURACY: ",float(self.num_correct)/self.timestep,self.timestep)
        else: 
            done = False

        return new_state,reward,done,False,{'t':self.timestep}

Relevant log output / Error message

No response

System Info

No response

Checklist