openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.64k stars 418 forks source link

[NVIDIA] Optimize deterministic scalar scatter #17886

Open serach24 opened 1 week ago

serach24 commented 1 week ago

This PR is the 1st step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR rewrites the scatter operation with scalar indices and updates, and leave the other scatter operations to be handled by original ScatterExpander. The 2nd PR to come will handle non-scalar indices and updates.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: https://github.com/google/jax/issues/17844

google-cla[bot] commented 1 week ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

serach24 commented 1 week ago
  1. Is it possible to smash into one commit with a lot more detailed commit message?

I think it is doable, but won't PRs be squashed to merge?

  1. Could you provide microbenchmark results, esp. comparing deterministic and non-deterministic scatter performance? If the performance is comparable, maybe we could even try to make it deterministic by default?

This is provided in the evaluation section of the attached doc.