Closed Joshuaalbert closed 5 days ago
Looks like the scan
should even maybe move outside the partitioned frequency axis. To be clear, partitioning of N
frequency over M
devices means that each device is responsible for contiguous chunks of N//M
frequencies. Even divisibility is required. These N//M
frequencies are then processed using vectorised operations locally on the device, and thus no communication is required. It thus makes sense that scan(vmap(vmap))
could make sense. As long as there is no unintentional communication between devices.
Done with multi_vmap which is much more complex and better. Could also automatically allow choosing best ordering of scan/vmap to minimise flops etc.
This benchmark shows that scan(vmap) is faster, probably due to the contiguous memory access pattern. Thus for performance, it makes more sense to scan over vmap in accumulation code.