BY571 / CQL

PyTorch implementation of the Offline Reinforcement Learning algorithm CQL. Includes the versions DQN-CQL and SAC-CQL for discrete and continuous action spaces.
122 stars 23 forks source link

Potential typo in the CQL implementation #5

Closed miladm12 closed 1 year ago

miladm12 commented 2 years ago

Hi, I noticed a potential typo in your implementation of CQL for the Atari game. In the file "/CQL/CQL-DQN/agent.py", in line 52, you subtract Q_a_s.mean() from the first term, though from the formulations in the original paper and the original implementation in tensorflow (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py), this term needs to be weighted based on the actual actions in the mini-batch. Since you already calculate the Q_expected, you just need to replace Q_a_s with Q_expected. So line 52 will become:

cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()

Please let me know if I'm wrong, but I did double check this with the source code and the formulations in the paper.

jialianchen commented 2 years ago

I have the same question as you. I try to change the code but i get a bad result.

截屏2022-08-22 22 41 56

I get confused about the results.

miladm12 commented 2 years ago

I changed it to the following and it is working fine for me:

Q_a_s = self.net(states)
Q_expected = Q_a_s.gather(1, actions) cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()

jialianchen commented 2 years ago

I run it for more steps, and it seems it works.

截屏2022-08-29 15 20 45

In principle, we should be right, but more steps are needed. If do not fix it, the code is more similar to DQN.

habanoz commented 1 year ago

Hi, I noticed a potential typo in your implementation of CQL for the Atari game. In the file "/CQL/CQL-DQN/agent.py", in line 52, you subtract Q_a_s.mean() from the first term, though from the formulations in the original paper and the original implementation in tensorflow (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py), this term needs to be weighted based on the actual actions in the mini-batch. Since you already calculate the Q_expected, you just need to replace Q_a_s with Q_expected. So line 52 will become:

cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()

Please let me know if I'm wrong, but I did double check this with the source code and the formulations in the paper.

I also checked with the reference implementation. This implementation has wrong cql loss calculation. Softmax part should include all action while expected part should include only actions found in the dataset.

This repo has some stars which makes it visible and needs to corrected.

BY571 commented 1 year ago

Hey @habanoz @jialianchen @miladm12 ! Sry for the late response, I just updated the CQL-DQN loss and it should be correct now. Also tested its performance on CartPole-v0: image