BY571 / IQN-and-Extensions

PyTorch Implementation of Implicit Quantile Networks (IQN) for Distributional Reinforcement Learning with additional extensions like PER, Noisy layer, N-step bootstrapping, Dueling architecture and parallel env support.
MIT License
80 stars 16 forks source link

IQN-DQN.ipynb max over taus instead of max over actions? #9

Open IgorAherne opened 1 year ago

IgorAherne commented 1 year ago

Hello,

Given that forward() will return tuple: return out.view(batch_size, num_tau, self.num_actions), taus

Should we use .max(1) instead of .max(2) ? Currently it is: Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N) Maybe should be: Q_targets_next = Q_targets_next.detach().max(1)[0].unsqueeze(1) # (batch_size, 1, numActions)

In other words, to find the maximum in every tau group, rather than across every action? Sorry if I misunderstood the process.

def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        self.optimizer.zero_grad()
        states, actions, rewards, next_states, dones = experiences
        # Get max predicted Q values (for next states) from target model
        Q_targets_next, _ = self.qnetwork_target(next_states)
        Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N)    <-------------------------------------- HERE

        # Compute Q targets for current states 
        Q_targets = rewards.unsqueeze(-1) + (self.GAMMA**self.n_step * Q_targets_next * (1. - dones.unsqueeze(-1)))
        # Get expected Q values from local model
        Q_expected, taus = self.qnetwork_local(states)
        Q_expected = Q_expected.gather(2, actions.unsqueeze(-1).expand(self.BATCH_SIZE, 8, 1))

        # Quantile Huber loss
        td_error = Q_targets - Q_expected
        assert td_error.shape == (self.BATCH_SIZE, 8, 8), "wrong td error shape"
        huber_l = calculate_huber_loss(td_error, 1.0)
        quantil_l = abs(taus -(td_error.detach() < 0).float()) * huber_l / 1.0

        loss = quantil_l.sum(dim=1).mean(dim=1) # , keepdim=True if per weights get multipl
        loss = loss.mean()

        # Minimize the loss
        loss.backward()
        #clip_grad_norm_(self.qnetwork_local.parameters(),1)
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)
        return loss.detach().cpu().numpy()            
BY571 commented 1 year ago

Hey, I haven't had a look at the code for a while. Can you reference the part in the paper that made you think it's the max over the taus?

IgorAherne commented 1 year ago

Hi Sebastian,

From page 5 of this paper https://arxiv.org/pdf/1806.06923.pdf These equations are a bit tough for me, but looking at equation 2 and 3 from here:

image