lowrollr / turbozero

fast + parallel AlphaZero in JAX
Apache License 2.0
85 stars 7 forks source link

Confusion around jit in train loop #18

Closed ConstantinRuhdorfer closed 2 months ago

ConstantinRuhdorfer commented 2 months ago

Hi,

Lovely repo, thanks for providing it! I am interested in evaluating it with one of my custom pgx environments.

I just have a quick question: Can you please clarify where the training loop is jitted? I checked the code of the library and can't seem to find any places in which jit is used but saw that this used to be the case a few months ago. Should I re-add them in my project?

Thanks!

lowrollr commented 2 months ago

Hey! Good question, the expensive parts of the training loop are still jitted but via jax.pmap (https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) in order to support running on multiple GPUs.

The expensive parts of the training loop (train step and self-play) are jitted (via pmap) here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L352-L406

Lmk if you have any other qs, thanks for taking a look at turbozero!

ConstantinRuhdorfer commented 2 months ago

Hi, Thanks for getting back so quickly! I was not aware that pmap also runs jit under the hood, thanks.