Open dev-goyal opened 6 months ago
If this sounds reasonable, and I haven't missed anything crucial, I'd be quite happy to make the PR. Perhaps we want to only make this configurable if collate_fn
is provided?
@dev-goyal thanks for the suggestion! We would be happy to help shepherd the PR, let us know if we can help answer any questions.
I think it make senses to override the current default batch type used only if collate_fn
is provided
Description
Currently,
iter_torch_batches
itself callsiter_batches
, https://github.com/ray-project/ray/blob/1d6983380b8adacd33f92588110a820e6587467c/python/ray/data/iterator.py#L393without setting a value for
batch_format
, which means it defaults to the defaultnumpy
format https://github.com/ray-project/ray/blob/1d6983380b8adacd33f92588110a820e6587467c/python/ray/data/iterator.py#L138This means, the incoming batch in
iter_torch_batches
is first converted to numpy, before any processing is done on it. Especially in the case whencollate_fn
is provided, this could be somewhat wasteful. Consider the following example,Thus, if
collate_fn
would allow accepting pyarrow, this would be much easier.Moreover, given that
collate_fn
runs on the GPU, compared to something likemap_batches
on CPUs, it would make sense to offload as much computation as possible earlier in the pipeline.Use case
See the bag of embeddings example above.