Curt-Park / rainbow-is-all-you-need

Rainbow is all you need! A step-by-step tutorial from DQN to Rainbow
MIT License
1.87k stars 334 forks source link

input state-action pair into Rainbow DQN #35

Closed junhuang-ifast closed 4 years ago

junhuang-ifast commented 4 years ago

Hi, I was thinking of incorporating the action (in addition to state) as a state-action pair input into the rainbow dqn model, however I am unsure of which part to insert it. Below code shows 4 places where I am thinking of adding the actions (as input to the model), but I am unsure if it is appropriate to add them there or not. (please see "<----" symbol)

def _compute_dqn_loss(self, samples: Dict[str, np.ndarray], gamma: float) -> torch.Tensor:
        """Return categorical dqn loss."""
        device = self.device  # for shortening the following lines
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"]).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
        done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)

        # Categorical DQN algorithm
        delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)

        with torch.no_grad():
            # Double DQN
            next_state_EDIT = np.concatenate([next_state, action])   <---- concat action
            next_action = self.dqn(next_state_EDIT).argmax(1)        <---- edited state as input
            next_dist = self.dqn_target.dist(next_state_EDIT)        <---- edited state as input
            next_dist = next_dist[range(self.batch_size), next_action]

            t_z = reward + (1 - done) * gamma * self.support
            t_z = t_z.clamp(min=self.v_min, max=self.v_max)
            b = (t_z - self.v_min) / delta_z
            l = b.floor().long()
            u = b.ceil().long()

            offset = (
                torch.linspace(
                    0, (self.batch_size - 1) * self.atom_size, self.batch_size
                ).long()
                .unsqueeze(1)
                .expand(self.batch_size, self.atom_size)
                .to(self.device)
            )

            proj_dist = torch.zeros(next_dist.size(), device=self.device)
            proj_dist.view(-1).index_add_(
                0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
            )
            proj_dist.view(-1).index_add_(
                0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
            )
        state_EDIT = np.concatenate([state, action])                  <---- concat action
        dist = self.dqn.dist(state_EDIT)                              <---- edited state as input
        log_p = torch.log(dist[range(self.batch_size), action])
        elementwise_loss = -(proj_dist * log_p).sum(1)

        return elementwise_loss
def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input state."""
        # NoisyNet: no epsilon greedy action selection
        state_EDIT = np.concatenate([state, action])               <---- concat action
        selected_action = self.dqn(
            torch.FloatTensor(state_EDIT).to(self.device)          <---- edited state as input
        ).argmax()
        selected_action = selected_action.detach().cpu().numpy()

        if not self.is_test:
            self.transition = [state, selected_action]

        return selected_action

I have seen state-action pair as input to the Q function of soft actor critic before, but not in DQN. So I am unsure if its logical to do this, especially in self.dqn.dist(state_EDIT) and selected_action = self.dqn(torch.FloatTensor(state_EDIT).to(self.device)).argmax().

Any ideas on this? thanks :)

Curt-Park commented 4 years ago

Firstly, I would like to know why you want to use state-action pairs for DQN. DQN is a method for problems with small size discrete actions, so it is designed to predict all actions' values according to input states. Your approach (state-action input) is usually employed for the problems on continuous action space which is intractable to predict all state-actions' values.