Zhendong-Wang / Diffusion-Policies-for-Offline-RL

Apache License 2.0
219 stars 33 forks source link

Action resampled from 50 actions while evaluation? #3

Closed ChenDRAG closed 1 year ago

ChenDRAG commented 1 year ago

Hi, I noticed that in implementation. The final action used for evaluation is actually selected from 50 action candidates generated by the actor network based on their estimated Q value. Since this seems not to be mentioned in the paper, I wonder how this resampling trick affects the algorithm's performance. Is this technique essential to the final performance of the algorithm?

https://github.com/Zhendong-Wang/Diffusion-Policies-for-Offline-RL/blob/35481c7981322ba31c5004b9e5d57f282ffd1876/agents/ql_diffusion.py#L187

def sample_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
        with torch.no_grad():
            action = self.actor.sample(state_rpt)
            q_value = self.critic_target.q_min(state_rpt, action).flatten()
            idx = torch.multinomial(F.softmax(q_value), 1)
        return action[idx].cpu().data.numpy().flatten()
Zhendong-Wang commented 1 year ago

Hi ChenDRAG,

First thanks for your interest in our paper and provide good suggestions for reducing the CPU usage. I will consider this as an option later, since this may require more GPU memories.

As for the action resampling, we inherit this from BCQ framework. I didn't play with it a lot, since this evaluation was set by default. The evaluation time won't change much due to the parallel computing by GPU. In my memeory, this technique is not essential for the final performance of general tasks but may help AntMaze tasks. I am not pretty sure here, but if you are interested, you can investigate it. 👍

ChenDRAG commented 1 year ago

Thank you for your quick reply!