cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Jit/Vmap-able DataLoader #49

Open pharringtonp19 opened 2 years ago

pharringtonp19 commented 2 years ago

Continuing with a discussion issues that I first posted here, I was wondering if there were any thoughts about creating a jax-specific data loader that could be nested in a jit/vmap-able function.

One motivation for this proposal is that it would enable us to jit the entire training loop.

I would be happy to help work on such an implementation -- Thanks!

cgarciae commented 2 years ago

Hey @pharringtonp19! I don't know how this skipped my inbox 😱

What I got out of that experiment was that moving data into a JAX program via hcb.call was slower. For a single device environment I think its not hard to get a jit-compatible DataLoader interface, however, I am not sure if we would gain anything if its slower.

If you are still interested, can you recreate a benchmark? Maybe a I missed something or had a bug, its just that a 5x slow-down is pretty rough.

pharringtonp19 commented 2 years ago

@cgarciae Thanks for the response. I will make a benchmark this week and get back to you -- much appreciated