danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 219 forks source link

Ask about jit in ninjax.py #100

Closed realjoenguyen closed 5 months ago

realjoenguyen commented 9 months ago

Hi,

As I am a beginner in jax, may I ask why you need to re-define jit in ninjax.py (here) ?

Is it because you need to return state in your jit? Sorry for a seemingly dumb question!
Thank you in advance!

danijar commented 5 months ago

Just updated the code, which removes nj.jit.