Open MRiabov opened 2 months ago
Hey. This is in my roadmap at some point but it's quite an involved algorithm. I have written it before but some thought will need to go into how to do it cleanly. I unfortunately have no estimate as to when I can get around to doing this.
I'm using Dreamer myself, and looking at the code - a lot can be simply replaced by Jax alternatives like scan and vmap.
Would you mind sharing the jax implementation, by the way?
On Wed, 4 Sept 2024, 19:13 Edan Toledo, @.***> wrote:
Hey. This is in my roadmap at some point but it's quite an involved algorithm. I have written it before but some thought will need to go into how to do it cleanly. I unfortunately have no estimate as to when I can get around to doing this.
— Reply to this email directly, view it on GitHub https://github.com/EdanToledo/Stoix/issues/113#issuecomment-2329694466, or unsubscribe https://github.com/notifications/unsubscribe-auth/AZZOTD2OTPKZFLMWRTLZTKLZU5EV7AVCNFSM6AAAAABNTUJ66WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMRZGY4TINBWGY . You are receiving this because you authored the thread.Message ID: @.***>
my code is quite messy right now - it was for a paper I submitted a while ago and it was for a multi-agent use case using dreamer and graph neural networks. It's not that the code is difficult, it's just that there are a lot of details to incorporate if you want to accurately represent the paper. I can try push this up on my todo list but I can't say it's a priority right now. I'll leave the issue open though to remind me.
I would add that this repo has an implementation, although with some changes. https://github.com/symoon11/dreamerv3-flax/tree/main
That said, I'll have to rewrite the paper for perfomance on my own, I guess.
Please describe the purpose of the feature. Is it related to a problem?
Create a DreamerV3 implementation - there is no pure Jax implementation to date.
Describe the solution you'd like
Pure Jax Anakin/Sebulba implementation of DreamerV3, including pairing with native Jax environments.
Describe alternatives you've considered
Current implementation of DreamerV3 has half numpy half jax.numpy in code. Which is suboptimal.
How do we know when implementation of this feature is complete?
Checklist:
Additional context
DreamerV3 is currently the smartest algorithm out there, and was able to collect diamonds in Minecraft with fixed hyperparameters and no human data involved. See https://arxiv.org/pdf/2301.04104v2, this is the current implementation: https://github.com/danijar/dreamerv3/tree/main