luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
679 stars 56 forks source link

Applications of meta-rl to DQN #1

Closed pseudo-rnd-thoughts closed 1 year ago

pseudo-rnd-thoughts commented 1 year ago

Hi,

I have just read the blog post and think this is really cool work. I'm guessing that a full academic version of the work is coming out soon.

I have a couple of questions about Figure 5

  1. Do you have any more understanding of what is happening there? Why this is preferred over L2?
  2. I'm reminded of Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents that found that the huber loss function does a bit better over the standard L2 loss as well for DQN, so I would be interested in replicating this work for DQN and investigating the type of loss function found there. Is there similarity in the asymmetry and non-convexness of the PPO loss function.

Thanks for any information

luchris429 commented 1 year ago

Thanks for the interest!

I'm guessing that a full academic version of the work is coming out soon.

Yep! Also you somehow found the blog before it was fully ready for promotion 😅

Do you have any more understanding of what is happening there? Why this is preferred over L2?

Some collaborators and I are going to put out a paper specifically on this soon!

replicating this work for DQN

Yep! A DQN version of this is 100% doable. I had an old version of it that I used for a previous paper. I will see if I can clean it up and incorporate it into this repo! Keep in mind that it might not scale as well as PPO because the experience buffer can consume a lot of GPU memory. (That being said, it's still more than scaleable enough to get some cool results).

pseudo-rnd-thoughts commented 1 year ago

Amazing, I look forward to reading the paper

Keep in mind that it might not scale as well as PPO because the experience buffer can consume a lot of GPU memory.

I guess this is where pmap is amazing as well