Closed ConstantinRuhdorfer closed 2 months ago
Hey! Good question, the expensive parts of the training loop are still jitted but via jax.pmap (https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) in order to support running on multiple GPUs.
The expensive parts of the training loop (train step and self-play) are jitted (via pmap) here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L352-L406
Lmk if you have any other qs, thanks for taking a look at turbozero!
Hi, Thanks for getting back so quickly! I was not aware that pmap also runs jit under the hood, thanks.
Hi,
Lovely repo, thanks for providing it! I am interested in evaluating it with one of my custom pgx environments.
I just have a quick question: Can you please clarify where the training loop is jitted? I checked the code of the library and can't seem to find any places in which jit is used but saw that this used to be the case a few months ago. Should I re-add them in my project?
Thanks!