rail-berkeley / rlkit

Collection of reinforcement learning algorithms
MIT License
2.43k stars 547 forks source link

AWAC loss function #130

Closed DuaneNielsen closed 3 years ago

DuaneNielsen commented 3 years ago

I'm having a little trouble correlating the Accelerating Online Reinforcement Learning with Offline Datasets paper to the code. Perhaps someone might help me.

The paper indicates the loss function for the policy as..

image

Then in the appendix it mentions that the Z(s) normalization was a bust, so it was thrown out. All good.

So my questions is, which branches of the below code produce cool AWAC results in the paper, and how does that relate to the theoretical result?

From what I can guess its...

normalize_over_state = 'advantage'
weight_loss='True'
alpha=0.0
use_awr_loss = True, 
weight_loss = True
awr_weight = 1.0

Which gives us

score = q_adv - v_pi
weights = F.softmax(score / beta, dim=0)
policy_loss =  self.awr_weight * (-policy_logpp * len(weights)*weights.detach()).mean()
policy_loss = self.rl_weight * policy_loss

So from the equation..

policy_logpp = policy probability Adv(s,a) = score So Beta = lambda

So instead we have something like

policy loss = - log_prob_policy_action softmax(advantage/beta, dim=0) batch_size

Can you confirm that this is the correct update?

Also, can you make the link back to the theoretical result?

Relevant section of code below...

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b/beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi)/beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0) # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch == True:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif self.normalize_over_batch == False:
                weights = score
            else:
                error
        weights = weights[:, 0]

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp * len(weights)*weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (-q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp * len(weights)*weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (-policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (-q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
anair13 commented 3 years ago

Sorry for the late answer, github didn't tag me -

Yes, that's all exactly correct. As for linking to the theoretical result, the resulting policy update should basically be: policy loss = - log_prob_policy_action * exp(advantage/beta)

But we normalize per batch to get: policy loss = - log_prob_policy_action * softmax(advantage/beta, dim=0) * batch_size Softmax is just exp() / sum(exp()) so this makes each batch be on the same scale and not be overly affected by eg. a single large advantage. The last *batch_size just brings it to the same scale as the other losses (eg. BC).

Does that answer the question?

The trainer kwargs for the experiments are also here in the experiment hyperparameters for each domain: https://github.com/vitchyr/rlkit/tree/master/examples/awac

DuaneNielsen commented 3 years ago

Thanks for answering, the softmax batch normalization is a cool trick!

I ran a few experiments with awac on the gym LunarLander environment, just using the loss from the paper. On that environment, the results were amazing. Great results with only 1000 expert transitions and 500 steps of online turning. Insane!

Results were so good that I re-read my code to make sure I had not initialized with an expert policy by mistake!

I also recovered working policies from Breakout Atari. But for spaceinvaders using manually generated actions I couldn't get it to work.

One thing is for sure though, this algorithm shows that model free offline RL is completely possible. This has motivated me to tackle Conservative Q learning as my next project.

Thanks!

anair13 commented 2 years ago

I just saw your post, glad to hear it worked out! Excited to see what it comes to