Open jheek opened 3 years ago
Might still be useful for people who don't use tf.data?
People will just implement an ad-hoc function with the exact same semantics. It's too convenient to do without. When you say "copy [...] takes a lot of time", it'd be good to supply some data to support the assertion.
Also I don't see how something like:
split_batch = shard(batch)
complicates much of anything.
The following pattern is common is user code and our examples:
The shard utility is part of
flax.common_utils
.This pattern has 2 issues:
Proposed solution
Make sure input data already has the correct shape at the end of the input pipeline. In tf.data we can do this by using a double batch call:
The advantage here is that the input pipeline always provides use with correctly shaped data. No reshaping is done in the critical loop and there is no confusion about numpy vs jax arrays.
TODO: