higgsfield / Imagination-Augmented-Agents

Building Agents with Imagination: pytorch step-by-step implementation
205 stars 46 forks source link

Imagination Core Full Rollout One Hot Encoding Incorrect #1

Open ASzot opened 6 years ago

ASzot commented 6 years ago

In imagination-augmented-agent.py the imagination core does not correctly one hot encode the actions for the environment model as shown in the lines below:

       if self.full_rollout:
            state = state.unsqueeze(0).repeat(self.num_actions, 1, 1, 1, 1).view(-1, *self.in_shape)
            action = torch.LongTensor([[i] for i in range(self.num_actions)]*batch_size)
            rollout_batch_size = batch_size * self.num_actions
        else:
            action = self.distil_policy.act(Variable(state, volatile=True))
            action = action.data.cpu()
            rollout_batch_size = batch_size

        for step in range(self.num_rolouts):
            onehot_action = torch.zeros(rollout_batch_size, self.num_actions, *self.in_shape[1:])
            onehot_action[range(rollout_batch_size), action] = 1

For a full rollout the action becomes of a column vector of size [batch_size, 1] this results in the line onehot_action[range(rollout_batch_size), action] = 1 not being correct (and taking a very long time). The fix is to add action = action.view(-1).

higgsfield commented 6 years ago

Thanks! Can you pull request?