rail-berkeley / rlkit

Collection of reinforcement learning algorithms
MIT License
2.5k stars 553 forks source link

Investigate super-convergence on RL algorithms #53

Open redknightlois opened 5 years ago

redknightlois commented 5 years ago

I have been using these two routines to figure out the best learning rate to apply with awesome results on SAC. However, the changes in the temperature alter those values along the way. Probably would be a good idea to extend it further to do some sort of 'automatic' discovery of LR after x amount of epochs. This version will also mess up the gradients, so you cannot use the policy after you run this.

    def find_policy_lr_step(self, loss):

        self.find_lr_batch_num += 1

        if self.find_lr_batch_num == 1:
            self.find_lr_avg_loss = 0.0
            self.find_lr_worst_loss = loss.item()
            self.find_lr_best_loss = loss.item()
            self.find_lr_best_lr = self.policy_optimizer.param_groups[0]['lr']
            self.find_lr_worst_lr = self.policy_optimizer.param_groups[0]['lr']

        self.find_lr_avg_loss = self.find_lr_beta * self.find_lr_avg_loss + (1-self.find_lr_beta) * loss.item()
        smoothed_loss = self.find_lr_avg_loss / (1 - self.find_lr_beta ** self.find_lr_batch_num)

        # Record the best and worst loss
        if self.find_lr_batch_num > self.find_lr_batches // 10 and smoothed_loss < self.find_lr_best_loss:
            self.find_lr_best_lr = self.find_lr_current_lr
            self.find_lr_best_loss = smoothed_loss

        # We only record at the start (we dont care about the divergent part)
        if self.find_lr_batch_num < self.find_lr_batches // 5:
            self.find_lr_worst_loss = max(smoothed_loss, self.find_lr_worst_loss)

        # Stop if the loss is exploding
        if self.find_lr_batch_num > self.find_lr_batches:

            import matplotlib.pyplot as plt
            plt.plot(self.find_lr_log_lrs,self.find_lr_losses)
            plt.show()

            # TODO: This is a simplistic heuristic until we do it properly doing gradient analysis.
            printout(f'The best learning rate for network could be around: {self.find_lr_best_lr / 10}')
            printout(f'Process will exit because finding the learning rate will make your gradients to degenerate')
            exit(0)

        # Store the values unless we are already diverging
        if smoothed_loss <= self.find_lr_worst_loss:
            self.find_lr_losses.append(smoothed_loss)
            self.find_lr_log_lrs.append(math.log10(self.policy_optimizer.param_groups[0]['lr']))

        # Update with the new learning rate.
        self.find_lr_current_lr *= self.find_lr_multiplier
        self.policy_optimizer.param_groups[0]['lr'] = self.find_lr_current_lr

    def find_qfunc_lr_step(self, qf1_loss, qf2_loss):
        self.find_lr_batch_num += 1

        if self.find_lr_batch_num == 1:
            self.find_lr_avg_loss = 0.0
            self.find_lr_worst_loss = min( qf1_loss.item(), qf2_loss.item() )
            self.find_lr_best_loss = min( qf1_loss.item(), qf2_loss.item() )
            self.find_lr_best_lr = self.qf1_optimizer.param_groups[0]['lr']
            self.find_lr_worst_lr = self.qf1_optimizer.param_groups[0]['lr']

        self.find_lr_avg_loss = self.find_lr_beta * self.find_lr_avg_loss + (1-self.find_lr_beta) * min( qf1_loss.item(), qf2_loss.item() )
        smoothed_loss = self.find_lr_avg_loss / (1 - self.find_lr_beta ** self.find_lr_batch_num)

        # Record the best and worst loss
        if self.find_lr_batch_num > self.find_lr_batches // 10 and smoothed_loss < self.find_lr_best_loss:
            self.find_lr_best_lr = self.find_lr_current_lr
            self.find_lr_best_loss = smoothed_loss

        # We only record at the start (we dont care about the divergent part)
        if self.find_lr_batch_num < self.find_lr_batches // 5:
            self.find_lr_worst_loss = max(smoothed_loss, self.find_lr_worst_loss)

        # Stop if the loss is exploding
        if self.find_lr_batch_num > self.find_lr_batches:

            import matplotlib.pyplot as plt
            plt.plot(self.find_lr_log_lrs,self.find_lr_losses)
            plt.show()

            # TODO: This is a simplistic heuristic until we do it properly doing gradient analysis.
            printout(f'The best learning rate for q function approximator could be around: {self.find_lr_best_lr / 10}')
            printout(f'Process will exit because finding the learning rate will make your gradients to degenerate')
            exit(0)

        # Store the values unless we are already diverging
        if smoothed_loss <= self.find_lr_worst_loss:
            self.find_lr_losses.append(smoothed_loss)
            self.find_lr_log_lrs.append(math.log10(self.qf1_optimizer.param_groups[0]['lr']))

        # Update with the new learning rate.
        self.find_lr_current_lr *= self.find_lr_multiplier
        self.qf1_optimizer.param_groups[0]['lr'] = self.find_lr_current_lr
        self.qf2_optimizer.param_groups[0]['lr'] = self.find_lr_current_lr
vitchyr commented 5 years ago

Thanks for the post! I'm a bit unsure what you are asking. Are you asking that others or I try this out, or merge this code in? Or were you asking for feedback?

Also, if you have example plots for the performance of this on specific environments, it would help.

redknightlois commented 5 years ago

I don't do research, there is probably a lot of things to do to achieve something worth publishing. Just letting you know that it shows promising result in my limited trials. So this is kind of an observation

vitchyr commented 5 years ago

I see. Thanks for sharing! Would you mind posting your results here?

On Tue, Jun 11, 2019, 4:25 PM Federico Andres Lois notifications@github.com wrote:

I don't do research, there is probably a lot of things to do to achieve something worth publishing. Just letting you know that it shows promising result in my limited trials. So this is kind of an observation

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/vitchyr/rlkit/issues/53?email_source=notifications&email_token=AAJ4VZMFYRBBOV46NOCCDGTP2AX7XA5CNFSM4HJEARI2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODXOZJ2A#issuecomment-501060840, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJ4VZJQEABC6HCHLHPI6SDP2AX7XANCNFSM4HJEARIQ .

redknightlois commented 5 years ago

Sorry, I would really like but I am under NDA for this stuff. What I can say (which is general enough) is that even though the source data is very difficult to make it converge using general methods (in the supervised case too); with super convergence effects I was able to steer the policy quite rapidly (in the same way I am able to do on the supervised case). I am training supervised neural networks in under 100 minutes what it took multiple days 6 months ago to the same accuracy.