google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.58k stars 107 forks source link

Hack jvp-matmul performance. #1203

Closed axch closed 1 year ago

axch commented 1 year ago

Specifically, the big problem was that JVP of matmul was materializing O(n^3)-sized intermediates, which were the results of indexing into input arrays. This change hacks the linearization of binary operations to recompute their inputs (instead of referring to them) if said inputs are array indexing. We should probably have a more cogent story for recomputation in the tangent (in particular, maybe the decision should be centralized, rather than requiring every downstream consumer of a primal value in a tangent to individually be responsible for invoking the recompute/refer logic), but this suffices to eliminate those intermediates, and speeds up the jvp-matmul benchmark by some 2.7x.

axch commented 1 year ago

To clarify: regular matmul remains embarrassingly slow (currently ~70-100x worse than numpy), but as of this change, JVP no longer imposes the additional ~2.7x overhead it used to, but is now on par with forward matmul. Also, VJP matmul speeds up by ~3.5x, bringing it to only some ~20% slower than forward matmul.