shibhansh / loss-of-plasticity

Demonstrations of Loss of Plasticity and Implementation of Continual Backpropagation
MIT License
180 stars 38 forks source link

Performance Question #8

Open tlaurie99 opened 4 days ago

tlaurie99 commented 4 days ago

Hello, and thank you for making this work open source. I am trying to replicate the results from the paper, but instead using a HalfCheetah env where I manipulate the joint stiffness every 20M timesteps like the below. I see a nice drop of reward (~1K) for the first drop (240 - 48 - 9.6 - 1.92 where the first drop is the largest and I see the best drop in reward). So what I am doing is seeing how a PPO algorithm with a normal critic (simply outputs a scalar value and computes the MSE) vs a CBP critic (still outputs a value and computes MSE, but with the neurons able to be reinitialized with the CBP linear). I am seeing that both perform very similar while the CBP implementation performs slightly worse (~2% less average reward) over 25 models of each ran with [1024, 1024] layer sizes.

My question is: am I implementing the CBP layers as intended? Is there a better way to test these networks for relearning after a change to the agent? Do you have any recommendations on how to proceed with testing these?

Thank you!

Env. step:

    def step(self, action):
        self._current_step += 1
        self._total_step += 1
        x_position_before = self.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        x_position_after = self.data.qpos[0]
        x_velocity = (x_position_after - x_position_before) / self.dt

        ctrl_cost = self.control_cost(action)
        forward_reward = self._forward_reward_weight * x_velocity

        observation = self._get_obs()
        reward = forward_reward - ctrl_cost
        terminated = False
        info = {
            "x_position": x_position_after,
            "x_velocity": x_velocity,
            "reward_run": forward_reward,
            "reward_ctrl": -ctrl_cost,
        }

        if self.render_mode == "human":
            self.render()

        terminated = self._current_step >= self._max_episode_time

        if self._total_step % 750_000 == 0:
            self.model.jnt_stiffness[3] /= 5
        return observation, reward, terminated, False, info

CBP model:

class CBPModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.actor_fcnet = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")
        hidden_layer_size = model_config['fcnet_hiddens'][0]

        self.act = nn.LeakyReLU()
        self.fc1 = nn.Linear(obs_space.shape[0], hidden_layer_size)
        self.fc2 = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc3 = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 1)

        self.cbp1 = CBPLinear(self.fc1, self.fc2, replacement_rate=1e-4, maturity_threshold=100, init='kaiming', act_type='leaky_relu')
        self.cbp2 = CBPLinear(self.fc2, self.fc3, replacement_rate=1e-4, maturity_threshold=100, init='kaiming', act_type='leaky_relu')
        self.cbp3 = CBPLinear(self.fc3, self.fc4, replacement_rate=1e-4, maturity_threshold=100, init='kaiming', act_type='leaky_relu')

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        means_clamped = torch.clamp(means, -1, 1)
        log_stds_clamped = torch.clamp(log_stds, -10, 0)
        logits = torch.cat((means_clamped, log_stds_clamped), dim = -1)

        '''-----CBP implementation for critic network-----'''
        obs = input_dict['obs']
        x = self.act(self.fc1(obs))
        x = self.cbp1(x)
        x = self.act(self.fc2(x))
        x = self.cbp2(x)
        x = self.act(self.fc3(x))
        x = self.cbp3(x)
        self.value = self.fc4(x)        
        return logits, state
shibhansh commented 1 day ago

Hey, at first glance, the code seems correct. I think there could be two issues. First, hyper-parameters are not good. For the ant experiment, I used a maturity threshold of 10,000. We also used weight decay of 1e-4; see the config file here. Second, CBP is only implemented with the critic. It may be the case that reinitialization in the actor is critical for good performance.

I'd suggest that you first implement CBP for both the actor and the critic, tune the hyper-parameters and see if the algorithm works. If that works, then you can remove CBP from the actor; that'll tell you the importance of CBP for the actor and critic separately.