Open dsharlet opened 6 months ago
I had an interesting thought about this: if we say that buffers are always infinite rank, and the "missing" dimensions are always broadcast dimensions, then cropping the last dimension to extent 1 is equivalent to a slice and we could just dynamically implement the crop that way. for_each_element
and related helpers already implement this implicit broadcasting logic...
In large real pipelines, there are many crops for pointwise dimensions, corresponding to "batch dimensions". For example:
This leaves the pointwise dimensions in place, so the callbacks get the full rank of the inputs, and then have to do work to skip over the batch dimensions.
If we could generate slices instead of crops, we would avoid this inefficiency. To do this, we'd need some kind of indication that the callback is going to treat some dimensions as batch dimensions, and we could only do this if all uses of the dimension are treated as batch dimensions by the callback.
In the meantime, I think we can at least optimize our buffer traversal helpers for the common case of many trailing dimensions.