Open pharringtonp19 opened 2 years ago
Hey @pharringtonp19! I don't know how this skipped my inbox 😱
What I got out of that experiment was that moving data into a JAX program via hcb.call
was slower. For a single device environment I think its not hard to get a jit
-compatible DataLoader interface, however, I am not sure if we would gain anything if its slower.
If you are still interested, can you recreate a benchmark? Maybe a I missed something or had a bug, its just that a 5x slow-down is pretty rough.
@cgarciae Thanks for the response. I will make a benchmark this week and get back to you -- much appreciated
Continuing with a discussion issues that I first posted here, I was wondering if there were any thoughts about creating a jax-specific data loader that could be nested in a jit/vmap-able function.
One motivation for this proposal is that it would enable us to jit the entire training loop.
I would be happy to help work on such an implementation -- Thanks!