rll-research / url_benchmark

MIT License
331 stars 51 forks source link

should the encoder parameters be updated twice in each iteration? #11

Closed kevinNejad closed 2 years ago

kevinNejad commented 2 years ago

Hi, thank you very much for such wonderful work and implementation In the implementation, the encoder has its separate optimiser and does a separate update on top of agent optimiser step. Doesn't the ICM (or ddpg/AC) update the encoder parameters?

I'm wondering if there is any advantage of using separate optimiser/update for the encoder, and if it's necessary for the model ?

Thank you

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, extr_reward, discount, next_obs = utils.to_torch(
            batch, self.device)

        # augment and encode
        obs = self.aug_and_encode(obs)
        with torch.no_grad():
            next_obs = self.aug_and_encode(next_obs)

        if self.reward_free:
            metrics.update(self.update_icm(obs, action, next_obs, step))
def update_icm(self, obs, action, next_obs, step):
        metrics = dict()

        forward_error, backward_error = self.icm(obs, action, next_obs)

        loss = forward_error.mean() + backward_error.mean()

        self.icm_opt.zero_grad(set_to_none=True)
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        loss.backward()
        self.icm_opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()

        if self.use_tb or self.use_wandb:
            metrics['icm_loss'] = loss.item()

        return metrics
kevinNejad commented 2 years ago

Just noticed you pass only the parameters of icm to it's optimiser and there is no global optimser for the entire network. closing the issue now.