sampsyo / cs6120

advanced compilers
https://www.cs.cornell.edu/courses/cs6120/2023fa/
MIT License
739 stars 157 forks source link

Project Proposal: Peephole Improvements for numpy using JAX #331

Closed atucker closed 1 year ago

atucker commented 2 years ago

What will you do? People sometimes encounter numerical stability issues when writing mathematically intensive code, which make it harder for your program to work consistently. There are various tricks that people will use to fix these problems, but if you or your lab-mates don't already know them they can be tricky to find. Of course, many libraries will already implement operations with the tricks, but when you're writing it yourself using basic primitives (for example, if you're trying to implement something from a paper which tells you the mathematical specification but doesn't give you code) it can be tricky to remember/know about all of them.

For example, in neural networks it's pretty common to use log(sum(exp(x))) somewhere. However, sum(exp(x)) can get big very quickly and cause overflow issues, even if log(sum(exp(x))) would fit into your float. Luckily, c + log(sum(exp(x - c))) = log(sum(exp(x))) for any constant c. If you set c = max(x), then would then sum exp(x-c) where x-c <= 0, avoiding your overflows. See this blog post for a proof.

The goal of this project is to implement a few of these numerical stability tricks as peephole "optimizations" for numpy code. It is the same idea of a peephole optimization in that it looks at the code and finds some small tweaks that it could make to improve it, but the underlying goal is to improve numerical stability, not to replace the code with something equivalent and faster.

How will you do it? JAX is a library that combines autograd and xla to make it possible to automatically differentiate numpy code and use xla to run it on a GPU, by providing a drop-in replacement for numpy. It provides an intermediate representation called Jaxpr which lets it represent the mathematical operations in a function, and target it to different platforms. It also provides a tracing-based JIT compiler that lets users decorate a function to be compiled into a faster version.

My plan is to use JAX to create a Jaxpr representation of the code, which I can then search through for opportunities to make peephole improvements.

How will you empirically measure success? This project is going to emphasize correctness more than usefulness, so I will evaluate it by writing test code which checks that the implementation makes its improvements, that it doesn't change code which looks similar but doesn't actually fit into this framework, and that the improvements in fact improve numerical stability.

atucker commented 2 years ago

Code is here.